Skip to main content

http_relay/http_relay/
persistence.rs

1//! SQLite-based persistence for relay entries.
2//!
3//! Provides durable storage with automatic LRU eviction when the entry count
4//! exceeds the configured maximum. Supports both file-based and in-memory SQLite.
5
6use anyhow::{Context, Result};
7use rusqlite::Connection;
8use std::path::Path;
9use std::sync::Mutex;
10
11use super::types::StoredEntry;
12use super::unix_timestamp_millis;
13
14/// Default maximum number of entries before LRU eviction kicks in.
15#[allow(dead_code)]
16pub const DEFAULT_MAX_ENTRIES: usize = 10_000;
17
18/// Repository for persisting relay entries to SQLite.
19///
20/// Thread-safe via internal Mutex. Uses WAL mode for better concurrency.
21pub struct EntryRepository {
22    connection: Mutex<Connection>,
23    max_entries: usize,
24}
25
26impl EntryRepository {
27    /// Creates a new repository.
28    ///
29    /// # Arguments
30    /// * `path` - File path for SQLite database, or None for in-memory database
31    /// * `max_entries` - Maximum entries before LRU eviction (oldest by created_at)
32    ///
33    /// # Errors
34    /// Returns error if database connection or schema creation fails.
35    pub fn new(path: Option<&Path>, max_entries: usize) -> Result<Self> {
36        let connection = match path {
37            Some(path) => Connection::open(path).context("Failed to open SQLite database file")?,
38            None => {
39                Connection::open_in_memory().context("Failed to open in-memory SQLite database")?
40            }
41        };
42
43        // Enable WAL mode for better concurrency (only for file-based databases)
44        if path.is_some() {
45            connection
46                .execute_batch("PRAGMA journal_mode=WAL;")
47                .context("Failed to enable WAL mode")?;
48        }
49
50        Self::create_schema(&connection)?;
51
52        Ok(Self {
53            connection: Mutex::new(connection),
54            max_entries,
55        })
56    }
57
58    /// Creates the database schema if it doesn't exist.
59    fn create_schema(connection: &Connection) -> Result<()> {
60        connection
61            .execute_batch(
62                r#"
63                CREATE TABLE IF NOT EXISTS entries (
64                    id TEXT PRIMARY KEY,
65                    message_body BLOB,
66                    content_type TEXT,
67                    acked INTEGER DEFAULT 0,
68                    expires_at INTEGER NOT NULL,
69                    created_at INTEGER NOT NULL
70                );
71                CREATE INDEX IF NOT EXISTS idx_expires_at ON entries(expires_at);
72                CREATE INDEX IF NOT EXISTS idx_created_at ON entries(created_at);
73                "#,
74            )
75            .context("Failed to create database schema")?;
76        Ok(())
77    }
78
79    /// Inserts or replaces an entry in the database.
80    ///
81    /// If the entry count exceeds `max_entries`, the oldest entry (by created_at)
82    /// is deleted first to make room (LRU eviction).
83    ///
84    /// # Arguments
85    /// * `id` - Unique identifier for the entry
86    /// * `body` - Message body bytes
87    /// * `content_type` - Optional content type header
88    /// * `expires_at` - Unix timestamp when entry expires
89    pub fn insert(
90        &self,
91        id: &str,
92        body: &[u8],
93        content_type: Option<&str>,
94        expires_at: i64,
95    ) -> Result<()> {
96        let connection = self.connection.lock().expect("Mutex poisoned");
97        let created_at = unix_timestamp_millis();
98
99        // Check if we need to evict oldest entry (LRU eviction for disk overflow protection)
100        let count: usize = connection
101            .query_row("SELECT COUNT(*) FROM entries", [], |row| row.get(0))
102            .context("Failed to count entries")?;
103
104        if count >= self.max_entries {
105            // Find the oldest entry's ID before evicting
106            let oldest_id: Option<String> = connection
107                .query_row(
108                    "SELECT id FROM entries ORDER BY created_at ASC LIMIT 1",
109                    [],
110                    |row| row.get(0),
111                )
112                .ok();
113
114            // Only evict if the oldest entry is NOT the one we're about to update.
115            // If we're updating the oldest entry, INSERT OR REPLACE will handle it
116            // without needing eviction (no net change in entry count).
117            if let Some(ref oldest) = oldest_id {
118                if oldest != id {
119                    connection
120                        .execute("DELETE FROM entries WHERE id = ?1", [oldest])
121                        .context("Failed to delete oldest entry for LRU eviction")?;
122                }
123            }
124        }
125
126        // Always reset acked=0: a new message requires a new acknowledgment from the consumer.
127        // Previous ACKs were for previous messages and don't carry over to new content.
128        connection
129            .execute(
130                "INSERT OR REPLACE INTO entries (id, message_body, content_type, acked, expires_at, created_at) VALUES (?1, ?2, ?3, 0, ?4, ?5)",
131                rusqlite::params![id, body, content_type, expires_at, created_at],
132            )
133            .context("Failed to insert entry")?;
134
135        Ok(())
136    }
137
138    /// Retrieves an entry by ID.
139    ///
140    /// Returns None if the entry doesn't exist.
141    pub fn get(&self, id: &str) -> Result<Option<StoredEntry>> {
142        let connection = self.connection.lock().expect("Mutex poisoned");
143
144        let result = connection.query_row(
145            "SELECT message_body, content_type, acked, expires_at FROM entries WHERE id = ?1",
146            [id],
147            |row| {
148                Ok(StoredEntry {
149                    message_body: row.get(0)?,
150                    content_type: row.get(1)?,
151                    acked: row.get::<_, i64>(2)? != 0,
152                    expires_at: row.get(3)?,
153                })
154            },
155        );
156
157        match result {
158            Ok(entry) => Ok(Some(entry)),
159            Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
160            Err(error) => Err(error).context("Failed to get entry"),
161        }
162    }
163
164    /// Marks an entry as acknowledged and clears its message body.
165    ///
166    /// Returns true if an entry was updated, false if the entry didn't exist.
167    pub fn ack(&self, id: &str) -> Result<bool> {
168        let connection = self.connection.lock().expect("Mutex poisoned");
169
170        let rows_affected = connection
171            .execute(
172                "UPDATE entries SET acked = 1, message_body = NULL WHERE id = ?1",
173                [id],
174            )
175            .context("Failed to acknowledge entry")?;
176
177        Ok(rows_affected > 0)
178    }
179
180    /// Deletes an entry by ID.
181    ///
182    /// Returns true if an entry was deleted, false if the entry didn't exist.
183    #[allow(dead_code)]
184    pub fn delete(&self, id: &str) -> Result<bool> {
185        let connection = self.connection.lock().expect("Mutex poisoned");
186
187        let rows_affected = connection
188            .execute("DELETE FROM entries WHERE id = ?1", [id])
189            .context("Failed to delete entry")?;
190
191        Ok(rows_affected > 0)
192    }
193
194    /// Deletes all expired entries.
195    ///
196    /// Returns the number of entries deleted.
197    pub fn cleanup_expired(&self) -> Result<usize> {
198        let connection = self.connection.lock().expect("Mutex poisoned");
199        let current_time = unix_timestamp_millis();
200
201        let rows_deleted = connection
202            .execute("DELETE FROM entries WHERE expires_at < ?1", [current_time])
203            .context("Failed to cleanup expired entries")?;
204
205        Ok(rows_deleted)
206    }
207
208    /// Returns the total number of entries in the database.
209    #[allow(dead_code)]
210    pub fn count(&self) -> Result<usize> {
211        let connection = self.connection.lock().expect("Mutex poisoned");
212
213        let count: usize = connection
214            .query_row("SELECT COUNT(*) FROM entries", [], |row| row.get(0))
215            .context("Failed to count entries")?;
216
217        Ok(count)
218    }
219}
220
221#[cfg(test)]
222mod tests {
223    use super::*;
224
225    fn create_test_repository() -> EntryRepository {
226        EntryRepository::new(None, 100).expect("Failed to create test repository")
227    }
228
229    #[test]
230    fn test_insert_and_get() {
231        let repository = create_test_repository();
232        let expires_at = unix_timestamp_millis() + 3_600_000;
233
234        repository
235            .insert("test-id", b"test body", Some("text/plain"), expires_at)
236            .expect("Failed to insert");
237
238        let entry = repository.get("test-id").expect("Failed to get").unwrap();
239        assert_eq!(entry.message_body, Some(b"test body".to_vec()));
240        assert_eq!(entry.content_type, Some("text/plain".to_string()));
241        assert!(!entry.acked);
242        assert_eq!(entry.expires_at, expires_at);
243    }
244
245    #[test]
246    fn test_get_nonexistent() {
247        let repository = create_test_repository();
248
249        let entry = repository.get("nonexistent").expect("Failed to get");
250        assert!(entry.is_none());
251    }
252
253    #[test]
254    fn test_ack() {
255        let repository = create_test_repository();
256        let expires_at = unix_timestamp_millis() + 3_600_000;
257
258        repository
259            .insert("test-id", b"test body", Some("text/plain"), expires_at)
260            .expect("Failed to insert");
261
262        let was_acked = repository.ack("test-id").expect("Failed to ack");
263        assert!(was_acked);
264
265        let entry = repository.get("test-id").expect("Failed to get").unwrap();
266        assert!(entry.acked);
267        assert!(entry.message_body.is_none());
268    }
269
270    #[test]
271    fn test_ack_nonexistent() {
272        let repository = create_test_repository();
273
274        let was_acked = repository.ack("nonexistent").expect("Failed to ack");
275        assert!(!was_acked);
276    }
277
278    #[test]
279    fn test_delete() {
280        let repository = create_test_repository();
281        let expires_at = unix_timestamp_millis() + 3_600_000;
282
283        repository
284            .insert("test-id", b"test body", None, expires_at)
285            .expect("Failed to insert");
286
287        let was_deleted = repository.delete("test-id").expect("Failed to delete");
288        assert!(was_deleted);
289
290        let entry = repository.get("test-id").expect("Failed to get");
291        assert!(entry.is_none());
292    }
293
294    #[test]
295    fn test_delete_nonexistent() {
296        let repository = create_test_repository();
297
298        let was_deleted = repository.delete("nonexistent").expect("Failed to delete");
299        assert!(!was_deleted);
300    }
301
302    #[test]
303    fn test_cleanup_expired() {
304        let repository = create_test_repository();
305        let past = unix_timestamp_millis() - 3_600_000;
306        let future = unix_timestamp_millis() + 3_600_000;
307
308        repository
309            .insert("expired", b"old", None, past)
310            .expect("Failed to insert");
311        repository
312            .insert("valid", b"new", None, future)
313            .expect("Failed to insert");
314
315        let deleted_count = repository.cleanup_expired().expect("Failed to cleanup");
316        assert_eq!(deleted_count, 1);
317
318        assert!(repository.get("expired").expect("Failed to get").is_none());
319        assert!(repository.get("valid").expect("Failed to get").is_some());
320    }
321
322    #[test]
323    fn test_count() {
324        let repository = create_test_repository();
325        let expires_at = unix_timestamp_millis() + 3_600_000;
326
327        assert_eq!(repository.count().expect("Failed to count"), 0);
328
329        repository
330            .insert("id1", b"body1", None, expires_at)
331            .expect("Failed to insert");
332        assert_eq!(repository.count().expect("Failed to count"), 1);
333
334        repository
335            .insert("id2", b"body2", None, expires_at)
336            .expect("Failed to insert");
337        assert_eq!(repository.count().expect("Failed to count"), 2);
338    }
339
340    #[test]
341    fn test_lru_eviction() {
342        let repository = EntryRepository::new(None, 3).expect("Failed to create repository");
343        let expires_at = unix_timestamp_millis() + 3_600_000;
344
345        repository
346            .insert("id1", b"body1", None, expires_at)
347            .unwrap();
348        repository
349            .insert("id2", b"body2", None, expires_at)
350            .unwrap();
351        repository
352            .insert("id3", b"body3", None, expires_at)
353            .unwrap();
354
355        assert_eq!(repository.count().unwrap(), 3);
356
357        // Insert fourth entry, should evict id1 (oldest)
358        repository
359            .insert("id4", b"body4", None, expires_at)
360            .unwrap();
361
362        assert_eq!(repository.count().unwrap(), 3);
363        assert!(repository.get("id1").unwrap().is_none());
364        assert!(repository.get("id2").unwrap().is_some());
365        assert!(repository.get("id3").unwrap().is_some());
366        assert!(repository.get("id4").unwrap().is_some());
367    }
368
369    #[test]
370    fn test_insert_or_replace() {
371        let repository = create_test_repository();
372        let expires_at = unix_timestamp_millis() + 3_600_000;
373
374        repository
375            .insert("test-id", b"original", Some("text/plain"), expires_at)
376            .expect("Failed to insert");
377
378        repository
379            .insert("test-id", b"replaced", Some("application/json"), expires_at)
380            .expect("Failed to replace");
381
382        let entry = repository.get("test-id").expect("Failed to get").unwrap();
383        assert_eq!(entry.message_body, Some(b"replaced".to_vec()));
384        assert_eq!(entry.content_type, Some("application/json".to_string()));
385        assert_eq!(repository.count().unwrap(), 1);
386    }
387
388    #[test]
389    fn test_lru_eviction_does_not_evict_updated_entry() {
390        // Regression test: updating the oldest entry at capacity should NOT evict it
391        let repository = EntryRepository::new(None, 3).expect("Failed to create repository");
392        let expires_at = unix_timestamp_millis() + 3_600_000;
393
394        // Fill repository to capacity: A is oldest, then B, then C
395        repository
396            .insert("A", b"original-A", None, expires_at)
397            .unwrap();
398        std::thread::sleep(std::time::Duration::from_millis(1)); // Ensure distinct timestamps
399        repository.insert("B", b"body-B", None, expires_at).unwrap();
400        std::thread::sleep(std::time::Duration::from_millis(1));
401        repository.insert("C", b"body-C", None, expires_at).unwrap();
402
403        assert_eq!(repository.count().unwrap(), 3);
404
405        // Update entry A (the oldest) with new data - should NOT evict A
406        repository
407            .insert("A", b"updated-A", None, expires_at)
408            .unwrap();
409
410        // All three entries should still exist
411        assert_eq!(repository.count().unwrap(), 3);
412
413        let entry_a = repository
414            .get("A")
415            .expect("Failed to get A")
416            .expect("A should exist");
417        assert_eq!(entry_a.message_body, Some(b"updated-A".to_vec()));
418
419        assert!(
420            repository.get("B").unwrap().is_some(),
421            "B should still exist"
422        );
423        assert!(
424            repository.get("C").unwrap().is_some(),
425            "C should still exist"
426        );
427    }
428}