langchain_rust/chain/
question_answering.rs

1use std::pin::Pin;
2
3use async_trait::async_trait;
4use futures::Stream;
5
6use crate::{
7    language_models::{llm::LLM, GenerateResult},
8    prompt::PromptArgs,
9    prompt_args,
10    schemas::{messages::Message, Document, StreamData},
11    template_jinja2,
12};
13
14use super::{
15    options::ChainCallOptions, Chain, ChainError, LLMChain, LLMChainBuilder, StuffDocument,
16};
17
18const DEFAULTCONDENSEQUESTIONTEMPLATE: &str = r#"Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question, in its original language.
19
20Chat History:
21{{chat_history}}
22Follow Up Input: {{question}}
23Standalone question:"#;
24
25pub struct CondenseQuestionPromptBuilder {
26    chat_history: String,
27    question: String,
28}
29impl CondenseQuestionPromptBuilder {
30    pub fn new() -> Self {
31        Self {
32            chat_history: "".to_string(),
33            question: "".to_string(),
34        }
35    }
36
37    pub fn question<S: Into<String>>(mut self, question: S) -> Self {
38        self.question = question.into();
39        self
40    }
41
42    pub fn chat_history(mut self, chat_history: &[Message]) -> Self {
43        self.chat_history = Message::messages_to_string(chat_history);
44        self
45    }
46
47    pub fn build(self) -> PromptArgs {
48        prompt_args! {
49            "chat_history" => self.chat_history,
50            "question" => self.question
51        }
52    }
53}
54
55pub struct CondenseQuestionGeneratorChain {
56    chain: LLMChain,
57}
58
59impl CondenseQuestionGeneratorChain {
60    pub fn new<L: Into<Box<dyn LLM>>>(llm: L) -> Self {
61        let condense_question_prompt_template =
62            template_jinja2!(DEFAULTCONDENSEQUESTIONTEMPLATE, "chat_history", "question");
63
64        let chain = LLMChainBuilder::new()
65            .llm(llm)
66            .prompt(condense_question_prompt_template)
67            .build()
68            .unwrap(); //Its safe to unwrap here because we are sure that the prompt and the LLM are
69                       //set.
70        Self { chain }
71    }
72
73    pub fn prompt_builder(&self) -> CondenseQuestionPromptBuilder {
74        CondenseQuestionPromptBuilder::new()
75    }
76}
77
78#[async_trait]
79impl Chain for CondenseQuestionGeneratorChain {
80    async fn call(&self, input_variables: PromptArgs) -> Result<GenerateResult, ChainError> {
81        self.chain.call(input_variables).await
82    }
83
84    async fn stream(
85        &self,
86        input_variables: PromptArgs,
87    ) -> Result<Pin<Box<dyn Stream<Item = Result<StreamData, ChainError>> + Send>>, ChainError>
88    {
89        self.chain.stream(input_variables).await
90    }
91}
92
93const DEFAULT_STUFF_QA_TEMPLATE: &str = r#"Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.
94
95{{context}}
96
97Question:{{question}}
98Helpful Answer:
99"#;
100
101pub struct StuffQAPromptBuilder<'a> {
102    input_documents: Vec<&'a Document>,
103    question: String,
104}
105
106impl<'a> StuffQAPromptBuilder<'a> {
107    pub fn new() -> Self {
108        Self {
109            input_documents: vec![],
110            question: "".to_string(),
111        }
112    }
113
114    pub fn documents(mut self, documents: &'a [Document]) -> Self {
115        self.input_documents = documents.iter().collect();
116        self
117    }
118
119    pub fn question<S: Into<String>>(mut self, question: S) -> Self {
120        self.question = question.into();
121        self
122    }
123
124    pub fn build(self) -> PromptArgs {
125        prompt_args! {
126            "input_documents" => self.input_documents,
127            "question" => self.question
128        }
129    }
130}
131
132pub(crate) fn load_stuff_qa<L: Into<Box<dyn LLM>>>(
133    llm: L,
134    options: Option<ChainCallOptions>,
135) -> StuffDocument {
136    let default_qa_prompt_template =
137        template_jinja2!(DEFAULT_STUFF_QA_TEMPLATE, "context", "question");
138
139    let llm_chain_builder = LLMChainBuilder::new()
140        .prompt(default_qa_prompt_template)
141        .options(options.unwrap_or_default())
142        .llm(llm)
143        .build()
144        .unwrap();
145
146    let llm_chain = llm_chain_builder;
147
148    StuffDocument::new(llm_chain)
149}
150
151#[cfg(test)]
152mod tests {
153    use crate::{
154        chain::{Chain, StuffDocument},
155        llm::openai::OpenAI,
156        schemas::Document,
157    };
158
159    #[tokio::test]
160    #[ignore]
161    async fn test_qa() {
162        let llm = OpenAI::default();
163        let chain = StuffDocument::load_stuff_qa(llm);
164        let input = chain
165            .qa_prompt_builder()
166            .documents(&[
167                Document::new(format!(
168                    "\nQuestion: {}\nAnswer: {}\n",
169                    "Which is the favorite text editor of luis", "Nvim"
170                )),
171                Document::new(format!(
172                    "\nQuestion: {}\nAnswer: {}\n",
173                    "How old is Luis", "24"
174                )),
175            ])
176            .question("How old is luis and whats his favorite text editor")
177            .build();
178
179        let ouput = chain.invoke(input).await.unwrap();
180
181        println!("{}", ouput);
182    }
183}