noether_engine/index/
cache.rs1use super::embedding::{Embedding, EmbeddingError, EmbeddingProvider};
2use serde::{Deserialize, Serialize};
3use sha2::{Digest, Sha256};
4use std::collections::HashMap;
5use std::path::PathBuf;
6
7pub struct CachedEmbeddingProvider {
10 inner: Box<dyn EmbeddingProvider>,
11 cache: HashMap<String, Embedding>,
12 path: PathBuf,
13 dirty: bool,
14}
15
16#[derive(Serialize, Deserialize)]
17struct CacheFile {
18 entries: Vec<CacheEntry>,
19}
20
21#[derive(Serialize, Deserialize)]
22struct CacheEntry {
23 text_hash: String,
24 embedding: Embedding,
25}
26
27impl CachedEmbeddingProvider {
28 pub fn new(inner: Box<dyn EmbeddingProvider>, path: impl Into<PathBuf>) -> Self {
29 let path = path.into();
30 let cache = if path.exists() {
31 std::fs::read_to_string(&path)
32 .ok()
33 .and_then(|content| {
34 if content.trim().is_empty() {
35 return None;
36 }
37 serde_json::from_str::<CacheFile>(&content).ok()
38 })
39 .map(|f| {
40 f.entries
41 .into_iter()
42 .map(|e| (e.text_hash, e.embedding))
43 .collect()
44 })
45 .unwrap_or_default()
46 } else {
47 HashMap::new()
48 };
49 Self {
50 inner,
51 cache,
52 path,
53 dirty: false,
54 }
55 }
56
57 fn text_hash(text: &str) -> String {
58 hex::encode(Sha256::digest(text.as_bytes()))
59 }
60
61 pub fn flush(&self) {
63 if !self.dirty {
64 return;
65 }
66 if let Some(parent) = self.path.parent() {
67 let _ = std::fs::create_dir_all(parent);
68 }
69 let file = CacheFile {
70 entries: self
71 .cache
72 .iter()
73 .map(|(h, e)| CacheEntry {
74 text_hash: h.clone(),
75 embedding: e.clone(),
76 })
77 .collect(),
78 };
79 if let Ok(json) = serde_json::to_string(&file) {
80 let _ = std::fs::write(&self.path, json);
81 }
82 }
83}
84
85impl Drop for CachedEmbeddingProvider {
86 fn drop(&mut self) {
87 self.flush();
88 }
89}
90
91impl EmbeddingProvider for CachedEmbeddingProvider {
92 fn dimensions(&self) -> usize {
93 self.inner.dimensions()
94 }
95
96 fn embed(&self, text: &str) -> Result<Embedding, EmbeddingError> {
97 let hash = Self::text_hash(text);
98 if let Some(cached) = self.cache.get(&hash) {
99 return Ok(cached.clone());
100 }
101 self.inner.embed(text)
106 }
107
108 fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Embedding>, EmbeddingError> {
109 texts.iter().map(|t| self.embed(t)).collect()
110 }
111}
112
113impl CachedEmbeddingProvider {
114 pub fn embed_cached(&mut self, text: &str) -> Result<Embedding, EmbeddingError> {
116 let hash = Self::text_hash(text);
117 if let Some(cached) = self.cache.get(&hash) {
118 return Ok(cached.clone());
119 }
120 let embedding = self.inner.embed(text)?;
121 self.cache.insert(hash, embedding.clone());
122 self.dirty = true;
123 Ok(embedding)
124 }
125
126 pub fn embed_batch_cached_paced(
145 &mut self,
146 texts: &[&str],
147 chunk_size: usize,
148 inter_batch_delay: std::time::Duration,
149 ) -> Result<Vec<Embedding>, EmbeddingError> {
150 if texts.is_empty() {
151 return Ok(Vec::new());
152 }
153
154 let hashes: Vec<String> = texts.iter().map(|t| Self::text_hash(t)).collect();
155 let mut miss_indices: Vec<usize> = Vec::new();
156 let mut miss_texts: Vec<&str> = Vec::new();
157 for (i, h) in hashes.iter().enumerate() {
158 if !self.cache.contains_key(h) {
159 miss_indices.push(i);
160 miss_texts.push(texts[i]);
161 }
162 }
163
164 if !miss_texts.is_empty() {
165 let chunk = chunk_size.max(1);
166 let mut consumed = 0usize;
167 for (b, slice) in miss_texts.chunks(chunk).enumerate() {
168 if b > 0 && !inter_batch_delay.is_zero() {
169 std::thread::sleep(inter_batch_delay);
170 }
171 let part = self.inner.embed_batch(slice)?;
172 for (j, emb) in part.into_iter().enumerate() {
173 let idx = miss_indices[consumed + j];
174 self.cache.insert(hashes[idx].clone(), emb);
175 }
176 consumed += slice.len();
177 self.dirty = true;
178 self.flush();
179 }
180 }
181
182 Ok(hashes
183 .iter()
184 .map(|h| self.cache.get(h).cloned().expect("just inserted"))
185 .collect())
186 }
187
188 pub fn embed_batch_cached(
191 &mut self,
192 texts: &[&str],
193 chunk_size: usize,
194 ) -> Result<Vec<Embedding>, EmbeddingError> {
195 self.embed_batch_cached_paced(texts, chunk_size, std::time::Duration::ZERO)
196 }
197}