langchain_rust/chain/conversational_retrieval_qa/
builder.rs

1use std::sync::Arc;
2use tokio::sync::Mutex;
3
4use crate::{
5    chain::{
6        Chain, ChainError, CondenseQuestionGeneratorChain, StuffDocumentBuilder, DEFAULT_OUTPUT_KEY,
7    },
8    language_models::llm::LLM,
9    memory::SimpleMemory,
10    prompt::FormatPrompter,
11    schemas::{BaseMemory, Retriever},
12};
13
14use super::ConversationalRetrieverChain;
15
16const CONVERSATIONAL_RETRIEVAL_QA_DEFAULT_INPUT_KEY: &str = "question";
17
18///Conversation Retriever Chain Builder
19/// # Usage
20/// ## Convensional way
21/// ```rust,ignore
22/// let chain = ConversationalRetrieverChainBuilder::new()
23///     .llm(llm)
24///     .rephrase_question(true)
25///     .retriever(RetrieverMock {})
26///     .memory(SimpleMemory::new().into())
27///     .build()
28///     .expect("Error building ConversationalChain");
29///
30/// ```
31/// ## Custom way
32/// ```rust,ignore
33///
34/// let llm = Box::new(OpenAI::default().with_model(OpenAIModel::Gpt35.to_string()));
35/// let combine_documents_chain = StuffDocument::load_stuff_qa(llm.clone_box());
36//  let condense_question_chain = CondenseQuestionGeneratorChain::new(llm.clone_box());
37/// let chain = ConversationalRetrieverChainBuilder::new()
38///     .rephrase_question(true)
39///     .combine_documents_chain(Box::new(combine_documents_chain))
40///     .condense_question_chain(Box::new(condense_question_chain))
41///     .retriever(RetrieverMock {})
42///     .memory(SimpleMemory::new().into())
43///     .build()
44///     .expect("Error building ConversationalChain");
45/// ```
46///
47pub struct ConversationalRetrieverChainBuilder {
48    llm: Option<Box<dyn LLM>>,
49    retriever: Option<Box<dyn Retriever>>,
50    memory: Option<Arc<Mutex<dyn BaseMemory>>>,
51    combine_documents_chain: Option<Box<dyn Chain>>,
52    condense_question_chain: Option<Box<dyn Chain>>,
53    prompt: Option<Box<dyn FormatPrompter>>,
54    rephrase_question: bool,
55    return_source_documents: bool,
56    input_key: String,
57    output_key: String,
58}
59impl ConversationalRetrieverChainBuilder {
60    pub fn new() -> Self {
61        ConversationalRetrieverChainBuilder {
62            llm: None,
63            retriever: None,
64            memory: None,
65            combine_documents_chain: None,
66            condense_question_chain: None,
67            prompt: None,
68            rephrase_question: true,
69            return_source_documents: true,
70            input_key: CONVERSATIONAL_RETRIEVAL_QA_DEFAULT_INPUT_KEY.to_string(),
71            output_key: DEFAULT_OUTPUT_KEY.to_string(),
72        }
73    }
74
75    pub fn retriever<R: Into<Box<dyn Retriever>>>(mut self, retriever: R) -> Self {
76        self.retriever = Some(retriever.into());
77        self
78    }
79
80    ///If you want to add a custom prompt,keep in mind which variables are obligatory.
81    pub fn prompt<P: Into<Box<dyn FormatPrompter>>>(mut self, prompt: P) -> Self {
82        self.prompt = Some(prompt.into());
83        self
84    }
85
86    pub fn input_key<S: Into<String>>(mut self, input_key: S) -> Self {
87        self.input_key = input_key.into();
88        self
89    }
90
91    pub fn memory(mut self, memory: Arc<Mutex<dyn BaseMemory>>) -> Self {
92        self.memory = Some(memory);
93        self
94    }
95
96    pub fn llm<L: Into<Box<dyn LLM>>>(mut self, llm: L) -> Self {
97        self.llm = Some(llm.into());
98        self
99    }
100
101    ///Chain designed to take the documents and the question and generate an output
102    pub fn combine_documents_chain<C: Into<Box<dyn Chain>>>(
103        mut self,
104        combine_documents_chain: C,
105    ) -> Self {
106        self.combine_documents_chain = Some(combine_documents_chain.into());
107        self
108    }
109
110    ///Chain designed to reformulate the question based on the cat history
111    pub fn condense_question_chain<C: Into<Box<dyn Chain>>>(
112        mut self,
113        condense_question_chain: C,
114    ) -> Self {
115        self.condense_question_chain = Some(condense_question_chain.into());
116        self
117    }
118
119    pub fn rephrase_question(mut self, rephrase_question: bool) -> Self {
120        self.rephrase_question = rephrase_question;
121        self
122    }
123
124    pub fn return_source_documents(mut self, return_source_documents: bool) -> Self {
125        self.return_source_documents = return_source_documents;
126        self
127    }
128
129    pub fn build(mut self) -> Result<ConversationalRetrieverChain, ChainError> {
130        if let Some(llm) = self.llm {
131            let combine_documents_chain = {
132                let mut builder = StuffDocumentBuilder::new().llm(llm.clone_box());
133                if let Some(prompt) = self.prompt {
134                    builder = builder.prompt(prompt);
135                }
136                builder.build()?
137            };
138            let condense_question_chain = CondenseQuestionGeneratorChain::new(llm.clone_box());
139            self.combine_documents_chain = Some(Box::new(combine_documents_chain));
140            self.condense_question_chain = Some(Box::new(condense_question_chain));
141        }
142
143        let retriever = self
144            .retriever
145            .ok_or_else(|| ChainError::MissingObject("Retriever must be set".into()))?;
146
147        let memory = self
148            .memory
149            .unwrap_or_else(|| Arc::new(Mutex::new(SimpleMemory::new())));
150
151        let combine_documents_chain = self.combine_documents_chain.ok_or_else(|| {
152            ChainError::MissingObject(
153                "Combine documents chain must be set or llm must be set".into(),
154            )
155        })?;
156        let condense_question_chain = self.condense_question_chain.ok_or_else(|| {
157            ChainError::MissingObject(
158                "Condense question chain must be set or llm must be set".into(),
159            )
160        })?;
161        Ok(ConversationalRetrieverChain {
162            retriever,
163            memory,
164            combine_documents_chain,
165            condense_question_chain,
166            rephrase_question: self.rephrase_question,
167            return_source_documents: self.return_source_documents,
168            input_key: self.input_key,
169            output_key: self.output_key,
170        })
171    }
172}