noether_engine/index/
cache.rs1#![warn(clippy::unwrap_used)]
2#![cfg_attr(test, allow(clippy::unwrap_used))]
3
4use super::embedding::{Embedding, EmbeddingError, EmbeddingProvider};
5use serde::{Deserialize, Serialize};
6use sha2::{Digest, Sha256};
7use std::collections::HashMap;
8use std::path::PathBuf;
9
10pub struct CachedEmbeddingProvider {
13 inner: Box<dyn EmbeddingProvider>,
14 cache: HashMap<String, Embedding>,
15 path: PathBuf,
16 dirty: bool,
17}
18
19#[derive(Serialize, Deserialize)]
20struct CacheFile {
21 entries: Vec<CacheEntry>,
22}
23
24#[derive(Serialize, Deserialize)]
25struct CacheEntry {
26 text_hash: String,
27 embedding: Embedding,
28}
29
30impl CachedEmbeddingProvider {
31 pub fn new(inner: Box<dyn EmbeddingProvider>, path: impl Into<PathBuf>) -> Self {
32 let path = path.into();
33 let cache = if path.exists() {
34 std::fs::read_to_string(&path)
35 .ok()
36 .and_then(|content| {
37 if content.trim().is_empty() {
38 return None;
39 }
40 serde_json::from_str::<CacheFile>(&content).ok()
41 })
42 .map(|f| {
43 f.entries
44 .into_iter()
45 .map(|e| (e.text_hash, e.embedding))
46 .collect()
47 })
48 .unwrap_or_default()
49 } else {
50 HashMap::new()
51 };
52 Self {
53 inner,
54 cache,
55 path,
56 dirty: false,
57 }
58 }
59
60 fn text_hash(text: &str) -> String {
61 hex::encode(Sha256::digest(text.as_bytes()))
62 }
63
64 pub fn flush(&self) {
66 if !self.dirty {
67 return;
68 }
69 if let Some(parent) = self.path.parent() {
70 let _ = std::fs::create_dir_all(parent);
71 }
72 let file = CacheFile {
73 entries: self
74 .cache
75 .iter()
76 .map(|(h, e)| CacheEntry {
77 text_hash: h.clone(),
78 embedding: e.clone(),
79 })
80 .collect(),
81 };
82 if let Ok(json) = serde_json::to_string(&file) {
83 let _ = std::fs::write(&self.path, json);
84 }
85 }
86}
87
88impl Drop for CachedEmbeddingProvider {
89 fn drop(&mut self) {
90 self.flush();
91 }
92}
93
94impl EmbeddingProvider for CachedEmbeddingProvider {
95 fn dimensions(&self) -> usize {
96 self.inner.dimensions()
97 }
98
99 fn embed(&self, text: &str) -> Result<Embedding, EmbeddingError> {
100 let hash = Self::text_hash(text);
101 if let Some(cached) = self.cache.get(&hash) {
102 return Ok(cached.clone());
103 }
104 self.inner.embed(text)
109 }
110
111 fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Embedding>, EmbeddingError> {
112 texts.iter().map(|t| self.embed(t)).collect()
113 }
114}
115
116impl CachedEmbeddingProvider {
117 pub fn embed_cached(&mut self, text: &str) -> Result<Embedding, EmbeddingError> {
119 let hash = Self::text_hash(text);
120 if let Some(cached) = self.cache.get(&hash) {
121 return Ok(cached.clone());
122 }
123 let embedding = self.inner.embed(text)?;
124 self.cache.insert(hash, embedding.clone());
125 self.dirty = true;
126 Ok(embedding)
127 }
128
129 pub fn embed_batch_cached_paced(
148 &mut self,
149 texts: &[&str],
150 chunk_size: usize,
151 inter_batch_delay: std::time::Duration,
152 ) -> Result<Vec<Embedding>, EmbeddingError> {
153 if texts.is_empty() {
154 return Ok(Vec::new());
155 }
156
157 let hashes: Vec<String> = texts.iter().map(|t| Self::text_hash(t)).collect();
158 let mut miss_indices: Vec<usize> = Vec::new();
159 let mut miss_texts: Vec<&str> = Vec::new();
160 for (i, h) in hashes.iter().enumerate() {
161 if !self.cache.contains_key(h) {
162 miss_indices.push(i);
163 miss_texts.push(texts[i]);
164 }
165 }
166
167 if !miss_texts.is_empty() {
168 let chunk = chunk_size.max(1);
169 let mut consumed = 0usize;
170 for (b, slice) in miss_texts.chunks(chunk).enumerate() {
171 if b > 0 && !inter_batch_delay.is_zero() {
172 std::thread::sleep(inter_batch_delay);
173 }
174 let part = self.inner.embed_batch(slice)?;
175 if part.len() != slice.len() {
182 return Err(EmbeddingError::Provider(format!(
183 "embed_batch returned {} embeddings for {} inputs",
184 part.len(),
185 slice.len()
186 )));
187 }
188 for (j, emb) in part.into_iter().enumerate() {
189 let idx = miss_indices[consumed + j];
190 self.cache.insert(hashes[idx].clone(), emb);
191 }
192 consumed += slice.len();
193 self.dirty = true;
194 self.flush();
195 }
196 }
197
198 let mut out = Vec::with_capacity(hashes.len());
199 for h in &hashes {
200 match self.cache.get(h).cloned() {
201 Some(e) => out.push(e),
202 None => {
203 return Err(EmbeddingError::Provider(
204 "embedding cache missing an entry after batch fill; provider or cache \
205 layer returned inconsistent results"
206 .to_string(),
207 ));
208 }
209 }
210 }
211 Ok(out)
212 }
213
214 pub fn embed_batch_cached(
217 &mut self,
218 texts: &[&str],
219 chunk_size: usize,
220 ) -> Result<Vec<Embedding>, EmbeddingError> {
221 self.embed_batch_cached_paced(texts, chunk_size, std::time::Duration::ZERO)
222 }
223}
224
225#[cfg(test)]
226mod tests {
227 use super::*;
228
229 struct ShortBatchProvider;
232
233 impl EmbeddingProvider for ShortBatchProvider {
234 fn dimensions(&self) -> usize {
235 4
236 }
237 fn embed(&self, _text: &str) -> Result<Embedding, EmbeddingError> {
238 Ok(vec![0.0; 4])
239 }
240 fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Embedding>, EmbeddingError> {
241 Ok(texts
243 .iter()
244 .take(texts.len().saturating_sub(1))
245 .map(|_| vec![0.0; 4])
246 .collect())
247 }
248 }
249
250 #[test]
251 fn short_batch_becomes_provider_error_not_panic() {
252 let tmp = std::env::temp_dir().join("noether-cache-short-batch-test.json");
253 let _ = std::fs::remove_file(&tmp);
254 let mut cp = CachedEmbeddingProvider::new(Box::new(ShortBatchProvider), tmp);
255 let texts = ["a", "b", "c"];
256 let r = cp.embed_batch_cached(&texts, 8);
257 assert!(
258 matches!(r, Err(EmbeddingError::Provider(ref m)) if m.contains("embed_batch returned")),
259 "expected Provider error, got: {r:?}"
260 );
261 }
262}