1use crate::embeddings::{blob_to_embedding, embedding_to_blob};
2use crate::{Database, DbResultExt};
3use roboticus_core::Result;
4
5#[derive(Debug, Clone)]
7pub struct PersistedCacheEntry {
8 pub prompt_hash: String,
9 pub response: String,
10 pub model: String,
11 pub tokens_saved: u32,
12 pub hit_count: u32,
13 pub embedding: Option<Vec<f32>>,
14 pub created_at: String,
15 pub expires_at: Option<String>,
16}
17
18pub fn save_cache_entry(db: &Database, id: &str, entry: &PersistedCacheEntry) -> Result<()> {
20 let embedding_blob = entry.embedding.as_ref().map(|e| embedding_to_blob(e));
21
22 let conn = db.conn();
23 conn.execute(
24 "INSERT OR REPLACE INTO semantic_cache \
25 (id, prompt_hash, embedding, response, model, tokens_saved, hit_count, created_at, expires_at) \
26 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9)",
27 rusqlite::params![
28 id,
29 entry.prompt_hash,
30 embedding_blob,
31 entry.response,
32 entry.model,
33 entry.tokens_saved,
34 entry.hit_count,
35 entry.created_at,
36 entry.expires_at,
37 ],
38 )
39 .db_err()?;
40
41 Ok(())
42}
43
44pub fn load_cache_entries(db: &Database) -> Result<Vec<(String, PersistedCacheEntry)>> {
46 let conn = db.conn();
47 let mut stmt = conn
48 .prepare(
49 "SELECT id, prompt_hash, embedding, response, model, tokens_saved, hit_count, \
50 created_at, expires_at \
51 FROM semantic_cache \
52 WHERE expires_at IS NULL OR expires_at > datetime('now')",
53 )
54 .db_err()?;
55
56 let rows = stmt
57 .query_map([], |row| {
58 let id: String = row.get(0)?;
59 let prompt_hash: String = row.get(1)?;
60 let blob: Option<Vec<u8>> = row.get(2)?;
61 let response: String = row.get(3)?;
62 let model: String = row.get(4)?;
63 let tokens_saved: u32 = row.get(5)?;
64 let hit_count: u32 = row.get(6)?;
65 let created_at: String = row.get(7)?;
66 let expires_at: Option<String> = row.get(8)?;
67
68 let embedding = blob.and_then(|b| {
69 if b.is_empty() {
70 None
71 } else {
72 Some(blob_to_embedding(&b))
73 }
74 });
75
76 Ok((
77 id,
78 PersistedCacheEntry {
79 prompt_hash,
80 response,
81 model,
82 tokens_saved,
83 hit_count,
84 embedding,
85 created_at,
86 expires_at,
87 },
88 ))
89 })
90 .db_err()?;
91
92 rows.collect::<std::result::Result<Vec<_>, _>>().db_err()
93}
94
95const NULL_EXPIRY_MAX_AGE_DAYS: u32 = 7;
98
99pub fn evict_expired_cache(db: &Database) -> Result<usize> {
105 let conn = db.conn();
106 let deleted = conn
107 .execute(
108 &format!(
109 "DELETE FROM semantic_cache WHERE \
110 (expires_at IS NOT NULL AND expires_at <= datetime('now')) \
111 OR (expires_at IS NULL AND created_at <= datetime('now', '-{NULL_EXPIRY_MAX_AGE_DAYS} days'))"
112 ),
113 [],
114 )
115 .db_err()?;
116 Ok(deleted)
117}
118
119pub fn cache_count(db: &Database) -> Result<usize> {
121 let conn = db.conn();
122 let count: usize = conn
123 .query_row("SELECT COUNT(*) FROM semantic_cache", [], |row| row.get(0))
124 .db_err()?;
125 Ok(count)
126}
127
128#[cfg(test)]
129mod tests {
130 use super::*;
131
132 fn test_db() -> Database {
133 Database::new(":memory:").unwrap()
134 }
135
136 #[test]
137 fn save_and_load_roundtrip() {
138 let db = test_db();
139
140 let entry = PersistedCacheEntry {
141 prompt_hash: "abc123".into(),
142 response: "Hello world".into(),
143 model: "test-model".into(),
144 tokens_saved: 50,
145 hit_count: 3,
146 embedding: Some(vec![0.1, 0.2, 0.3]),
147 created_at: "2025-01-01T00:00:00".into(),
148 expires_at: Some("2030-12-31T23:59:59".into()),
149 };
150
151 save_cache_entry(&db, "cache-1", &entry).unwrap();
152
153 let loaded = load_cache_entries(&db).unwrap();
154 assert_eq!(loaded.len(), 1);
155 assert_eq!(loaded[0].0, "cache-1");
156 assert_eq!(loaded[0].1.prompt_hash, "abc123");
157 assert_eq!(loaded[0].1.response, "Hello world");
158 assert_eq!(loaded[0].1.tokens_saved, 50);
159 assert_eq!(loaded[0].1.hit_count, 3);
160 assert!(loaded[0].1.embedding.is_some());
161 assert_eq!(loaded[0].1.embedding.as_ref().unwrap().len(), 3);
162 }
163
164 #[test]
165 fn save_without_embedding() {
166 let db = test_db();
167
168 let entry = PersistedCacheEntry {
169 prompt_hash: "def456".into(),
170 response: "No embedding".into(),
171 model: "test-model".into(),
172 tokens_saved: 10,
173 hit_count: 0,
174 embedding: None,
175 created_at: "2025-01-01T00:00:00".into(),
176 expires_at: None,
177 };
178
179 save_cache_entry(&db, "cache-2", &entry).unwrap();
180
181 let loaded = load_cache_entries(&db).unwrap();
182 assert_eq!(loaded.len(), 1);
183 assert!(loaded[0].1.embedding.is_none());
184 assert!(loaded[0].1.expires_at.is_none());
185 }
186
187 #[test]
188 fn evict_expired() {
189 let db = test_db();
190
191 let expired = PersistedCacheEntry {
192 prompt_hash: "expired".into(),
193 response: "old".into(),
194 model: "m".into(),
195 tokens_saved: 0,
196 hit_count: 0,
197 embedding: None,
198 created_at: "2020-01-01T00:00:00".into(),
199 expires_at: Some("2020-01-02T00:00:00".into()),
200 };
201 let fresh = PersistedCacheEntry {
202 prompt_hash: "fresh".into(),
203 response: "new".into(),
204 model: "m".into(),
205 tokens_saved: 0,
206 hit_count: 0,
207 embedding: None,
208 created_at: "2025-01-01T00:00:00".into(),
209 expires_at: Some("2030-12-31T23:59:59".into()),
210 };
211
212 save_cache_entry(&db, "c1", &expired).unwrap();
213 save_cache_entry(&db, "c2", &fresh).unwrap();
214
215 let evicted = evict_expired_cache(&db).unwrap();
216 assert_eq!(evicted, 1);
217 assert_eq!(cache_count(&db).unwrap(), 1);
218 }
219
220 #[test]
221 fn evict_null_expiry_after_max_age() {
222 let db = test_db();
223
224 let old_null = PersistedCacheEntry {
226 prompt_hash: "old_null".into(),
227 response: "ancient".into(),
228 model: "m".into(),
229 tokens_saved: 0,
230 hit_count: 0,
231 embedding: None,
232 created_at: "2020-01-01T00:00:00".into(),
233 expires_at: None,
234 };
235 let recent_null = PersistedCacheEntry {
237 prompt_hash: "recent_null".into(),
238 response: "fresh".into(),
239 model: "m".into(),
240 tokens_saved: 0,
241 hit_count: 0,
242 embedding: None,
243 created_at: "2099-01-01T00:00:00".into(),
244 expires_at: None,
245 };
246
247 save_cache_entry(&db, "c1", &old_null).unwrap();
248 save_cache_entry(&db, "c2", &recent_null).unwrap();
249 assert_eq!(cache_count(&db).unwrap(), 2);
250
251 let evicted = evict_expired_cache(&db).unwrap();
252 assert_eq!(evicted, 1);
253 assert_eq!(cache_count(&db).unwrap(), 1);
254
255 let remaining = load_cache_entries(&db).unwrap();
256 assert_eq!(remaining[0].1.prompt_hash, "recent_null");
257 }
258
259 #[test]
260 fn cache_count_empty() {
261 let db = test_db();
262 assert_eq!(cache_count(&db).unwrap(), 0);
263 }
264
265 #[test]
266 fn replace_existing_entry() {
267 let db = test_db();
268
269 let entry1 = PersistedCacheEntry {
270 prompt_hash: "hash".into(),
271 response: "first".into(),
272 model: "m".into(),
273 tokens_saved: 10,
274 hit_count: 1,
275 embedding: None,
276 created_at: "2025-01-01T00:00:00".into(),
277 expires_at: None,
278 };
279 let entry2 = PersistedCacheEntry {
280 prompt_hash: "hash".into(),
281 response: "second".into(),
282 model: "m".into(),
283 tokens_saved: 20,
284 hit_count: 5,
285 embedding: None,
286 created_at: "2025-01-02T00:00:00".into(),
287 expires_at: None,
288 };
289
290 save_cache_entry(&db, "c1", &entry1).unwrap();
291 save_cache_entry(&db, "c1", &entry2).unwrap();
292
293 assert_eq!(cache_count(&db).unwrap(), 1);
294 let loaded = load_cache_entries(&db).unwrap();
295 assert_eq!(loaded[0].1.response, "second");
296 }
297}