langchain_rust/chain/
llm_chain.rs

1use std::pin::Pin;
2
3use async_trait::async_trait;
4use futures::Stream;
5use futures_util::TryStreamExt;
6
7use crate::{
8    language_models::{llm::LLM, GenerateResult},
9    output_parsers::{OutputParser, SimpleParser},
10    prompt::{FormatPrompter, PromptArgs},
11    schemas::StreamData,
12};
13
14use super::{chain_trait::Chain, options::ChainCallOptions, ChainError};
15
16pub struct LLMChainBuilder {
17    prompt: Option<Box<dyn FormatPrompter>>,
18    llm: Option<Box<dyn LLM>>,
19    output_key: Option<String>,
20    options: Option<ChainCallOptions>,
21    output_parser: Option<Box<dyn OutputParser>>,
22}
23
24impl LLMChainBuilder {
25    pub fn new() -> Self {
26        Self {
27            prompt: None,
28            llm: None,
29            options: None,
30            output_key: None,
31            output_parser: None,
32        }
33    }
34    pub fn options(mut self, options: ChainCallOptions) -> Self {
35        self.options = Some(options);
36        self
37    }
38
39    pub fn prompt<P: Into<Box<dyn FormatPrompter>>>(mut self, prompt: P) -> Self {
40        self.prompt = Some(prompt.into());
41        self
42    }
43
44    pub fn llm<L: Into<Box<dyn LLM>>>(mut self, llm: L) -> Self {
45        self.llm = Some(llm.into());
46        self
47    }
48
49    pub fn output_key<S: Into<String>>(mut self, output_key: S) -> Self {
50        self.output_key = Some(output_key.into());
51        self
52    }
53
54    pub fn output_parser<P: Into<Box<dyn OutputParser>>>(mut self, output_parser: P) -> Self {
55        self.output_parser = Some(output_parser.into());
56        self
57    }
58
59    pub fn build(self) -> Result<LLMChain, ChainError> {
60        let prompt = self
61            .prompt
62            .ok_or_else(|| ChainError::MissingObject("Prompt must be set".into()))?;
63
64        let mut llm = self
65            .llm
66            .ok_or_else(|| ChainError::MissingObject("LLM must be set".into()))?;
67
68        if let Some(options) = self.options {
69            let llm_options = ChainCallOptions::to_llm_options(options);
70            llm.add_options(llm_options);
71        }
72
73        let chain = LLMChain {
74            prompt,
75            llm,
76            output_key: self.output_key.unwrap_or("output".to_string()),
77            output_parser: self
78                .output_parser
79                .unwrap_or_else(|| Box::new(SimpleParser::default())),
80        };
81
82        Ok(chain)
83    }
84}
85
86pub struct LLMChain {
87    prompt: Box<dyn FormatPrompter>,
88    llm: Box<dyn LLM>,
89    output_key: String,
90    output_parser: Box<dyn OutputParser>,
91}
92
93#[async_trait]
94impl Chain for LLMChain {
95    fn get_input_keys(&self) -> Vec<String> {
96        self.prompt.get_input_variables()
97    }
98
99    fn get_output_keys(&self) -> Vec<String> {
100        vec![self.output_key.clone()]
101    }
102
103    async fn call(&self, input_variables: PromptArgs) -> Result<GenerateResult, ChainError> {
104        let prompt = self.prompt.format_prompt(input_variables.clone())?;
105        log::debug!("Prompt: {:?}", prompt);
106        let mut output = self.llm.generate(&prompt.to_chat_messages()).await?;
107        output.generation = self.output_parser.parse(&output.generation).await?;
108
109        Ok(output)
110    }
111
112    async fn invoke(&self, input_variables: PromptArgs) -> Result<String, ChainError> {
113        let prompt = self.prompt.format_prompt(input_variables.clone())?;
114        log::debug!("Prompt: {:?}", prompt);
115        let output = self
116            .llm
117            .generate(&prompt.to_chat_messages())
118            .await?
119            .generation;
120        Ok(output)
121    }
122
123    async fn stream(
124        &self,
125        input_variables: PromptArgs,
126    ) -> Result<Pin<Box<dyn Stream<Item = Result<StreamData, ChainError>> + Send>>, ChainError>
127    {
128        let prompt = self.prompt.format_prompt(input_variables.clone())?;
129        log::debug!("Prompt: {:?}", prompt);
130        let llm_stream = self.llm.stream(&prompt.to_chat_messages()).await?;
131
132        // Map the errors from LLMError to ChainError
133        let mapped_stream = llm_stream.map_err(ChainError::from);
134
135        Ok(Box::pin(mapped_stream))
136    }
137}
138
139#[cfg(test)]
140mod tests {
141    use crate::{
142        chain::options::ChainCallOptions,
143        llm::openai::{OpenAI, OpenAIModel},
144        message_formatter,
145        prompt::{HumanMessagePromptTemplate, MessageOrTemplate},
146        prompt_args, template_fstring,
147    };
148
149    use super::*;
150
151    #[tokio::test]
152    #[ignore]
153    async fn test_invoke_chain() {
154        // Create an AI message prompt template
155        let human_message_prompt = HumanMessagePromptTemplate::new(template_fstring!(
156            "Mi nombre es: {nombre} ",
157            "nombre",
158        ));
159
160        // Use the `message_formatter` macro to construct the formatter
161        let formatter =
162            message_formatter![MessageOrTemplate::Template(human_message_prompt.into()),];
163
164        let options = ChainCallOptions::default();
165        let llm = OpenAI::default().with_model(OpenAIModel::Gpt35.to_string());
166        let chain = LLMChainBuilder::new()
167            .prompt(formatter)
168            .llm(llm)
169            .options(options)
170            .build()
171            .expect("Failed to build LLMChain");
172
173        let input_variables = prompt_args! {
174            "nombre" => "luis",
175
176        };
177        // Execute `chain.invoke` and assert that it should succeed
178        let result = chain.invoke(input_variables).await;
179        assert!(
180            result.is_ok(),
181            "Error invoking LLMChain: {:?}",
182            result.err()
183        )
184    }
185}