synaptic_retrieval/
compression.rs1use std::sync::Arc;
2
3use async_trait::async_trait;
4use synaptic_core::SynapseError;
5use synaptic_embeddings::Embeddings;
6
7use crate::{Document, Retriever};
8
9#[async_trait]
11pub trait DocumentCompressor: Send + Sync {
12 async fn compress_documents(
14 &self,
15 documents: Vec<Document>,
16 query: &str,
17 ) -> Result<Vec<Document>, SynapseError>;
18}
19
20pub struct EmbeddingsFilter {
23 embeddings: Arc<dyn Embeddings>,
24 threshold: f32,
25}
26
27impl EmbeddingsFilter {
28 pub fn new(embeddings: Arc<dyn Embeddings>, threshold: f32) -> Self {
30 Self {
31 embeddings,
32 threshold,
33 }
34 }
35
36 pub fn with_default_threshold(embeddings: Arc<dyn Embeddings>) -> Self {
38 Self::new(embeddings, 0.75)
39 }
40}
41
42fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
44 if a.len() != b.len() || a.is_empty() {
45 return 0.0;
46 }
47
48 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
49 let mag_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
50 let mag_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
51
52 if mag_a == 0.0 || mag_b == 0.0 {
53 return 0.0;
54 }
55
56 dot / (mag_a * mag_b)
57}
58
59#[async_trait]
60impl DocumentCompressor for EmbeddingsFilter {
61 async fn compress_documents(
62 &self,
63 documents: Vec<Document>,
64 query: &str,
65 ) -> Result<Vec<Document>, SynapseError> {
66 if documents.is_empty() {
67 return Ok(vec![]);
68 }
69
70 let query_embedding = self.embeddings.embed_query(query).await?;
72
73 let doc_texts: Vec<&str> = documents.iter().map(|d| d.content.as_str()).collect();
75 let doc_embeddings = self.embeddings.embed_documents(&doc_texts).await?;
76
77 let filtered = documents
79 .into_iter()
80 .zip(doc_embeddings.iter())
81 .filter(|(_, doc_emb)| cosine_similarity(&query_embedding, doc_emb) >= self.threshold)
82 .map(|(doc, _)| doc)
83 .collect();
84
85 Ok(filtered)
86 }
87}
88
89pub struct ContextualCompressionRetriever {
92 base: Arc<dyn Retriever>,
93 compressor: Arc<dyn DocumentCompressor>,
94}
95
96impl ContextualCompressionRetriever {
97 pub fn new(base: Arc<dyn Retriever>, compressor: Arc<dyn DocumentCompressor>) -> Self {
99 Self { base, compressor }
100 }
101}
102
103#[async_trait]
104impl Retriever for ContextualCompressionRetriever {
105 async fn retrieve(&self, query: &str, top_k: usize) -> Result<Vec<Document>, SynapseError> {
106 let docs = self.base.retrieve(query, top_k).await?;
108
109 let compressed = self.compressor.compress_documents(docs, query).await?;
111
112 Ok(compressed)
113 }
114}