agent_chain_core/language_models/
llms.rs1use std::collections::HashMap;
8use std::pin::Pin;
9
10use async_trait::async_trait;
11use futures::Stream;
12use serde_json::Value;
13
14use super::base::{BaseLanguageModel, LangSmithParams, LanguageModelConfig, LanguageModelInput};
15use crate::callbacks::CallbackManagerForLLMRun;
16use crate::error::Result;
17use crate::outputs::{Generation, GenerationChunk, GenerationType, LLMResult};
18use crate::prompt_values::PromptValue;
19
20pub type LLMStream = Pin<Box<dyn Stream<Item = Result<GenerationChunk>> + Send>>;
22
23#[derive(Debug, Clone, Default)]
25pub struct LLMConfig {
26 pub base: LanguageModelConfig,
28}
29
30impl LLMConfig {
31 pub fn new() -> Self {
33 Self::default()
34 }
35
36 pub fn with_cache(mut self, cache: bool) -> Self {
38 self.base.cache = Some(cache);
39 self
40 }
41
42 pub fn with_verbose(mut self, verbose: bool) -> Self {
44 self.base.verbose = verbose;
45 self
46 }
47
48 pub fn with_tags(mut self, tags: Vec<String>) -> Self {
50 self.base.tags = Some(tags);
51 self
52 }
53
54 pub fn with_metadata(mut self, metadata: HashMap<String, Value>) -> Self {
56 self.base.metadata = Some(metadata);
57 self
58 }
59}
60
61fn extract_text(generation: &GenerationType) -> String {
63 match generation {
64 GenerationType::Generation(g) => g.text.clone(),
65 GenerationType::GenerationChunk(g) => g.text.clone(),
66 GenerationType::ChatGeneration(g) => g.text.to_string(),
67 GenerationType::ChatGenerationChunk(g) => g.text.to_string(),
68 }
69}
70
71#[async_trait]
83pub trait BaseLLM: BaseLanguageModel {
84 fn llm_config(&self) -> &LLMConfig;
86
87 async fn generate_prompts(
99 &self,
100 prompts: Vec<String>,
101 stop: Option<Vec<String>>,
102 run_manager: Option<&CallbackManagerForLLMRun>,
103 ) -> Result<LLMResult>;
104
105 async fn stream_prompt(
120 &self,
121 prompt: String,
122 stop: Option<Vec<String>>,
123 run_manager: Option<&CallbackManagerForLLMRun>,
124 ) -> Result<LLMStream> {
125 let result = self
126 .generate_prompts(vec![prompt], stop, run_manager)
127 .await?;
128
129 if let Some(generations) = result.generations.first()
131 && let Some(generation) = generations.first()
132 {
133 let text = extract_text(generation);
134 let chunk = GenerationChunk::new(text);
135 return Ok(Box::pin(futures::stream::once(async move { Ok(chunk) })));
136 }
137
138 Ok(Box::pin(futures::stream::empty()))
140 }
141
142 fn convert_input(&self, input: LanguageModelInput) -> Result<String> {
144 match input {
145 LanguageModelInput::Text(s) => Ok(s),
146 LanguageModelInput::StringPrompt(p) => Ok(p.to_string()),
147 LanguageModelInput::ChatPrompt(p) => {
148 let messages = p.to_messages();
150 let parts: Vec<String> = messages
151 .iter()
152 .map(|msg| format!("{}: {}", msg.message_type(), msg.content()))
153 .collect();
154 Ok(parts.join("\n"))
155 }
156 LanguageModelInput::ImagePrompt(p) => Ok(p.image_url.url.clone().unwrap_or_default()),
157 LanguageModelInput::Messages(m) => {
158 let parts: Vec<String> = m
160 .iter()
161 .map(|msg| format!("{}: {}", msg.message_type(), msg.content()))
162 .collect();
163 Ok(parts.join("\n"))
164 }
165 }
166 }
167
168 async fn invoke(&self, input: LanguageModelInput) -> Result<String> {
170 let prompt = self.convert_input(input)?;
171 let result = self.generate_prompts(vec![prompt], None, None).await?;
172
173 if let Some(generations) = result.generations.first()
175 && let Some(generation) = generations.first()
176 {
177 return Ok(extract_text(generation));
178 }
179
180 Ok(String::new())
181 }
182
183 fn get_llm_ls_params(&self, stop: Option<&[String]>) -> LangSmithParams {
185 let mut params = self.get_ls_params(stop);
186 params.ls_model_type = Some("llm".to_string());
187 params
188 }
189}
190
191#[async_trait]
205pub trait LLM: BaseLLM {
206 async fn call(
218 &self,
219 prompt: String,
220 stop: Option<Vec<String>>,
221 run_manager: Option<&CallbackManagerForLLMRun>,
222 ) -> Result<String>;
223}
224
225pub fn get_prompts_from_cache(
229 params: &HashMap<String, Value>,
230 prompts: &[String],
231 cache: Option<&dyn crate::caches::BaseCache>,
232) -> (
233 HashMap<usize, Vec<Generation>>,
234 String,
235 Vec<usize>,
236 Vec<String>,
237) {
238 let llm_string = serde_json::to_string(¶ms).unwrap_or_default();
239 let mut existing_prompts = HashMap::new();
240 let mut missing_prompt_idxs = Vec::new();
241 let mut missing_prompts = Vec::new();
242
243 if let Some(cache) = cache {
244 for (i, prompt) in prompts.iter().enumerate() {
245 if let Some(cached) = cache.lookup(prompt, &llm_string) {
246 existing_prompts.insert(i, cached);
247 } else {
248 missing_prompts.push(prompt.clone());
249 missing_prompt_idxs.push(i);
250 }
251 }
252 } else {
253 for (i, prompt) in prompts.iter().enumerate() {
255 missing_prompts.push(prompt.clone());
256 missing_prompt_idxs.push(i);
257 }
258 }
259
260 (
261 existing_prompts,
262 llm_string,
263 missing_prompt_idxs,
264 missing_prompts,
265 )
266}
267
268pub fn update_cache(
270 cache: Option<&dyn crate::caches::BaseCache>,
271 existing_prompts: &mut HashMap<usize, Vec<Generation>>,
272 llm_string: &str,
273 missing_prompt_idxs: &[usize],
274 new_results: &LLMResult,
275 prompts: &[String],
276) -> Option<HashMap<String, Value>> {
277 if let Some(cache) = cache {
278 for (i, result) in new_results.generations.iter().enumerate() {
279 if let Some(&idx) = missing_prompt_idxs.get(i) {
280 let generations: Vec<Generation> = result
281 .iter()
282 .filter_map(|g| match g {
283 GenerationType::Generation(generation) => Some(generation.clone()),
284 GenerationType::GenerationChunk(chunk) => Some(chunk.clone().into()),
285 _ => None,
286 })
287 .collect();
288
289 existing_prompts.insert(idx, generations.clone());
290
291 if let Some(prompt) = prompts.get(idx) {
292 cache.update(prompt, llm_string, generations);
293 }
294 }
295 }
296 }
297
298 new_results.llm_output.clone()
299}
300
301#[cfg(test)]
302mod tests {
303 use super::*;
304
305 #[test]
306 fn test_llm_config_builder() {
307 let config = LLMConfig::new()
308 .with_cache(true)
309 .with_verbose(true)
310 .with_tags(vec!["test".to_string()]);
311
312 assert_eq!(config.base.cache, Some(true));
313 assert!(config.base.verbose);
314 assert_eq!(config.base.tags, Some(vec!["test".to_string()]));
315 }
316
317 #[test]
318 fn test_get_prompts_from_cache_no_cache() {
319 let params = HashMap::new();
320 let prompts = vec!["Hello".to_string(), "World".to_string()];
321
322 let (existing, _llm_string, missing_idxs, missing) =
323 get_prompts_from_cache(¶ms, &prompts, None);
324
325 assert!(existing.is_empty());
326 assert_eq!(missing_idxs, vec![0, 1]);
327 assert_eq!(missing, prompts);
328 }
329}