1use anyhow::{Context, Result};
4use chrono::Duration;
5use rusqlite::{Connection, OptionalExtension, params};
6use serde::Serialize;
7use serde::de::DeserializeOwned;
8use std::path::Path;
9
10use crate::expiry::{CacheExpiryWindow, now_rfc3339};
11use crate::stats::CacheStats;
12
13#[derive(Debug)]
15pub struct ApiCache {
16 conn: Connection,
17 default_ttl: Duration,
18 #[allow(dead_code)]
19 max_size_bytes: Option<u64>,
20}
21
22impl ApiCache {
23 pub fn open(path: impl AsRef<Path>) -> Result<Self> {
25 let conn = Connection::open(path).context("open cache database")?;
26
27 conn.execute(
28 "CREATE TABLE IF NOT EXISTS cache_entries (
29 key TEXT PRIMARY KEY,
30 data TEXT NOT NULL,
31 cached_at TEXT NOT NULL,
32 expires_at TEXT NOT NULL
33 )",
34 [],
35 )?;
36
37 conn.execute(
38 "CREATE INDEX IF NOT EXISTS idx_expires ON cache_entries(expires_at)",
39 [],
40 )?;
41
42 Ok(Self {
43 conn,
44 default_ttl: Duration::hours(24),
45 max_size_bytes: None,
46 })
47 }
48
49 pub fn open_in_memory() -> Result<Self> {
51 let conn = Connection::open_in_memory().context("open in-memory cache")?;
52
53 conn.execute(
54 "CREATE TABLE cache_entries (
55 key TEXT PRIMARY KEY,
56 data TEXT NOT NULL,
57 cached_at TEXT NOT NULL,
58 expires_at TEXT NOT NULL
59 )",
60 [],
61 )?;
62
63 Ok(Self {
64 conn,
65 default_ttl: Duration::hours(24),
66 max_size_bytes: None,
67 })
68 }
69
70 pub fn with_ttl(mut self, ttl: Duration) -> Self {
72 self.default_ttl = ttl;
73 self
74 }
75
76 pub fn with_max_size(mut self, max_size_bytes: u64) -> Self {
78 self.max_size_bytes = Some(max_size_bytes);
79 self
80 }
81
82 pub fn get<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>> {
84 let now = now_rfc3339();
85
86 let row: Option<String> = self
87 .conn
88 .query_row(
89 "SELECT data FROM cache_entries WHERE key = ?1 AND expires_at > ?2",
90 params![key, now],
91 |row| row.get(0),
92 )
93 .optional()?;
94
95 match row {
96 Some(data) => {
97 let value: T = serde_json::from_str(&data)
98 .with_context(|| format!("deserialize cached value for key: {key}"))?;
99 Ok(Some(value))
100 }
101 None => Ok(None),
102 }
103 }
104
105 pub fn set<T: Serialize>(&self, key: &str, value: &T) -> Result<()> {
107 self.set_with_ttl(key, value, self.default_ttl)
108 }
109
110 pub fn set_with_ttl<T: Serialize>(&self, key: &str, value: &T, ttl: Duration) -> Result<()> {
112 let window = CacheExpiryWindow::from_now(ttl);
113 let data = serde_json::to_string(value)
114 .with_context(|| format!("serialize value for key: {key}"))?;
115
116 self.conn.execute(
117 "INSERT OR REPLACE INTO cache_entries (key, data, cached_at, expires_at) VALUES (?1, ?2, ?3, ?4)",
118 params![
119 key,
120 data,
121 window.cached_at_rfc3339(),
122 window.expires_at_rfc3339(),
123 ],
124 )?;
125
126 Ok(())
127 }
128
129 pub fn contains(&self, key: &str) -> Result<bool> {
131 let now = now_rfc3339();
132
133 let count: i64 = self.conn.query_row(
134 "SELECT COUNT(*) FROM cache_entries WHERE key = ?1 AND expires_at > ?2",
135 params![key, now],
136 |row| row.get(0),
137 )?;
138
139 Ok(count > 0)
140 }
141
142 pub fn cleanup_expired(&self) -> Result<usize> {
144 let now = now_rfc3339();
145
146 let deleted = self.conn.execute(
147 "DELETE FROM cache_entries WHERE expires_at <= ?1",
148 params![now],
149 )?;
150
151 Ok(deleted)
152 }
153
154 pub fn clear(&self) -> Result<()> {
156 self.conn.execute("DELETE FROM cache_entries", [])?;
157 Ok(())
158 }
159
160 pub fn stats(&self) -> Result<CacheStats> {
162 let now = now_rfc3339();
163
164 let total: i64 = self
165 .conn
166 .query_row("SELECT COUNT(*) FROM cache_entries", [], |row| row.get(0))?;
167
168 let expired: i64 = self.conn.query_row(
169 "SELECT COUNT(*) FROM cache_entries WHERE expires_at <= ?1",
170 params![now],
171 |row| row.get(0),
172 )?;
173
174 let size_bytes: i64 =
175 self.conn
176 .query_row("SELECT SUM(LENGTH(data)) FROM cache_entries", [], |row| {
177 Ok(row.get::<_, Option<i64>>(0).unwrap_or(Some(0)).unwrap_or(0))
178 })?;
179
180 Ok(CacheStats::from_raw_counts(total, expired, size_bytes))
181 }
182}
183
184#[cfg(test)]
185mod tests {
186 use super::*;
187 use crate::CacheKey;
188
189 #[derive(Debug, serde::Serialize, serde::Deserialize, PartialEq, Clone)]
190 struct TestData {
191 name: String,
192 count: u32,
193 }
194
195 #[test]
196 fn cache_basic_operations() {
197 let cache = ApiCache::open_in_memory().unwrap();
198
199 let data = TestData {
200 name: "test".to_string(),
201 count: 42,
202 };
203
204 let result: Option<TestData> = cache.get("key1").unwrap();
205 assert!(result.is_none());
206
207 cache.set("key1", &data).unwrap();
208
209 let result: Option<TestData> = cache.get("key1").unwrap();
210 assert_eq!(result, Some(data));
211 }
212
213 #[test]
214 fn cache_ttl_expiration() {
215 let cache = ApiCache::open_in_memory()
216 .unwrap()
217 .with_ttl(Duration::seconds(1));
218
219 let data = TestData {
220 name: "test".to_string(),
221 count: 42,
222 };
223
224 cache.set("key1", &data).unwrap();
225
226 let result: Option<TestData> = cache.get("key1").unwrap();
227 assert_eq!(result, Some(data.clone()));
228
229 std::thread::sleep(std::time::Duration::from_millis(1100));
230
231 let result: Option<TestData> = cache.get("key1").unwrap();
232 assert!(result.is_none());
233 }
234
235 #[test]
236 fn cache_stats() {
237 let cache = ApiCache::open_in_memory().unwrap();
238
239 let data = TestData {
240 name: "test".to_string(),
241 count: 42,
242 };
243
244 cache.set("key1", &data).unwrap();
245 cache.set("key2", &data).unwrap();
246
247 let stats = cache.stats().unwrap();
248 assert_eq!(stats.total_entries, 2);
249 assert_eq!(stats.valid_entries, 2);
250 assert_eq!(stats.expired_entries, 0);
251 }
252
253 #[test]
254 fn cache_cleanup() {
255 let cache = ApiCache::open_in_memory().unwrap();
256
257 let data = TestData {
258 name: "test".to_string(),
259 count: 42,
260 };
261
262 cache
263 .set_with_ttl("key1", &data, Duration::seconds(-1))
264 .unwrap();
265
266 let deleted = cache.cleanup_expired().unwrap();
267 assert_eq!(deleted, 1);
268
269 let stats = cache.stats().unwrap();
270 assert_eq!(stats.expired_entries, 0);
271 }
272
273 #[test]
274 fn cache_clear() {
275 let cache = ApiCache::open_in_memory().unwrap();
276
277 let data = TestData {
278 name: "test".to_string(),
279 count: 42,
280 };
281
282 cache.set("key1", &data).unwrap();
283 cache.set("key2", &data).unwrap();
284
285 cache.clear().unwrap();
286
287 let stats = cache.stats().unwrap();
288 assert_eq!(stats.total_entries, 0);
289 }
290
291 #[test]
292 fn cache_contains() {
293 let cache = ApiCache::open_in_memory().unwrap();
294
295 let data = TestData {
296 name: "test".to_string(),
297 count: 42,
298 };
299
300 assert!(!cache.contains("key1").unwrap());
301
302 cache.set("key1", &data).unwrap();
303 assert!(cache.contains("key1").unwrap());
304 }
305
306 #[test]
307 fn cache_key_reexport_matches_contract() {
308 let details = CacheKey::pr_details("https://api.github.com/repos/o/r/pulls/1");
309 let reviews = CacheKey::pr_reviews("https://api.github.com/repos/o/r/pulls/1", 2);
310 let notes = CacheKey::mr_notes(12, 34, 1);
311
312 assert_eq!(
313 details,
314 "pr:details:https://api.github.com/repos/o/r/pulls/1"
315 );
316 assert_eq!(
317 reviews,
318 "pr:reviews:https://api.github.com/repos/o/r/pulls/1:page2"
319 );
320 assert_eq!(notes, "gitlab:mr:notes:project12:mr34:page1");
321 }
322
323 #[test]
324 fn cache_stats_reexport_matches_contract() {
325 let stats = CacheStats::from_raw_counts(5, 2, 2 * 1024 * 1024 + 77);
326 assert_eq!(stats.total_entries, 5);
327 assert_eq!(stats.expired_entries, 2);
328 assert_eq!(stats.valid_entries, 3);
329 assert_eq!(stats.cache_size_mb, 2);
330 }
331}