langchain_rust/chain/conversational/
builder.rs

1use std::sync::Arc;
2
3use tokio::sync::Mutex;
4
5use crate::{
6    chain::{
7        llm_chain::LLMChainBuilder, options::ChainCallOptions, ChainError, DEFAULT_OUTPUT_KEY,
8    },
9    language_models::llm::LLM,
10    memory::SimpleMemory,
11    output_parsers::OutputParser,
12    prompt::{FormatPrompter, HumanMessagePromptTemplate},
13    schemas::memory::BaseMemory,
14    template_fstring,
15};
16
17use super::{prompt::DEFAULT_TEMPLATE, ConversationalChain, DEFAULT_INPUT_VARIABLE};
18
19pub struct ConversationalChainBuilder {
20    llm: Option<Box<dyn LLM>>,
21    options: Option<ChainCallOptions>,
22    memory: Option<Arc<Mutex<dyn BaseMemory>>>,
23    output_key: Option<String>,
24    output_parser: Option<Box<dyn OutputParser>>,
25    input_key: Option<String>,
26    prompt: Option<Box<dyn FormatPrompter>>,
27}
28
29impl ConversationalChainBuilder {
30    pub fn new() -> Self {
31        Self {
32            llm: None,
33            options: None,
34            memory: None,
35            output_key: None,
36            output_parser: None,
37            input_key: None,
38            prompt: None,
39        }
40    }
41
42    pub fn llm<L: Into<Box<dyn LLM>>>(mut self, llm: L) -> Self {
43        self.llm = Some(llm.into());
44        self
45    }
46
47    pub fn options(mut self, options: ChainCallOptions) -> Self {
48        self.options = Some(options);
49        self
50    }
51
52    pub fn input_key<S: Into<String>>(mut self, input_key: S) -> Self {
53        self.input_key = Some(input_key.into());
54        self
55    }
56
57    pub fn output_parser<P: Into<Box<dyn OutputParser>>>(mut self, output_parser: P) -> Self {
58        self.output_parser = Some(output_parser.into());
59        self
60    }
61
62    pub fn memory(mut self, memory: Arc<Mutex<dyn BaseMemory>>) -> Self {
63        self.memory = Some(memory);
64        self
65    }
66
67    pub fn output_key<S: Into<String>>(mut self, output_key: S) -> Self {
68        self.output_key = Some(output_key.into());
69        self
70    }
71
72    ///If you want to add a custom prompt,keep in mind which variables are obligatory.
73    pub fn prompt<P: Into<Box<dyn FormatPrompter>>>(mut self, prompt: P) -> Self {
74        self.prompt = Some(prompt.into());
75        self
76    }
77
78    pub fn build(self) -> Result<ConversationalChain, ChainError> {
79        let llm = self
80            .llm
81            .ok_or_else(|| ChainError::MissingObject("LLM must be set".into()))?;
82        let prompt = match self.prompt {
83            Some(prompt) => prompt,
84            None => Box::new(HumanMessagePromptTemplate::new(template_fstring!(
85                DEFAULT_TEMPLATE,
86                "history",
87                "input"
88            ))),
89        };
90        let llm_chain = {
91            let mut builder = LLMChainBuilder::new()
92                .prompt(prompt)
93                .llm(llm)
94                .output_key(self.output_key.unwrap_or_else(|| DEFAULT_OUTPUT_KEY.into()));
95
96            if let Some(options) = self.options {
97                builder = builder.options(options);
98            }
99
100            if let Some(output_parser) = self.output_parser {
101                builder = builder.output_parser(output_parser);
102            }
103
104            builder.build()?
105        };
106
107        let memory = self
108            .memory
109            .unwrap_or_else(|| Arc::new(Mutex::new(SimpleMemory::new())));
110
111        Ok(ConversationalChain {
112            llm: llm_chain,
113            memory,
114            input_key: self
115                .input_key
116                .unwrap_or_else(|| DEFAULT_INPUT_VARIABLE.to_string()),
117        })
118    }
119}