agent_chain_core/language_models/
llms.rs

1//! Base interface for traditional large language models (LLMs).
2//!
3//! These are traditionally older models (newer models generally are chat models).
4//! LLMs take a string as input and return a string as output.
5//! Mirrors `langchain_core.language_models.llms`.
6
7use std::collections::HashMap;
8use std::pin::Pin;
9
10use async_trait::async_trait;
11use futures::Stream;
12use serde_json::Value;
13
14use super::base::{BaseLanguageModel, LangSmithParams, LanguageModelConfig, LanguageModelInput};
15use crate::callbacks::CallbackManagerForLLMRun;
16use crate::error::Result;
17use crate::outputs::{Generation, GenerationChunk, GenerationType, LLMResult};
18use crate::prompt_values::PromptValue;
19
20/// Type alias for a streaming LLM output.
21pub type LLMStream = Pin<Box<dyn Stream<Item = Result<GenerationChunk>> + Send>>;
22
23/// Configuration specific to LLMs.
24#[derive(Debug, Clone, Default)]
25pub struct LLMConfig {
26    /// Base language model configuration.
27    pub base: LanguageModelConfig,
28}
29
30impl LLMConfig {
31    /// Create a new LLM configuration.
32    pub fn new() -> Self {
33        Self::default()
34    }
35
36    /// Enable caching.
37    pub fn with_cache(mut self, cache: bool) -> Self {
38        self.base.cache = Some(cache);
39        self
40    }
41
42    /// Enable verbose mode.
43    pub fn with_verbose(mut self, verbose: bool) -> Self {
44        self.base.verbose = verbose;
45        self
46    }
47
48    /// Set tags.
49    pub fn with_tags(mut self, tags: Vec<String>) -> Self {
50        self.base.tags = Some(tags);
51        self
52    }
53
54    /// Set metadata.
55    pub fn with_metadata(mut self, metadata: HashMap<String, Value>) -> Self {
56        self.base.metadata = Some(metadata);
57        self
58    }
59}
60
61/// Helper function to extract text from a GenerationType.
62fn extract_text(generation: &GenerationType) -> String {
63    match generation {
64        GenerationType::Generation(g) => g.text.clone(),
65        GenerationType::GenerationChunk(g) => g.text.clone(),
66        GenerationType::ChatGeneration(g) => g.text.to_string(),
67        GenerationType::ChatGenerationChunk(g) => g.text.to_string(),
68    }
69}
70
71/// Base LLM abstract interface.
72///
73/// It should take in a prompt and return a string.
74///
75/// # Implementation Guide
76///
77/// | Method/Property         | Description                                        | Required |
78/// |------------------------|----------------------------------------------------|---------:|
79/// | `generate_prompts`     | Use to generate from prompts                       | Required |
80/// | `llm_type` (property)  | Used to uniquely identify the type of the model    | Required |
81/// | `stream_prompt`        | Use to implement streaming                         | Optional |
82#[async_trait]
83pub trait BaseLLM: BaseLanguageModel {
84    /// Get the LLM configuration.
85    fn llm_config(&self) -> &LLMConfig;
86
87    /// Run the LLM on the given prompts.
88    ///
89    /// # Arguments
90    ///
91    /// * `prompts` - The prompts to generate from.
92    /// * `stop` - Stop words to use when generating.
93    /// * `run_manager` - Callback manager for the run.
94    ///
95    /// # Returns
96    ///
97    /// The LLM result.
98    async fn generate_prompts(
99        &self,
100        prompts: Vec<String>,
101        stop: Option<Vec<String>>,
102        run_manager: Option<&CallbackManagerForLLMRun>,
103    ) -> Result<LLMResult>;
104
105    /// Stream the LLM on the given prompt.
106    ///
107    /// Default implementation falls back to `generate_prompts` and returns
108    /// the output as a single chunk.
109    ///
110    /// # Arguments
111    ///
112    /// * `prompt` - The prompt to generate from.
113    /// * `stop` - Stop words to use when generating.
114    /// * `run_manager` - Callback manager for the run.
115    ///
116    /// # Returns
117    ///
118    /// A stream of generation chunks.
119    async fn stream_prompt(
120        &self,
121        prompt: String,
122        stop: Option<Vec<String>>,
123        run_manager: Option<&CallbackManagerForLLMRun>,
124    ) -> Result<LLMStream> {
125        let result = self
126            .generate_prompts(vec![prompt], stop, run_manager)
127            .await?;
128
129        // Get the first generation
130        if let Some(generations) = result.generations.first()
131            && let Some(generation) = generations.first()
132        {
133            let text = extract_text(generation);
134            let chunk = GenerationChunk::new(text);
135            return Ok(Box::pin(futures::stream::once(async move { Ok(chunk) })));
136        }
137
138        // Empty result
139        Ok(Box::pin(futures::stream::empty()))
140    }
141
142    /// Convert input to a prompt string.
143    fn convert_input(&self, input: LanguageModelInput) -> Result<String> {
144        match input {
145            LanguageModelInput::Text(s) => Ok(s),
146            LanguageModelInput::StringPrompt(p) => Ok(p.to_string()),
147            LanguageModelInput::ChatPrompt(p) => {
148                // Convert chat prompt to string representation
149                let messages = p.to_messages();
150                let parts: Vec<String> = messages
151                    .iter()
152                    .map(|msg| format!("{}: {}", msg.message_type(), msg.content()))
153                    .collect();
154                Ok(parts.join("\n"))
155            }
156            LanguageModelInput::ImagePrompt(p) => Ok(p.image_url.url.clone().unwrap_or_default()),
157            LanguageModelInput::Messages(m) => {
158                // Convert messages to a string representation
159                let parts: Vec<String> = m
160                    .iter()
161                    .map(|msg| format!("{}: {}", msg.message_type(), msg.content()))
162                    .collect();
163                Ok(parts.join("\n"))
164            }
165        }
166    }
167
168    /// Invoke the model with input.
169    async fn invoke(&self, input: LanguageModelInput) -> Result<String> {
170        let prompt = self.convert_input(input)?;
171        let result = self.generate_prompts(vec![prompt], None, None).await?;
172
173        // Get the first generation's text
174        if let Some(generations) = result.generations.first()
175            && let Some(generation) = generations.first()
176        {
177            return Ok(extract_text(generation));
178        }
179
180        Ok(String::new())
181    }
182
183    /// Get standard params for tracing.
184    fn get_llm_ls_params(&self, stop: Option<&[String]>) -> LangSmithParams {
185        let mut params = self.get_ls_params(stop);
186        params.ls_model_type = Some("llm".to_string());
187        params
188    }
189}
190
191/// Simple interface for implementing a custom LLM.
192///
193/// You should subclass this class and implement the following:
194///
195/// - `call` method: Run the LLM on the given prompt.
196/// - `llm_type` property: Return a unique identifier for this LLM.
197/// - `identifying_params` property: Return identifying parameters for caching/tracing.
198///
199/// Optional: Override the following methods for more optimizations:
200///
201/// - `acall`: Provide a native async version of `call`.
202/// - `stream_prompt`: Stream the LLM output.
203/// - `astream_prompt`: Async version of streaming.
204#[async_trait]
205pub trait LLM: BaseLLM {
206    /// Run the LLM on the given input.
207    ///
208    /// # Arguments
209    ///
210    /// * `prompt` - The prompt to generate from.
211    /// * `stop` - Stop words to use when generating.
212    /// * `run_manager` - Callback manager for the run.
213    ///
214    /// # Returns
215    ///
216    /// The model output as a string. Should NOT include the prompt.
217    async fn call(
218        &self,
219        prompt: String,
220        stop: Option<Vec<String>>,
221        run_manager: Option<&CallbackManagerForLLMRun>,
222    ) -> Result<String>;
223}
224
225/// Helper function to get prompts from cache.
226///
227/// Returns existing prompts, llm string, missing prompt indices, and missing prompts.
228pub fn get_prompts_from_cache(
229    params: &HashMap<String, Value>,
230    prompts: &[String],
231    cache: Option<&dyn crate::caches::BaseCache>,
232) -> (
233    HashMap<usize, Vec<Generation>>,
234    String,
235    Vec<usize>,
236    Vec<String>,
237) {
238    let llm_string = serde_json::to_string(&params).unwrap_or_default();
239    let mut existing_prompts = HashMap::new();
240    let mut missing_prompt_idxs = Vec::new();
241    let mut missing_prompts = Vec::new();
242
243    if let Some(cache) = cache {
244        for (i, prompt) in prompts.iter().enumerate() {
245            if let Some(cached) = cache.lookup(prompt, &llm_string) {
246                existing_prompts.insert(i, cached);
247            } else {
248                missing_prompts.push(prompt.clone());
249                missing_prompt_idxs.push(i);
250            }
251        }
252    } else {
253        // No cache, all prompts are missing
254        for (i, prompt) in prompts.iter().enumerate() {
255            missing_prompts.push(prompt.clone());
256            missing_prompt_idxs.push(i);
257        }
258    }
259
260    (
261        existing_prompts,
262        llm_string,
263        missing_prompt_idxs,
264        missing_prompts,
265    )
266}
267
268/// Helper function to update cache with new results.
269pub fn update_cache(
270    cache: Option<&dyn crate::caches::BaseCache>,
271    existing_prompts: &mut HashMap<usize, Vec<Generation>>,
272    llm_string: &str,
273    missing_prompt_idxs: &[usize],
274    new_results: &LLMResult,
275    prompts: &[String],
276) -> Option<HashMap<String, Value>> {
277    if let Some(cache) = cache {
278        for (i, result) in new_results.generations.iter().enumerate() {
279            if let Some(&idx) = missing_prompt_idxs.get(i) {
280                let generations: Vec<Generation> = result
281                    .iter()
282                    .filter_map(|g| match g {
283                        GenerationType::Generation(generation) => Some(generation.clone()),
284                        GenerationType::GenerationChunk(chunk) => Some(chunk.clone().into()),
285                        _ => None,
286                    })
287                    .collect();
288
289                existing_prompts.insert(idx, generations.clone());
290
291                if let Some(prompt) = prompts.get(idx) {
292                    cache.update(prompt, llm_string, generations);
293                }
294            }
295        }
296    }
297
298    new_results.llm_output.clone()
299}
300
301#[cfg(test)]
302mod tests {
303    use super::*;
304
305    #[test]
306    fn test_llm_config_builder() {
307        let config = LLMConfig::new()
308            .with_cache(true)
309            .with_verbose(true)
310            .with_tags(vec!["test".to_string()]);
311
312        assert_eq!(config.base.cache, Some(true));
313        assert!(config.base.verbose);
314        assert_eq!(config.base.tags, Some(vec!["test".to_string()]));
315    }
316
317    #[test]
318    fn test_get_prompts_from_cache_no_cache() {
319        let params = HashMap::new();
320        let prompts = vec!["Hello".to_string(), "World".to_string()];
321
322        let (existing, _llm_string, missing_idxs, missing) =
323            get_prompts_from_cache(&params, &prompts, None);
324
325        assert!(existing.is_empty());
326        assert_eq!(missing_idxs, vec![0, 1]);
327        assert_eq!(missing, prompts);
328    }
329}