agent_chain_core/
chat_models.rs

1//! Core ChatModel trait and related types.
2//!
3//! This module provides the base abstraction for chat models, following the
4//! LangChain pattern of having a common interface for different providers.
5
6use std::pin::Pin;
7use std::sync::Arc;
8
9use async_trait::async_trait;
10use futures::Stream;
11use serde::{Deserialize, Serialize};
12
13use crate::error::Result;
14use crate::messages::{AIMessage, BaseMessage};
15use crate::tools::{Tool, ToolDefinition};
16
17/// Output from a chat model generation.
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct ChatResult {
20    /// The generated message.
21    pub message: AIMessage,
22    /// Additional metadata from the model.
23    #[serde(default)]
24    pub metadata: ChatResultMetadata,
25}
26
27/// Metadata from a chat model generation.
28#[derive(Debug, Clone, Default, Serialize, Deserialize)]
29pub struct ChatResultMetadata {
30    /// The model that was used.
31    #[serde(skip_serializing_if = "Option::is_none")]
32    pub model: Option<String>,
33    /// Stop reason from the model.
34    #[serde(skip_serializing_if = "Option::is_none")]
35    pub stop_reason: Option<String>,
36    /// Token usage information.
37    #[serde(skip_serializing_if = "Option::is_none")]
38    pub usage: Option<UsageMetadata>,
39}
40
41/// Token usage metadata.
42#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct UsageMetadata {
44    /// Number of input tokens.
45    pub input_tokens: u32,
46    /// Number of output tokens.
47    pub output_tokens: u32,
48    /// Total tokens (input + output).
49    pub total_tokens: u32,
50}
51
52impl UsageMetadata {
53    /// Create a new usage metadata.
54    pub fn new(input_tokens: u32, output_tokens: u32) -> Self {
55        Self {
56            input_tokens,
57            output_tokens,
58            total_tokens: input_tokens + output_tokens,
59        }
60    }
61}
62
63/// A chunk of output from streaming.
64#[derive(Debug, Clone, Serialize, Deserialize)]
65pub struct ChatChunk {
66    /// The content delta.
67    pub content: String,
68    /// Whether this is the final chunk.
69    pub is_final: bool,
70    /// Metadata (only present on final chunk).
71    #[serde(skip_serializing_if = "Option::is_none")]
72    pub metadata: Option<ChatResultMetadata>,
73}
74
75/// Type alias for streaming output.
76pub type ChatStream = Pin<Box<dyn Stream<Item = Result<ChatChunk>> + Send>>;
77
78/// Parameters for tracing and monitoring.
79#[derive(Debug, Clone, Default, Serialize, Deserialize)]
80pub struct LangSmithParams {
81    /// Provider name (e.g., "anthropic", "openai").
82    #[serde(skip_serializing_if = "Option::is_none")]
83    pub ls_provider: Option<String>,
84    /// Model name.
85    #[serde(skip_serializing_if = "Option::is_none")]
86    pub ls_model_name: Option<String>,
87    /// Model type (always "chat" for chat models).
88    #[serde(skip_serializing_if = "Option::is_none")]
89    pub ls_model_type: Option<String>,
90    /// Temperature setting.
91    #[serde(skip_serializing_if = "Option::is_none")]
92    pub ls_temperature: Option<f64>,
93    /// Max tokens setting.
94    #[serde(skip_serializing_if = "Option::is_none")]
95    pub ls_max_tokens: Option<u32>,
96    /// Stop sequences.
97    #[serde(skip_serializing_if = "Option::is_none")]
98    pub ls_stop: Option<Vec<String>>,
99}
100
101/// Base trait for all chat models.
102///
103/// This trait follows the LangChain pattern where each provider implements
104/// the core generation methods. The trait provides both sync-style (via async)
105/// and streaming interfaces.
106///
107/// # Example Implementation
108///
109/// ```ignore
110/// use agent_chain_core::chat_model::{ChatModel, ChatResult};
111/// use agent_chain_core::messages::BaseMessage;
112///
113/// struct MyChatModel {
114///     model: String,
115/// }
116///
117/// #[async_trait::async_trait]
118/// impl ChatModel for MyChatModel {
119///     fn llm_type(&self) -> &str {
120///         "my-chat-model"
121///     }
122///
123///     async fn generate(
124///         &self,
125///         messages: Vec<BaseMessage>,
126///         stop: Option<Vec<String>>,
127///     ) -> Result<ChatResult> {
128///         // Implementation here
129///         todo!()
130///     }
131/// }
132/// ```
133#[async_trait]
134pub trait ChatModel: Send + Sync {
135    /// Return the type identifier for this chat model.
136    ///
137    /// This is used for logging and tracing purposes.
138    fn llm_type(&self) -> &str;
139
140    /// Get the model name/identifier.
141    fn model_name(&self) -> &str;
142
143    /// Generate a response from the model.
144    ///
145    /// # Arguments
146    ///
147    /// * `messages` - The conversation history.
148    /// * `stop` - Optional stop sequences.
149    ///
150    /// # Returns
151    ///
152    /// A `ChatResult` containing the generated message and metadata.
153    async fn generate(
154        &self,
155        messages: Vec<BaseMessage>,
156        stop: Option<Vec<String>>,
157    ) -> Result<ChatResult>;
158
159    /// Generate a response from the model with tools.
160    ///
161    /// This is the preferred method when tool calling is needed.
162    /// Default implementation ignores tools and calls `generate`.
163    /// Providers should override this to enable tool calling.
164    ///
165    /// # Arguments
166    ///
167    /// * `messages` - The conversation history.
168    /// * `tools` - Tool definitions for the model to use.
169    /// * `tool_choice` - Optional configuration for tool selection.
170    /// * `stop` - Optional stop sequences.
171    ///
172    /// # Returns
173    ///
174    /// A `ChatResult` containing the generated message and metadata.
175    async fn generate_with_tools(
176        &self,
177        messages: Vec<BaseMessage>,
178        tools: &[ToolDefinition],
179        tool_choice: Option<&ToolChoice>,
180        stop: Option<Vec<String>>,
181    ) -> Result<ChatResult> {
182        // Default implementation ignores tools
183        let _ = tools;
184        let _ = tool_choice;
185        self.generate(messages, stop).await
186    }
187
188    /// Generate a streaming response from the model.
189    ///
190    /// Default implementation calls `generate` and wraps the result in a stream.
191    /// Providers should override this for native streaming support.
192    ///
193    /// # Arguments
194    ///
195    /// * `messages` - The conversation history.
196    /// * `stop` - Optional stop sequences.
197    ///
198    /// # Returns
199    ///
200    /// A stream of `ChatChunk`s.
201    async fn stream(
202        &self,
203        messages: Vec<BaseMessage>,
204        stop: Option<Vec<String>>,
205    ) -> Result<ChatStream> {
206        let result = self.generate(messages, stop).await?;
207        let chunk = ChatChunk {
208            content: result.message.content().to_string(),
209            is_final: true,
210            metadata: Some(result.metadata),
211        };
212        Ok(Box::pin(futures::stream::once(async move { Ok(chunk) })))
213    }
214
215    /// Get parameters for tracing/monitoring.
216    fn get_ls_params(&self, stop: Option<&[String]>) -> LangSmithParams {
217        let mut params = LangSmithParams {
218            ls_model_type: Some("chat".to_string()),
219            ..Default::default()
220        };
221        if let Some(stop) = stop {
222            params.ls_stop = Some(stop.to_vec());
223        }
224        params
225    }
226
227    /// Get identifying parameters for serialization.
228    fn identifying_params(&self) -> serde_json::Value {
229        serde_json::json!({
230            "_type": self.llm_type(),
231            "model": self.model_name(),
232        })
233    }
234}
235
236/// Configuration for tool choice.
237#[derive(Debug, Clone, Serialize, Deserialize)]
238#[serde(tag = "type", rename_all = "snake_case")]
239pub enum ToolChoice {
240    /// Let the model decide whether to use tools.
241    Auto,
242    /// Model must use at least one tool.
243    Any,
244    /// Model must use a specific tool.
245    Tool {
246        /// Name of the tool to use.
247        name: String,
248    },
249    /// Model should not use any tools.
250    None,
251}
252
253/// A chat model that has been bound with tools (generic version).
254///
255/// This wraps an underlying chat model and includes tool definitions
256/// that will be passed to the model on each invocation.
257pub struct BoundChatModel<M: ChatModel> {
258    /// The underlying chat model.
259    model: M,
260    /// Tools bound to this model.
261    tools: Vec<Arc<dyn Tool + Send + Sync>>,
262    /// Tool choice configuration.
263    tool_choice: Option<ToolChoice>,
264}
265
266impl<M: ChatModel> BoundChatModel<M> {
267    /// Create a new bound chat model.
268    pub fn new(model: M, tools: Vec<Arc<dyn Tool + Send + Sync>>) -> Self {
269        Self {
270            model,
271            tools,
272            tool_choice: None,
273        }
274    }
275
276    /// Set the tool choice.
277    pub fn with_tool_choice(mut self, tool_choice: ToolChoice) -> Self {
278        self.tool_choice = Some(tool_choice);
279        self
280    }
281
282    /// Get the tool definitions.
283    pub fn tool_definitions(&self) -> Vec<ToolDefinition> {
284        self.tools.iter().map(|t| t.definition()).collect()
285    }
286
287    /// Get a reference to the underlying model.
288    pub fn model(&self) -> &M {
289        &self.model
290    }
291
292    /// Get the tools.
293    pub fn tools(&self) -> &[Arc<dyn Tool + Send + Sync>] {
294        &self.tools
295    }
296
297    /// Get the tool choice.
298    pub fn tool_choice(&self) -> Option<&ToolChoice> {
299        self.tool_choice.as_ref()
300    }
301
302    /// Invoke the model with messages.
303    ///
304    /// This generates a response using the bound tools.
305    pub async fn invoke(&self, messages: Vec<BaseMessage>) -> BaseMessage {
306        let tool_definitions = self.tool_definitions();
307        match self
308            .model
309            .generate_with_tools(messages, &tool_definitions, self.tool_choice.as_ref(), None)
310            .await
311        {
312            Ok(result) => result.message.into(),
313            Err(e) => {
314                // Return an error message
315                AIMessage::new(format!("Error: {}", e)).into()
316            }
317        }
318    }
319}
320
321impl<M: ChatModel + Clone> Clone for BoundChatModel<M> {
322    fn clone(&self) -> Self {
323        Self {
324            model: self.model.clone(),
325            tools: self.tools.clone(),
326            tool_choice: self.tool_choice.clone(),
327        }
328    }
329}
330
331/// Extension trait for chat models to add tool binding.
332pub trait ChatModelExt: ChatModel + Sized {
333    /// Bind tools to this chat model.
334    ///
335    /// # Arguments
336    ///
337    /// * `tools` - The tools to bind.
338    ///
339    /// # Returns
340    ///
341    /// A `BoundChatModel` that includes the tools.
342    fn bind_tools(self, tools: Vec<Arc<dyn Tool + Send + Sync>>) -> BoundChatModel<Self> {
343        BoundChatModel::new(self, tools)
344    }
345}
346
347// Implement ChatModelExt for all ChatModel implementations
348impl<T: ChatModel + Sized> ChatModelExt for T {}
349
350/// A dynamically-typed chat model bound with tools.
351///
352/// This is the dynamic dispatch version of `BoundChatModel`, useful when
353/// working with `Arc<dyn ChatModel>` or boxed trait objects.
354#[derive(Clone)]
355pub struct DynBoundChatModel {
356    /// The underlying chat model.
357    model: Arc<dyn ChatModel>,
358    /// Tools bound to this model.
359    tools: Vec<Arc<dyn Tool + Send + Sync>>,
360    /// Tool choice configuration.
361    tool_choice: Option<ToolChoice>,
362}
363
364impl DynBoundChatModel {
365    /// Create a new dynamically-typed bound chat model.
366    pub fn new(model: Arc<dyn ChatModel>, tools: Vec<Arc<dyn Tool + Send + Sync>>) -> Self {
367        Self {
368            model,
369            tools,
370            tool_choice: None,
371        }
372    }
373
374    /// Set the tool choice.
375    pub fn with_tool_choice(mut self, tool_choice: ToolChoice) -> Self {
376        self.tool_choice = Some(tool_choice);
377        self
378    }
379
380    /// Get the tool definitions.
381    pub fn tool_definitions(&self) -> Vec<ToolDefinition> {
382        self.tools.iter().map(|t| t.definition()).collect()
383    }
384
385    /// Get a reference to the underlying model.
386    pub fn model(&self) -> &Arc<dyn ChatModel> {
387        &self.model
388    }
389
390    /// Get the tools.
391    pub fn tools(&self) -> &[Arc<dyn Tool + Send + Sync>] {
392        &self.tools
393    }
394
395    /// Get the tool choice.
396    pub fn tool_choice(&self) -> Option<&ToolChoice> {
397        self.tool_choice.as_ref()
398    }
399
400    /// Invoke the model with messages.
401    ///
402    /// This generates a response using the bound tools.
403    pub async fn invoke(&self, messages: Vec<BaseMessage>) -> BaseMessage {
404        let tool_definitions = self.tool_definitions();
405        match self
406            .model
407            .generate_with_tools(messages, &tool_definitions, self.tool_choice.as_ref(), None)
408            .await
409        {
410            Ok(result) => result.message.into(),
411            Err(e) => {
412                // Return an error message
413                AIMessage::new(format!("Error: {}", e)).into()
414            }
415        }
416    }
417}
418
419/// Extension methods for `Arc<dyn ChatModel>`.
420pub trait DynChatModelExt {
421    /// Bind tools to this chat model, returning a dynamically-typed bound model.
422    fn bind_tools(self, tools: Vec<Arc<dyn Tool + Send + Sync>>) -> DynBoundChatModel;
423}
424
425impl DynChatModelExt for Arc<dyn ChatModel> {
426    fn bind_tools(self, tools: Vec<Arc<dyn Tool + Send + Sync>>) -> DynBoundChatModel {
427        DynBoundChatModel::new(self, tools)
428    }
429}