llm_chain/chains/
map_reduce.rs1use crate::traits::ExecutorError;
11use crate::{
12 frame::Frame, output::Output, prompt::Data, serialization::StorableEntity, step::Step, tokens,
13 tokens::PromptTokensError, traits::Executor, Parameters,
14};
15use futures::future::join_all;
16use futures::future::FutureExt;
17use serde::Deserialize;
18use serde::Serialize;
19
20use thiserror::Error;
21
22#[derive(Error, Debug)]
24pub enum MapReduceChainError {
25 #[error("FormatAndExecuteError: {0}")]
27 FormatAndExecuteError(#[from] crate::frame::FormatAndExecuteError),
28 #[error("TokenizeError: {0}")]
30 TokenizeError(#[from] crate::tokens::PromptTokensError),
31 #[error("The vector of input documents was empty")]
32 InputEmpty,
33 #[error("Error templating: {0}")]
34 StringTemplate(#[from] crate::prompt::StringTemplateError),
35}
36
37#[derive(Serialize, Deserialize)]
42pub struct Chain {
43 map: Step,
44 reduce: Step,
45}
46
47impl Chain {
48 pub fn new(map: Step, reduce: Step) -> Chain {
52 Chain { map, reduce }
53 }
54
55 pub async fn run<E: Executor>(
63 &self,
64 documents: Vec<Parameters>,
65 base_parameters: Parameters,
66 executor: &E,
67 ) -> Result<Output, MapReduceChainError> {
68 if documents.is_empty() {
69 return Err(MapReduceChainError::InputEmpty);
70 }
71 let map_frame = Frame::new(executor, &self.map);
72 let reduce_frame = Frame::new(executor, &self.reduce);
73
74 let chunked_docs = self.chunk_documents(
75 documents.clone(),
76 base_parameters.clone(),
77 executor,
78 &self.map,
79 )?;
80
81 let chunked_docs_with_base_parameters: Vec<_> = chunked_docs
83 .iter()
84 .map(|doc| base_parameters.combine(doc))
85 .collect();
86 let mapped_documents: Vec<_> = join_all(
87 chunked_docs_with_base_parameters
88 .iter()
89 .map(|doc| map_frame.format_and_execute(doc))
90 .collect::<Vec<_>>(),
91 )
92 .await;
93 let mapped_documents = mapped_documents
94 .into_iter()
95 .collect::<Result<Vec<Output>, _>>()?;
96 let mapped_documents: Vec<Result<Data<String>, ExecutorError>> = join_all(
97 mapped_documents
98 .into_iter()
99 .map(|x| x.to_immediate().map(|x| x.map(|y| y.as_content())))
100 .collect::<Vec<_>>(),
101 )
102 .await;
103 let mapped_documents: Vec<Data<String>> = mapped_documents
104 .into_iter()
105 .collect::<Result<Vec<Data<String>>, ExecutorError>>()
106 .map_err(|e| {
107 MapReduceChainError::FormatAndExecuteError(
108 crate::frame::FormatAndExecuteError::Execute(e),
109 )
110 })?;
111
112 let mut documents = self
113 .combine_documents_up_to(executor, mapped_documents, &base_parameters)
114 .await?;
115
116 if documents.is_empty() {
117 return Err(MapReduceChainError::InputEmpty);
118 }
119
120 loop {
121 let tasks: Vec<_> = documents
122 .iter()
123 .map(|doc| base_parameters.with_text(doc))
124 .collect();
125 let futures = tasks.iter().map(|p| reduce_frame.format_and_execute(p));
126 let new_docs = join_all(futures).await;
127 let new_docs = new_docs.into_iter().collect::<Result<Vec<_>, _>>()?;
128 let new_docs = join_all(
129 new_docs
130 .into_iter()
131 .map(|x| x.to_immediate().map(|x| x.map(|y| y.as_content()))),
132 )
133 .await;
134 let new_docs = new_docs
135 .into_iter()
136 .collect::<Result<Vec<Data<String>>, ExecutorError>>()
137 .map_err(|e| {
138 MapReduceChainError::FormatAndExecuteError(
139 crate::frame::FormatAndExecuteError::Execute(e),
140 )
141 })?;
142 let n_new_docs = new_docs.len();
143 if n_new_docs == 1 {
144 return Ok(Output::new_immediate(new_docs[0].clone()));
145 }
146 documents = self
147 .combine_documents_up_to(executor, new_docs, &base_parameters)
148 .await?;
149 }
150 }
151
152 async fn combine_documents_up_to<E: Executor>(
153 &self,
154 executor: &E,
155 mut v: Vec<Data<String>>,
156 parameters: &Parameters,
157 ) -> Result<Vec<String>, MapReduceChainError> {
158 let mut new_outputs = Vec::new();
159 while let Some(current) = v.pop() {
160 let mut current_doc = current.extract_last_body().cloned().unwrap_or_default();
161 while let Some(next) = v.last() {
162 let Some(next_doc_content) = next.extract_last_body() else {
163 continue;
164 };
165 let mut new_doc = current_doc.clone();
166 new_doc.push('\n');
167 new_doc.push_str(next_doc_content);
168
169 let params = parameters.with_text(new_doc.clone());
170 let prompt = self.reduce.format(¶ms)?;
171 let count = executor.tokens_used(self.reduce.options(), &prompt)?;
172 if count.has_tokens_remaining() {
173 current_doc = new_doc;
174 v.pop();
175 } else {
176 break;
177 }
178 }
179 new_outputs.push(current_doc);
180 }
181 Ok(new_outputs)
182 }
183
184 fn chunk_documents<'a, E>(
185 &self,
186 v: Vec<Parameters>,
187 base_parameters: Parameters,
188 executor: &E,
189 step: &Step,
190 ) -> Result<Vec<Parameters>, PromptTokensError>
191 where
192 E: Executor + 'a,
193 {
194 let data: Result<Vec<_>, _> = v
195 .iter()
196 .map(|x| {
197 <E as tokens::ExecutorTokenCountExt>::split_to_fit(
198 executor,
199 step,
200 x,
201 &base_parameters,
202 None,
203 )
204 })
205 .collect();
206 let data = data?.iter().flatten().cloned().collect();
207 Ok(data)
208 }
209}
210
211impl StorableEntity for Chain {
215 fn get_metadata() -> Vec<(String, String)> {
222 let base = vec![(
223 "chain-type".to_string(),
224 "llm-chain::chains::map_reduce::Chain".to_string(),
225 )];
226 base
227 }
228}