llm/chat/mod.rs
1use std::collections::HashMap;
2use std::fmt;
3
4use async_trait::async_trait;
5use futures::stream::{Stream, StreamExt};
6use serde::{Deserialize, Serialize};
7use serde_json::Value;
8
9use crate::{error::LLMError, ToolCall};
10
11/// Role of a participant in a chat conversation.
12#[derive(Debug, Clone, PartialEq, Eq)]
13pub enum ChatRole {
14 /// The user/human participant in the conversation
15 User,
16 /// The AI assistant participant in the conversation
17 Assistant,
18}
19
20/// The supported MIME type of an image.
21#[derive(Debug, Clone, PartialEq, Eq)]
22#[non_exhaustive]
23pub enum ImageMime {
24 /// JPEG image
25 JPEG,
26 /// PNG image
27 PNG,
28 /// GIF image
29 GIF,
30 /// WebP image
31 WEBP,
32}
33
34impl ImageMime {
35 pub fn mime_type(&self) -> &'static str {
36 match self {
37 ImageMime::JPEG => "image/jpeg",
38 ImageMime::PNG => "image/png",
39 ImageMime::GIF => "image/gif",
40 ImageMime::WEBP => "image/webp",
41 }
42 }
43}
44
45/// The type of a message in a chat conversation.
46#[derive(Debug, Clone, PartialEq, Eq, Default)]
47pub enum MessageType {
48 /// A text message
49 #[default]
50 Text,
51 /// An image message
52 Image((ImageMime, Vec<u8>)),
53 /// PDF message
54 Pdf(Vec<u8>),
55 /// An image URL message
56 ImageURL(String),
57 /// A tool use
58 ToolUse(Vec<ToolCall>),
59 /// Tool result
60 ToolResult(Vec<ToolCall>),
61}
62
63/// The type of reasoning effort for a message in a chat conversation.
64pub enum ReasoningEffort {
65 /// Low reasoning effort
66 Low,
67 /// Medium reasoning effort
68 Medium,
69 /// High reasoning effort
70 High,
71}
72
73/// A single message in a chat conversation.
74#[derive(Debug, Clone)]
75pub struct ChatMessage {
76 /// The role of who sent this message (user or assistant)
77 pub role: ChatRole,
78 /// The type of the message (text, image, audio, video, etc)
79 pub message_type: MessageType,
80 /// The text content of the message
81 pub content: String,
82}
83
84/// Represents a parameter in a function tool
85#[derive(Debug, Clone, Serialize)]
86pub struct ParameterProperty {
87 /// The type of the parameter (e.g. "string", "number", "array", etc)
88 #[serde(rename = "type")]
89 pub property_type: String,
90 /// Description of what the parameter does
91 pub description: String,
92 /// When type is "array", this defines the type of the array items
93 #[serde(skip_serializing_if = "Option::is_none")]
94 pub items: Option<Box<ParameterProperty>>,
95 /// When type is "enum", this defines the possible values for the parameter
96 #[serde(skip_serializing_if = "Option::is_none", rename = "enum")]
97 pub enum_list: Option<Vec<String>>,
98}
99
100/// Represents the parameters schema for a function tool
101#[derive(Debug, Clone, Serialize)]
102pub struct ParametersSchema {
103 /// The type of the parameters object (usually "object")
104 #[serde(rename = "type")]
105 pub schema_type: String,
106 /// Map of parameter names to their properties
107 pub properties: HashMap<String, ParameterProperty>,
108 /// List of required parameter names
109 pub required: Vec<String>,
110}
111
112/// Represents a function definition for a tool.
113///
114/// The `parameters` field stores the JSON Schema describing the function
115/// arguments. It is kept as a raw `serde_json::Value` to allow arbitrary
116/// complexity (nested arrays/objects, `oneOf`, etc.) without requiring a
117/// bespoke Rust structure.
118///
119/// Builder helpers can still generate simple schemas automatically, but the
120/// user may also provide any valid schema directly.
121#[derive(Debug, Clone, Serialize)]
122pub struct FunctionTool {
123 /// Name of the function
124 pub name: String,
125 /// Human-readable description
126 pub description: String,
127 /// JSON Schema describing the parameters
128 pub parameters: Value,
129}
130
131/// Defines rules for structured output responses based on [OpenAI's structured output requirements](https://platform.openai.com/docs/api-reference/chat/create#chat-create-response_format).
132/// Individual providers may have additional requirements or restrictions, but these should be handled by each provider's backend implementation.
133///
134/// If you plan on deserializing into this struct, make sure the source text has a `"name"` field, since that's technically the only thing required by OpenAI.
135///
136/// ## Example
137///
138/// ```
139/// use llm::chat::StructuredOutputFormat;
140/// use serde_json::json;
141///
142/// let response_format = r#"
143/// {
144/// "name": "Student",
145/// "description": "A student object",
146/// "schema": {
147/// "type": "object",
148/// "properties": {
149/// "name": {
150/// "type": "string"
151/// },
152/// "age": {
153/// "type": "integer"
154/// },
155/// "is_student": {
156/// "type": "boolean"
157/// }
158/// },
159/// "required": ["name", "age", "is_student"]
160/// }
161/// }
162/// "#;
163/// let structured_output: StructuredOutputFormat = serde_json::from_str(response_format).unwrap();
164/// assert_eq!(structured_output.name, "Student");
165/// assert_eq!(structured_output.description, Some("A student object".to_string()));
166/// ```
167#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
168
169pub struct StructuredOutputFormat {
170 /// Name of the schema
171 pub name: String,
172 /// The description of the schema
173 pub description: Option<String>,
174 /// The JSON schema for the structured output
175 pub schema: Option<Value>,
176 /// Whether to enable strict schema adherence
177 pub strict: Option<bool>,
178}
179
180/// Represents a tool that can be used in chat
181#[derive(Debug, Clone, Serialize)]
182pub struct Tool {
183 /// The type of tool (e.g. "function")
184 #[serde(rename = "type")]
185 pub tool_type: String,
186 /// The function definition if this is a function tool
187 pub function: FunctionTool,
188}
189
190/// Tool choice determines how the LLM uses available tools.
191/// The behavior is standardized across different LLM providers.
192#[derive(Debug, Clone, Default)]
193pub enum ToolChoice {
194 /// Model can use any tool, but it must use at least one.
195 /// This is useful when you want to force the model to use tools.
196 Any,
197
198 /// Model can use any tool, and may elect to use none.
199 /// This is the default behavior and gives the model flexibility.
200 #[default]
201 Auto,
202
203 /// Model must use the specified tool and only the specified tool.
204 /// The string parameter is the name of the required tool.
205 /// This is useful when you want the model to call a specific function.
206 Tool(String),
207
208 /// Explicitly disables the use of tools.
209 /// The model will not use any tools even if they are provided.
210 None,
211}
212
213impl Serialize for ToolChoice {
214 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
215 where
216 S: serde::Serializer,
217 {
218 match self {
219 ToolChoice::Any => serializer.serialize_str("required"),
220 ToolChoice::Auto => serializer.serialize_str("auto"),
221 ToolChoice::None => serializer.serialize_str("none"),
222 ToolChoice::Tool(name) => {
223 use serde::ser::SerializeMap;
224
225 // For tool_choice: {"type": "function", "function": {"name": "function_name"}}
226 let mut map = serializer.serialize_map(Some(2))?;
227 map.serialize_entry("type", "function")?;
228
229 // Inner function object
230 let mut function_obj = std::collections::HashMap::new();
231 function_obj.insert("name", name.as_str());
232
233 map.serialize_entry("function", &function_obj)?;
234 map.end()
235 }
236 }
237 }
238}
239
240pub trait ChatResponse: std::fmt::Debug + std::fmt::Display {
241 fn text(&self) -> Option<String>;
242 fn tool_calls(&self) -> Option<Vec<ToolCall>>;
243 fn thinking(&self) -> Option<String> {
244 None
245 }
246}
247
248/// Trait for providers that support chat-style interactions.
249#[async_trait]
250pub trait ChatProvider: Sync + Send {
251 /// Sends a chat request to the provider with a sequence of messages.
252 ///
253 /// # Arguments
254 ///
255 /// * `messages` - The conversation history as a slice of chat messages
256 ///
257 /// # Returns
258 ///
259 /// The provider's response text or an error
260 async fn chat(&self, messages: &[ChatMessage]) -> Result<Box<dyn ChatResponse>, LLMError> {
261 self.chat_with_tools(messages, None).await
262 }
263
264 /// Sends a chat request to the provider with a sequence of messages and tools.
265 ///
266 /// # Arguments
267 ///
268 /// * `messages` - The conversation history as a slice of chat messages
269 /// * `tools` - Optional slice of tools to use in the chat
270 ///
271 /// # Returns
272 ///
273 /// The provider's response text or an error
274 async fn chat_with_tools(
275 &self,
276 messages: &[ChatMessage],
277 tools: Option<&[Tool]>,
278 ) -> Result<Box<dyn ChatResponse>, LLMError>;
279
280 /// Sends a streaming chat request to the provider with a sequence of messages.
281 ///
282 /// # Arguments
283 ///
284 /// * `messages` - The conversation history as a slice of chat messages
285 ///
286 /// # Returns
287 ///
288 /// A stream of text tokens or an error
289 async fn chat_stream(
290 &self,
291 _messages: &[ChatMessage],
292 ) -> Result<std::pin::Pin<Box<dyn Stream<Item = Result<String, LLMError>> + Send>>, LLMError> {
293 Err(LLMError::Generic(
294 "Streaming not supported for this provider".to_string(),
295 ))
296 }
297
298 /// Get current memory contents if provider supports memory
299 async fn memory_contents(&self) -> Option<Vec<ChatMessage>> {
300 None
301 }
302
303 /// Summarizes a conversation history into a concise 2-3 sentence summary
304 ///
305 /// # Arguments
306 /// * `msgs` - The conversation messages to summarize
307 ///
308 /// # Returns
309 /// A string containing the summary or an error if summarization fails
310 async fn summarize_history(&self, msgs: &[ChatMessage]) -> Result<String, LLMError> {
311 let prompt = format!(
312 "Summarize in 2-3 sentences:\n{}",
313 msgs.iter()
314 .map(|m| format!("{:?}: {}", m.role, m.content))
315 .collect::<Vec<_>>()
316 .join("\n"),
317 );
318 let req = [ChatMessage::user().content(prompt).build()];
319 self.chat(&req)
320 .await?
321 .text()
322 .ok_or(LLMError::Generic("no text in summary response".into()))
323 }
324}
325
326impl fmt::Display for ReasoningEffort {
327 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
328 match self {
329 ReasoningEffort::Low => write!(f, "low"),
330 ReasoningEffort::Medium => write!(f, "medium"),
331 ReasoningEffort::High => write!(f, "high"),
332 }
333 }
334}
335
336impl ChatMessage {
337 /// Create a new builder for a user message
338 pub fn user() -> ChatMessageBuilder {
339 ChatMessageBuilder::new(ChatRole::User)
340 }
341
342 /// Create a new builder for an assistant message
343 pub fn assistant() -> ChatMessageBuilder {
344 ChatMessageBuilder::new(ChatRole::Assistant)
345 }
346}
347
348/// Builder for ChatMessage
349#[derive(Debug)]
350pub struct ChatMessageBuilder {
351 role: ChatRole,
352 message_type: MessageType,
353 content: String,
354}
355
356impl ChatMessageBuilder {
357 /// Create a new ChatMessageBuilder with specified role
358 pub fn new(role: ChatRole) -> Self {
359 Self {
360 role,
361 message_type: MessageType::default(),
362 content: String::new(),
363 }
364 }
365
366 /// Set the message content
367 pub fn content<S: Into<String>>(mut self, content: S) -> Self {
368 self.content = content.into();
369 self
370 }
371
372 /// Set the message type as Image
373 pub fn image(mut self, image_mime: ImageMime, raw_bytes: Vec<u8>) -> Self {
374 self.message_type = MessageType::Image((image_mime, raw_bytes));
375 self
376 }
377
378 /// Set the message type as Image
379 pub fn pdf(mut self, raw_bytes: Vec<u8>) -> Self {
380 self.message_type = MessageType::Pdf(raw_bytes);
381 self
382 }
383
384 /// Set the message type as ImageURL
385 pub fn image_url(mut self, url: impl Into<String>) -> Self {
386 self.message_type = MessageType::ImageURL(url.into());
387 self
388 }
389
390 /// Set the message type as ToolUse
391 pub fn tool_use(mut self, tools: Vec<ToolCall>) -> Self {
392 self.message_type = MessageType::ToolUse(tools);
393 self
394 }
395
396 /// Set the message type as ToolResult
397 pub fn tool_result(mut self, tools: Vec<ToolCall>) -> Self {
398 self.message_type = MessageType::ToolResult(tools);
399 self
400 }
401
402 /// Build the ChatMessage
403 pub fn build(self) -> ChatMessage {
404 ChatMessage {
405 role: self.role,
406 message_type: self.message_type,
407 content: self.content,
408 }
409 }
410}
411
412/// Creates a Server-Sent Events (SSE) stream from an HTTP response.
413///
414/// # Arguments
415///
416/// * `response` - The HTTP response from the streaming API
417/// * `parser` - Function to parse each SSE chunk into optional text content
418///
419/// # Returns
420///
421/// A pinned stream of text tokens or an error
422pub(crate) fn create_sse_stream<F>(
423 response: reqwest::Response,
424 parser: F,
425) -> std::pin::Pin<Box<dyn Stream<Item = Result<String, LLMError>> + Send>>
426where
427 F: Fn(&str) -> Result<Option<String>, LLMError> + Send + 'static,
428{
429 let stream = response
430 .bytes_stream()
431 .map(move |chunk| match chunk {
432 Ok(bytes) => {
433 let text = String::from_utf8_lossy(&bytes);
434 parser(&text)
435 }
436 Err(e) => Err(LLMError::HttpError(e.to_string())),
437 })
438 .filter_map(|result| async move {
439 match result {
440 Ok(Some(content)) => Some(Ok(content)),
441 Ok(None) => None,
442 Err(e) => Some(Err(e)),
443 }
444 });
445
446 Box::pin(stream)
447}