use std::sync::Arc;
use async_trait::async_trait;
use serde_json::{json, Value};
use cognis_core::documents::Document;
use cognis_core::error::{CognisError, Result};
use cognis_core::language_models::chat_model::BaseChatModel;
use cognis_core::messages::{HumanMessage, Message};
use cognis_core::runnables::base::Runnable;
use cognis_core::runnables::config::RunnableConfig;
const DEFAULT_STUFF_PROMPT: &str =
"Write a concise summary of the following:\n\n\"{text}\"\n\nCONCISE SUMMARY:";
const DEFAULT_MAP_SUMMARIZE_PROMPT: &str =
"Write a concise summary of the following:\n\n\"{text}\"\n\nCONCISE SUMMARY:";
const DEFAULT_REDUCE_SUMMARIZE_PROMPT: &str =
"The following is a set of summaries:\n\n{summaries}\n\n\
Take these and distill them into a final, consolidated summary of the main themes.\n\n\
FINAL SUMMARY:";
const DEFAULT_INITIAL_SUMMARIZE_PROMPT: &str =
"Write a concise summary of the following:\n\n\"{text}\"\n\nCONCISE SUMMARY:";
const DEFAULT_REFINE_SUMMARIZE_PROMPT: &str = "Your job is to produce a final summary.\n\
We have provided an existing summary up to a certain point:\n\n\
{existing_summary}\n\n\
We have the opportunity to refine the existing summary (only if needed) \
with some more context below.\n\n\
\"{text}\"\n\n\
Given the new context, refine the original summary. \
If the context is not useful, return the original summary.\n\n\
REFINED SUMMARY:";
pub struct StuffSummarizationChain {
llm: Arc<dyn BaseChatModel>,
prompt: String,
document_separator: String,
}
impl StuffSummarizationChain {
pub fn new(llm: Arc<dyn BaseChatModel>) -> Self {
Self {
llm,
prompt: DEFAULT_STUFF_PROMPT.to_string(),
document_separator: "\n\n".to_string(),
}
}
pub fn with_prompt(mut self, prompt: impl Into<String>) -> Self {
self.prompt = prompt.into();
self
}
pub fn with_document_separator(mut self, sep: impl Into<String>) -> Self {
self.document_separator = sep.into();
self
}
fn format_prompt(&self, text: &str) -> String {
self.prompt.replace("{text}", text)
}
pub async fn call(&self, documents: &[Document]) -> Result<String> {
let combined: String = documents
.iter()
.map(|d| d.page_content.as_str())
.collect::<Vec<_>>()
.join(&self.document_separator);
let prompt = self.format_prompt(&combined);
let messages = vec![Message::Human(HumanMessage::new(&prompt))];
let ai_msg = self.llm.invoke_messages(&messages, None).await?;
Ok(ai_msg.base.content.text())
}
}
#[async_trait]
impl Runnable for StuffSummarizationChain {
fn name(&self) -> &str {
"StuffSummarizationChain"
}
async fn invoke(&self, input: Value, _config: Option<&RunnableConfig>) -> Result<Value> {
let docs = parse_documents_from_input(&input)?;
let summary = self.call(&docs).await?;
Ok(json!({ "summary": summary }))
}
}
pub struct MapReduceSummarizationChain {
llm: Arc<dyn BaseChatModel>,
map_prompt: String,
reduce_prompt: String,
max_reduce_length: usize,
}
impl MapReduceSummarizationChain {
pub fn new(llm: Arc<dyn BaseChatModel>) -> Self {
Self {
llm,
map_prompt: DEFAULT_MAP_SUMMARIZE_PROMPT.to_string(),
reduce_prompt: DEFAULT_REDUCE_SUMMARIZE_PROMPT.to_string(),
max_reduce_length: 0,
}
}
pub fn with_map_prompt(mut self, prompt: impl Into<String>) -> Self {
self.map_prompt = prompt.into();
self
}
pub fn with_reduce_prompt(mut self, prompt: impl Into<String>) -> Self {
self.reduce_prompt = prompt.into();
self
}
pub fn with_max_reduce_length(mut self, max_len: usize) -> Self {
self.max_reduce_length = max_len;
self
}
fn format_map_prompt(&self, text: &str) -> String {
self.map_prompt.replace("{text}", text)
}
fn format_reduce_prompt(&self, summaries: &str) -> String {
self.reduce_prompt.replace("{summaries}", summaries)
}
pub async fn call(&self, documents: &[Document]) -> Result<String> {
let mut map_results = Vec::with_capacity(documents.len());
for doc in documents {
let prompt = self.format_map_prompt(&doc.page_content);
let messages = vec![Message::Human(HumanMessage::new(&prompt))];
let ai_msg = self.llm.invoke_messages(&messages, None).await?;
map_results.push(ai_msg.base.content.text());
}
self.reduce(&map_results).await
}
pub async fn map(&self, documents: &[Document]) -> Result<Vec<String>> {
let mut results = Vec::with_capacity(documents.len());
for doc in documents {
let prompt = self.format_map_prompt(&doc.page_content);
let messages = vec![Message::Human(HumanMessage::new(&prompt))];
let ai_msg = self.llm.invoke_messages(&messages, None).await?;
results.push(ai_msg.base.content.text());
}
Ok(results)
}
fn reduce<'a>(
&'a self,
summaries: &'a [String],
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<String>> + Send + 'a>> {
Box::pin(async move {
let combined = summaries.join("\n\n");
if self.max_reduce_length > 0 && combined.len() > self.max_reduce_length {
let chunks = split_into_chunks(summaries, self.max_reduce_length);
let mut intermediate = Vec::with_capacity(chunks.len());
for chunk in &chunks {
let joined = chunk.join("\n\n");
let prompt = self.format_reduce_prompt(&joined);
let messages = vec![Message::Human(HumanMessage::new(&prompt))];
let ai_msg = self.llm.invoke_messages(&messages, None).await?;
intermediate.push(ai_msg.base.content.text());
}
return self.reduce(&intermediate).await;
}
let prompt = self.format_reduce_prompt(&combined);
let messages = vec![Message::Human(HumanMessage::new(&prompt))];
let ai_msg = self.llm.invoke_messages(&messages, None).await?;
Ok(ai_msg.base.content.text())
})
}
}
#[async_trait]
impl Runnable for MapReduceSummarizationChain {
fn name(&self) -> &str {
"MapReduceSummarizationChain"
}
async fn invoke(&self, input: Value, _config: Option<&RunnableConfig>) -> Result<Value> {
let docs = parse_documents_from_input(&input)?;
let summary = self.call(&docs).await?;
Ok(json!({ "summary": summary }))
}
}
pub struct RefineSummarizationChain {
llm: Arc<dyn BaseChatModel>,
initial_prompt: String,
refine_prompt: String,
}
impl RefineSummarizationChain {
pub fn new(llm: Arc<dyn BaseChatModel>) -> Self {
Self {
llm,
initial_prompt: DEFAULT_INITIAL_SUMMARIZE_PROMPT.to_string(),
refine_prompt: DEFAULT_REFINE_SUMMARIZE_PROMPT.to_string(),
}
}
pub fn with_initial_prompt(mut self, prompt: impl Into<String>) -> Self {
self.initial_prompt = prompt.into();
self
}
pub fn with_refine_prompt(mut self, prompt: impl Into<String>) -> Self {
self.refine_prompt = prompt.into();
self
}
fn format_initial_prompt(&self, text: &str) -> String {
self.initial_prompt.replace("{text}", text)
}
fn format_refine_prompt(&self, text: &str, existing_summary: &str) -> String {
self.refine_prompt
.replace("{text}", text)
.replace("{existing_summary}", existing_summary)
}
pub async fn call(&self, documents: &[Document]) -> Result<String> {
if documents.is_empty() {
return Err(CognisError::Other(
"RefineSummarizationChain requires at least one document".into(),
));
}
let first_prompt = self.format_initial_prompt(&documents[0].page_content);
let messages = vec![Message::Human(HumanMessage::new(&first_prompt))];
let ai_msg = self.llm.invoke_messages(&messages, None).await?;
let mut current_summary = ai_msg.base.content.text();
for doc in &documents[1..] {
let refine_prompt = self.format_refine_prompt(&doc.page_content, ¤t_summary);
let messages = vec![Message::Human(HumanMessage::new(&refine_prompt))];
let ai_msg = self.llm.invoke_messages(&messages, None).await?;
current_summary = ai_msg.base.content.text();
}
Ok(current_summary)
}
}
#[async_trait]
impl Runnable for RefineSummarizationChain {
fn name(&self) -> &str {
"RefineSummarizationChain"
}
async fn invoke(&self, input: Value, _config: Option<&RunnableConfig>) -> Result<Value> {
let docs = parse_documents_from_input(&input)?;
let summary = self.call(&docs).await?;
Ok(json!({ "summary": summary }))
}
}
fn parse_documents_from_input(input: &Value) -> Result<Vec<Document>> {
let arr = input
.get("documents")
.and_then(|v| v.as_array())
.ok_or_else(|| CognisError::TypeMismatch {
expected: "JSON object with 'documents' array".into(),
got: format!("{}", input),
})?;
let mut docs = Vec::with_capacity(arr.len());
for item in arr {
match item {
Value::String(s) => {
docs.push(Document::new(s.as_str()));
}
Value::Object(_) => {
let doc: Document = serde_json::from_value(item.clone())?;
docs.push(doc);
}
_ => {
return Err(CognisError::TypeMismatch {
expected: "string or Document object".into(),
got: format!("{}", item),
});
}
}
}
Ok(docs)
}
fn split_into_chunks(summaries: &[String], max_length: usize) -> Vec<Vec<String>> {
let mut chunks: Vec<Vec<String>> = Vec::new();
let mut current_chunk: Vec<String> = Vec::new();
let mut current_length: usize = 0;
for summary in summaries {
let sep_len = if current_chunk.is_empty() { 0 } else { 2 }; if current_length + sep_len + summary.len() > max_length && !current_chunk.is_empty() {
chunks.push(std::mem::take(&mut current_chunk));
current_length = 0;
}
current_length += if current_chunk.is_empty() { 0 } else { 2 } + summary.len();
current_chunk.push(summary.clone());
}
if !current_chunk.is_empty() {
chunks.push(current_chunk);
}
chunks
}
#[cfg(test)]
mod tests {
use super::*;
use cognis_core::language_models::fake::{FakeListChatModel, ParrotFakeChatModel};
fn fake_model(responses: Vec<&str>) -> Arc<dyn BaseChatModel> {
Arc::new(FakeListChatModel::new(
responses.into_iter().map(String::from).collect(),
))
}
fn make_doc(content: &str) -> Document {
Document::new(content)
}
#[tokio::test]
async fn test_stuff_basic_summarization() {
let llm = fake_model(vec!["This is a summary."]);
let chain = StuffSummarizationChain::new(llm);
let docs = vec![make_doc("Some content to summarize.")];
let result = chain.call(&docs).await.unwrap();
assert_eq!(result, "This is a summary.");
}
#[tokio::test]
async fn test_stuff_multiple_documents_joined() {
let llm: Arc<dyn BaseChatModel> = Arc::new(ParrotFakeChatModel::new());
let chain = StuffSummarizationChain::new(llm);
let docs = vec![make_doc("Doc A"), make_doc("Doc B"), make_doc("Doc C")];
let result = chain.call(&docs).await.unwrap();
assert!(result.contains("Doc A"));
assert!(result.contains("Doc B"));
assert!(result.contains("Doc C"));
}
#[tokio::test]
async fn test_stuff_custom_prompt() {
let llm: Arc<dyn BaseChatModel> = Arc::new(ParrotFakeChatModel::new());
let chain = StuffSummarizationChain::new(llm).with_prompt("CUSTOM SUMMARIZE: {text} END");
let docs = vec![make_doc("hello world")];
let result = chain.call(&docs).await.unwrap();
assert!(result.contains("CUSTOM SUMMARIZE:"));
assert!(result.contains("hello world"));
assert!(result.contains("END"));
}
#[tokio::test]
async fn test_stuff_empty_documents() {
let llm = fake_model(vec!["empty summary"]);
let chain = StuffSummarizationChain::new(llm);
let docs: Vec<Document> = vec![];
let result = chain.call(&docs).await.unwrap();
assert_eq!(result, "empty summary");
}
#[tokio::test]
async fn test_stuff_implements_runnable() {
let llm = fake_model(vec!["runnable summary"]);
let chain = StuffSummarizationChain::new(llm);
let runnable: &dyn Runnable = &chain;
assert_eq!(runnable.name(), "StuffSummarizationChain");
let input = json!({ "documents": ["doc content"] });
let result = runnable.invoke(input, None).await.unwrap();
assert_eq!(result["summary"], "runnable summary");
}
#[tokio::test]
async fn test_map_reduce_with_multiple_documents() {
let llm = fake_model(vec![
"summary of doc 1",
"summary of doc 2",
"summary of doc 3",
"final combined summary",
]);
let chain = MapReduceSummarizationChain::new(llm);
let docs = vec![
make_doc("First document content"),
make_doc("Second document content"),
make_doc("Third document content"),
];
let result = chain.call(&docs).await.unwrap();
assert_eq!(result, "final combined summary");
}
#[tokio::test]
async fn test_map_reduce_produces_map_then_reduce_outputs() {
let llm = fake_model(vec!["mapped-A", "mapped-B", "reduced-final"]);
let chain = MapReduceSummarizationChain::new(llm.clone());
let docs = vec![make_doc("Doc A"), make_doc("Doc B")];
let map_results = chain.map(&docs).await.unwrap();
assert_eq!(map_results.len(), 2);
assert_eq!(map_results[0], "mapped-A");
assert_eq!(map_results[1], "mapped-B");
let llm2 = fake_model(vec!["mapped-X", "mapped-Y", "final-reduced"]);
let chain2 = MapReduceSummarizationChain::new(llm2);
let result = chain2.call(&docs).await.unwrap();
assert_eq!(result, "final-reduced");
}
#[tokio::test]
async fn test_map_reduce_single_document() {
let llm = fake_model(vec!["single-map", "single-reduce"]);
let chain = MapReduceSummarizationChain::new(llm);
let docs = vec![make_doc("Only document")];
let result = chain.call(&docs).await.unwrap();
assert_eq!(result, "single-reduce");
}
#[tokio::test]
async fn test_map_reduce_empty_documents() {
let llm = fake_model(vec!["reduce-of-nothing"]);
let chain = MapReduceSummarizationChain::new(llm);
let docs: Vec<Document> = vec![];
let result = chain.call(&docs).await.unwrap();
assert_eq!(result, "reduce-of-nothing");
}
#[tokio::test]
async fn test_map_reduce_implements_runnable() {
let llm = fake_model(vec!["m1", "m2", "reduced"]);
let chain = MapReduceSummarizationChain::new(llm);
let runnable: &dyn Runnable = &chain;
assert_eq!(runnable.name(), "MapReduceSummarizationChain");
let input = json!({ "documents": ["doc1", "doc2"] });
let result = runnable.invoke(input, None).await.unwrap();
assert_eq!(result["summary"], "reduced");
}
#[tokio::test]
async fn test_map_reduce_large_document_set() {
let mut responses: Vec<&str> = Vec::new();
let map_responses: Vec<String> = (0..10).map(|i| format!("summary-{}", i)).collect();
for r in &map_responses {
responses.push(r.as_str());
}
let mut all_responses: Vec<String> = map_responses;
all_responses.push("grand-summary".to_string());
let llm = Arc::new(FakeListChatModel::new(all_responses));
let chain = MapReduceSummarizationChain::new(llm);
let docs: Vec<Document> = (0..10)
.map(|i| make_doc(&format!("Document number {} with some content.", i)))
.collect();
let result = chain.call(&docs).await.unwrap();
assert_eq!(result, "grand-summary");
}
#[tokio::test]
async fn test_map_reduce_custom_prompts() {
let llm: Arc<dyn BaseChatModel> = Arc::new(ParrotFakeChatModel::new());
let chain = MapReduceSummarizationChain::new(llm)
.with_map_prompt("MAP: {text} ENDMAP")
.with_reduce_prompt("REDUCE: {summaries} ENDREDUCE");
let docs = vec![make_doc("test content")];
let result = chain.call(&docs).await.unwrap();
assert!(result.contains("REDUCE:"));
assert!(result.contains("ENDREDUCE"));
}
#[tokio::test]
async fn test_refine_iterates_through_documents() {
let llm = fake_model(vec!["initial-summary", "refined-once", "refined-twice"]);
let chain = RefineSummarizationChain::new(llm);
let docs = vec![make_doc("Doc 1"), make_doc("Doc 2"), make_doc("Doc 3")];
let result = chain.call(&docs).await.unwrap();
assert_eq!(result, "refined-twice");
}
#[tokio::test]
async fn test_refine_single_document_no_refinement() {
let llm = fake_model(vec!["only-summary"]);
let chain = RefineSummarizationChain::new(llm);
let docs = vec![make_doc("Single document")];
let result = chain.call(&docs).await.unwrap();
assert_eq!(result, "only-summary");
}
#[tokio::test]
async fn test_refine_empty_documents_returns_error() {
let llm = fake_model(vec!["unused"]);
let chain = RefineSummarizationChain::new(llm);
let result = chain.call(&[]).await;
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("at least one document"));
}
#[tokio::test]
async fn test_refine_implements_runnable() {
let llm = fake_model(vec!["summary-result"]);
let chain = RefineSummarizationChain::new(llm);
let runnable: &dyn Runnable = &chain;
assert_eq!(runnable.name(), "RefineSummarizationChain");
let input = json!({ "documents": ["content"] });
let result = runnable.invoke(input, None).await.unwrap();
assert_eq!(result["summary"], "summary-result");
}
#[tokio::test]
async fn test_refine_custom_prompts() {
let llm: Arc<dyn BaseChatModel> = Arc::new(ParrotFakeChatModel::new());
let chain = RefineSummarizationChain::new(llm)
.with_initial_prompt("INIT: {text}")
.with_refine_prompt("REFINE: {existing_summary} + {text}");
let docs = vec![make_doc("alpha"), make_doc("beta")];
let result = chain.call(&docs).await.unwrap();
assert!(result.contains("REFINE:"));
assert!(result.contains("beta"));
}
#[tokio::test]
async fn test_runnable_with_document_objects_input() {
let llm = fake_model(vec!["parsed-summary"]);
let chain = StuffSummarizationChain::new(llm);
let input = json!({
"documents": [
{ "page_content": "Document from JSON object" }
]
});
let result = chain.invoke(input, None).await.unwrap();
assert_eq!(result["summary"], "parsed-summary");
}
#[tokio::test]
async fn test_runnable_invalid_input() {
let llm = fake_model(vec!["unused"]);
let chain = StuffSummarizationChain::new(llm);
let result = chain.invoke(json!({"text": "hello"}), None).await;
assert!(result.is_err());
}
#[test]
fn test_split_into_chunks() {
let summaries: Vec<String> = vec![
"short".to_string(),
"also short".to_string(),
"another one".to_string(),
];
let chunks = split_into_chunks(&summaries, 20);
assert_eq!(chunks.len(), 2);
assert_eq!(chunks[0].len(), 2);
assert_eq!(chunks[1].len(), 1);
}
}