Skip to main content

rivven_llm/
types.rs

1//! Core types for LLM interactions
2//!
3//! Provider-agnostic types for chat completions and embeddings.
4
5use serde::{Deserialize, Serialize};
6use tracing::warn;
7
8// ============================================================================
9// Chat Types
10// ============================================================================
11
12/// Role of a chat message participant
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
14#[serde(rename_all = "lowercase")]
15pub enum Role {
16    /// System message (sets behavior / persona)
17    System,
18    /// User message (the human prompt)
19    User,
20    /// Assistant message (the LLM response)
21    Assistant,
22    /// Tool/function call result
23    Tool,
24}
25
26impl std::fmt::Display for Role {
27    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28        match self {
29            Role::System => write!(f, "system"),
30            Role::User => write!(f, "user"),
31            Role::Assistant => write!(f, "assistant"),
32            Role::Tool => write!(f, "tool"),
33        }
34    }
35}
36
37/// A single message in a chat conversation
38#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct ChatMessage {
40    /// Role of the message sender
41    pub role: Role,
42    /// Message content
43    pub content: String,
44}
45
46impl ChatMessage {
47    /// Create a system message
48    pub fn system(content: impl Into<String>) -> Self {
49        Self {
50            role: Role::System,
51            content: content.into(),
52        }
53    }
54
55    /// Create a user message
56    pub fn user(content: impl Into<String>) -> Self {
57        Self {
58            role: Role::User,
59            content: content.into(),
60        }
61    }
62
63    /// Create an assistant message
64    pub fn assistant(content: impl Into<String>) -> Self {
65        Self {
66            role: Role::Assistant,
67            content: content.into(),
68        }
69    }
70}
71
72/// Request for a chat completion
73///
74/// # Security — Prompt Injection
75///
76/// When building a `ChatRequest` from untrusted input, ensure that
77/// user-supplied text is placed exclusively in `User`-role messages and is
78/// never concatenated into `System` prompts without sanitization. See
79/// [`LlmProvider`](crate::provider::LlmProvider) for detailed guidance.
80#[derive(Debug, Clone)]
81pub struct ChatRequest {
82    /// Messages in the conversation
83    pub messages: Vec<ChatMessage>,
84    /// Model override (uses provider default if None)
85    pub model: Option<String>,
86    /// Sampling temperature (0.0–2.0, lower = more deterministic)
87    pub temperature: Option<f32>,
88    /// Maximum tokens in the response
89    pub max_tokens: Option<u32>,
90    /// Top-p nucleus sampling (0.0–1.0)
91    pub top_p: Option<f32>,
92    /// Stop sequences
93    pub stop: Vec<String>,
94}
95
96impl ChatRequest {
97    /// Create a builder for `ChatRequest`
98    pub fn builder() -> ChatRequestBuilder {
99        ChatRequestBuilder::default()
100    }
101
102    /// Quick single-prompt request
103    pub fn prompt(content: impl Into<String>) -> Self {
104        Self {
105            messages: vec![ChatMessage::user(content)],
106            model: None,
107            temperature: None,
108            max_tokens: None,
109            top_p: None,
110            stop: Vec::new(),
111        }
112    }
113
114    /// Quick single-prompt with system message
115    pub fn with_system(system: impl Into<String>, prompt: impl Into<String>) -> Self {
116        Self {
117            messages: vec![ChatMessage::system(system), ChatMessage::user(prompt)],
118            model: None,
119            temperature: None,
120            max_tokens: None,
121            top_p: None,
122            stop: Vec::new(),
123        }
124    }
125}
126
127/// Builder for `ChatRequest`
128#[derive(Debug, Default)]
129pub struct ChatRequestBuilder {
130    messages: Vec<ChatMessage>,
131    model: Option<String>,
132    temperature: Option<f32>,
133    max_tokens: Option<u32>,
134    top_p: Option<f32>,
135    stop: Vec<String>,
136}
137
138impl ChatRequestBuilder {
139    /// Add a message to the conversation
140    pub fn message(mut self, msg: ChatMessage) -> Self {
141        self.messages.push(msg);
142        self
143    }
144
145    /// Add multiple messages
146    pub fn messages(mut self, msgs: impl IntoIterator<Item = ChatMessage>) -> Self {
147        self.messages.extend(msgs);
148        self
149    }
150
151    /// Set the system prompt
152    pub fn system(self, content: impl Into<String>) -> Self {
153        self.message(ChatMessage::system(content))
154    }
155
156    /// Add a user message
157    pub fn user(self, content: impl Into<String>) -> Self {
158        self.message(ChatMessage::user(content))
159    }
160
161    /// Override model
162    pub fn model(mut self, model: impl Into<String>) -> Self {
163        self.model = Some(model.into());
164        self
165    }
166
167    /// Set temperature (0.0–2.0). Values outside this range are clamped.
168    pub fn temperature(mut self, t: f32) -> Self {
169        let clamped = t.clamp(0.0, 2.0);
170        if (clamped - t).abs() > f32::EPSILON {
171            warn!(
172                requested = t,
173                clamped = clamped,
174                "temperature value {t} out of range [0.0, 2.0], clamped to {clamped}"
175            );
176        }
177        self.temperature = Some(clamped);
178        self
179    }
180
181    /// Set max tokens
182    pub fn max_tokens(mut self, n: u32) -> Self {
183        self.max_tokens = Some(n);
184        self
185    }
186
187    /// Set top-p nucleus sampling (0.0–1.0). Values outside this range are clamped.
188    pub fn top_p(mut self, p: f32) -> Self {
189        let clamped = p.clamp(0.0, 1.0);
190        if (clamped - p).abs() > f32::EPSILON {
191            warn!(
192                requested = p,
193                clamped = clamped,
194                "top_p value {p} out of range [0.0, 1.0], clamped to {clamped}"
195            );
196        }
197        self.top_p = Some(clamped);
198        self
199    }
200
201    /// Add a stop sequence
202    pub fn stop(mut self, s: impl Into<String>) -> Self {
203        self.stop.push(s.into());
204        self
205    }
206
207    /// Build the request
208    pub fn build(self) -> ChatRequest {
209        ChatRequest {
210            messages: self.messages,
211            model: self.model,
212            temperature: self.temperature,
213            max_tokens: self.max_tokens,
214            top_p: self.top_p,
215            stop: self.stop,
216        }
217    }
218}
219
220/// Reason why the model stopped generating
221#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
222#[serde(rename_all = "snake_case")]
223pub enum FinishReason {
224    /// Model completed naturally
225    Stop,
226    /// Hit max_tokens limit
227    Length,
228    /// Content was filtered by safety systems
229    ContentFilter,
230    /// Model made a tool/function call
231    ToolCalls,
232}
233
234impl std::fmt::Display for FinishReason {
235    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
236        match self {
237            FinishReason::Stop => write!(f, "stop"),
238            FinishReason::Length => write!(f, "length"),
239            FinishReason::ContentFilter => write!(f, "content_filter"),
240            FinishReason::ToolCalls => write!(f, "tool_calls"),
241        }
242    }
243}
244
245/// A single choice in a chat completion response
246#[derive(Debug, Clone)]
247pub struct ChatChoice {
248    /// Index of this choice (for n>1 requests)
249    pub index: u32,
250    /// The generated message
251    pub message: ChatMessage,
252    /// Why the model stopped
253    pub finish_reason: FinishReason,
254}
255
256/// Token usage statistics
257#[derive(Debug, Clone, Copy, Default)]
258pub struct Usage {
259    /// Tokens in the prompt
260    pub prompt_tokens: u32,
261    /// Tokens in the completion
262    pub completion_tokens: u32,
263    /// Total tokens consumed
264    pub total_tokens: u32,
265}
266
267/// Response from a chat completion
268#[derive(Debug, Clone)]
269pub struct ChatResponse {
270    /// Provider-assigned response ID
271    pub id: String,
272    /// Model that generated the response
273    pub model: String,
274    /// Generated choices
275    pub choices: Vec<ChatChoice>,
276    /// Token usage
277    pub usage: Usage,
278}
279
280impl ChatResponse {
281    /// Get the content of the first choice (convenience)
282    pub fn content(&self) -> &str {
283        self.choices
284            .first()
285            .map(|c| c.message.content.as_str())
286            .unwrap_or("")
287    }
288
289    /// Get the finish reason of the first choice
290    pub fn finish_reason(&self) -> Option<FinishReason> {
291        self.choices.first().map(|c| c.finish_reason)
292    }
293}
294
295// ============================================================================
296// Embedding Types
297// ============================================================================
298
299/// Request for text embeddings
300#[derive(Debug, Clone)]
301pub struct EmbeddingRequest {
302    /// Input texts to embed
303    pub input: Vec<String>,
304    /// Model override (uses provider default if None)
305    pub model: Option<String>,
306    /// Number of dimensions (if the model supports it)
307    pub dimensions: Option<u32>,
308    /// Input type hint (e.g. `"search_document"` or `"search_query"` for Cohere)
309    ///
310    /// Providers that don't support this field ignore it.
311    pub input_type: Option<String>,
312}
313
314impl EmbeddingRequest {
315    /// Create a builder for `EmbeddingRequest`
316    pub fn builder() -> EmbeddingRequestBuilder {
317        EmbeddingRequestBuilder::default()
318    }
319
320    /// Quick single-text embedding
321    pub fn single(text: impl Into<String>) -> Self {
322        Self {
323            input: vec![text.into()],
324            model: None,
325            dimensions: None,
326            input_type: None,
327        }
328    }
329
330    /// Quick batch embedding
331    pub fn batch(texts: impl IntoIterator<Item = impl Into<String>>) -> Self {
332        Self {
333            input: texts.into_iter().map(Into::into).collect(),
334            model: None,
335            dimensions: None,
336            input_type: None,
337        }
338    }
339}
340
341/// Builder for `EmbeddingRequest`
342#[derive(Debug, Default)]
343pub struct EmbeddingRequestBuilder {
344    input: Vec<String>,
345    model: Option<String>,
346    dimensions: Option<u32>,
347    input_type: Option<String>,
348}
349
350impl EmbeddingRequestBuilder {
351    /// Add a text input
352    pub fn input(mut self, text: impl Into<String>) -> Self {
353        self.input.push(text.into());
354        self
355    }
356
357    /// Add multiple text inputs
358    pub fn inputs(mut self, texts: impl IntoIterator<Item = impl Into<String>>) -> Self {
359        self.input.extend(texts.into_iter().map(Into::into));
360        self
361    }
362
363    /// Override model
364    pub fn model(mut self, model: impl Into<String>) -> Self {
365        self.model = Some(model.into());
366        self
367    }
368
369    /// Set embedding dimensions
370    pub fn dimensions(mut self, d: u32) -> Self {
371        self.dimensions = Some(d);
372        self
373    }
374
375    /// Set input type hint (e.g. `"search_document"`, `"search_query"`)
376    pub fn input_type(mut self, t: impl Into<String>) -> Self {
377        self.input_type = Some(t.into());
378        self
379    }
380
381    /// Build the request
382    pub fn build(self) -> EmbeddingRequest {
383        EmbeddingRequest {
384            input: self.input,
385            model: self.model,
386            dimensions: self.dimensions,
387            input_type: self.input_type,
388        }
389    }
390}
391
392/// A single embedding vector
393#[derive(Debug, Clone)]
394pub struct Embedding {
395    /// Index of this embedding in the batch
396    pub index: u32,
397    /// The embedding vector (f32 values)
398    pub values: Vec<f32>,
399}
400
401impl Embedding {
402    /// Dimensionality of the embedding
403    pub fn dimensions(&self) -> usize {
404        self.values.len()
405    }
406}
407
408/// Token usage for embedding requests
409#[derive(Debug, Clone, Copy, Default)]
410pub struct EmbeddingUsage {
411    /// Total tokens processed
412    pub prompt_tokens: u32,
413    /// Total tokens (same as prompt_tokens for embeddings)
414    pub total_tokens: u32,
415}
416
417/// Response from an embedding request
418#[derive(Debug, Clone)]
419pub struct EmbeddingResponse {
420    /// Model that generated the embeddings
421    pub model: String,
422    /// The embedding vectors (one per input text)
423    pub embeddings: Vec<Embedding>,
424    /// Token usage
425    pub usage: EmbeddingUsage,
426}
427
428impl EmbeddingResponse {
429    /// Get the first embedding vector (convenience for single-text requests)
430    pub fn first_embedding(&self) -> Option<&[f32]> {
431        self.embeddings.first().map(|e| e.values.as_slice())
432    }
433
434    /// Number of embeddings
435    pub fn len(&self) -> usize {
436        self.embeddings.len()
437    }
438
439    /// Whether the response is empty
440    pub fn is_empty(&self) -> bool {
441        self.embeddings.is_empty()
442    }
443}
444
445#[cfg(test)]
446mod tests {
447    use super::*;
448
449    #[test]
450    fn test_chat_message_constructors() {
451        let sys = ChatMessage::system("You are helpful.");
452        assert_eq!(sys.role, Role::System);
453        assert_eq!(sys.content, "You are helpful.");
454
455        let usr = ChatMessage::user("Hello");
456        assert_eq!(usr.role, Role::User);
457
458        let ast = ChatMessage::assistant("Hi there!");
459        assert_eq!(ast.role, Role::Assistant);
460    }
461
462    #[test]
463    fn test_chat_request_builder() {
464        let req = ChatRequest::builder()
465            .system("Be concise.")
466            .user("What is Rust?")
467            .temperature(0.5)
468            .max_tokens(100)
469            .model("gpt-4o")
470            .build();
471
472        assert_eq!(req.messages.len(), 2);
473        assert_eq!(req.messages[0].role, Role::System);
474        assert_eq!(req.messages[1].role, Role::User);
475        assert_eq!(req.temperature, Some(0.5));
476        assert_eq!(req.max_tokens, Some(100));
477        assert_eq!(req.model.as_deref(), Some("gpt-4o"));
478    }
479
480    #[test]
481    fn test_chat_request_prompt() {
482        let req = ChatRequest::prompt("Hello");
483        assert_eq!(req.messages.len(), 1);
484        assert_eq!(req.messages[0].role, Role::User);
485        assert_eq!(req.messages[0].content, "Hello");
486    }
487
488    #[test]
489    fn test_chat_request_with_system() {
490        let req = ChatRequest::with_system("Be brief.", "Hi");
491        assert_eq!(req.messages.len(), 2);
492        assert_eq!(req.messages[0].role, Role::System);
493    }
494
495    #[test]
496    fn test_temperature_clamping() {
497        let req = ChatRequest::builder().temperature(5.0).build();
498        assert_eq!(req.temperature, Some(2.0));
499
500        let req = ChatRequest::builder().temperature(-1.0).build();
501        assert_eq!(req.temperature, Some(0.0));
502    }
503
504    #[test]
505    fn test_top_p_clamping() {
506        let req = ChatRequest::builder().top_p(1.5).build();
507        assert_eq!(req.top_p, Some(1.0));
508    }
509
510    #[test]
511    fn test_chat_response_content() {
512        let resp = ChatResponse {
513            id: "test".to_string(),
514            model: "gpt-4o".to_string(),
515            choices: vec![ChatChoice {
516                index: 0,
517                message: ChatMessage::assistant("Hello!"),
518                finish_reason: FinishReason::Stop,
519            }],
520            usage: Usage {
521                prompt_tokens: 5,
522                completion_tokens: 1,
523                total_tokens: 6,
524            },
525        };
526        assert_eq!(resp.content(), "Hello!");
527        assert_eq!(resp.finish_reason(), Some(FinishReason::Stop));
528    }
529
530    #[test]
531    fn test_chat_response_empty() {
532        let resp = ChatResponse {
533            id: "test".to_string(),
534            model: "test".to_string(),
535            choices: vec![],
536            usage: Usage::default(),
537        };
538        assert_eq!(resp.content(), "");
539        assert_eq!(resp.finish_reason(), None);
540    }
541
542    #[test]
543    fn test_embedding_request_single() {
544        let req = EmbeddingRequest::single("Hello world");
545        assert_eq!(req.input.len(), 1);
546        assert_eq!(req.input[0], "Hello world");
547    }
548
549    #[test]
550    fn test_embedding_request_batch() {
551        let req = EmbeddingRequest::batch(["one", "two", "three"]);
552        assert_eq!(req.input.len(), 3);
553    }
554
555    #[test]
556    fn test_embedding_request_builder() {
557        let req = EmbeddingRequest::builder()
558            .input("hello")
559            .input("world")
560            .model("text-embedding-3-small")
561            .dimensions(256)
562            .build();
563        assert_eq!(req.input.len(), 2);
564        assert_eq!(req.model.as_deref(), Some("text-embedding-3-small"));
565        assert_eq!(req.dimensions, Some(256));
566        assert!(req.input_type.is_none());
567    }
568
569    #[test]
570    fn test_embedding_request_builder_with_input_type() {
571        let req = EmbeddingRequest::builder()
572            .input("query")
573            .model("cohere.embed-english-v3")
574            .input_type("search_query")
575            .build();
576        assert_eq!(req.input_type.as_deref(), Some("search_query"));
577    }
578
579    #[test]
580    fn test_embedding_response_first() {
581        let resp = EmbeddingResponse {
582            model: "test".to_string(),
583            embeddings: vec![Embedding {
584                index: 0,
585                values: vec![0.1, 0.2, 0.3],
586            }],
587            usage: EmbeddingUsage {
588                prompt_tokens: 2,
589                total_tokens: 2,
590            },
591        };
592        assert_eq!(resp.first_embedding(), Some([0.1, 0.2, 0.3].as_slice()));
593        assert_eq!(resp.len(), 1);
594        assert!(!resp.is_empty());
595        assert_eq!(resp.embeddings[0].dimensions(), 3);
596    }
597
598    #[test]
599    fn test_embedding_response_empty() {
600        let resp = EmbeddingResponse {
601            model: "test".to_string(),
602            embeddings: vec![],
603            usage: EmbeddingUsage::default(),
604        };
605        assert!(resp.is_empty());
606        assert_eq!(resp.first_embedding(), None);
607    }
608
609    #[test]
610    fn test_role_display() {
611        assert_eq!(Role::System.to_string(), "system");
612        assert_eq!(Role::User.to_string(), "user");
613        assert_eq!(Role::Assistant.to_string(), "assistant");
614        assert_eq!(Role::Tool.to_string(), "tool");
615    }
616
617    #[test]
618    fn test_finish_reason_display() {
619        assert_eq!(FinishReason::Stop.to_string(), "stop");
620        assert_eq!(FinishReason::Length.to_string(), "length");
621        assert_eq!(FinishReason::ContentFilter.to_string(), "content_filter");
622        assert_eq!(FinishReason::ToolCalls.to_string(), "tool_calls");
623    }
624
625    #[test]
626    fn test_role_serde_roundtrip() {
627        let json = serde_json::to_string(&Role::System).unwrap();
628        assert_eq!(json, r#""system""#);
629        let back: Role = serde_json::from_str(&json).unwrap();
630        assert_eq!(back, Role::System);
631    }
632
633    #[test]
634    fn test_finish_reason_serde_roundtrip() {
635        let json = serde_json::to_string(&FinishReason::ContentFilter).unwrap();
636        assert_eq!(json, r#""content_filter""#);
637        let back: FinishReason = serde_json::from_str(&json).unwrap();
638        assert_eq!(back, FinishReason::ContentFilter);
639    }
640
641    #[test]
642    fn test_usage_default() {
643        let u = Usage::default();
644        assert_eq!(u.prompt_tokens, 0);
645        assert_eq!(u.completion_tokens, 0);
646        assert_eq!(u.total_tokens, 0);
647    }
648}