langchain_rust/chain/stuff_documents/
chain.rs

1use std::pin::Pin;
2
3use async_trait::async_trait;
4use futures::Stream;
5use serde_json::Value;
6
7use crate::{
8    chain::{
9        load_stuff_qa, options::ChainCallOptions, Chain, ChainError, LLMChain, StuffQAPromptBuilder,
10    },
11    language_models::{llm::LLM, GenerateResult},
12    prompt::PromptArgs,
13    schemas::{Document, StreamData},
14};
15
16const COMBINE_DOCUMENTS_DEFAULT_INPUT_KEY: &str = "input_documents";
17const COMBINE_DOCUMENTS_DEFAULT_OUTPUT_KEY: &str = "text";
18const COMBINE_DOCUMENTS_DEFAULT_DOCUMENT_VARIABLE_NAME: &str = "context";
19const STUFF_DOCUMENTS_DEFAULT_SEPARATOR: &str = "\n\n";
20
21pub struct StuffDocument {
22    llm_chain: LLMChain,
23    input_key: String,
24    document_variable_name: String,
25    separator: String,
26}
27
28impl StuffDocument {
29    pub fn new(llm_chain: LLMChain) -> Self {
30        Self {
31            llm_chain,
32            input_key: COMBINE_DOCUMENTS_DEFAULT_INPUT_KEY.to_string(),
33            document_variable_name: COMBINE_DOCUMENTS_DEFAULT_DOCUMENT_VARIABLE_NAME.to_string(),
34            separator: STUFF_DOCUMENTS_DEFAULT_SEPARATOR.to_string(),
35        }
36    }
37
38    fn join_documents(&self, docs: Vec<Document>) -> String {
39        docs.iter()
40            .map(|doc| doc.page_content.clone())
41            .collect::<Vec<_>>()
42            .join(&self.separator)
43    }
44
45    ///Inly use thi if you use the deafult prompt
46    pub fn qa_prompt_builder<'a>(&self) -> StuffQAPromptBuilder<'a> {
47        StuffQAPromptBuilder::new()
48    }
49
50    /// load_stuff_qa return an instance of StuffDocument
51    /// with a prompt desiged for question ansering
52    ///
53    /// # Example
54    /// ```rust,ignore
55    ///
56    /// let llm = OpenAI::default();
57    /// let chain = StuffDocument::load_stuff_qa(llm);
58    ///
59    /// let input = chain
60    /// .qa_prompt_builder()
61    /// .documents(&[
62    /// Document::new(format!(
63    /// "\nQuestion: {}\nAnswer: {}\n",
64    /// "Which is the favorite text editor of luis", "Nvim"
65    /// )),
66    /// Document::new(format!(
67    /// "\nQuestion: {}\nAnswer: {}\n",
68    /// "How old is Luis", "24"
69    /// )),
70    /// ])
71    /// .question("How old is luis and whats his favorite text editor")
72    /// .build();
73    ///
74    /// let ouput = chain.invoke(input).await.unwrap();
75    ///
76    /// println!("{}", ouput);
77    /// ```
78    ///
79    pub fn load_stuff_qa<L: Into<Box<dyn LLM>>>(llm: L) -> Self {
80        load_stuff_qa(llm, None)
81    }
82
83    /// load_stuff_qa_with_options return an instance of StuffDocument
84    /// with a prompt desiged for question ansering
85    ///
86    /// # Example
87    /// ```rust,ignore
88    ///
89    /// let llm = OpenAI::default();
90    /// let chain = StuffDocument::load_stuff_qa_with_options(llm,ChainCallOptions::default());
91    ///
92    /// let input = chain
93    /// .qa_prompt_builder()
94    /// .documents(&[
95    /// Document::new(format!(
96    /// "\nQuestion: {}\nAnswer: {}\n",
97    /// "Which is the favorite text editor of luis", "Nvim"
98    /// )),
99    /// Document::new(format!(
100    /// "\nQuestion: {}\nAnswer: {}\n",
101    /// "How old is Luis", "24"
102    /// )),
103    /// ])
104    /// .question("How old is luis and whats his favorite text editor")
105    /// .build();
106    ///
107    /// let ouput = chain.invoke(input).await.unwrap();
108    ///
109    /// println!("{}", ouput);
110    /// ```
111    ///
112    pub fn load_stuff_qa_with_options<L: LLM + 'static>(llm: L, opt: ChainCallOptions) -> Self {
113        load_stuff_qa(llm, Some(opt))
114    }
115}
116
117#[async_trait]
118impl Chain for StuffDocument {
119    async fn call(&self, input_variables: PromptArgs) -> Result<GenerateResult, ChainError> {
120        let docs = input_variables
121            .get(&self.input_key)
122            .ok_or_else(|| ChainError::MissingInputVariable(self.input_key.clone()))?;
123
124        let documents: Vec<Document> = serde_json::from_value(docs.clone()).map_err(|e| {
125            ChainError::IncorrectInputVariable {
126                source: e,
127                expected_type: "Vec<Document>".to_string(),
128            }
129        })?;
130
131        let mut input_values = input_variables.clone();
132        input_values.insert(
133            self.document_variable_name.clone(),
134            Value::String(self.join_documents(documents)),
135        );
136
137        self.llm_chain.call(input_values).await
138    }
139
140    async fn stream(
141        &self,
142        input_variables: PromptArgs,
143    ) -> Result<Pin<Box<dyn Stream<Item = Result<StreamData, ChainError>> + Send>>, ChainError>
144    {
145        let docs = input_variables
146            .get(&self.input_key)
147            .ok_or_else(|| ChainError::MissingInputVariable(self.input_key.clone()))?;
148
149        let documents: Vec<Document> = serde_json::from_value(docs.clone()).map_err(|e| {
150            ChainError::IncorrectInputVariable {
151                source: e,
152                expected_type: "Vec<Document>".to_string(),
153            }
154        })?;
155
156        let mut input_values = input_variables.clone();
157        input_values.insert(
158            self.document_variable_name.clone(),
159            Value::String(self.join_documents(documents)),
160        );
161        self.llm_chain.stream(input_values).await
162    }
163
164    fn get_input_keys(&self) -> Vec<String> {
165        vec![self.input_key.clone()]
166    }
167}