langchain_rust/chain/
question_answering.rs1use std::pin::Pin;
2
3use async_trait::async_trait;
4use futures::Stream;
5
6use crate::{
7 language_models::{llm::LLM, GenerateResult},
8 prompt::PromptArgs,
9 prompt_args,
10 schemas::{messages::Message, Document, StreamData},
11 template_jinja2,
12};
13
14use super::{
15 options::ChainCallOptions, Chain, ChainError, LLMChain, LLMChainBuilder, StuffDocument,
16};
17
18const DEFAULTCONDENSEQUESTIONTEMPLATE: &str = r#"Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question, in its original language.
19
20Chat History:
21{{chat_history}}
22Follow Up Input: {{question}}
23Standalone question:"#;
24
25pub struct CondenseQuestionPromptBuilder {
26 chat_history: String,
27 question: String,
28}
29impl CondenseQuestionPromptBuilder {
30 pub fn new() -> Self {
31 Self {
32 chat_history: "".to_string(),
33 question: "".to_string(),
34 }
35 }
36
37 pub fn question<S: Into<String>>(mut self, question: S) -> Self {
38 self.question = question.into();
39 self
40 }
41
42 pub fn chat_history(mut self, chat_history: &[Message]) -> Self {
43 self.chat_history = Message::messages_to_string(chat_history);
44 self
45 }
46
47 pub fn build(self) -> PromptArgs {
48 prompt_args! {
49 "chat_history" => self.chat_history,
50 "question" => self.question
51 }
52 }
53}
54
55pub struct CondenseQuestionGeneratorChain {
56 chain: LLMChain,
57}
58
59impl CondenseQuestionGeneratorChain {
60 pub fn new<L: Into<Box<dyn LLM>>>(llm: L) -> Self {
61 let condense_question_prompt_template =
62 template_jinja2!(DEFAULTCONDENSEQUESTIONTEMPLATE, "chat_history", "question");
63
64 let chain = LLMChainBuilder::new()
65 .llm(llm)
66 .prompt(condense_question_prompt_template)
67 .build()
68 .unwrap(); Self { chain }
71 }
72
73 pub fn prompt_builder(&self) -> CondenseQuestionPromptBuilder {
74 CondenseQuestionPromptBuilder::new()
75 }
76}
77
78#[async_trait]
79impl Chain for CondenseQuestionGeneratorChain {
80 async fn call(&self, input_variables: PromptArgs) -> Result<GenerateResult, ChainError> {
81 self.chain.call(input_variables).await
82 }
83
84 async fn stream(
85 &self,
86 input_variables: PromptArgs,
87 ) -> Result<Pin<Box<dyn Stream<Item = Result<StreamData, ChainError>> + Send>>, ChainError>
88 {
89 self.chain.stream(input_variables).await
90 }
91}
92
93const DEFAULT_STUFF_QA_TEMPLATE: &str = r#"Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.
94
95{{context}}
96
97Question:{{question}}
98Helpful Answer:
99"#;
100
101pub struct StuffQAPromptBuilder<'a> {
102 input_documents: Vec<&'a Document>,
103 question: String,
104}
105
106impl<'a> StuffQAPromptBuilder<'a> {
107 pub fn new() -> Self {
108 Self {
109 input_documents: vec![],
110 question: "".to_string(),
111 }
112 }
113
114 pub fn documents(mut self, documents: &'a [Document]) -> Self {
115 self.input_documents = documents.iter().collect();
116 self
117 }
118
119 pub fn question<S: Into<String>>(mut self, question: S) -> Self {
120 self.question = question.into();
121 self
122 }
123
124 pub fn build(self) -> PromptArgs {
125 prompt_args! {
126 "input_documents" => self.input_documents,
127 "question" => self.question
128 }
129 }
130}
131
132pub(crate) fn load_stuff_qa<L: Into<Box<dyn LLM>>>(
133 llm: L,
134 options: Option<ChainCallOptions>,
135) -> StuffDocument {
136 let default_qa_prompt_template =
137 template_jinja2!(DEFAULT_STUFF_QA_TEMPLATE, "context", "question");
138
139 let llm_chain_builder = LLMChainBuilder::new()
140 .prompt(default_qa_prompt_template)
141 .options(options.unwrap_or_default())
142 .llm(llm)
143 .build()
144 .unwrap();
145
146 let llm_chain = llm_chain_builder;
147
148 StuffDocument::new(llm_chain)
149}
150
151#[cfg(test)]
152mod tests {
153 use crate::{
154 chain::{Chain, StuffDocument},
155 llm::openai::OpenAI,
156 schemas::Document,
157 };
158
159 #[tokio::test]
160 #[ignore]
161 async fn test_qa() {
162 let llm = OpenAI::default();
163 let chain = StuffDocument::load_stuff_qa(llm);
164 let input = chain
165 .qa_prompt_builder()
166 .documents(&[
167 Document::new(format!(
168 "\nQuestion: {}\nAnswer: {}\n",
169 "Which is the favorite text editor of luis", "Nvim"
170 )),
171 Document::new(format!(
172 "\nQuestion: {}\nAnswer: {}\n",
173 "How old is Luis", "24"
174 )),
175 ])
176 .question("How old is luis and whats his favorite text editor")
177 .build();
178
179 let ouput = chain.invoke(input).await.unwrap();
180
181 println!("{}", ouput);
182 }
183}