langchain_rust/chain/conversational/
mod.rs1use std::{pin::Pin, sync::Arc};
2
3use async_stream::stream;
4use async_trait::async_trait;
5use futures::Stream;
6use futures_util::{pin_mut, StreamExt};
7use tokio::sync::Mutex;
8
9use crate::{
10 language_models::GenerateResult,
11 prompt::PromptArgs,
12 prompt_args,
13 schemas::{memory::BaseMemory, messages::Message, StreamData},
14};
15
16const DEFAULT_INPUT_VARIABLE: &str = "input";
17
18use super::{chain_trait::Chain, llm_chain::LLMChain, ChainError};
19
20pub mod builder;
21mod prompt;
22
23pub struct ConversationalChainPromptBuilder {
25 input: String,
26}
27
28impl ConversationalChainPromptBuilder {
29 pub fn new() -> Self {
30 Self {
31 input: "".to_string(),
32 }
33 }
34
35 pub fn input<S: Into<String>>(mut self, input: S) -> Self {
36 self.input = input.into();
37 self
38 }
39
40 pub fn build(self) -> PromptArgs {
41 prompt_args! {
42 DEFAULT_INPUT_VARIABLE => self.input,
43 }
44 }
45}
46
47pub struct ConversationalChain {
48 llm: LLMChain,
49 input_key: String,
50 pub memory: Arc<Mutex<dyn BaseMemory>>,
51}
52
53impl ConversationalChain {
55 pub fn prompt_builder(&self) -> ConversationalChainPromptBuilder {
56 ConversationalChainPromptBuilder::new()
57 }
58}
59
60#[async_trait]
61impl Chain for ConversationalChain {
62 async fn call(&self, input_variables: PromptArgs) -> Result<GenerateResult, ChainError> {
63 let input_variable = &input_variables
64 .get(&self.input_key)
65 .ok_or(ChainError::MissingInputVariable(self.input_key.clone()))?;
66 let human_message = Message::new_human_message(input_variable);
67
68 let history = {
69 let memory = self.memory.lock().await;
70 memory.to_string()
71 };
72 let mut input_variables = input_variables;
73 input_variables.insert("history".to_string(), history.into());
74 let result = self.llm.call(input_variables.clone()).await?;
75
76 let mut memory = self.memory.lock().await;
77 memory.add_message(human_message);
78 memory.add_message(Message::new_ai_message(&result.generation));
79 Ok(result)
80 }
81
82 async fn stream(
83 &self,
84 input_variables: PromptArgs,
85 ) -> Result<Pin<Box<dyn Stream<Item = Result<StreamData, ChainError>> + Send>>, ChainError>
86 {
87 let input_variable = &input_variables
88 .get(&self.input_key)
89 .ok_or(ChainError::MissingInputVariable(self.input_key.clone()))?;
90 let human_message = Message::new_human_message(input_variable);
91
92 let history = {
93 let memory = self.memory.lock().await;
94 memory.to_string()
95 };
96
97 let mut input_variables = input_variables;
98 input_variables.insert("history".to_string(), history.into());
99
100 let complete_ai_message = Arc::new(Mutex::new(String::new()));
101 let complete_ai_message_clone = complete_ai_message.clone();
102
103 let memory = self.memory.clone();
104
105 let stream = self.llm.stream(input_variables).await?;
106 let output_stream = stream! {
107 pin_mut!(stream);
108 while let Some(result) = stream.next().await {
109 match result {
110 Ok(data) => {
111 let mut complete_ai_message_clone =
112 complete_ai_message_clone.lock().await;
113 complete_ai_message_clone.push_str(&data.content);
114
115 yield Ok(data);
116 },
117 Err(e) => {
118 yield Err(e);
119 }
120 }
121 }
122
123 let mut memory = memory.lock().await;
124 memory.add_message(human_message);
125 memory.add_message(Message::new_ai_message(&complete_ai_message.lock().await));
126 };
127
128 Ok(Box::pin(output_stream))
129 }
130
131 fn get_input_keys(&self) -> Vec<String> {
132 vec![self.input_key.clone()]
133 }
134}
135
136#[cfg(test)]
137mod tests {
138 use crate::{
139 chain::conversational::builder::ConversationalChainBuilder,
140 llm::openai::{OpenAI, OpenAIModel},
141 prompt_args,
142 };
143
144 use super::*;
145
146 #[tokio::test]
147 #[ignore]
148 async fn test_invoke_conversational() {
149 let llm = OpenAI::default().with_model(OpenAIModel::Gpt35.to_string());
150 let chain = ConversationalChainBuilder::new()
151 .llm(llm)
152 .build()
153 .expect("Error building ConversationalChain");
154
155 let input_variables_first = prompt_args! {
156 "input" => "Soy de peru",
157 };
158 let result_first = chain.invoke(input_variables_first).await;
160 assert!(
161 result_first.is_ok(),
162 "Error invoking LLMChain: {:?}",
163 result_first.err()
164 );
165
166 if let Ok(result) = result_first {
168 println!("Result: {:?}", result);
169 }
170
171 let input_variables_second = prompt_args! {
172 "input" => "Cuales son platos tipicos de mi pais",
173 };
174 let result_second = chain.invoke(input_variables_second).await;
176 assert!(
177 result_second.is_ok(),
178 "Error invoking LLMChain: {:?}",
179 result_second.err()
180 );
181
182 if let Ok(result) = result_second {
184 println!("Result: {:?}", result);
185 }
186 }
187}