langchain_rust/chain/
llm_chain.rs1use std::pin::Pin;
2
3use async_trait::async_trait;
4use futures::Stream;
5use futures_util::TryStreamExt;
6
7use crate::{
8 language_models::{llm::LLM, GenerateResult},
9 output_parsers::{OutputParser, SimpleParser},
10 prompt::{FormatPrompter, PromptArgs},
11 schemas::StreamData,
12};
13
14use super::{chain_trait::Chain, options::ChainCallOptions, ChainError};
15
16pub struct LLMChainBuilder {
17 prompt: Option<Box<dyn FormatPrompter>>,
18 llm: Option<Box<dyn LLM>>,
19 output_key: Option<String>,
20 options: Option<ChainCallOptions>,
21 output_parser: Option<Box<dyn OutputParser>>,
22}
23
24impl LLMChainBuilder {
25 pub fn new() -> Self {
26 Self {
27 prompt: None,
28 llm: None,
29 options: None,
30 output_key: None,
31 output_parser: None,
32 }
33 }
34 pub fn options(mut self, options: ChainCallOptions) -> Self {
35 self.options = Some(options);
36 self
37 }
38
39 pub fn prompt<P: Into<Box<dyn FormatPrompter>>>(mut self, prompt: P) -> Self {
40 self.prompt = Some(prompt.into());
41 self
42 }
43
44 pub fn llm<L: Into<Box<dyn LLM>>>(mut self, llm: L) -> Self {
45 self.llm = Some(llm.into());
46 self
47 }
48
49 pub fn output_key<S: Into<String>>(mut self, output_key: S) -> Self {
50 self.output_key = Some(output_key.into());
51 self
52 }
53
54 pub fn output_parser<P: Into<Box<dyn OutputParser>>>(mut self, output_parser: P) -> Self {
55 self.output_parser = Some(output_parser.into());
56 self
57 }
58
59 pub fn build(self) -> Result<LLMChain, ChainError> {
60 let prompt = self
61 .prompt
62 .ok_or_else(|| ChainError::MissingObject("Prompt must be set".into()))?;
63
64 let mut llm = self
65 .llm
66 .ok_or_else(|| ChainError::MissingObject("LLM must be set".into()))?;
67
68 if let Some(options) = self.options {
69 let llm_options = ChainCallOptions::to_llm_options(options);
70 llm.add_options(llm_options);
71 }
72
73 let chain = LLMChain {
74 prompt,
75 llm,
76 output_key: self.output_key.unwrap_or("output".to_string()),
77 output_parser: self
78 .output_parser
79 .unwrap_or_else(|| Box::new(SimpleParser::default())),
80 };
81
82 Ok(chain)
83 }
84}
85
86pub struct LLMChain {
87 prompt: Box<dyn FormatPrompter>,
88 llm: Box<dyn LLM>,
89 output_key: String,
90 output_parser: Box<dyn OutputParser>,
91}
92
93#[async_trait]
94impl Chain for LLMChain {
95 fn get_input_keys(&self) -> Vec<String> {
96 self.prompt.get_input_variables()
97 }
98
99 fn get_output_keys(&self) -> Vec<String> {
100 vec![self.output_key.clone()]
101 }
102
103 async fn call(&self, input_variables: PromptArgs) -> Result<GenerateResult, ChainError> {
104 let prompt = self.prompt.format_prompt(input_variables.clone())?;
105 log::debug!("Prompt: {:?}", prompt);
106 let mut output = self.llm.generate(&prompt.to_chat_messages()).await?;
107 output.generation = self.output_parser.parse(&output.generation).await?;
108
109 Ok(output)
110 }
111
112 async fn invoke(&self, input_variables: PromptArgs) -> Result<String, ChainError> {
113 let prompt = self.prompt.format_prompt(input_variables.clone())?;
114 log::debug!("Prompt: {:?}", prompt);
115 let output = self
116 .llm
117 .generate(&prompt.to_chat_messages())
118 .await?
119 .generation;
120 Ok(output)
121 }
122
123 async fn stream(
124 &self,
125 input_variables: PromptArgs,
126 ) -> Result<Pin<Box<dyn Stream<Item = Result<StreamData, ChainError>> + Send>>, ChainError>
127 {
128 let prompt = self.prompt.format_prompt(input_variables.clone())?;
129 log::debug!("Prompt: {:?}", prompt);
130 let llm_stream = self.llm.stream(&prompt.to_chat_messages()).await?;
131
132 let mapped_stream = llm_stream.map_err(ChainError::from);
134
135 Ok(Box::pin(mapped_stream))
136 }
137}
138
139#[cfg(test)]
140mod tests {
141 use crate::{
142 chain::options::ChainCallOptions,
143 llm::openai::{OpenAI, OpenAIModel},
144 message_formatter,
145 prompt::{HumanMessagePromptTemplate, MessageOrTemplate},
146 prompt_args, template_fstring,
147 };
148
149 use super::*;
150
151 #[tokio::test]
152 #[ignore]
153 async fn test_invoke_chain() {
154 let human_message_prompt = HumanMessagePromptTemplate::new(template_fstring!(
156 "Mi nombre es: {nombre} ",
157 "nombre",
158 ));
159
160 let formatter =
162 message_formatter![MessageOrTemplate::Template(human_message_prompt.into()),];
163
164 let options = ChainCallOptions::default();
165 let llm = OpenAI::default().with_model(OpenAIModel::Gpt35.to_string());
166 let chain = LLMChainBuilder::new()
167 .prompt(formatter)
168 .llm(llm)
169 .options(options)
170 .build()
171 .expect("Failed to build LLMChain");
172
173 let input_variables = prompt_args! {
174 "nombre" => "luis",
175
176 };
177 let result = chain.invoke(input_variables).await;
179 assert!(
180 result.is_ok(),
181 "Error invoking LLMChain: {:?}",
182 result.err()
183 )
184 }
185}