Skip to main content

embedcache_core/
lib.rs

1//! Pure-Rust core for `embedcache`: a content-addressed embedding cache
2//! backed by an embedded [redb](https://crates.io/crates/redb) KV store.
3//!
4//! - **Key.** A 32-byte blake3 hash of `(model_name, 0x00, text)`. The null
5//!   separator prevents `(model="a", text="bc")` from colliding with
6//!   `(model="ab", text="c")`.
7//! - **Value.** `[u64 inserted_at_secs LE][u32 dim LE][dim x f32 LE]`.
8//! - **TTL.** Optional. Evaluated on `get`; expired entries are returned as
9//!   `None` and removed by `purge_expired`.
10
11#![deny(unsafe_code)]
12#![warn(missing_docs)]
13#![warn(rust_2018_idioms)]
14
15use std::path::Path;
16use std::time::{SystemTime, UNIX_EPOCH};
17
18use redb::{Database, ReadableTable, ReadableTableMetadata, TableDefinition};
19use serde::{Deserialize, Serialize};
20use thiserror::Error;
21
22const TABLE: TableDefinition<'_, &[u8; 32], Vec<u8>> = TableDefinition::new("embeddings");
23
24/// Crate-wide result alias.
25pub type Result<T> = std::result::Result<T, CacheError>;
26
27/// All errors surfaced by `embedcache-core`.
28#[derive(Error, Debug)]
29pub enum CacheError {
30    /// Failure inside the redb store.
31    #[error("redb error: {0}")]
32    Redb(String),
33    /// I/O failure opening the cache directory or file.
34    #[error("io error: {0}")]
35    Io(#[from] std::io::Error),
36    /// A stored value is shorter than its declared header.
37    #[error("malformed entry: {0}")]
38    Malformed(String),
39    /// Caller supplied an invalid configuration.
40    #[error("invalid config: {0}")]
41    InvalidConfig(String),
42}
43
44// redb has half a dozen error types. Collapse them into one variant; the
45// string is what the caller will print anyway.
46macro_rules! redb_from {
47    ($($t:ty),+ $(,)?) => {$(
48        impl From<$t> for CacheError {
49            fn from(e: $t) -> Self { Self::Redb(e.to_string()) }
50        }
51    )+};
52}
53redb_from!(
54    redb::Error,
55    redb::DatabaseError,
56    redb::TransactionError,
57    redb::TableError,
58    redb::StorageError,
59    redb::CommitError,
60);
61
62/// On-disk content-addressed embedding cache.
63pub struct Cache {
64    db: Database,
65    ttl_secs: Option<u64>,
66    /// Path to the underlying redb file. Stashed at open time so `stats`
67    /// can report `disk_bytes` accurately (redb itself does not surface
68    /// the path back to callers).
69    path: std::path::PathBuf,
70}
71
72/// Cache stats returned by [`Cache::stats`].
73#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
74pub struct CacheStats {
75    /// Number of entries currently stored.
76    pub entries: u64,
77    /// Total raw value bytes (excluding redb overhead).
78    pub value_bytes: u64,
79    /// File size of the database on disk in bytes.
80    pub disk_bytes: u64,
81}
82
83impl Cache {
84    /// Open or create a cache at `path` with no TTL.
85    pub fn open<P: AsRef<Path>>(path: P) -> Result<Self> {
86        Self::open_with_ttl(path, None)
87    }
88
89    /// Open or create a cache at `path` with an optional TTL in seconds.
90    pub fn open_with_ttl<P: AsRef<Path>>(path: P, ttl_secs: Option<u64>) -> Result<Self> {
91        if let Some(ttl) = ttl_secs {
92            if ttl == 0 {
93                return Err(CacheError::InvalidConfig(
94                    "ttl_secs must be > 0 (or None for no expiry)".into(),
95                ));
96            }
97        }
98        if let Some(parent) = path.as_ref().parent() {
99            if !parent.as_os_str().is_empty() {
100                std::fs::create_dir_all(parent)?;
101            }
102        }
103        let db = Database::create(path.as_ref())?;
104        // Create the table if it doesn't exist.
105        let txn = db.begin_write()?;
106        {
107            let _t = txn.open_table(TABLE)?;
108        }
109        txn.commit()?;
110        Ok(Self {
111            db,
112            ttl_secs,
113            path: path.as_ref().to_path_buf(),
114        })
115    }
116
117    /// 32-byte content-addressed key for `(model, text)`.
118    pub fn key(model: &str, text: &str) -> [u8; 32] {
119        let mut hasher = blake3::Hasher::new();
120        hasher.update(model.as_bytes());
121        hasher.update(&[0u8]);
122        hasher.update(text.as_bytes());
123        *hasher.finalize().as_bytes()
124    }
125
126    /// Look up a vector. Returns `None` if absent or expired.
127    pub fn get(&self, model: &str, text: &str) -> Result<Option<Vec<f32>>> {
128        let key = Self::key(model, text);
129        let now = unix_now();
130        let txn = self.db.begin_read()?;
131        let table = txn.open_table(TABLE)?;
132        let Some(stored) = table.get(&key)? else {
133            return Ok(None);
134        };
135        let bytes = stored.value();
136        let (inserted_at, vec) = decode_entry(&bytes)?;
137        if let Some(ttl) = self.ttl_secs {
138            // `>=` so "ttl=N seconds" means the entry is dead after N seconds
139            // have elapsed. With `>` you'd see N+1 seconds of life, which is
140            // surprising at second-granularity timestamps.
141            if now.saturating_sub(inserted_at) >= ttl {
142                return Ok(None);
143            }
144        }
145        Ok(Some(vec))
146    }
147
148    /// Insert or overwrite a vector for `(model, text)`.
149    pub fn put(&self, model: &str, text: &str, vector: &[f32]) -> Result<()> {
150        let key = Self::key(model, text);
151        let bytes = encode_entry(unix_now(), vector);
152        let txn = self.db.begin_write()?;
153        {
154            let mut table = txn.open_table(TABLE)?;
155            table.insert(&key, bytes)?;
156        }
157        txn.commit()?;
158        Ok(())
159    }
160
161    /// Remove a single entry. Returns `true` if the key was present.
162    pub fn remove(&self, model: &str, text: &str) -> Result<bool> {
163        let key = Self::key(model, text);
164        let txn = self.db.begin_write()?;
165        let removed = {
166            let mut table = txn.open_table(TABLE)?;
167            // Bind the AccessGuard so its borrow of `table` ends before the
168            // block returns; otherwise the temporary outlives the table.
169            let prev = table.remove(&key)?;
170            prev.is_some()
171        };
172        txn.commit()?;
173        Ok(removed)
174    }
175
176    /// Remove every entry. Returns the number of entries removed.
177    pub fn clear(&self) -> Result<u64> {
178        let txn = self.db.begin_write()?;
179        let removed = {
180            let mut table = txn.open_table(TABLE)?;
181            let keys: Vec<[u8; 32]> = table
182                .iter()?
183                .filter_map(|r| r.ok().map(|(k, _)| *k.value()))
184                .collect();
185            for k in &keys {
186                let _ = table.remove(k)?;
187            }
188            keys.len() as u64
189        };
190        txn.commit()?;
191        Ok(removed)
192    }
193
194    /// Remove every entry whose `inserted_at + ttl < now`. Returns the count.
195    /// No-op when the cache has no TTL.
196    pub fn purge_expired(&self) -> Result<u64> {
197        let Some(ttl) = self.ttl_secs else {
198            return Ok(0);
199        };
200        let now = unix_now();
201        let txn = self.db.begin_write()?;
202        let removed = {
203            let mut table = txn.open_table(TABLE)?;
204            let mut victims: Vec<[u8; 32]> = Vec::new();
205            for entry in table.iter()? {
206                let (k, v) = entry?;
207                let bytes = v.value();
208                if bytes.len() < 8 {
209                    continue;
210                }
211                let inserted = u64::from_le_bytes(bytes[0..8].try_into().unwrap());
212                if now.saturating_sub(inserted) >= ttl {
213                    victims.push(*k.value());
214                }
215            }
216            for k in &victims {
217                table.remove(k)?;
218            }
219            victims.len() as u64
220        };
221        txn.commit()?;
222        Ok(removed)
223    }
224
225    /// Evict oldest entries until the total stored value bytes are
226    /// `<= target_bytes`. Returns the count removed.
227    pub fn purge_to_size(&self, target_bytes: u64) -> Result<u64> {
228        let txn = self.db.begin_write()?;
229        let removed = {
230            let mut table = txn.open_table(TABLE)?;
231            // Collect (inserted_at, key, size_bytes), sort oldest-first.
232            let mut all: Vec<(u64, [u8; 32], u64)> = Vec::new();
233            let mut total: u64 = 0;
234            for entry in table.iter()? {
235                let (k, v) = entry?;
236                let bytes = v.value();
237                if bytes.len() < 8 {
238                    continue;
239                }
240                let inserted = u64::from_le_bytes(bytes[0..8].try_into().unwrap());
241                let size = bytes.len() as u64;
242                total += size;
243                all.push((inserted, *k.value(), size));
244            }
245            if total <= target_bytes {
246                return Ok(0);
247            }
248            all.sort_by_key(|(t, _, _)| *t);
249            let mut removed = 0u64;
250            for (_, k, size) in &all {
251                if total <= target_bytes {
252                    break;
253                }
254                table.remove(k)?;
255                total = total.saturating_sub(*size);
256                removed += 1;
257            }
258            removed
259        };
260        txn.commit()?;
261        Ok(removed)
262    }
263
264    /// Counts and bytes for the cache.
265    pub fn stats(&self) -> Result<CacheStats> {
266        let txn = self.db.begin_read()?;
267        let table = txn.open_table(TABLE)?;
268        let entries = table.len()?;
269        let mut value_bytes = 0u64;
270        for entry in table.iter()? {
271            let (_, v) = entry?;
272            value_bytes += v.value().len() as u64;
273        }
274        let disk_bytes = self.disk_size();
275        Ok(CacheStats {
276            entries,
277            value_bytes,
278            disk_bytes,
279        })
280    }
281
282    /// Number of entries.
283    pub fn len(&self) -> Result<u64> {
284        let txn = self.db.begin_read()?;
285        let table = txn.open_table(TABLE)?;
286        Ok(table.len()?)
287    }
288
289    /// True if no entries are stored.
290    pub fn is_empty(&self) -> Result<bool> {
291        Ok(self.len()? == 0)
292    }
293
294    fn disk_size(&self) -> u64 {
295        // We stash the path at open time so we can `metadata()` it for an
296        // honest byte count. Returns 0 if the file vanished (e.g. another
297        // process unlinked it) — the cache itself is still usable since
298        // redb keeps an open fd.
299        std::fs::metadata(&self.path).map(|m| m.len()).unwrap_or(0)
300    }
301
302    /// Path to the underlying database file.
303    pub fn path(&self) -> &std::path::Path {
304        &self.path
305    }
306}
307
308fn unix_now() -> u64 {
309    SystemTime::now()
310        .duration_since(UNIX_EPOCH)
311        .map(|d| d.as_secs())
312        .unwrap_or(0)
313}
314
315fn encode_entry(inserted_at: u64, vec: &[f32]) -> Vec<u8> {
316    let dim = vec.len() as u32;
317    let mut out = Vec::with_capacity(8 + 4 + vec.len() * 4);
318    out.extend_from_slice(&inserted_at.to_le_bytes());
319    out.extend_from_slice(&dim.to_le_bytes());
320    for &x in vec {
321        out.extend_from_slice(&x.to_le_bytes());
322    }
323    out
324}
325
326fn decode_entry(bytes: &[u8]) -> Result<(u64, Vec<f32>)> {
327    if bytes.len() < 12 {
328        return Err(CacheError::Malformed("entry shorter than header".into()));
329    }
330    let inserted = u64::from_le_bytes(bytes[0..8].try_into().unwrap());
331    let dim = u32::from_le_bytes(bytes[8..12].try_into().unwrap()) as usize;
332    let expected = 12 + dim * 4;
333    if bytes.len() != expected {
334        return Err(CacheError::Malformed(format!(
335            "entry length {}, expected {}",
336            bytes.len(),
337            expected
338        )));
339    }
340    let mut vec = Vec::with_capacity(dim);
341    for i in 0..dim {
342        let off = 12 + i * 4;
343        vec.push(f32::from_le_bytes(bytes[off..off + 4].try_into().unwrap()));
344    }
345    Ok((inserted, vec))
346}
347
348#[cfg(test)]
349mod tests {
350    use super::*;
351
352    fn tempdb() -> (tempfile::TempDir, std::path::PathBuf) {
353        let dir = tempfile::tempdir().unwrap();
354        let path = dir.path().join("cache.redb");
355        (dir, path)
356    }
357
358    #[test]
359    fn key_changes_with_model_or_text() {
360        let a = Cache::key("m1", "hello");
361        let b = Cache::key("m2", "hello");
362        let c = Cache::key("m1", "world");
363        assert_ne!(a, b);
364        assert_ne!(a, c);
365        assert_eq!(a, Cache::key("m1", "hello"));
366    }
367
368    #[test]
369    fn key_separator_blocks_concatenation_collision() {
370        // Without a separator, ("a", "bc") and ("ab", "c") would collide.
371        let a = Cache::key("a", "bc");
372        let b = Cache::key("ab", "c");
373        assert_ne!(a, b);
374    }
375
376    #[test]
377    fn put_then_get_round_trips() {
378        let (_dir, path) = tempdb();
379        let cache = Cache::open(&path).unwrap();
380        let v = vec![0.1, 0.2, 0.3];
381        cache.put("m", "hello", &v).unwrap();
382        assert_eq!(cache.get("m", "hello").unwrap(), Some(v));
383    }
384
385    #[test]
386    fn get_missing_returns_none() {
387        let (_dir, path) = tempdb();
388        let cache = Cache::open(&path).unwrap();
389        assert_eq!(cache.get("m", "nope").unwrap(), None);
390    }
391
392    #[test]
393    fn put_overwrites_existing_entry() {
394        let (_dir, path) = tempdb();
395        let cache = Cache::open(&path).unwrap();
396        cache.put("m", "k", &[1.0, 2.0]).unwrap();
397        cache.put("m", "k", &[3.0, 4.0, 5.0]).unwrap();
398        assert_eq!(cache.get("m", "k").unwrap(), Some(vec![3.0, 4.0, 5.0]));
399    }
400
401    #[test]
402    fn remove_returns_true_when_present() {
403        let (_dir, path) = tempdb();
404        let cache = Cache::open(&path).unwrap();
405        cache.put("m", "k", &[1.0]).unwrap();
406        assert!(cache.remove("m", "k").unwrap());
407        assert!(!cache.remove("m", "k").unwrap());
408    }
409
410    #[test]
411    fn clear_removes_all() {
412        let (_dir, path) = tempdb();
413        let cache = Cache::open(&path).unwrap();
414        for i in 0..10 {
415            cache.put("m", &format!("k{i}"), &[i as f32]).unwrap();
416        }
417        assert_eq!(cache.len().unwrap(), 10);
418        cache.clear().unwrap();
419        assert_eq!(cache.len().unwrap(), 0);
420    }
421
422    #[test]
423    fn purge_to_size_evicts_oldest() {
424        let (_dir, path) = tempdb();
425        let cache = Cache::open(&path).unwrap();
426        // Each entry: 8 + 4 + 4 = 16 bytes value.
427        for i in 0..10 {
428            cache.put("m", &format!("k{i}"), &[i as f32]).unwrap();
429            // Spread inserted_at across calls. unix_now is whole seconds, so
430            // we sleep just enough to differentiate. Skip on cargo test in a
431            // single second; the eviction doesn't have to be in strict
432            // chronological order if all timestamps tie, just under target.
433        }
434        // Target small enough to force eviction.
435        let removed = cache.purge_to_size(32).unwrap();
436        assert!(removed > 0, "expected at least 1 eviction");
437        let stats = cache.stats().unwrap();
438        assert!(stats.value_bytes <= 32, "value_bytes={}", stats.value_bytes);
439    }
440
441    #[test]
442    fn ttl_zero_rejected() {
443        let (_dir, path) = tempdb();
444        let err = Cache::open_with_ttl(&path, Some(0));
445        assert!(err.is_err());
446    }
447
448    #[test]
449    fn disk_bytes_reflects_real_file_size() {
450        let (_dir, path) = tempdb();
451        let cache = Cache::open(&path).unwrap();
452        cache.put("m", "k", &[1.0_f32, 2.0, 3.0]).unwrap();
453        let s = cache.stats().unwrap();
454        assert!(s.disk_bytes > 0, "disk_bytes should be > 0 after writes");
455    }
456
457    #[test]
458    fn path_accessor_returns_open_path() {
459        let (_dir, path) = tempdb();
460        let cache = Cache::open(&path).unwrap();
461        assert_eq!(cache.path(), path.as_path());
462    }
463
464    #[test]
465    fn malformed_entry_rejected() {
466        // decode_entry directly, since we cannot write a malformed entry
467        // through the public API.
468        let bad = vec![0u8; 5];
469        let r = decode_entry(&bad);
470        assert!(r.is_err());
471    }
472
473    #[test]
474    fn encode_decode_round_trip() {
475        let v = vec![1.0_f32, -2.5, 3.125, f32::MIN, f32::MAX];
476        let bytes = encode_entry(123, &v);
477        let (t, decoded) = decode_entry(&bytes).unwrap();
478        assert_eq!(t, 123);
479        assert_eq!(decoded, v);
480    }
481}