Skip to main content

noether_engine/index/
cache.rs

1#![warn(clippy::unwrap_used)]
2#![cfg_attr(test, allow(clippy::unwrap_used))]
3
4use super::embedding::{Embedding, EmbeddingError, EmbeddingProvider};
5use serde::{Deserialize, Serialize};
6use sha2::{Digest, Sha256};
7use std::collections::HashMap;
8use std::path::PathBuf;
9
10/// Wraps an EmbeddingProvider with a file-backed cache.
11/// Embeddings are keyed by SHA-256 of the input text.
12pub struct CachedEmbeddingProvider {
13    inner: Box<dyn EmbeddingProvider>,
14    cache: HashMap<String, Embedding>,
15    path: PathBuf,
16    dirty: bool,
17}
18
19#[derive(Serialize, Deserialize)]
20struct CacheFile {
21    entries: Vec<CacheEntry>,
22}
23
24#[derive(Serialize, Deserialize)]
25struct CacheEntry {
26    text_hash: String,
27    embedding: Embedding,
28}
29
30impl CachedEmbeddingProvider {
31    pub fn new(inner: Box<dyn EmbeddingProvider>, path: impl Into<PathBuf>) -> Self {
32        let path = path.into();
33        let cache = if path.exists() {
34            std::fs::read_to_string(&path)
35                .ok()
36                .and_then(|content| {
37                    if content.trim().is_empty() {
38                        return None;
39                    }
40                    serde_json::from_str::<CacheFile>(&content).ok()
41                })
42                .map(|f| {
43                    f.entries
44                        .into_iter()
45                        .map(|e| (e.text_hash, e.embedding))
46                        .collect()
47                })
48                .unwrap_or_default()
49        } else {
50            HashMap::new()
51        };
52        Self {
53            inner,
54            cache,
55            path,
56            dirty: false,
57        }
58    }
59
60    fn text_hash(text: &str) -> String {
61        hex::encode(Sha256::digest(text.as_bytes()))
62    }
63
64    /// Flush cache to disk if dirty.
65    pub fn flush(&self) {
66        if !self.dirty {
67            return;
68        }
69        if let Some(parent) = self.path.parent() {
70            let _ = std::fs::create_dir_all(parent);
71        }
72        let file = CacheFile {
73            entries: self
74                .cache
75                .iter()
76                .map(|(h, e)| CacheEntry {
77                    text_hash: h.clone(),
78                    embedding: e.clone(),
79                })
80                .collect(),
81        };
82        if let Ok(json) = serde_json::to_string(&file) {
83            let _ = std::fs::write(&self.path, json);
84        }
85    }
86}
87
88impl Drop for CachedEmbeddingProvider {
89    fn drop(&mut self) {
90        self.flush();
91    }
92}
93
94impl EmbeddingProvider for CachedEmbeddingProvider {
95    fn dimensions(&self) -> usize {
96        self.inner.dimensions()
97    }
98
99    fn embed(&self, text: &str) -> Result<Embedding, EmbeddingError> {
100        let hash = Self::text_hash(text);
101        if let Some(cached) = self.cache.get(&hash) {
102            return Ok(cached.clone());
103        }
104        // Cache miss — compute and store
105        // We need interior mutability here since the trait requires &self
106        // Use unsafe or switch to RefCell. For simplicity, call inner and
107        // let the caller handle caching via embed_and_cache.
108        self.inner.embed(text)
109    }
110
111    fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Embedding>, EmbeddingError> {
112        texts.iter().map(|t| self.embed(t)).collect()
113    }
114}
115
116impl CachedEmbeddingProvider {
117    /// Embed with caching — stores result in cache.
118    pub fn embed_cached(&mut self, text: &str) -> Result<Embedding, EmbeddingError> {
119        let hash = Self::text_hash(text);
120        if let Some(cached) = self.cache.get(&hash) {
121            return Ok(cached.clone());
122        }
123        let embedding = self.inner.embed(text)?;
124        self.cache.insert(hash, embedding.clone());
125        self.dirty = true;
126        Ok(embedding)
127    }
128
129    /// Embed many texts at once, calling `inner.embed_batch` on cache
130    /// misses. Cache hits are served from memory. Misses are sent in chunks
131    /// of `chunk_size` to keep individual requests under typical provider
132    /// payload limits.
133    ///
134    /// Two robustness properties matter when a remote provider is rate-limited:
135    ///
136    /// - **Progressive caching.** Each successful batch is committed to the
137    ///   in-memory cache *and* flushed to disk immediately. If the next
138    ///   batch trips a 429, the function still returns Err — but the partial
139    ///   work done so far is durable, so the next process restart picks up
140    ///   exactly where the crash left off.
141    /// - **Inter-batch pacing.** Between batch calls we sleep
142    ///   `inter_batch_delay`. With Mistral's free-tier 1 req/s ceiling, a
143    ///   ~1100 ms sleep keeps us comfortably under the limit; paid tiers
144    ///   can pass `Duration::ZERO` to skip pacing.
145    ///
146    /// Order of results matches order of `texts`.
147    pub fn embed_batch_cached_paced(
148        &mut self,
149        texts: &[&str],
150        chunk_size: usize,
151        inter_batch_delay: std::time::Duration,
152    ) -> Result<Vec<Embedding>, EmbeddingError> {
153        if texts.is_empty() {
154            return Ok(Vec::new());
155        }
156
157        let hashes: Vec<String> = texts.iter().map(|t| Self::text_hash(t)).collect();
158        let mut miss_indices: Vec<usize> = Vec::new();
159        let mut miss_texts: Vec<&str> = Vec::new();
160        for (i, h) in hashes.iter().enumerate() {
161            if !self.cache.contains_key(h) {
162                miss_indices.push(i);
163                miss_texts.push(texts[i]);
164            }
165        }
166
167        if !miss_texts.is_empty() {
168            let chunk = chunk_size.max(1);
169            let mut consumed = 0usize;
170            for (b, slice) in miss_texts.chunks(chunk).enumerate() {
171                if b > 0 && !inter_batch_delay.is_zero() {
172                    std::thread::sleep(inter_batch_delay);
173                }
174                let part = self.inner.embed_batch(slice)?;
175                // A well-behaved provider returns exactly one embedding per
176                // input text. A misbehaving one (truncated response, rate-
177                // limit short-read, etc.) would desync `consumed` against
178                // `miss_indices` and leave some cache slots unfilled — the
179                // final `cache.get(h).expect(..)` used to panic here. Bail
180                // as a typed provider error instead.
181                if part.len() != slice.len() {
182                    return Err(EmbeddingError::Provider(format!(
183                        "embed_batch returned {} embeddings for {} inputs",
184                        part.len(),
185                        slice.len()
186                    )));
187                }
188                for (j, emb) in part.into_iter().enumerate() {
189                    let idx = miss_indices[consumed + j];
190                    self.cache.insert(hashes[idx].clone(), emb);
191                }
192                consumed += slice.len();
193                self.dirty = true;
194                self.flush();
195            }
196        }
197
198        let mut out = Vec::with_capacity(hashes.len());
199        for h in &hashes {
200            match self.cache.get(h).cloned() {
201                Some(e) => out.push(e),
202                None => {
203                    return Err(EmbeddingError::Provider(
204                        "embedding cache missing an entry after batch fill; provider or cache \
205                         layer returned inconsistent results"
206                            .to_string(),
207                    ));
208                }
209            }
210        }
211        Ok(out)
212    }
213
214    /// Backward-compatible wrapper: no inter-batch sleep, single final flush.
215    /// Prefer `embed_batch_cached_paced` for rate-limited providers.
216    pub fn embed_batch_cached(
217        &mut self,
218        texts: &[&str],
219        chunk_size: usize,
220    ) -> Result<Vec<Embedding>, EmbeddingError> {
221        self.embed_batch_cached_paced(texts, chunk_size, std::time::Duration::ZERO)
222    }
223}
224
225#[cfg(test)]
226mod tests {
227    use super::*;
228
229    /// Provider that returns fewer embeddings than requested on the first
230    /// batch — simulates a misbehaving remote that truncates responses.
231    struct ShortBatchProvider;
232
233    impl EmbeddingProvider for ShortBatchProvider {
234        fn dimensions(&self) -> usize {
235            4
236        }
237        fn embed(&self, _text: &str) -> Result<Embedding, EmbeddingError> {
238            Ok(vec![0.0; 4])
239        }
240        fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Embedding>, EmbeddingError> {
241            // Deliberately drop the last entry.
242            Ok(texts
243                .iter()
244                .take(texts.len().saturating_sub(1))
245                .map(|_| vec![0.0; 4])
246                .collect())
247        }
248    }
249
250    #[test]
251    fn short_batch_becomes_provider_error_not_panic() {
252        let tmp = std::env::temp_dir().join("noether-cache-short-batch-test.json");
253        let _ = std::fs::remove_file(&tmp);
254        let mut cp = CachedEmbeddingProvider::new(Box::new(ShortBatchProvider), tmp);
255        let texts = ["a", "b", "c"];
256        let r = cp.embed_batch_cached(&texts, 8);
257        assert!(
258            matches!(r, Err(EmbeddingError::Provider(ref m)) if m.contains("embed_batch returned")),
259            "expected Provider error, got: {r:?}"
260        );
261    }
262}