Skip to main content

autoagents_llm/chat/
mod.rs

1use std::collections::HashMap;
2use std::fmt;
3use std::pin::Pin;
4
5use async_trait::async_trait;
6use futures::stream::{Stream, StreamExt};
7use serde::{Deserialize, Serialize};
8use serde_json::Value;
9
10use crate::{ToolCall, error::LLMError};
11
12/// Usage metadata for a chat response.
13#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
14pub struct Usage {
15    /// Number of tokens in the prompt
16    #[serde(alias = "input_tokens")]
17    pub prompt_tokens: u32,
18    /// Number of tokens in the completion
19    #[serde(alias = "output_tokens")]
20    pub completion_tokens: u32,
21    /// Total number of tokens used
22    pub total_tokens: u32,
23    /// Breakdown of completion tokens, if available
24    #[serde(
25        skip_serializing_if = "Option::is_none",
26        alias = "output_tokens_details"
27    )]
28    pub completion_tokens_details: Option<CompletionTokensDetails>,
29    /// Breakdown of prompt tokens, if available
30    #[serde(
31        skip_serializing_if = "Option::is_none",
32        alias = "input_tokens_details"
33    )]
34    pub prompt_tokens_details: Option<PromptTokensDetails>,
35}
36
37/// Stream response chunk that mimics OpenAI's streaming response format
38#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct StreamResponse {
40    /// Array of choices in the response
41    pub choices: Vec<StreamChoice>,
42    /// Usage metadata, typically present in the final chunk
43    #[serde(skip_serializing_if = "Option::is_none")]
44    pub usage: Option<Usage>,
45}
46
47/// Individual choice in a streaming response
48#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct StreamChoice {
50    /// Delta containing the incremental content
51    pub delta: StreamDelta,
52}
53
54/// Delta content in a streaming response
55#[derive(Debug, Clone, Serialize, Deserialize)]
56pub struct StreamDelta {
57    /// The incremental content, if any
58    #[serde(skip_serializing_if = "Option::is_none")]
59    pub content: Option<String>,
60    /// The incremental reasoning content, if any
61    #[serde(skip_serializing_if = "Option::is_none")]
62    pub reasoning_content: Option<String>,
63    /// The incremental tool calls, if any
64    #[serde(skip_serializing_if = "Option::is_none")]
65    pub tool_calls: Option<Vec<ToolCall>>,
66}
67
68/// A streaming chunk that can be either text or a tool call event.
69///
70/// This enum provides a unified representation of streaming events
71/// when using `chat_stream_with_tools`. It allows callers to receive
72/// text deltas as they arrive while also handling tool use blocks.
73#[derive(Debug, Clone, Serialize, Deserialize)]
74pub enum StreamChunk {
75    /// Text content delta
76    Text(String),
77    /// Reasoning content delta
78    ReasoningContent(String),
79
80    /// Tool use block started (contains tool id and name)
81    ToolUseStart {
82        /// The index of this content block in the response
83        index: usize,
84        /// The unique ID for this tool use
85        id: String,
86        /// The name of the tool being called
87        name: String,
88    },
89
90    /// Tool use input JSON delta (partial JSON string)
91    ToolUseInputDelta {
92        /// The index of this content block
93        index: usize,
94        /// Partial JSON string for the tool input
95        partial_json: String,
96    },
97
98    /// Tool use block complete with assembled ToolCall
99    ToolUseComplete {
100        /// The index of this content block
101        index: usize,
102        /// The complete tool call with id, name, and parsed arguments
103        tool_call: ToolCall,
104    },
105
106    /// Stream ended with stop reason
107    Done {
108        /// The reason the stream stopped (e.g., "end_turn", "tool_use")
109        stop_reason: String,
110    },
111    Usage(Usage),
112}
113
114/// Breakdown of completion tokens.
115#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
116pub struct CompletionTokensDetails {
117    /// Tokens used for reasoning (for reasoning models)
118    #[serde(skip_serializing_if = "Option::is_none")]
119    pub reasoning_tokens: Option<u32>,
120    /// Tokens used for audio output
121    #[serde(skip_serializing_if = "Option::is_none")]
122    pub audio_tokens: Option<u32>,
123}
124
125/// Breakdown of prompt tokens.
126#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
127pub struct PromptTokensDetails {
128    /// Tokens used for cached content
129    #[serde(skip_serializing_if = "Option::is_none")]
130    pub cached_tokens: Option<u32>,
131    /// Tokens used for audio input
132    #[serde(skip_serializing_if = "Option::is_none")]
133    pub audio_tokens: Option<u32>,
134}
135
136/// Role of a participant in a chat conversation.
137#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
138pub enum ChatRole {
139    // The system Prompt
140    System,
141    /// The user/human participant in the conversation
142    User,
143    /// The AI assistant participant in the conversation
144    Assistant,
145    /// Tool/function response
146    Tool,
147}
148
149impl fmt::Display for ChatRole {
150    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
151        let value = match self {
152            ChatRole::System => "system",
153            ChatRole::User => "user",
154            ChatRole::Assistant => "assistant",
155            ChatRole::Tool => "tool",
156        };
157        f.write_str(value)
158    }
159}
160
161/// The supported MIME type of an image.
162#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
163#[non_exhaustive]
164pub enum ImageMime {
165    /// JPEG image
166    JPEG,
167    /// PNG image
168    PNG,
169    /// GIF image
170    GIF,
171    /// WebP image
172    WEBP,
173}
174
175impl ImageMime {
176    pub fn mime_type(&self) -> &'static str {
177        match self {
178            ImageMime::JPEG => "image/jpeg",
179            ImageMime::PNG => "image/png",
180            ImageMime::GIF => "image/gif",
181            ImageMime::WEBP => "image/webp",
182        }
183    }
184}
185
186/// The type of a message in a chat conversation.
187#[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)]
188pub enum MessageType {
189    /// A text message
190    #[default]
191    Text,
192    /// An image message
193    Image((ImageMime, Vec<u8>)),
194    /// PDF message
195    Pdf(Vec<u8>),
196    /// An image URL message
197    ImageURL(String),
198    /// A tool use
199    ToolUse(Vec<ToolCall>),
200    /// Tool result
201    ToolResult(Vec<ToolCall>),
202}
203
204/// The type of reasoning effort for a message in a chat conversation.
205pub enum ReasoningEffort {
206    /// Low reasoning effort
207    Low,
208    /// Medium reasoning effort
209    Medium,
210    /// High reasoning effort
211    High,
212}
213
214/// A single message in a chat conversation.
215#[derive(Debug, Clone, Serialize, Deserialize)]
216pub struct ChatMessage {
217    /// The role of who sent this message (user or assistant)
218    pub role: ChatRole,
219    /// The type of the message (text, image, audio, video, etc)
220    pub message_type: MessageType,
221    /// The text content of the message
222    pub content: String,
223}
224
225/// Represents a parameter in a function tool
226#[derive(Debug, Clone, Serialize)]
227pub struct ParameterProperty {
228    /// The type of the parameter (e.g. "string", "number", "array", etc)
229    #[serde(rename = "type")]
230    pub property_type: String,
231    /// Description of what the parameter does
232    pub description: String,
233    /// When type is "array", this defines the type of the array items
234    #[serde(skip_serializing_if = "Option::is_none")]
235    pub items: Option<Box<ParameterProperty>>,
236    /// When type is "enum", this defines the possible values for the parameter
237    #[serde(skip_serializing_if = "Option::is_none", rename = "enum")]
238    pub enum_list: Option<Vec<String>>,
239}
240
241/// Represents the parameters schema for a function tool
242#[derive(Debug, Clone, Serialize)]
243pub struct ParametersSchema {
244    /// The type of the parameters object (usually "object")
245    #[serde(rename = "type")]
246    pub schema_type: String,
247    /// Map of parameter names to their properties
248    pub properties: HashMap<String, ParameterProperty>,
249    /// List of required parameter names
250    pub required: Vec<String>,
251}
252
253/// Represents a function definition for a tool.
254///
255/// The `parameters` field stores the JSON Schema describing the function
256/// arguments.  It is kept as a raw `serde_json::Value` to allow arbitrary
257/// complexity (nested arrays/objects, `oneOf`, etc.) without requiring a
258/// bespoke Rust structure.
259///
260/// Builder helpers can still generate simple schemas automatically, but the
261/// user may also provide any valid schema directly.
262#[derive(Debug, Clone, Serialize)]
263pub struct FunctionTool {
264    /// Name of the function
265    pub name: String,
266    /// Human-readable description
267    pub description: String,
268    /// JSON Schema describing the parameters
269    pub parameters: Value,
270}
271
272/// Defines rules for structured output responses based on [OpenAI's structured output requirements](https://platform.openai.com/docs/api-reference/chat/create#chat-create-response_format).
273/// Individual providers may have additional requirements or restrictions, but these should be handled by each provider's backend implementation.
274///
275/// If you plan on deserializing into this struct, make sure the source text has a `"name"` field, since that's technically the only thing required by OpenAI.
276///
277/// ## Example
278///
279/// ```
280/// use autoagents_llm::chat::StructuredOutputFormat;
281/// use serde_json::json;
282///
283/// let response_format = r#"
284///     {
285///         "name": "Student",
286///         "description": "A student object",
287///         "schema": {
288///             "type": "object",
289///             "properties": {
290///                 "name": {
291///                     "type": "string"
292///                 },
293///                 "age": {
294///                     "type": "integer"
295///                 },
296///                 "is_student": {
297///                     "type": "boolean"
298///                 }
299///             },
300///             "required": ["name", "age", "is_student"]
301///         }
302///     }
303/// "#;
304/// let structured_output: StructuredOutputFormat = serde_json::from_str(response_format).unwrap();
305/// assert_eq!(structured_output.name, "Student");
306/// assert_eq!(structured_output.description, Some("A student object".to_string()));
307/// ```
308#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
309
310pub struct StructuredOutputFormat {
311    /// Name of the schema
312    pub name: String,
313    /// The description of the schema
314    pub description: Option<String>,
315    /// The JSON schema for the structured output
316    pub schema: Option<Value>,
317    /// Whether to enable strict schema adherence
318    pub strict: Option<bool>,
319}
320
321/// Represents a tool that can be used in chat
322#[derive(Debug, Clone, Serialize)]
323pub struct Tool {
324    /// The type of tool (e.g. "function")
325    #[serde(rename = "type")]
326    pub tool_type: String,
327    /// The function definition if this is a function tool
328    pub function: FunctionTool,
329}
330
331/// Tool choice determines how the LLM uses available tools.
332/// The behavior is standardized across different LLM providers.
333#[derive(Debug, Clone, Default)]
334pub enum ToolChoice {
335    /// Model can use any tool, but it must use at least one.
336    /// This is useful when you want to force the model to use tools.
337    Any,
338
339    /// Model can use any tool, and may elect to use none.
340    /// This is the default behavior and gives the model flexibility.
341    #[default]
342    Auto,
343
344    /// Model must use the specified tool and only the specified tool.
345    /// The string parameter is the name of the required tool.
346    /// This is useful when you want the model to call a specific function.
347    Tool(String),
348
349    /// Explicitly disables the use of tools.
350    /// The model will not use any tools even if they are provided.
351    None,
352}
353
354impl Serialize for ToolChoice {
355    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
356    where
357        S: serde::Serializer,
358    {
359        match self {
360            ToolChoice::Any => serializer.serialize_str("required"),
361            ToolChoice::Auto => serializer.serialize_str("auto"),
362            ToolChoice::None => serializer.serialize_str("none"),
363            ToolChoice::Tool(name) => {
364                use serde::ser::SerializeMap;
365
366                // For tool_choice: {"type": "function", "function": {"name": "function_name"}}
367                let mut map = serializer.serialize_map(Some(2))?;
368                map.serialize_entry("type", "function")?;
369
370                // Inner function object
371                let mut function_obj = std::collections::HashMap::new();
372                function_obj.insert("name", name.as_str());
373
374                map.serialize_entry("function", &function_obj)?;
375                map.end()
376            }
377        }
378    }
379}
380
381pub trait ChatResponse: std::fmt::Debug + std::fmt::Display + Send + Sync {
382    fn text(&self) -> Option<String>;
383    fn tool_calls(&self) -> Option<Vec<ToolCall>>;
384    fn thinking(&self) -> Option<String> {
385        None
386    }
387    fn usage(&self) -> Option<Usage> {
388        None
389    }
390}
391
392/// Trait for providers that support chat-style interactions.
393#[async_trait]
394pub trait ChatProvider: Sync + Send {
395    /// Sends a chat request to the provider with a sequence of messages.
396    ///
397    /// # Arguments
398    ///
399    /// * `messages` - The conversation history as a slice of chat messages
400    /// * `json_schema` - Optional json_schema for the response format
401    ///
402    /// # Returns
403    ///
404    /// The provider's response text or an error
405    async fn chat(
406        &self,
407        messages: &[ChatMessage],
408        json_schema: Option<StructuredOutputFormat>,
409    ) -> Result<Box<dyn ChatResponse>, LLMError> {
410        self.chat_with_tools(messages, None, json_schema).await
411    }
412
413    /// Sends a chat request to the provider with a sequence of messages and tools.
414    ///
415    /// # Arguments
416    ///
417    /// * `messages` - The conversation history as a slice of chat messages
418    /// * `tools` - Optional slice of tools to use in the chat
419    /// * `json_schema` - Optional json_schema for the response format
420    ///
421    /// # Returns
422    ///
423    /// The provider's response text or an error
424    async fn chat_with_tools(
425        &self,
426        messages: &[ChatMessage],
427        tools: Option<&[Tool]>,
428        json_schema: Option<StructuredOutputFormat>,
429    ) -> Result<Box<dyn ChatResponse>, LLMError>;
430
431    /// Sends a chat with web search request to the provider
432    ///
433    /// # Arguments
434    ///
435    /// * `input` - The input message
436    ///
437    /// # Returns
438    ///
439    /// The provider's response text or an error
440    async fn chat_with_web_search(
441        &self,
442        _input: String,
443    ) -> Result<Box<dyn ChatResponse>, LLMError> {
444        Err(LLMError::Generic(
445            "Web search not supported for this provider".to_string(),
446        ))
447    }
448
449    /// Sends a streaming chat request to the provider with a sequence of messages.
450    ///
451    /// # Arguments
452    ///
453    /// * `messages` - The conversation history as a slice of chat messages
454    /// * `json_schema` - Optional json_schema for the response format
455    ///
456    /// # Returns
457    ///
458    /// A stream of text tokens or an error
459    async fn chat_stream(
460        &self,
461        _messages: &[ChatMessage],
462        _json_schema: Option<StructuredOutputFormat>,
463    ) -> Result<std::pin::Pin<Box<dyn Stream<Item = Result<String, LLMError>> + Send>>, LLMError>
464    {
465        Err(LLMError::Generic(
466            "Streaming not supported for this provider".to_string(),
467        ))
468    }
469
470    /// Sends a streaming chat request that returns structured response chunks.
471    ///
472    /// ⚠️ Getting usage metadata while streaming have been noticed to be a unstable depending on the provider
473    /// (it can be missing).
474    ///
475    /// This method returns a stream of `StreamResponse` objects that mimic OpenAI's
476    /// streaming response format with `.choices[0].delta.content` and `.usage`.
477    ///
478    /// # Arguments
479    ///
480    /// * `messages` - The conversation history as a slice of chat messages
481    /// * `tools` - Optional slice of tools to use in the chat
482    /// * `json_schema` - Optional json_schema for the response format
483    ///
484    /// # Returns
485    ///
486    /// A stream of `StreamResponse` objects or an error
487    async fn chat_stream_struct(
488        &self,
489        _messages: &[ChatMessage],
490        _tools: Option<&[Tool]>,
491        _json_schema: Option<StructuredOutputFormat>,
492    ) -> Result<
493        std::pin::Pin<Box<dyn Stream<Item = Result<StreamResponse, LLMError>> + Send>>,
494        LLMError,
495    > {
496        Err(LLMError::Generic(
497            "Structured streaming not supported for this provider".to_string(),
498        ))
499    }
500
501    /// Sends a streaming chat request with tool support.
502    ///
503    /// Returns a stream of `StreamChunk` which can be text deltas or tool call events.
504    /// When `stop_reason` is "tool_use", the caller should execute the tool(s)
505    /// and continue the conversation.
506    ///
507    /// This method is ideal for agentic workflows where you want to stream text
508    /// output to the user while still receiving tool call requests.
509    ///
510    /// # Arguments
511    ///
512    /// * `messages` - The conversation history as a slice of chat messages
513    /// * `tools` - Optional slice of tools available for the model to use
514    /// * `json_schema` - Optional json_schema for the response format
515    ///
516    /// # Returns
517    ///
518    /// A stream of `StreamChunk` items or an error
519    ///
520    /// # Example
521    ///
522    /// ```ignore
523    /// use futures::StreamExt;
524    ///
525    /// let mut stream = client
526    ///     .chat_stream_with_tools(&messages, Some(&tools))
527    ///     .await?;
528    ///
529    /// let mut tool_calls = Vec::new();
530    /// while let Some(chunk) = stream.next().await {
531    ///     match chunk? {
532    ///         StreamChunk::Text(text) => print!("{}", text),
533    ///         StreamChunk::ToolUseComplete { tool_call, .. } => {
534    ///             tool_calls.push(tool_call);
535    ///         }
536    ///         StreamChunk::Done { stop_reason } => {
537    ///             if stop_reason == "tool_use" {
538    ///                 // Execute tool_calls and continue conversation
539    ///             }
540    ///         }
541    ///         _ => {}
542    ///     }
543    /// }
544    /// ```
545    async fn chat_stream_with_tools(
546        &self,
547        _messages: &[ChatMessage],
548        _tools: Option<&[Tool]>,
549        _json_schema: Option<StructuredOutputFormat>,
550    ) -> Result<Pin<Box<dyn Stream<Item = Result<StreamChunk, LLMError>> + Send>>, LLMError> {
551        Err(LLMError::Generic(
552            "Streaming with tools not supported for this provider".to_string(),
553        ))
554    }
555}
556
557impl fmt::Display for ReasoningEffort {
558    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
559        match self {
560            ReasoningEffort::Low => write!(f, "low"),
561            ReasoningEffort::Medium => write!(f, "medium"),
562            ReasoningEffort::High => write!(f, "high"),
563        }
564    }
565}
566
567impl ChatMessage {
568    /// Create a new builder for a user message
569    pub fn user() -> ChatMessageBuilder {
570        ChatMessageBuilder::new(ChatRole::User)
571    }
572
573    /// Create a new builder for an assistant message
574    pub fn assistant() -> ChatMessageBuilder {
575        ChatMessageBuilder::new(ChatRole::Assistant)
576    }
577}
578
579/// Builder for ChatMessage
580#[derive(Debug)]
581pub struct ChatMessageBuilder {
582    role: ChatRole,
583    message_type: MessageType,
584    content: String,
585}
586
587impl ChatMessageBuilder {
588    /// Create a new ChatMessageBuilder with specified role
589    pub fn new(role: ChatRole) -> Self {
590        Self {
591            role,
592            message_type: MessageType::default(),
593            content: String::default(),
594        }
595    }
596
597    /// Set the message content
598    pub fn content<S: Into<String>>(mut self, content: S) -> Self {
599        self.content = content.into();
600        self
601    }
602
603    /// Set the message type as Image
604    pub fn image(mut self, image_mime: ImageMime, raw_bytes: Vec<u8>) -> Self {
605        self.message_type = MessageType::Image((image_mime, raw_bytes));
606        self
607    }
608
609    /// Set the message type as Image
610    pub fn pdf(mut self, raw_bytes: Vec<u8>) -> Self {
611        self.message_type = MessageType::Pdf(raw_bytes);
612        self
613    }
614
615    /// Set the message type as ImageURL
616    pub fn image_url(mut self, url: impl Into<String>) -> Self {
617        self.message_type = MessageType::ImageURL(url.into());
618        self
619    }
620
621    /// Set the message type as ToolUse
622    pub fn tool_use(mut self, tools: Vec<ToolCall>) -> Self {
623        self.message_type = MessageType::ToolUse(tools);
624        self
625    }
626
627    /// Set the message type as ToolResult
628    pub fn tool_result(mut self, tools: Vec<ToolCall>) -> Self {
629        self.message_type = MessageType::ToolResult(tools);
630        self
631    }
632
633    /// Build the ChatMessage
634    pub fn build(self) -> ChatMessage {
635        ChatMessage {
636            role: self.role,
637            message_type: self.message_type,
638            content: self.content,
639        }
640    }
641}
642
643/// Creates a Server-Sent Events (SSE) stream from an HTTP response.
644///
645/// # Arguments
646///
647/// * `response` - The HTTP response from the streaming API
648/// * `parser` - Function to parse each SSE chunk into optional text content
649///
650/// # Returns
651///
652/// A pinned stream of text tokens or an error
653#[allow(dead_code)]
654pub(crate) fn create_sse_stream<F>(
655    response: reqwest::Response,
656    parser: F,
657) -> std::pin::Pin<Box<dyn Stream<Item = Result<String, LLMError>> + Send>>
658where
659    F: Fn(&str) -> Result<Option<String>, LLMError> + Send + 'static,
660{
661    let stream = response
662        .bytes_stream()
663        .scan(
664            (String::default(), Vec::default()),
665            move |(buffer, utf8_buffer), chunk| {
666                let result = match chunk {
667                    Ok(bytes) => {
668                        utf8_buffer.extend_from_slice(&bytes);
669
670                        match String::from_utf8(utf8_buffer.clone()) {
671                            Ok(text) => {
672                                buffer.push_str(&text);
673                                utf8_buffer.clear();
674                            }
675                            Err(e) => {
676                                let valid_up_to = e.utf8_error().valid_up_to();
677                                if valid_up_to > 0 {
678                                    // Safe to use from_utf8_lossy here since valid_up_to points to
679                                    // a valid UTF-8 boundary - no replacement characters will be introduced
680                                    let valid =
681                                        String::from_utf8_lossy(&utf8_buffer[..valid_up_to]);
682                                    buffer.push_str(&valid);
683                                    utf8_buffer.drain(..valid_up_to);
684                                }
685                            }
686                        }
687
688                        let mut results = Vec::default();
689
690                        while let Some(pos) = buffer.find("\n\n") {
691                            let event = buffer[..pos + 2].to_string();
692                            buffer.drain(..pos + 2);
693
694                            match parser(&event) {
695                                Ok(Some(content)) => results.push(Ok(content)),
696                                Ok(None) => {}
697                                Err(e) => results.push(Err(e)),
698                            }
699                        }
700
701                        Some(results)
702                    }
703                    Err(e) => Some(vec![Err(LLMError::HttpError(e.to_string()))]),
704                };
705
706                async move { result }
707            },
708        )
709        .flat_map(futures::stream::iter);
710
711    Box::pin(stream)
712}
713
714#[cfg(not(target_arch = "wasm32"))]
715pub mod utils {
716    use crate::error::LLMError;
717    use reqwest::Response;
718    pub async fn check_response_status(response: Response) -> Result<Response, LLMError> {
719        if !response.status().is_success() {
720            let status = response.status();
721            let error_text = response.text().await?;
722            return Err(LLMError::ResponseFormatError {
723                message: format!("API returned error status: {status}"),
724                raw_response: error_text,
725            });
726        }
727        Ok(response)
728    }
729}
730
731#[cfg(test)]
732mod tests {
733    use super::*;
734    use bytes::Bytes;
735    use futures::stream::StreamExt;
736
737    #[test]
738    fn test_chat_message_builder_user() {
739        let msg = ChatMessage::user().content("hello").build();
740        assert_eq!(msg.role, ChatRole::User);
741        assert_eq!(msg.content, "hello");
742        assert!(matches!(msg.message_type, MessageType::Text));
743    }
744
745    #[test]
746    fn test_chat_message_builder_assistant() {
747        let msg = ChatMessage::assistant().content("reply").build();
748        assert_eq!(msg.role, ChatRole::Assistant);
749        assert_eq!(msg.content, "reply");
750    }
751
752    #[test]
753    fn test_chat_message_builder_image() {
754        let msg = ChatMessage::user()
755            .content("describe")
756            .image(ImageMime::PNG, vec![1, 2, 3])
757            .build();
758        assert!(matches!(msg.message_type, MessageType::Image(_)));
759    }
760
761    #[test]
762    fn test_chat_message_builder_pdf() {
763        let msg = ChatMessage::user()
764            .content("read")
765            .pdf(vec![4, 5, 6])
766            .build();
767        assert!(matches!(msg.message_type, MessageType::Pdf(_)));
768    }
769
770    #[test]
771    fn test_chat_message_builder_tool_use() {
772        let tc = crate::ToolCall {
773            id: "t1".to_string(),
774            call_type: "function".to_string(),
775            function: crate::FunctionCall {
776                name: "tool".to_string(),
777                arguments: "{}".to_string(),
778            },
779        };
780        let msg = ChatMessage::assistant()
781            .content("calling tool")
782            .tool_use(vec![tc])
783            .build();
784        assert!(matches!(msg.message_type, MessageType::ToolUse(_)));
785    }
786
787    #[test]
788    fn test_chat_message_builder_tool_result() {
789        let tc = crate::ToolCall {
790            id: "t1".to_string(),
791            call_type: "function".to_string(),
792            function: crate::FunctionCall {
793                name: "tool".to_string(),
794                arguments: "result".to_string(),
795            },
796        };
797        let msg = ChatMessageBuilder::new(ChatRole::Tool)
798            .tool_result(vec![tc])
799            .build();
800        assert!(matches!(msg.message_type, MessageType::ToolResult(_)));
801        assert_eq!(msg.role, ChatRole::Tool);
802    }
803
804    #[test]
805    fn test_chat_role_display() {
806        assert_eq!(format!("{}", ChatRole::System), "system");
807        assert_eq!(format!("{}", ChatRole::User), "user");
808        assert_eq!(format!("{}", ChatRole::Assistant), "assistant");
809        assert_eq!(format!("{}", ChatRole::Tool), "tool");
810    }
811
812    #[test]
813    fn test_image_mime_mime_type() {
814        assert_eq!(ImageMime::JPEG.mime_type(), "image/jpeg");
815        assert_eq!(ImageMime::PNG.mime_type(), "image/png");
816        assert_eq!(ImageMime::GIF.mime_type(), "image/gif");
817        assert_eq!(ImageMime::WEBP.mime_type(), "image/webp");
818    }
819
820    #[test]
821    fn test_reasoning_effort_display() {
822        assert_eq!(format!("{}", ReasoningEffort::Low), "low");
823        assert_eq!(format!("{}", ReasoningEffort::Medium), "medium");
824        assert_eq!(format!("{}", ReasoningEffort::High), "high");
825    }
826
827    #[test]
828    fn test_tool_choice_serialization() {
829        let any_json = serde_json::to_value(&ToolChoice::Any).unwrap();
830        assert_eq!(any_json, "required");
831
832        let auto_json = serde_json::to_value(&ToolChoice::Auto).unwrap();
833        assert_eq!(auto_json, "auto");
834
835        let none_json = serde_json::to_value(&ToolChoice::None).unwrap();
836        assert_eq!(none_json, "none");
837
838        let tool_json = serde_json::to_value(ToolChoice::Tool("my_func".to_string())).unwrap();
839        assert_eq!(tool_json["type"], "function");
840        assert_eq!(tool_json["function"]["name"], "my_func");
841    }
842
843    #[test]
844    fn test_structured_output_format_roundtrip() {
845        let format = StructuredOutputFormat {
846            name: "Test".to_string(),
847            description: Some("A test".to_string()),
848            schema: Some(serde_json::json!({"type": "object"})),
849            strict: Some(true),
850        };
851        let json = serde_json::to_string(&format).unwrap();
852        let parsed: StructuredOutputFormat = serde_json::from_str(&json).unwrap();
853        assert_eq!(parsed, format);
854    }
855
856    #[test]
857    fn test_structured_output_format_minimal() {
858        let json_str = r#"{"name":"Minimal"}"#;
859        let parsed: StructuredOutputFormat = serde_json::from_str(json_str).unwrap();
860        assert_eq!(parsed.name, "Minimal");
861        assert_eq!(parsed.description, None);
862        assert_eq!(parsed.schema, None);
863        assert_eq!(parsed.strict, None);
864    }
865
866    #[test]
867    fn test_chat_message_builder_image_url() {
868        let msg = ChatMessage::user()
869            .image_url("https://example.com/img.png")
870            .content("describe this")
871            .build();
872        assert!(matches!(msg.message_type, MessageType::ImageURL(_)));
873    }
874
875    #[tokio::test]
876    async fn test_create_sse_stream_handles_split_utf8() {
877        let test_data = "data: Positive reactions\n\n".as_bytes();
878
879        let chunks: Vec<Result<Bytes, reqwest::Error>> = vec![
880            Ok(Bytes::from(&test_data[..10])),
881            Ok(Bytes::from(&test_data[10..])),
882        ];
883
884        let mock_response = create_mock_response(chunks);
885
886        let parser = |event: &str| -> Result<Option<String>, LLMError> {
887            if let Some(content) = event.strip_prefix("data: ") {
888                let content = content.trim();
889                if content.is_empty() {
890                    return Ok(None);
891                }
892                Ok(Some(content.to_string()))
893            } else {
894                Ok(None)
895            }
896        };
897
898        let mut stream = create_sse_stream(mock_response, parser);
899
900        let mut results = Vec::new();
901        while let Some(result) = stream.next().await {
902            results.push(result);
903        }
904
905        assert_eq!(results.len(), 1);
906        assert_eq!(results[0].as_ref().unwrap(), "Positive reactions");
907    }
908
909    #[tokio::test]
910    async fn test_create_sse_stream_handles_split_sse_events() {
911        let event1 = "data: First event\n\n";
912        let event2 = "data: Second event\n\n";
913        let combined = format!("{}{}", event1, event2);
914        let test_data = combined.as_bytes().to_vec();
915
916        let split_point = event1.len() + 5;
917        let chunks: Vec<Result<Bytes, reqwest::Error>> = vec![
918            Ok(Bytes::from(test_data[..split_point].to_vec())),
919            Ok(Bytes::from(test_data[split_point..].to_vec())),
920        ];
921
922        let mock_response = create_mock_response(chunks);
923
924        let parser = |event: &str| -> Result<Option<String>, LLMError> {
925            if let Some(content) = event.strip_prefix("data: ") {
926                let content = content.trim();
927                if content.is_empty() {
928                    return Ok(None);
929                }
930                Ok(Some(content.to_string()))
931            } else {
932                Ok(None)
933            }
934        };
935
936        let mut stream = create_sse_stream(mock_response, parser);
937
938        let mut results = Vec::new();
939        while let Some(result) = stream.next().await {
940            results.push(result);
941        }
942
943        assert_eq!(results.len(), 2);
944        assert_eq!(results[0].as_ref().unwrap(), "First event");
945        assert_eq!(results[1].as_ref().unwrap(), "Second event");
946    }
947
948    #[tokio::test]
949    async fn test_create_sse_stream_handles_multibyte_utf8_split() {
950        let multibyte_char = "✨";
951        let event = format!("data: Star {}\n\n", multibyte_char);
952        let test_data = event.as_bytes().to_vec();
953
954        let emoji_start = event.find(multibyte_char).unwrap();
955        let split_in_emoji = emoji_start + 1;
956
957        let chunks: Vec<Result<Bytes, reqwest::Error>> = vec![
958            Ok(Bytes::from(test_data[..split_in_emoji].to_vec())),
959            Ok(Bytes::from(test_data[split_in_emoji..].to_vec())),
960        ];
961
962        let mock_response = create_mock_response(chunks);
963
964        let parser = |event: &str| -> Result<Option<String>, LLMError> {
965            if let Some(content) = event.strip_prefix("data: ") {
966                let content = content.trim();
967                if content.is_empty() {
968                    return Ok(None);
969                }
970                Ok(Some(content.to_string()))
971            } else {
972                Ok(None)
973            }
974        };
975
976        let mut stream = create_sse_stream(mock_response, parser);
977
978        let mut results = Vec::new();
979        while let Some(result) = stream.next().await {
980            results.push(result);
981        }
982
983        assert_eq!(results.len(), 1);
984        assert_eq!(
985            results[0].as_ref().unwrap(),
986            &format!("Star {}", multibyte_char)
987        );
988    }
989
990    fn create_mock_response(chunks: Vec<Result<Bytes, reqwest::Error>>) -> reqwest::Response {
991        use http_body_util::StreamBody;
992        use reqwest::Body;
993
994        let frame_stream = futures::stream::iter(
995            chunks
996                .into_iter()
997                .map(|chunk| chunk.map(hyper::body::Frame::data)),
998        );
999
1000        let body = StreamBody::new(frame_stream);
1001        let body = Body::wrap(body);
1002
1003        let http_response = http::Response::builder().status(200).body(body).unwrap();
1004
1005        http_response.into()
1006    }
1007}