agent_chain_core/language_models/
base.rs1use std::collections::HashMap;
7use std::pin::Pin;
8
9use async_trait::async_trait;
10use futures::Stream;
11use serde::{Deserialize, Serialize};
12use serde_json::Value;
13
14use crate::caches::BaseCache;
15use crate::callbacks::Callbacks;
16use crate::error::Result;
17use crate::messages::{AIMessage, BaseMessage};
18use crate::outputs::LLMResult;
19
20#[derive(Debug, Clone, Default, Serialize, Deserialize)]
24pub struct LangSmithParams {
25 #[serde(skip_serializing_if = "Option::is_none")]
27 pub ls_provider: Option<String>,
28
29 #[serde(skip_serializing_if = "Option::is_none")]
31 pub ls_model_name: Option<String>,
32
33 #[serde(skip_serializing_if = "Option::is_none")]
35 pub ls_model_type: Option<String>,
36
37 #[serde(skip_serializing_if = "Option::is_none")]
39 pub ls_temperature: Option<f64>,
40
41 #[serde(skip_serializing_if = "Option::is_none")]
43 pub ls_max_tokens: Option<u32>,
44
45 #[serde(skip_serializing_if = "Option::is_none")]
47 pub ls_stop: Option<Vec<String>>,
48}
49
50impl LangSmithParams {
51 pub fn new() -> Self {
53 Self::default()
54 }
55
56 pub fn with_provider(mut self, provider: impl Into<String>) -> Self {
58 self.ls_provider = Some(provider.into());
59 self
60 }
61
62 pub fn with_model_name(mut self, model_name: impl Into<String>) -> Self {
64 self.ls_model_name = Some(model_name.into());
65 self
66 }
67
68 pub fn with_model_type(mut self, model_type: impl Into<String>) -> Self {
70 self.ls_model_type = Some(model_type.into());
71 self
72 }
73
74 pub fn with_temperature(mut self, temperature: f64) -> Self {
76 self.ls_temperature = Some(temperature);
77 self
78 }
79
80 pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
82 self.ls_max_tokens = Some(max_tokens);
83 self
84 }
85
86 pub fn with_stop(mut self, stop: Vec<String>) -> Self {
88 self.ls_stop = Some(stop);
89 self
90 }
91}
92
93use crate::prompt_values::{ChatPromptValue, ImagePromptValue, StringPromptValue};
94
95#[derive(Debug, Clone)]
99pub enum LanguageModelInput {
100 Text(String),
102 StringPrompt(StringPromptValue),
104 ChatPrompt(ChatPromptValue),
106 ImagePrompt(ImagePromptValue),
108 Messages(Vec<BaseMessage>),
110}
111
112impl From<String> for LanguageModelInput {
113 fn from(s: String) -> Self {
114 LanguageModelInput::Text(s)
115 }
116}
117
118impl From<&str> for LanguageModelInput {
119 fn from(s: &str) -> Self {
120 LanguageModelInput::Text(s.to_string())
121 }
122}
123
124impl From<StringPromptValue> for LanguageModelInput {
125 fn from(p: StringPromptValue) -> Self {
126 LanguageModelInput::StringPrompt(p)
127 }
128}
129
130impl From<ChatPromptValue> for LanguageModelInput {
131 fn from(p: ChatPromptValue) -> Self {
132 LanguageModelInput::ChatPrompt(p)
133 }
134}
135
136impl From<ImagePromptValue> for LanguageModelInput {
137 fn from(p: ImagePromptValue) -> Self {
138 LanguageModelInput::ImagePrompt(p)
139 }
140}
141
142impl From<Vec<BaseMessage>> for LanguageModelInput {
143 fn from(m: Vec<BaseMessage>) -> Self {
144 LanguageModelInput::Messages(m)
145 }
146}
147
148impl LanguageModelInput {
149 pub fn to_messages(&self) -> Vec<BaseMessage> {
151 use crate::prompt_values::PromptValue;
152 match self {
153 LanguageModelInput::Text(s) => {
154 vec![BaseMessage::Human(crate::messages::HumanMessage::new(s))]
155 }
156 LanguageModelInput::StringPrompt(p) => p.to_messages(),
157 LanguageModelInput::ChatPrompt(p) => p.to_messages(),
158 LanguageModelInput::ImagePrompt(p) => p.to_messages(),
159 LanguageModelInput::Messages(m) => m.clone(),
160 }
161 }
162}
163
164impl std::fmt::Display for LanguageModelInput {
165 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
166 use crate::prompt_values::PromptValue;
167 match self {
168 LanguageModelInput::Text(s) => write!(f, "{}", s),
169 LanguageModelInput::StringPrompt(p) => write!(f, "{}", PromptValue::to_string(p)),
170 LanguageModelInput::ChatPrompt(p) => write!(f, "{}", PromptValue::to_string(p)),
171 LanguageModelInput::ImagePrompt(p) => write!(f, "{}", PromptValue::to_string(p)),
172 LanguageModelInput::Messages(m) => {
173 let joined = m
174 .iter()
175 .map(|msg| format!("{}: {}", msg.message_type(), msg.content()))
176 .collect::<Vec<_>>()
177 .join("\n");
178 write!(f, "{}", joined)
179 }
180 }
181 }
182}
183
184#[derive(Debug, Clone)]
188pub enum LanguageModelOutput {
189 Message(Box<AIMessage>),
191 Text(String),
193}
194
195impl From<AIMessage> for LanguageModelOutput {
196 fn from(m: AIMessage) -> Self {
197 LanguageModelOutput::Message(Box::new(m))
198 }
199}
200
201impl From<String> for LanguageModelOutput {
202 fn from(s: String) -> Self {
203 LanguageModelOutput::Text(s)
204 }
205}
206
207impl LanguageModelOutput {
208 pub fn text(&self) -> &str {
210 match self {
211 LanguageModelOutput::Message(m) => m.content(),
212 LanguageModelOutput::Text(s) => s,
213 }
214 }
215
216 pub fn into_text(self) -> String {
218 match self {
219 LanguageModelOutput::Message(m) => m.content().to_string(),
220 LanguageModelOutput::Text(s) => s,
221 }
222 }
223
224 pub fn message(m: AIMessage) -> Self {
226 LanguageModelOutput::Message(Box::new(m))
227 }
228}
229
230#[derive(Debug, Clone, Default, Serialize, Deserialize)]
232pub struct LanguageModelConfig {
233 #[serde(skip_serializing_if = "Option::is_none")]
239 pub cache: Option<bool>,
240
241 #[serde(default)]
243 pub verbose: bool,
244
245 #[serde(skip_serializing_if = "Option::is_none")]
247 pub tags: Option<Vec<String>>,
248
249 #[serde(skip_serializing_if = "Option::is_none")]
251 pub metadata: Option<HashMap<String, Value>>,
252}
253
254impl LanguageModelConfig {
255 pub fn new() -> Self {
257 Self::default()
258 }
259
260 pub fn with_cache(mut self, cache: bool) -> Self {
262 self.cache = Some(cache);
263 self
264 }
265
266 pub fn with_verbose(mut self, verbose: bool) -> Self {
268 self.verbose = verbose;
269 self
270 }
271
272 pub fn with_tags(mut self, tags: Vec<String>) -> Self {
274 self.tags = Some(tags);
275 self
276 }
277
278 pub fn with_metadata(mut self, metadata: HashMap<String, Value>) -> Self {
280 self.metadata = Some(metadata);
281 self
282 }
283}
284
285#[async_trait]
290pub trait BaseLanguageModel: Send + Sync {
291 fn llm_type(&self) -> &str;
295
296 fn model_name(&self) -> &str;
298
299 fn config(&self) -> &LanguageModelConfig;
301
302 fn cache(&self) -> Option<&dyn BaseCache> {
304 None
305 }
306
307 fn callbacks(&self) -> Option<&Callbacks> {
309 None
310 }
311
312 async fn generate_prompt(
326 &self,
327 prompts: Vec<LanguageModelInput>,
328 stop: Option<Vec<String>>,
329 callbacks: Option<Callbacks>,
330 ) -> Result<LLMResult>;
331
332 fn get_ls_params(&self, stop: Option<&[String]>) -> LangSmithParams {
334 let mut params = LangSmithParams::new();
335
336 let llm_type = self.llm_type();
338 let provider = if llm_type.starts_with("Chat") {
339 llm_type
340 .strip_prefix("Chat")
341 .unwrap_or(llm_type)
342 .to_lowercase()
343 } else if llm_type.ends_with("Chat") {
344 llm_type
345 .strip_suffix("Chat")
346 .unwrap_or(llm_type)
347 .to_lowercase()
348 } else {
349 llm_type.to_lowercase()
350 };
351
352 params.ls_provider = Some(provider);
353 params.ls_model_name = Some(self.model_name().to_string());
354
355 if let Some(stop) = stop {
356 params.ls_stop = Some(stop.to_vec());
357 }
358
359 params
360 }
361
362 fn identifying_params(&self) -> HashMap<String, Value> {
364 let mut params = HashMap::new();
365 params.insert(
366 "_type".to_string(),
367 Value::String(self.llm_type().to_string()),
368 );
369 params.insert(
370 "model".to_string(),
371 Value::String(self.model_name().to_string()),
372 );
373 params
374 }
375
376 fn get_token_ids(&self, text: &str) -> Vec<u32> {
386 text.split_whitespace()
389 .enumerate()
390 .map(|(i, _)| i as u32)
391 .collect()
392 }
393
394 fn get_num_tokens(&self, text: &str) -> usize {
404 self.get_token_ids(text).len()
405 }
406
407 fn get_num_tokens_from_messages(&self, messages: &[BaseMessage]) -> usize {
417 messages
418 .iter()
419 .map(|m| {
420 let role_tokens = 4; let content_tokens = self.get_num_tokens(m.content());
423 role_tokens + content_tokens
424 })
425 .sum()
426 }
427}
428
429#[allow(dead_code)]
431pub type LanguageModelOutputStream =
432 Pin<Box<dyn Stream<Item = Result<LanguageModelOutput>> + Send>>;
433
434#[cfg(test)]
435mod tests {
436 use super::*;
437
438 #[test]
439 fn test_langsmith_params_builder() {
440 let params = LangSmithParams::new()
441 .with_provider("openai")
442 .with_model_name("gpt-4")
443 .with_model_type("chat")
444 .with_temperature(0.7)
445 .with_max_tokens(1000)
446 .with_stop(vec!["STOP".to_string()]);
447
448 assert_eq!(params.ls_provider, Some("openai".to_string()));
449 assert_eq!(params.ls_model_name, Some("gpt-4".to_string()));
450 assert_eq!(params.ls_model_type, Some("chat".to_string()));
451 assert_eq!(params.ls_temperature, Some(0.7));
452 assert_eq!(params.ls_max_tokens, Some(1000));
453 assert_eq!(params.ls_stop, Some(vec!["STOP".to_string()]));
454 }
455
456 #[test]
457 fn test_language_model_input_from_str() {
458 let input: LanguageModelInput = "Hello".into();
459 match input {
460 LanguageModelInput::Text(s) => assert_eq!(s, "Hello"),
461 _ => panic!("Expected Text variant"),
462 }
463 }
464
465 #[test]
466 fn test_language_model_output_text() {
467 let output = LanguageModelOutput::Text("Hello".to_string());
468 assert_eq!(output.text(), "Hello");
469 assert_eq!(output.into_text(), "Hello");
470 }
471
472 #[test]
473 fn test_language_model_config_builder() {
474 let config = LanguageModelConfig::new()
475 .with_cache(true)
476 .with_verbose(true)
477 .with_tags(vec!["test".to_string()]);
478
479 assert_eq!(config.cache, Some(true));
480 assert!(config.verbose);
481 assert_eq!(config.tags, Some(vec!["test".to_string()]));
482 }
483}