use std::sync::Arc;
use async_trait::async_trait;
use cognis_core::documents::Document;
use cognis_core::embeddings::Embeddings;
use cognis_core::error::Result;
use cognis_core::language_models::chat_model::BaseChatModel;
use cognis_core::messages::{HumanMessage, Message};
use cognis_core::retrievers::BaseRetriever;
#[async_trait]
pub trait DocumentCompressor: Send + Sync {
async fn compress_documents(
&self,
documents: &[Document],
query: &str,
) -> Result<Vec<Document>>;
fn name(&self) -> &str;
}
const DEFAULT_PROMPT_TEMPLATE: &str = "Given the following question and document, extract only the parts of the document that are relevant to answering the question. If the document is not relevant at all, respond with exactly \"NO_RELEVANT_CONTENT\".\n\nQuestion: {query}\n\nDocument:\n{document}\n\nRelevant parts:";
pub struct LLMCompressor {
model: Arc<dyn BaseChatModel>,
prompt_template: String,
}
impl LLMCompressor {
pub fn new(model: Arc<dyn BaseChatModel>) -> Self {
Self {
model,
prompt_template: DEFAULT_PROMPT_TEMPLATE.to_string(),
}
}
pub fn with_prompt_template(mut self, template: impl Into<String>) -> Self {
self.prompt_template = template.into();
self
}
fn format_prompt(&self, query: &str, document: &str) -> String {
self.prompt_template
.replace("{query}", query)
.replace("{document}", document)
}
}
#[async_trait]
impl DocumentCompressor for LLMCompressor {
fn name(&self) -> &str {
"LLMCompressor"
}
async fn compress_documents(
&self,
documents: &[Document],
query: &str,
) -> Result<Vec<Document>> {
let mut compressed = Vec::new();
for doc in documents {
let prompt = self.format_prompt(query, &doc.page_content);
let message = Message::Human(HumanMessage::new(prompt));
let ai_msg = self.model.invoke_messages(&[message], None).await?;
let text = ai_msg.base.content.text();
let trimmed = text.trim();
if !trimmed.is_empty() && trimmed != "NO_RELEVANT_CONTENT" {
let mut compressed_doc = doc.clone();
compressed_doc.page_content = trimmed.to_string();
compressed.push(compressed_doc);
}
}
Ok(compressed)
}
}
pub struct EmbeddingsFilter {
embeddings: Arc<dyn Embeddings>,
similarity_threshold: f32,
top_k: Option<usize>,
}
impl EmbeddingsFilter {
pub fn new(embeddings: Arc<dyn Embeddings>) -> Self {
Self {
embeddings,
similarity_threshold: 0.0,
top_k: None,
}
}
pub fn with_similarity_threshold(mut self, threshold: f32) -> Self {
self.similarity_threshold = threshold;
self
}
pub fn with_top_k(mut self, k: usize) -> Self {
self.top_k = Some(k);
self
}
}
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() || a.is_empty() {
return 0.0;
}
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm_a == 0.0 || norm_b == 0.0 {
return 0.0;
}
dot / (norm_a * norm_b)
}
#[async_trait]
impl DocumentCompressor for EmbeddingsFilter {
fn name(&self) -> &str {
"EmbeddingsFilter"
}
async fn compress_documents(
&self,
documents: &[Document],
query: &str,
) -> Result<Vec<Document>> {
if documents.is_empty() {
return Ok(Vec::new());
}
let query_embedding = self.embeddings.embed_query(query).await?;
let doc_texts: Vec<String> = documents.iter().map(|d| d.page_content.clone()).collect();
let doc_embeddings = self.embeddings.embed_documents(doc_texts).await?;
let mut scored: Vec<(usize, f32)> = doc_embeddings
.iter()
.enumerate()
.map(|(i, emb)| (i, cosine_similarity(&query_embedding, emb)))
.filter(|(_, sim)| *sim >= self.similarity_threshold)
.collect();
scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
if let Some(k) = self.top_k {
scored.truncate(k);
}
Ok(scored
.into_iter()
.map(|(i, _)| documents[i].clone())
.collect())
}
}
pub struct ContextualCompressionRetriever {
base_retriever: Box<dyn BaseRetriever>,
compressor: Box<dyn DocumentCompressor>,
}
impl ContextualCompressionRetriever {
pub fn new(
base_retriever: Box<dyn BaseRetriever>,
compressor: Box<dyn DocumentCompressor>,
) -> Self {
Self {
base_retriever,
compressor,
}
}
pub fn with_base_retriever(mut self, retriever: Box<dyn BaseRetriever>) -> Self {
self.base_retriever = retriever;
self
}
pub fn with_compressor(mut self, compressor: Box<dyn DocumentCompressor>) -> Self {
self.compressor = compressor;
self
}
}
#[async_trait]
impl BaseRetriever for ContextualCompressionRetriever {
async fn get_relevant_documents(&self, query: &str) -> Result<Vec<Document>> {
let docs = self.base_retriever.get_relevant_documents(query).await?;
self.compressor.compress_documents(&docs, query).await
}
}
#[cfg(test)]
mod tests {
use super::*;
use cognis_core::messages::AIMessage;
use cognis_core::outputs::{ChatGeneration, ChatResult};
use std::collections::HashMap;
use std::sync::Mutex;
struct MockRetriever {
docs: Vec<Document>,
}
impl MockRetriever {
fn new(contents: &[&str]) -> Self {
Self {
docs: contents.iter().map(|c| Document::new(*c)).collect(),
}
}
}
#[async_trait]
impl BaseRetriever for MockRetriever {
async fn get_relevant_documents(&self, _query: &str) -> Result<Vec<Document>> {
Ok(self.docs.clone())
}
}
struct FakeChatModel {
responses: Mutex<Vec<String>>,
}
impl FakeChatModel {
fn new(responses: Vec<String>) -> Self {
Self {
responses: Mutex::new(responses),
}
}
}
#[async_trait]
impl BaseChatModel for FakeChatModel {
async fn _generate(
&self,
_messages: &[Message],
_stop: Option<&[String]>,
) -> Result<ChatResult> {
let response = {
let mut resps = self.responses.lock().unwrap();
if resps.is_empty() {
"NO_RELEVANT_CONTENT".to_string()
} else {
resps.remove(0)
}
};
let ai_msg = AIMessage::new(&response);
Ok(ChatResult {
generations: vec![ChatGeneration::new(ai_msg)],
llm_output: None,
})
}
fn llm_type(&self) -> &str {
"fake"
}
}
struct FakeEmbeddings {
query_embedding: Vec<f32>,
doc_embeddings: Vec<Vec<f32>>,
}
#[async_trait]
impl Embeddings for FakeEmbeddings {
async fn embed_documents(&self, _texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
Ok(self.doc_embeddings.clone())
}
async fn embed_query(&self, _text: &str) -> Result<Vec<f32>> {
Ok(self.query_embedding.clone())
}
}
#[tokio::test]
async fn test_llm_compressor_extracts_relevant_text() {
let model = Arc::new(FakeChatModel::new(vec![
"relevant part A".into(),
"relevant part B".into(),
]));
let compressor = LLMCompressor::new(model);
let docs = vec![Document::new("full doc A"), Document::new("full doc B")];
let result = compressor.compress_documents(&docs, "query").await.unwrap();
assert_eq!(result.len(), 2);
assert_eq!(result[0].page_content, "relevant part A");
assert_eq!(result[1].page_content, "relevant part B");
}
#[tokio::test]
async fn test_llm_compressor_filters_irrelevant_documents() {
let model = Arc::new(FakeChatModel::new(vec![
"relevant content".into(),
"NO_RELEVANT_CONTENT".into(),
"also relevant".into(),
]));
let compressor = LLMCompressor::new(model);
let docs = vec![
Document::new("doc 1"),
Document::new("doc 2"),
Document::new("doc 3"),
];
let result = compressor.compress_documents(&docs, "query").await.unwrap();
assert_eq!(result.len(), 2);
assert_eq!(result[0].page_content, "relevant content");
assert_eq!(result[1].page_content, "also relevant");
}
#[tokio::test]
async fn test_llm_compressor_preserves_metadata() {
let model = Arc::new(FakeChatModel::new(vec!["compressed".into()]));
let compressor = LLMCompressor::new(model);
let mut metadata = HashMap::new();
metadata.insert("source".to_string(), serde_json::json!("test.pdf"));
let doc = Document::new("original content").with_metadata(metadata.clone());
let result = compressor
.compress_documents(&[doc], "query")
.await
.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].page_content, "compressed");
assert_eq!(result[0].metadata, metadata);
}
#[tokio::test]
async fn test_embeddings_filter_with_threshold() {
let embeddings = Arc::new(FakeEmbeddings {
query_embedding: vec![1.0, 0.0, 0.0],
doc_embeddings: vec![
vec![1.0, 0.0, 0.0],
vec![0.0, 1.0, 0.0],
vec![0.707, 0.707, 0.0],
],
});
let filter = EmbeddingsFilter::new(embeddings).with_similarity_threshold(0.5);
let docs = vec![
Document::new("identical"),
Document::new("orthogonal"),
Document::new("partial"),
];
let result = filter.compress_documents(&docs, "query").await.unwrap();
assert_eq!(result.len(), 2);
assert_eq!(result[0].page_content, "identical");
assert_eq!(result[1].page_content, "partial");
}
#[tokio::test]
async fn test_embeddings_filter_with_top_k() {
let embeddings = Arc::new(FakeEmbeddings {
query_embedding: vec![1.0, 0.0],
doc_embeddings: vec![
vec![1.0, 0.0], vec![0.5, 0.5], vec![0.9, 0.1], vec![0.0, 1.0], ],
});
let filter = EmbeddingsFilter::new(embeddings).with_top_k(2);
let docs = vec![
Document::new("doc_a"),
Document::new("doc_b"),
Document::new("doc_c"),
Document::new("doc_d"),
];
let result = filter.compress_documents(&docs, "query").await.unwrap();
assert_eq!(result.len(), 2);
assert_eq!(result[0].page_content, "doc_a");
assert_eq!(result[1].page_content, "doc_c");
}
#[tokio::test]
async fn test_contextual_compression_retriever_with_llm() {
let base = Box::new(MockRetriever::new(&["full document 1", "full document 2"]));
let model = Arc::new(FakeChatModel::new(vec![
"compressed 1".into(),
"compressed 2".into(),
]));
let compressor = Box::new(LLMCompressor::new(model));
let retriever = ContextualCompressionRetriever::new(base, compressor);
let docs = retriever.get_relevant_documents("query").await.unwrap();
assert_eq!(docs.len(), 2);
assert_eq!(docs[0].page_content, "compressed 1");
assert_eq!(docs[1].page_content, "compressed 2");
}
#[tokio::test]
async fn test_contextual_compression_retriever_with_embeddings_filter() {
let base = Box::new(MockRetriever::new(&["relevant", "irrelevant"]));
let embeddings = Arc::new(FakeEmbeddings {
query_embedding: vec![1.0, 0.0],
doc_embeddings: vec![
vec![0.9, 0.1], vec![0.0, 1.0], ],
});
let filter = Box::new(EmbeddingsFilter::new(embeddings).with_similarity_threshold(0.5));
let retriever = ContextualCompressionRetriever::new(base, filter);
let docs = retriever.get_relevant_documents("query").await.unwrap();
assert_eq!(docs.len(), 1);
assert_eq!(docs[0].page_content, "relevant");
}
#[tokio::test]
async fn test_empty_document_list() {
let model = Arc::new(FakeChatModel::new(vec![]));
let compressor = LLMCompressor::new(model);
let result = compressor.compress_documents(&[], "query").await.unwrap();
assert!(result.is_empty());
}
#[tokio::test]
async fn test_all_documents_filtered_out() {
let model = Arc::new(FakeChatModel::new(vec![
"NO_RELEVANT_CONTENT".into(),
"NO_RELEVANT_CONTENT".into(),
]));
let compressor = LLMCompressor::new(model);
let docs = vec![Document::new("doc 1"), Document::new("doc 2")];
let result = compressor.compress_documents(&docs, "query").await.unwrap();
assert!(result.is_empty());
}
#[tokio::test]
async fn test_custom_prompt_template() {
let model = Arc::new(FakeChatModel::new(vec!["answer".into()]));
let custom_template = "Q: {query}\nDoc: {document}\nExtract:";
let compressor = LLMCompressor::new(model).with_prompt_template(custom_template);
let formatted = compressor.format_prompt("test query", "test doc");
assert_eq!(formatted, "Q: test query\nDoc: test doc\nExtract:");
let docs = vec![Document::new("test doc")];
let result = compressor
.compress_documents(&docs, "test query")
.await
.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].page_content, "answer");
}
#[tokio::test]
async fn test_builder_pattern() {
let base = Box::new(MockRetriever::new(&["doc"]));
let model = Arc::new(FakeChatModel::new(vec!["compressed".into()]));
let compressor = Box::new(LLMCompressor::new(model));
let new_base = Box::new(MockRetriever::new(&["new doc"]));
let new_model = Arc::new(FakeChatModel::new(vec!["new compressed".into()]));
let new_compressor = Box::new(LLMCompressor::new(new_model));
let retriever = ContextualCompressionRetriever::new(base, compressor)
.with_base_retriever(new_base)
.with_compressor(new_compressor);
let docs = retriever.get_relevant_documents("query").await.unwrap();
assert_eq!(docs.len(), 1);
assert_eq!(docs[0].page_content, "new compressed");
}
#[tokio::test]
async fn test_embeddings_filter_empty_documents() {
let embeddings = Arc::new(FakeEmbeddings {
query_embedding: vec![1.0, 0.0],
doc_embeddings: vec![],
});
let filter = EmbeddingsFilter::new(embeddings);
let result = filter.compress_documents(&[], "query").await.unwrap();
assert!(result.is_empty());
}
#[tokio::test]
async fn test_cosine_similarity_function() {
assert!((cosine_similarity(&[1.0, 0.0], &[1.0, 0.0]) - 1.0).abs() < 1e-6);
assert!((cosine_similarity(&[1.0, 0.0], &[0.0, 1.0])).abs() < 1e-6);
assert!((cosine_similarity(&[1.0, 0.0], &[-1.0, 0.0]) + 1.0).abs() < 1e-6);
assert_eq!(cosine_similarity(&[], &[]), 0.0);
}
}