use std::collections::HashMap;
use cognis_core::documents::Document;
use cognis_core::error::Result;
use serde_json::Value;
pub trait DocumentCombineStrategy: Send + Sync {
fn combine(&self, documents: &[Document], separator: &str) -> Result<String>;
}
pub struct StuffStrategy {
document_prompt: Option<String>,
}
impl StuffStrategy {
pub fn new() -> Self {
Self {
document_prompt: None,
}
}
pub fn with_prompt(prompt: impl Into<String>) -> Self {
Self {
document_prompt: Some(prompt.into()),
}
}
}
impl Default for StuffStrategy {
fn default() -> Self {
Self::new()
}
}
impl DocumentCombineStrategy for StuffStrategy {
fn combine(&self, documents: &[Document], separator: &str) -> Result<String> {
let formatted: Vec<String> = documents
.iter()
.map(|doc| match &self.document_prompt {
Some(template) => DocumentFormatter::format_document(doc, template),
None => doc.page_content.clone(),
})
.collect();
Ok(formatted.join(separator))
}
}
pub type SummaryFn = Box<dyn Fn(&[Document]) -> Result<String> + Send + Sync>;
pub struct CollapseStrategy {
max_docs_per_group: usize,
summary_fn: SummaryFn,
}
impl CollapseStrategy {
pub fn new(max_docs_per_group: usize, summary_fn: SummaryFn) -> Self {
Self {
max_docs_per_group: max_docs_per_group.max(2),
summary_fn,
}
}
}
impl DocumentCombineStrategy for CollapseStrategy {
fn combine(&self, documents: &[Document], _separator: &str) -> Result<String> {
if documents.is_empty() {
return Ok(String::new());
}
let mut current_docs: Vec<Document> = documents.to_vec();
while current_docs.len() > self.max_docs_per_group {
let mut next_docs = Vec::new();
for chunk in current_docs.chunks(self.max_docs_per_group) {
let summary = (self.summary_fn)(chunk)?;
next_docs.push(Document::new(summary));
}
current_docs = next_docs;
}
if current_docs.len() == 1 {
Ok(current_docs[0].page_content.clone())
} else {
let summary = (self.summary_fn)(¤t_docs)?;
Ok(summary)
}
}
}
pub struct DocumentFormatter;
impl DocumentFormatter {
pub fn format_document(doc: &Document, template: &str) -> String {
let result = template.replace("{page_content}", &doc.page_content);
let mut output = String::with_capacity(result.len());
let mut remaining = result.as_str();
while let Some(start) = remaining.find("{metadata.") {
output.push_str(&remaining[..start]);
let after_prefix = &remaining[start + "{metadata.".len()..];
if let Some(end) = after_prefix.find('}') {
let key = &after_prefix[..end];
let value = doc
.metadata
.get(key)
.map(|v| match v {
Value::String(s) => s.clone(),
other => other.to_string(),
})
.unwrap_or_default();
output.push_str(&value);
remaining = &after_prefix[end + 1..];
} else {
output.push_str(&remaining[start..]);
remaining = "";
break;
}
}
output.push_str(remaining);
output
}
pub fn format_documents(docs: &[Document], template: &str, separator: &str) -> String {
let formatted: Vec<String> = docs
.iter()
.map(|doc| Self::format_document(doc, template))
.collect();
formatted.join(separator)
}
}
pub struct StuffDocumentsChain {
combine_strategy: Box<dyn DocumentCombineStrategy>,
input_key: String,
output_key: String,
document_separator: String,
document_prompt: Option<String>,
}
impl StuffDocumentsChain {
pub fn builder() -> StuffDocumentsChainBuilder {
StuffDocumentsChainBuilder::default()
}
pub fn new(strategy: Box<dyn DocumentCombineStrategy>) -> Self {
Self {
combine_strategy: strategy,
input_key: "input_documents".to_string(),
output_key: "output_text".to_string(),
document_separator: "\n\n".to_string(),
document_prompt: None,
}
}
pub fn invoke(&self, documents: &[Document]) -> Result<String> {
self.combine_strategy
.combine(documents, &self.document_separator)
}
pub fn invoke_as_map(&self, documents: &[Document]) -> Result<HashMap<String, String>> {
let result = self.invoke(documents)?;
let mut map = HashMap::new();
map.insert(self.output_key.clone(), result);
Ok(map)
}
pub fn input_key(&self) -> &str {
&self.input_key
}
pub fn output_key(&self) -> &str {
&self.output_key
}
pub fn document_separator(&self) -> &str {
&self.document_separator
}
pub fn document_prompt(&self) -> Option<&str> {
self.document_prompt.as_deref()
}
}
pub struct StuffDocumentsChainBuilder {
combine_strategy: Option<Box<dyn DocumentCombineStrategy>>,
input_key: String,
output_key: String,
document_separator: String,
document_prompt: Option<String>,
}
impl Default for StuffDocumentsChainBuilder {
fn default() -> Self {
Self {
combine_strategy: None,
input_key: "input_documents".to_string(),
output_key: "output_text".to_string(),
document_separator: "\n\n".to_string(),
document_prompt: None,
}
}
}
impl StuffDocumentsChainBuilder {
pub fn strategy(mut self, strategy: Box<dyn DocumentCombineStrategy>) -> Self {
self.combine_strategy = Some(strategy);
self
}
pub fn input_key(mut self, key: impl Into<String>) -> Self {
self.input_key = key.into();
self
}
pub fn output_key(mut self, key: impl Into<String>) -> Self {
self.output_key = key.into();
self
}
pub fn document_separator(mut self, sep: impl Into<String>) -> Self {
self.document_separator = sep.into();
self
}
pub fn document_prompt(mut self, prompt: impl Into<String>) -> Self {
self.document_prompt = Some(prompt.into());
self
}
pub fn build(self) -> StuffDocumentsChain {
let strategy = self
.combine_strategy
.unwrap_or_else(|| match &self.document_prompt {
Some(prompt) => Box::new(StuffStrategy::with_prompt(prompt.clone())),
None => Box::new(StuffStrategy::new()),
});
StuffDocumentsChain {
combine_strategy: strategy,
input_key: self.input_key,
output_key: self.output_key,
document_separator: self.document_separator,
document_prompt: self.document_prompt,
}
}
}
pub fn create_stuff_documents_chain(prompt_template: &str) -> StuffDocumentsChain {
StuffDocumentsChain::builder()
.strategy(Box::new(StuffStrategy::with_prompt(prompt_template)))
.build()
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
fn make_doc(content: &str) -> Document {
Document::new(content)
}
fn make_doc_with_metadata(content: &str, metadata: HashMap<String, Value>) -> Document {
Document::new(content).with_metadata(metadata)
}
#[test]
fn test_stuff_strategy_simple_concat() {
let strategy = StuffStrategy::new();
let docs = vec![make_doc("Hello"), make_doc("World")];
let result = strategy.combine(&docs, "\n\n").unwrap();
assert_eq!(result, "Hello\n\nWorld");
}
#[test]
fn test_stuff_strategy_empty_docs() {
let strategy = StuffStrategy::new();
let docs: Vec<Document> = vec![];
let result = strategy.combine(&docs, "\n\n").unwrap();
assert_eq!(result, "");
}
#[test]
fn test_stuff_strategy_single_doc() {
let strategy = StuffStrategy::new();
let docs = vec![make_doc("Only one")];
let result = strategy.combine(&docs, "\n\n").unwrap();
assert_eq!(result, "Only one");
}
#[test]
fn test_stuff_strategy_custom_separator() {
let strategy = StuffStrategy::new();
let docs = vec![make_doc("A"), make_doc("B"), make_doc("C")];
let result = strategy.combine(&docs, " | ").unwrap();
assert_eq!(result, "A | B | C");
}
#[test]
fn test_stuff_strategy_with_prompt() {
let strategy = StuffStrategy::with_prompt("Content: {page_content}");
let docs = vec![make_doc("Hello"), make_doc("World")];
let result = strategy.combine(&docs, "\n").unwrap();
assert_eq!(result, "Content: Hello\nContent: World");
}
#[test]
fn test_stuff_strategy_with_metadata_prompt() {
let strategy =
StuffStrategy::with_prompt("Content: {page_content}\nSource: {metadata.source}");
let mut meta = HashMap::new();
meta.insert("source".to_string(), json!("file.txt"));
let docs = vec![make_doc_with_metadata("Hello", meta)];
let result = strategy.combine(&docs, "\n\n").unwrap();
assert_eq!(result, "Content: Hello\nSource: file.txt");
}
#[test]
fn test_collapse_strategy_empty() {
let strategy =
CollapseStrategy::new(2, Box::new(|_docs: &[Document]| Ok("summary".to_string())));
let docs: Vec<Document> = vec![];
let result = strategy.combine(&docs, "\n\n").unwrap();
assert_eq!(result, "");
}
#[test]
fn test_collapse_strategy_single_doc() {
let strategy =
CollapseStrategy::new(2, Box::new(|_docs: &[Document]| Ok("summary".to_string())));
let docs = vec![make_doc("Only one")];
let result = strategy.combine(&docs, "\n\n").unwrap();
assert_eq!(result, "Only one");
}
#[test]
fn test_collapse_strategy_within_group_size() {
let strategy = CollapseStrategy::new(
3,
Box::new(|docs: &[Document]| {
let contents: Vec<&str> = docs.iter().map(|d| d.page_content.as_str()).collect();
Ok(format!("SUMMARY({})", contents.join(", ")))
}),
);
let docs = vec![make_doc("A"), make_doc("B"), make_doc("C")];
let result = strategy.combine(&docs, "\n\n").unwrap();
assert_eq!(result, "SUMMARY(A, B, C)");
}
#[test]
fn test_collapse_strategy_hierarchical() {
let strategy = CollapseStrategy::new(
2,
Box::new(|docs: &[Document]| {
let contents: Vec<&str> = docs.iter().map(|d| d.page_content.as_str()).collect();
Ok(format!("[{}]", contents.join("+")))
}),
);
let docs = vec![make_doc("A"), make_doc("B"), make_doc("C"), make_doc("D")];
let result = strategy.combine(&docs, "\n\n").unwrap();
assert_eq!(result, "[[A+B]+[C+D]]");
}
#[test]
fn test_collapse_strategy_min_group_size() {
let strategy = CollapseStrategy::new(
1,
Box::new(|docs: &[Document]| {
let contents: Vec<&str> = docs.iter().map(|d| d.page_content.as_str()).collect();
Ok(contents.join("+"))
}),
);
assert_eq!(strategy.max_docs_per_group, 2);
}
#[test]
fn test_format_document_page_content() {
let doc = make_doc("Hello world");
let result = DocumentFormatter::format_document(&doc, "Text: {page_content}");
assert_eq!(result, "Text: Hello world");
}
#[test]
fn test_format_document_metadata_key() {
let mut meta = HashMap::new();
meta.insert("source".to_string(), json!("wiki.txt"));
meta.insert("page".to_string(), json!(42));
let doc = make_doc_with_metadata("Content here", meta);
let result = DocumentFormatter::format_document(
&doc,
"{page_content} from {metadata.source} page {metadata.page}",
);
assert_eq!(result, "Content here from wiki.txt page 42");
}
#[test]
fn test_format_document_missing_metadata_key() {
let doc = make_doc("Content");
let result =
DocumentFormatter::format_document(&doc, "{page_content} source={metadata.source}");
assert_eq!(result, "Content source=");
}
#[test]
fn test_format_documents_multiple() {
let docs = vec![make_doc("A"), make_doc("B")];
let result = DocumentFormatter::format_documents(&docs, "Doc: {page_content}", " | ");
assert_eq!(result, "Doc: A | Doc: B");
}
#[test]
fn test_chain_default_builder() {
let chain = StuffDocumentsChain::builder().build();
assert_eq!(chain.input_key(), "input_documents");
assert_eq!(chain.output_key(), "output_text");
assert_eq!(chain.document_separator(), "\n\n");
}
#[test]
fn test_chain_custom_keys() {
let chain = StuffDocumentsChain::builder()
.input_key("docs")
.output_key("result")
.build();
assert_eq!(chain.input_key(), "docs");
assert_eq!(chain.output_key(), "result");
}
#[test]
fn test_chain_invoke() {
let chain = StuffDocumentsChain::builder().build();
let docs = vec![make_doc("First"), make_doc("Second")];
let result = chain.invoke(&docs).unwrap();
assert_eq!(result, "First\n\nSecond");
}
#[test]
fn test_chain_invoke_as_map() {
let chain = StuffDocumentsChain::builder().build();
let docs = vec![make_doc("Alpha"), make_doc("Beta")];
let map = chain.invoke_as_map(&docs).unwrap();
assert_eq!(map.get("output_text").unwrap(), "Alpha\n\nBeta");
}
#[test]
fn test_chain_custom_separator() {
let chain = StuffDocumentsChain::builder()
.document_separator("---")
.build();
let docs = vec![make_doc("X"), make_doc("Y")];
let result = chain.invoke(&docs).unwrap();
assert_eq!(result, "X---Y");
}
#[test]
fn test_chain_with_document_prompt() {
let chain = StuffDocumentsChain::builder()
.document_prompt(">> {page_content}")
.build();
let docs = vec![make_doc("Hello"), make_doc("World")];
let result = chain.invoke(&docs).unwrap();
assert_eq!(result, ">> Hello\n\n>> World");
}
#[test]
fn test_chain_new_constructor() {
let chain = StuffDocumentsChain::new(Box::new(StuffStrategy::new()));
let docs = vec![make_doc("Test")];
let result = chain.invoke(&docs).unwrap();
assert_eq!(result, "Test");
}
#[test]
fn test_factory_function() {
let chain = create_stuff_documents_chain("Content: {page_content}");
let docs = vec![make_doc("Hello"), make_doc("World")];
let result = chain.invoke(&docs).unwrap();
assert_eq!(result, "Content: Hello\n\nContent: World");
}
#[test]
fn test_factory_with_metadata() {
let chain = create_stuff_documents_chain("{page_content} (source: {metadata.source})");
let mut meta = HashMap::new();
meta.insert("source".to_string(), json!("test.pdf"));
let docs = vec![
make_doc_with_metadata("Page 1 content", meta.clone()),
make_doc_with_metadata("Page 2 content", {
let mut m = HashMap::new();
m.insert("source".to_string(), json!("test2.pdf"));
m
}),
];
let result = chain.invoke(&docs).unwrap();
assert_eq!(
result,
"Page 1 content (source: test.pdf)\n\nPage 2 content (source: test2.pdf)"
);
}
}