langchain_rust/chain/conversational/
mod.rs

1use std::{pin::Pin, sync::Arc};
2
3use async_stream::stream;
4use async_trait::async_trait;
5use futures::Stream;
6use futures_util::{pin_mut, StreamExt};
7use tokio::sync::Mutex;
8
9use crate::{
10    language_models::GenerateResult,
11    prompt::PromptArgs,
12    prompt_args,
13    schemas::{memory::BaseMemory, messages::Message, StreamData},
14};
15
16const DEFAULT_INPUT_VARIABLE: &str = "input";
17
18use super::{chain_trait::Chain, llm_chain::LLMChain, ChainError};
19
20pub mod builder;
21mod prompt;
22
23///This is only usefull when you dont modify the original prompt
24pub struct ConversationalChainPromptBuilder {
25    input: String,
26}
27
28impl ConversationalChainPromptBuilder {
29    pub fn new() -> Self {
30        Self {
31            input: "".to_string(),
32        }
33    }
34
35    pub fn input<S: Into<String>>(mut self, input: S) -> Self {
36        self.input = input.into();
37        self
38    }
39
40    pub fn build(self) -> PromptArgs {
41        prompt_args! {
42            DEFAULT_INPUT_VARIABLE => self.input,
43        }
44    }
45}
46
47pub struct ConversationalChain {
48    llm: LLMChain,
49    input_key: String,
50    pub memory: Arc<Mutex<dyn BaseMemory>>,
51}
52
53//Conversational Chain is a simple chain to interact with ai as a string of messages
54impl ConversationalChain {
55    pub fn prompt_builder(&self) -> ConversationalChainPromptBuilder {
56        ConversationalChainPromptBuilder::new()
57    }
58}
59
60#[async_trait]
61impl Chain for ConversationalChain {
62    async fn call(&self, input_variables: PromptArgs) -> Result<GenerateResult, ChainError> {
63        let input_variable = &input_variables
64            .get(&self.input_key)
65            .ok_or(ChainError::MissingInputVariable(self.input_key.clone()))?;
66        let human_message = Message::new_human_message(input_variable);
67
68        let history = {
69            let memory = self.memory.lock().await;
70            memory.to_string()
71        };
72        let mut input_variables = input_variables;
73        input_variables.insert("history".to_string(), history.into());
74        let result = self.llm.call(input_variables.clone()).await?;
75
76        let mut memory = self.memory.lock().await;
77        memory.add_message(human_message);
78        memory.add_message(Message::new_ai_message(&result.generation));
79        Ok(result)
80    }
81
82    async fn stream(
83        &self,
84        input_variables: PromptArgs,
85    ) -> Result<Pin<Box<dyn Stream<Item = Result<StreamData, ChainError>> + Send>>, ChainError>
86    {
87        let input_variable = &input_variables
88            .get(&self.input_key)
89            .ok_or(ChainError::MissingInputVariable(self.input_key.clone()))?;
90        let human_message = Message::new_human_message(input_variable);
91
92        let history = {
93            let memory = self.memory.lock().await;
94            memory.to_string()
95        };
96
97        let mut input_variables = input_variables;
98        input_variables.insert("history".to_string(), history.into());
99
100        let complete_ai_message = Arc::new(Mutex::new(String::new()));
101        let complete_ai_message_clone = complete_ai_message.clone();
102
103        let memory = self.memory.clone();
104
105        let stream = self.llm.stream(input_variables).await?;
106        let output_stream = stream! {
107            pin_mut!(stream);
108            while let Some(result) = stream.next().await {
109                match result {
110                    Ok(data) => {
111                        let mut complete_ai_message_clone =
112                            complete_ai_message_clone.lock().await;
113                        complete_ai_message_clone.push_str(&data.content);
114
115                        yield Ok(data);
116                    },
117                    Err(e) => {
118                        yield Err(e);
119                    }
120                }
121            }
122
123            let mut memory = memory.lock().await;
124            memory.add_message(human_message);
125            memory.add_message(Message::new_ai_message(&complete_ai_message.lock().await));
126        };
127
128        Ok(Box::pin(output_stream))
129    }
130
131    fn get_input_keys(&self) -> Vec<String> {
132        vec![self.input_key.clone()]
133    }
134}
135
136#[cfg(test)]
137mod tests {
138    use crate::{
139        chain::conversational::builder::ConversationalChainBuilder,
140        llm::openai::{OpenAI, OpenAIModel},
141        prompt_args,
142    };
143
144    use super::*;
145
146    #[tokio::test]
147    #[ignore]
148    async fn test_invoke_conversational() {
149        let llm = OpenAI::default().with_model(OpenAIModel::Gpt35.to_string());
150        let chain = ConversationalChainBuilder::new()
151            .llm(llm)
152            .build()
153            .expect("Error building ConversationalChain");
154
155        let input_variables_first = prompt_args! {
156            "input" => "Soy de peru",
157        };
158        // Execute the first `chain.invoke` and assert that it should succeed
159        let result_first = chain.invoke(input_variables_first).await;
160        assert!(
161            result_first.is_ok(),
162            "Error invoking LLMChain: {:?}",
163            result_first.err()
164        );
165
166        // Optionally, if you want to print the successful result, you can do so like this:
167        if let Ok(result) = result_first {
168            println!("Result: {:?}", result);
169        }
170
171        let input_variables_second = prompt_args! {
172            "input" => "Cuales son platos tipicos de mi pais",
173        };
174        // Execute the second `chain.invoke` and assert that it should succeed
175        let result_second = chain.invoke(input_variables_second).await;
176        assert!(
177            result_second.is_ok(),
178            "Error invoking LLMChain: {:?}",
179            result_second.err()
180        );
181
182        // Optionally, if you want to print the successful result, you can do so like this:
183        if let Ok(result) = result_second {
184            println!("Result: {:?}", result);
185        }
186    }
187}