autoagents_llm/chat/
mod.rs

1use std::collections::HashMap;
2use std::fmt;
3
4use async_trait::async_trait;
5use futures::stream::{Stream, StreamExt};
6use serde::{Deserialize, Serialize};
7use serde_json::Value;
8
9use crate::{error::LLMError, ToolCall};
10
11/// Role of a participant in a chat conversation.
12#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)]
13pub enum ChatRole {
14    /// The user/human participant in the conversation
15    User,
16    /// The AI assistant participant in the conversation
17    Assistant,
18}
19
20/// The supported MIME type of an image.
21#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)]
22#[non_exhaustive]
23pub enum ImageMime {
24    /// JPEG image
25    JPEG,
26    /// PNG image
27    PNG,
28    /// GIF image
29    GIF,
30    /// WebP image
31    WEBP,
32}
33
34impl ImageMime {
35    pub fn mime_type(&self) -> &'static str {
36        match self {
37            ImageMime::JPEG => "image/jpeg",
38            ImageMime::PNG => "image/png",
39            ImageMime::GIF => "image/gif",
40            ImageMime::WEBP => "image/webp",
41        }
42    }
43}
44
45/// The type of a message in a chat conversation.
46#[derive(Debug, Clone, PartialEq, Eq, Default, Deserialize, Serialize)]
47pub enum MessageType {
48    /// A text message
49    #[default]
50    Text,
51    /// An image message
52    Image((ImageMime, Vec<u8>)),
53    /// PDF message
54    Pdf(Vec<u8>),
55    /// An image URL message
56    ImageURL(String),
57    /// A tool use
58    ToolUse(Vec<ToolCall>),
59    /// Tool result
60    ToolResult(Vec<ToolCall>),
61}
62
63/// The type of reasoning effort for a message in a chat conversation.
64pub enum ReasoningEffort {
65    /// Low reasoning effort
66    Low,
67    /// Medium reasoning effort
68    Medium,
69    /// High reasoning effort
70    High,
71}
72
73/// A single message in a chat conversation.
74#[derive(Debug, Clone, Serialize, Deserialize)]
75pub struct ChatMessage {
76    /// The role of who sent this message (user or assistant)
77    pub role: ChatRole,
78    /// The type of the message (text, image, audio, video, etc)
79    pub message_type: MessageType,
80    /// The text content of the message
81    pub content: String,
82}
83
84/// Represents a parameter in a function tool
85#[derive(Debug, Clone, Serialize)]
86pub struct ParameterProperty {
87    /// The type of the parameter (e.g. "string", "number", "array", etc)
88    #[serde(rename = "type")]
89    pub property_type: String,
90    /// Description of what the parameter does
91    pub description: String,
92    /// When type is "array", this defines the type of the array items
93    #[serde(skip_serializing_if = "Option::is_none")]
94    pub items: Option<Box<ParameterProperty>>,
95    /// When type is "enum", this defines the possible values for the parameter
96    #[serde(skip_serializing_if = "Option::is_none", rename = "enum")]
97    pub enum_list: Option<Vec<String>>,
98}
99
100/// Represents the parameters schema for a function tool
101#[derive(Debug, Clone, Serialize)]
102pub struct ParametersSchema {
103    /// The type of the parameters object (usually "object")
104    #[serde(rename = "type")]
105    pub schema_type: String,
106    /// Map of parameter names to their properties
107    pub properties: HashMap<String, ParameterProperty>,
108    /// List of required parameter names
109    pub required: Vec<String>,
110}
111
112/// Represents a function definition for a tool.
113///
114/// The `parameters` field stores the JSON Schema describing the function
115/// arguments.  It is kept as a raw `serde_json::Value` to allow arbitrary
116/// complexity (nested arrays/objects, `oneOf`, etc.) without requiring a
117/// bespoke Rust structure.
118///
119/// Builder helpers can still generate simple schemas automatically, but the
120/// user may also provide any valid schema directly.
121#[derive(Debug, Clone, Serialize)]
122pub struct FunctionTool {
123    /// Name of the function
124    pub name: String,
125    /// Human-readable description
126    pub description: String,
127    /// JSON Schema describing the parameters
128    pub parameters: Value,
129}
130
131/// Defines rules for structured output responses based on [OpenAI's structured output requirements](https://platform.openai.com/docs/api-reference/chat/create#chat-create-response_format).
132/// Individual providers may have additional requirements or restrictions, but these should be handled by each provider's backend implementation.
133///
134/// If you plan on deserializing into this struct, make sure the source text has a `"name"` field, since that's technically the only thing required by OpenAI.
135///
136/// ## Example
137///
138/// ```
139/// use llm::chat::StructuredOutputFormat;
140/// use serde_json::json;
141///
142/// let response_format = r#"
143///     {
144///         "name": "Student",
145///         "description": "A student object",
146///         "schema": {
147///             "type": "object",
148///             "properties": {
149///                 "name": {
150///                     "type": "string"
151///                 },
152///                 "age": {
153///                     "type": "integer"
154///                 },
155///                 "is_student": {
156///                     "type": "boolean"
157///                 }
158///             },
159///             "required": ["name", "age", "is_student"]
160///         }
161///     }
162/// "#;
163/// let structured_output: StructuredOutputFormat = serde_json::from_str(response_format).unwrap();
164/// assert_eq!(structured_output.name, "Student");
165/// assert_eq!(structured_output.description, Some("A student object".to_string()));
166/// ```
167#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
168
169pub struct StructuredOutputFormat {
170    /// Name of the schema
171    pub name: String,
172    /// The description of the schema
173    pub description: Option<String>,
174    /// The JSON schema for the structured output
175    pub schema: Option<Value>,
176    /// Whether to enable strict schema adherence
177    pub strict: Option<bool>,
178}
179
180/// Represents a tool that can be used in chat
181#[derive(Debug, Clone, Serialize)]
182pub struct Tool {
183    /// The type of tool (e.g. "function")
184    #[serde(rename = "type")]
185    pub tool_type: String,
186    /// The function definition if this is a function tool
187    pub function: FunctionTool,
188}
189
190/// Tool choice determines how the LLM uses available tools.
191/// The behavior is standardized across different LLM providers.
192#[derive(Debug, Clone, Default)]
193pub enum ToolChoice {
194    /// Model can use any tool, but it must use at least one.
195    /// This is useful when you want to force the model to use tools.
196    Any,
197
198    /// Model can use any tool, and may elect to use none.
199    /// This is the default behavior and gives the model flexibility.
200    #[default]
201    Auto,
202
203    /// Model must use the specified tool and only the specified tool.
204    /// The string parameter is the name of the required tool.
205    /// This is useful when you want the model to call a specific function.
206    Tool(String),
207
208    /// Explicitly disables the use of tools.
209    /// The model will not use any tools even if they are provided.
210    None,
211}
212
213impl Serialize for ToolChoice {
214    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
215    where
216        S: serde::Serializer,
217    {
218        match self {
219            ToolChoice::Any => serializer.serialize_str("required"),
220            ToolChoice::Auto => serializer.serialize_str("auto"),
221            ToolChoice::None => serializer.serialize_str("none"),
222            ToolChoice::Tool(name) => {
223                use serde::ser::SerializeMap;
224
225                // For tool_choice: {"type": "function", "function": {"name": "function_name"}}
226                let mut map = serializer.serialize_map(Some(2))?;
227                map.serialize_entry("type", "function")?;
228
229                // Inner function object
230                let mut function_obj = std::collections::HashMap::new();
231                function_obj.insert("name", name.as_str());
232
233                map.serialize_entry("function", &function_obj)?;
234                map.end()
235            }
236        }
237    }
238}
239
240pub trait ChatResponse: std::fmt::Debug + std::fmt::Display + Send + Sync {
241    fn text(&self) -> Option<String>;
242    fn tool_calls(&self) -> Option<Vec<ToolCall>>;
243    fn thinking(&self) -> Option<String> {
244        None
245    }
246}
247
248/// Trait for providers that support chat-style interactions.
249#[async_trait]
250pub trait ChatProvider: Sync + Send {
251    /// Sends a chat request to the provider with a sequence of messages.
252    ///
253    /// # Arguments
254    ///
255    /// * `messages` - The conversation history as a slice of chat messages
256    ///
257    /// # Returns
258    ///
259    /// The provider's response text or an error
260    async fn chat(&self, messages: &[ChatMessage]) -> Result<Box<dyn ChatResponse>, LLMError> {
261        self.chat_with_tools(messages, None).await
262    }
263
264    /// Sends a chat request to the provider with a sequence of messages and tools.
265    ///
266    /// # Arguments
267    ///
268    /// * `messages` - The conversation history as a slice of chat messages
269    /// * `tools` - Optional slice of tools to use in the chat
270    ///
271    /// # Returns
272    ///
273    /// The provider's response text or an error
274    async fn chat_with_tools(
275        &self,
276        messages: &[ChatMessage],
277        tools: Option<&[Tool]>,
278    ) -> Result<Box<dyn ChatResponse>, LLMError>;
279
280    /// Sends a streaming chat request to the provider with a sequence of messages.
281    ///
282    /// # Arguments
283    ///
284    /// * `messages` - The conversation history as a slice of chat messages
285    ///
286    /// # Returns
287    ///
288    /// A stream of text tokens or an error
289    async fn chat_stream(
290        &self,
291        _messages: &[ChatMessage],
292    ) -> Result<std::pin::Pin<Box<dyn Stream<Item = Result<String, LLMError>> + Send>>, LLMError>
293    {
294        Err(LLMError::Generic(
295            "Streaming not supported for this provider".to_string(),
296        ))
297    }
298
299    /// Get current memory contents if provider supports memory
300    async fn memory_contents(&self) -> Option<Vec<ChatMessage>> {
301        None
302    }
303
304    /// Summarizes a conversation history into a concise 2-3 sentence summary
305    ///
306    /// # Arguments
307    /// * `msgs` - The conversation messages to summarize
308    ///
309    /// # Returns
310    /// A string containing the summary or an error if summarization fails
311    async fn summarize_history(&self, msgs: &[ChatMessage]) -> Result<String, LLMError> {
312        let prompt = format!(
313            "Summarize in 2-3 sentences:\n{}",
314            msgs.iter()
315                .map(|m| format!("{:?}: {}", m.role, m.content))
316                .collect::<Vec<_>>()
317                .join("\n"),
318        );
319        let req = [ChatMessage::user().content(prompt).build()];
320        self.chat(&req)
321            .await?
322            .text()
323            .ok_or(LLMError::Generic("no text in summary response".into()))
324    }
325}
326
327impl fmt::Display for ReasoningEffort {
328    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
329        match self {
330            ReasoningEffort::Low => write!(f, "low"),
331            ReasoningEffort::Medium => write!(f, "medium"),
332            ReasoningEffort::High => write!(f, "high"),
333        }
334    }
335}
336
337impl ChatMessage {
338    /// Create a new builder for a user message
339    pub fn user() -> ChatMessageBuilder {
340        ChatMessageBuilder::new(ChatRole::User)
341    }
342
343    /// Create a new builder for an assistant message
344    pub fn assistant() -> ChatMessageBuilder {
345        ChatMessageBuilder::new(ChatRole::Assistant)
346    }
347}
348
349/// Builder for ChatMessage
350#[derive(Debug)]
351pub struct ChatMessageBuilder {
352    role: ChatRole,
353    message_type: MessageType,
354    content: String,
355}
356
357impl ChatMessageBuilder {
358    /// Create a new ChatMessageBuilder with specified role
359    pub fn new(role: ChatRole) -> Self {
360        Self {
361            role,
362            message_type: MessageType::default(),
363            content: String::new(),
364        }
365    }
366
367    /// Set the message content
368    pub fn content<S: Into<String>>(mut self, content: S) -> Self {
369        self.content = content.into();
370        self
371    }
372
373    /// Set the message type as Image
374    pub fn image(mut self, image_mime: ImageMime, raw_bytes: Vec<u8>) -> Self {
375        self.message_type = MessageType::Image((image_mime, raw_bytes));
376        self
377    }
378
379    /// Set the message type as Image
380    pub fn pdf(mut self, raw_bytes: Vec<u8>) -> Self {
381        self.message_type = MessageType::Pdf(raw_bytes);
382        self
383    }
384
385    /// Set the message type as ImageURL
386    pub fn image_url(mut self, url: impl Into<String>) -> Self {
387        self.message_type = MessageType::ImageURL(url.into());
388        self
389    }
390
391    /// Set the message type as ToolUse
392    pub fn tool_use(mut self, tools: Vec<ToolCall>) -> Self {
393        self.message_type = MessageType::ToolUse(tools);
394        self
395    }
396
397    /// Set the message type as ToolResult
398    pub fn tool_result(mut self, tools: Vec<ToolCall>) -> Self {
399        self.message_type = MessageType::ToolResult(tools);
400        self
401    }
402
403    /// Build the ChatMessage
404    pub fn build(self) -> ChatMessage {
405        ChatMessage {
406            role: self.role,
407            message_type: self.message_type,
408            content: self.content,
409        }
410    }
411}
412
413/// Creates a Server-Sent Events (SSE) stream from an HTTP response.
414///
415/// # Arguments
416///
417/// * `response` - The HTTP response from the streaming API
418/// * `parser` - Function to parse each SSE chunk into optional text content
419///
420/// # Returns
421///
422/// A pinned stream of text tokens or an error
423pub(crate) fn create_sse_stream<F>(
424    response: reqwest::Response,
425    parser: F,
426) -> std::pin::Pin<Box<dyn Stream<Item = Result<String, LLMError>> + Send>>
427where
428    F: Fn(&str) -> Result<Option<String>, LLMError> + Send + 'static,
429{
430    let stream = response
431        .bytes_stream()
432        .map(move |chunk| match chunk {
433            Ok(bytes) => {
434                let text = String::from_utf8_lossy(&bytes);
435                parser(&text)
436            }
437            Err(e) => Err(LLMError::HttpError(e.to_string())),
438        })
439        .filter_map(|result| async move {
440            match result {
441                Ok(Some(content)) => Some(Ok(content)),
442                Ok(None) => None,
443                Err(e) => Some(Err(e)),
444            }
445        });
446
447    Box::pin(stream)
448}