langchain_rust/chain/conversational_retrieval_qa/
builder.rs1use std::sync::Arc;
2use tokio::sync::Mutex;
3
4use crate::{
5 chain::{
6 Chain, ChainError, CondenseQuestionGeneratorChain, StuffDocumentBuilder, DEFAULT_OUTPUT_KEY,
7 },
8 language_models::llm::LLM,
9 memory::SimpleMemory,
10 prompt::FormatPrompter,
11 schemas::{BaseMemory, Retriever},
12};
13
14use super::ConversationalRetrieverChain;
15
16const CONVERSATIONAL_RETRIEVAL_QA_DEFAULT_INPUT_KEY: &str = "question";
17
18pub struct ConversationalRetrieverChainBuilder {
48 llm: Option<Box<dyn LLM>>,
49 retriever: Option<Box<dyn Retriever>>,
50 memory: Option<Arc<Mutex<dyn BaseMemory>>>,
51 combine_documents_chain: Option<Box<dyn Chain>>,
52 condense_question_chain: Option<Box<dyn Chain>>,
53 prompt: Option<Box<dyn FormatPrompter>>,
54 rephrase_question: bool,
55 return_source_documents: bool,
56 input_key: String,
57 output_key: String,
58}
59impl ConversationalRetrieverChainBuilder {
60 pub fn new() -> Self {
61 ConversationalRetrieverChainBuilder {
62 llm: None,
63 retriever: None,
64 memory: None,
65 combine_documents_chain: None,
66 condense_question_chain: None,
67 prompt: None,
68 rephrase_question: true,
69 return_source_documents: true,
70 input_key: CONVERSATIONAL_RETRIEVAL_QA_DEFAULT_INPUT_KEY.to_string(),
71 output_key: DEFAULT_OUTPUT_KEY.to_string(),
72 }
73 }
74
75 pub fn retriever<R: Into<Box<dyn Retriever>>>(mut self, retriever: R) -> Self {
76 self.retriever = Some(retriever.into());
77 self
78 }
79
80 pub fn prompt<P: Into<Box<dyn FormatPrompter>>>(mut self, prompt: P) -> Self {
82 self.prompt = Some(prompt.into());
83 self
84 }
85
86 pub fn input_key<S: Into<String>>(mut self, input_key: S) -> Self {
87 self.input_key = input_key.into();
88 self
89 }
90
91 pub fn memory(mut self, memory: Arc<Mutex<dyn BaseMemory>>) -> Self {
92 self.memory = Some(memory);
93 self
94 }
95
96 pub fn llm<L: Into<Box<dyn LLM>>>(mut self, llm: L) -> Self {
97 self.llm = Some(llm.into());
98 self
99 }
100
101 pub fn combine_documents_chain<C: Into<Box<dyn Chain>>>(
103 mut self,
104 combine_documents_chain: C,
105 ) -> Self {
106 self.combine_documents_chain = Some(combine_documents_chain.into());
107 self
108 }
109
110 pub fn condense_question_chain<C: Into<Box<dyn Chain>>>(
112 mut self,
113 condense_question_chain: C,
114 ) -> Self {
115 self.condense_question_chain = Some(condense_question_chain.into());
116 self
117 }
118
119 pub fn rephrase_question(mut self, rephrase_question: bool) -> Self {
120 self.rephrase_question = rephrase_question;
121 self
122 }
123
124 pub fn return_source_documents(mut self, return_source_documents: bool) -> Self {
125 self.return_source_documents = return_source_documents;
126 self
127 }
128
129 pub fn build(mut self) -> Result<ConversationalRetrieverChain, ChainError> {
130 if let Some(llm) = self.llm {
131 let combine_documents_chain = {
132 let mut builder = StuffDocumentBuilder::new().llm(llm.clone_box());
133 if let Some(prompt) = self.prompt {
134 builder = builder.prompt(prompt);
135 }
136 builder.build()?
137 };
138 let condense_question_chain = CondenseQuestionGeneratorChain::new(llm.clone_box());
139 self.combine_documents_chain = Some(Box::new(combine_documents_chain));
140 self.condense_question_chain = Some(Box::new(condense_question_chain));
141 }
142
143 let retriever = self
144 .retriever
145 .ok_or_else(|| ChainError::MissingObject("Retriever must be set".into()))?;
146
147 let memory = self
148 .memory
149 .unwrap_or_else(|| Arc::new(Mutex::new(SimpleMemory::new())));
150
151 let combine_documents_chain = self.combine_documents_chain.ok_or_else(|| {
152 ChainError::MissingObject(
153 "Combine documents chain must be set or llm must be set".into(),
154 )
155 })?;
156 let condense_question_chain = self.condense_question_chain.ok_or_else(|| {
157 ChainError::MissingObject(
158 "Condense question chain must be set or llm must be set".into(),
159 )
160 })?;
161 Ok(ConversationalRetrieverChain {
162 retriever,
163 memory,
164 combine_documents_chain,
165 condense_question_chain,
166 rephrase_question: self.rephrase_question,
167 return_source_documents: self.return_source_documents,
168 input_key: self.input_key,
169 output_key: self.output_key,
170 })
171 }
172}