Skip to main content

roboticus_db/
cache.rs

1use crate::embeddings::{blob_to_embedding, embedding_to_blob};
2use crate::{Database, DbResultExt};
3use roboticus_core::Result;
4
5/// A single persisted cache entry.
6#[derive(Debug, Clone)]
7pub struct PersistedCacheEntry {
8    pub prompt_hash: String,
9    pub response: String,
10    pub model: String,
11    pub tokens_saved: u32,
12    pub hit_count: u32,
13    pub embedding: Option<Vec<f32>>,
14    pub created_at: String,
15    pub expires_at: Option<String>,
16}
17
18/// Save a cache entry to the `semantic_cache` table.
19pub fn save_cache_entry(db: &Database, id: &str, entry: &PersistedCacheEntry) -> Result<()> {
20    let embedding_blob = entry.embedding.as_ref().map(|e| embedding_to_blob(e));
21
22    let conn = db.conn();
23    conn.execute(
24        "INSERT OR REPLACE INTO semantic_cache \
25         (id, prompt_hash, embedding, response, model, tokens_saved, hit_count, created_at, expires_at) \
26         VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9)",
27        rusqlite::params![
28            id,
29            entry.prompt_hash,
30            embedding_blob,
31            entry.response,
32            entry.model,
33            entry.tokens_saved,
34            entry.hit_count,
35            entry.created_at,
36            entry.expires_at,
37        ],
38    )
39    .db_err()?;
40
41    Ok(())
42}
43
44/// Load all non-expired cache entries from the database.
45pub fn load_cache_entries(db: &Database) -> Result<Vec<(String, PersistedCacheEntry)>> {
46    let conn = db.conn();
47    let mut stmt = conn
48        .prepare(
49            "SELECT id, prompt_hash, embedding, response, model, tokens_saved, hit_count, \
50             created_at, expires_at \
51             FROM semantic_cache \
52             WHERE expires_at IS NULL OR expires_at > datetime('now')",
53        )
54        .db_err()?;
55
56    let rows = stmt
57        .query_map([], |row| {
58            let id: String = row.get(0)?;
59            let prompt_hash: String = row.get(1)?;
60            let blob: Option<Vec<u8>> = row.get(2)?;
61            let response: String = row.get(3)?;
62            let model: String = row.get(4)?;
63            let tokens_saved: u32 = row.get(5)?;
64            let hit_count: u32 = row.get(6)?;
65            let created_at: String = row.get(7)?;
66            let expires_at: Option<String> = row.get(8)?;
67
68            let embedding = blob.and_then(|b| {
69                if b.is_empty() {
70                    None
71                } else {
72                    Some(blob_to_embedding(&b))
73                }
74            });
75
76            Ok((
77                id,
78                PersistedCacheEntry {
79                    prompt_hash,
80                    response,
81                    model,
82                    tokens_saved,
83                    hit_count,
84                    embedding,
85                    created_at,
86                    expires_at,
87                },
88            ))
89        })
90        .db_err()?;
91
92    rows.collect::<std::result::Result<Vec<_>, _>>().db_err()
93}
94
95/// Maximum age (days) for cache entries that lack an explicit `expires_at`.
96/// Prevents NULL-expiry rows from accumulating indefinitely.
97const NULL_EXPIRY_MAX_AGE_DAYS: u32 = 7;
98
99/// Remove expired entries from the semantic_cache table.
100///
101/// Evicts rows where:
102/// 1. `expires_at` has passed, OR
103/// 2. `expires_at IS NULL` and the row is older than `NULL_EXPIRY_MAX_AGE_DAYS` (7 days).
104pub fn evict_expired_cache(db: &Database) -> Result<usize> {
105    let conn = db.conn();
106    let deleted = conn
107        .execute(
108            &format!(
109                "DELETE FROM semantic_cache WHERE \
110                 (expires_at IS NOT NULL AND expires_at <= datetime('now')) \
111                 OR (expires_at IS NULL AND created_at <= datetime('now', '-{NULL_EXPIRY_MAX_AGE_DAYS} days'))"
112            ),
113            [],
114        )
115        .db_err()?;
116    Ok(deleted)
117}
118
119/// Count of cached entries.
120pub fn cache_count(db: &Database) -> Result<usize> {
121    let conn = db.conn();
122    let count: usize = conn
123        .query_row("SELECT COUNT(*) FROM semantic_cache", [], |row| row.get(0))
124        .db_err()?;
125    Ok(count)
126}
127
128#[cfg(test)]
129mod tests {
130    use super::*;
131
132    fn test_db() -> Database {
133        Database::new(":memory:").unwrap()
134    }
135
136    #[test]
137    fn save_and_load_roundtrip() {
138        let db = test_db();
139
140        let entry = PersistedCacheEntry {
141            prompt_hash: "abc123".into(),
142            response: "Hello world".into(),
143            model: "test-model".into(),
144            tokens_saved: 50,
145            hit_count: 3,
146            embedding: Some(vec![0.1, 0.2, 0.3]),
147            created_at: "2025-01-01T00:00:00".into(),
148            expires_at: Some("2030-12-31T23:59:59".into()),
149        };
150
151        save_cache_entry(&db, "cache-1", &entry).unwrap();
152
153        let loaded = load_cache_entries(&db).unwrap();
154        assert_eq!(loaded.len(), 1);
155        assert_eq!(loaded[0].0, "cache-1");
156        assert_eq!(loaded[0].1.prompt_hash, "abc123");
157        assert_eq!(loaded[0].1.response, "Hello world");
158        assert_eq!(loaded[0].1.tokens_saved, 50);
159        assert_eq!(loaded[0].1.hit_count, 3);
160        assert!(loaded[0].1.embedding.is_some());
161        assert_eq!(loaded[0].1.embedding.as_ref().unwrap().len(), 3);
162    }
163
164    #[test]
165    fn save_without_embedding() {
166        let db = test_db();
167
168        let entry = PersistedCacheEntry {
169            prompt_hash: "def456".into(),
170            response: "No embedding".into(),
171            model: "test-model".into(),
172            tokens_saved: 10,
173            hit_count: 0,
174            embedding: None,
175            created_at: "2025-01-01T00:00:00".into(),
176            expires_at: None,
177        };
178
179        save_cache_entry(&db, "cache-2", &entry).unwrap();
180
181        let loaded = load_cache_entries(&db).unwrap();
182        assert_eq!(loaded.len(), 1);
183        assert!(loaded[0].1.embedding.is_none());
184        assert!(loaded[0].1.expires_at.is_none());
185    }
186
187    #[test]
188    fn evict_expired() {
189        let db = test_db();
190
191        let expired = PersistedCacheEntry {
192            prompt_hash: "expired".into(),
193            response: "old".into(),
194            model: "m".into(),
195            tokens_saved: 0,
196            hit_count: 0,
197            embedding: None,
198            created_at: "2020-01-01T00:00:00".into(),
199            expires_at: Some("2020-01-02T00:00:00".into()),
200        };
201        let fresh = PersistedCacheEntry {
202            prompt_hash: "fresh".into(),
203            response: "new".into(),
204            model: "m".into(),
205            tokens_saved: 0,
206            hit_count: 0,
207            embedding: None,
208            created_at: "2025-01-01T00:00:00".into(),
209            expires_at: Some("2030-12-31T23:59:59".into()),
210        };
211
212        save_cache_entry(&db, "c1", &expired).unwrap();
213        save_cache_entry(&db, "c2", &fresh).unwrap();
214
215        let evicted = evict_expired_cache(&db).unwrap();
216        assert_eq!(evicted, 1);
217        assert_eq!(cache_count(&db).unwrap(), 1);
218    }
219
220    #[test]
221    fn evict_null_expiry_after_max_age() {
222        let db = test_db();
223
224        // Old entry with NULL expires_at — should be evicted.
225        let old_null = PersistedCacheEntry {
226            prompt_hash: "old_null".into(),
227            response: "ancient".into(),
228            model: "m".into(),
229            tokens_saved: 0,
230            hit_count: 0,
231            embedding: None,
232            created_at: "2020-01-01T00:00:00".into(),
233            expires_at: None,
234        };
235        // Recent entry with NULL expires_at — should survive.
236        let recent_null = PersistedCacheEntry {
237            prompt_hash: "recent_null".into(),
238            response: "fresh".into(),
239            model: "m".into(),
240            tokens_saved: 0,
241            hit_count: 0,
242            embedding: None,
243            created_at: "2099-01-01T00:00:00".into(),
244            expires_at: None,
245        };
246
247        save_cache_entry(&db, "c1", &old_null).unwrap();
248        save_cache_entry(&db, "c2", &recent_null).unwrap();
249        assert_eq!(cache_count(&db).unwrap(), 2);
250
251        let evicted = evict_expired_cache(&db).unwrap();
252        assert_eq!(evicted, 1);
253        assert_eq!(cache_count(&db).unwrap(), 1);
254
255        let remaining = load_cache_entries(&db).unwrap();
256        assert_eq!(remaining[0].1.prompt_hash, "recent_null");
257    }
258
259    #[test]
260    fn cache_count_empty() {
261        let db = test_db();
262        assert_eq!(cache_count(&db).unwrap(), 0);
263    }
264
265    #[test]
266    fn replace_existing_entry() {
267        let db = test_db();
268
269        let entry1 = PersistedCacheEntry {
270            prompt_hash: "hash".into(),
271            response: "first".into(),
272            model: "m".into(),
273            tokens_saved: 10,
274            hit_count: 1,
275            embedding: None,
276            created_at: "2025-01-01T00:00:00".into(),
277            expires_at: None,
278        };
279        let entry2 = PersistedCacheEntry {
280            prompt_hash: "hash".into(),
281            response: "second".into(),
282            model: "m".into(),
283            tokens_saved: 20,
284            hit_count: 5,
285            embedding: None,
286            created_at: "2025-01-02T00:00:00".into(),
287            expires_at: None,
288        };
289
290        save_cache_entry(&db, "c1", &entry1).unwrap();
291        save_cache_entry(&db, "c1", &entry2).unwrap();
292
293        assert_eq!(cache_count(&db).unwrap(), 1);
294        let loaded = load_cache_entries(&db).unwrap();
295        assert_eq!(loaded[0].1.response, "second");
296    }
297}