llm/chat/
mod.rs

1use std::collections::HashMap;
2use std::fmt;
3
4use async_trait::async_trait;
5use futures::stream::{Stream, StreamExt};
6use serde::{Deserialize, Serialize};
7use serde_json::Value;
8
9use crate::{error::LLMError, ToolCall};
10
11/// Usage metadata for a chat response.
12#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
13pub struct Usage {
14    /// Number of tokens in the prompt
15    #[serde(alias = "input_tokens")]
16    pub prompt_tokens: u32,
17    /// Number of tokens in the completion
18    #[serde(alias = "output_tokens")]
19    pub completion_tokens: u32,
20    /// Total number of tokens used
21    pub total_tokens: u32,
22    /// Breakdown of completion tokens, if available
23    #[serde(
24        skip_serializing_if = "Option::is_none",
25        alias = "output_tokens_details"
26    )]
27    pub completion_tokens_details: Option<CompletionTokensDetails>,
28    /// Breakdown of prompt tokens, if available
29    #[serde(
30        skip_serializing_if = "Option::is_none",
31        alias = "input_tokens_details"
32    )]
33    pub prompt_tokens_details: Option<PromptTokensDetails>,
34}
35
36/// Stream response chunk that mimics OpenAI's streaming response format
37#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct StreamResponse {
39    /// Array of choices in the response
40    pub choices: Vec<StreamChoice>,
41    /// Usage metadata, typically present in the final chunk
42    #[serde(skip_serializing_if = "Option::is_none")]
43    pub usage: Option<Usage>,
44}
45
46/// Individual choice in a streaming response
47#[derive(Debug, Clone, Serialize, Deserialize)]
48pub struct StreamChoice {
49    /// Delta containing the incremental content
50    pub delta: StreamDelta,
51}
52
53/// Delta content in a streaming response
54#[derive(Debug, Clone, Serialize, Deserialize)]
55pub struct StreamDelta {
56    /// The incremental content, if any
57    #[serde(skip_serializing_if = "Option::is_none")]
58    pub content: Option<String>,
59    /// The incremental tool calls, if any
60    #[serde(skip_serializing_if = "Option::is_none")]
61    pub tool_calls: Option<Vec<ToolCall>>,
62}
63
64/// Breakdown of completion tokens.
65#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
66pub struct CompletionTokensDetails {
67    /// Tokens used for reasoning (for reasoning models)
68    #[serde(skip_serializing_if = "Option::is_none")]
69    pub reasoning_tokens: Option<u32>,
70    /// Tokens used for audio output
71    #[serde(skip_serializing_if = "Option::is_none")]
72    pub audio_tokens: Option<u32>,
73}
74
75/// Breakdown of prompt tokens.
76#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
77pub struct PromptTokensDetails {
78    /// Tokens used for cached content
79    #[serde(skip_serializing_if = "Option::is_none")]
80    pub cached_tokens: Option<u32>,
81    /// Tokens used for audio input
82    #[serde(skip_serializing_if = "Option::is_none")]
83    pub audio_tokens: Option<u32>,
84}
85
86/// Role of a participant in a chat conversation.
87#[derive(Debug, Clone, PartialEq, Eq)]
88pub enum ChatRole {
89    /// The user/human participant in the conversation
90    User,
91    /// The AI assistant participant in the conversation
92    Assistant,
93}
94
95/// The supported MIME type of an image.
96#[derive(Debug, Clone, PartialEq, Eq)]
97#[non_exhaustive]
98pub enum ImageMime {
99    /// JPEG image
100    JPEG,
101    /// PNG image
102    PNG,
103    /// GIF image
104    GIF,
105    /// WebP image
106    WEBP,
107}
108
109impl ImageMime {
110    pub fn mime_type(&self) -> &'static str {
111        match self {
112            ImageMime::JPEG => "image/jpeg",
113            ImageMime::PNG => "image/png",
114            ImageMime::GIF => "image/gif",
115            ImageMime::WEBP => "image/webp",
116        }
117    }
118}
119
120/// The type of a message in a chat conversation.
121#[derive(Debug, Clone, PartialEq, Eq, Default)]
122pub enum MessageType {
123    /// A text message
124    #[default]
125    Text,
126    /// An image message
127    Image((ImageMime, Vec<u8>)),
128    /// PDF message
129    Pdf(Vec<u8>),
130    /// An image URL message
131    ImageURL(String),
132    /// A tool use
133    ToolUse(Vec<ToolCall>),
134    /// Tool result
135    ToolResult(Vec<ToolCall>),
136}
137
138/// The type of reasoning effort for a message in a chat conversation.
139pub enum ReasoningEffort {
140    /// Low reasoning effort
141    Low,
142    /// Medium reasoning effort
143    Medium,
144    /// High reasoning effort
145    High,
146}
147
148/// A single message in a chat conversation.
149#[derive(Debug, Clone)]
150pub struct ChatMessage {
151    /// The role of who sent this message (user or assistant)
152    pub role: ChatRole,
153    /// The type of the message (text, image, audio, video, etc)
154    pub message_type: MessageType,
155    /// The text content of the message
156    pub content: String,
157}
158
159/// Represents a parameter in a function tool
160#[derive(Debug, Clone, Serialize)]
161pub struct ParameterProperty {
162    /// The type of the parameter (e.g. "string", "number", "array", etc)
163    #[serde(rename = "type")]
164    pub property_type: String,
165    /// Description of what the parameter does
166    pub description: String,
167    /// When type is "array", this defines the type of the array items
168    #[serde(skip_serializing_if = "Option::is_none")]
169    pub items: Option<Box<ParameterProperty>>,
170    /// When type is "enum", this defines the possible values for the parameter
171    #[serde(skip_serializing_if = "Option::is_none", rename = "enum")]
172    pub enum_list: Option<Vec<String>>,
173}
174
175/// Represents the parameters schema for a function tool
176#[derive(Debug, Clone, Serialize)]
177pub struct ParametersSchema {
178    /// The type of the parameters object (usually "object")
179    #[serde(rename = "type")]
180    pub schema_type: String,
181    /// Map of parameter names to their properties
182    pub properties: HashMap<String, ParameterProperty>,
183    /// List of required parameter names
184    pub required: Vec<String>,
185}
186
187/// Represents a function definition for a tool.
188///
189/// The `parameters` field stores the JSON Schema describing the function
190/// arguments.  It is kept as a raw `serde_json::Value` to allow arbitrary
191/// complexity (nested arrays/objects, `oneOf`, etc.) without requiring a
192/// bespoke Rust structure.
193///
194/// Builder helpers can still generate simple schemas automatically, but the
195/// user may also provide any valid schema directly.
196#[derive(Debug, Clone, Serialize)]
197pub struct FunctionTool {
198    /// Name of the function
199    pub name: String,
200    /// Human-readable description
201    pub description: String,
202    /// JSON Schema describing the parameters
203    pub parameters: Value,
204}
205
206/// Defines rules for structured output responses based on [OpenAI's structured output requirements](https://platform.openai.com/docs/api-reference/chat/create#chat-create-response_format).
207/// Individual providers may have additional requirements or restrictions, but these should be handled by each provider's backend implementation.
208///
209/// If you plan on deserializing into this struct, make sure the source text has a `"name"` field, since that's technically the only thing required by OpenAI.
210///
211/// ## Example
212///
213/// ```
214/// use llm::chat::StructuredOutputFormat;
215/// use serde_json::json;
216///
217/// let response_format = r#"
218///     {
219///         "name": "Student",
220///         "description": "A student object",
221///         "schema": {
222///             "type": "object",
223///             "properties": {
224///                 "name": {
225///                     "type": "string"
226///                 },
227///                 "age": {
228///                     "type": "integer"
229///                 },
230///                 "is_student": {
231///                     "type": "boolean"
232///                 }
233///             },
234///             "required": ["name", "age", "is_student"]
235///         }
236///     }
237/// "#;
238/// let structured_output: StructuredOutputFormat = serde_json::from_str(response_format).unwrap();
239/// assert_eq!(structured_output.name, "Student");
240/// assert_eq!(structured_output.description, Some("A student object".to_string()));
241/// ```
242#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
243
244pub struct StructuredOutputFormat {
245    /// Name of the schema
246    pub name: String,
247    /// The description of the schema
248    pub description: Option<String>,
249    /// The JSON schema for the structured output
250    pub schema: Option<Value>,
251    /// Whether to enable strict schema adherence
252    pub strict: Option<bool>,
253}
254
255/// Represents a tool that can be used in chat
256#[derive(Debug, Clone, Serialize)]
257pub struct Tool {
258    /// The type of tool (e.g. "function")
259    #[serde(rename = "type")]
260    pub tool_type: String,
261    /// The function definition if this is a function tool
262    pub function: FunctionTool,
263}
264
265/// Tool choice determines how the LLM uses available tools.
266/// The behavior is standardized across different LLM providers.
267#[derive(Debug, Clone, Default)]
268pub enum ToolChoice {
269    /// Model can use any tool, but it must use at least one.
270    /// This is useful when you want to force the model to use tools.
271    Any,
272
273    /// Model can use any tool, and may elect to use none.
274    /// This is the default behavior and gives the model flexibility.
275    #[default]
276    Auto,
277
278    /// Model must use the specified tool and only the specified tool.
279    /// The string parameter is the name of the required tool.
280    /// This is useful when you want the model to call a specific function.
281    Tool(String),
282
283    /// Explicitly disables the use of tools.
284    /// The model will not use any tools even if they are provided.
285    None,
286}
287
288impl Serialize for ToolChoice {
289    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
290    where
291        S: serde::Serializer,
292    {
293        match self {
294            ToolChoice::Any => serializer.serialize_str("required"),
295            ToolChoice::Auto => serializer.serialize_str("auto"),
296            ToolChoice::None => serializer.serialize_str("none"),
297            ToolChoice::Tool(name) => {
298                use serde::ser::SerializeMap;
299
300                // For tool_choice: {"type": "function", "function": {"name": "function_name"}}
301                let mut map = serializer.serialize_map(Some(2))?;
302                map.serialize_entry("type", "function")?;
303
304                // Inner function object
305                let mut function_obj = std::collections::HashMap::new();
306                function_obj.insert("name", name.as_str());
307
308                map.serialize_entry("function", &function_obj)?;
309                map.end()
310            }
311        }
312    }
313}
314
315pub trait ChatResponse: std::fmt::Debug + std::fmt::Display + Send + Sync {
316    fn text(&self) -> Option<String>;
317    fn tool_calls(&self) -> Option<Vec<ToolCall>>;
318    fn thinking(&self) -> Option<String> {
319        None
320    }
321    fn usage(&self) -> Option<Usage> {
322        None
323    }
324}
325
326/// Trait for providers that support chat-style interactions.
327#[async_trait]
328pub trait ChatProvider: Sync + Send {
329    /// Sends a chat request to the provider with a sequence of messages.
330    ///
331    /// # Arguments
332    ///
333    /// * `messages` - The conversation history as a slice of chat messages
334    ///
335    /// # Returns
336    ///
337    /// The provider's response text or an error
338    async fn chat(&self, messages: &[ChatMessage]) -> Result<Box<dyn ChatResponse>, LLMError> {
339        self.chat_with_tools(messages, None).await
340    }
341
342    /// Sends a chat request to the provider with a sequence of messages and tools.
343    ///
344    /// # Arguments
345    ///
346    /// * `messages` - The conversation history as a slice of chat messages
347    /// * `tools` - Optional slice of tools to use in the chat
348    ///
349    /// # Returns
350    ///
351    /// The provider's response text or an error
352    async fn chat_with_tools(
353        &self,
354        messages: &[ChatMessage],
355        tools: Option<&[Tool]>,
356    ) -> Result<Box<dyn ChatResponse>, LLMError>;
357
358    /// Sends a chat with web search request to the provider
359    ///
360    /// # Arguments
361    ///
362    /// * `input` - The input message
363    ///
364    /// # Returns
365    ///
366    /// The provider's response text or an error
367    async fn chat_with_web_search(
368        &self,
369        _input: String,
370    ) -> Result<Box<dyn ChatResponse>, LLMError> {
371        Err(LLMError::Generic(
372            "Web search not supported for this provider".to_string(),
373        ))
374    }
375
376    /// Sends a streaming chat request to the provider with a sequence of messages.
377    ///
378    /// # Arguments
379    ///
380    /// * `messages` - The conversation history as a slice of chat messages
381    ///
382    /// # Returns
383    ///
384    /// A stream of text tokens or an error
385    async fn chat_stream(
386        &self,
387        _messages: &[ChatMessage],
388    ) -> Result<std::pin::Pin<Box<dyn Stream<Item = Result<String, LLMError>> + Send>>, LLMError>
389    {
390        Err(LLMError::Generic(
391            "Streaming not supported for this provider".to_string(),
392        ))
393    }
394
395    /// Sends a streaming chat request that returns structured response chunks.
396    ///
397    /// ⚠️ Getting usage metadata while streaming have been noticed to be a unstable depending on the provider
398    /// (it can be missing).
399    ///
400    /// This method returns a stream of `StreamResponse` objects that mimic OpenAI's
401    /// streaming response format with `.choices[0].delta.content` and `.usage`.
402    ///
403    /// # Arguments
404    ///
405    /// * `messages` - The conversation history as a slice of chat messages
406    ///
407    /// # Returns
408    ///
409    /// A stream of `StreamResponse` objects or an error
410    async fn chat_stream_struct(
411        &self,
412        _messages: &[ChatMessage],
413    ) -> Result<
414        std::pin::Pin<Box<dyn Stream<Item = Result<StreamResponse, LLMError>> + Send>>,
415        LLMError,
416    > {
417        Err(LLMError::Generic(
418            "Structured streaming not supported for this provider".to_string(),
419        ))
420    }
421
422    /// Get current memory contents if provider supports memory
423    async fn memory_contents(&self) -> Option<Vec<ChatMessage>> {
424        None
425    }
426
427    /// Summarizes a conversation history into a concise 2-3 sentence summary
428    ///
429    /// # Arguments
430    /// * `msgs` - The conversation messages to summarize
431    ///
432    /// # Returns
433    /// A string containing the summary or an error if summarization fails
434    async fn summarize_history(&self, msgs: &[ChatMessage]) -> Result<String, LLMError> {
435        let prompt = format!(
436            "Summarize in 2-3 sentences:\n{}",
437            msgs.iter()
438                .map(|m| format!("{:?}: {}", m.role, m.content))
439                .collect::<Vec<_>>()
440                .join("\n"),
441        );
442        let req = [ChatMessage::user().content(prompt).build()];
443        self.chat(&req)
444            .await?
445            .text()
446            .ok_or(LLMError::Generic("no text in summary response".into()))
447    }
448}
449
450impl fmt::Display for ReasoningEffort {
451    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
452        match self {
453            ReasoningEffort::Low => write!(f, "low"),
454            ReasoningEffort::Medium => write!(f, "medium"),
455            ReasoningEffort::High => write!(f, "high"),
456        }
457    }
458}
459
460impl ChatMessage {
461    /// Create a new builder for a user message
462    pub fn user() -> ChatMessageBuilder {
463        ChatMessageBuilder::new(ChatRole::User)
464    }
465
466    /// Create a new builder for an assistant message
467    pub fn assistant() -> ChatMessageBuilder {
468        ChatMessageBuilder::new(ChatRole::Assistant)
469    }
470}
471
472/// Builder for ChatMessage
473#[derive(Debug)]
474pub struct ChatMessageBuilder {
475    role: ChatRole,
476    message_type: MessageType,
477    content: String,
478}
479
480impl ChatMessageBuilder {
481    /// Create a new ChatMessageBuilder with specified role
482    pub fn new(role: ChatRole) -> Self {
483        Self {
484            role,
485            message_type: MessageType::default(),
486            content: String::new(),
487        }
488    }
489
490    /// Set the message content
491    pub fn content<S: Into<String>>(mut self, content: S) -> Self {
492        self.content = content.into();
493        self
494    }
495
496    /// Set the message type as Image
497    pub fn image(mut self, image_mime: ImageMime, raw_bytes: Vec<u8>) -> Self {
498        self.message_type = MessageType::Image((image_mime, raw_bytes));
499        self
500    }
501
502    /// Set the message type as Image
503    pub fn pdf(mut self, raw_bytes: Vec<u8>) -> Self {
504        self.message_type = MessageType::Pdf(raw_bytes);
505        self
506    }
507
508    /// Set the message type as ImageURL
509    pub fn image_url(mut self, url: impl Into<String>) -> Self {
510        self.message_type = MessageType::ImageURL(url.into());
511        self
512    }
513
514    /// Set the message type as ToolUse
515    pub fn tool_use(mut self, tools: Vec<ToolCall>) -> Self {
516        self.message_type = MessageType::ToolUse(tools);
517        self
518    }
519
520    /// Set the message type as ToolResult
521    pub fn tool_result(mut self, tools: Vec<ToolCall>) -> Self {
522        self.message_type = MessageType::ToolResult(tools);
523        self
524    }
525
526    /// Build the ChatMessage
527    pub fn build(self) -> ChatMessage {
528        ChatMessage {
529            role: self.role,
530            message_type: self.message_type,
531            content: self.content,
532        }
533    }
534}
535
536/// Creates a Server-Sent Events (SSE) stream from an HTTP response.
537///
538/// # Arguments
539///
540/// * `response` - The HTTP response from the streaming API
541/// * `parser` - Function to parse each SSE chunk into optional text content
542///
543/// # Returns
544///
545/// A pinned stream of text tokens or an error
546pub(crate) fn create_sse_stream<F>(
547    response: reqwest::Response,
548    parser: F,
549) -> std::pin::Pin<Box<dyn Stream<Item = Result<String, LLMError>> + Send>>
550where
551    F: Fn(&str) -> Result<Option<String>, LLMError> + Send + 'static,
552{
553    let stream = response
554        .bytes_stream()
555        .scan(
556            (String::new(), Vec::new()),
557            move |(buffer, utf8_buffer), chunk| {
558                let result = match chunk {
559                    Ok(bytes) => {
560                        utf8_buffer.extend_from_slice(&bytes);
561
562                        match String::from_utf8(utf8_buffer.clone()) {
563                            Ok(text) => {
564                                buffer.push_str(&text);
565                                utf8_buffer.clear();
566                            }
567                            Err(e) => {
568                                let valid_up_to = e.utf8_error().valid_up_to();
569                                if valid_up_to > 0 {
570                                    // Safe to use from_utf8_lossy here since valid_up_to points to
571                                    // a valid UTF-8 boundary - no replacement characters will be introduced
572                                    let valid =
573                                        String::from_utf8_lossy(&utf8_buffer[..valid_up_to]);
574                                    buffer.push_str(&valid);
575                                    utf8_buffer.drain(..valid_up_to);
576                                }
577                            }
578                        }
579
580                        let mut results = Vec::new();
581
582                        while let Some(pos) = buffer.find("\n\n") {
583                            let event = buffer[..pos + 2].to_string();
584                            buffer.drain(..pos + 2);
585
586                            match parser(&event) {
587                                Ok(Some(content)) => results.push(Ok(content)),
588                                Ok(None) => {}
589                                Err(e) => results.push(Err(e)),
590                            }
591                        }
592
593                        Some(results)
594                    }
595                    Err(e) => Some(vec![Err(LLMError::HttpError(e.to_string()))]),
596                };
597
598                async move { result }
599            },
600        )
601        .flat_map(futures::stream::iter);
602
603    Box::pin(stream)
604}
605
606#[cfg(test)]
607mod tests {
608    use super::*;
609    use bytes::Bytes;
610    use futures::stream::StreamExt;
611
612    #[tokio::test]
613    async fn test_create_sse_stream_handles_split_utf8() {
614        let test_data = "data: Positive reactions\n\n".as_bytes();
615
616        let chunks: Vec<Result<Bytes, reqwest::Error>> = vec![
617            Ok(Bytes::from(&test_data[..10])),
618            Ok(Bytes::from(&test_data[10..])),
619        ];
620
621        let mock_response = create_mock_response(chunks);
622
623        let parser = |event: &str| -> Result<Option<String>, LLMError> {
624            if let Some(content) = event.strip_prefix("data: ") {
625                let content = content.trim();
626                if content.is_empty() {
627                    return Ok(None);
628                }
629                Ok(Some(content.to_string()))
630            } else {
631                Ok(None)
632            }
633        };
634
635        let mut stream = create_sse_stream(mock_response, parser);
636
637        let mut results = Vec::new();
638        while let Some(result) = stream.next().await {
639            results.push(result);
640        }
641
642        assert_eq!(results.len(), 1);
643        assert_eq!(results[0].as_ref().unwrap(), "Positive reactions");
644    }
645
646    #[tokio::test]
647    async fn test_create_sse_stream_handles_split_sse_events() {
648        let event1 = "data: First event\n\n";
649        let event2 = "data: Second event\n\n";
650        let combined = format!("{}{}", event1, event2);
651        let test_data = combined.as_bytes().to_vec();
652
653        let split_point = event1.len() + 5;
654        let chunks: Vec<Result<Bytes, reqwest::Error>> = vec![
655            Ok(Bytes::from(test_data[..split_point].to_vec())),
656            Ok(Bytes::from(test_data[split_point..].to_vec())),
657        ];
658
659        let mock_response = create_mock_response(chunks);
660
661        let parser = |event: &str| -> Result<Option<String>, LLMError> {
662            if let Some(content) = event.strip_prefix("data: ") {
663                let content = content.trim();
664                if content.is_empty() {
665                    return Ok(None);
666                }
667                Ok(Some(content.to_string()))
668            } else {
669                Ok(None)
670            }
671        };
672
673        let mut stream = create_sse_stream(mock_response, parser);
674
675        let mut results = Vec::new();
676        while let Some(result) = stream.next().await {
677            results.push(result);
678        }
679
680        assert_eq!(results.len(), 2);
681        assert_eq!(results[0].as_ref().unwrap(), "First event");
682        assert_eq!(results[1].as_ref().unwrap(), "Second event");
683    }
684
685    #[tokio::test]
686    async fn test_create_sse_stream_handles_multibyte_utf8_split() {
687        let multibyte_char = "✨";
688        let event = format!("data: Star {}\n\n", multibyte_char);
689        let test_data = event.as_bytes().to_vec();
690
691        let emoji_start = event.find(multibyte_char).unwrap();
692        let split_in_emoji = emoji_start + 1;
693
694        let chunks: Vec<Result<Bytes, reqwest::Error>> = vec![
695            Ok(Bytes::from(test_data[..split_in_emoji].to_vec())),
696            Ok(Bytes::from(test_data[split_in_emoji..].to_vec())),
697        ];
698
699        let mock_response = create_mock_response(chunks);
700
701        let parser = |event: &str| -> Result<Option<String>, LLMError> {
702            if let Some(content) = event.strip_prefix("data: ") {
703                let content = content.trim();
704                if content.is_empty() {
705                    return Ok(None);
706                }
707                Ok(Some(content.to_string()))
708            } else {
709                Ok(None)
710            }
711        };
712
713        let mut stream = create_sse_stream(mock_response, parser);
714
715        let mut results = Vec::new();
716        while let Some(result) = stream.next().await {
717            results.push(result);
718        }
719
720        assert_eq!(results.len(), 1);
721        assert_eq!(
722            results[0].as_ref().unwrap(),
723            &format!("Star {}", multibyte_char)
724        );
725    }
726
727    fn create_mock_response(chunks: Vec<Result<Bytes, reqwest::Error>>) -> reqwest::Response {
728        use http_body_util::StreamBody;
729        use reqwest::Body;
730
731        let frame_stream = futures::stream::iter(
732            chunks
733                .into_iter()
734                .map(|chunk| chunk.map(|bytes| hyper::body::Frame::data(bytes))),
735        );
736
737        let body = StreamBody::new(frame_stream);
738        let body = Body::wrap(body);
739
740        let http_response = http::Response::builder().status(200).body(body).unwrap();
741
742        http_response.into()
743    }
744}