autoagents_llm/chat/mod.rs
1use std::collections::HashMap;
2use std::fmt;
3
4use async_trait::async_trait;
5use futures::stream::{Stream, StreamExt};
6use serde::{Deserialize, Serialize};
7use serde_json::Value;
8
9use crate::{error::LLMError, ToolCall};
10
11/// Role of a participant in a chat conversation.
12#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)]
13pub enum ChatRole {
14 /// The user/human participant in the conversation
15 User,
16 /// The AI assistant participant in the conversation
17 Assistant,
18}
19
20/// The supported MIME type of an image.
21#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)]
22#[non_exhaustive]
23pub enum ImageMime {
24 /// JPEG image
25 JPEG,
26 /// PNG image
27 PNG,
28 /// GIF image
29 GIF,
30 /// WebP image
31 WEBP,
32}
33
34impl ImageMime {
35 pub fn mime_type(&self) -> &'static str {
36 match self {
37 ImageMime::JPEG => "image/jpeg",
38 ImageMime::PNG => "image/png",
39 ImageMime::GIF => "image/gif",
40 ImageMime::WEBP => "image/webp",
41 }
42 }
43}
44
45/// The type of a message in a chat conversation.
46#[derive(Debug, Clone, PartialEq, Eq, Default, Deserialize, Serialize)]
47pub enum MessageType {
48 /// A text message
49 #[default]
50 Text,
51 /// An image message
52 Image((ImageMime, Vec<u8>)),
53 /// PDF message
54 Pdf(Vec<u8>),
55 /// An image URL message
56 ImageURL(String),
57 /// A tool use
58 ToolUse(Vec<ToolCall>),
59 /// Tool result
60 ToolResult(Vec<ToolCall>),
61}
62
63/// The type of reasoning effort for a message in a chat conversation.
64pub enum ReasoningEffort {
65 /// Low reasoning effort
66 Low,
67 /// Medium reasoning effort
68 Medium,
69 /// High reasoning effort
70 High,
71}
72
73/// A single message in a chat conversation.
74#[derive(Debug, Clone, Serialize, Deserialize)]
75pub struct ChatMessage {
76 /// The role of who sent this message (user or assistant)
77 pub role: ChatRole,
78 /// The type of the message (text, image, audio, video, etc)
79 pub message_type: MessageType,
80 /// The text content of the message
81 pub content: String,
82}
83
84/// Represents a parameter in a function tool
85#[derive(Debug, Clone, Serialize)]
86pub struct ParameterProperty {
87 /// The type of the parameter (e.g. "string", "number", "array", etc)
88 #[serde(rename = "type")]
89 pub property_type: String,
90 /// Description of what the parameter does
91 pub description: String,
92 /// When type is "array", this defines the type of the array items
93 #[serde(skip_serializing_if = "Option::is_none")]
94 pub items: Option<Box<ParameterProperty>>,
95 /// When type is "enum", this defines the possible values for the parameter
96 #[serde(skip_serializing_if = "Option::is_none", rename = "enum")]
97 pub enum_list: Option<Vec<String>>,
98}
99
100/// Represents the parameters schema for a function tool
101#[derive(Debug, Clone, Serialize)]
102pub struct ParametersSchema {
103 /// The type of the parameters object (usually "object")
104 #[serde(rename = "type")]
105 pub schema_type: String,
106 /// Map of parameter names to their properties
107 pub properties: HashMap<String, ParameterProperty>,
108 /// List of required parameter names
109 pub required: Vec<String>,
110}
111
112/// Represents a function definition for a tool.
113///
114/// The `parameters` field stores the JSON Schema describing the function
115/// arguments. It is kept as a raw `serde_json::Value` to allow arbitrary
116/// complexity (nested arrays/objects, `oneOf`, etc.) without requiring a
117/// bespoke Rust structure.
118///
119/// Builder helpers can still generate simple schemas automatically, but the
120/// user may also provide any valid schema directly.
121#[derive(Debug, Clone, Serialize)]
122pub struct FunctionTool {
123 /// Name of the function
124 pub name: String,
125 /// Human-readable description
126 pub description: String,
127 /// JSON Schema describing the parameters
128 pub parameters: Value,
129}
130
131/// Defines rules for structured output responses based on [OpenAI's structured output requirements](https://platform.openai.com/docs/api-reference/chat/create#chat-create-response_format).
132/// Individual providers may have additional requirements or restrictions, but these should be handled by each provider's backend implementation.
133///
134/// If you plan on deserializing into this struct, make sure the source text has a `"name"` field, since that's technically the only thing required by OpenAI.
135///
136/// ## Example
137///
138/// ```
139/// use llm::chat::StructuredOutputFormat;
140/// use serde_json::json;
141///
142/// let response_format = r#"
143/// {
144/// "name": "Student",
145/// "description": "A student object",
146/// "schema": {
147/// "type": "object",
148/// "properties": {
149/// "name": {
150/// "type": "string"
151/// },
152/// "age": {
153/// "type": "integer"
154/// },
155/// "is_student": {
156/// "type": "boolean"
157/// }
158/// },
159/// "required": ["name", "age", "is_student"]
160/// }
161/// }
162/// "#;
163/// let structured_output: StructuredOutputFormat = serde_json::from_str(response_format).unwrap();
164/// assert_eq!(structured_output.name, "Student");
165/// assert_eq!(structured_output.description, Some("A student object".to_string()));
166/// ```
167#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
168
169pub struct StructuredOutputFormat {
170 /// Name of the schema
171 pub name: String,
172 /// The description of the schema
173 pub description: Option<String>,
174 /// The JSON schema for the structured output
175 pub schema: Option<Value>,
176 /// Whether to enable strict schema adherence
177 pub strict: Option<bool>,
178}
179
180/// Represents a tool that can be used in chat
181#[derive(Debug, Clone, Serialize)]
182pub struct Tool {
183 /// The type of tool (e.g. "function")
184 #[serde(rename = "type")]
185 pub tool_type: String,
186 /// The function definition if this is a function tool
187 pub function: FunctionTool,
188}
189
190/// Tool choice determines how the LLM uses available tools.
191/// The behavior is standardized across different LLM providers.
192#[derive(Debug, Clone, Default)]
193pub enum ToolChoice {
194 /// Model can use any tool, but it must use at least one.
195 /// This is useful when you want to force the model to use tools.
196 Any,
197
198 /// Model can use any tool, and may elect to use none.
199 /// This is the default behavior and gives the model flexibility.
200 #[default]
201 Auto,
202
203 /// Model must use the specified tool and only the specified tool.
204 /// The string parameter is the name of the required tool.
205 /// This is useful when you want the model to call a specific function.
206 Tool(String),
207
208 /// Explicitly disables the use of tools.
209 /// The model will not use any tools even if they are provided.
210 None,
211}
212
213impl Serialize for ToolChoice {
214 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
215 where
216 S: serde::Serializer,
217 {
218 match self {
219 ToolChoice::Any => serializer.serialize_str("required"),
220 ToolChoice::Auto => serializer.serialize_str("auto"),
221 ToolChoice::None => serializer.serialize_str("none"),
222 ToolChoice::Tool(name) => {
223 use serde::ser::SerializeMap;
224
225 // For tool_choice: {"type": "function", "function": {"name": "function_name"}}
226 let mut map = serializer.serialize_map(Some(2))?;
227 map.serialize_entry("type", "function")?;
228
229 // Inner function object
230 let mut function_obj = std::collections::HashMap::new();
231 function_obj.insert("name", name.as_str());
232
233 map.serialize_entry("function", &function_obj)?;
234 map.end()
235 }
236 }
237 }
238}
239
240pub trait ChatResponse: std::fmt::Debug + std::fmt::Display + Send + Sync {
241 fn text(&self) -> Option<String>;
242 fn tool_calls(&self) -> Option<Vec<ToolCall>>;
243 fn thinking(&self) -> Option<String> {
244 None
245 }
246}
247
248/// Trait for providers that support chat-style interactions.
249#[async_trait]
250pub trait ChatProvider: Sync + Send {
251 /// Sends a chat request to the provider with a sequence of messages.
252 ///
253 /// # Arguments
254 ///
255 /// * `messages` - The conversation history as a slice of chat messages
256 ///
257 /// # Returns
258 ///
259 /// The provider's response text or an error
260 async fn chat(&self, messages: &[ChatMessage]) -> Result<Box<dyn ChatResponse>, LLMError> {
261 self.chat_with_tools(messages, None).await
262 }
263
264 /// Sends a chat request to the provider with a sequence of messages and tools.
265 ///
266 /// # Arguments
267 ///
268 /// * `messages` - The conversation history as a slice of chat messages
269 /// * `tools` - Optional slice of tools to use in the chat
270 ///
271 /// # Returns
272 ///
273 /// The provider's response text or an error
274 async fn chat_with_tools(
275 &self,
276 messages: &[ChatMessage],
277 tools: Option<&[Tool]>,
278 ) -> Result<Box<dyn ChatResponse>, LLMError>;
279
280 /// Sends a streaming chat request to the provider with a sequence of messages.
281 ///
282 /// # Arguments
283 ///
284 /// * `messages` - The conversation history as a slice of chat messages
285 ///
286 /// # Returns
287 ///
288 /// A stream of text tokens or an error
289 async fn chat_stream(
290 &self,
291 _messages: &[ChatMessage],
292 ) -> Result<std::pin::Pin<Box<dyn Stream<Item = Result<String, LLMError>> + Send>>, LLMError>
293 {
294 Err(LLMError::Generic(
295 "Streaming not supported for this provider".to_string(),
296 ))
297 }
298
299 /// Get current memory contents if provider supports memory
300 async fn memory_contents(&self) -> Option<Vec<ChatMessage>> {
301 None
302 }
303
304 /// Summarizes a conversation history into a concise 2-3 sentence summary
305 ///
306 /// # Arguments
307 /// * `msgs` - The conversation messages to summarize
308 ///
309 /// # Returns
310 /// A string containing the summary or an error if summarization fails
311 async fn summarize_history(&self, msgs: &[ChatMessage]) -> Result<String, LLMError> {
312 let prompt = format!(
313 "Summarize in 2-3 sentences:\n{}",
314 msgs.iter()
315 .map(|m| format!("{:?}: {}", m.role, m.content))
316 .collect::<Vec<_>>()
317 .join("\n"),
318 );
319 let req = [ChatMessage::user().content(prompt).build()];
320 self.chat(&req)
321 .await?
322 .text()
323 .ok_or(LLMError::Generic("no text in summary response".into()))
324 }
325}
326
327impl fmt::Display for ReasoningEffort {
328 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
329 match self {
330 ReasoningEffort::Low => write!(f, "low"),
331 ReasoningEffort::Medium => write!(f, "medium"),
332 ReasoningEffort::High => write!(f, "high"),
333 }
334 }
335}
336
337impl ChatMessage {
338 /// Create a new builder for a user message
339 pub fn user() -> ChatMessageBuilder {
340 ChatMessageBuilder::new(ChatRole::User)
341 }
342
343 /// Create a new builder for an assistant message
344 pub fn assistant() -> ChatMessageBuilder {
345 ChatMessageBuilder::new(ChatRole::Assistant)
346 }
347}
348
349/// Builder for ChatMessage
350#[derive(Debug)]
351pub struct ChatMessageBuilder {
352 role: ChatRole,
353 message_type: MessageType,
354 content: String,
355}
356
357impl ChatMessageBuilder {
358 /// Create a new ChatMessageBuilder with specified role
359 pub fn new(role: ChatRole) -> Self {
360 Self {
361 role,
362 message_type: MessageType::default(),
363 content: String::new(),
364 }
365 }
366
367 /// Set the message content
368 pub fn content<S: Into<String>>(mut self, content: S) -> Self {
369 self.content = content.into();
370 self
371 }
372
373 /// Set the message type as Image
374 pub fn image(mut self, image_mime: ImageMime, raw_bytes: Vec<u8>) -> Self {
375 self.message_type = MessageType::Image((image_mime, raw_bytes));
376 self
377 }
378
379 /// Set the message type as Image
380 pub fn pdf(mut self, raw_bytes: Vec<u8>) -> Self {
381 self.message_type = MessageType::Pdf(raw_bytes);
382 self
383 }
384
385 /// Set the message type as ImageURL
386 pub fn image_url(mut self, url: impl Into<String>) -> Self {
387 self.message_type = MessageType::ImageURL(url.into());
388 self
389 }
390
391 /// Set the message type as ToolUse
392 pub fn tool_use(mut self, tools: Vec<ToolCall>) -> Self {
393 self.message_type = MessageType::ToolUse(tools);
394 self
395 }
396
397 /// Set the message type as ToolResult
398 pub fn tool_result(mut self, tools: Vec<ToolCall>) -> Self {
399 self.message_type = MessageType::ToolResult(tools);
400 self
401 }
402
403 /// Build the ChatMessage
404 pub fn build(self) -> ChatMessage {
405 ChatMessage {
406 role: self.role,
407 message_type: self.message_type,
408 content: self.content,
409 }
410 }
411}
412
413/// Creates a Server-Sent Events (SSE) stream from an HTTP response.
414///
415/// # Arguments
416///
417/// * `response` - The HTTP response from the streaming API
418/// * `parser` - Function to parse each SSE chunk into optional text content
419///
420/// # Returns
421///
422/// A pinned stream of text tokens or an error
423pub(crate) fn create_sse_stream<F>(
424 response: reqwest::Response,
425 parser: F,
426) -> std::pin::Pin<Box<dyn Stream<Item = Result<String, LLMError>> + Send>>
427where
428 F: Fn(&str) -> Result<Option<String>, LLMError> + Send + 'static,
429{
430 let stream = response
431 .bytes_stream()
432 .map(move |chunk| match chunk {
433 Ok(bytes) => {
434 let text = String::from_utf8_lossy(&bytes);
435 parser(&text)
436 }
437 Err(e) => Err(LLMError::HttpError(e.to_string())),
438 })
439 .filter_map(|result| async move {
440 match result {
441 Ok(Some(content)) => Some(Ok(content)),
442 Ok(None) => None,
443 Err(e) => Some(Err(e)),
444 }
445 });
446
447 Box::pin(stream)
448}