langchain_rust/chain/stuff_documents/
chain.rs1use std::pin::Pin;
2
3use async_trait::async_trait;
4use futures::Stream;
5use serde_json::Value;
6
7use crate::{
8 chain::{
9 load_stuff_qa, options::ChainCallOptions, Chain, ChainError, LLMChain, StuffQAPromptBuilder,
10 },
11 language_models::{llm::LLM, GenerateResult},
12 prompt::PromptArgs,
13 schemas::{Document, StreamData},
14};
15
16const COMBINE_DOCUMENTS_DEFAULT_INPUT_KEY: &str = "input_documents";
17const COMBINE_DOCUMENTS_DEFAULT_OUTPUT_KEY: &str = "text";
18const COMBINE_DOCUMENTS_DEFAULT_DOCUMENT_VARIABLE_NAME: &str = "context";
19const STUFF_DOCUMENTS_DEFAULT_SEPARATOR: &str = "\n\n";
20
21pub struct StuffDocument {
22 llm_chain: LLMChain,
23 input_key: String,
24 document_variable_name: String,
25 separator: String,
26}
27
28impl StuffDocument {
29 pub fn new(llm_chain: LLMChain) -> Self {
30 Self {
31 llm_chain,
32 input_key: COMBINE_DOCUMENTS_DEFAULT_INPUT_KEY.to_string(),
33 document_variable_name: COMBINE_DOCUMENTS_DEFAULT_DOCUMENT_VARIABLE_NAME.to_string(),
34 separator: STUFF_DOCUMENTS_DEFAULT_SEPARATOR.to_string(),
35 }
36 }
37
38 fn join_documents(&self, docs: Vec<Document>) -> String {
39 docs.iter()
40 .map(|doc| doc.page_content.clone())
41 .collect::<Vec<_>>()
42 .join(&self.separator)
43 }
44
45 pub fn qa_prompt_builder<'a>(&self) -> StuffQAPromptBuilder<'a> {
47 StuffQAPromptBuilder::new()
48 }
49
50 pub fn load_stuff_qa<L: Into<Box<dyn LLM>>>(llm: L) -> Self {
80 load_stuff_qa(llm, None)
81 }
82
83 pub fn load_stuff_qa_with_options<L: LLM + 'static>(llm: L, opt: ChainCallOptions) -> Self {
113 load_stuff_qa(llm, Some(opt))
114 }
115}
116
117#[async_trait]
118impl Chain for StuffDocument {
119 async fn call(&self, input_variables: PromptArgs) -> Result<GenerateResult, ChainError> {
120 let docs = input_variables
121 .get(&self.input_key)
122 .ok_or_else(|| ChainError::MissingInputVariable(self.input_key.clone()))?;
123
124 let documents: Vec<Document> = serde_json::from_value(docs.clone()).map_err(|e| {
125 ChainError::IncorrectInputVariable {
126 source: e,
127 expected_type: "Vec<Document>".to_string(),
128 }
129 })?;
130
131 let mut input_values = input_variables.clone();
132 input_values.insert(
133 self.document_variable_name.clone(),
134 Value::String(self.join_documents(documents)),
135 );
136
137 self.llm_chain.call(input_values).await
138 }
139
140 async fn stream(
141 &self,
142 input_variables: PromptArgs,
143 ) -> Result<Pin<Box<dyn Stream<Item = Result<StreamData, ChainError>> + Send>>, ChainError>
144 {
145 let docs = input_variables
146 .get(&self.input_key)
147 .ok_or_else(|| ChainError::MissingInputVariable(self.input_key.clone()))?;
148
149 let documents: Vec<Document> = serde_json::from_value(docs.clone()).map_err(|e| {
150 ChainError::IncorrectInputVariable {
151 source: e,
152 expected_type: "Vec<Document>".to_string(),
153 }
154 })?;
155
156 let mut input_values = input_variables.clone();
157 input_values.insert(
158 self.document_variable_name.clone(),
159 Value::String(self.join_documents(documents)),
160 );
161 self.llm_chain.stream(input_values).await
162 }
163
164 fn get_input_keys(&self) -> Vec<String> {
165 vec![self.input_key.clone()]
166 }
167}