1use std::num::NonZeroU64;
9use std::path::Path;
10use std::time::{SystemTime, UNIX_EPOCH};
11
12use std::sync::Mutex;
13
14use anyhow::Result;
15use rusqlite::Connection;
16
17use crate::types::SnapshotId;
18
19#[derive(derive_more::Debug)]
29pub struct ImageRegistry {
30 #[debug(skip)]
31 conn: Mutex<Connection>,
32 capacity: NonZeroU64,
33}
34
35fn epoch_secs() -> i64 {
37 SystemTime::now()
38 .duration_since(UNIX_EPOCH)
39 .unwrap_or_default()
40 .as_secs()
41 .try_into()
42 .unwrap_or(i64::MAX)
43}
44
45impl ImageRegistry {
46 pub fn open(path: &Path, capacity: NonZeroU64) -> Result<Self> {
57 if let Some(parent) = path.parent() {
58 std::fs::create_dir_all(parent)?;
59 }
60
61 let conn = Connection::open(path)?;
62
63 conn.execute_batch(
64 "PRAGMA journal_mode = WAL;
65 PRAGMA synchronous = NORMAL;",
66 )?;
67
68 conn.execute_batch(
69 "CREATE TABLE IF NOT EXISTS snapshots (
70 key TEXT PRIMARY KEY,
71 snapshot_id TEXT NOT NULL,
72 accessed_at INTEGER NOT NULL
73 );",
74 )?;
75
76 Ok(Self {
77 conn: Mutex::new(conn),
78 capacity,
79 })
80 }
81
82 #[must_use]
86 pub fn get(&self, key: &str) -> Option<SnapshotId> {
87 let now = epoch_secs();
88 let conn = self.conn.lock().ok()?;
89
90 let snapshot: Option<String> = conn
91 .query_row(
92 "SELECT snapshot_id FROM snapshots WHERE key = ?1",
93 [key],
94 |row| row.get(0),
95 )
96 .ok();
97
98 if snapshot.is_some() {
99 let _ = conn.execute(
100 "UPDATE snapshots SET accessed_at = ?1 WHERE key = ?2",
101 rusqlite::params![now, key],
102 );
103 }
104
105 drop(conn);
106 snapshot.map(SnapshotId::new)
107 }
108
109 pub fn put(&self, key: &str, snapshot: &SnapshotId) -> Vec<SnapshotId> {
115 let now = epoch_secs();
116
117 let Ok(conn) = self.conn.lock() else {
118 return Vec::new();
119 };
120
121 let snapshot_id: &str = snapshot.as_ref();
123 let _result = conn.execute(
124 "INSERT OR REPLACE INTO snapshots (key, snapshot_id, accessed_at)
125 VALUES (?1, ?2, ?3)",
126 rusqlite::params![key, snapshot_id, now],
127 );
128
129 drop(conn);
130 self.evict_overflow()
131 }
132
133 #[must_use]
138 pub fn invalidate(&self, key: &str) -> Option<SnapshotId> {
139 let conn = self.conn.lock().ok()?;
140
141 let snapshot: Option<String> = conn
142 .query_row(
143 "SELECT snapshot_id FROM snapshots WHERE key = ?1",
144 [key],
145 |row| row.get(0),
146 )
147 .ok();
148
149 if snapshot.is_some() {
150 let _ = conn.execute("DELETE FROM snapshots WHERE key = ?1", [key]);
151 }
152
153 drop(conn);
154 snapshot.map(SnapshotId::new)
155 }
156
157 #[must_use]
163 pub fn all_snapshot_ids(&self) -> Vec<SnapshotId> {
164 let Ok(conn) = self.conn.lock() else {
165 return Vec::new();
166 };
167 let Ok(mut stmt) = conn.prepare("SELECT snapshot_id FROM snapshots") else {
168 return Vec::new();
169 };
170 stmt.query_map([], |row| row.get::<_, String>(0).map(SnapshotId::new))
171 .ok()
172 .map(|rows| rows.filter_map(Result::ok).collect())
173 .unwrap_or_default()
174 }
175
176 #[must_use]
178 pub fn len(&self) -> u64 {
179 let Ok(conn) = self.conn.lock() else {
180 return 0;
181 };
182 conn.query_row("SELECT COUNT(*) FROM snapshots", [], |row| {
183 row.get::<_, i64>(0)
184 })
185 .unwrap_or(0)
186 .try_into()
187 .unwrap_or(0)
188 }
189
190 #[must_use]
192 pub fn is_empty(&self) -> bool {
193 self.len() == 0
194 }
195
196 fn evict_overflow(&self) -> Vec<SnapshotId> {
199 let count = self.len();
200 let capacity = self.capacity.get();
201 if count <= capacity {
202 return Vec::new();
203 }
204
205 let overflow = count - capacity;
206
207 let Ok(conn) = self.conn.lock() else {
208 return Vec::new();
209 };
210
211 let Ok(mut stmt) =
212 conn.prepare("SELECT snapshot_id FROM snapshots ORDER BY accessed_at ASC LIMIT ?1")
213 else {
214 return Vec::new();
215 };
216
217 let evicted: Vec<SnapshotId> = stmt
218 .query_map([overflow], |row| {
219 row.get::<_, String>(0).map(SnapshotId::new)
220 })
221 .ok()
222 .map(|rows| rows.filter_map(Result::ok).collect())
223 .unwrap_or_default();
224
225 drop(stmt);
227
228 let _deleted = conn.execute(
230 "DELETE FROM snapshots WHERE key IN (
231 SELECT key FROM snapshots ORDER BY accessed_at ASC LIMIT ?1
232 )",
233 [overflow],
234 );
235
236 evicted
237 }
238}
239
240#[cfg(test)]
241#[allow(clippy::unwrap_used, clippy::expect_used)]
242mod tests {
243 use super::*;
244
245 fn open_temp(capacity: u64) -> (ImageRegistry, tempfile::TempDir) {
246 let dir = tempfile::tempdir().expect("failed to create temp dir");
247 let db_path = dir.path().join("registry.db");
248 let capacity = NonZeroU64::new(capacity).expect("capacity must be non-zero");
249 let registry = ImageRegistry::open(&db_path, capacity).expect("failed to open registry");
250 (registry, dir)
251 }
252
253 #[test]
254 fn get_returns_none_for_unknown_key() {
255 let (reg, _dir) = open_temp(10);
256 assert!(reg.get("nonexistent").is_none());
257 }
258
259 #[test]
260 fn put_then_get_returns_snapshot() {
261 let (reg, _dir) = open_temp(10);
262 let snap = SnapshotId::new("snap-abc");
263 let evicted = reg.put("my-key", &snap);
264 assert!(evicted.is_empty());
265
266 let got = reg.get("my-key");
267 assert_eq!(got, Some(SnapshotId::new("snap-abc")));
268 }
269
270 #[test]
271 fn get_updates_access_time() {
272 let (reg, _dir) = open_temp(2);
273
274 reg.put("a", &SnapshotId::new("snap-a"));
276
277 std::thread::sleep(std::time::Duration::from_secs(1));
279
280 reg.put("b", &SnapshotId::new("snap-b"));
281
282 std::thread::sleep(std::time::Duration::from_secs(1));
284 let _ = reg.get("a");
285
286 std::thread::sleep(std::time::Duration::from_secs(1));
289 let evicted = reg.put("c", &SnapshotId::new("snap-c"));
290
291 assert_eq!(evicted.len(), 1);
292 assert_eq!(evicted[0], SnapshotId::new("snap-b"));
293
294 assert!(reg.get("a").is_some());
296 assert!(reg.get("b").is_none());
298 }
299
300 #[test]
301 fn eviction_returns_overflow_entries() {
302 let (reg, _dir) = open_temp(2);
303
304 reg.put("x", &SnapshotId::new("snap-x"));
305 std::thread::sleep(std::time::Duration::from_secs(1));
306 reg.put("y", &SnapshotId::new("snap-y"));
307 std::thread::sleep(std::time::Duration::from_secs(1));
308
309 let evicted = reg.put("z", &SnapshotId::new("snap-z"));
311
312 assert_eq!(evicted.len(), 1);
313 assert_eq!(evicted[0], SnapshotId::new("snap-x"));
314 assert_eq!(reg.len(), 2);
315 }
316
317 #[test]
318 fn survives_reopen() {
319 let dir = tempfile::tempdir().expect("failed to create temp dir");
320 let db_path = dir.path().join("registry.db");
321
322 let capacity = NonZeroU64::new(10).expect("capacity must be non-zero");
323
324 {
325 let reg = ImageRegistry::open(&db_path, capacity).expect("open");
326 reg.put("persistent", &SnapshotId::new("snap-persist"));
327 assert_eq!(reg.len(), 1);
328 }
330
331 let reg2 = ImageRegistry::open(&db_path, capacity).expect("reopen");
332 assert_eq!(reg2.len(), 1);
333 let got = reg2.get("persistent");
334 assert_eq!(got, Some(SnapshotId::new("snap-persist")));
335 }
336
337 #[test]
338 fn all_snapshot_ids_returns_every_entry() {
339 let (reg, _dir) = open_temp(10);
340 assert!(reg.all_snapshot_ids().is_empty());
341
342 reg.put("k1", &SnapshotId::new("forever-a"));
343 reg.put("k2", &SnapshotId::new("forever-b"));
344
345 let mut ids: Vec<String> = reg
346 .all_snapshot_ids()
347 .into_iter()
348 .map(|s| s.to_string())
349 .collect();
350 ids.sort();
351 assert_eq!(ids, vec!["forever-a".to_string(), "forever-b".to_string()]);
352 }
353
354 #[test]
355 fn invalidate_returns_removed_snapshot() {
356 let (reg, _dir) = open_temp(10);
357 let snap = SnapshotId::new("snap-rm");
358 reg.put("to-remove", &snap);
359
360 let removed = reg.invalidate("to-remove");
361 assert_eq!(removed, Some(SnapshotId::new("snap-rm")));
362 assert!(reg.get("to-remove").is_none());
363 assert_eq!(reg.len(), 0);
364
365 let removed2 = reg.invalidate("to-remove");
367 assert!(removed2.is_none());
368 }
369}