agent_chain_core/language_models/
chat_models.rs

1//! Chat models for conversational AI.
2//!
3//! This module provides the base abstraction for chat models,
4//! following the LangChain pattern of having a common interface
5//! for different providers.
6//!
7//! Mirrors `langchain_core.language_models.chat_models`.
8
9use std::collections::HashMap;
10use std::pin::Pin;
11use std::sync::Arc;
12
13use async_trait::async_trait;
14use futures::Stream;
15use serde::{Deserialize, Serialize};
16use serde_json::Value;
17
18use super::base::{BaseLanguageModel, LangSmithParams, LanguageModelConfig, LanguageModelInput};
19use super::model_profile::ModelProfile;
20use crate::GenerationType;
21use crate::callbacks::{
22    AsyncCallbackManagerForLLMRun, BaseCallbackHandler, CallbackManagerForLLMRun, Callbacks,
23};
24use crate::error::{Error, Result};
25use crate::messages::{AIMessage, AIMessageChunk, BaseMessage, ChunkPosition, UsageMetadata};
26use crate::outputs::{ChatGeneration, ChatGenerationChunk, ChatResult, Generation, LLMResult};
27use crate::rate_limiters::BaseRateLimiter;
28use crate::tools::{BaseTool, ToolDefinition};
29
30/// Type alias for streaming output.
31pub type ChatStream = Pin<Box<dyn Stream<Item = Result<ChatChunk>> + Send>>;
32
33/// Type alias for a streaming chat generation output.
34pub type ChatGenerationStream = Pin<Box<dyn Stream<Item = Result<ChatGenerationChunk>> + Send>>;
35
36/// Type alias for streaming AIMessageChunk output.
37pub type AIMessageChunkStream = Pin<Box<dyn Stream<Item = Result<AIMessageChunk>> + Send>>;
38
39/// A chunk of output from streaming.
40///
41/// This struct carries content deltas during streaming, along with optional
42/// metadata that is typically attached to the final chunk.
43#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct ChatChunk {
45    /// The content delta.
46    pub content: String,
47    /// Whether this is the final chunk.
48    pub is_final: bool,
49    /// Usage metadata (token counts) - typically present on the final chunk.
50    #[serde(skip_serializing_if = "Option::is_none")]
51    pub usage_metadata: Option<UsageMetadata>,
52    /// The reason the model stopped generating (e.g., "stop", "length", "tool_calls").
53    /// Typically present on the final chunk.
54    #[serde(skip_serializing_if = "Option::is_none")]
55    pub finish_reason: Option<String>,
56}
57
58impl ChatChunk {
59    /// Create a new content chunk (non-final).
60    pub fn new(content: impl Into<String>) -> Self {
61        Self {
62            content: content.into(),
63            is_final: false,
64            usage_metadata: None,
65            finish_reason: None,
66        }
67    }
68
69    /// Create a final chunk with optional metadata.
70    pub fn final_chunk(
71        usage_metadata: Option<UsageMetadata>,
72        finish_reason: Option<String>,
73    ) -> Self {
74        Self {
75            content: String::new(),
76            is_final: true,
77            usage_metadata,
78            finish_reason,
79        }
80    }
81
82    /// Set usage metadata on this chunk.
83    pub fn with_usage_metadata(mut self, usage: UsageMetadata) -> Self {
84        self.usage_metadata = Some(usage);
85        self
86    }
87
88    /// Set finish reason on this chunk.
89    pub fn with_finish_reason(mut self, reason: impl Into<String>) -> Self {
90        self.finish_reason = Some(reason.into());
91        self
92    }
93}
94
95/// Configuration for tool choice.
96///
97/// Mirrors Python's tool_choice parameter patterns.
98#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
99#[serde(untagged)]
100pub enum ToolChoice {
101    /// String value like "auto", "any", "none", or a specific tool name.
102    String(String),
103    /// Structured tool choice with type and optional name.
104    Structured {
105        /// Type of tool choice.
106        #[serde(rename = "type")]
107        choice_type: String,
108        /// Optional tool name.
109        #[serde(skip_serializing_if = "Option::is_none")]
110        name: Option<String>,
111    },
112}
113
114impl ToolChoice {
115    /// Create an "auto" tool choice - let the model decide.
116    pub fn auto() -> Self {
117        ToolChoice::String("auto".to_string())
118    }
119
120    /// Create an "any" tool choice - model must use at least one tool.
121    pub fn any() -> Self {
122        ToolChoice::String("any".to_string())
123    }
124
125    /// Create a "none" tool choice - model should not use any tools.
126    pub fn none() -> Self {
127        ToolChoice::String("none".to_string())
128    }
129
130    /// Create a tool choice for a specific tool by name.
131    pub fn tool(name: impl Into<String>) -> Self {
132        ToolChoice::Structured {
133            choice_type: "tool".to_string(),
134            name: Some(name.into()),
135        }
136    }
137}
138
139/// Disable streaming options.
140///
141/// Mirrors Python's `disable_streaming: bool | Literal["tool_calling"]` field.
142#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
143#[serde(untagged)]
144pub enum DisableStreaming {
145    /// Boolean value: true = always disable, false = never disable.
146    Bool(bool),
147    /// Literal "tool_calling": disable only when tools are present.
148    ToolCalling,
149}
150
151impl Default for DisableStreaming {
152    fn default() -> Self {
153        DisableStreaming::Bool(false)
154    }
155}
156
157impl DisableStreaming {
158    /// Check if streaming should be bypassed.
159    ///
160    /// # Arguments
161    ///
162    /// * `has_tools` - Whether tools are present in the current call.
163    pub fn should_disable(&self, has_tools: bool) -> bool {
164        match self {
165            DisableStreaming::Bool(b) => *b,
166            DisableStreaming::ToolCalling => has_tools,
167        }
168    }
169}
170
171impl From<bool> for DisableStreaming {
172    fn from(b: bool) -> Self {
173        DisableStreaming::Bool(b)
174    }
175}
176
177/// Configuration specific to chat models.
178#[derive(Clone, Default)]
179pub struct ChatModelConfig {
180    /// Base language model configuration.
181    pub base: LanguageModelConfig,
182
183    /// Rate limiter for limiting API requests.
184    pub rate_limiter: Option<Arc<dyn BaseRateLimiter>>,
185
186    /// Whether to disable streaming for this model.
187    ///
188    /// If streaming is bypassed, then `stream`/`astream` will defer to `invoke`/`ainvoke`.
189    ///
190    /// - If `Bool(true)`, will always bypass streaming case.
191    /// - If `ToolCalling`, will bypass streaming case only when tools are present.
192    /// - If `Bool(false)` (default), will always use streaming case if available.
193    pub disable_streaming: DisableStreaming,
194
195    /// Version of `AIMessage` output format.
196    ///
197    /// - `"v0"`: provider-specific format in content
198    /// - `"v1"`: standardized format in content
199    ///
200    /// Can also be set via `LC_OUTPUT_VERSION` environment variable.
201    pub output_version: Option<String>,
202
203    /// Profile detailing model capabilities.
204    pub profile: Option<ModelProfile>,
205}
206
207impl std::fmt::Debug for ChatModelConfig {
208    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
209        f.debug_struct("ChatModelConfig")
210            .field("base", &self.base)
211            .field(
212                "rate_limiter",
213                &self.rate_limiter.as_ref().map(|_| "<rate_limiter>"),
214            )
215            .field("disable_streaming", &self.disable_streaming)
216            .field("output_version", &self.output_version)
217            .field("profile", &self.profile)
218            .finish()
219    }
220}
221
222impl ChatModelConfig {
223    /// Create a new chat model configuration.
224    pub fn new() -> Self {
225        Self::default()
226    }
227
228    /// Set the rate limiter.
229    pub fn with_rate_limiter(mut self, rate_limiter: Arc<dyn BaseRateLimiter>) -> Self {
230        self.rate_limiter = Some(rate_limiter);
231        self
232    }
233
234    /// Set whether to disable streaming.
235    pub fn with_disable_streaming(mut self, disable: impl Into<DisableStreaming>) -> Self {
236        self.disable_streaming = disable.into();
237        self
238    }
239
240    /// Set the output version.
241    pub fn with_output_version(mut self, version: impl Into<String>) -> Self {
242        self.output_version = Some(version.into());
243        self
244    }
245
246    /// Set the model profile.
247    pub fn with_profile(mut self, profile: ModelProfile) -> Self {
248        self.profile = Some(profile);
249        self
250    }
251
252    /// Enable caching.
253    pub fn with_cache(mut self, cache: bool) -> Self {
254        self.base.cache = Some(cache);
255        self
256    }
257
258    /// Enable verbose mode.
259    pub fn with_verbose(mut self, verbose: bool) -> Self {
260        self.base.verbose = verbose;
261        self
262    }
263
264    /// Set tags.
265    pub fn with_tags(mut self, tags: Vec<String>) -> Self {
266        self.base.tags = Some(tags);
267        self
268    }
269
270    /// Set metadata.
271    pub fn with_metadata(mut self, metadata: HashMap<String, Value>) -> Self {
272        self.base.metadata = Some(metadata);
273        self
274    }
275}
276
277/// Base trait for all chat models.
278///
279/// This trait follows the LangChain pattern where each provider implements
280/// the core generation methods. The trait provides both sync-style (via async)
281/// and streaming interfaces.
282///
283/// # Implementation Guide
284///
285/// Custom chat model implementations should override these methods:
286///
287/// | Method/Property           | Description                                        | Required |
288/// |--------------------------|----------------------------------------------------|---------:|
289/// | `_generate`              | Use to generate a chat result from messages        | Required |
290/// | `_llm_type` (property)   | Used to uniquely identify the type of the model    | Required |
291/// | `_identifying_params`    | Represent model parameterization for tracing       | Optional |
292/// | `_stream`                | Use to implement streaming                         | Optional |
293/// | `_agenerate`             | Use to implement a native async method             | Optional |
294/// | `_astream`               | Use to implement async version of `_stream`        | Optional |
295#[async_trait]
296pub trait BaseChatModel: BaseLanguageModel {
297    /// Get the chat model configuration.
298    fn chat_config(&self) -> &ChatModelConfig;
299
300    /// Get the model profile, if available.
301    fn profile(&self) -> Option<&ModelProfile> {
302        self.chat_config().profile.as_ref()
303    }
304
305    /// Core abstract method to generate a chat result.
306    ///
307    /// Implementations must override this method.
308    ///
309    /// # Arguments
310    ///
311    /// * `messages` - The messages to generate from.
312    /// * `stop` - Optional list of stop words to use when generating.
313    /// * `run_manager` - Optional callback manager to use for this call.
314    ///
315    /// # Returns
316    ///
317    /// The output chat result containing generations.
318    async fn _generate(
319        &self,
320        messages: Vec<BaseMessage>,
321        stop: Option<Vec<String>>,
322        run_manager: Option<&CallbackManagerForLLMRun>,
323    ) -> Result<ChatResult>;
324
325    /// Async version of `_generate`.
326    ///
327    /// Default implementation calls `_generate`.
328    async fn _agenerate(
329        &self,
330        messages: Vec<BaseMessage>,
331        stop: Option<Vec<String>>,
332        _run_manager: Option<&AsyncCallbackManagerForLLMRun>,
333    ) -> Result<ChatResult> {
334        self._generate(messages, stop, None).await
335    }
336
337    /// Stream the output of the model.
338    ///
339    /// Default implementation raises NotImplementedError.
340    ///
341    /// # Arguments
342    ///
343    /// * `messages` - The messages to generate from.
344    /// * `stop` - Optional list of stop words to use when generating.
345    /// * `run_manager` - Optional callback manager to use for this call.
346    ///
347    /// # Yields
348    ///
349    /// The chat generation chunks.
350    fn _stream(
351        &self,
352        _messages: Vec<BaseMessage>,
353        _stop: Option<Vec<String>>,
354        _run_manager: Option<&CallbackManagerForLLMRun>,
355    ) -> Result<ChatGenerationStream> {
356        Err(Error::NotImplemented("Streaming not implemented".into()))
357    }
358
359    /// Async stream the output of the model.
360    ///
361    /// Default implementation calls `_stream`.
362    async fn _astream(
363        &self,
364        messages: Vec<BaseMessage>,
365        stop: Option<Vec<String>>,
366        _run_manager: Option<&AsyncCallbackManagerForLLMRun>,
367    ) -> Result<ChatGenerationStream> {
368        self._stream(messages, stop, None)
369    }
370
371    /// Get the first AI message from a chat result.
372    ///
373    /// Helper method to extract the first generation's message as an AIMessage.
374    fn get_first_message(&self, result: &ChatResult) -> Result<AIMessage> {
375        if result.generations.is_empty() {
376            return Err(Error::Other("No generations returned".into()));
377        }
378
379        match result.generations[0].message.clone() {
380            BaseMessage::AI(message) => Ok(message),
381            other => Ok(AIMessage::new(other.content())),
382        }
383    }
384
385    /// Combine LLM outputs from multiple results.
386    ///
387    /// This method is called after generating results from multiple prompts
388    /// to combine any LLM-specific output information.
389    ///
390    /// Default implementation returns an empty HashMap.
391    /// Subclasses can override to combine provider-specific output data.
392    fn _combine_llm_outputs(
393        &self,
394        _llm_outputs: &[Option<HashMap<String, Value>>],
395    ) -> HashMap<String, Value> {
396        HashMap::new()
397    }
398
399    /// Convert cached Generation objects to ChatGeneration objects.
400    ///
401    /// Handle case where cache contains Generation objects instead of
402    /// ChatGeneration objects. This can happen due to serialization/deserialization
403    /// issues or legacy cache data.
404    fn _convert_cached_generations(&self, cache_val: Vec<Generation>) -> Vec<ChatGeneration> {
405        cache_val
406            .into_iter()
407            .map(|cached_gen| {
408                // Convert Generation to ChatGeneration by creating AIMessage from text
409                let message = AIMessage::new(&cached_gen.text);
410                match cached_gen.generation_info {
411                    Some(info) => ChatGeneration::with_info(message.into(), info),
412                    None => ChatGeneration::new(message.into()),
413                }
414            })
415            .collect()
416    }
417
418    /// Get invocation parameters for tracing.
419    ///
420    /// Returns a HashMap containing the model configuration and stop sequences.
421    fn _get_invocation_params(
422        &self,
423        stop: Option<&[String]>,
424        kwargs: Option<&HashMap<String, Value>>,
425    ) -> HashMap<String, Value> {
426        let mut params = self.get_identifying_params();
427        if let Some(stop) = stop {
428            params.insert(
429                "stop".to_string(),
430                Value::Array(stop.iter().map(|s| Value::String(s.clone())).collect()),
431            );
432        }
433        if let Some(kw) = kwargs {
434            params.extend(kw.clone());
435        }
436        params
437    }
438
439    /// Get the LLM string for cache key generation.
440    ///
441    /// This string uniquely identifies the model configuration for caching purposes.
442    fn _get_llm_string(
443        &self,
444        stop: Option<&[String]>,
445        kwargs: Option<&HashMap<String, Value>>,
446    ) -> String {
447        let params = self._get_invocation_params(stop, kwargs);
448
449        // Sort params for deterministic key
450        let mut sorted_items: Vec<_> = params.iter().collect();
451        sorted_items.sort_by_key(|(k, _)| *k);
452
453        format!("{:?}", sorted_items)
454    }
455
456    /// Check if `_stream` is implemented (not the default).
457    ///
458    /// This is used by `_should_stream` to determine if streaming is available.
459    /// Implementations that override `_stream` should also override this to return `true`.
460    fn has_stream_impl(&self) -> bool {
461        false
462    }
463
464    /// Check if `_astream` is implemented (not the default).
465    ///
466    /// This is used by `_should_stream` to determine if async streaming is available.
467    /// Implementations that override `_astream` should also override this to return `true`.
468    fn has_astream_impl(&self) -> bool {
469        false
470    }
471
472    /// Check if streaming is enabled via a model field.
473    ///
474    /// Override this if the model has a `streaming` field that should be checked.
475    fn has_streaming_field(&self) -> Option<bool> {
476        None
477    }
478
479    /// Determine if a given model call should hit the streaming API.
480    ///
481    /// This method mirrors Python's `_should_stream` behavior:
482    /// 1. Check if streaming is implemented (either sync or async)
483    /// 2. Check if streaming has been disabled on this instance
484    /// 3. Check if streaming is disabled for tool calling and tools are present
485    /// 4. Check if streaming field is set on the model
486    /// 5. Check if any streaming callback handlers are present
487    ///
488    /// # Arguments
489    ///
490    /// * `async_api` - Whether this is an async API call
491    /// * `has_tools` - Whether tools are present in the call
492    /// * `stream_kwarg` - Optional explicit stream kwarg from caller
493    /// * `run_manager` - Optional callback manager for checking streaming handlers
494    ///
495    /// # Returns
496    ///
497    /// `true` if streaming should be used, `false` otherwise.
498    fn _should_stream(
499        &self,
500        async_api: bool,
501        has_tools: bool,
502        stream_kwarg: Option<bool>,
503        run_manager: Option<&[Arc<dyn BaseCallbackHandler>]>,
504    ) -> bool {
505        // Check if streaming is implemented
506        let sync_not_implemented = !self.has_stream_impl();
507        let async_not_implemented = !self.has_astream_impl();
508
509        // Check if streaming is implemented
510        if !async_api && sync_not_implemented {
511            return false;
512        }
513        // Note: since async falls back to sync, we check both here
514        if async_api && async_not_implemented && sync_not_implemented {
515            return false;
516        }
517
518        // Check if streaming has been disabled on this instance
519        if self
520            .chat_config()
521            .disable_streaming
522            .should_disable(has_tools)
523        {
524            return false;
525        }
526
527        // Check if a runtime streaming flag has been passed in
528        if let Some(stream) = stream_kwarg {
529            return stream;
530        }
531
532        // Check if streaming field is set on the model
533        if let Some(streaming) = self.has_streaming_field() {
534            return streaming;
535        }
536
537        // Check if any streaming callback handlers are present
538        if let Some(handlers) = run_manager {
539            // In Python, this checks for `_StreamingCallbackHandler` instances
540            // For Rust, we can check if any handler implements StreamingCallbackHandler
541            // This is a simplified check - in practice, you'd want a more sophisticated check
542            if !handlers.is_empty() {
543                // If there are any handlers, we assume streaming might be wanted
544                return true;
545            }
546        }
547
548        // Default: streaming is available and not disabled
549        true
550    }
551
552    /// Generate from a batch of message lists.
553    ///
554    /// This method should make use of batched calls for models that expose a batched API.
555    ///
556    /// Use this method when you want to:
557    /// 1. Take advantage of batched calls
558    /// 2. Need more output from the model than just the top generated value
559    /// 3. Are building chains that are agnostic to the underlying language model type
560    ///
561    /// # Arguments
562    ///
563    /// * `messages` - List of message lists.
564    /// * `stop` - Stop words to use when generating.
565    /// * `callbacks` - Callbacks to pass through.
566    ///
567    /// # Returns
568    ///
569    /// An `LLMResult` containing a list of candidate `ChatGeneration` objects.
570    async fn generate(
571        &self,
572        messages: Vec<Vec<BaseMessage>>,
573        stop: Option<Vec<String>>,
574        _callbacks: Option<Callbacks>,
575    ) -> Result<LLMResult> {
576        let mut all_generations: Vec<Vec<GenerationType>> = Vec::new();
577
578        for message_list in messages {
579            let result = self._generate(message_list, stop.clone(), None).await?;
580            all_generations.push(result.generations.into_iter().map(|e| e.into()).collect());
581        }
582
583        Ok(LLMResult::new(all_generations))
584    }
585
586    /// Async version of `generate`.
587    async fn agenerate(
588        &self,
589        messages: Vec<Vec<BaseMessage>>,
590        stop: Option<Vec<String>>,
591        _callbacks: Option<Callbacks>,
592    ) -> Result<LLMResult> {
593        let mut all_generations: Vec<Vec<GenerationType>> = Vec::new();
594
595        for message_list in messages {
596            let result = self._agenerate(message_list, stop.clone(), None).await?;
597            all_generations.push(result.generations.into_iter().map(|e| e.into()).collect());
598        }
599
600        Ok(LLMResult::new(all_generations))
601    }
602
603    /// Async call helper.
604    ///
605    /// This is a convenience method that wraps `agenerate` for single-message calls.
606    async fn _call_async(
607        &self,
608        messages: Vec<BaseMessage>,
609        stop: Option<Vec<String>>,
610        callbacks: Option<Callbacks>,
611    ) -> Result<BaseMessage> {
612        let result = self.agenerate(vec![messages], stop, callbacks).await?;
613
614        if result.generations.is_empty() || result.generations[0].is_empty() {
615            return Err(Error::Other("No generations returned".into()));
616        }
617
618        match &result.generations[0][0] {
619            GenerationType::ChatGeneration(chat_gen) => Ok(chat_gen.message.clone()),
620            _ => Err(Error::Other("Unexpected generation type".into())),
621        }
622    }
623
624    /// Generate a response from the model with tools.
625    ///
626    /// This is the preferred method when tool calling is needed.
627    /// Default implementation ignores tools and calls `_generate`.
628    ///
629    /// # Arguments
630    ///
631    /// * `messages` - The conversation history.
632    /// * `tools` - Tool definitions for the model to use.
633    /// * `tool_choice` - Optional configuration for tool selection.
634    /// * `stop` - Optional stop sequences.
635    ///
636    /// # Returns
637    ///
638    /// An `AIMessage` containing the generated response.
639    async fn generate_with_tools(
640        &self,
641        messages: Vec<BaseMessage>,
642        _tools: &[ToolDefinition],
643        _tool_choice: Option<&ToolChoice>,
644        stop: Option<Vec<String>>,
645    ) -> Result<AIMessage> {
646        let result = self._generate(messages, stop, None).await?;
647
648        if result.generations.is_empty() {
649            return Err(Error::Other("No generations returned".into()));
650        }
651
652        match result.generations[0].message.clone() {
653            BaseMessage::AI(message) => Ok(message),
654            _ => Err(Error::Other("Unexpected message type".into())),
655        }
656    }
657
658    /// Convert input to messages.
659    fn convert_input(&self, input: LanguageModelInput) -> Result<Vec<BaseMessage>> {
660        Ok(input.to_messages())
661    }
662
663    /// Invoke the model with input.
664    async fn invoke(&self, input: LanguageModelInput) -> Result<AIMessage> {
665        let messages = self.convert_input(input)?;
666        let result = self._generate(messages, None, None).await?;
667
668        if result.generations.is_empty() {
669            return Err(Error::Other("No generations returned".into()));
670        }
671
672        match result.generations[0].message.clone() {
673            BaseMessage::AI(message) => Ok(message),
674            _ => Err(Error::Other("Unexpected message type".into())),
675        }
676    }
677
678    /// Async invoke the model.
679    async fn ainvoke(&self, input: LanguageModelInput) -> Result<AIMessage> {
680        let messages = self.convert_input(input)?;
681        let result = self._agenerate(messages, None, None).await?;
682
683        if result.generations.is_empty() {
684            return Err(Error::Other("No generations returned".into()));
685        }
686
687        match result.generations[0].message.clone() {
688            BaseMessage::AI(message) => Ok(message),
689            _ => Err(Error::Other("Unexpected message type".into())),
690        }
691    }
692
693    /// Bind tools to the model.
694    ///
695    /// This method returns a Runnable that can call tools. The default
696    /// implementation raises NotImplementedError.
697    ///
698    /// # Arguments
699    ///
700    /// * `tools` - Sequence of tools to bind to the model.
701    /// * `tool_choice` - Optional tool choice configuration.
702    ///
703    /// # Returns
704    ///
705    /// A Result with error indicating tools are not supported.
706    ///
707    /// # Note
708    ///
709    /// Provider implementations should override this method to return a configured model.
710    fn bind_tools(
711        &self,
712        _tools: &[Arc<dyn BaseTool>],
713        _tool_choice: Option<ToolChoice>,
714    ) -> Result<()> {
715        Err(Error::NotImplemented(
716            "bind_tools is not implemented for this model".into(),
717        ))
718    }
719
720    /// Get tool definitions from tools.
721    ///
722    /// Helper method to convert tools to their definitions.
723    fn get_tool_definitions(&self, tools: &[Arc<dyn BaseTool>]) -> Vec<ToolDefinition> {
724        tools.iter().map(|t| t.definition()).collect()
725    }
726
727    /// Generate a streaming response from the model.
728    ///
729    /// This is the main streaming API. It yields `AIMessageChunk`s.
730    /// Providers should override `_stream` for native streaming support.
731    ///
732    /// # Arguments
733    ///
734    /// * `input` - The input to the model (string, messages, or PromptValue).
735    /// * `stop` - Optional stop sequences.
736    ///
737    /// # Returns
738    ///
739    /// A stream of `AIMessageChunk`s.
740    async fn stream(
741        &self,
742        input: LanguageModelInput,
743        stop: Option<Vec<String>>,
744    ) -> Result<AIMessageChunkStream> {
745        let messages = self.convert_input(input)?;
746        let has_tools = false;
747
748        // Check if streaming should be used
749        if !self._should_stream(false, has_tools, Some(true), None) {
750            // Model doesn't implement streaming or streaming is disabled,
751            // fall back to invoke and yield the result as a single chunk
752            let result = self._generate(messages, stop, None).await?;
753            let message = self.get_first_message(&result)?;
754            let chunk = AIMessageChunk::new(message.content());
755            return Ok(Box::pin(futures::stream::once(async move { Ok(chunk) })));
756        }
757
758        // Acquire rate limiter if configured (blocking until token available)
759        if let Some(ref rate_limiter) = self.chat_config().rate_limiter {
760            rate_limiter.acquire(true);
761        }
762
763        // Use the _stream method and convert ChatGenerationChunk to AIMessageChunk
764        let generation_stream = self._stream(messages, stop, None)?;
765
766        // Transform the stream to yield AIMessageChunk instead of ChatGenerationChunk
767        let chunk_stream = async_stream::stream! {
768            use futures::StreamExt;
769
770            let mut pinned_stream = generation_stream;
771            let mut yielded = false;
772
773            while let Some(result) = pinned_stream.next().await {
774                match result {
775                    Ok(generation_chunk) => {
776                        // Extract AIMessageChunk from the generation chunk
777                        let ai_chunk = match generation_chunk.message {
778                            BaseMessage::AI(ai_msg) => AIMessageChunk::new(ai_msg.content()),
779                            other => AIMessageChunk::new(other.content()),
780                        };
781                        yielded = true;
782                        yield Ok(ai_chunk);
783                    }
784                    Err(e) => {
785                        yield Err(e);
786                        return;
787                    }
788                }
789            }
790
791            // Yield a final empty chunk with chunk_position="last" if we yielded anything
792            if yielded {
793                let mut final_chunk = AIMessageChunk::new("");
794                final_chunk.set_chunk_position(Some(ChunkPosition::Last));
795                yield Ok(final_chunk);
796            }
797        };
798
799        Ok(Box::pin(chunk_stream))
800    }
801
802    /// Async stream the model output.
803    ///
804    /// This is the async version of `stream`. It yields `AIMessageChunk`s.
805    /// Providers should override `_astream` for native async streaming support.
806    ///
807    /// # Arguments
808    ///
809    /// * `input` - The input to the model (string, messages, or PromptValue).
810    /// * `stop` - Optional stop sequences.
811    ///
812    /// # Returns
813    ///
814    /// A stream of `AIMessageChunk`s.
815    async fn astream(
816        &self,
817        input: LanguageModelInput,
818        stop: Option<Vec<String>>,
819    ) -> Result<AIMessageChunkStream> {
820        let messages = self.convert_input(input)?;
821        let has_tools = false;
822
823        // Check if streaming should be used
824        if !self._should_stream(true, has_tools, Some(true), None) {
825            // No async or sync stream is implemented, fall back to ainvoke
826            let result = self._agenerate(messages, stop, None).await?;
827            let message = self.get_first_message(&result)?;
828            let chunk = AIMessageChunk::new(message.content());
829            return Ok(Box::pin(futures::stream::once(async move { Ok(chunk) })));
830        }
831
832        // Acquire rate limiter if configured (blocking until token available)
833        if let Some(ref rate_limiter) = self.chat_config().rate_limiter {
834            rate_limiter.aacquire(true).await;
835        }
836
837        // Use the _astream method
838        let generation_stream = self._astream(messages, stop, None).await?;
839
840        // Transform the stream to yield AIMessageChunk instead of ChatGenerationChunk
841        let chunk_stream = async_stream::stream! {
842            use futures::StreamExt;
843
844            let mut pinned_stream = generation_stream;
845            let mut yielded = false;
846
847            while let Some(result) = pinned_stream.next().await {
848                match result {
849                    Ok(generation_chunk) => {
850                        // Extract AIMessageChunk from the generation chunk
851                        let ai_chunk = match generation_chunk.message {
852                            BaseMessage::AI(ai_msg) => AIMessageChunk::new(ai_msg.content()),
853                            other => AIMessageChunk::new(other.content()),
854                        };
855                        yielded = true;
856                        yield Ok(ai_chunk);
857                    }
858                    Err(e) => {
859                        yield Err(e);
860                        return;
861                    }
862                }
863            }
864
865            // Yield a final empty chunk with chunk_position="last" if we yielded anything
866            if yielded {
867                let mut final_chunk = AIMessageChunk::new("");
868                final_chunk.set_chunk_position(Some(ChunkPosition::Last));
869                yield Ok(final_chunk);
870            }
871        };
872
873        Ok(Box::pin(chunk_stream))
874    }
875
876    /// Stream ChatGenerationChunk objects from the model.
877    ///
878    /// This is a lower-level streaming API that yields `ChatGenerationChunk`s directly.
879    /// Most users should use `stream()` or `astream()` instead.
880    ///
881    /// # Arguments
882    ///
883    /// * `messages` - The conversation history.
884    /// * `stop` - Optional stop sequences.
885    /// * `run_manager` - Optional callback manager for the run.
886    ///
887    /// # Returns
888    ///
889    /// A stream of `ChatGenerationChunk`s.
890    async fn stream_generations(
891        &self,
892        messages: Vec<BaseMessage>,
893        stop: Option<Vec<String>>,
894        run_manager: Option<&CallbackManagerForLLMRun>,
895    ) -> Result<ChatGenerationStream> {
896        let has_tools = false;
897
898        // Check if streaming should be used
899        if !self._should_stream(false, has_tools, None, None) {
900            // Fall back to non-streaming
901            let result = self._generate(messages, stop, run_manager).await?;
902            if result.generations.is_empty() {
903                return Err(Error::Other("No generations returned".into()));
904            }
905
906            let message = result.generations[0].message.clone();
907            let chunk = ChatGenerationChunk::new(message);
908            return Ok(Box::pin(futures::stream::once(async move { Ok(chunk) })));
909        }
910
911        // Try to use streaming
912        self._stream(messages, stop, run_manager)
913    }
914
915    /// Get standard params for tracing.
916    fn get_chat_ls_params(&self, stop: Option<&[String]>) -> LangSmithParams {
917        let mut params = self.get_ls_params(stop);
918        params.ls_model_type = Some("chat".to_string());
919        params
920    }
921
922    /// Get a dictionary representation of the model.
923    ///
924    /// Returns identifying parameters plus the model type.
925    fn to_dict(&self) -> HashMap<String, Value> {
926        let mut result = self.get_identifying_params();
927        result.insert(
928            "_type".to_string(),
929            Value::String(self.llm_type().to_string()),
930        );
931        result
932    }
933
934    /// Create a wrapper that structures model output using a schema.
935    ///
936    /// This method returns a Runnable that formats outputs to match the given schema.
937    /// The default implementation raises NotImplementedError.
938    ///
939    /// # Arguments
940    ///
941    /// * `schema` - The output schema (as a JSON value).
942    /// * `include_raw` - If true, include raw model response in output.
943    ///
944    /// # Returns
945    ///
946    /// A Result with error indicating structured output is not supported.
947    ///
948    /// # Note
949    ///
950    /// Provider implementations should override `bind_tools` first, as the default
951    /// implementation uses `bind_tools` internally.
952    fn with_structured_output(&self, _schema: Value, _include_raw: bool) -> Result<()> {
953        Err(Error::NotImplemented(
954            "with_structured_output is not implemented for this model".into(),
955        ))
956    }
957
958    /// Get the identifying parameters for this model.
959    ///
960    /// Returns a map of parameters that uniquely identify this model instance.
961    fn get_identifying_params(&self) -> HashMap<String, Value> {
962        let mut params = HashMap::new();
963        params.insert(
964            "_type".to_string(),
965            Value::String(self.llm_type().to_string()),
966        );
967        params.insert(
968            "model".to_string(),
969            Value::String(self.model_name().to_string()),
970        );
971        params
972    }
973}
974
975/// Simplified implementation for a chat model to inherit from.
976///
977/// This implementation is primarily here for backwards compatibility.
978/// For new implementations, please use `BaseChatModel` directly.
979#[async_trait]
980pub trait SimpleChatModel: BaseChatModel {
981    /// Simple call method that takes messages and returns a string.
982    ///
983    /// Implementations should override this method.
984    async fn _call(
985        &self,
986        messages: Vec<BaseMessage>,
987        stop: Option<Vec<String>>,
988        run_manager: Option<&CallbackManagerForLLMRun>,
989    ) -> Result<String>;
990}
991
992#[async_trait]
993impl<T: SimpleChatModel> BaseChatModel for T {
994    fn chat_config(&self) -> &ChatModelConfig {
995        <T as BaseChatModel>::chat_config(self)
996    }
997
998    async fn _generate(
999        &self,
1000        messages: Vec<BaseMessage>,
1001        stop: Option<Vec<String>>,
1002        run_manager: Option<&CallbackManagerForLLMRun>,
1003    ) -> Result<ChatResult> {
1004        let output_str = self._call(messages, stop, run_manager).await?;
1005        let message = AIMessage::new(output_str);
1006        let generation = ChatGeneration::new(message.into());
1007        Ok(ChatResult::new(vec![generation]))
1008    }
1009}
1010
1011/// Generate from a stream of chunks.
1012///
1013/// Collects all chunks from the stream and generates a final ChatResult.
1014///
1015/// This corresponds to `generate_from_stream` in LangChain Python.
1016///
1017/// # Arguments
1018///
1019/// * `stream` - An iterator of `ChatGenerationChunk` objects.
1020///
1021/// # Returns
1022///
1023/// A `ChatResult` containing the merged generation.
1024///
1025/// # Errors
1026///
1027/// Returns an error if no generations are found in the stream.
1028pub fn generate_from_stream<I>(mut stream: I) -> Result<ChatResult>
1029where
1030    I: Iterator<Item = ChatGenerationChunk>,
1031{
1032    let first = stream.next();
1033    if first.is_none() {
1034        return Err(Error::Other("No generations found in stream.".into()));
1035    }
1036
1037    let mut generation = first.unwrap();
1038
1039    // Merge remaining chunks
1040    for chunk in stream {
1041        generation = generation + chunk;
1042    }
1043
1044    // Convert ChatGenerationChunk to ChatGeneration
1045    let chat_generation: ChatGeneration = generation.into();
1046    Ok(ChatResult::new(vec![chat_generation]))
1047}
1048
1049/// Async generate from a stream of chunks.
1050///
1051/// Collects all chunks from an async stream and generates a final ChatResult.
1052///
1053/// This corresponds to `agenerate_from_stream` in LangChain Python.
1054///
1055/// # Arguments
1056///
1057/// * `stream` - An async stream of `ChatGenerationChunk` objects.
1058///
1059/// # Returns
1060///
1061/// A `ChatResult` containing the merged generation.
1062///
1063/// # Errors
1064///
1065/// Returns an error if no generations are found in the stream.
1066pub async fn agenerate_from_stream(
1067    stream: impl futures::Stream<Item = Result<ChatGenerationChunk>> + Unpin,
1068) -> Result<ChatResult> {
1069    use futures::StreamExt;
1070
1071    let chunks: Vec<ChatGenerationChunk> = stream
1072        .filter_map(|result| async { result.ok() })
1073        .collect()
1074        .await;
1075
1076    if chunks.is_empty() {
1077        return Err(Error::Other("No generations found in stream.".into()));
1078    }
1079
1080    generate_from_stream(chunks.into_iter())
1081}
1082
1083/// Collect a stream of ChatGenerationChunks and merge them.
1084///
1085/// This is a convenience function that collects all chunks from a stream
1086/// and returns the merged result.
1087///
1088/// # Arguments
1089///
1090/// * `stream` - An async stream of `ChatGenerationChunk` results.
1091///
1092/// # Returns
1093///
1094/// The merged `ChatGenerationChunk`, or `None` if the stream was empty.
1095pub async fn collect_and_merge_stream(
1096    mut stream: impl futures::StreamExt<Item = Result<ChatGenerationChunk>> + Unpin,
1097) -> Result<Option<ChatGenerationChunk>> {
1098    let mut chunks = Vec::new();
1099    while let Some(chunk_result) = stream.next().await {
1100        chunks.push(chunk_result?);
1101    }
1102
1103    if chunks.is_empty() {
1104        return Ok(None);
1105    }
1106
1107    Ok(crate::outputs::merge_chat_generation_chunks(chunks))
1108}
1109
1110#[cfg(test)]
1111mod tests {
1112    use super::*;
1113
1114    #[test]
1115    fn test_chat_model_config_builder() {
1116        let config = ChatModelConfig::new()
1117            .with_cache(true)
1118            .with_verbose(true)
1119            .with_disable_streaming(true)
1120            .with_output_version("v1");
1121
1122        assert_eq!(config.base.cache, Some(true));
1123        assert!(config.base.verbose);
1124        assert_eq!(config.disable_streaming, DisableStreaming::Bool(true));
1125        assert_eq!(config.output_version, Some("v1".to_string()));
1126    }
1127
1128    #[test]
1129    fn test_tool_choice_auto() {
1130        let choice = ToolChoice::auto();
1131        assert_eq!(choice, ToolChoice::String("auto".to_string()));
1132    }
1133
1134    #[test]
1135    fn test_tool_choice_any() {
1136        let choice = ToolChoice::any();
1137        assert_eq!(choice, ToolChoice::String("any".to_string()));
1138    }
1139
1140    #[test]
1141    fn test_tool_choice_none() {
1142        let choice = ToolChoice::none();
1143        assert_eq!(choice, ToolChoice::String("none".to_string()));
1144    }
1145
1146    #[test]
1147    fn test_tool_choice_tool() {
1148        let choice = ToolChoice::tool("my_tool");
1149        assert_eq!(
1150            choice,
1151            ToolChoice::Structured {
1152                choice_type: "tool".to_string(),
1153                name: Some("my_tool".to_string()),
1154            }
1155        );
1156    }
1157
1158    #[test]
1159    fn test_tool_choice_serialization() {
1160        let auto = ToolChoice::auto();
1161        let json = serde_json::to_string(&auto).unwrap();
1162        assert_eq!(json, "\"auto\"");
1163
1164        let tool = ToolChoice::tool("my_tool");
1165        let json = serde_json::to_string(&tool).unwrap();
1166        assert!(json.contains("my_tool"));
1167        assert!(json.contains("tool"));
1168    }
1169
1170    #[test]
1171    fn test_disable_streaming() {
1172        let bool_false = DisableStreaming::Bool(false);
1173        assert!(!bool_false.should_disable(true));
1174        assert!(!bool_false.should_disable(false));
1175
1176        let bool_true = DisableStreaming::Bool(true);
1177        assert!(bool_true.should_disable(true));
1178        assert!(bool_true.should_disable(false));
1179
1180        let tool_calling = DisableStreaming::ToolCalling;
1181        assert!(tool_calling.should_disable(true));
1182        assert!(!tool_calling.should_disable(false));
1183    }
1184
1185    #[test]
1186    fn test_generate_from_stream() {
1187        let chunks = vec![
1188            ChatGenerationChunk::new(AIMessage::new("Hello, ").into()),
1189            ChatGenerationChunk::new(AIMessage::new("world!").into()),
1190        ];
1191
1192        let result = generate_from_stream(chunks.into_iter()).unwrap();
1193        assert_eq!(result.generations.len(), 1);
1194        assert_eq!(result.generations[0].message.content(), "Hello, world!");
1195    }
1196
1197    #[test]
1198    fn test_generate_from_stream_empty() {
1199        let chunks: Vec<ChatGenerationChunk> = vec![];
1200        let result = generate_from_stream(chunks.into_iter());
1201        assert!(result.is_err());
1202    }
1203
1204    #[tokio::test]
1205    async fn test_agenerate_from_stream() {
1206        let chunks = vec![
1207            Ok(ChatGenerationChunk::new(AIMessage::new("Hello, ").into())),
1208            Ok(ChatGenerationChunk::new(AIMessage::new("world!").into())),
1209        ];
1210
1211        let stream = futures::stream::iter(chunks);
1212        let result = agenerate_from_stream(stream).await.unwrap();
1213        assert_eq!(result.generations.len(), 1);
1214        assert_eq!(result.generations[0].message.content(), "Hello, world!");
1215    }
1216
1217    #[tokio::test]
1218    async fn test_collect_and_merge_stream() {
1219        let chunks = vec![
1220            Ok(ChatGenerationChunk::new(AIMessage::new("a").into())),
1221            Ok(ChatGenerationChunk::new(AIMessage::new("b").into())),
1222            Ok(ChatGenerationChunk::new(AIMessage::new("c").into())),
1223        ];
1224
1225        let stream = futures::stream::iter(chunks);
1226        let merged = collect_and_merge_stream(stream).await.unwrap();
1227
1228        assert!(merged.is_some());
1229        assert_eq!(merged.unwrap().text, "abc");
1230    }
1231
1232    #[tokio::test]
1233    async fn test_collect_and_merge_stream_empty() {
1234        let chunks: Vec<Result<ChatGenerationChunk>> = vec![];
1235        let stream = futures::stream::iter(chunks);
1236        let merged = collect_and_merge_stream(stream).await.unwrap();
1237        assert!(merged.is_none());
1238    }
1239}