langchain_rust/chain/conversational/
builder.rs1use std::sync::Arc;
2
3use tokio::sync::Mutex;
4
5use crate::{
6 chain::{
7 llm_chain::LLMChainBuilder, options::ChainCallOptions, ChainError, DEFAULT_OUTPUT_KEY,
8 },
9 language_models::llm::LLM,
10 memory::SimpleMemory,
11 output_parsers::OutputParser,
12 prompt::{FormatPrompter, HumanMessagePromptTemplate},
13 schemas::memory::BaseMemory,
14 template_fstring,
15};
16
17use super::{prompt::DEFAULT_TEMPLATE, ConversationalChain, DEFAULT_INPUT_VARIABLE};
18
19pub struct ConversationalChainBuilder {
20 llm: Option<Box<dyn LLM>>,
21 options: Option<ChainCallOptions>,
22 memory: Option<Arc<Mutex<dyn BaseMemory>>>,
23 output_key: Option<String>,
24 output_parser: Option<Box<dyn OutputParser>>,
25 input_key: Option<String>,
26 prompt: Option<Box<dyn FormatPrompter>>,
27}
28
29impl ConversationalChainBuilder {
30 pub fn new() -> Self {
31 Self {
32 llm: None,
33 options: None,
34 memory: None,
35 output_key: None,
36 output_parser: None,
37 input_key: None,
38 prompt: None,
39 }
40 }
41
42 pub fn llm<L: Into<Box<dyn LLM>>>(mut self, llm: L) -> Self {
43 self.llm = Some(llm.into());
44 self
45 }
46
47 pub fn options(mut self, options: ChainCallOptions) -> Self {
48 self.options = Some(options);
49 self
50 }
51
52 pub fn input_key<S: Into<String>>(mut self, input_key: S) -> Self {
53 self.input_key = Some(input_key.into());
54 self
55 }
56
57 pub fn output_parser<P: Into<Box<dyn OutputParser>>>(mut self, output_parser: P) -> Self {
58 self.output_parser = Some(output_parser.into());
59 self
60 }
61
62 pub fn memory(mut self, memory: Arc<Mutex<dyn BaseMemory>>) -> Self {
63 self.memory = Some(memory);
64 self
65 }
66
67 pub fn output_key<S: Into<String>>(mut self, output_key: S) -> Self {
68 self.output_key = Some(output_key.into());
69 self
70 }
71
72 pub fn prompt<P: Into<Box<dyn FormatPrompter>>>(mut self, prompt: P) -> Self {
74 self.prompt = Some(prompt.into());
75 self
76 }
77
78 pub fn build(self) -> Result<ConversationalChain, ChainError> {
79 let llm = self
80 .llm
81 .ok_or_else(|| ChainError::MissingObject("LLM must be set".into()))?;
82 let prompt = match self.prompt {
83 Some(prompt) => prompt,
84 None => Box::new(HumanMessagePromptTemplate::new(template_fstring!(
85 DEFAULT_TEMPLATE,
86 "history",
87 "input"
88 ))),
89 };
90 let llm_chain = {
91 let mut builder = LLMChainBuilder::new()
92 .prompt(prompt)
93 .llm(llm)
94 .output_key(self.output_key.unwrap_or_else(|| DEFAULT_OUTPUT_KEY.into()));
95
96 if let Some(options) = self.options {
97 builder = builder.options(options);
98 }
99
100 if let Some(output_parser) = self.output_parser {
101 builder = builder.output_parser(output_parser);
102 }
103
104 builder.build()?
105 };
106
107 let memory = self
108 .memory
109 .unwrap_or_else(|| Arc::new(Mutex::new(SimpleMemory::new())));
110
111 Ok(ConversationalChain {
112 llm: llm_chain,
113 memory,
114 input_key: self
115 .input_key
116 .unwrap_or_else(|| DEFAULT_INPUT_VARIABLE.to_string()),
117 })
118 }
119}