agent_chain_core/chat_models.rs
1//! Core ChatModel trait and related types.
2//!
3//! This module provides the base abstraction for chat models, following the
4//! LangChain pattern of having a common interface for different providers.
5
6use std::pin::Pin;
7use std::sync::Arc;
8
9use async_trait::async_trait;
10use futures::Stream;
11use serde::{Deserialize, Serialize};
12
13use crate::error::Result;
14use crate::messages::{AIMessage, BaseMessage};
15use crate::tools::{Tool, ToolDefinition};
16
17/// Output from a chat model generation.
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct ChatResult {
20 /// The generated message.
21 pub message: AIMessage,
22 /// Additional metadata from the model.
23 #[serde(default)]
24 pub metadata: ChatResultMetadata,
25}
26
27/// Metadata from a chat model generation.
28#[derive(Debug, Clone, Default, Serialize, Deserialize)]
29pub struct ChatResultMetadata {
30 /// The model that was used.
31 #[serde(skip_serializing_if = "Option::is_none")]
32 pub model: Option<String>,
33 /// Stop reason from the model.
34 #[serde(skip_serializing_if = "Option::is_none")]
35 pub stop_reason: Option<String>,
36 /// Token usage information.
37 #[serde(skip_serializing_if = "Option::is_none")]
38 pub usage: Option<UsageMetadata>,
39}
40
41/// Token usage metadata.
42#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct UsageMetadata {
44 /// Number of input tokens.
45 pub input_tokens: u32,
46 /// Number of output tokens.
47 pub output_tokens: u32,
48 /// Total tokens (input + output).
49 pub total_tokens: u32,
50}
51
52impl UsageMetadata {
53 /// Create a new usage metadata.
54 pub fn new(input_tokens: u32, output_tokens: u32) -> Self {
55 Self {
56 input_tokens,
57 output_tokens,
58 total_tokens: input_tokens + output_tokens,
59 }
60 }
61}
62
63/// A chunk of output from streaming.
64#[derive(Debug, Clone, Serialize, Deserialize)]
65pub struct ChatChunk {
66 /// The content delta.
67 pub content: String,
68 /// Whether this is the final chunk.
69 pub is_final: bool,
70 /// Metadata (only present on final chunk).
71 #[serde(skip_serializing_if = "Option::is_none")]
72 pub metadata: Option<ChatResultMetadata>,
73}
74
75/// Type alias for streaming output.
76pub type ChatStream = Pin<Box<dyn Stream<Item = Result<ChatChunk>> + Send>>;
77
78/// Parameters for tracing and monitoring.
79#[derive(Debug, Clone, Default, Serialize, Deserialize)]
80pub struct LangSmithParams {
81 /// Provider name (e.g., "anthropic", "openai").
82 #[serde(skip_serializing_if = "Option::is_none")]
83 pub ls_provider: Option<String>,
84 /// Model name.
85 #[serde(skip_serializing_if = "Option::is_none")]
86 pub ls_model_name: Option<String>,
87 /// Model type (always "chat" for chat models).
88 #[serde(skip_serializing_if = "Option::is_none")]
89 pub ls_model_type: Option<String>,
90 /// Temperature setting.
91 #[serde(skip_serializing_if = "Option::is_none")]
92 pub ls_temperature: Option<f64>,
93 /// Max tokens setting.
94 #[serde(skip_serializing_if = "Option::is_none")]
95 pub ls_max_tokens: Option<u32>,
96 /// Stop sequences.
97 #[serde(skip_serializing_if = "Option::is_none")]
98 pub ls_stop: Option<Vec<String>>,
99}
100
101/// Base trait for all chat models.
102///
103/// This trait follows the LangChain pattern where each provider implements
104/// the core generation methods. The trait provides both sync-style (via async)
105/// and streaming interfaces.
106///
107/// # Example Implementation
108///
109/// ```ignore
110/// use agent_chain_core::chat_model::{ChatModel, ChatResult};
111/// use agent_chain_core::messages::BaseMessage;
112///
113/// struct MyChatModel {
114/// model: String,
115/// }
116///
117/// #[async_trait::async_trait]
118/// impl ChatModel for MyChatModel {
119/// fn llm_type(&self) -> &str {
120/// "my-chat-model"
121/// }
122///
123/// async fn generate(
124/// &self,
125/// messages: Vec<BaseMessage>,
126/// stop: Option<Vec<String>>,
127/// ) -> Result<ChatResult> {
128/// // Implementation here
129/// todo!()
130/// }
131/// }
132/// ```
133#[async_trait]
134pub trait ChatModel: Send + Sync {
135 /// Return the type identifier for this chat model.
136 ///
137 /// This is used for logging and tracing purposes.
138 fn llm_type(&self) -> &str;
139
140 /// Get the model name/identifier.
141 fn model_name(&self) -> &str;
142
143 /// Generate a response from the model.
144 ///
145 /// # Arguments
146 ///
147 /// * `messages` - The conversation history.
148 /// * `stop` - Optional stop sequences.
149 ///
150 /// # Returns
151 ///
152 /// A `ChatResult` containing the generated message and metadata.
153 async fn generate(
154 &self,
155 messages: Vec<BaseMessage>,
156 stop: Option<Vec<String>>,
157 ) -> Result<ChatResult>;
158
159 /// Generate a response from the model with tools.
160 ///
161 /// This is the preferred method when tool calling is needed.
162 /// Default implementation ignores tools and calls `generate`.
163 /// Providers should override this to enable tool calling.
164 ///
165 /// # Arguments
166 ///
167 /// * `messages` - The conversation history.
168 /// * `tools` - Tool definitions for the model to use.
169 /// * `tool_choice` - Optional configuration for tool selection.
170 /// * `stop` - Optional stop sequences.
171 ///
172 /// # Returns
173 ///
174 /// A `ChatResult` containing the generated message and metadata.
175 async fn generate_with_tools(
176 &self,
177 messages: Vec<BaseMessage>,
178 tools: &[ToolDefinition],
179 tool_choice: Option<&ToolChoice>,
180 stop: Option<Vec<String>>,
181 ) -> Result<ChatResult> {
182 // Default implementation ignores tools
183 let _ = tools;
184 let _ = tool_choice;
185 self.generate(messages, stop).await
186 }
187
188 /// Generate a streaming response from the model.
189 ///
190 /// Default implementation calls `generate` and wraps the result in a stream.
191 /// Providers should override this for native streaming support.
192 ///
193 /// # Arguments
194 ///
195 /// * `messages` - The conversation history.
196 /// * `stop` - Optional stop sequences.
197 ///
198 /// # Returns
199 ///
200 /// A stream of `ChatChunk`s.
201 async fn stream(
202 &self,
203 messages: Vec<BaseMessage>,
204 stop: Option<Vec<String>>,
205 ) -> Result<ChatStream> {
206 let result = self.generate(messages, stop).await?;
207 let chunk = ChatChunk {
208 content: result.message.content().to_string(),
209 is_final: true,
210 metadata: Some(result.metadata),
211 };
212 Ok(Box::pin(futures::stream::once(async move { Ok(chunk) })))
213 }
214
215 /// Get parameters for tracing/monitoring.
216 fn get_ls_params(&self, stop: Option<&[String]>) -> LangSmithParams {
217 let mut params = LangSmithParams {
218 ls_model_type: Some("chat".to_string()),
219 ..Default::default()
220 };
221 if let Some(stop) = stop {
222 params.ls_stop = Some(stop.to_vec());
223 }
224 params
225 }
226
227 /// Get identifying parameters for serialization.
228 fn identifying_params(&self) -> serde_json::Value {
229 serde_json::json!({
230 "_type": self.llm_type(),
231 "model": self.model_name(),
232 })
233 }
234}
235
236/// Configuration for tool choice.
237#[derive(Debug, Clone, Serialize, Deserialize)]
238#[serde(tag = "type", rename_all = "snake_case")]
239pub enum ToolChoice {
240 /// Let the model decide whether to use tools.
241 Auto,
242 /// Model must use at least one tool.
243 Any,
244 /// Model must use a specific tool.
245 Tool {
246 /// Name of the tool to use.
247 name: String,
248 },
249 /// Model should not use any tools.
250 None,
251}
252
253/// A chat model that has been bound with tools (generic version).
254///
255/// This wraps an underlying chat model and includes tool definitions
256/// that will be passed to the model on each invocation.
257pub struct BoundChatModel<M: ChatModel> {
258 /// The underlying chat model.
259 model: M,
260 /// Tools bound to this model.
261 tools: Vec<Arc<dyn Tool + Send + Sync>>,
262 /// Tool choice configuration.
263 tool_choice: Option<ToolChoice>,
264}
265
266impl<M: ChatModel> BoundChatModel<M> {
267 /// Create a new bound chat model.
268 pub fn new(model: M, tools: Vec<Arc<dyn Tool + Send + Sync>>) -> Self {
269 Self {
270 model,
271 tools,
272 tool_choice: None,
273 }
274 }
275
276 /// Set the tool choice.
277 pub fn with_tool_choice(mut self, tool_choice: ToolChoice) -> Self {
278 self.tool_choice = Some(tool_choice);
279 self
280 }
281
282 /// Get the tool definitions.
283 pub fn tool_definitions(&self) -> Vec<ToolDefinition> {
284 self.tools.iter().map(|t| t.definition()).collect()
285 }
286
287 /// Get a reference to the underlying model.
288 pub fn model(&self) -> &M {
289 &self.model
290 }
291
292 /// Get the tools.
293 pub fn tools(&self) -> &[Arc<dyn Tool + Send + Sync>] {
294 &self.tools
295 }
296
297 /// Get the tool choice.
298 pub fn tool_choice(&self) -> Option<&ToolChoice> {
299 self.tool_choice.as_ref()
300 }
301
302 /// Invoke the model with messages.
303 ///
304 /// This generates a response using the bound tools.
305 pub async fn invoke(&self, messages: Vec<BaseMessage>) -> BaseMessage {
306 let tool_definitions = self.tool_definitions();
307 match self
308 .model
309 .generate_with_tools(messages, &tool_definitions, self.tool_choice.as_ref(), None)
310 .await
311 {
312 Ok(result) => result.message.into(),
313 Err(e) => {
314 // Return an error message
315 AIMessage::new(format!("Error: {}", e)).into()
316 }
317 }
318 }
319}
320
321impl<M: ChatModel + Clone> Clone for BoundChatModel<M> {
322 fn clone(&self) -> Self {
323 Self {
324 model: self.model.clone(),
325 tools: self.tools.clone(),
326 tool_choice: self.tool_choice.clone(),
327 }
328 }
329}
330
331/// Extension trait for chat models to add tool binding.
332pub trait ChatModelExt: ChatModel + Sized {
333 /// Bind tools to this chat model.
334 ///
335 /// # Arguments
336 ///
337 /// * `tools` - The tools to bind.
338 ///
339 /// # Returns
340 ///
341 /// A `BoundChatModel` that includes the tools.
342 fn bind_tools(self, tools: Vec<Arc<dyn Tool + Send + Sync>>) -> BoundChatModel<Self> {
343 BoundChatModel::new(self, tools)
344 }
345}
346
347// Implement ChatModelExt for all ChatModel implementations
348impl<T: ChatModel + Sized> ChatModelExt for T {}
349
350/// A dynamically-typed chat model bound with tools.
351///
352/// This is the dynamic dispatch version of `BoundChatModel`, useful when
353/// working with `Arc<dyn ChatModel>` or boxed trait objects.
354#[derive(Clone)]
355pub struct DynBoundChatModel {
356 /// The underlying chat model.
357 model: Arc<dyn ChatModel>,
358 /// Tools bound to this model.
359 tools: Vec<Arc<dyn Tool + Send + Sync>>,
360 /// Tool choice configuration.
361 tool_choice: Option<ToolChoice>,
362}
363
364impl DynBoundChatModel {
365 /// Create a new dynamically-typed bound chat model.
366 pub fn new(model: Arc<dyn ChatModel>, tools: Vec<Arc<dyn Tool + Send + Sync>>) -> Self {
367 Self {
368 model,
369 tools,
370 tool_choice: None,
371 }
372 }
373
374 /// Set the tool choice.
375 pub fn with_tool_choice(mut self, tool_choice: ToolChoice) -> Self {
376 self.tool_choice = Some(tool_choice);
377 self
378 }
379
380 /// Get the tool definitions.
381 pub fn tool_definitions(&self) -> Vec<ToolDefinition> {
382 self.tools.iter().map(|t| t.definition()).collect()
383 }
384
385 /// Get a reference to the underlying model.
386 pub fn model(&self) -> &Arc<dyn ChatModel> {
387 &self.model
388 }
389
390 /// Get the tools.
391 pub fn tools(&self) -> &[Arc<dyn Tool + Send + Sync>] {
392 &self.tools
393 }
394
395 /// Get the tool choice.
396 pub fn tool_choice(&self) -> Option<&ToolChoice> {
397 self.tool_choice.as_ref()
398 }
399
400 /// Invoke the model with messages.
401 ///
402 /// This generates a response using the bound tools.
403 pub async fn invoke(&self, messages: Vec<BaseMessage>) -> BaseMessage {
404 let tool_definitions = self.tool_definitions();
405 match self
406 .model
407 .generate_with_tools(messages, &tool_definitions, self.tool_choice.as_ref(), None)
408 .await
409 {
410 Ok(result) => result.message.into(),
411 Err(e) => {
412 // Return an error message
413 AIMessage::new(format!("Error: {}", e)).into()
414 }
415 }
416 }
417}
418
419/// Extension methods for `Arc<dyn ChatModel>`.
420pub trait DynChatModelExt {
421 /// Bind tools to this chat model, returning a dynamically-typed bound model.
422 fn bind_tools(self, tools: Vec<Arc<dyn Tool + Send + Sync>>) -> DynBoundChatModel;
423}
424
425impl DynChatModelExt for Arc<dyn ChatModel> {
426 fn bind_tools(self, tools: Vec<Arc<dyn Tool + Send + Sync>>) -> DynBoundChatModel {
427 DynBoundChatModel::new(self, tools)
428 }
429}