Skip to main content

noether_engine/index/
cache.rs

1use super::embedding::{Embedding, EmbeddingError, EmbeddingProvider};
2use serde::{Deserialize, Serialize};
3use sha2::{Digest, Sha256};
4use std::collections::HashMap;
5use std::path::PathBuf;
6
7/// Wraps an EmbeddingProvider with a file-backed cache.
8/// Embeddings are keyed by SHA-256 of the input text.
9pub struct CachedEmbeddingProvider {
10    inner: Box<dyn EmbeddingProvider>,
11    cache: HashMap<String, Embedding>,
12    path: PathBuf,
13    dirty: bool,
14}
15
16#[derive(Serialize, Deserialize)]
17struct CacheFile {
18    entries: Vec<CacheEntry>,
19}
20
21#[derive(Serialize, Deserialize)]
22struct CacheEntry {
23    text_hash: String,
24    embedding: Embedding,
25}
26
27impl CachedEmbeddingProvider {
28    pub fn new(inner: Box<dyn EmbeddingProvider>, path: impl Into<PathBuf>) -> Self {
29        let path = path.into();
30        let cache = if path.exists() {
31            std::fs::read_to_string(&path)
32                .ok()
33                .and_then(|content| {
34                    if content.trim().is_empty() {
35                        return None;
36                    }
37                    serde_json::from_str::<CacheFile>(&content).ok()
38                })
39                .map(|f| {
40                    f.entries
41                        .into_iter()
42                        .map(|e| (e.text_hash, e.embedding))
43                        .collect()
44                })
45                .unwrap_or_default()
46        } else {
47            HashMap::new()
48        };
49        Self {
50            inner,
51            cache,
52            path,
53            dirty: false,
54        }
55    }
56
57    fn text_hash(text: &str) -> String {
58        hex::encode(Sha256::digest(text.as_bytes()))
59    }
60
61    /// Flush cache to disk if dirty.
62    pub fn flush(&self) {
63        if !self.dirty {
64            return;
65        }
66        if let Some(parent) = self.path.parent() {
67            let _ = std::fs::create_dir_all(parent);
68        }
69        let file = CacheFile {
70            entries: self
71                .cache
72                .iter()
73                .map(|(h, e)| CacheEntry {
74                    text_hash: h.clone(),
75                    embedding: e.clone(),
76                })
77                .collect(),
78        };
79        if let Ok(json) = serde_json::to_string(&file) {
80            let _ = std::fs::write(&self.path, json);
81        }
82    }
83}
84
85impl Drop for CachedEmbeddingProvider {
86    fn drop(&mut self) {
87        self.flush();
88    }
89}
90
91impl EmbeddingProvider for CachedEmbeddingProvider {
92    fn dimensions(&self) -> usize {
93        self.inner.dimensions()
94    }
95
96    fn embed(&self, text: &str) -> Result<Embedding, EmbeddingError> {
97        let hash = Self::text_hash(text);
98        if let Some(cached) = self.cache.get(&hash) {
99            return Ok(cached.clone());
100        }
101        // Cache miss — compute and store
102        // We need interior mutability here since the trait requires &self
103        // Use unsafe or switch to RefCell. For simplicity, call inner and
104        // let the caller handle caching via embed_and_cache.
105        self.inner.embed(text)
106    }
107
108    fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Embedding>, EmbeddingError> {
109        texts.iter().map(|t| self.embed(t)).collect()
110    }
111}
112
113impl CachedEmbeddingProvider {
114    /// Embed with caching — stores result in cache.
115    pub fn embed_cached(&mut self, text: &str) -> Result<Embedding, EmbeddingError> {
116        let hash = Self::text_hash(text);
117        if let Some(cached) = self.cache.get(&hash) {
118            return Ok(cached.clone());
119        }
120        let embedding = self.inner.embed(text)?;
121        self.cache.insert(hash, embedding.clone());
122        self.dirty = true;
123        Ok(embedding)
124    }
125
126    /// Embed many texts at once, calling `inner.embed_batch` on cache
127    /// misses. Cache hits are served from memory. Misses are sent in chunks
128    /// of `chunk_size` to keep individual requests under typical provider
129    /// payload limits.
130    ///
131    /// Two robustness properties matter when a remote provider is rate-limited:
132    ///
133    /// - **Progressive caching.** Each successful batch is committed to the
134    ///   in-memory cache *and* flushed to disk immediately. If the next
135    ///   batch trips a 429, the function still returns Err — but the partial
136    ///   work done so far is durable, so the next process restart picks up
137    ///   exactly where the crash left off.
138    /// - **Inter-batch pacing.** Between batch calls we sleep
139    ///   `inter_batch_delay`. With Mistral's free-tier 1 req/s ceiling, a
140    ///   ~1100 ms sleep keeps us comfortably under the limit; paid tiers
141    ///   can pass `Duration::ZERO` to skip pacing.
142    ///
143    /// Order of results matches order of `texts`.
144    pub fn embed_batch_cached_paced(
145        &mut self,
146        texts: &[&str],
147        chunk_size: usize,
148        inter_batch_delay: std::time::Duration,
149    ) -> Result<Vec<Embedding>, EmbeddingError> {
150        if texts.is_empty() {
151            return Ok(Vec::new());
152        }
153
154        let hashes: Vec<String> = texts.iter().map(|t| Self::text_hash(t)).collect();
155        let mut miss_indices: Vec<usize> = Vec::new();
156        let mut miss_texts: Vec<&str> = Vec::new();
157        for (i, h) in hashes.iter().enumerate() {
158            if !self.cache.contains_key(h) {
159                miss_indices.push(i);
160                miss_texts.push(texts[i]);
161            }
162        }
163
164        if !miss_texts.is_empty() {
165            let chunk = chunk_size.max(1);
166            let mut consumed = 0usize;
167            for (b, slice) in miss_texts.chunks(chunk).enumerate() {
168                if b > 0 && !inter_batch_delay.is_zero() {
169                    std::thread::sleep(inter_batch_delay);
170                }
171                let part = self.inner.embed_batch(slice)?;
172                for (j, emb) in part.into_iter().enumerate() {
173                    let idx = miss_indices[consumed + j];
174                    self.cache.insert(hashes[idx].clone(), emb);
175                }
176                consumed += slice.len();
177                self.dirty = true;
178                self.flush();
179            }
180        }
181
182        Ok(hashes
183            .iter()
184            .map(|h| self.cache.get(h).cloned().expect("just inserted"))
185            .collect())
186    }
187
188    /// Backward-compatible wrapper: no inter-batch sleep, single final flush.
189    /// Prefer `embed_batch_cached_paced` for rate-limited providers.
190    pub fn embed_batch_cached(
191        &mut self,
192        texts: &[&str],
193        chunk_size: usize,
194    ) -> Result<Vec<Embedding>, EmbeddingError> {
195        self.embed_batch_cached_paced(texts, chunk_size, std::time::Duration::ZERO)
196    }
197}