synaptic_embeddings/
cached.rs1use std::sync::Arc;
2
3use async_trait::async_trait;
4use sha2::{Digest, Sha256};
5use synaptic_core::{Store, SynapticError};
6
7use crate::Embeddings;
8
9pub struct CacheBackedEmbeddings {
20 inner: Arc<dyn Embeddings>,
21 store: Arc<dyn Store>,
22 namespace: String,
23}
24
25impl CacheBackedEmbeddings {
26 pub fn new(
33 inner: Arc<dyn Embeddings>,
34 store: Arc<dyn Store>,
35 namespace: impl Into<String>,
36 ) -> Self {
37 Self {
38 inner,
39 store,
40 namespace: namespace.into(),
41 }
42 }
43
44 fn store_namespace(&self) -> Vec<String> {
46 vec!["embedding_cache".to_string(), self.namespace.clone()]
47 }
48
49 fn hash_key(text: &str) -> String {
51 let mut hasher = Sha256::new();
52 hasher.update(text.as_bytes());
53 format!("{:x}", hasher.finalize())
54 }
55}
56
57#[async_trait]
58impl Embeddings for CacheBackedEmbeddings {
59 async fn embed_documents(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, SynapticError> {
60 let ns = self.store_namespace();
61 let ns_refs: Vec<&str> = ns.iter().map(|s| s.as_str()).collect();
62
63 let mut results: Vec<Option<Vec<f32>>> = Vec::with_capacity(texts.len());
65 let mut uncached_indices: Vec<usize> = Vec::new();
66 let mut uncached_texts: Vec<&str> = Vec::new();
67
68 for (i, text) in texts.iter().enumerate() {
69 let key = Self::hash_key(text);
70 if let Some(item) = self.store.get(&ns_refs, &key).await? {
71 let embedding: Vec<f32> = serde_json::from_value(item.value)
73 .map_err(|e| SynapticError::Store(format!("cache deserialize error: {e}")))?;
74 results.push(Some(embedding));
75 } else {
76 results.push(None);
77 uncached_indices.push(i);
78 uncached_texts.push(text);
79 }
80 }
81
82 if !uncached_texts.is_empty() {
84 let new_embeddings = self.inner.embed_documents(&uncached_texts).await?;
85
86 for (idx, embedding) in uncached_indices.iter().zip(new_embeddings.into_iter()) {
88 let key = Self::hash_key(texts[*idx]);
89 let value = serde_json::to_value(&embedding)
90 .map_err(|e| SynapticError::Store(format!("cache serialize error: {e}")))?;
91 self.store.put(&ns_refs, &key, value).await?;
92 results[*idx] = Some(embedding);
93 }
94 }
95
96 Ok(results.into_iter().map(|r| r.unwrap()).collect())
98 }
99
100 async fn embed_query(&self, text: &str) -> Result<Vec<f32>, SynapticError> {
101 let ns = self.store_namespace();
102 let ns_refs: Vec<&str> = ns.iter().map(|s| s.as_str()).collect();
103 let key = Self::hash_key(text);
104
105 if let Some(item) = self.store.get(&ns_refs, &key).await? {
107 let embedding: Vec<f32> = serde_json::from_value(item.value)
108 .map_err(|e| SynapticError::Store(format!("cache deserialize error: {e}")))?;
109 return Ok(embedding);
110 }
111
112 let embedding = self.inner.embed_query(text).await?;
114
115 let value = serde_json::to_value(&embedding)
117 .map_err(|e| SynapticError::Store(format!("cache serialize error: {e}")))?;
118 self.store.put(&ns_refs, &key, value).await?;
119
120 Ok(embedding)
121 }
122}