Skip to main content

autoagents_llm/chat/
mod.rs

1use std::collections::HashMap;
2use std::fmt;
3use std::pin::Pin;
4
5use async_trait::async_trait;
6use futures::stream::{Stream, StreamExt};
7use serde::{Deserialize, Serialize};
8use serde_json::Value;
9
10use crate::{ToolCall, error::LLMError};
11
12/// Usage metadata for a chat response.
13#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
14pub struct Usage {
15    /// Number of tokens in the prompt
16    #[serde(alias = "input_tokens")]
17    pub prompt_tokens: u32,
18    /// Number of tokens in the completion
19    #[serde(alias = "output_tokens")]
20    pub completion_tokens: u32,
21    /// Total number of tokens used
22    pub total_tokens: u32,
23    /// Breakdown of completion tokens, if available
24    #[serde(
25        skip_serializing_if = "Option::is_none",
26        alias = "output_tokens_details"
27    )]
28    pub completion_tokens_details: Option<CompletionTokensDetails>,
29    /// Breakdown of prompt tokens, if available
30    #[serde(
31        skip_serializing_if = "Option::is_none",
32        alias = "input_tokens_details"
33    )]
34    pub prompt_tokens_details: Option<PromptTokensDetails>,
35}
36
37/// Stream response chunk that mimics OpenAI's streaming response format
38#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct StreamResponse {
40    /// Array of choices in the response
41    pub choices: Vec<StreamChoice>,
42    /// Usage metadata, typically present in the final chunk
43    #[serde(skip_serializing_if = "Option::is_none")]
44    pub usage: Option<Usage>,
45}
46
47/// Individual choice in a streaming response
48#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct StreamChoice {
50    /// Delta containing the incremental content
51    pub delta: StreamDelta,
52}
53
54/// Delta content in a streaming response
55#[derive(Debug, Clone, Serialize, Deserialize)]
56pub struct StreamDelta {
57    /// The incremental content, if any
58    #[serde(skip_serializing_if = "Option::is_none")]
59    pub content: Option<String>,
60    /// The incremental tool calls, if any
61    #[serde(skip_serializing_if = "Option::is_none")]
62    pub tool_calls: Option<Vec<ToolCall>>,
63}
64
65/// A streaming chunk that can be either text or a tool call event.
66///
67/// This enum provides a unified representation of streaming events
68/// when using `chat_stream_with_tools`. It allows callers to receive
69/// text deltas as they arrive while also handling tool use blocks.
70#[derive(Debug, Clone, Serialize, Deserialize)]
71pub enum StreamChunk {
72    /// Text content delta
73    Text(String),
74
75    /// Tool use block started (contains tool id and name)
76    ToolUseStart {
77        /// The index of this content block in the response
78        index: usize,
79        /// The unique ID for this tool use
80        id: String,
81        /// The name of the tool being called
82        name: String,
83    },
84
85    /// Tool use input JSON delta (partial JSON string)
86    ToolUseInputDelta {
87        /// The index of this content block
88        index: usize,
89        /// Partial JSON string for the tool input
90        partial_json: String,
91    },
92
93    /// Tool use block complete with assembled ToolCall
94    ToolUseComplete {
95        /// The index of this content block
96        index: usize,
97        /// The complete tool call with id, name, and parsed arguments
98        tool_call: ToolCall,
99    },
100
101    /// Stream ended with stop reason
102    Done {
103        /// The reason the stream stopped (e.g., "end_turn", "tool_use")
104        stop_reason: String,
105    },
106    Usage(Usage),
107}
108
109/// Breakdown of completion tokens.
110#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
111pub struct CompletionTokensDetails {
112    /// Tokens used for reasoning (for reasoning models)
113    #[serde(skip_serializing_if = "Option::is_none")]
114    pub reasoning_tokens: Option<u32>,
115    /// Tokens used for audio output
116    #[serde(skip_serializing_if = "Option::is_none")]
117    pub audio_tokens: Option<u32>,
118}
119
120/// Breakdown of prompt tokens.
121#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
122pub struct PromptTokensDetails {
123    /// Tokens used for cached content
124    #[serde(skip_serializing_if = "Option::is_none")]
125    pub cached_tokens: Option<u32>,
126    /// Tokens used for audio input
127    #[serde(skip_serializing_if = "Option::is_none")]
128    pub audio_tokens: Option<u32>,
129}
130
131/// Role of a participant in a chat conversation.
132#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
133pub enum ChatRole {
134    // The system Prompt
135    System,
136    /// The user/human participant in the conversation
137    User,
138    /// The AI assistant participant in the conversation
139    Assistant,
140    /// Tool/function response
141    Tool,
142}
143
144impl fmt::Display for ChatRole {
145    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
146        let value = match self {
147            ChatRole::System => "system",
148            ChatRole::User => "user",
149            ChatRole::Assistant => "assistant",
150            ChatRole::Tool => "tool",
151        };
152        f.write_str(value)
153    }
154}
155
156/// The supported MIME type of an image.
157#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
158#[non_exhaustive]
159pub enum ImageMime {
160    /// JPEG image
161    JPEG,
162    /// PNG image
163    PNG,
164    /// GIF image
165    GIF,
166    /// WebP image
167    WEBP,
168}
169
170impl ImageMime {
171    pub fn mime_type(&self) -> &'static str {
172        match self {
173            ImageMime::JPEG => "image/jpeg",
174            ImageMime::PNG => "image/png",
175            ImageMime::GIF => "image/gif",
176            ImageMime::WEBP => "image/webp",
177        }
178    }
179}
180
181/// The type of a message in a chat conversation.
182#[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)]
183pub enum MessageType {
184    /// A text message
185    #[default]
186    Text,
187    /// An image message
188    Image((ImageMime, Vec<u8>)),
189    /// PDF message
190    Pdf(Vec<u8>),
191    /// An image URL message
192    ImageURL(String),
193    /// A tool use
194    ToolUse(Vec<ToolCall>),
195    /// Tool result
196    ToolResult(Vec<ToolCall>),
197}
198
199/// The type of reasoning effort for a message in a chat conversation.
200pub enum ReasoningEffort {
201    /// Low reasoning effort
202    Low,
203    /// Medium reasoning effort
204    Medium,
205    /// High reasoning effort
206    High,
207}
208
209/// A single message in a chat conversation.
210#[derive(Debug, Clone, Serialize, Deserialize)]
211pub struct ChatMessage {
212    /// The role of who sent this message (user or assistant)
213    pub role: ChatRole,
214    /// The type of the message (text, image, audio, video, etc)
215    pub message_type: MessageType,
216    /// The text content of the message
217    pub content: String,
218}
219
220/// Represents a parameter in a function tool
221#[derive(Debug, Clone, Serialize)]
222pub struct ParameterProperty {
223    /// The type of the parameter (e.g. "string", "number", "array", etc)
224    #[serde(rename = "type")]
225    pub property_type: String,
226    /// Description of what the parameter does
227    pub description: String,
228    /// When type is "array", this defines the type of the array items
229    #[serde(skip_serializing_if = "Option::is_none")]
230    pub items: Option<Box<ParameterProperty>>,
231    /// When type is "enum", this defines the possible values for the parameter
232    #[serde(skip_serializing_if = "Option::is_none", rename = "enum")]
233    pub enum_list: Option<Vec<String>>,
234}
235
236/// Represents the parameters schema for a function tool
237#[derive(Debug, Clone, Serialize)]
238pub struct ParametersSchema {
239    /// The type of the parameters object (usually "object")
240    #[serde(rename = "type")]
241    pub schema_type: String,
242    /// Map of parameter names to their properties
243    pub properties: HashMap<String, ParameterProperty>,
244    /// List of required parameter names
245    pub required: Vec<String>,
246}
247
248/// Represents a function definition for a tool.
249///
250/// The `parameters` field stores the JSON Schema describing the function
251/// arguments.  It is kept as a raw `serde_json::Value` to allow arbitrary
252/// complexity (nested arrays/objects, `oneOf`, etc.) without requiring a
253/// bespoke Rust structure.
254///
255/// Builder helpers can still generate simple schemas automatically, but the
256/// user may also provide any valid schema directly.
257#[derive(Debug, Clone, Serialize)]
258pub struct FunctionTool {
259    /// Name of the function
260    pub name: String,
261    /// Human-readable description
262    pub description: String,
263    /// JSON Schema describing the parameters
264    pub parameters: Value,
265}
266
267/// Defines rules for structured output responses based on [OpenAI's structured output requirements](https://platform.openai.com/docs/api-reference/chat/create#chat-create-response_format).
268/// Individual providers may have additional requirements or restrictions, but these should be handled by each provider's backend implementation.
269///
270/// If you plan on deserializing into this struct, make sure the source text has a `"name"` field, since that's technically the only thing required by OpenAI.
271///
272/// ## Example
273///
274/// ```
275/// use autoagents_llm::chat::StructuredOutputFormat;
276/// use serde_json::json;
277///
278/// let response_format = r#"
279///     {
280///         "name": "Student",
281///         "description": "A student object",
282///         "schema": {
283///             "type": "object",
284///             "properties": {
285///                 "name": {
286///                     "type": "string"
287///                 },
288///                 "age": {
289///                     "type": "integer"
290///                 },
291///                 "is_student": {
292///                     "type": "boolean"
293///                 }
294///             },
295///             "required": ["name", "age", "is_student"]
296///         }
297///     }
298/// "#;
299/// let structured_output: StructuredOutputFormat = serde_json::from_str(response_format).unwrap();
300/// assert_eq!(structured_output.name, "Student");
301/// assert_eq!(structured_output.description, Some("A student object".to_string()));
302/// ```
303#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
304
305pub struct StructuredOutputFormat {
306    /// Name of the schema
307    pub name: String,
308    /// The description of the schema
309    pub description: Option<String>,
310    /// The JSON schema for the structured output
311    pub schema: Option<Value>,
312    /// Whether to enable strict schema adherence
313    pub strict: Option<bool>,
314}
315
316/// Represents a tool that can be used in chat
317#[derive(Debug, Clone, Serialize)]
318pub struct Tool {
319    /// The type of tool (e.g. "function")
320    #[serde(rename = "type")]
321    pub tool_type: String,
322    /// The function definition if this is a function tool
323    pub function: FunctionTool,
324}
325
326/// Tool choice determines how the LLM uses available tools.
327/// The behavior is standardized across different LLM providers.
328#[derive(Debug, Clone, Default)]
329pub enum ToolChoice {
330    /// Model can use any tool, but it must use at least one.
331    /// This is useful when you want to force the model to use tools.
332    Any,
333
334    /// Model can use any tool, and may elect to use none.
335    /// This is the default behavior and gives the model flexibility.
336    #[default]
337    Auto,
338
339    /// Model must use the specified tool and only the specified tool.
340    /// The string parameter is the name of the required tool.
341    /// This is useful when you want the model to call a specific function.
342    Tool(String),
343
344    /// Explicitly disables the use of tools.
345    /// The model will not use any tools even if they are provided.
346    None,
347}
348
349impl Serialize for ToolChoice {
350    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
351    where
352        S: serde::Serializer,
353    {
354        match self {
355            ToolChoice::Any => serializer.serialize_str("required"),
356            ToolChoice::Auto => serializer.serialize_str("auto"),
357            ToolChoice::None => serializer.serialize_str("none"),
358            ToolChoice::Tool(name) => {
359                use serde::ser::SerializeMap;
360
361                // For tool_choice: {"type": "function", "function": {"name": "function_name"}}
362                let mut map = serializer.serialize_map(Some(2))?;
363                map.serialize_entry("type", "function")?;
364
365                // Inner function object
366                let mut function_obj = std::collections::HashMap::new();
367                function_obj.insert("name", name.as_str());
368
369                map.serialize_entry("function", &function_obj)?;
370                map.end()
371            }
372        }
373    }
374}
375
376pub trait ChatResponse: std::fmt::Debug + std::fmt::Display + Send + Sync {
377    fn text(&self) -> Option<String>;
378    fn tool_calls(&self) -> Option<Vec<ToolCall>>;
379    fn thinking(&self) -> Option<String> {
380        None
381    }
382    fn usage(&self) -> Option<Usage> {
383        None
384    }
385}
386
387/// Trait for providers that support chat-style interactions.
388#[async_trait]
389pub trait ChatProvider: Sync + Send {
390    /// Sends a chat request to the provider with a sequence of messages.
391    ///
392    /// # Arguments
393    ///
394    /// * `messages` - The conversation history as a slice of chat messages
395    /// * `json_schema` - Optional json_schema for the response format
396    ///
397    /// # Returns
398    ///
399    /// The provider's response text or an error
400    async fn chat(
401        &self,
402        messages: &[ChatMessage],
403        json_schema: Option<StructuredOutputFormat>,
404    ) -> Result<Box<dyn ChatResponse>, LLMError> {
405        self.chat_with_tools(messages, None, json_schema).await
406    }
407
408    /// Sends a chat request to the provider with a sequence of messages and tools.
409    ///
410    /// # Arguments
411    ///
412    /// * `messages` - The conversation history as a slice of chat messages
413    /// * `tools` - Optional slice of tools to use in the chat
414    /// * `json_schema` - Optional json_schema for the response format
415    ///
416    /// # Returns
417    ///
418    /// The provider's response text or an error
419    async fn chat_with_tools(
420        &self,
421        messages: &[ChatMessage],
422        tools: Option<&[Tool]>,
423        json_schema: Option<StructuredOutputFormat>,
424    ) -> Result<Box<dyn ChatResponse>, LLMError>;
425
426    /// Sends a chat with web search request to the provider
427    ///
428    /// # Arguments
429    ///
430    /// * `input` - The input message
431    ///
432    /// # Returns
433    ///
434    /// The provider's response text or an error
435    async fn chat_with_web_search(
436        &self,
437        _input: String,
438    ) -> Result<Box<dyn ChatResponse>, LLMError> {
439        Err(LLMError::Generic(
440            "Web search not supported for this provider".to_string(),
441        ))
442    }
443
444    /// Sends a streaming chat request to the provider with a sequence of messages.
445    ///
446    /// # Arguments
447    ///
448    /// * `messages` - The conversation history as a slice of chat messages
449    /// * `json_schema` - Optional json_schema for the response format
450    ///
451    /// # Returns
452    ///
453    /// A stream of text tokens or an error
454    async fn chat_stream(
455        &self,
456        _messages: &[ChatMessage],
457        _json_schema: Option<StructuredOutputFormat>,
458    ) -> Result<std::pin::Pin<Box<dyn Stream<Item = Result<String, LLMError>> + Send>>, LLMError>
459    {
460        Err(LLMError::Generic(
461            "Streaming not supported for this provider".to_string(),
462        ))
463    }
464
465    /// Sends a streaming chat request that returns structured response chunks.
466    ///
467    /// ⚠️ Getting usage metadata while streaming have been noticed to be a unstable depending on the provider
468    /// (it can be missing).
469    ///
470    /// This method returns a stream of `StreamResponse` objects that mimic OpenAI's
471    /// streaming response format with `.choices[0].delta.content` and `.usage`.
472    ///
473    /// # Arguments
474    ///
475    /// * `messages` - The conversation history as a slice of chat messages
476    /// * `tools` - Optional slice of tools to use in the chat
477    /// * `json_schema` - Optional json_schema for the response format
478    ///
479    /// # Returns
480    ///
481    /// A stream of `StreamResponse` objects or an error
482    async fn chat_stream_struct(
483        &self,
484        _messages: &[ChatMessage],
485        _tools: Option<&[Tool]>,
486        _json_schema: Option<StructuredOutputFormat>,
487    ) -> Result<
488        std::pin::Pin<Box<dyn Stream<Item = Result<StreamResponse, LLMError>> + Send>>,
489        LLMError,
490    > {
491        Err(LLMError::Generic(
492            "Structured streaming not supported for this provider".to_string(),
493        ))
494    }
495
496    /// Sends a streaming chat request with tool support.
497    ///
498    /// Returns a stream of `StreamChunk` which can be text deltas or tool call events.
499    /// When `stop_reason` is "tool_use", the caller should execute the tool(s)
500    /// and continue the conversation.
501    ///
502    /// This method is ideal for agentic workflows where you want to stream text
503    /// output to the user while still receiving tool call requests.
504    ///
505    /// # Arguments
506    ///
507    /// * `messages` - The conversation history as a slice of chat messages
508    /// * `tools` - Optional slice of tools available for the model to use
509    /// * `json_schema` - Optional json_schema for the response format
510    ///
511    /// # Returns
512    ///
513    /// A stream of `StreamChunk` items or an error
514    ///
515    /// # Example
516    ///
517    /// ```ignore
518    /// use futures::StreamExt;
519    ///
520    /// let mut stream = client
521    ///     .chat_stream_with_tools(&messages, Some(&tools))
522    ///     .await?;
523    ///
524    /// let mut tool_calls = Vec::new();
525    /// while let Some(chunk) = stream.next().await {
526    ///     match chunk? {
527    ///         StreamChunk::Text(text) => print!("{}", text),
528    ///         StreamChunk::ToolUseComplete { tool_call, .. } => {
529    ///             tool_calls.push(tool_call);
530    ///         }
531    ///         StreamChunk::Done { stop_reason } => {
532    ///             if stop_reason == "tool_use" {
533    ///                 // Execute tool_calls and continue conversation
534    ///             }
535    ///         }
536    ///         _ => {}
537    ///     }
538    /// }
539    /// ```
540    async fn chat_stream_with_tools(
541        &self,
542        _messages: &[ChatMessage],
543        _tools: Option<&[Tool]>,
544        _json_schema: Option<StructuredOutputFormat>,
545    ) -> Result<Pin<Box<dyn Stream<Item = Result<StreamChunk, LLMError>> + Send>>, LLMError> {
546        Err(LLMError::Generic(
547            "Streaming with tools not supported for this provider".to_string(),
548        ))
549    }
550}
551
552impl fmt::Display for ReasoningEffort {
553    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
554        match self {
555            ReasoningEffort::Low => write!(f, "low"),
556            ReasoningEffort::Medium => write!(f, "medium"),
557            ReasoningEffort::High => write!(f, "high"),
558        }
559    }
560}
561
562impl ChatMessage {
563    /// Create a new builder for a user message
564    pub fn user() -> ChatMessageBuilder {
565        ChatMessageBuilder::new(ChatRole::User)
566    }
567
568    /// Create a new builder for an assistant message
569    pub fn assistant() -> ChatMessageBuilder {
570        ChatMessageBuilder::new(ChatRole::Assistant)
571    }
572}
573
574/// Builder for ChatMessage
575#[derive(Debug)]
576pub struct ChatMessageBuilder {
577    role: ChatRole,
578    message_type: MessageType,
579    content: String,
580}
581
582impl ChatMessageBuilder {
583    /// Create a new ChatMessageBuilder with specified role
584    pub fn new(role: ChatRole) -> Self {
585        Self {
586            role,
587            message_type: MessageType::default(),
588            content: String::new(),
589        }
590    }
591
592    /// Set the message content
593    pub fn content<S: Into<String>>(mut self, content: S) -> Self {
594        self.content = content.into();
595        self
596    }
597
598    /// Set the message type as Image
599    pub fn image(mut self, image_mime: ImageMime, raw_bytes: Vec<u8>) -> Self {
600        self.message_type = MessageType::Image((image_mime, raw_bytes));
601        self
602    }
603
604    /// Set the message type as Image
605    pub fn pdf(mut self, raw_bytes: Vec<u8>) -> Self {
606        self.message_type = MessageType::Pdf(raw_bytes);
607        self
608    }
609
610    /// Set the message type as ImageURL
611    pub fn image_url(mut self, url: impl Into<String>) -> Self {
612        self.message_type = MessageType::ImageURL(url.into());
613        self
614    }
615
616    /// Set the message type as ToolUse
617    pub fn tool_use(mut self, tools: Vec<ToolCall>) -> Self {
618        self.message_type = MessageType::ToolUse(tools);
619        self
620    }
621
622    /// Set the message type as ToolResult
623    pub fn tool_result(mut self, tools: Vec<ToolCall>) -> Self {
624        self.message_type = MessageType::ToolResult(tools);
625        self
626    }
627
628    /// Build the ChatMessage
629    pub fn build(self) -> ChatMessage {
630        ChatMessage {
631            role: self.role,
632            message_type: self.message_type,
633            content: self.content,
634        }
635    }
636}
637
638/// Creates a Server-Sent Events (SSE) stream from an HTTP response.
639///
640/// # Arguments
641///
642/// * `response` - The HTTP response from the streaming API
643/// * `parser` - Function to parse each SSE chunk into optional text content
644///
645/// # Returns
646///
647/// A pinned stream of text tokens or an error
648#[allow(dead_code)]
649pub(crate) fn create_sse_stream<F>(
650    response: reqwest::Response,
651    parser: F,
652) -> std::pin::Pin<Box<dyn Stream<Item = Result<String, LLMError>> + Send>>
653where
654    F: Fn(&str) -> Result<Option<String>, LLMError> + Send + 'static,
655{
656    let stream = response
657        .bytes_stream()
658        .scan(
659            (String::new(), Vec::new()),
660            move |(buffer, utf8_buffer), chunk| {
661                let result = match chunk {
662                    Ok(bytes) => {
663                        utf8_buffer.extend_from_slice(&bytes);
664
665                        match String::from_utf8(utf8_buffer.clone()) {
666                            Ok(text) => {
667                                buffer.push_str(&text);
668                                utf8_buffer.clear();
669                            }
670                            Err(e) => {
671                                let valid_up_to = e.utf8_error().valid_up_to();
672                                if valid_up_to > 0 {
673                                    // Safe to use from_utf8_lossy here since valid_up_to points to
674                                    // a valid UTF-8 boundary - no replacement characters will be introduced
675                                    let valid =
676                                        String::from_utf8_lossy(&utf8_buffer[..valid_up_to]);
677                                    buffer.push_str(&valid);
678                                    utf8_buffer.drain(..valid_up_to);
679                                }
680                            }
681                        }
682
683                        let mut results = Vec::new();
684
685                        while let Some(pos) = buffer.find("\n\n") {
686                            let event = buffer[..pos + 2].to_string();
687                            buffer.drain(..pos + 2);
688
689                            match parser(&event) {
690                                Ok(Some(content)) => results.push(Ok(content)),
691                                Ok(None) => {}
692                                Err(e) => results.push(Err(e)),
693                            }
694                        }
695
696                        Some(results)
697                    }
698                    Err(e) => Some(vec![Err(LLMError::HttpError(e.to_string()))]),
699                };
700
701                async move { result }
702            },
703        )
704        .flat_map(futures::stream::iter);
705
706    Box::pin(stream)
707}
708
709#[cfg(not(target_arch = "wasm32"))]
710pub mod utils {
711    use crate::error::LLMError;
712    use reqwest::Response;
713    pub async fn check_response_status(response: Response) -> Result<Response, LLMError> {
714        if !response.status().is_success() {
715            let status = response.status();
716            let error_text = response.text().await?;
717            return Err(LLMError::ResponseFormatError {
718                message: format!("API returned error status: {status}"),
719                raw_response: error_text,
720            });
721        }
722        Ok(response)
723    }
724}
725
726#[cfg(test)]
727mod tests {
728    use super::*;
729    use bytes::Bytes;
730    use futures::stream::StreamExt;
731
732    #[tokio::test]
733    async fn test_create_sse_stream_handles_split_utf8() {
734        let test_data = "data: Positive reactions\n\n".as_bytes();
735
736        let chunks: Vec<Result<Bytes, reqwest::Error>> = vec![
737            Ok(Bytes::from(&test_data[..10])),
738            Ok(Bytes::from(&test_data[10..])),
739        ];
740
741        let mock_response = create_mock_response(chunks);
742
743        let parser = |event: &str| -> Result<Option<String>, LLMError> {
744            if let Some(content) = event.strip_prefix("data: ") {
745                let content = content.trim();
746                if content.is_empty() {
747                    return Ok(None);
748                }
749                Ok(Some(content.to_string()))
750            } else {
751                Ok(None)
752            }
753        };
754
755        let mut stream = create_sse_stream(mock_response, parser);
756
757        let mut results = Vec::new();
758        while let Some(result) = stream.next().await {
759            results.push(result);
760        }
761
762        assert_eq!(results.len(), 1);
763        assert_eq!(results[0].as_ref().unwrap(), "Positive reactions");
764    }
765
766    #[tokio::test]
767    async fn test_create_sse_stream_handles_split_sse_events() {
768        let event1 = "data: First event\n\n";
769        let event2 = "data: Second event\n\n";
770        let combined = format!("{}{}", event1, event2);
771        let test_data = combined.as_bytes().to_vec();
772
773        let split_point = event1.len() + 5;
774        let chunks: Vec<Result<Bytes, reqwest::Error>> = vec![
775            Ok(Bytes::from(test_data[..split_point].to_vec())),
776            Ok(Bytes::from(test_data[split_point..].to_vec())),
777        ];
778
779        let mock_response = create_mock_response(chunks);
780
781        let parser = |event: &str| -> Result<Option<String>, LLMError> {
782            if let Some(content) = event.strip_prefix("data: ") {
783                let content = content.trim();
784                if content.is_empty() {
785                    return Ok(None);
786                }
787                Ok(Some(content.to_string()))
788            } else {
789                Ok(None)
790            }
791        };
792
793        let mut stream = create_sse_stream(mock_response, parser);
794
795        let mut results = Vec::new();
796        while let Some(result) = stream.next().await {
797            results.push(result);
798        }
799
800        assert_eq!(results.len(), 2);
801        assert_eq!(results[0].as_ref().unwrap(), "First event");
802        assert_eq!(results[1].as_ref().unwrap(), "Second event");
803    }
804
805    #[tokio::test]
806    async fn test_create_sse_stream_handles_multibyte_utf8_split() {
807        let multibyte_char = "✨";
808        let event = format!("data: Star {}\n\n", multibyte_char);
809        let test_data = event.as_bytes().to_vec();
810
811        let emoji_start = event.find(multibyte_char).unwrap();
812        let split_in_emoji = emoji_start + 1;
813
814        let chunks: Vec<Result<Bytes, reqwest::Error>> = vec![
815            Ok(Bytes::from(test_data[..split_in_emoji].to_vec())),
816            Ok(Bytes::from(test_data[split_in_emoji..].to_vec())),
817        ];
818
819        let mock_response = create_mock_response(chunks);
820
821        let parser = |event: &str| -> Result<Option<String>, LLMError> {
822            if let Some(content) = event.strip_prefix("data: ") {
823                let content = content.trim();
824                if content.is_empty() {
825                    return Ok(None);
826                }
827                Ok(Some(content.to_string()))
828            } else {
829                Ok(None)
830            }
831        };
832
833        let mut stream = create_sse_stream(mock_response, parser);
834
835        let mut results = Vec::new();
836        while let Some(result) = stream.next().await {
837            results.push(result);
838        }
839
840        assert_eq!(results.len(), 1);
841        assert_eq!(
842            results[0].as_ref().unwrap(),
843            &format!("Star {}", multibyte_char)
844        );
845    }
846
847    fn create_mock_response(chunks: Vec<Result<Bytes, reqwest::Error>>) -> reqwest::Response {
848        use http_body_util::StreamBody;
849        use reqwest::Body;
850
851        let frame_stream = futures::stream::iter(
852            chunks
853                .into_iter()
854                .map(|chunk| chunk.map(hyper::body::Frame::data)),
855        );
856
857        let body = StreamBody::new(frame_stream);
858        let body = Body::wrap(body);
859
860        let http_response = http::Response::builder().status(200).body(body).unwrap();
861
862        http_response.into()
863    }
864}