1#![deny(unsafe_code)]
12#![warn(missing_docs)]
13#![warn(rust_2018_idioms)]
14
15use std::path::Path;
16use std::time::{SystemTime, UNIX_EPOCH};
17
18use redb::{Database, ReadableTable, ReadableTableMetadata, TableDefinition};
19use serde::{Deserialize, Serialize};
20use thiserror::Error;
21
22const TABLE: TableDefinition<'_, &[u8; 32], Vec<u8>> = TableDefinition::new("embeddings");
23
24pub type Result<T> = std::result::Result<T, CacheError>;
26
27#[derive(Error, Debug)]
29pub enum CacheError {
30 #[error("redb error: {0}")]
32 Redb(String),
33 #[error("io error: {0}")]
35 Io(#[from] std::io::Error),
36 #[error("malformed entry: {0}")]
38 Malformed(String),
39 #[error("invalid config: {0}")]
41 InvalidConfig(String),
42}
43
44macro_rules! redb_from {
47 ($($t:ty),+ $(,)?) => {$(
48 impl From<$t> for CacheError {
49 fn from(e: $t) -> Self { Self::Redb(e.to_string()) }
50 }
51 )+};
52}
53redb_from!(
54 redb::Error,
55 redb::DatabaseError,
56 redb::TransactionError,
57 redb::TableError,
58 redb::StorageError,
59 redb::CommitError,
60);
61
62pub struct Cache {
64 db: Database,
65 ttl_secs: Option<u64>,
66 path: std::path::PathBuf,
70}
71
72#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
74pub struct CacheStats {
75 pub entries: u64,
77 pub value_bytes: u64,
79 pub disk_bytes: u64,
81}
82
83impl Cache {
84 pub fn open<P: AsRef<Path>>(path: P) -> Result<Self> {
86 Self::open_with_ttl(path, None)
87 }
88
89 pub fn open_with_ttl<P: AsRef<Path>>(path: P, ttl_secs: Option<u64>) -> Result<Self> {
91 if let Some(ttl) = ttl_secs {
92 if ttl == 0 {
93 return Err(CacheError::InvalidConfig(
94 "ttl_secs must be > 0 (or None for no expiry)".into(),
95 ));
96 }
97 }
98 if let Some(parent) = path.as_ref().parent() {
99 if !parent.as_os_str().is_empty() {
100 std::fs::create_dir_all(parent)?;
101 }
102 }
103 let db = Database::create(path.as_ref())?;
104 let txn = db.begin_write()?;
106 {
107 let _t = txn.open_table(TABLE)?;
108 }
109 txn.commit()?;
110 Ok(Self {
111 db,
112 ttl_secs,
113 path: path.as_ref().to_path_buf(),
114 })
115 }
116
117 pub fn key(model: &str, text: &str) -> [u8; 32] {
119 let mut hasher = blake3::Hasher::new();
120 hasher.update(model.as_bytes());
121 hasher.update(&[0u8]);
122 hasher.update(text.as_bytes());
123 *hasher.finalize().as_bytes()
124 }
125
126 pub fn get(&self, model: &str, text: &str) -> Result<Option<Vec<f32>>> {
128 let key = Self::key(model, text);
129 let now = unix_now();
130 let txn = self.db.begin_read()?;
131 let table = txn.open_table(TABLE)?;
132 let Some(stored) = table.get(&key)? else {
133 return Ok(None);
134 };
135 let bytes = stored.value();
136 let (inserted_at, vec) = decode_entry(&bytes)?;
137 if let Some(ttl) = self.ttl_secs {
138 if now.saturating_sub(inserted_at) >= ttl {
142 return Ok(None);
143 }
144 }
145 Ok(Some(vec))
146 }
147
148 pub fn put(&self, model: &str, text: &str, vector: &[f32]) -> Result<()> {
150 let key = Self::key(model, text);
151 let bytes = encode_entry(unix_now(), vector);
152 let txn = self.db.begin_write()?;
153 {
154 let mut table = txn.open_table(TABLE)?;
155 table.insert(&key, bytes)?;
156 }
157 txn.commit()?;
158 Ok(())
159 }
160
161 pub fn remove(&self, model: &str, text: &str) -> Result<bool> {
163 let key = Self::key(model, text);
164 let txn = self.db.begin_write()?;
165 let removed = {
166 let mut table = txn.open_table(TABLE)?;
167 let prev = table.remove(&key)?;
170 prev.is_some()
171 };
172 txn.commit()?;
173 Ok(removed)
174 }
175
176 pub fn clear(&self) -> Result<u64> {
178 let txn = self.db.begin_write()?;
179 let removed = {
180 let mut table = txn.open_table(TABLE)?;
181 let keys: Vec<[u8; 32]> = table
182 .iter()?
183 .filter_map(|r| r.ok().map(|(k, _)| *k.value()))
184 .collect();
185 for k in &keys {
186 let _ = table.remove(k)?;
187 }
188 keys.len() as u64
189 };
190 txn.commit()?;
191 Ok(removed)
192 }
193
194 pub fn purge_expired(&self) -> Result<u64> {
197 let Some(ttl) = self.ttl_secs else {
198 return Ok(0);
199 };
200 let now = unix_now();
201 let txn = self.db.begin_write()?;
202 let removed = {
203 let mut table = txn.open_table(TABLE)?;
204 let mut victims: Vec<[u8; 32]> = Vec::new();
205 for entry in table.iter()? {
206 let (k, v) = entry?;
207 let bytes = v.value();
208 if bytes.len() < 8 {
209 continue;
210 }
211 let inserted = u64::from_le_bytes(bytes[0..8].try_into().unwrap());
212 if now.saturating_sub(inserted) >= ttl {
213 victims.push(*k.value());
214 }
215 }
216 for k in &victims {
217 table.remove(k)?;
218 }
219 victims.len() as u64
220 };
221 txn.commit()?;
222 Ok(removed)
223 }
224
225 pub fn purge_to_size(&self, target_bytes: u64) -> Result<u64> {
228 let txn = self.db.begin_write()?;
229 let removed = {
230 let mut table = txn.open_table(TABLE)?;
231 let mut all: Vec<(u64, [u8; 32], u64)> = Vec::new();
233 let mut total: u64 = 0;
234 for entry in table.iter()? {
235 let (k, v) = entry?;
236 let bytes = v.value();
237 if bytes.len() < 8 {
238 continue;
239 }
240 let inserted = u64::from_le_bytes(bytes[0..8].try_into().unwrap());
241 let size = bytes.len() as u64;
242 total += size;
243 all.push((inserted, *k.value(), size));
244 }
245 if total <= target_bytes {
246 return Ok(0);
247 }
248 all.sort_by_key(|(t, _, _)| *t);
249 let mut removed = 0u64;
250 for (_, k, size) in &all {
251 if total <= target_bytes {
252 break;
253 }
254 table.remove(k)?;
255 total = total.saturating_sub(*size);
256 removed += 1;
257 }
258 removed
259 };
260 txn.commit()?;
261 Ok(removed)
262 }
263
264 pub fn stats(&self) -> Result<CacheStats> {
266 let txn = self.db.begin_read()?;
267 let table = txn.open_table(TABLE)?;
268 let entries = table.len()?;
269 let mut value_bytes = 0u64;
270 for entry in table.iter()? {
271 let (_, v) = entry?;
272 value_bytes += v.value().len() as u64;
273 }
274 let disk_bytes = self.disk_size();
275 Ok(CacheStats {
276 entries,
277 value_bytes,
278 disk_bytes,
279 })
280 }
281
282 pub fn len(&self) -> Result<u64> {
284 let txn = self.db.begin_read()?;
285 let table = txn.open_table(TABLE)?;
286 Ok(table.len()?)
287 }
288
289 pub fn is_empty(&self) -> Result<bool> {
291 Ok(self.len()? == 0)
292 }
293
294 fn disk_size(&self) -> u64 {
295 std::fs::metadata(&self.path).map(|m| m.len()).unwrap_or(0)
300 }
301
302 pub fn path(&self) -> &std::path::Path {
304 &self.path
305 }
306}
307
308fn unix_now() -> u64 {
309 SystemTime::now()
310 .duration_since(UNIX_EPOCH)
311 .map(|d| d.as_secs())
312 .unwrap_or(0)
313}
314
315fn encode_entry(inserted_at: u64, vec: &[f32]) -> Vec<u8> {
316 let dim = vec.len() as u32;
317 let mut out = Vec::with_capacity(8 + 4 + vec.len() * 4);
318 out.extend_from_slice(&inserted_at.to_le_bytes());
319 out.extend_from_slice(&dim.to_le_bytes());
320 for &x in vec {
321 out.extend_from_slice(&x.to_le_bytes());
322 }
323 out
324}
325
326fn decode_entry(bytes: &[u8]) -> Result<(u64, Vec<f32>)> {
327 if bytes.len() < 12 {
328 return Err(CacheError::Malformed("entry shorter than header".into()));
329 }
330 let inserted = u64::from_le_bytes(bytes[0..8].try_into().unwrap());
331 let dim = u32::from_le_bytes(bytes[8..12].try_into().unwrap()) as usize;
332 let expected = 12 + dim * 4;
333 if bytes.len() != expected {
334 return Err(CacheError::Malformed(format!(
335 "entry length {}, expected {}",
336 bytes.len(),
337 expected
338 )));
339 }
340 let mut vec = Vec::with_capacity(dim);
341 for i in 0..dim {
342 let off = 12 + i * 4;
343 vec.push(f32::from_le_bytes(bytes[off..off + 4].try_into().unwrap()));
344 }
345 Ok((inserted, vec))
346}
347
348#[cfg(test)]
349mod tests {
350 use super::*;
351
352 fn tempdb() -> (tempfile::TempDir, std::path::PathBuf) {
353 let dir = tempfile::tempdir().unwrap();
354 let path = dir.path().join("cache.redb");
355 (dir, path)
356 }
357
358 #[test]
359 fn key_changes_with_model_or_text() {
360 let a = Cache::key("m1", "hello");
361 let b = Cache::key("m2", "hello");
362 let c = Cache::key("m1", "world");
363 assert_ne!(a, b);
364 assert_ne!(a, c);
365 assert_eq!(a, Cache::key("m1", "hello"));
366 }
367
368 #[test]
369 fn key_separator_blocks_concatenation_collision() {
370 let a = Cache::key("a", "bc");
372 let b = Cache::key("ab", "c");
373 assert_ne!(a, b);
374 }
375
376 #[test]
377 fn put_then_get_round_trips() {
378 let (_dir, path) = tempdb();
379 let cache = Cache::open(&path).unwrap();
380 let v = vec![0.1, 0.2, 0.3];
381 cache.put("m", "hello", &v).unwrap();
382 assert_eq!(cache.get("m", "hello").unwrap(), Some(v));
383 }
384
385 #[test]
386 fn get_missing_returns_none() {
387 let (_dir, path) = tempdb();
388 let cache = Cache::open(&path).unwrap();
389 assert_eq!(cache.get("m", "nope").unwrap(), None);
390 }
391
392 #[test]
393 fn put_overwrites_existing_entry() {
394 let (_dir, path) = tempdb();
395 let cache = Cache::open(&path).unwrap();
396 cache.put("m", "k", &[1.0, 2.0]).unwrap();
397 cache.put("m", "k", &[3.0, 4.0, 5.0]).unwrap();
398 assert_eq!(cache.get("m", "k").unwrap(), Some(vec![3.0, 4.0, 5.0]));
399 }
400
401 #[test]
402 fn remove_returns_true_when_present() {
403 let (_dir, path) = tempdb();
404 let cache = Cache::open(&path).unwrap();
405 cache.put("m", "k", &[1.0]).unwrap();
406 assert!(cache.remove("m", "k").unwrap());
407 assert!(!cache.remove("m", "k").unwrap());
408 }
409
410 #[test]
411 fn clear_removes_all() {
412 let (_dir, path) = tempdb();
413 let cache = Cache::open(&path).unwrap();
414 for i in 0..10 {
415 cache.put("m", &format!("k{i}"), &[i as f32]).unwrap();
416 }
417 assert_eq!(cache.len().unwrap(), 10);
418 cache.clear().unwrap();
419 assert_eq!(cache.len().unwrap(), 0);
420 }
421
422 #[test]
423 fn purge_to_size_evicts_oldest() {
424 let (_dir, path) = tempdb();
425 let cache = Cache::open(&path).unwrap();
426 for i in 0..10 {
428 cache.put("m", &format!("k{i}"), &[i as f32]).unwrap();
429 }
434 let removed = cache.purge_to_size(32).unwrap();
436 assert!(removed > 0, "expected at least 1 eviction");
437 let stats = cache.stats().unwrap();
438 assert!(stats.value_bytes <= 32, "value_bytes={}", stats.value_bytes);
439 }
440
441 #[test]
442 fn ttl_zero_rejected() {
443 let (_dir, path) = tempdb();
444 let err = Cache::open_with_ttl(&path, Some(0));
445 assert!(err.is_err());
446 }
447
448 #[test]
449 fn disk_bytes_reflects_real_file_size() {
450 let (_dir, path) = tempdb();
451 let cache = Cache::open(&path).unwrap();
452 cache.put("m", "k", &[1.0_f32, 2.0, 3.0]).unwrap();
453 let s = cache.stats().unwrap();
454 assert!(s.disk_bytes > 0, "disk_bytes should be > 0 after writes");
455 }
456
457 #[test]
458 fn path_accessor_returns_open_path() {
459 let (_dir, path) = tempdb();
460 let cache = Cache::open(&path).unwrap();
461 assert_eq!(cache.path(), path.as_path());
462 }
463
464 #[test]
465 fn malformed_entry_rejected() {
466 let bad = vec![0u8; 5];
469 let r = decode_entry(&bad);
470 assert!(r.is_err());
471 }
472
473 #[test]
474 fn encode_decode_round_trip() {
475 let v = vec![1.0_f32, -2.5, 3.125, f32::MIN, f32::MAX];
476 let bytes = encode_entry(123, &v);
477 let (t, decoded) = decode_entry(&bytes).unwrap();
478 assert_eq!(t, 123);
479 assert_eq!(decoded, v);
480 }
481}