llm_chain/chains/
map_reduce.rs

1//! The `map_reduce` module contains the `Chain` struct, which represents a map-reduce chain.
2//!
3//! A map-reduce chain is a combination of two steps - a `map` step and a `reduce` step.
4//! The `map` step processes each input document and the `reduce` step combines the results
5//! of the `map` step into a single output.
6//!
7//! The `Chain` struct is generic over the type of the `Step` and provides a convenient way
8//! to execute map-reduce operations using a provided `Executor`.
9
10use crate::traits::ExecutorError;
11use crate::{
12    frame::Frame, output::Output, prompt::Data, serialization::StorableEntity, step::Step, tokens,
13    tokens::PromptTokensError, traits::Executor, Parameters,
14};
15use futures::future::join_all;
16use futures::future::FutureExt;
17use serde::Deserialize;
18use serde::Serialize;
19
20use thiserror::Error;
21
22/// The `MapReduceChainError` enum represents errors that can occur when executing a map-reduce chain.
23#[derive(Error, Debug)]
24pub enum MapReduceChainError {
25    /// An error relating to the operation of the Executor.
26    #[error("FormatAndExecuteError: {0}")]
27    FormatAndExecuteError(#[from] crate::frame::FormatAndExecuteError),
28    /// An error relating to tokenizing the inputs.
29    #[error("TokenizeError: {0}")]
30    TokenizeError(#[from] crate::tokens::PromptTokensError),
31    #[error("The vector of input documents was empty")]
32    InputEmpty,
33    #[error("Error templating: {0}")]
34    StringTemplate(#[from] crate::prompt::StringTemplateError),
35}
36
37/// The `Chain` struct represents a map-reduce chain, consisting of a `map` step and a `reduce` step.
38///
39/// The struct is generic over the type of the `Step` and provides methods for constructing and
40/// executing the chain using a given `Executor`.
41#[derive(Serialize, Deserialize)]
42pub struct Chain {
43    map: Step,
44    reduce: Step,
45}
46
47impl Chain {
48    /// Constructs a new `Chain` with the given `map` and `reduce` steps.
49    ///
50    /// The `new` function takes two instances of `Step` and returns a new `Chain` instance.
51    pub fn new(map: Step, reduce: Step) -> Chain {
52        Chain { map, reduce }
53    }
54
55    /// Executes the map-reduce chain using the provided `Executor`.
56    ///
57    /// The `run` function takes a vector of input documents, a base set of parameters, and a reference
58    /// to an `Executor`. It processes the input documents using the `map` step and the `reduce` step,
59    /// and returns the result as an `Option<E::Output>`.
60    ///
61    /// The function is asynchronous and must be awaited.
62    pub async fn run<E: Executor>(
63        &self,
64        documents: Vec<Parameters>,
65        base_parameters: Parameters,
66        executor: &E,
67    ) -> Result<Output, MapReduceChainError> {
68        if documents.is_empty() {
69            return Err(MapReduceChainError::InputEmpty);
70        }
71        let map_frame = Frame::new(executor, &self.map);
72        let reduce_frame = Frame::new(executor, &self.reduce);
73
74        let chunked_docs = self.chunk_documents(
75            documents.clone(),
76            base_parameters.clone(),
77            executor,
78            &self.map,
79        )?;
80
81        // Execute the `map` step for each document, combining the base parameters with each document's parameters.
82        let chunked_docs_with_base_parameters: Vec<_> = chunked_docs
83            .iter()
84            .map(|doc| base_parameters.combine(doc))
85            .collect();
86        let mapped_documents: Vec<_> = join_all(
87            chunked_docs_with_base_parameters
88                .iter()
89                .map(|doc| map_frame.format_and_execute(doc))
90                .collect::<Vec<_>>(),
91        )
92        .await;
93        let mapped_documents = mapped_documents
94            .into_iter()
95            .collect::<Result<Vec<Output>, _>>()?;
96        let mapped_documents: Vec<Result<Data<String>, ExecutorError>> = join_all(
97            mapped_documents
98                .into_iter()
99                .map(|x| x.to_immediate().map(|x| x.map(|y| y.as_content())))
100                .collect::<Vec<_>>(),
101        )
102        .await;
103        let mapped_documents: Vec<Data<String>> = mapped_documents
104            .into_iter()
105            .collect::<Result<Vec<Data<String>>, ExecutorError>>()
106            .map_err(|e| {
107                MapReduceChainError::FormatAndExecuteError(
108                    crate::frame::FormatAndExecuteError::Execute(e),
109                )
110            })?;
111
112        let mut documents = self
113            .combine_documents_up_to(executor, mapped_documents, &base_parameters)
114            .await?;
115
116        if documents.is_empty() {
117            return Err(MapReduceChainError::InputEmpty);
118        }
119
120        loop {
121            let tasks: Vec<_> = documents
122                .iter()
123                .map(|doc| base_parameters.with_text(doc))
124                .collect();
125            let futures = tasks.iter().map(|p| reduce_frame.format_and_execute(p));
126            let new_docs = join_all(futures).await;
127            let new_docs = new_docs.into_iter().collect::<Result<Vec<_>, _>>()?;
128            let new_docs = join_all(
129                new_docs
130                    .into_iter()
131                    .map(|x| x.to_immediate().map(|x| x.map(|y| y.as_content()))),
132            )
133            .await;
134            let new_docs = new_docs
135                .into_iter()
136                .collect::<Result<Vec<Data<String>>, ExecutorError>>()
137                .map_err(|e| {
138                    MapReduceChainError::FormatAndExecuteError(
139                        crate::frame::FormatAndExecuteError::Execute(e),
140                    )
141                })?;
142            let n_new_docs = new_docs.len();
143            if n_new_docs == 1 {
144                return Ok(Output::new_immediate(new_docs[0].clone()));
145            }
146            documents = self
147                .combine_documents_up_to(executor, new_docs, &base_parameters)
148                .await?;
149        }
150    }
151
152    async fn combine_documents_up_to<E: Executor>(
153        &self,
154        executor: &E,
155        mut v: Vec<Data<String>>,
156        parameters: &Parameters,
157    ) -> Result<Vec<String>, MapReduceChainError> {
158        let mut new_outputs = Vec::new();
159        while let Some(current) = v.pop() {
160            let mut current_doc = current.extract_last_body().cloned().unwrap_or_default();
161            while let Some(next) = v.last() {
162                let Some(next_doc_content) = next.extract_last_body() else {
163                    continue;
164                };
165                let mut new_doc = current_doc.clone();
166                new_doc.push('\n');
167                new_doc.push_str(next_doc_content);
168
169                let params = parameters.with_text(new_doc.clone());
170                let prompt = self.reduce.format(&params)?;
171                let count = executor.tokens_used(self.reduce.options(), &prompt)?;
172                if count.has_tokens_remaining() {
173                    current_doc = new_doc;
174                    v.pop();
175                } else {
176                    break;
177                }
178            }
179            new_outputs.push(current_doc);
180        }
181        Ok(new_outputs)
182    }
183
184    fn chunk_documents<'a, E>(
185        &self,
186        v: Vec<Parameters>,
187        base_parameters: Parameters,
188        executor: &E,
189        step: &Step,
190    ) -> Result<Vec<Parameters>, PromptTokensError>
191    where
192        E: Executor + 'a,
193    {
194        let data: Result<Vec<_>, _> = v
195            .iter()
196            .map(|x| {
197                <E as tokens::ExecutorTokenCountExt>::split_to_fit(
198                    executor,
199                    step,
200                    x,
201                    &base_parameters,
202                    None,
203                )
204            })
205            .collect();
206        let data = data?.iter().flatten().cloned().collect();
207        Ok(data)
208    }
209}
210
211/// Implements the `StorableEntity` trait for the `Chain` struct.
212///
213/// This implementation provides a method for extracting metadata from a `Chain` instance, in order to identify it
214impl StorableEntity for Chain {
215    /// Returns metadata about the Chain instance.
216    ///
217    /// The metadata is returned as a vector of tuples, where each tuple contains a key-value pair
218    /// representing a metadata field and its value.
219    ///
220    /// This method also extracts metadata from the Step instances associated with the Chain.
221    fn get_metadata() -> Vec<(String, String)> {
222        let base = vec![(
223            "chain-type".to_string(),
224            "llm-chain::chains::map_reduce::Chain".to_string(),
225        )];
226        base
227    }
228}