use crate::output::Output;
use crate::traits::ExecutorError;
use crate::{
frame::Frame,
tokens::ExecutorTokenCountExt,
tokens::PromptTokensError,
traits::{Executor, Step},
Parameters,
};
use futures::future::join_all;
#[cfg(feature = "serialization")]
use serde::{
de::{MapAccess, Visitor},
Deserialize,
};
#[cfg(feature = "serialization")]
use crate::serialization::StorableEntity;
use thiserror::Error;
#[derive(Error, Debug)]
pub enum MapReduceChainError<Err: ExecutorError> {
#[error("ExecutorError: {0}")]
ExecutorError(#[from] Err),
#[error("TokenizeError: {0}")]
TokenizeError(#[from] crate::tokens::PromptTokensError),
#[error("The vector of input documents was empty")]
InputEmpty,
}
pub struct Chain<S: Step> {
map: S,
reduce: S,
}
impl<S: Step> Chain<S> {
pub fn new(map: S, reduce: S) -> Chain<S> {
Chain { map, reduce }
}
pub async fn run<E>(
&self,
documents: Vec<Parameters>,
base_parameters: Parameters,
executor: &E,
) -> Result<E::Output, MapReduceChainError<E::Error>>
where
E: Executor<Step = S>,
{
if documents.is_empty() {
return Err(MapReduceChainError::InputEmpty);
}
let map_frame = Frame::new(executor, &self.map);
let reduce_frame = Frame::new(executor, &self.reduce);
let chunked_docs =
self.chunk_documents::<E, E::Token>(documents.clone(), executor, &self.map)?;
let chunked_docs_with_base_parameters: Vec<_> = chunked_docs
.iter()
.map(|doc| base_parameters.combine(doc))
.collect();
let futures: Vec<_> = chunked_docs_with_base_parameters
.iter()
.map(|doc| map_frame.format_and_execute(doc))
.collect();
let mapped_documents = join_all(futures).await;
let mapped_documents = mapped_documents.into_iter().collect::<Result<_, _>>()?;
let mut documents = self
.combine_documents_up_to::<E, E::Token>(executor, mapped_documents, &base_parameters)
.await?;
if documents.is_empty() {
return Err(MapReduceChainError::InputEmpty);
}
loop {
let tasks: Vec<_> = documents
.iter()
.map(|doc| base_parameters.with_text(doc))
.collect();
let futures = tasks.iter().map(|p| reduce_frame.format_and_execute(p));
let new_docs = join_all(futures).await;
let new_docs = new_docs.into_iter().collect::<Result<Vec<_>, _>>()?;
let n_new_docs = new_docs.len();
if n_new_docs == 1 {
return Ok(new_docs[0].clone());
}
documents = self
.combine_documents_up_to::<E, E::Token>(executor, new_docs, &base_parameters)
.await?;
}
}
async fn combine_documents_up_to<E, T>(
&self,
executor: &E,
mut v: Vec<<E as Executor>::Output>,
parameters: &Parameters,
) -> Result<Vec<String>, MapReduceChainError<E::Error>>
where
E: Executor<Step = S>,
{
let mut new_outputs = Vec::new();
while let Some(current) = v.pop() {
let mut current_doc = current.primary_textual_output().await.unwrap_or_default();
while let Some(next) = v.last() {
let next_doc = next.primary_textual_output().await;
if next_doc.is_none() {
continue;
}
let mut new_doc = current_doc.clone();
new_doc.push('\n');
new_doc.push_str(&next.primary_textual_output().await.unwrap_or_default());
let params = parameters.with_text(new_doc.clone());
let count = executor.tokens_used(&self.reduce, ¶ms)?;
if count.has_tokens_remaining() {
current_doc = new_doc;
v.pop();
} else {
break;
}
}
new_outputs.push(current_doc);
}
Ok(new_outputs)
}
fn chunk_documents<E, T>(
&self,
v: Vec<Parameters>,
executor: &E,
step: &S,
) -> Result<Vec<Parameters>, PromptTokensError>
where
E: Executor<Step = S>,
{
let data: Result<Vec<_>, _> = v.iter().map(|x| executor.split_to_fit(step, x)).collect();
let data = data?.iter().flatten().cloned().collect();
Ok(data)
}
}
#[cfg(feature = "serialization")]
impl<'de, S: Step + Deserialize<'de>> Deserialize<'de> for Chain<S> {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
struct ChainVisitor<S>(std::marker::PhantomData<S>);
impl<'de, S: Step + Deserialize<'de>> Visitor<'de> for ChainVisitor<S> {
type Value = Chain<S>;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("an object with fields `map` and `reduce`")
}
fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
where
A: MapAccess<'de>,
{
let mut map_value: Option<S> = None;
let mut reduce_value: Option<S> = None;
while let Some(key) = map.next_key()? {
match key {
"map" => {
if map_value.is_some() {
return Err(serde::de::Error::duplicate_field("map"));
}
map_value = Some(map.next_value()?);
}
"reduce" => {
if reduce_value.is_some() {
return Err(serde::de::Error::duplicate_field("reduce"));
}
reduce_value = Some(map.next_value()?);
}
_ => (),
}
}
let map = map_value.ok_or_else(|| serde::de::Error::missing_field("map"))?;
let reduce =
reduce_value.ok_or_else(|| serde::de::Error::missing_field("reduce"))?;
Ok(Chain { map, reduce })
}
}
deserializer.deserialize_struct(
"Chain",
&["map", "reduce"],
ChainVisitor(std::marker::PhantomData),
)
}
}
#[cfg(feature = "serialization")]
impl<S> StorableEntity for Chain<S>
where
S: Step + StorableEntity,
{
fn get_metadata() -> Vec<(String, String)> {
let mut base = vec![(
"chain-type".to_string(),
"llm-chain::chains::map_reduce::Chain".to_string(),
)];
base.append(&mut S::get_metadata());
base
}
}