1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
use std::sync::Arc;
use tokio::sync::Mutex;

use crate::{
    chain::{
        Chain, ChainError, CondenseQuestionGeneratorChain, StuffDocumentBuilder, DEFAULT_OUTPUT_KEY,
    },
    language_models::llm::LLM,
    memory::SimpleMemory,
    prompt::FormatPrompter,
    schemas::{BaseMemory, Retriever},
};

use super::ConversationalRetrieverChain;

const CONVERSATIONAL_RETRIEVAL_QA_DEFAULT_INPUT_KEY: &str = "question";

///Conversation Retriever Chain Builder
/// # Usage
/// ## Convensional way
/// ```rust,ignore
/// let chain = ConversationalRetrieverChainBuilder::new()
///     .llm(llm)
///     .rephrase_question(true)
///     .retriever(RetrieverMock {})
///     .memory(SimpleMemory::new().into())
///     .build()
///     .expect("Error building ConversationalChain");
///
/// ```
/// ## Custom way
/// ```rust,ignore
///
/// let llm = Box::new(OpenAI::default().with_model(OpenAIModel::Gpt35.to_string()));
/// let combine_documents_chain = StuffDocument::load_stuff_qa(llm.clone_box());
//  let condense_question_chian = CondenseQuestionGeneratorChain::new(llm.clone_box());
/// let chain = ConversationalRetrieverChainBuilder::new()
///     .rephrase_question(true)
///     .combine_documents_chain(Box::new(combine_documents_chain))
///     .condense_question_chian(Box::new(condense_question_chian))
///     .retriever(RetrieverMock {})
///     .memory(SimpleMemory::new().into())
///     .build()
///     .expect("Error building ConversationalChain");
/// ```
///
pub struct ConversationalRetrieverChainBuilder {
    llm: Option<Box<dyn LLM>>,
    retriever: Option<Box<dyn Retriever>>,
    memory: Option<Arc<Mutex<dyn BaseMemory>>>,
    combine_documents_chain: Option<Box<dyn Chain>>,
    condense_question_chian: Option<Box<dyn Chain>>,
    prompt: Option<Box<dyn FormatPrompter>>,
    rephrase_question: bool,
    return_source_documents: bool,
    input_key: String,
    output_key: String,
}
impl ConversationalRetrieverChainBuilder {
    pub fn new() -> Self {
        ConversationalRetrieverChainBuilder {
            llm: None,
            retriever: None,
            memory: None,
            combine_documents_chain: None,
            condense_question_chian: None,
            prompt: None,
            rephrase_question: true,
            return_source_documents: true,
            input_key: CONVERSATIONAL_RETRIEVAL_QA_DEFAULT_INPUT_KEY.to_string(),
            output_key: DEFAULT_OUTPUT_KEY.to_string(),
        }
    }

    pub fn retriever<R: Into<Box<dyn Retriever>>>(mut self, retriever: R) -> Self {
        self.retriever = Some(retriever.into());
        self
    }

    ///If you want to add a custom prompt,keep in mind which variables are obligatory.
    pub fn prompt<P: Into<Box<dyn FormatPrompter>>>(mut self, prompt: P) -> Self {
        self.prompt = Some(prompt.into());
        self
    }

    pub fn input_key<S: Into<String>>(mut self, input_key: S) -> Self {
        self.input_key = input_key.into();
        self
    }

    pub fn memory(mut self, memory: Arc<Mutex<dyn BaseMemory>>) -> Self {
        self.memory = Some(memory);
        self
    }

    pub fn llm<L: Into<Box<dyn LLM>>>(mut self, llm: L) -> Self {
        self.llm = Some(llm.into());
        self
    }

    ///Chain designed to take the documents and the question and generate an output
    pub fn combine_documents_chain<C: Into<Box<dyn Chain>>>(
        mut self,
        combine_documents_chain: C,
    ) -> Self {
        self.combine_documents_chain = Some(combine_documents_chain.into());
        self
    }

    ///Chain designed to reformulate the question based on the cat history
    pub fn condense_question_chian<C: Into<Box<dyn Chain>>>(
        mut self,
        condense_question_chian: C,
    ) -> Self {
        self.condense_question_chian = Some(condense_question_chian.into());
        self
    }

    pub fn rephrase_question(mut self, rephrase_question: bool) -> Self {
        self.rephrase_question = rephrase_question;
        self
    }

    pub fn return_source_documents(mut self, return_source_documents: bool) -> Self {
        self.return_source_documents = return_source_documents;
        self
    }

    pub fn build(mut self) -> Result<ConversationalRetrieverChain, ChainError> {
        if let Some(llm) = self.llm {
            let combine_documents_chain = {
                let mut builder = StuffDocumentBuilder::new().llm(llm.clone_box());
                if let Some(prompt) = self.prompt {
                    builder = builder.prompt(prompt);
                }
                builder.build()?
            };
            let condense_question_chian = CondenseQuestionGeneratorChain::new(llm.clone_box());
            self.combine_documents_chain = Some(Box::new(combine_documents_chain));
            self.condense_question_chian = Some(Box::new(condense_question_chian));
        }

        let retriever = self
            .retriever
            .ok_or_else(|| ChainError::MissingObject("Retriever must be set".into()))?;

        let memory = self
            .memory
            .unwrap_or_else(|| Arc::new(Mutex::new(SimpleMemory::new())));

        let combine_documents_chain = self.combine_documents_chain.ok_or_else(|| {
            ChainError::MissingObject(
                "Combine documents chain must be set or llm must be set".into(),
            )
        })?;
        let condense_question_chian = self.condense_question_chian.ok_or_else(|| {
            ChainError::MissingObject(
                "Condense question chain must be set or llm must be set".into(),
            )
        })?;
        Ok(ConversationalRetrieverChain {
            retriever,
            memory,
            combine_documents_chain,
            condense_question_chian,
            rephrase_question: self.rephrase_question,
            return_source_documents: self.return_source_documents,
            input_key: self.input_key,
            output_key: self.output_key,
        })
    }
}