1use anyhow::{Context, Result};
4use chrono::{DateTime, Duration, Utc};
5use rusqlite::{Connection, OpenFlags, OptionalExtension, params};
6use serde::Serialize;
7use serde::de::DeserializeOwned;
8use std::path::Path;
9
10use crate::expiry::{CacheExpiryWindow, is_valid, now_rfc3339, parse_rfc3339_utc};
11use crate::stats::CacheStats;
12
13#[derive(Debug)]
30pub struct ApiCache {
31 inner: ApiCacheInner,
32}
33
34#[derive(Debug)]
44struct ApiCacheInner {
45 conn: Connection,
46 default_ttl: Duration,
47 #[allow(dead_code)]
48 max_size_bytes: Option<u64>,
49}
50
51#[derive(Debug, Clone, PartialEq, Eq)]
53pub struct CacheInspection {
54 pub stats: CacheStats,
55 pub oldest_cached_at: Option<String>,
56 pub newest_cached_at: Option<String>,
57}
58
59#[derive(Clone, Debug, PartialEq, Eq)]
61pub enum CacheLookup<T> {
62 Fresh(T),
64 Stale(T),
66 Miss,
68}
69
70impl ApiCache {
71 pub fn open(path: impl AsRef<Path>) -> Result<Self> {
73 let conn = Connection::open(path).context("open cache database")?;
74
75 conn.execute(
76 "CREATE TABLE IF NOT EXISTS cache_entries (
77 key TEXT PRIMARY KEY,
78 data TEXT NOT NULL,
79 cached_at TEXT NOT NULL,
80 expires_at TEXT NOT NULL
81 )",
82 [],
83 )?;
84
85 conn.execute(
86 "CREATE INDEX IF NOT EXISTS idx_expires ON cache_entries(expires_at)",
87 [],
88 )?;
89
90 Ok(Self {
91 inner: ApiCacheInner {
92 conn,
93 default_ttl: Duration::hours(24),
94 max_size_bytes: None,
95 },
96 })
97 }
98
99 pub fn open_read_only(path: impl AsRef<Path>) -> Result<Self> {
101 let conn = Connection::open_with_flags(path, OpenFlags::SQLITE_OPEN_READ_ONLY)
102 .context("open cache database read-only")?;
103
104 Ok(Self {
105 inner: ApiCacheInner {
106 conn,
107 default_ttl: Duration::hours(24),
108 max_size_bytes: None,
109 },
110 })
111 }
112
113 pub fn open_in_memory() -> Result<Self> {
115 let conn = Connection::open_in_memory().context("open in-memory cache")?;
116
117 conn.execute(
118 "CREATE TABLE cache_entries (
119 key TEXT PRIMARY KEY,
120 data TEXT NOT NULL,
121 cached_at TEXT NOT NULL,
122 expires_at TEXT NOT NULL
123 )",
124 [],
125 )?;
126
127 Ok(Self {
128 inner: ApiCacheInner {
129 conn,
130 default_ttl: Duration::hours(24),
131 max_size_bytes: None,
132 },
133 })
134 }
135
136 pub fn with_ttl(mut self, ttl: Duration) -> Self {
138 self.inner.default_ttl = ttl;
139 self
140 }
141
142 pub fn with_max_size(mut self, max_size_bytes: u64) -> Self {
144 self.inner.max_size_bytes = Some(max_size_bytes);
145 self
146 }
147
148 pub fn get<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>> {
150 let now = now_rfc3339();
151
152 let row: Option<String> = self
153 .inner
154 .conn
155 .query_row(
156 "SELECT data FROM cache_entries WHERE key = ?1 AND expires_at > ?2",
157 params![key, now],
158 |row| row.get(0),
159 )
160 .optional()?;
161
162 match row {
163 Some(data) => {
164 let value: T = serde_json::from_str(&data)
165 .with_context(|| format!("deserialize cached value for key: {key}"))?;
166 Ok(Some(value))
167 }
168 None => Ok(None),
169 }
170 }
171
172 pub fn lookup<T: DeserializeOwned>(&self, key: &str) -> Result<CacheLookup<T>> {
174 let now = Utc::now();
175
176 let row: Option<(String, String)> = self
177 .inner
178 .conn
179 .query_row(
180 "SELECT data, expires_at FROM cache_entries WHERE key = ?1",
181 params![key],
182 |row| Ok((row.get(0)?, row.get(1)?)),
183 )
184 .optional()?;
185
186 let Some((data, expires_at)) = row else {
187 return Ok(CacheLookup::Miss);
188 };
189
190 let value: T = serde_json::from_str(&data)
191 .with_context(|| format!("deserialize cached value for key: {key}"))?;
192 let expires_at = parse_rfc3339_utc(&expires_at)
193 .with_context(|| format!("parse cached expiry for key: {key}"))?;
194
195 if is_valid(expires_at, now) {
196 Ok(CacheLookup::Fresh(value))
197 } else {
198 Ok(CacheLookup::Stale(value))
199 }
200 }
201
202 pub fn set<T: Serialize>(&self, key: &str, value: &T) -> Result<()> {
204 self.set_with_ttl(key, value, self.inner.default_ttl)
205 }
206
207 pub fn set_with_ttl<T: Serialize>(&self, key: &str, value: &T, ttl: Duration) -> Result<()> {
209 let window = CacheExpiryWindow::from_now(ttl);
210 let data = serde_json::to_string(value)
211 .with_context(|| format!("serialize value for key: {key}"))?;
212
213 self.inner.conn.execute(
214 "INSERT OR REPLACE INTO cache_entries (key, data, cached_at, expires_at) VALUES (?1, ?2, ?3, ?4)",
215 params![
216 key,
217 data,
218 window.cached_at_rfc3339(),
219 window.expires_at_rfc3339(),
220 ],
221 )?;
222
223 Ok(())
224 }
225
226 pub fn contains(&self, key: &str) -> Result<bool> {
228 let now = now_rfc3339();
229
230 let count: i64 = self.inner.conn.query_row(
231 "SELECT COUNT(*) FROM cache_entries WHERE key = ?1 AND expires_at > ?2",
232 params![key, now],
233 |row| row.get(0),
234 )?;
235
236 Ok(count > 0)
237 }
238
239 pub fn cleanup_expired(&self) -> Result<usize> {
241 let now = now_rfc3339();
242
243 let deleted = self.inner.conn.execute(
244 "DELETE FROM cache_entries WHERE expires_at <= ?1",
245 params![now],
246 )?;
247
248 Ok(deleted)
249 }
250
251 pub fn count_older_than(&self, cutoff: DateTime<Utc>) -> Result<usize> {
253 let cutoff = cutoff.to_rfc3339();
254 let count: i64 = self.inner.conn.query_row(
255 "SELECT COUNT(*) FROM cache_entries WHERE cached_at < ?1",
256 params![cutoff],
257 |row| row.get(0),
258 )?;
259 Ok(count.max(0) as usize)
260 }
261
262 pub fn cleanup_older_than(&self, cutoff: DateTime<Utc>) -> Result<usize> {
264 let cutoff = cutoff.to_rfc3339();
265 let deleted = self.inner.conn.execute(
266 "DELETE FROM cache_entries WHERE cached_at < ?1",
267 params![cutoff],
268 )?;
269 Ok(deleted)
270 }
271
272 pub fn clear(&self) -> Result<()> {
274 self.inner.conn.execute("DELETE FROM cache_entries", [])?;
275 Ok(())
276 }
277
278 pub fn stats(&self) -> Result<CacheStats> {
280 let now = now_rfc3339();
281
282 let total: i64 =
283 self.inner
284 .conn
285 .query_row("SELECT COUNT(*) FROM cache_entries", [], |row| row.get(0))?;
286
287 let expired: i64 = self.inner.conn.query_row(
288 "SELECT COUNT(*) FROM cache_entries WHERE expires_at <= ?1",
289 params![now],
290 |row| row.get(0),
291 )?;
292
293 let size_bytes: i64 = self.inner.conn.query_row(
294 "SELECT SUM(LENGTH(data)) FROM cache_entries",
295 [],
296 |row| Ok(row.get::<_, Option<i64>>(0).unwrap_or(Some(0)).unwrap_or(0)),
297 )?;
298
299 Ok(CacheStats::from_raw_counts(total, expired, size_bytes))
300 }
301
302 pub fn inspect(&self) -> Result<CacheInspection> {
304 let stats = self.stats()?;
305 let oldest_cached_at =
306 self.inner
307 .conn
308 .query_row("SELECT MIN(cached_at) FROM cache_entries", [], |row| {
309 row.get::<_, Option<String>>(0)
310 })?;
311 let newest_cached_at =
312 self.inner
313 .conn
314 .query_row("SELECT MAX(cached_at) FROM cache_entries", [], |row| {
315 row.get::<_, Option<String>>(0)
316 })?;
317 Ok(CacheInspection {
318 stats,
319 oldest_cached_at,
320 newest_cached_at,
321 })
322 }
323}
324
325#[cfg(test)]
326mod tests {
327 use super::*;
328 use crate::CacheKey;
329
330 #[derive(Debug, serde::Serialize, serde::Deserialize, PartialEq, Clone)]
331 struct TestData {
332 name: String,
333 count: u32,
334 }
335
336 #[test]
337 fn cache_basic_operations() {
338 let cache = ApiCache::open_in_memory().unwrap();
339
340 let data = TestData {
341 name: "test".to_string(),
342 count: 42,
343 };
344
345 let result: Option<TestData> = cache.get("key1").unwrap();
346 assert!(result.is_none());
347
348 cache.set("key1", &data).unwrap();
349
350 let result: Option<TestData> = cache.get("key1").unwrap();
351 assert_eq!(result, Some(data));
352 }
353
354 #[test]
355 fn cache_ttl_expiration() {
356 let cache = ApiCache::open_in_memory()
357 .unwrap()
358 .with_ttl(Duration::seconds(1));
359
360 let data = TestData {
361 name: "test".to_string(),
362 count: 42,
363 };
364
365 cache.set("key1", &data).unwrap();
366
367 let result: Option<TestData> = cache.get("key1").unwrap();
368 assert_eq!(result, Some(data.clone()));
369
370 std::thread::sleep(std::time::Duration::from_millis(1100));
371
372 let result: Option<TestData> = cache.get("key1").unwrap();
373 assert!(result.is_none());
374 }
375
376 #[test]
377 fn lookup_distinguishes_fresh_stale_and_miss() -> Result<()> {
378 let cache = ApiCache::open_in_memory()?;
379
380 let fresh = TestData {
381 name: "fresh".to_string(),
382 count: 1,
383 };
384 let stale = TestData {
385 name: "stale".to_string(),
386 count: 2,
387 };
388
389 cache.set_with_ttl("fresh", &fresh, Duration::seconds(60))?;
390 cache.set_with_ttl("stale", &stale, Duration::seconds(-1))?;
391
392 assert_eq!(
393 cache.lookup::<TestData>("fresh")?,
394 CacheLookup::Fresh(fresh)
395 );
396 assert_eq!(
397 cache.lookup::<TestData>("stale")?,
398 CacheLookup::Stale(stale)
399 );
400 assert_eq!(cache.lookup::<TestData>("missing")?, CacheLookup::Miss);
401
402 let filtered: Option<TestData> = cache.get("stale")?;
403 assert!(
404 filtered.is_none(),
405 "ApiCache::get should continue filtering expired rows"
406 );
407 Ok(())
408 }
409
410 #[test]
411 fn cache_stats() {
412 let cache = ApiCache::open_in_memory().unwrap();
413
414 let data = TestData {
415 name: "test".to_string(),
416 count: 42,
417 };
418
419 cache.set("key1", &data).unwrap();
420 cache.set("key2", &data).unwrap();
421
422 let stats = cache.stats().unwrap();
423 assert_eq!(stats.total_entries, 2);
424 assert_eq!(stats.valid_entries, 2);
425 assert_eq!(stats.expired_entries, 0);
426 }
427
428 #[test]
429 fn cache_inspect_reports_timestamp_bounds() {
430 let cache = ApiCache::open_in_memory().unwrap();
431
432 cache.set("key1", &"one").unwrap();
433 cache.set("key2", &"two").unwrap();
434
435 let inspection = cache.inspect().unwrap();
436 assert_eq!(inspection.stats.total_entries, 2);
437 assert!(inspection.oldest_cached_at.is_some());
438 assert!(inspection.newest_cached_at.is_some());
439 }
440
441 #[test]
442 fn cache_cleanup() {
443 let cache = ApiCache::open_in_memory().unwrap();
444
445 let data = TestData {
446 name: "test".to_string(),
447 count: 42,
448 };
449
450 cache
451 .set_with_ttl("key1", &data, Duration::seconds(-1))
452 .unwrap();
453
454 let deleted = cache.cleanup_expired().unwrap();
455 assert_eq!(deleted, 1);
456
457 let stats = cache.stats().unwrap();
458 assert_eq!(stats.expired_entries, 0);
459 }
460
461 #[test]
462 fn cache_clear() {
463 let cache = ApiCache::open_in_memory().unwrap();
464
465 let data = TestData {
466 name: "test".to_string(),
467 count: 42,
468 };
469
470 cache.set("key1", &data).unwrap();
471 cache.set("key2", &data).unwrap();
472
473 cache.clear().unwrap();
474
475 let stats = cache.stats().unwrap();
476 assert_eq!(stats.total_entries, 0);
477 }
478
479 #[test]
480 fn cache_cleanup_older_than_removes_matching_entries() {
481 let cache = ApiCache::open_in_memory().unwrap();
482
483 cache.set("old1", &"one").unwrap();
484 cache.set("old2", &"two").unwrap();
485
486 let cutoff = Utc::now() + Duration::seconds(1);
487 assert_eq!(cache.count_older_than(cutoff).unwrap(), 2);
488 assert_eq!(cache.cleanup_older_than(cutoff).unwrap(), 2);
489 assert!(cache.stats().unwrap().is_empty());
490 }
491
492 #[test]
493 fn cache_contains() {
494 let cache = ApiCache::open_in_memory().unwrap();
495
496 let data = TestData {
497 name: "test".to_string(),
498 count: 42,
499 };
500
501 assert!(!cache.contains("key1").unwrap());
502
503 cache.set("key1", &data).unwrap();
504 assert!(cache.contains("key1").unwrap());
505 }
506
507 #[test]
508 fn cache_key_reexport_matches_contract() {
509 let details = CacheKey::pr_details("https://api.github.com/repos/o/r/pulls/1");
510 let reviews = CacheKey::pr_reviews("https://api.github.com/repos/o/r/pulls/1", 2);
511 let notes = CacheKey::mr_notes(12, 34, 1);
512
513 assert_eq!(
514 details,
515 "pr:details:https://api.github.com/repos/o/r/pulls/1"
516 );
517 assert_eq!(
518 reviews,
519 "pr:reviews:https://api.github.com/repos/o/r/pulls/1:page2"
520 );
521 assert_eq!(notes, "gitlab:mr:notes:project12:mr34:page1");
522 }
523
524 #[test]
525 fn cache_stats_reexport_matches_contract() {
526 let stats = CacheStats::from_raw_counts(5, 2, 2 * 1024 * 1024 + 77);
527 assert_eq!(stats.total_entries, 5);
528 assert_eq!(stats.expired_entries, 2);
529 assert_eq!(stats.valid_entries, 3);
530 assert_eq!(stats.cache_size_mb, 2);
531 }
532}