1#![deny(unsafe_code)]
12#![warn(missing_docs)]
13#![warn(rust_2018_idioms)]
14
15use std::path::Path;
16use std::time::{SystemTime, UNIX_EPOCH};
17
18use redb::{Database, ReadableTable, ReadableTableMetadata, TableDefinition};
19use serde::{Deserialize, Serialize};
20use thiserror::Error;
21
22const TABLE: TableDefinition<'_, &[u8; 32], Vec<u8>> = TableDefinition::new("embeddings");
23
24pub type Result<T> = std::result::Result<T, CacheError>;
26
27#[derive(Error, Debug)]
29pub enum CacheError {
30 #[error("redb error: {0}")]
32 Redb(String),
33 #[error("io error: {0}")]
35 Io(#[from] std::io::Error),
36 #[error("malformed entry: {0}")]
38 Malformed(String),
39 #[error("invalid config: {0}")]
41 InvalidConfig(String),
42}
43
44macro_rules! redb_from {
47 ($($t:ty),+ $(,)?) => {$(
48 impl From<$t> for CacheError {
49 fn from(e: $t) -> Self { Self::Redb(e.to_string()) }
50 }
51 )+};
52}
53redb_from!(
54 redb::Error,
55 redb::DatabaseError,
56 redb::TransactionError,
57 redb::TableError,
58 redb::StorageError,
59 redb::CommitError,
60);
61
62pub struct Cache {
64 db: Database,
65 ttl_secs: Option<u64>,
66}
67
68#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
70pub struct CacheStats {
71 pub entries: u64,
73 pub value_bytes: u64,
75 pub disk_bytes: u64,
77}
78
79impl Cache {
80 pub fn open<P: AsRef<Path>>(path: P) -> Result<Self> {
82 Self::open_with_ttl(path, None)
83 }
84
85 pub fn open_with_ttl<P: AsRef<Path>>(path: P, ttl_secs: Option<u64>) -> Result<Self> {
87 if let Some(ttl) = ttl_secs {
88 if ttl == 0 {
89 return Err(CacheError::InvalidConfig(
90 "ttl_secs must be > 0 (or None for no expiry)".into(),
91 ));
92 }
93 }
94 if let Some(parent) = path.as_ref().parent() {
95 if !parent.as_os_str().is_empty() {
96 std::fs::create_dir_all(parent)?;
97 }
98 }
99 let db = Database::create(path.as_ref())?;
100 let txn = db.begin_write()?;
102 {
103 let _t = txn.open_table(TABLE)?;
104 }
105 txn.commit()?;
106 Ok(Self { db, ttl_secs })
107 }
108
109 pub fn key(model: &str, text: &str) -> [u8; 32] {
111 let mut hasher = blake3::Hasher::new();
112 hasher.update(model.as_bytes());
113 hasher.update(&[0u8]);
114 hasher.update(text.as_bytes());
115 *hasher.finalize().as_bytes()
116 }
117
118 pub fn get(&self, model: &str, text: &str) -> Result<Option<Vec<f32>>> {
120 let key = Self::key(model, text);
121 let now = unix_now();
122 let txn = self.db.begin_read()?;
123 let table = txn.open_table(TABLE)?;
124 let Some(stored) = table.get(&key)? else {
125 return Ok(None);
126 };
127 let bytes = stored.value();
128 let (inserted_at, vec) = decode_entry(&bytes)?;
129 if let Some(ttl) = self.ttl_secs {
130 if now.saturating_sub(inserted_at) >= ttl {
134 return Ok(None);
135 }
136 }
137 Ok(Some(vec))
138 }
139
140 pub fn put(&self, model: &str, text: &str, vector: &[f32]) -> Result<()> {
142 let key = Self::key(model, text);
143 let bytes = encode_entry(unix_now(), vector);
144 let txn = self.db.begin_write()?;
145 {
146 let mut table = txn.open_table(TABLE)?;
147 table.insert(&key, bytes)?;
148 }
149 txn.commit()?;
150 Ok(())
151 }
152
153 pub fn remove(&self, model: &str, text: &str) -> Result<bool> {
155 let key = Self::key(model, text);
156 let txn = self.db.begin_write()?;
157 let removed = {
158 let mut table = txn.open_table(TABLE)?;
159 let prev = table.remove(&key)?;
162 prev.is_some()
163 };
164 txn.commit()?;
165 Ok(removed)
166 }
167
168 pub fn clear(&self) -> Result<u64> {
170 let txn = self.db.begin_write()?;
171 let removed = {
172 let mut table = txn.open_table(TABLE)?;
173 let keys: Vec<[u8; 32]> = table
174 .iter()?
175 .filter_map(|r| r.ok().map(|(k, _)| *k.value()))
176 .collect();
177 for k in &keys {
178 let _ = table.remove(k)?;
179 }
180 keys.len() as u64
181 };
182 txn.commit()?;
183 Ok(removed)
184 }
185
186 pub fn purge_expired(&self) -> Result<u64> {
189 let Some(ttl) = self.ttl_secs else {
190 return Ok(0);
191 };
192 let now = unix_now();
193 let txn = self.db.begin_write()?;
194 let removed = {
195 let mut table = txn.open_table(TABLE)?;
196 let mut victims: Vec<[u8; 32]> = Vec::new();
197 for entry in table.iter()? {
198 let (k, v) = entry?;
199 let bytes = v.value();
200 if bytes.len() < 8 {
201 continue;
202 }
203 let inserted = u64::from_le_bytes(bytes[0..8].try_into().unwrap());
204 if now.saturating_sub(inserted) >= ttl {
205 victims.push(*k.value());
206 }
207 }
208 for k in &victims {
209 table.remove(k)?;
210 }
211 victims.len() as u64
212 };
213 txn.commit()?;
214 Ok(removed)
215 }
216
217 pub fn purge_to_size(&self, target_bytes: u64) -> Result<u64> {
220 let txn = self.db.begin_write()?;
221 let removed = {
222 let mut table = txn.open_table(TABLE)?;
223 let mut all: Vec<(u64, [u8; 32], u64)> = Vec::new();
225 let mut total: u64 = 0;
226 for entry in table.iter()? {
227 let (k, v) = entry?;
228 let bytes = v.value();
229 if bytes.len() < 8 {
230 continue;
231 }
232 let inserted = u64::from_le_bytes(bytes[0..8].try_into().unwrap());
233 let size = bytes.len() as u64;
234 total += size;
235 all.push((inserted, *k.value(), size));
236 }
237 if total <= target_bytes {
238 return Ok(0);
239 }
240 all.sort_by_key(|(t, _, _)| *t);
241 let mut removed = 0u64;
242 for (_, k, size) in &all {
243 if total <= target_bytes {
244 break;
245 }
246 table.remove(k)?;
247 total = total.saturating_sub(*size);
248 removed += 1;
249 }
250 removed
251 };
252 txn.commit()?;
253 Ok(removed)
254 }
255
256 pub fn stats(&self) -> Result<CacheStats> {
258 let txn = self.db.begin_read()?;
259 let table = txn.open_table(TABLE)?;
260 let entries = table.len()?;
261 let mut value_bytes = 0u64;
262 for entry in table.iter()? {
263 let (_, v) = entry?;
264 value_bytes += v.value().len() as u64;
265 }
266 let disk_bytes = self.disk_size();
267 Ok(CacheStats {
268 entries,
269 value_bytes,
270 disk_bytes,
271 })
272 }
273
274 pub fn len(&self) -> Result<u64> {
276 let txn = self.db.begin_read()?;
277 let table = txn.open_table(TABLE)?;
278 Ok(table.len()?)
279 }
280
281 pub fn is_empty(&self) -> Result<bool> {
283 Ok(self.len()? == 0)
284 }
285
286 fn disk_size(&self) -> u64 {
287 0
291 }
292}
293
294fn unix_now() -> u64 {
295 SystemTime::now()
296 .duration_since(UNIX_EPOCH)
297 .map(|d| d.as_secs())
298 .unwrap_or(0)
299}
300
301fn encode_entry(inserted_at: u64, vec: &[f32]) -> Vec<u8> {
302 let dim = vec.len() as u32;
303 let mut out = Vec::with_capacity(8 + 4 + vec.len() * 4);
304 out.extend_from_slice(&inserted_at.to_le_bytes());
305 out.extend_from_slice(&dim.to_le_bytes());
306 for &x in vec {
307 out.extend_from_slice(&x.to_le_bytes());
308 }
309 out
310}
311
312fn decode_entry(bytes: &[u8]) -> Result<(u64, Vec<f32>)> {
313 if bytes.len() < 12 {
314 return Err(CacheError::Malformed("entry shorter than header".into()));
315 }
316 let inserted = u64::from_le_bytes(bytes[0..8].try_into().unwrap());
317 let dim = u32::from_le_bytes(bytes[8..12].try_into().unwrap()) as usize;
318 let expected = 12 + dim * 4;
319 if bytes.len() != expected {
320 return Err(CacheError::Malformed(format!(
321 "entry length {}, expected {}",
322 bytes.len(),
323 expected
324 )));
325 }
326 let mut vec = Vec::with_capacity(dim);
327 for i in 0..dim {
328 let off = 12 + i * 4;
329 vec.push(f32::from_le_bytes(bytes[off..off + 4].try_into().unwrap()));
330 }
331 Ok((inserted, vec))
332}
333
334#[cfg(test)]
335mod tests {
336 use super::*;
337
338 fn tempdb() -> (tempfile::TempDir, std::path::PathBuf) {
339 let dir = tempfile::tempdir().unwrap();
340 let path = dir.path().join("cache.redb");
341 (dir, path)
342 }
343
344 #[test]
345 fn key_changes_with_model_or_text() {
346 let a = Cache::key("m1", "hello");
347 let b = Cache::key("m2", "hello");
348 let c = Cache::key("m1", "world");
349 assert_ne!(a, b);
350 assert_ne!(a, c);
351 assert_eq!(a, Cache::key("m1", "hello"));
352 }
353
354 #[test]
355 fn key_separator_blocks_concatenation_collision() {
356 let a = Cache::key("a", "bc");
358 let b = Cache::key("ab", "c");
359 assert_ne!(a, b);
360 }
361
362 #[test]
363 fn put_then_get_round_trips() {
364 let (_dir, path) = tempdb();
365 let cache = Cache::open(&path).unwrap();
366 let v = vec![0.1, 0.2, 0.3];
367 cache.put("m", "hello", &v).unwrap();
368 assert_eq!(cache.get("m", "hello").unwrap(), Some(v));
369 }
370
371 #[test]
372 fn get_missing_returns_none() {
373 let (_dir, path) = tempdb();
374 let cache = Cache::open(&path).unwrap();
375 assert_eq!(cache.get("m", "nope").unwrap(), None);
376 }
377
378 #[test]
379 fn put_overwrites_existing_entry() {
380 let (_dir, path) = tempdb();
381 let cache = Cache::open(&path).unwrap();
382 cache.put("m", "k", &[1.0, 2.0]).unwrap();
383 cache.put("m", "k", &[3.0, 4.0, 5.0]).unwrap();
384 assert_eq!(cache.get("m", "k").unwrap(), Some(vec![3.0, 4.0, 5.0]));
385 }
386
387 #[test]
388 fn remove_returns_true_when_present() {
389 let (_dir, path) = tempdb();
390 let cache = Cache::open(&path).unwrap();
391 cache.put("m", "k", &[1.0]).unwrap();
392 assert!(cache.remove("m", "k").unwrap());
393 assert!(!cache.remove("m", "k").unwrap());
394 }
395
396 #[test]
397 fn clear_removes_all() {
398 let (_dir, path) = tempdb();
399 let cache = Cache::open(&path).unwrap();
400 for i in 0..10 {
401 cache.put("m", &format!("k{i}"), &[i as f32]).unwrap();
402 }
403 assert_eq!(cache.len().unwrap(), 10);
404 cache.clear().unwrap();
405 assert_eq!(cache.len().unwrap(), 0);
406 }
407
408 #[test]
409 fn purge_to_size_evicts_oldest() {
410 let (_dir, path) = tempdb();
411 let cache = Cache::open(&path).unwrap();
412 for i in 0..10 {
414 cache.put("m", &format!("k{i}"), &[i as f32]).unwrap();
415 }
420 let removed = cache.purge_to_size(32).unwrap();
422 assert!(removed > 0, "expected at least 1 eviction");
423 let stats = cache.stats().unwrap();
424 assert!(stats.value_bytes <= 32, "value_bytes={}", stats.value_bytes);
425 }
426
427 #[test]
428 fn ttl_zero_rejected() {
429 let (_dir, path) = tempdb();
430 let err = Cache::open_with_ttl(&path, Some(0));
431 assert!(err.is_err());
432 }
433
434 #[test]
435 fn malformed_entry_rejected() {
436 let bad = vec![0u8; 5];
439 let r = decode_entry(&bad);
440 assert!(r.is_err());
441 }
442
443 #[test]
444 fn encode_decode_round_trip() {
445 let v = vec![1.0_f32, -2.5, 3.125, f32::MIN, f32::MAX];
446 let bytes = encode_entry(123, &v);
447 let (t, decoded) = decode_entry(&bytes).unwrap();
448 assert_eq!(t, 123);
449 assert_eq!(decoded, v);
450 }
451}