1use anyhow::{Context, Result};
4use chrono::{DateTime, Duration, Utc};
5use rusqlite::{Connection, OpenFlags, OptionalExtension, params};
6use serde::Serialize;
7use serde::de::DeserializeOwned;
8use std::path::Path;
9
10use crate::expiry::{CacheExpiryWindow, now_rfc3339};
11use crate::stats::CacheStats;
12
13#[derive(Debug)]
15pub struct ApiCache {
16 conn: Connection,
17 default_ttl: Duration,
18 #[allow(dead_code)]
19 max_size_bytes: Option<u64>,
20}
21
22#[derive(Debug, Clone, PartialEq, Eq)]
24pub struct CacheInspection {
25 pub stats: CacheStats,
26 pub oldest_cached_at: Option<String>,
27 pub newest_cached_at: Option<String>,
28}
29
30impl ApiCache {
31 pub fn open(path: impl AsRef<Path>) -> Result<Self> {
33 let conn = Connection::open(path).context("open cache database")?;
34
35 conn.execute(
36 "CREATE TABLE IF NOT EXISTS cache_entries (
37 key TEXT PRIMARY KEY,
38 data TEXT NOT NULL,
39 cached_at TEXT NOT NULL,
40 expires_at TEXT NOT NULL
41 )",
42 [],
43 )?;
44
45 conn.execute(
46 "CREATE INDEX IF NOT EXISTS idx_expires ON cache_entries(expires_at)",
47 [],
48 )?;
49
50 Ok(Self {
51 conn,
52 default_ttl: Duration::hours(24),
53 max_size_bytes: None,
54 })
55 }
56
57 pub fn open_read_only(path: impl AsRef<Path>) -> Result<Self> {
59 let conn = Connection::open_with_flags(path, OpenFlags::SQLITE_OPEN_READ_ONLY)
60 .context("open cache database read-only")?;
61
62 Ok(Self {
63 conn,
64 default_ttl: Duration::hours(24),
65 max_size_bytes: None,
66 })
67 }
68
69 pub fn open_in_memory() -> Result<Self> {
71 let conn = Connection::open_in_memory().context("open in-memory cache")?;
72
73 conn.execute(
74 "CREATE TABLE cache_entries (
75 key TEXT PRIMARY KEY,
76 data TEXT NOT NULL,
77 cached_at TEXT NOT NULL,
78 expires_at TEXT NOT NULL
79 )",
80 [],
81 )?;
82
83 Ok(Self {
84 conn,
85 default_ttl: Duration::hours(24),
86 max_size_bytes: None,
87 })
88 }
89
90 pub fn with_ttl(mut self, ttl: Duration) -> Self {
92 self.default_ttl = ttl;
93 self
94 }
95
96 pub fn with_max_size(mut self, max_size_bytes: u64) -> Self {
98 self.max_size_bytes = Some(max_size_bytes);
99 self
100 }
101
102 pub fn get<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>> {
104 let now = now_rfc3339();
105
106 let row: Option<String> = self
107 .conn
108 .query_row(
109 "SELECT data FROM cache_entries WHERE key = ?1 AND expires_at > ?2",
110 params![key, now],
111 |row| row.get(0),
112 )
113 .optional()?;
114
115 match row {
116 Some(data) => {
117 let value: T = serde_json::from_str(&data)
118 .with_context(|| format!("deserialize cached value for key: {key}"))?;
119 Ok(Some(value))
120 }
121 None => Ok(None),
122 }
123 }
124
125 pub fn set<T: Serialize>(&self, key: &str, value: &T) -> Result<()> {
127 self.set_with_ttl(key, value, self.default_ttl)
128 }
129
130 pub fn set_with_ttl<T: Serialize>(&self, key: &str, value: &T, ttl: Duration) -> Result<()> {
132 let window = CacheExpiryWindow::from_now(ttl);
133 let data = serde_json::to_string(value)
134 .with_context(|| format!("serialize value for key: {key}"))?;
135
136 self.conn.execute(
137 "INSERT OR REPLACE INTO cache_entries (key, data, cached_at, expires_at) VALUES (?1, ?2, ?3, ?4)",
138 params![
139 key,
140 data,
141 window.cached_at_rfc3339(),
142 window.expires_at_rfc3339(),
143 ],
144 )?;
145
146 Ok(())
147 }
148
149 pub fn contains(&self, key: &str) -> Result<bool> {
151 let now = now_rfc3339();
152
153 let count: i64 = self.conn.query_row(
154 "SELECT COUNT(*) FROM cache_entries WHERE key = ?1 AND expires_at > ?2",
155 params![key, now],
156 |row| row.get(0),
157 )?;
158
159 Ok(count > 0)
160 }
161
162 pub fn cleanup_expired(&self) -> Result<usize> {
164 let now = now_rfc3339();
165
166 let deleted = self.conn.execute(
167 "DELETE FROM cache_entries WHERE expires_at <= ?1",
168 params![now],
169 )?;
170
171 Ok(deleted)
172 }
173
174 pub fn count_older_than(&self, cutoff: DateTime<Utc>) -> Result<usize> {
176 let cutoff = cutoff.to_rfc3339();
177 let count: i64 = self.conn.query_row(
178 "SELECT COUNT(*) FROM cache_entries WHERE cached_at < ?1",
179 params![cutoff],
180 |row| row.get(0),
181 )?;
182 Ok(count.max(0) as usize)
183 }
184
185 pub fn cleanup_older_than(&self, cutoff: DateTime<Utc>) -> Result<usize> {
187 let cutoff = cutoff.to_rfc3339();
188 let deleted = self.conn.execute(
189 "DELETE FROM cache_entries WHERE cached_at < ?1",
190 params![cutoff],
191 )?;
192 Ok(deleted)
193 }
194
195 pub fn clear(&self) -> Result<()> {
197 self.conn.execute("DELETE FROM cache_entries", [])?;
198 Ok(())
199 }
200
201 pub fn stats(&self) -> Result<CacheStats> {
203 let now = now_rfc3339();
204
205 let total: i64 = self
206 .conn
207 .query_row("SELECT COUNT(*) FROM cache_entries", [], |row| row.get(0))?;
208
209 let expired: i64 = self.conn.query_row(
210 "SELECT COUNT(*) FROM cache_entries WHERE expires_at <= ?1",
211 params![now],
212 |row| row.get(0),
213 )?;
214
215 let size_bytes: i64 =
216 self.conn
217 .query_row("SELECT SUM(LENGTH(data)) FROM cache_entries", [], |row| {
218 Ok(row.get::<_, Option<i64>>(0).unwrap_or(Some(0)).unwrap_or(0))
219 })?;
220
221 Ok(CacheStats::from_raw_counts(total, expired, size_bytes))
222 }
223
224 pub fn inspect(&self) -> Result<CacheInspection> {
226 let stats = self.stats()?;
227 let oldest_cached_at =
228 self.conn
229 .query_row("SELECT MIN(cached_at) FROM cache_entries", [], |row| {
230 row.get::<_, Option<String>>(0)
231 })?;
232 let newest_cached_at =
233 self.conn
234 .query_row("SELECT MAX(cached_at) FROM cache_entries", [], |row| {
235 row.get::<_, Option<String>>(0)
236 })?;
237 Ok(CacheInspection {
238 stats,
239 oldest_cached_at,
240 newest_cached_at,
241 })
242 }
243}
244
245#[cfg(test)]
246mod tests {
247 use super::*;
248 use crate::CacheKey;
249
250 #[derive(Debug, serde::Serialize, serde::Deserialize, PartialEq, Clone)]
251 struct TestData {
252 name: String,
253 count: u32,
254 }
255
256 #[test]
257 fn cache_basic_operations() {
258 let cache = ApiCache::open_in_memory().unwrap();
259
260 let data = TestData {
261 name: "test".to_string(),
262 count: 42,
263 };
264
265 let result: Option<TestData> = cache.get("key1").unwrap();
266 assert!(result.is_none());
267
268 cache.set("key1", &data).unwrap();
269
270 let result: Option<TestData> = cache.get("key1").unwrap();
271 assert_eq!(result, Some(data));
272 }
273
274 #[test]
275 fn cache_ttl_expiration() {
276 let cache = ApiCache::open_in_memory()
277 .unwrap()
278 .with_ttl(Duration::seconds(1));
279
280 let data = TestData {
281 name: "test".to_string(),
282 count: 42,
283 };
284
285 cache.set("key1", &data).unwrap();
286
287 let result: Option<TestData> = cache.get("key1").unwrap();
288 assert_eq!(result, Some(data.clone()));
289
290 std::thread::sleep(std::time::Duration::from_millis(1100));
291
292 let result: Option<TestData> = cache.get("key1").unwrap();
293 assert!(result.is_none());
294 }
295
296 #[test]
297 fn cache_stats() {
298 let cache = ApiCache::open_in_memory().unwrap();
299
300 let data = TestData {
301 name: "test".to_string(),
302 count: 42,
303 };
304
305 cache.set("key1", &data).unwrap();
306 cache.set("key2", &data).unwrap();
307
308 let stats = cache.stats().unwrap();
309 assert_eq!(stats.total_entries, 2);
310 assert_eq!(stats.valid_entries, 2);
311 assert_eq!(stats.expired_entries, 0);
312 }
313
314 #[test]
315 fn cache_inspect_reports_timestamp_bounds() {
316 let cache = ApiCache::open_in_memory().unwrap();
317
318 cache.set("key1", &"one").unwrap();
319 cache.set("key2", &"two").unwrap();
320
321 let inspection = cache.inspect().unwrap();
322 assert_eq!(inspection.stats.total_entries, 2);
323 assert!(inspection.oldest_cached_at.is_some());
324 assert!(inspection.newest_cached_at.is_some());
325 }
326
327 #[test]
328 fn cache_cleanup() {
329 let cache = ApiCache::open_in_memory().unwrap();
330
331 let data = TestData {
332 name: "test".to_string(),
333 count: 42,
334 };
335
336 cache
337 .set_with_ttl("key1", &data, Duration::seconds(-1))
338 .unwrap();
339
340 let deleted = cache.cleanup_expired().unwrap();
341 assert_eq!(deleted, 1);
342
343 let stats = cache.stats().unwrap();
344 assert_eq!(stats.expired_entries, 0);
345 }
346
347 #[test]
348 fn cache_clear() {
349 let cache = ApiCache::open_in_memory().unwrap();
350
351 let data = TestData {
352 name: "test".to_string(),
353 count: 42,
354 };
355
356 cache.set("key1", &data).unwrap();
357 cache.set("key2", &data).unwrap();
358
359 cache.clear().unwrap();
360
361 let stats = cache.stats().unwrap();
362 assert_eq!(stats.total_entries, 0);
363 }
364
365 #[test]
366 fn cache_cleanup_older_than_removes_matching_entries() {
367 let cache = ApiCache::open_in_memory().unwrap();
368
369 cache.set("old1", &"one").unwrap();
370 cache.set("old2", &"two").unwrap();
371
372 let cutoff = Utc::now() + Duration::seconds(1);
373 assert_eq!(cache.count_older_than(cutoff).unwrap(), 2);
374 assert_eq!(cache.cleanup_older_than(cutoff).unwrap(), 2);
375 assert!(cache.stats().unwrap().is_empty());
376 }
377
378 #[test]
379 fn cache_contains() {
380 let cache = ApiCache::open_in_memory().unwrap();
381
382 let data = TestData {
383 name: "test".to_string(),
384 count: 42,
385 };
386
387 assert!(!cache.contains("key1").unwrap());
388
389 cache.set("key1", &data).unwrap();
390 assert!(cache.contains("key1").unwrap());
391 }
392
393 #[test]
394 fn cache_key_reexport_matches_contract() {
395 let details = CacheKey::pr_details("https://api.github.com/repos/o/r/pulls/1");
396 let reviews = CacheKey::pr_reviews("https://api.github.com/repos/o/r/pulls/1", 2);
397 let notes = CacheKey::mr_notes(12, 34, 1);
398
399 assert_eq!(
400 details,
401 "pr:details:https://api.github.com/repos/o/r/pulls/1"
402 );
403 assert_eq!(
404 reviews,
405 "pr:reviews:https://api.github.com/repos/o/r/pulls/1:page2"
406 );
407 assert_eq!(notes, "gitlab:mr:notes:project12:mr34:page1");
408 }
409
410 #[test]
411 fn cache_stats_reexport_matches_contract() {
412 let stats = CacheStats::from_raw_counts(5, 2, 2 * 1024 * 1024 + 77);
413 assert_eq!(stats.total_entries, 5);
414 assert_eq!(stats.expired_entries, 2);
415 assert_eq!(stats.valid_entries, 3);
416 assert_eq!(stats.cache_size_mb, 2);
417 }
418}