use std::collections::{HashMap, HashSet};
use serde_json::Value;
use cognis_core::documents::Document;
use cognis_core::error::Result;
pub trait DocumentCompressor: Send + Sync {
fn compress(&self, documents: &[Document], query: &str) -> Result<Vec<Document>>;
}
pub struct LengthCompressor {
max_length: usize,
}
impl LengthCompressor {
pub fn new(max_length: usize) -> Self {
Self { max_length }
}
}
impl DocumentCompressor for LengthCompressor {
fn compress(&self, documents: &[Document], _query: &str) -> Result<Vec<Document>> {
let mut result = Vec::with_capacity(documents.len());
for doc in documents {
let mut compressed = doc.clone();
if compressed.page_content.len() > self.max_length {
let truncated: String = compressed
.page_content
.chars()
.take(self.max_length)
.collect();
compressed.page_content = truncated;
}
result.push(compressed);
}
Ok(result)
}
}
pub struct SentenceExtractor {
min_sentences: usize,
}
impl SentenceExtractor {
pub fn new() -> Self {
Self { min_sentences: 1 }
}
pub fn with_min_sentences(mut self, n: usize) -> Self {
self.min_sentences = n;
self
}
fn split_sentences(text: &str) -> Vec<String> {
let mut sentences = Vec::new();
let mut current = String::new();
for ch in text.chars() {
current.push(ch);
if ch == '.' || ch == '!' || ch == '?' {
let trimmed = current.trim().to_string();
if !trimmed.is_empty() {
sentences.push(trimmed);
}
current.clear();
}
}
let trimmed = current.trim().to_string();
if !trimmed.is_empty() {
sentences.push(trimmed);
}
sentences
}
fn extract_keywords(query: &str) -> HashSet<String> {
query
.split_whitespace()
.map(|w| {
w.to_lowercase()
.trim_matches(|c: char| !c.is_alphanumeric())
.to_string()
})
.filter(|w| !w.is_empty())
.collect()
}
}
impl Default for SentenceExtractor {
fn default() -> Self {
Self::new()
}
}
impl DocumentCompressor for SentenceExtractor {
fn compress(&self, documents: &[Document], query: &str) -> Result<Vec<Document>> {
let keywords = Self::extract_keywords(query);
let mut result = Vec::with_capacity(documents.len());
for doc in documents {
let sentences = Self::split_sentences(&doc.page_content);
if sentences.is_empty() {
result.push(doc.clone());
continue;
}
let matching: Vec<&String> = sentences
.iter()
.filter(|s| {
let lower = s.to_lowercase();
keywords.iter().any(|kw| lower.contains(kw.as_str()))
})
.collect();
let selected = if matching.len() >= self.min_sentences {
matching.into_iter().cloned().collect::<Vec<_>>()
} else {
sentences
.iter()
.take(self.min_sentences)
.cloned()
.collect::<Vec<_>>()
};
let mut compressed = doc.clone();
compressed.page_content = selected.join(" ");
result.push(compressed);
}
Ok(result)
}
}
pub struct RedundancyFilter {
similarity_threshold: f64,
}
impl RedundancyFilter {
pub fn new(similarity_threshold: f64) -> Self {
Self {
similarity_threshold,
}
}
fn jaccard_similarity(a: &HashSet<String>, b: &HashSet<String>) -> f64 {
if a.is_empty() && b.is_empty() {
return 1.0;
}
let intersection = a.intersection(b).count() as f64;
let union = a.union(b).count() as f64;
if union == 0.0 {
return 0.0;
}
intersection / union
}
fn word_set(text: &str) -> HashSet<String> {
text.split_whitespace()
.map(|w| {
w.to_lowercase()
.trim_matches(|c: char| !c.is_alphanumeric())
.to_string()
})
.filter(|w| !w.is_empty())
.collect()
}
}
impl DocumentCompressor for RedundancyFilter {
fn compress(&self, documents: &[Document], _query: &str) -> Result<Vec<Document>> {
let mut accepted: Vec<(Document, HashSet<String>)> = Vec::new();
for doc in documents {
let word_set = Self::word_set(&doc.page_content);
let is_duplicate = accepted.iter().any(|(_, existing_set)| {
Self::jaccard_similarity(&word_set, existing_set) >= self.similarity_threshold
});
if !is_duplicate {
accepted.push((doc.clone(), word_set));
}
}
Ok(accepted.into_iter().map(|(doc, _)| doc).collect())
}
}
pub struct RelevanceScorer {
min_score: f64,
}
impl RelevanceScorer {
pub fn new(min_score: f64) -> Self {
Self { min_score }
}
fn score(doc_text: &str, query_keywords: &[String]) -> f64 {
if query_keywords.is_empty() {
return 1.0;
}
let lower = doc_text.to_lowercase();
let matches = query_keywords
.iter()
.filter(|kw| lower.contains(kw.as_str()))
.count();
matches as f64 / query_keywords.len() as f64
}
}
impl DocumentCompressor for RelevanceScorer {
fn compress(&self, documents: &[Document], query: &str) -> Result<Vec<Document>> {
let keywords: Vec<String> = query
.split_whitespace()
.map(|w| {
w.to_lowercase()
.trim_matches(|c: char| !c.is_alphanumeric())
.to_string()
})
.filter(|w| !w.is_empty())
.collect();
let result = documents
.iter()
.filter(|doc| Self::score(&doc.page_content, &keywords) >= self.min_score)
.cloned()
.collect();
Ok(result)
}
}
#[derive(Debug, Clone)]
enum MetadataCondition {
RequireField(String),
RequireValue(String, Value),
ExcludeValue(String, Value),
}
pub struct MetadataFilter {
conditions: Vec<MetadataCondition>,
}
impl MetadataFilter {
pub fn new() -> Self {
Self {
conditions: Vec::new(),
}
}
pub fn require_field(mut self, field: &str) -> Self {
self.conditions
.push(MetadataCondition::RequireField(field.to_string()));
self
}
pub fn require_value(mut self, field: &str, value: Value) -> Self {
self.conditions
.push(MetadataCondition::RequireValue(field.to_string(), value));
self
}
pub fn exclude_value(mut self, field: &str, value: Value) -> Self {
self.conditions
.push(MetadataCondition::ExcludeValue(field.to_string(), value));
self
}
fn satisfies(&self, metadata: &HashMap<String, Value>) -> bool {
self.conditions.iter().all(|cond| match cond {
MetadataCondition::RequireField(field) => metadata.contains_key(field),
MetadataCondition::RequireValue(field, value) => metadata.get(field) == Some(value),
MetadataCondition::ExcludeValue(field, value) => metadata.get(field) != Some(value),
})
}
}
impl Default for MetadataFilter {
fn default() -> Self {
Self::new()
}
}
impl DocumentCompressor for MetadataFilter {
fn compress(&self, documents: &[Document], _query: &str) -> Result<Vec<Document>> {
Ok(documents
.iter()
.filter(|doc| self.satisfies(&doc.metadata))
.cloned()
.collect())
}
}
pub struct CompressorPipeline {
compressors: Vec<Box<dyn DocumentCompressor>>,
}
impl CompressorPipeline {
pub fn new() -> Self {
Self {
compressors: Vec::new(),
}
}
#[allow(clippy::should_implement_trait)]
pub fn add(mut self, compressor: Box<dyn DocumentCompressor>) -> Self {
self.compressors.push(compressor);
self
}
pub fn len(&self) -> usize {
self.compressors.len()
}
pub fn is_empty(&self) -> bool {
self.compressors.is_empty()
}
}
impl Default for CompressorPipeline {
fn default() -> Self {
Self::new()
}
}
impl DocumentCompressor for CompressorPipeline {
fn compress(&self, documents: &[Document], query: &str) -> Result<Vec<Document>> {
let mut docs = documents.to_vec();
for compressor in &self.compressors {
docs = compressor.compress(&docs, query)?;
if docs.is_empty() {
return Ok(docs);
}
}
Ok(docs)
}
}
pub struct ContextualCompressionRetriever {
documents: Vec<Document>,
compressor: Box<dyn DocumentCompressor>,
}
impl ContextualCompressionRetriever {
pub fn new(documents: Vec<Document>, compressor: Box<dyn DocumentCompressor>) -> Self {
Self {
documents,
compressor,
}
}
pub fn retrieve(&self, query: &str, k: usize) -> Result<Vec<Document>> {
let mut docs = self.compressor.compress(&self.documents, query)?;
docs.truncate(k);
Ok(docs)
}
pub fn retrieve_all(&self, query: &str) -> Result<Vec<Document>> {
self.compressor.compress(&self.documents, query)
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
fn doc(content: &str) -> Document {
Document::new(content)
}
fn doc_with_meta(content: &str, meta: Vec<(&str, Value)>) -> Document {
let metadata: HashMap<String, Value> =
meta.into_iter().map(|(k, v)| (k.to_string(), v)).collect();
Document::new(content).with_metadata(metadata)
}
#[test]
fn test_length_compressor_truncates_long_document() {
let compressor = LengthCompressor::new(10);
let docs = vec![doc("This is a long document that should be truncated")];
let result = compressor.compress(&docs, "query").unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].page_content.len(), 10);
assert_eq!(result[0].page_content, "This is a ");
}
#[test]
fn test_length_compressor_keeps_short_document() {
let compressor = LengthCompressor::new(100);
let docs = vec![doc("Short")];
let result = compressor.compress(&docs, "query").unwrap();
assert_eq!(result[0].page_content, "Short");
}
#[test]
fn test_length_compressor_exact_length() {
let compressor = LengthCompressor::new(5);
let docs = vec![doc("Hello")];
let result = compressor.compress(&docs, "query").unwrap();
assert_eq!(result[0].page_content, "Hello");
}
#[test]
fn test_length_compressor_zero_length() {
let compressor = LengthCompressor::new(0);
let docs = vec![doc("Hello")];
let result = compressor.compress(&docs, "query").unwrap();
assert_eq!(result[0].page_content, "");
}
#[test]
fn test_length_compressor_empty_docs() {
let compressor = LengthCompressor::new(10);
let result = compressor.compress(&[], "query").unwrap();
assert!(result.is_empty());
}
#[test]
fn test_length_compressor_preserves_metadata() {
let compressor = LengthCompressor::new(5);
let docs = vec![doc_with_meta(
"Hello World",
vec![("source", json!("test.pdf"))],
)];
let result = compressor.compress(&docs, "query").unwrap();
assert_eq!(result[0].metadata.get("source"), Some(&json!("test.pdf")));
}
#[test]
fn test_sentence_extractor_matches_keywords() {
let extractor = SentenceExtractor::new();
let docs = vec![doc(
"The cat sat on the mat. The dog ran in the park. Rust is great.",
)];
let result = extractor.compress(&docs, "dog park").unwrap();
assert_eq!(result.len(), 1);
assert!(result[0].page_content.contains("dog"));
assert!(result[0].page_content.contains("park"));
}
#[test]
fn test_sentence_extractor_no_matches_returns_min_sentences() {
let extractor = SentenceExtractor::new().with_min_sentences(2);
let docs = vec![doc("First sentence. Second sentence. Third sentence.")];
let result = extractor.compress(&docs, "nonexistent").unwrap();
assert_eq!(result.len(), 1);
assert!(result[0].page_content.contains("First sentence."));
assert!(result[0].page_content.contains("Second sentence."));
}
#[test]
fn test_sentence_extractor_case_insensitive() {
let extractor = SentenceExtractor::new();
let docs = vec![doc("Rust is awesome. Python is nice.")];
let result = extractor.compress(&docs, "RUST").unwrap();
assert!(result[0].page_content.contains("Rust is awesome."));
}
#[test]
fn test_sentence_extractor_empty_query() {
let extractor = SentenceExtractor::new().with_min_sentences(1);
let docs = vec![doc("First. Second. Third.")];
let result = extractor.compress(&docs, "").unwrap();
assert_eq!(result.len(), 1);
assert!(result[0].page_content.contains("First."));
}
#[test]
fn test_sentence_extractor_empty_docs() {
let extractor = SentenceExtractor::new();
let result = extractor.compress(&[], "query").unwrap();
assert!(result.is_empty());
}
#[test]
fn test_redundancy_filter_removes_duplicates() {
let filter = RedundancyFilter::new(0.8);
let docs = vec![
doc("the quick brown fox jumps over the lazy dog"),
doc("the quick brown fox jumps over the lazy dog"), doc("completely different content about something else"),
];
let result = filter.compress(&docs, "query").unwrap();
assert_eq!(result.len(), 2);
assert!(result[0].page_content.contains("fox"));
assert!(result[1].page_content.contains("different"));
}
#[test]
fn test_redundancy_filter_near_duplicates() {
let filter = RedundancyFilter::new(0.7);
let docs = vec![
doc("the quick brown fox jumps over the lazy dog"),
doc("the quick brown fox leaps over the lazy dog"), ];
let result = filter.compress(&docs, "query").unwrap();
assert_eq!(result.len(), 1);
}
#[test]
fn test_redundancy_filter_low_threshold_keeps_all() {
let filter = RedundancyFilter::new(1.0);
let docs = vec![doc("the quick brown fox"), doc("the quick brown fox jumps")];
let result = filter.compress(&docs, "query").unwrap();
assert_eq!(result.len(), 2);
}
#[test]
fn test_redundancy_filter_empty_docs() {
let filter = RedundancyFilter::new(0.8);
let result = filter.compress(&[], "query").unwrap();
assert!(result.is_empty());
}
#[test]
fn test_redundancy_filter_single_doc() {
let filter = RedundancyFilter::new(0.8);
let docs = vec![doc("only document")];
let result = filter.compress(&docs, "query").unwrap();
assert_eq!(result.len(), 1);
}
#[test]
fn test_relevance_scorer_filters_irrelevant() {
let scorer = RelevanceScorer::new(0.5);
let docs = vec![
doc("rust programming language is fast and safe"),
doc("cooking recipes for pasta and pizza"),
doc("rust compiler and borrow checker"),
];
let result = scorer.compress(&docs, "rust programming").unwrap();
assert_eq!(result.len(), 2);
assert!(result[0].page_content.contains("rust"));
assert!(result[1].page_content.contains("rust"));
}
#[test]
fn test_relevance_scorer_all_relevant() {
let scorer = RelevanceScorer::new(0.0);
let docs = vec![doc("anything"), doc("goes")];
let result = scorer.compress(&docs, "query").unwrap();
assert_eq!(result.len(), 2);
}
#[test]
fn test_relevance_scorer_none_relevant() {
let scorer = RelevanceScorer::new(1.0);
let docs = vec![doc("no matching keywords here")];
let result = scorer.compress(&docs, "rust programming").unwrap();
assert!(result.is_empty());
}
#[test]
fn test_relevance_scorer_empty_query() {
let scorer = RelevanceScorer::new(0.5);
let docs = vec![doc("some document")];
let result = scorer.compress(&docs, "").unwrap();
assert_eq!(result.len(), 1);
}
#[test]
fn test_relevance_scorer_empty_docs() {
let scorer = RelevanceScorer::new(0.5);
let result = scorer.compress(&[], "query").unwrap();
assert!(result.is_empty());
}
#[test]
fn test_metadata_filter_require_field() {
let filter = MetadataFilter::new().require_field("source");
let docs = vec![
doc_with_meta("has source", vec![("source", json!("file.pdf"))]),
doc("no source"),
];
let result = filter.compress(&docs, "query").unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].page_content, "has source");
}
#[test]
fn test_metadata_filter_require_value() {
let filter = MetadataFilter::new().require_value("type", json!("article"));
let docs = vec![
doc_with_meta("article", vec![("type", json!("article"))]),
doc_with_meta("blog", vec![("type", json!("blog"))]),
doc("no type"),
];
let result = filter.compress(&docs, "query").unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].page_content, "article");
}
#[test]
fn test_metadata_filter_exclude_value() {
let filter = MetadataFilter::new().exclude_value("status", json!("draft"));
let docs = vec![
doc_with_meta("published", vec![("status", json!("published"))]),
doc_with_meta("draft", vec![("status", json!("draft"))]),
doc("no status"),
];
let result = filter.compress(&docs, "query").unwrap();
assert_eq!(result.len(), 2);
assert_eq!(result[0].page_content, "published");
assert_eq!(result[1].page_content, "no status");
}
#[test]
fn test_metadata_filter_combined_conditions() {
let filter = MetadataFilter::new()
.require_field("source")
.exclude_value("status", json!("draft"));
let docs = vec![
doc_with_meta(
"good",
vec![("source", json!("a")), ("status", json!("published"))],
),
doc_with_meta(
"draft",
vec![("source", json!("b")), ("status", json!("draft"))],
),
doc_with_meta("no source", vec![("status", json!("published"))]),
doc("nothing"),
];
let result = filter.compress(&docs, "query").unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].page_content, "good");
}
#[test]
fn test_metadata_filter_no_conditions() {
let filter = MetadataFilter::new();
let docs = vec![doc("a"), doc("b")];
let result = filter.compress(&docs, "query").unwrap();
assert_eq!(result.len(), 2);
}
#[test]
fn test_metadata_filter_empty_docs() {
let filter = MetadataFilter::new().require_field("source");
let result = filter.compress(&[], "query").unwrap();
assert!(result.is_empty());
}
#[test]
fn test_pipeline_chains_compressors() {
let pipeline = CompressorPipeline::new()
.add(Box::new(RelevanceScorer::new(0.5)))
.add(Box::new(LengthCompressor::new(20)));
let docs = vec![
doc("rust programming language documentation"),
doc("cooking recipes for beginners"),
];
let result = pipeline.compress(&docs, "rust programming").unwrap();
assert_eq!(result.len(), 1);
assert!(result[0].page_content.len() <= 20);
}
#[test]
fn test_pipeline_empty_pipeline() {
let pipeline = CompressorPipeline::new();
assert!(pipeline.is_empty());
assert_eq!(pipeline.len(), 0);
let docs = vec![doc("unchanged")];
let result = pipeline.compress(&docs, "query").unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].page_content, "unchanged");
}
#[test]
fn test_pipeline_short_circuits_on_empty() {
let pipeline = CompressorPipeline::new()
.add(Box::new(RelevanceScorer::new(1.0)))
.add(Box::new(LengthCompressor::new(5)));
let docs = vec![doc("no matching keywords at all")];
let result = pipeline.compress(&docs, "nonexistent").unwrap();
assert!(result.is_empty());
}
#[test]
fn test_pipeline_len() {
let pipeline = CompressorPipeline::new()
.add(Box::new(LengthCompressor::new(10)))
.add(Box::new(RedundancyFilter::new(0.8)));
assert_eq!(pipeline.len(), 2);
assert!(!pipeline.is_empty());
}
#[test]
fn test_retriever_retrieve_with_k() {
let docs = vec![
doc("rust is fast"),
doc("rust is safe"),
doc("python is dynamic"),
];
let compressor = Box::new(RelevanceScorer::new(0.5));
let retriever = ContextualCompressionRetriever::new(docs, compressor);
let result = retriever.retrieve("rust", 1).unwrap();
assert_eq!(result.len(), 1);
}
#[test]
fn test_retriever_retrieve_all() {
let docs = vec![
doc("rust is fast"),
doc("rust is safe"),
doc("python is dynamic"),
];
let compressor = Box::new(RelevanceScorer::new(0.5));
let retriever = ContextualCompressionRetriever::new(docs, compressor);
let result = retriever.retrieve_all("rust").unwrap();
assert_eq!(result.len(), 2);
}
#[test]
fn test_retriever_empty_docs() {
let compressor = Box::new(LengthCompressor::new(100));
let retriever = ContextualCompressionRetriever::new(vec![], compressor);
let result = retriever.retrieve("query", 5).unwrap();
assert!(result.is_empty());
}
#[test]
fn test_retriever_k_larger_than_results() {
let docs = vec![doc("only one")];
let compressor = Box::new(LengthCompressor::new(100));
let retriever = ContextualCompressionRetriever::new(docs, compressor);
let result = retriever.retrieve("query", 10).unwrap();
assert_eq!(result.len(), 1);
}
#[test]
fn test_retriever_end_to_end_with_pipeline() {
let docs = vec![
doc_with_meta(
"rust is a great programming language. It is fast.",
vec![("source", json!("docs"))],
),
doc_with_meta(
"cooking pasta is easy. Boil water first.",
vec![("source", json!("recipes"))],
),
doc("no metadata here"),
];
let pipeline = CompressorPipeline::new()
.add(Box::new(MetadataFilter::new().require_field("source")))
.add(Box::new(RelevanceScorer::new(0.5)))
.add(Box::new(LengthCompressor::new(30)));
let retriever = ContextualCompressionRetriever::new(docs, Box::new(pipeline));
let result = retriever.retrieve_all("rust programming").unwrap();
assert_eq!(result.len(), 1);
assert!(result[0].page_content.len() <= 30);
}
#[test]
fn test_retriever_all_filtered_out() {
let docs = vec![doc("nothing relevant")];
let compressor = Box::new(RelevanceScorer::new(1.0));
let retriever = ContextualCompressionRetriever::new(docs, compressor);
let result = retriever.retrieve("nonexistent keywords here", 5).unwrap();
assert!(result.is_empty());
}
}