Skip to main content

synaptic_embeddings/
cached.rs

1use std::sync::Arc;
2
3use async_trait::async_trait;
4use sha2::{Digest, Sha256};
5use synaptic_core::{Store, SynapticError};
6
7use crate::Embeddings;
8
9/// An embeddings wrapper that caches results in a [`Store`] backend.
10///
11/// Previously computed embeddings are stored in the provided [`Store`] keyed
12/// by the SHA-256 hash of the input text. On subsequent calls, cached
13/// embeddings are returned directly, and only uncached texts are sent to the
14/// inner embeddings provider.
15///
16/// This aligns with LangChain Python's `CacheBackedEmbeddings` architecture,
17/// allowing any `Store` implementation (in-memory, SQLite, PostgreSQL, Redis,
18/// etc.) to serve as the caching backend.
19pub struct CacheBackedEmbeddings {
20    inner: Arc<dyn Embeddings>,
21    store: Arc<dyn Store>,
22    namespace: String,
23}
24
25impl CacheBackedEmbeddings {
26    /// Create a new cached embeddings wrapper.
27    ///
28    /// - `inner` — the underlying embeddings provider to delegate to on cache misses.
29    /// - `store` — the [`Store`] backend for persisting cached embeddings.
30    /// - `namespace` — a logical namespace within the store (combined with
31    ///   `"embedding_cache"` as the prefix).
32    pub fn new(
33        inner: Arc<dyn Embeddings>,
34        store: Arc<dyn Store>,
35        namespace: impl Into<String>,
36    ) -> Self {
37        Self {
38            inner,
39            store,
40            namespace: namespace.into(),
41        }
42    }
43
44    /// Build the store namespace for this cache instance.
45    fn store_namespace(&self) -> Vec<String> {
46        vec!["embedding_cache".to_string(), self.namespace.clone()]
47    }
48
49    /// Compute the SHA-256 hash of a text, returned as a hex string.
50    fn hash_key(text: &str) -> String {
51        let mut hasher = Sha256::new();
52        hasher.update(text.as_bytes());
53        format!("{:x}", hasher.finalize())
54    }
55}
56
57#[async_trait]
58impl Embeddings for CacheBackedEmbeddings {
59    async fn embed_documents(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, SynapticError> {
60        let ns = self.store_namespace();
61        let ns_refs: Vec<&str> = ns.iter().map(|s| s.as_str()).collect();
62
63        // Check cache for each text
64        let mut results: Vec<Option<Vec<f32>>> = Vec::with_capacity(texts.len());
65        let mut uncached_indices: Vec<usize> = Vec::new();
66        let mut uncached_texts: Vec<&str> = Vec::new();
67
68        for (i, text) in texts.iter().enumerate() {
69            let key = Self::hash_key(text);
70            if let Some(item) = self.store.get(&ns_refs, &key).await? {
71                // Deserialize the cached embedding
72                let embedding: Vec<f32> = serde_json::from_value(item.value)
73                    .map_err(|e| SynapticError::Store(format!("cache deserialize error: {e}")))?;
74                results.push(Some(embedding));
75            } else {
76                results.push(None);
77                uncached_indices.push(i);
78                uncached_texts.push(text);
79            }
80        }
81
82        // Embed uncached texts
83        if !uncached_texts.is_empty() {
84            let new_embeddings = self.inner.embed_documents(&uncached_texts).await?;
85
86            // Store new embeddings in cache
87            for (idx, embedding) in uncached_indices.iter().zip(new_embeddings.into_iter()) {
88                let key = Self::hash_key(texts[*idx]);
89                let value = serde_json::to_value(&embedding)
90                    .map_err(|e| SynapticError::Store(format!("cache serialize error: {e}")))?;
91                self.store.put(&ns_refs, &key, value).await?;
92                results[*idx] = Some(embedding);
93            }
94        }
95
96        // All results should now be Some
97        Ok(results.into_iter().map(|r| r.unwrap()).collect())
98    }
99
100    async fn embed_query(&self, text: &str) -> Result<Vec<f32>, SynapticError> {
101        let ns = self.store_namespace();
102        let ns_refs: Vec<&str> = ns.iter().map(|s| s.as_str()).collect();
103        let key = Self::hash_key(text);
104
105        // Check cache
106        if let Some(item) = self.store.get(&ns_refs, &key).await? {
107            let embedding: Vec<f32> = serde_json::from_value(item.value)
108                .map_err(|e| SynapticError::Store(format!("cache deserialize error: {e}")))?;
109            return Ok(embedding);
110        }
111
112        // Cache miss: compute embedding
113        let embedding = self.inner.embed_query(text).await?;
114
115        // Store in cache
116        let value = serde_json::to_value(&embedding)
117            .map_err(|e| SynapticError::Store(format!("cache serialize error: {e}")))?;
118        self.store.put(&ns_refs, &key, value).await?;
119
120        Ok(embedding)
121    }
122}