use std::sync::Arc;
use cognis_core::documents::Document;
use cognis_core::error::Result;
use cognis_core::language_models::chat_model::BaseChatModel;
use cognis_core::messages::{HumanMessage, Message};
const DEFAULT_MAP_PROMPT: &str = "Summarize the following text:\n\n{text}\n\nSummary:";
const DEFAULT_REDUCE_PROMPT: &str =
"Combine the following summaries into a coherent final answer:\n\n{summaries}\n\nFinal answer:";
pub struct MapReduceChain {
llm: Arc<dyn BaseChatModel>,
map_prompt: String,
reduce_prompt: String,
max_concurrency: usize,
}
impl MapReduceChain {
pub fn new(llm: Arc<dyn BaseChatModel>) -> Self {
Self {
llm,
map_prompt: DEFAULT_MAP_PROMPT.to_string(),
reduce_prompt: DEFAULT_REDUCE_PROMPT.to_string(),
max_concurrency: 5,
}
}
pub fn with_map_prompt(mut self, prompt: impl Into<String>) -> Self {
self.map_prompt = prompt.into();
self
}
pub fn with_reduce_prompt(mut self, prompt: impl Into<String>) -> Self {
self.reduce_prompt = prompt.into();
self
}
pub fn with_max_concurrency(mut self, max_concurrency: usize) -> Self {
self.max_concurrency = max_concurrency;
self
}
fn format_map_prompt(&self, text: &str) -> String {
self.map_prompt.replace("{text}", text)
}
fn format_reduce_prompt(&self, summaries: &str) -> String {
self.reduce_prompt.replace("{summaries}", summaries)
}
pub async fn call(&self, documents: &[Document]) -> Result<String> {
let mut map_results = Vec::with_capacity(documents.len());
for doc in documents {
let prompt = self.format_map_prompt(&doc.page_content);
let messages = vec![Message::Human(HumanMessage::new(&prompt))];
let ai_msg = self.llm.invoke_messages(&messages, None).await?;
map_results.push(ai_msg.base.content.text());
}
let combined = map_results.join("\n\n");
let reduce_prompt = self.format_reduce_prompt(&combined);
let messages = vec![Message::Human(HumanMessage::new(&reduce_prompt))];
let ai_msg = self.llm.invoke_messages(&messages, None).await?;
Ok(ai_msg.base.content.text())
}
pub async fn map(&self, documents: &[Document]) -> Result<Vec<String>> {
let mut results = Vec::with_capacity(documents.len());
for doc in documents {
let prompt = self.format_map_prompt(&doc.page_content);
let messages = vec![Message::Human(HumanMessage::new(&prompt))];
let ai_msg = self.llm.invoke_messages(&messages, None).await?;
results.push(ai_msg.base.content.text());
}
Ok(results)
}
}
#[cfg(test)]
mod tests {
use super::*;
use cognis_core::language_models::fake::{FakeListChatModel, ParrotFakeChatModel};
fn fake_model(responses: Vec<&str>) -> Arc<dyn BaseChatModel> {
Arc::new(FakeListChatModel::new(
responses.into_iter().map(String::from).collect(),
))
}
fn make_doc(content: &str) -> Document {
Document::new(content)
}
#[tokio::test]
async fn test_map_phase_produces_per_doc_summaries() {
let llm = fake_model(vec!["Summary of doc 1", "Summary of doc 2", "Final"]);
let chain = MapReduceChain::new(llm);
let docs = vec![
make_doc("Document one content"),
make_doc("Document two content"),
];
let map_results = chain.map(&docs).await.unwrap();
assert_eq!(map_results.len(), 2);
assert_eq!(map_results[0], "Summary of doc 1");
assert_eq!(map_results[1], "Summary of doc 2");
}
#[tokio::test]
async fn test_reduce_phase_combines_summaries() {
let llm: Arc<dyn BaseChatModel> = Arc::new(ParrotFakeChatModel::new());
let chain = MapReduceChain::new(llm);
let reduce_prompt = chain.format_reduce_prompt("Summary A\n\nSummary B");
assert!(reduce_prompt.contains("Summary A"));
assert!(reduce_prompt.contains("Summary B"));
assert!(reduce_prompt.contains("Combine the following summaries"));
}
#[tokio::test]
async fn test_end_to_end_call() {
let llm = fake_model(vec!["mapped-1", "mapped-2", "final-answer"]);
let chain = MapReduceChain::new(llm);
let docs = vec![make_doc("First doc"), make_doc("Second doc")];
let result = chain.call(&docs).await.unwrap();
assert_eq!(result, "final-answer");
}
#[tokio::test]
async fn test_custom_prompts() {
let llm: Arc<dyn BaseChatModel> = Arc::new(ParrotFakeChatModel::new());
let chain = MapReduceChain::new(llm)
.with_map_prompt("EXTRACT: {text} END")
.with_reduce_prompt("MERGE: {summaries} DONE");
let docs = vec![make_doc("hello world")];
let result = chain.call(&docs).await.unwrap();
assert!(result.contains("MERGE:"));
assert!(result.contains("DONE"));
}
#[tokio::test]
async fn test_single_document() {
let llm = fake_model(vec!["single-map", "single-reduce"]);
let chain = MapReduceChain::new(llm);
let docs = vec![make_doc("Only document")];
let result = chain.call(&docs).await.unwrap();
assert_eq!(result, "single-reduce");
}
#[tokio::test]
async fn test_empty_documents() {
let llm = fake_model(vec!["reduce-of-nothing"]);
let chain = MapReduceChain::new(llm);
let docs: Vec<Document> = vec![];
let result = chain.call(&docs).await.unwrap();
assert_eq!(result, "reduce-of-nothing");
}
#[tokio::test]
async fn test_max_concurrency_builder() {
let llm = fake_model(vec!["x"]);
let chain = MapReduceChain::new(llm).with_max_concurrency(10);
assert_eq!(chain.max_concurrency, 10);
}
}