agent_chain_core/language_models/
base.rs

1//! Base language model class.
2//!
3//! This module provides the foundational abstractions for language models,
4//! mirroring `langchain_core.language_models.base`.
5
6use std::collections::HashMap;
7use std::pin::Pin;
8
9use async_trait::async_trait;
10use futures::Stream;
11use serde::{Deserialize, Serialize};
12use serde_json::Value;
13
14use crate::caches::BaseCache;
15use crate::callbacks::Callbacks;
16use crate::error::Result;
17use crate::messages::{AIMessage, BaseMessage};
18use crate::outputs::LLMResult;
19
20/// Parameters for LangSmith tracing.
21///
22/// These parameters are used for tracing and monitoring language model calls.
23#[derive(Debug, Clone, Default, Serialize, Deserialize)]
24pub struct LangSmithParams {
25    /// Provider of the model (e.g., "anthropic", "openai").
26    #[serde(skip_serializing_if = "Option::is_none")]
27    pub ls_provider: Option<String>,
28
29    /// Name of the model.
30    #[serde(skip_serializing_if = "Option::is_none")]
31    pub ls_model_name: Option<String>,
32
33    /// Type of the model. Should be "chat" or "llm".
34    #[serde(skip_serializing_if = "Option::is_none")]
35    pub ls_model_type: Option<String>,
36
37    /// Temperature for generation.
38    #[serde(skip_serializing_if = "Option::is_none")]
39    pub ls_temperature: Option<f64>,
40
41    /// Max tokens for generation.
42    #[serde(skip_serializing_if = "Option::is_none")]
43    pub ls_max_tokens: Option<u32>,
44
45    /// Stop words for generation.
46    #[serde(skip_serializing_if = "Option::is_none")]
47    pub ls_stop: Option<Vec<String>>,
48}
49
50impl LangSmithParams {
51    /// Create new LangSmith params with default values.
52    pub fn new() -> Self {
53        Self::default()
54    }
55
56    /// Set the provider.
57    pub fn with_provider(mut self, provider: impl Into<String>) -> Self {
58        self.ls_provider = Some(provider.into());
59        self
60    }
61
62    /// Set the model name.
63    pub fn with_model_name(mut self, model_name: impl Into<String>) -> Self {
64        self.ls_model_name = Some(model_name.into());
65        self
66    }
67
68    /// Set the model type.
69    pub fn with_model_type(mut self, model_type: impl Into<String>) -> Self {
70        self.ls_model_type = Some(model_type.into());
71        self
72    }
73
74    /// Set the temperature.
75    pub fn with_temperature(mut self, temperature: f64) -> Self {
76        self.ls_temperature = Some(temperature);
77        self
78    }
79
80    /// Set the max tokens.
81    pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
82        self.ls_max_tokens = Some(max_tokens);
83        self
84    }
85
86    /// Set the stop sequences.
87    pub fn with_stop(mut self, stop: Vec<String>) -> Self {
88        self.ls_stop = Some(stop);
89        self
90    }
91}
92
93use crate::prompt_values::{ChatPromptValue, ImagePromptValue, StringPromptValue};
94
95/// Input to a language model.
96///
97/// Can be a string, a prompt value, or a sequence of messages.
98#[derive(Debug, Clone)]
99pub enum LanguageModelInput {
100    /// A simple string input.
101    Text(String),
102    /// A string prompt value.
103    StringPrompt(StringPromptValue),
104    /// A chat prompt value (messages).
105    ChatPrompt(ChatPromptValue),
106    /// An image prompt value.
107    ImagePrompt(ImagePromptValue),
108    /// A sequence of messages.
109    Messages(Vec<BaseMessage>),
110}
111
112impl From<String> for LanguageModelInput {
113    fn from(s: String) -> Self {
114        LanguageModelInput::Text(s)
115    }
116}
117
118impl From<&str> for LanguageModelInput {
119    fn from(s: &str) -> Self {
120        LanguageModelInput::Text(s.to_string())
121    }
122}
123
124impl From<StringPromptValue> for LanguageModelInput {
125    fn from(p: StringPromptValue) -> Self {
126        LanguageModelInput::StringPrompt(p)
127    }
128}
129
130impl From<ChatPromptValue> for LanguageModelInput {
131    fn from(p: ChatPromptValue) -> Self {
132        LanguageModelInput::ChatPrompt(p)
133    }
134}
135
136impl From<ImagePromptValue> for LanguageModelInput {
137    fn from(p: ImagePromptValue) -> Self {
138        LanguageModelInput::ImagePrompt(p)
139    }
140}
141
142impl From<Vec<BaseMessage>> for LanguageModelInput {
143    fn from(m: Vec<BaseMessage>) -> Self {
144        LanguageModelInput::Messages(m)
145    }
146}
147
148impl LanguageModelInput {
149    /// Convert the input to messages.
150    pub fn to_messages(&self) -> Vec<BaseMessage> {
151        use crate::prompt_values::PromptValue;
152        match self {
153            LanguageModelInput::Text(s) => {
154                vec![BaseMessage::Human(crate::messages::HumanMessage::new(s))]
155            }
156            LanguageModelInput::StringPrompt(p) => p.to_messages(),
157            LanguageModelInput::ChatPrompt(p) => p.to_messages(),
158            LanguageModelInput::ImagePrompt(p) => p.to_messages(),
159            LanguageModelInput::Messages(m) => m.clone(),
160        }
161    }
162}
163
164impl std::fmt::Display for LanguageModelInput {
165    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
166        use crate::prompt_values::PromptValue;
167        match self {
168            LanguageModelInput::Text(s) => write!(f, "{}", s),
169            LanguageModelInput::StringPrompt(p) => write!(f, "{}", PromptValue::to_string(p)),
170            LanguageModelInput::ChatPrompt(p) => write!(f, "{}", PromptValue::to_string(p)),
171            LanguageModelInput::ImagePrompt(p) => write!(f, "{}", PromptValue::to_string(p)),
172            LanguageModelInput::Messages(m) => {
173                let joined = m
174                    .iter()
175                    .map(|msg| format!("{}: {}", msg.message_type(), msg.content()))
176                    .collect::<Vec<_>>()
177                    .join("\n");
178                write!(f, "{}", joined)
179            }
180        }
181    }
182}
183
184/// Output from a language model.
185///
186/// Can be either a message (from chat models) or a string (from LLMs).
187#[derive(Debug, Clone)]
188pub enum LanguageModelOutput {
189    /// A message output (from chat models).
190    Message(Box<AIMessage>),
191    /// A string output (from LLMs).
192    Text(String),
193}
194
195impl From<AIMessage> for LanguageModelOutput {
196    fn from(m: AIMessage) -> Self {
197        LanguageModelOutput::Message(Box::new(m))
198    }
199}
200
201impl From<String> for LanguageModelOutput {
202    fn from(s: String) -> Self {
203        LanguageModelOutput::Text(s)
204    }
205}
206
207impl LanguageModelOutput {
208    /// Get the text content of the output.
209    pub fn text(&self) -> &str {
210        match self {
211            LanguageModelOutput::Message(m) => m.content(),
212            LanguageModelOutput::Text(s) => s,
213        }
214    }
215
216    /// Convert to string, consuming the output.
217    pub fn into_text(self) -> String {
218        match self {
219            LanguageModelOutput::Message(m) => m.content().to_string(),
220            LanguageModelOutput::Text(s) => s,
221        }
222    }
223
224    /// Create a Message variant from an AIMessage.
225    pub fn message(m: AIMessage) -> Self {
226        LanguageModelOutput::Message(Box::new(m))
227    }
228}
229
230/// Configuration for a language model.
231#[derive(Debug, Clone, Default, Serialize, Deserialize)]
232pub struct LanguageModelConfig {
233    /// Whether to cache the response.
234    ///
235    /// - If `true`, will use the global cache.
236    /// - If `false`, will not use a cache.
237    /// - If not set (`None`), will use the global cache if it's set.
238    #[serde(skip_serializing_if = "Option::is_none")]
239    pub cache: Option<bool>,
240
241    /// Whether to print verbose output.
242    #[serde(default)]
243    pub verbose: bool,
244
245    /// Tags to add to the run trace.
246    #[serde(skip_serializing_if = "Option::is_none")]
247    pub tags: Option<Vec<String>>,
248
249    /// Metadata to add to the run trace.
250    #[serde(skip_serializing_if = "Option::is_none")]
251    pub metadata: Option<HashMap<String, Value>>,
252}
253
254impl LanguageModelConfig {
255    /// Create a new configuration with defaults.
256    pub fn new() -> Self {
257        Self::default()
258    }
259
260    /// Enable caching.
261    pub fn with_cache(mut self, cache: bool) -> Self {
262        self.cache = Some(cache);
263        self
264    }
265
266    /// Enable verbose mode.
267    pub fn with_verbose(mut self, verbose: bool) -> Self {
268        self.verbose = verbose;
269        self
270    }
271
272    /// Set tags.
273    pub fn with_tags(mut self, tags: Vec<String>) -> Self {
274        self.tags = Some(tags);
275        self
276    }
277
278    /// Set metadata.
279    pub fn with_metadata(mut self, metadata: HashMap<String, Value>) -> Self {
280        self.metadata = Some(metadata);
281        self
282    }
283}
284
285/// Abstract base trait for interfacing with language models.
286///
287/// All language model wrappers inherit from `BaseLanguageModel`.
288/// This trait provides common functionality for both chat models and traditional LLMs.
289#[async_trait]
290pub trait BaseLanguageModel: Send + Sync {
291    /// Return the type identifier for this language model.
292    ///
293    /// This is used for logging and tracing purposes.
294    fn llm_type(&self) -> &str;
295
296    /// Get the model name/identifier.
297    fn model_name(&self) -> &str;
298
299    /// Get the configuration for this model.
300    fn config(&self) -> &LanguageModelConfig;
301
302    /// Get the cache for this model, if any.
303    fn cache(&self) -> Option<&dyn BaseCache> {
304        None
305    }
306
307    /// Get the callbacks for this model.
308    fn callbacks(&self) -> Option<&Callbacks> {
309        None
310    }
311
312    /// Pass a sequence of prompts to the model and return model generations.
313    ///
314    /// This method should make use of batched calls for models that expose a batched API.
315    ///
316    /// # Arguments
317    ///
318    /// * `prompts` - List of `PromptValue` objects.
319    /// * `stop` - Stop words to use when generating.
320    /// * `callbacks` - Callbacks to pass through.
321    ///
322    /// # Returns
323    ///
324    /// An `LLMResult`, which contains a list of candidate `Generation` objects.
325    async fn generate_prompt(
326        &self,
327        prompts: Vec<LanguageModelInput>,
328        stop: Option<Vec<String>>,
329        callbacks: Option<Callbacks>,
330    ) -> Result<LLMResult>;
331
332    /// Get parameters for tracing/monitoring.
333    fn get_ls_params(&self, stop: Option<&[String]>) -> LangSmithParams {
334        let mut params = LangSmithParams::new();
335
336        // Try to determine provider from class name
337        let llm_type = self.llm_type();
338        let provider = if llm_type.starts_with("Chat") {
339            llm_type
340                .strip_prefix("Chat")
341                .unwrap_or(llm_type)
342                .to_lowercase()
343        } else if llm_type.ends_with("Chat") {
344            llm_type
345                .strip_suffix("Chat")
346                .unwrap_or(llm_type)
347                .to_lowercase()
348        } else {
349            llm_type.to_lowercase()
350        };
351
352        params.ls_provider = Some(provider);
353        params.ls_model_name = Some(self.model_name().to_string());
354
355        if let Some(stop) = stop {
356            params.ls_stop = Some(stop.to_vec());
357        }
358
359        params
360    }
361
362    /// Get the identifying parameters for this model.
363    fn identifying_params(&self) -> HashMap<String, Value> {
364        let mut params = HashMap::new();
365        params.insert(
366            "_type".to_string(),
367            Value::String(self.llm_type().to_string()),
368        );
369        params.insert(
370            "model".to_string(),
371            Value::String(self.model_name().to_string()),
372        );
373        params
374    }
375
376    /// Get the ordered IDs of tokens in a text.
377    ///
378    /// # Arguments
379    ///
380    /// * `text` - The string input to tokenize.
381    ///
382    /// # Returns
383    ///
384    /// A list of token IDs.
385    fn get_token_ids(&self, text: &str) -> Vec<u32> {
386        // Default implementation: rough estimate based on whitespace
387        // Actual implementations should use proper tokenizers
388        text.split_whitespace()
389            .enumerate()
390            .map(|(i, _)| i as u32)
391            .collect()
392    }
393
394    /// Get the number of tokens present in the text.
395    ///
396    /// # Arguments
397    ///
398    /// * `text` - The string input to tokenize.
399    ///
400    /// # Returns
401    ///
402    /// The number of tokens in the text.
403    fn get_num_tokens(&self, text: &str) -> usize {
404        self.get_token_ids(text).len()
405    }
406
407    /// Get the number of tokens in the messages.
408    ///
409    /// # Arguments
410    ///
411    /// * `messages` - The message inputs to tokenize.
412    ///
413    /// # Returns
414    ///
415    /// The sum of the number of tokens across the messages.
416    fn get_num_tokens_from_messages(&self, messages: &[BaseMessage]) -> usize {
417        messages
418            .iter()
419            .map(|m| {
420                // Add some tokens for the message role/type
421                let role_tokens = 4; // Approximate overhead for role
422                let content_tokens = self.get_num_tokens(m.content());
423                role_tokens + content_tokens
424            })
425            .sum()
426    }
427}
428
429/// Type alias for a boxed language model output stream.
430#[allow(dead_code)]
431pub type LanguageModelOutputStream =
432    Pin<Box<dyn Stream<Item = Result<LanguageModelOutput>> + Send>>;
433
434#[cfg(test)]
435mod tests {
436    use super::*;
437
438    #[test]
439    fn test_langsmith_params_builder() {
440        let params = LangSmithParams::new()
441            .with_provider("openai")
442            .with_model_name("gpt-4")
443            .with_model_type("chat")
444            .with_temperature(0.7)
445            .with_max_tokens(1000)
446            .with_stop(vec!["STOP".to_string()]);
447
448        assert_eq!(params.ls_provider, Some("openai".to_string()));
449        assert_eq!(params.ls_model_name, Some("gpt-4".to_string()));
450        assert_eq!(params.ls_model_type, Some("chat".to_string()));
451        assert_eq!(params.ls_temperature, Some(0.7));
452        assert_eq!(params.ls_max_tokens, Some(1000));
453        assert_eq!(params.ls_stop, Some(vec!["STOP".to_string()]));
454    }
455
456    #[test]
457    fn test_language_model_input_from_str() {
458        let input: LanguageModelInput = "Hello".into();
459        match input {
460            LanguageModelInput::Text(s) => assert_eq!(s, "Hello"),
461            _ => panic!("Expected Text variant"),
462        }
463    }
464
465    #[test]
466    fn test_language_model_output_text() {
467        let output = LanguageModelOutput::Text("Hello".to_string());
468        assert_eq!(output.text(), "Hello");
469        assert_eq!(output.into_text(), "Hello");
470    }
471
472    #[test]
473    fn test_language_model_config_builder() {
474        let config = LanguageModelConfig::new()
475            .with_cache(true)
476            .with_verbose(true)
477            .with_tags(vec!["test".to_string()]);
478
479        assert_eq!(config.cache, Some(true));
480        assert!(config.verbose);
481        assert_eq!(config.tags, Some(vec!["test".to_string()]));
482    }
483}