Skip to main content

synaptic_cache/
semantic.rs

1use std::sync::Arc;
2
3use async_trait::async_trait;
4use synaptic_core::{ChatResponse, SynapticError};
5use synaptic_embeddings::Embeddings;
6use tokio::sync::RwLock;
7
8use crate::LlmCache;
9
10struct SemanticEntry {
11    embedding: Vec<f32>,
12    response: ChatResponse,
13}
14
15/// Cache that uses embedding similarity to match semantically equivalent queries.
16///
17/// When a cache lookup is performed, the key is embedded and compared against all
18/// stored entries using cosine similarity. If any entry exceeds the similarity
19/// threshold, its cached response is returned.
20pub struct SemanticCache {
21    embeddings: Arc<dyn Embeddings>,
22    entries: RwLock<Vec<SemanticEntry>>,
23    similarity_threshold: f32,
24}
25
26impl SemanticCache {
27    /// Create a new semantic cache with the given embeddings provider and similarity threshold.
28    ///
29    /// The threshold should be between 0.0 and 1.0. A typical value is 0.95, meaning
30    /// only very similar queries will match.
31    pub fn new(embeddings: Arc<dyn Embeddings>, similarity_threshold: f32) -> Self {
32        Self {
33            embeddings,
34            entries: RwLock::new(Vec::new()),
35            similarity_threshold,
36        }
37    }
38}
39
40#[async_trait]
41impl LlmCache for SemanticCache {
42    async fn get(&self, key: &str) -> Result<Option<ChatResponse>, SynapticError> {
43        let query_embedding =
44            self.embeddings.embed_query(key).await.map_err(|e| {
45                SynapticError::Cache(format!("embedding error during cache get: {e}"))
46            })?;
47
48        let entries = self.entries.read().await;
49        let mut best_score = f32::NEG_INFINITY;
50        let mut best_response = None;
51
52        for entry in entries.iter() {
53            let score = cosine_similarity(&query_embedding, &entry.embedding);
54            if score >= self.similarity_threshold && score > best_score {
55                best_score = score;
56                best_response = Some(entry.response.clone());
57            }
58        }
59
60        Ok(best_response)
61    }
62
63    async fn put(&self, key: &str, response: &ChatResponse) -> Result<(), SynapticError> {
64        let embedding =
65            self.embeddings.embed_query(key).await.map_err(|e| {
66                SynapticError::Cache(format!("embedding error during cache put: {e}"))
67            })?;
68
69        let mut entries = self.entries.write().await;
70        entries.push(SemanticEntry {
71            embedding,
72            response: response.clone(),
73        });
74
75        Ok(())
76    }
77
78    async fn clear(&self) -> Result<(), SynapticError> {
79        let mut entries = self.entries.write().await;
80        entries.clear();
81        Ok(())
82    }
83}
84
85/// Compute cosine similarity between two vectors.
86fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
87    if a.len() != b.len() || a.is_empty() {
88        return 0.0;
89    }
90
91    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
92    let mag_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
93    let mag_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
94
95    if mag_a == 0.0 || mag_b == 0.0 {
96        return 0.0;
97    }
98
99    dot / (mag_a * mag_b)
100}
101
102#[cfg(test)]
103mod tests {
104    use super::*;
105
106    #[test]
107    fn cosine_similarity_identical_vectors() {
108        let a = vec![1.0, 0.0, 0.0];
109        let b = vec![1.0, 0.0, 0.0];
110        let sim = cosine_similarity(&a, &b);
111        assert!((sim - 1.0).abs() < 1e-6);
112    }
113
114    #[test]
115    fn cosine_similarity_orthogonal_vectors() {
116        let a = vec![1.0, 0.0];
117        let b = vec![0.0, 1.0];
118        let sim = cosine_similarity(&a, &b);
119        assert!(sim.abs() < 1e-6);
120    }
121
122    #[test]
123    fn cosine_similarity_empty_vectors() {
124        let a: Vec<f32> = vec![];
125        let b: Vec<f32> = vec![];
126        assert_eq!(cosine_similarity(&a, &b), 0.0);
127    }
128}