1use anyhow::{Context, Result};
7use rusqlite::Connection;
8use std::path::Path;
9use std::sync::Mutex;
10
11use super::types::StoredEntry;
12use super::unix_timestamp_millis;
13
14#[allow(dead_code)]
16pub const DEFAULT_MAX_ENTRIES: usize = 10_000;
17
18pub struct EntryRepository {
22 connection: Mutex<Connection>,
23 max_entries: usize,
24}
25
26impl EntryRepository {
27 pub fn new(path: Option<&Path>, max_entries: usize) -> Result<Self> {
36 let connection = match path {
37 Some(path) => Connection::open(path).context("Failed to open SQLite database file")?,
38 None => {
39 Connection::open_in_memory().context("Failed to open in-memory SQLite database")?
40 }
41 };
42
43 if path.is_some() {
45 connection
46 .execute_batch("PRAGMA journal_mode=WAL;")
47 .context("Failed to enable WAL mode")?;
48 }
49
50 Self::create_schema(&connection)?;
51
52 Ok(Self {
53 connection: Mutex::new(connection),
54 max_entries,
55 })
56 }
57
58 fn create_schema(connection: &Connection) -> Result<()> {
60 connection
61 .execute_batch(
62 r#"
63 CREATE TABLE IF NOT EXISTS entries (
64 id TEXT PRIMARY KEY,
65 message_body BLOB,
66 content_type TEXT,
67 acked INTEGER DEFAULT 0,
68 expires_at INTEGER NOT NULL,
69 created_at INTEGER NOT NULL
70 );
71 CREATE INDEX IF NOT EXISTS idx_expires_at ON entries(expires_at);
72 CREATE INDEX IF NOT EXISTS idx_created_at ON entries(created_at);
73 "#,
74 )
75 .context("Failed to create database schema")?;
76 Ok(())
77 }
78
79 pub fn insert(
90 &self,
91 id: &str,
92 body: &[u8],
93 content_type: Option<&str>,
94 expires_at: i64,
95 ) -> Result<()> {
96 let connection = self.connection.lock().expect("Mutex poisoned");
97 let created_at = unix_timestamp_millis();
98
99 let count: usize = connection
101 .query_row("SELECT COUNT(*) FROM entries", [], |row| row.get(0))
102 .context("Failed to count entries")?;
103
104 if count >= self.max_entries {
105 let oldest_id: Option<String> = connection
107 .query_row(
108 "SELECT id FROM entries ORDER BY created_at ASC LIMIT 1",
109 [],
110 |row| row.get(0),
111 )
112 .ok();
113
114 if let Some(ref oldest) = oldest_id {
118 if oldest != id {
119 connection
120 .execute("DELETE FROM entries WHERE id = ?1", [oldest])
121 .context("Failed to delete oldest entry for LRU eviction")?;
122 }
123 }
124 }
125
126 connection
129 .execute(
130 "INSERT OR REPLACE INTO entries (id, message_body, content_type, acked, expires_at, created_at) VALUES (?1, ?2, ?3, 0, ?4, ?5)",
131 rusqlite::params![id, body, content_type, expires_at, created_at],
132 )
133 .context("Failed to insert entry")?;
134
135 Ok(())
136 }
137
138 pub fn get(&self, id: &str) -> Result<Option<StoredEntry>> {
142 let connection = self.connection.lock().expect("Mutex poisoned");
143
144 let result = connection.query_row(
145 "SELECT message_body, content_type, acked, expires_at FROM entries WHERE id = ?1",
146 [id],
147 |row| {
148 Ok(StoredEntry {
149 message_body: row.get(0)?,
150 content_type: row.get(1)?,
151 acked: row.get::<_, i64>(2)? != 0,
152 expires_at: row.get(3)?,
153 })
154 },
155 );
156
157 match result {
158 Ok(entry) => Ok(Some(entry)),
159 Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
160 Err(error) => Err(error).context("Failed to get entry"),
161 }
162 }
163
164 pub fn ack(&self, id: &str) -> Result<bool> {
168 let connection = self.connection.lock().expect("Mutex poisoned");
169
170 let rows_affected = connection
171 .execute(
172 "UPDATE entries SET acked = 1, message_body = NULL WHERE id = ?1",
173 [id],
174 )
175 .context("Failed to acknowledge entry")?;
176
177 Ok(rows_affected > 0)
178 }
179
180 #[allow(dead_code)]
184 pub fn delete(&self, id: &str) -> Result<bool> {
185 let connection = self.connection.lock().expect("Mutex poisoned");
186
187 let rows_affected = connection
188 .execute("DELETE FROM entries WHERE id = ?1", [id])
189 .context("Failed to delete entry")?;
190
191 Ok(rows_affected > 0)
192 }
193
194 pub fn cleanup_expired(&self) -> Result<usize> {
198 let connection = self.connection.lock().expect("Mutex poisoned");
199 let current_time = unix_timestamp_millis();
200
201 let rows_deleted = connection
202 .execute("DELETE FROM entries WHERE expires_at < ?1", [current_time])
203 .context("Failed to cleanup expired entries")?;
204
205 Ok(rows_deleted)
206 }
207
208 #[allow(dead_code)]
210 pub fn count(&self) -> Result<usize> {
211 let connection = self.connection.lock().expect("Mutex poisoned");
212
213 let count: usize = connection
214 .query_row("SELECT COUNT(*) FROM entries", [], |row| row.get(0))
215 .context("Failed to count entries")?;
216
217 Ok(count)
218 }
219}
220
221#[cfg(test)]
222mod tests {
223 use super::*;
224
225 fn create_test_repository() -> EntryRepository {
226 EntryRepository::new(None, 100).expect("Failed to create test repository")
227 }
228
229 #[test]
230 fn test_insert_and_get() {
231 let repository = create_test_repository();
232 let expires_at = unix_timestamp_millis() + 3_600_000;
233
234 repository
235 .insert("test-id", b"test body", Some("text/plain"), expires_at)
236 .expect("Failed to insert");
237
238 let entry = repository.get("test-id").expect("Failed to get").unwrap();
239 assert_eq!(entry.message_body, Some(b"test body".to_vec()));
240 assert_eq!(entry.content_type, Some("text/plain".to_string()));
241 assert!(!entry.acked);
242 assert_eq!(entry.expires_at, expires_at);
243 }
244
245 #[test]
246 fn test_get_nonexistent() {
247 let repository = create_test_repository();
248
249 let entry = repository.get("nonexistent").expect("Failed to get");
250 assert!(entry.is_none());
251 }
252
253 #[test]
254 fn test_ack() {
255 let repository = create_test_repository();
256 let expires_at = unix_timestamp_millis() + 3_600_000;
257
258 repository
259 .insert("test-id", b"test body", Some("text/plain"), expires_at)
260 .expect("Failed to insert");
261
262 let was_acked = repository.ack("test-id").expect("Failed to ack");
263 assert!(was_acked);
264
265 let entry = repository.get("test-id").expect("Failed to get").unwrap();
266 assert!(entry.acked);
267 assert!(entry.message_body.is_none());
268 }
269
270 #[test]
271 fn test_ack_nonexistent() {
272 let repository = create_test_repository();
273
274 let was_acked = repository.ack("nonexistent").expect("Failed to ack");
275 assert!(!was_acked);
276 }
277
278 #[test]
279 fn test_delete() {
280 let repository = create_test_repository();
281 let expires_at = unix_timestamp_millis() + 3_600_000;
282
283 repository
284 .insert("test-id", b"test body", None, expires_at)
285 .expect("Failed to insert");
286
287 let was_deleted = repository.delete("test-id").expect("Failed to delete");
288 assert!(was_deleted);
289
290 let entry = repository.get("test-id").expect("Failed to get");
291 assert!(entry.is_none());
292 }
293
294 #[test]
295 fn test_delete_nonexistent() {
296 let repository = create_test_repository();
297
298 let was_deleted = repository.delete("nonexistent").expect("Failed to delete");
299 assert!(!was_deleted);
300 }
301
302 #[test]
303 fn test_cleanup_expired() {
304 let repository = create_test_repository();
305 let past = unix_timestamp_millis() - 3_600_000;
306 let future = unix_timestamp_millis() + 3_600_000;
307
308 repository
309 .insert("expired", b"old", None, past)
310 .expect("Failed to insert");
311 repository
312 .insert("valid", b"new", None, future)
313 .expect("Failed to insert");
314
315 let deleted_count = repository.cleanup_expired().expect("Failed to cleanup");
316 assert_eq!(deleted_count, 1);
317
318 assert!(repository.get("expired").expect("Failed to get").is_none());
319 assert!(repository.get("valid").expect("Failed to get").is_some());
320 }
321
322 #[test]
323 fn test_count() {
324 let repository = create_test_repository();
325 let expires_at = unix_timestamp_millis() + 3_600_000;
326
327 assert_eq!(repository.count().expect("Failed to count"), 0);
328
329 repository
330 .insert("id1", b"body1", None, expires_at)
331 .expect("Failed to insert");
332 assert_eq!(repository.count().expect("Failed to count"), 1);
333
334 repository
335 .insert("id2", b"body2", None, expires_at)
336 .expect("Failed to insert");
337 assert_eq!(repository.count().expect("Failed to count"), 2);
338 }
339
340 #[test]
341 fn test_lru_eviction() {
342 let repository = EntryRepository::new(None, 3).expect("Failed to create repository");
343 let expires_at = unix_timestamp_millis() + 3_600_000;
344
345 repository
346 .insert("id1", b"body1", None, expires_at)
347 .unwrap();
348 repository
349 .insert("id2", b"body2", None, expires_at)
350 .unwrap();
351 repository
352 .insert("id3", b"body3", None, expires_at)
353 .unwrap();
354
355 assert_eq!(repository.count().unwrap(), 3);
356
357 repository
359 .insert("id4", b"body4", None, expires_at)
360 .unwrap();
361
362 assert_eq!(repository.count().unwrap(), 3);
363 assert!(repository.get("id1").unwrap().is_none());
364 assert!(repository.get("id2").unwrap().is_some());
365 assert!(repository.get("id3").unwrap().is_some());
366 assert!(repository.get("id4").unwrap().is_some());
367 }
368
369 #[test]
370 fn test_insert_or_replace() {
371 let repository = create_test_repository();
372 let expires_at = unix_timestamp_millis() + 3_600_000;
373
374 repository
375 .insert("test-id", b"original", Some("text/plain"), expires_at)
376 .expect("Failed to insert");
377
378 repository
379 .insert("test-id", b"replaced", Some("application/json"), expires_at)
380 .expect("Failed to replace");
381
382 let entry = repository.get("test-id").expect("Failed to get").unwrap();
383 assert_eq!(entry.message_body, Some(b"replaced".to_vec()));
384 assert_eq!(entry.content_type, Some("application/json".to_string()));
385 assert_eq!(repository.count().unwrap(), 1);
386 }
387
388 #[test]
389 fn test_lru_eviction_does_not_evict_updated_entry() {
390 let repository = EntryRepository::new(None, 3).expect("Failed to create repository");
392 let expires_at = unix_timestamp_millis() + 3_600_000;
393
394 repository
396 .insert("A", b"original-A", None, expires_at)
397 .unwrap();
398 std::thread::sleep(std::time::Duration::from_millis(1)); repository.insert("B", b"body-B", None, expires_at).unwrap();
400 std::thread::sleep(std::time::Duration::from_millis(1));
401 repository.insert("C", b"body-C", None, expires_at).unwrap();
402
403 assert_eq!(repository.count().unwrap(), 3);
404
405 repository
407 .insert("A", b"updated-A", None, expires_at)
408 .unwrap();
409
410 assert_eq!(repository.count().unwrap(), 3);
412
413 let entry_a = repository
414 .get("A")
415 .expect("Failed to get A")
416 .expect("A should exist");
417 assert_eq!(entry_a.message_body, Some(b"updated-A".to_vec()));
418
419 assert!(
420 repository.get("B").unwrap().is_some(),
421 "B should still exist"
422 );
423 assert!(
424 repository.get("C").unwrap().is_some(),
425 "C should still exist"
426 );
427 }
428}