Skip to main content

embedcache_core/
lib.rs

1//! Pure-Rust core for `embedcache`: a content-addressed embedding cache
2//! backed by an embedded [redb](https://crates.io/crates/redb) KV store.
3//!
4//! - **Key.** A 32-byte blake3 hash of `(model_name, 0x00, text)`. The null
5//!   separator prevents `(model="a", text="bc")` from colliding with
6//!   `(model="ab", text="c")`.
7//! - **Value.** `[u64 inserted_at_secs LE][u32 dim LE][dim x f32 LE]`.
8//! - **TTL.** Optional. Evaluated on `get`; expired entries are returned as
9//!   `None` and removed by `purge_expired`.
10
11#![deny(unsafe_code)]
12#![warn(missing_docs)]
13#![warn(rust_2018_idioms)]
14
15use std::path::Path;
16use std::time::{SystemTime, UNIX_EPOCH};
17
18use redb::{Database, ReadableTable, ReadableTableMetadata, TableDefinition};
19use serde::{Deserialize, Serialize};
20use thiserror::Error;
21
22const TABLE: TableDefinition<'_, &[u8; 32], Vec<u8>> = TableDefinition::new("embeddings");
23
24/// Crate-wide result alias.
25pub type Result<T> = std::result::Result<T, CacheError>;
26
27/// All errors surfaced by `embedcache-core`.
28#[derive(Error, Debug)]
29pub enum CacheError {
30    /// Failure inside the redb store.
31    #[error("redb error: {0}")]
32    Redb(String),
33    /// I/O failure opening the cache directory or file.
34    #[error("io error: {0}")]
35    Io(#[from] std::io::Error),
36    /// A stored value is shorter than its declared header.
37    #[error("malformed entry: {0}")]
38    Malformed(String),
39    /// Caller supplied an invalid configuration.
40    #[error("invalid config: {0}")]
41    InvalidConfig(String),
42}
43
44// redb has half a dozen error types. Collapse them into one variant; the
45// string is what the caller will print anyway.
46macro_rules! redb_from {
47    ($($t:ty),+ $(,)?) => {$(
48        impl From<$t> for CacheError {
49            fn from(e: $t) -> Self { Self::Redb(e.to_string()) }
50        }
51    )+};
52}
53redb_from!(
54    redb::Error,
55    redb::DatabaseError,
56    redb::TransactionError,
57    redb::TableError,
58    redb::StorageError,
59    redb::CommitError,
60);
61
62/// On-disk content-addressed embedding cache.
63pub struct Cache {
64    db: Database,
65    ttl_secs: Option<u64>,
66}
67
68/// Cache stats returned by [`Cache::stats`].
69#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
70pub struct CacheStats {
71    /// Number of entries currently stored.
72    pub entries: u64,
73    /// Total raw value bytes (excluding redb overhead).
74    pub value_bytes: u64,
75    /// File size of the database on disk in bytes.
76    pub disk_bytes: u64,
77}
78
79impl Cache {
80    /// Open or create a cache at `path` with no TTL.
81    pub fn open<P: AsRef<Path>>(path: P) -> Result<Self> {
82        Self::open_with_ttl(path, None)
83    }
84
85    /// Open or create a cache at `path` with an optional TTL in seconds.
86    pub fn open_with_ttl<P: AsRef<Path>>(path: P, ttl_secs: Option<u64>) -> Result<Self> {
87        if let Some(ttl) = ttl_secs {
88            if ttl == 0 {
89                return Err(CacheError::InvalidConfig(
90                    "ttl_secs must be > 0 (or None for no expiry)".into(),
91                ));
92            }
93        }
94        if let Some(parent) = path.as_ref().parent() {
95            if !parent.as_os_str().is_empty() {
96                std::fs::create_dir_all(parent)?;
97            }
98        }
99        let db = Database::create(path.as_ref())?;
100        // Create the table if it doesn't exist.
101        let txn = db.begin_write()?;
102        {
103            let _t = txn.open_table(TABLE)?;
104        }
105        txn.commit()?;
106        Ok(Self { db, ttl_secs })
107    }
108
109    /// 32-byte content-addressed key for `(model, text)`.
110    pub fn key(model: &str, text: &str) -> [u8; 32] {
111        let mut hasher = blake3::Hasher::new();
112        hasher.update(model.as_bytes());
113        hasher.update(&[0u8]);
114        hasher.update(text.as_bytes());
115        *hasher.finalize().as_bytes()
116    }
117
118    /// Look up a vector. Returns `None` if absent or expired.
119    pub fn get(&self, model: &str, text: &str) -> Result<Option<Vec<f32>>> {
120        let key = Self::key(model, text);
121        let now = unix_now();
122        let txn = self.db.begin_read()?;
123        let table = txn.open_table(TABLE)?;
124        let Some(stored) = table.get(&key)? else {
125            return Ok(None);
126        };
127        let bytes = stored.value();
128        let (inserted_at, vec) = decode_entry(&bytes)?;
129        if let Some(ttl) = self.ttl_secs {
130            // `>=` so "ttl=N seconds" means the entry is dead after N seconds
131            // have elapsed. With `>` you'd see N+1 seconds of life, which is
132            // surprising at second-granularity timestamps.
133            if now.saturating_sub(inserted_at) >= ttl {
134                return Ok(None);
135            }
136        }
137        Ok(Some(vec))
138    }
139
140    /// Insert or overwrite a vector for `(model, text)`.
141    pub fn put(&self, model: &str, text: &str, vector: &[f32]) -> Result<()> {
142        let key = Self::key(model, text);
143        let bytes = encode_entry(unix_now(), vector);
144        let txn = self.db.begin_write()?;
145        {
146            let mut table = txn.open_table(TABLE)?;
147            table.insert(&key, bytes)?;
148        }
149        txn.commit()?;
150        Ok(())
151    }
152
153    /// Remove a single entry. Returns `true` if the key was present.
154    pub fn remove(&self, model: &str, text: &str) -> Result<bool> {
155        let key = Self::key(model, text);
156        let txn = self.db.begin_write()?;
157        let removed = {
158            let mut table = txn.open_table(TABLE)?;
159            // Bind the AccessGuard so its borrow of `table` ends before the
160            // block returns; otherwise the temporary outlives the table.
161            let prev = table.remove(&key)?;
162            prev.is_some()
163        };
164        txn.commit()?;
165        Ok(removed)
166    }
167
168    /// Remove every entry. Returns the number of entries removed.
169    pub fn clear(&self) -> Result<u64> {
170        let txn = self.db.begin_write()?;
171        let removed = {
172            let mut table = txn.open_table(TABLE)?;
173            let keys: Vec<[u8; 32]> = table
174                .iter()?
175                .filter_map(|r| r.ok().map(|(k, _)| *k.value()))
176                .collect();
177            for k in &keys {
178                let _ = table.remove(k)?;
179            }
180            keys.len() as u64
181        };
182        txn.commit()?;
183        Ok(removed)
184    }
185
186    /// Remove every entry whose `inserted_at + ttl < now`. Returns the count.
187    /// No-op when the cache has no TTL.
188    pub fn purge_expired(&self) -> Result<u64> {
189        let Some(ttl) = self.ttl_secs else {
190            return Ok(0);
191        };
192        let now = unix_now();
193        let txn = self.db.begin_write()?;
194        let removed = {
195            let mut table = txn.open_table(TABLE)?;
196            let mut victims: Vec<[u8; 32]> = Vec::new();
197            for entry in table.iter()? {
198                let (k, v) = entry?;
199                let bytes = v.value();
200                if bytes.len() < 8 {
201                    continue;
202                }
203                let inserted = u64::from_le_bytes(bytes[0..8].try_into().unwrap());
204                if now.saturating_sub(inserted) >= ttl {
205                    victims.push(*k.value());
206                }
207            }
208            for k in &victims {
209                table.remove(k)?;
210            }
211            victims.len() as u64
212        };
213        txn.commit()?;
214        Ok(removed)
215    }
216
217    /// Evict oldest entries until the total stored value bytes are
218    /// `<= target_bytes`. Returns the count removed.
219    pub fn purge_to_size(&self, target_bytes: u64) -> Result<u64> {
220        let txn = self.db.begin_write()?;
221        let removed = {
222            let mut table = txn.open_table(TABLE)?;
223            // Collect (inserted_at, key, size_bytes), sort oldest-first.
224            let mut all: Vec<(u64, [u8; 32], u64)> = Vec::new();
225            let mut total: u64 = 0;
226            for entry in table.iter()? {
227                let (k, v) = entry?;
228                let bytes = v.value();
229                if bytes.len() < 8 {
230                    continue;
231                }
232                let inserted = u64::from_le_bytes(bytes[0..8].try_into().unwrap());
233                let size = bytes.len() as u64;
234                total += size;
235                all.push((inserted, *k.value(), size));
236            }
237            if total <= target_bytes {
238                return Ok(0);
239            }
240            all.sort_by_key(|(t, _, _)| *t);
241            let mut removed = 0u64;
242            for (_, k, size) in &all {
243                if total <= target_bytes {
244                    break;
245                }
246                table.remove(k)?;
247                total = total.saturating_sub(*size);
248                removed += 1;
249            }
250            removed
251        };
252        txn.commit()?;
253        Ok(removed)
254    }
255
256    /// Counts and bytes for the cache.
257    pub fn stats(&self) -> Result<CacheStats> {
258        let txn = self.db.begin_read()?;
259        let table = txn.open_table(TABLE)?;
260        let entries = table.len()?;
261        let mut value_bytes = 0u64;
262        for entry in table.iter()? {
263            let (_, v) = entry?;
264            value_bytes += v.value().len() as u64;
265        }
266        let disk_bytes = self.disk_size();
267        Ok(CacheStats {
268            entries,
269            value_bytes,
270            disk_bytes,
271        })
272    }
273
274    /// Number of entries.
275    pub fn len(&self) -> Result<u64> {
276        let txn = self.db.begin_read()?;
277        let table = txn.open_table(TABLE)?;
278        Ok(table.len()?)
279    }
280
281    /// True if no entries are stored.
282    pub fn is_empty(&self) -> Result<bool> {
283        Ok(self.len()? == 0)
284    }
285
286    fn disk_size(&self) -> u64 {
287        // redb does not expose the file path back to us; the caller knows
288        // the path. For stats we return 0 if we cannot infer it, which is
289        // honest rather than approximate.
290        0
291    }
292}
293
294fn unix_now() -> u64 {
295    SystemTime::now()
296        .duration_since(UNIX_EPOCH)
297        .map(|d| d.as_secs())
298        .unwrap_or(0)
299}
300
301fn encode_entry(inserted_at: u64, vec: &[f32]) -> Vec<u8> {
302    let dim = vec.len() as u32;
303    let mut out = Vec::with_capacity(8 + 4 + vec.len() * 4);
304    out.extend_from_slice(&inserted_at.to_le_bytes());
305    out.extend_from_slice(&dim.to_le_bytes());
306    for &x in vec {
307        out.extend_from_slice(&x.to_le_bytes());
308    }
309    out
310}
311
312fn decode_entry(bytes: &[u8]) -> Result<(u64, Vec<f32>)> {
313    if bytes.len() < 12 {
314        return Err(CacheError::Malformed("entry shorter than header".into()));
315    }
316    let inserted = u64::from_le_bytes(bytes[0..8].try_into().unwrap());
317    let dim = u32::from_le_bytes(bytes[8..12].try_into().unwrap()) as usize;
318    let expected = 12 + dim * 4;
319    if bytes.len() != expected {
320        return Err(CacheError::Malformed(format!(
321            "entry length {}, expected {}",
322            bytes.len(),
323            expected
324        )));
325    }
326    let mut vec = Vec::with_capacity(dim);
327    for i in 0..dim {
328        let off = 12 + i * 4;
329        vec.push(f32::from_le_bytes(bytes[off..off + 4].try_into().unwrap()));
330    }
331    Ok((inserted, vec))
332}
333
334#[cfg(test)]
335mod tests {
336    use super::*;
337
338    fn tempdb() -> (tempfile::TempDir, std::path::PathBuf) {
339        let dir = tempfile::tempdir().unwrap();
340        let path = dir.path().join("cache.redb");
341        (dir, path)
342    }
343
344    #[test]
345    fn key_changes_with_model_or_text() {
346        let a = Cache::key("m1", "hello");
347        let b = Cache::key("m2", "hello");
348        let c = Cache::key("m1", "world");
349        assert_ne!(a, b);
350        assert_ne!(a, c);
351        assert_eq!(a, Cache::key("m1", "hello"));
352    }
353
354    #[test]
355    fn key_separator_blocks_concatenation_collision() {
356        // Without a separator, ("a", "bc") and ("ab", "c") would collide.
357        let a = Cache::key("a", "bc");
358        let b = Cache::key("ab", "c");
359        assert_ne!(a, b);
360    }
361
362    #[test]
363    fn put_then_get_round_trips() {
364        let (_dir, path) = tempdb();
365        let cache = Cache::open(&path).unwrap();
366        let v = vec![0.1, 0.2, 0.3];
367        cache.put("m", "hello", &v).unwrap();
368        assert_eq!(cache.get("m", "hello").unwrap(), Some(v));
369    }
370
371    #[test]
372    fn get_missing_returns_none() {
373        let (_dir, path) = tempdb();
374        let cache = Cache::open(&path).unwrap();
375        assert_eq!(cache.get("m", "nope").unwrap(), None);
376    }
377
378    #[test]
379    fn put_overwrites_existing_entry() {
380        let (_dir, path) = tempdb();
381        let cache = Cache::open(&path).unwrap();
382        cache.put("m", "k", &[1.0, 2.0]).unwrap();
383        cache.put("m", "k", &[3.0, 4.0, 5.0]).unwrap();
384        assert_eq!(cache.get("m", "k").unwrap(), Some(vec![3.0, 4.0, 5.0]));
385    }
386
387    #[test]
388    fn remove_returns_true_when_present() {
389        let (_dir, path) = tempdb();
390        let cache = Cache::open(&path).unwrap();
391        cache.put("m", "k", &[1.0]).unwrap();
392        assert!(cache.remove("m", "k").unwrap());
393        assert!(!cache.remove("m", "k").unwrap());
394    }
395
396    #[test]
397    fn clear_removes_all() {
398        let (_dir, path) = tempdb();
399        let cache = Cache::open(&path).unwrap();
400        for i in 0..10 {
401            cache.put("m", &format!("k{i}"), &[i as f32]).unwrap();
402        }
403        assert_eq!(cache.len().unwrap(), 10);
404        cache.clear().unwrap();
405        assert_eq!(cache.len().unwrap(), 0);
406    }
407
408    #[test]
409    fn purge_to_size_evicts_oldest() {
410        let (_dir, path) = tempdb();
411        let cache = Cache::open(&path).unwrap();
412        // Each entry: 8 + 4 + 4 = 16 bytes value.
413        for i in 0..10 {
414            cache.put("m", &format!("k{i}"), &[i as f32]).unwrap();
415            // Spread inserted_at across calls. unix_now is whole seconds, so
416            // we sleep just enough to differentiate. Skip on cargo test in a
417            // single second; the eviction doesn't have to be in strict
418            // chronological order if all timestamps tie, just under target.
419        }
420        // Target small enough to force eviction.
421        let removed = cache.purge_to_size(32).unwrap();
422        assert!(removed > 0, "expected at least 1 eviction");
423        let stats = cache.stats().unwrap();
424        assert!(stats.value_bytes <= 32, "value_bytes={}", stats.value_bytes);
425    }
426
427    #[test]
428    fn ttl_zero_rejected() {
429        let (_dir, path) = tempdb();
430        let err = Cache::open_with_ttl(&path, Some(0));
431        assert!(err.is_err());
432    }
433
434    #[test]
435    fn malformed_entry_rejected() {
436        // decode_entry directly, since we cannot write a malformed entry
437        // through the public API.
438        let bad = vec![0u8; 5];
439        let r = decode_entry(&bad);
440        assert!(r.is_err());
441    }
442
443    #[test]
444    fn encode_decode_round_trip() {
445        let v = vec![1.0_f32, -2.5, 3.125, f32::MIN, f32::MAX];
446        let bytes = encode_entry(123, &v);
447        let (t, decoded) = decode_entry(&bytes).unwrap();
448        assert_eq!(t, 123);
449        assert_eq!(decoded, v);
450    }
451}