Skip to main content

synaptic_retrieval/
compression.rs

1use std::sync::Arc;
2
3use async_trait::async_trait;
4use synaptic_core::SynapticError;
5use synaptic_embeddings::Embeddings;
6
7use crate::{Document, Retriever};
8
9/// Trait for compressing or filtering a set of documents based on a query.
10#[async_trait]
11pub trait DocumentCompressor: Send + Sync {
12    /// Compress or filter documents based on relevance to the query.
13    async fn compress_documents(
14        &self,
15        documents: Vec<Document>,
16        query: &str,
17    ) -> Result<Vec<Document>, SynapticError>;
18}
19
20/// Filters documents based on cosine similarity between the query embedding
21/// and document content embeddings.
22pub struct EmbeddingsFilter {
23    embeddings: Arc<dyn Embeddings>,
24    threshold: f32,
25}
26
27impl EmbeddingsFilter {
28    /// Create a new EmbeddingsFilter with the given embeddings provider and similarity threshold.
29    pub fn new(embeddings: Arc<dyn Embeddings>, threshold: f32) -> Self {
30        Self {
31            embeddings,
32            threshold,
33        }
34    }
35
36    /// Create a new EmbeddingsFilter with the default threshold of 0.75.
37    pub fn with_default_threshold(embeddings: Arc<dyn Embeddings>) -> Self {
38        Self::new(embeddings, 0.75)
39    }
40}
41
42/// Compute cosine similarity between two vectors.
43fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
44    if a.len() != b.len() || a.is_empty() {
45        return 0.0;
46    }
47
48    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
49    let mag_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
50    let mag_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
51
52    if mag_a == 0.0 || mag_b == 0.0 {
53        return 0.0;
54    }
55
56    dot / (mag_a * mag_b)
57}
58
59#[async_trait]
60impl DocumentCompressor for EmbeddingsFilter {
61    async fn compress_documents(
62        &self,
63        documents: Vec<Document>,
64        query: &str,
65    ) -> Result<Vec<Document>, SynapticError> {
66        if documents.is_empty() {
67            return Ok(vec![]);
68        }
69
70        // Embed the query
71        let query_embedding = self.embeddings.embed_query(query).await?;
72
73        // Embed all document contents
74        let doc_texts: Vec<&str> = documents.iter().map(|d| d.content.as_str()).collect();
75        let doc_embeddings = self.embeddings.embed_documents(&doc_texts).await?;
76
77        // Filter documents by cosine similarity threshold
78        let filtered = documents
79            .into_iter()
80            .zip(doc_embeddings.iter())
81            .filter(|(_, doc_emb)| cosine_similarity(&query_embedding, doc_emb) >= self.threshold)
82            .map(|(doc, _)| doc)
83            .collect();
84
85        Ok(filtered)
86    }
87}
88
89/// A retriever that retrieves documents from a base retriever and then
90/// compresses/filters them using a DocumentCompressor.
91pub struct ContextualCompressionRetriever {
92    base: Arc<dyn Retriever>,
93    compressor: Arc<dyn DocumentCompressor>,
94}
95
96impl ContextualCompressionRetriever {
97    /// Create a new ContextualCompressionRetriever.
98    pub fn new(base: Arc<dyn Retriever>, compressor: Arc<dyn DocumentCompressor>) -> Self {
99        Self { base, compressor }
100    }
101}
102
103#[async_trait]
104impl Retriever for ContextualCompressionRetriever {
105    async fn retrieve(&self, query: &str, top_k: usize) -> Result<Vec<Document>, SynapticError> {
106        // First, retrieve from the base retriever
107        let docs = self.base.retrieve(query, top_k).await?;
108
109        // Then compress/filter the results
110        let compressed = self.compressor.compress_documents(docs, query).await?;
111
112        Ok(compressed)
113    }
114}