Skip to main content

actr_runtime_mailbox/
sqlite.rs

1//! SQLite storage backend implementation
2
3use crate::{
4    error::StorageResult,
5    mailbox::{
6        Mailbox, MailboxDepthObserver, MailboxStats, MessagePriority, MessageRecord, MessageStatus,
7    },
8};
9use async_trait::async_trait;
10use chrono::{DateTime, Utc};
11use rusqlite::{Connection, params};
12use std::{
13    collections::HashMap,
14    path::{Path, PathBuf},
15    sync::{Arc, Mutex},
16};
17use uuid::Uuid;
18
19/// SQLite configuration
20#[derive(Debug, Clone)]
21pub struct SqliteConfig {
22    /// Database file path
23    pub database_path: PathBuf,
24    /// Whether to enable WAL mode
25    pub enable_wal: bool,
26}
27
28impl Default for SqliteConfig {
29    fn default() -> Self {
30        Self {
31            database_path: PathBuf::from("mailbox.db"),
32            enable_wal: true,
33        }
34    }
35}
36
37/// SQLite connection wrapper
38struct SqliteConnection {
39    conn: Mutex<Connection>,
40}
41
42impl SqliteConnection {
43    fn new(config: &SqliteConfig) -> StorageResult<Self> {
44        let conn = Connection::open(&config.database_path)?;
45        if config.enable_wal {
46            conn.execute_batch("PRAGMA journal_mode = WAL;")?;
47        }
48        Self::create_tables(&conn)?;
49        Ok(Self {
50            conn: Mutex::new(conn),
51        })
52    }
53
54    fn create_tables(conn: &Connection) -> StorageResult<()> {
55        conn.execute_batch(
56            r#"
57            CREATE TABLE IF NOT EXISTS messages (
58                id TEXT PRIMARY KEY,
59                from_actr_id BLOB NOT NULL,  -- ActrId Protobuf bytes (all messages must have a sender)
60                payload BLOB NOT NULL,
61                priority INTEGER NOT NULL,
62                status INTEGER NOT NULL DEFAULT 0, -- 0: Queued, 1: Inflight
63                created_at TEXT NOT NULL
64            );
65            CREATE INDEX IF NOT EXISTS idx_messages_priority_status ON messages(priority DESC, status, created_at ASC);
66            "#,
67        )?;
68        Ok(())
69    }
70}
71
72/// SQLite mailbox implementation
73pub struct SqliteMailbox {
74    connection: Arc<SqliteConnection>,
75    /// Optional depth observer invoked after every successful
76    /// [`SqliteMailbox::enqueue`] with the post-enqueue queued-message
77    /// count. `None` means no observer is installed (the caller is
78    /// expected to poll via [`Mailbox::status`] instead).
79    depth_observer: Arc<Mutex<Option<Arc<dyn MailboxDepthObserver>>>>,
80}
81
82impl SqliteMailbox {
83    pub async fn new<P: AsRef<Path>>(database_path: P) -> StorageResult<Self> {
84        let config = SqliteConfig {
85            database_path: database_path.as_ref().to_path_buf(),
86            ..Default::default()
87        };
88        Self::with_config(config).await
89    }
90
91    pub async fn with_config(config: SqliteConfig) -> StorageResult<Self> {
92        let connection = Arc::new(SqliteConnection::new(&config)?);
93        Ok(Self {
94            connection,
95            depth_observer: Arc::new(Mutex::new(None)),
96        })
97    }
98
99    /// Cheap read of the currently-installed depth observer, used from
100    /// the enqueue hot path. Returns `None` if no observer is installed.
101    fn current_depth_observer(&self) -> Option<Arc<dyn MailboxDepthObserver>> {
102        self.depth_observer
103            .lock()
104            .expect("depth_observer mutex poisoned")
105            .clone()
106    }
107}
108
109const DEFAULT_BATCH_SIZE: u32 = 32;
110
111#[async_trait]
112impl Mailbox for SqliteMailbox {
113    async fn enqueue(
114        &self,
115        from: Vec<u8>,
116        payload: Vec<u8>,
117        priority: MessagePriority,
118    ) -> StorageResult<Uuid> {
119        let id = Uuid::new_v4();
120        let observer = self.current_depth_observer();
121
122        // `from` is already Protobuf bytes, store directly.
123        //
124        // When a depth observer is installed, compute the post-enqueue
125        // `queued_messages` count while we still hold the connection
126        // Mutex — this keeps the observer notification monotonic with
127        // respect to concurrent `ack`s and avoids a second round-trip.
128        let depth = {
129            let conn = self.connection.conn.lock().unwrap();
130            conn.execute(
131                "INSERT INTO messages (id, from_actr_id, payload, priority, status, created_at) VALUES (?1, ?2, ?3, ?4, 0, ?5)",
132                params![
133                    id.to_string(),
134                    from,
135                    payload,
136                    priority as i64,
137                    Utc::now().to_rfc3339(),
138                ],
139            )?;
140            if observer.is_some() {
141                let queued: i64 = conn.query_row(
142                    "SELECT COUNT(*) FROM messages WHERE status = 0",
143                    [],
144                    |row| row.get(0),
145                )?;
146                Some(queued.max(0) as usize)
147            } else {
148                None
149            }
150        };
151
152        if let (Some(observer), Some(queued)) = (observer, depth) {
153            observer.on_depth_change(queued);
154        }
155
156        Ok(id)
157    }
158
159    async fn dequeue(&self) -> StorageResult<Vec<MessageRecord>> {
160        let conn = self.connection.conn.lock().unwrap();
161        let mut stmt = conn.prepare(
162            r#"
163            UPDATE messages
164            SET status = 1 -- Inflight
165            WHERE id IN (
166                SELECT id FROM messages
167                WHERE status = 0 -- Queued
168                ORDER BY priority DESC, created_at ASC
169                LIMIT ?1
170            )
171            RETURNING id, from_actr_id, payload, priority, created_at, status;
172            "#,
173        )?;
174
175        let mut messages = stmt
176            .query_map(params![DEFAULT_BATCH_SIZE], |row| {
177                // Return from_actr_id as raw bytes without deserializing
178                let from: Vec<u8> = row.get(1)?;
179
180                let priority_val: i64 = row.get(3)?;
181                let id_str: String = row.get(0)?;
182                let id = Uuid::parse_str(&id_str).map_err(|e| {
183                    rusqlite::Error::FromSqlConversionFailure(
184                        0,
185                        rusqlite::types::Type::Text,
186                        Box::new(e),
187                    )
188                })?;
189                let created_at_str: String = row.get(4)?;
190                let created_at = DateTime::parse_from_rfc3339(&created_at_str)
191                    .map_err(|e| {
192                        rusqlite::Error::FromSqlConversionFailure(
193                            4,
194                            rusqlite::types::Type::Text,
195                            Box::new(e),
196                        )
197                    })?
198                    .with_timezone(&Utc);
199                Ok(MessageRecord {
200                    id,
201                    from,
202                    payload: row.get(2)?,
203                    priority: if priority_val == 1 {
204                        MessagePriority::High
205                    } else {
206                        MessagePriority::Normal
207                    },
208                    created_at,
209                    status: if row.get::<_, i64>(5)? == 1 {
210                        MessageStatus::Inflight
211                    } else {
212                        MessageStatus::Queued
213                    },
214                })
215            })?
216            .collect::<Result<Vec<_>, _>>()?;
217
218        // The order of rows from a RETURNING clause is not guaranteed.
219        // We must sort in memory to ensure priority is respected.
220        messages.sort_unstable_by(|a, b| {
221            b.priority
222                .cmp(&a.priority)
223                .then_with(|| a.created_at.cmp(&b.created_at))
224        });
225
226        Ok(messages)
227    }
228
229    async fn ack(&self, message_id: Uuid) -> StorageResult<()> {
230        let conn = self.connection.conn.lock().unwrap();
231        conn.execute(
232            "DELETE FROM messages WHERE id = ?1",
233            params![message_id.to_string()],
234        )?;
235        Ok(())
236    }
237
238    async fn status(&self) -> StorageResult<MailboxStats> {
239        let conn = self.connection.conn.lock().unwrap();
240        let queued_messages: u64 = conn.query_row(
241            "SELECT COUNT(*) FROM messages WHERE status = 0",
242            [],
243            |row| row.get(0),
244        )?;
245        let inflight_messages: u64 = conn.query_row(
246            "SELECT COUNT(*) FROM messages WHERE status = 1",
247            [],
248            |row| row.get(0),
249        )?;
250
251        let mut queued_by_priority = HashMap::new();
252        let mut stmt = conn.prepare(
253            "SELECT priority, COUNT(*) FROM messages WHERE status = 0 GROUP BY priority",
254        )?;
255        let rows = stmt.query_map([], |row| {
256            let priority_val: i64 = row.get(0)?;
257            let count: u64 = row.get(1)?;
258            Ok((priority_val, count))
259        })?;
260
261        for row in rows {
262            let (priority_val, count) = row?;
263            let priority = if priority_val == 1 {
264                MessagePriority::High
265            } else {
266                MessagePriority::Normal
267            };
268            queued_by_priority.insert(priority, count);
269        }
270
271        Ok(MailboxStats {
272            queued_messages,
273            inflight_messages,
274            queued_by_priority,
275        })
276    }
277
278    fn set_depth_observer(&self, observer: Arc<dyn MailboxDepthObserver>) -> bool {
279        let mut guard = self
280            .depth_observer
281            .lock()
282            .expect("depth_observer mutex poisoned");
283        *guard = Some(observer);
284        true
285    }
286}
287
288#[cfg(test)]
289mod tests {
290    use super::*;
291    use actr_protocol::prost::Message as ProstMessage;
292    use actr_protocol::{ActrId, ActrType, Realm};
293    use tempfile::tempdir;
294
295    async fn setup_mailbox() -> SqliteMailbox {
296        let dir = tempdir().unwrap();
297        let path = dir.path().join("test.db");
298        SqliteMailbox::new(&path).await.unwrap()
299    }
300
301    fn dummy_actr_id_bytes() -> Vec<u8> {
302        let actr_id = ActrId {
303            realm: Realm { realm_id: 1 },
304            serial_number: 1000,
305            r#type: ActrType {
306                manufacturer: "test".to_string(),
307                name: "TestActor".to_string(),
308                version: "1.0.0".to_string(),
309            },
310        };
311        let mut buf = Vec::new();
312        actr_id.encode(&mut buf).unwrap();
313        buf
314    }
315
316    #[tokio::test]
317    async fn test_enqueue_dequeue_ack_workflow() {
318        let mailbox = setup_mailbox().await;
319
320        // 1. Enqueue
321        let from = dummy_actr_id_bytes();
322        let payload = b"hello".to_vec();
323        let msg_id = mailbox
324            .enqueue(from.clone(), payload.clone(), MessagePriority::Normal)
325            .await
326            .unwrap();
327
328        // 2. Dequeue
329        let messages = mailbox.dequeue().await.unwrap();
330        assert_eq!(messages.len(), 1);
331        assert_eq!(messages[0].id, msg_id);
332        assert_eq!(messages[0].from, from);
333        assert_eq!(messages[0].payload, payload);
334        assert_eq!(messages[0].status, MessageStatus::Inflight);
335
336        // 3. Dequeue again, should be empty
337        let messages_again = mailbox.dequeue().await.unwrap();
338        assert!(messages_again.is_empty());
339
340        // 4. Ack
341        mailbox.ack(msg_id).await.unwrap();
342
343        // 5. Check status, should be empty
344        let stats = mailbox.status().await.unwrap();
345        assert_eq!(stats.queued_messages, 0);
346        assert_eq!(stats.inflight_messages, 0);
347    }
348
349    #[tokio::test]
350    async fn test_priority_order() {
351        let mailbox = setup_mailbox().await;
352
353        let from = dummy_actr_id_bytes();
354        let normal_id = mailbox
355            .enqueue(from.clone(), b"normal".to_vec(), MessagePriority::Normal)
356            .await
357            .unwrap();
358        let high_id = mailbox
359            .enqueue(from.clone(), b"high".to_vec(), MessagePriority::High)
360            .await
361            .unwrap();
362
363        // Dequeue should return both messages, with the high priority one first.
364        let messages = mailbox.dequeue().await.unwrap();
365        assert_eq!(messages.len(), 2);
366        assert_eq!(messages[0].id, high_id); // High priority first
367        assert_eq!(messages[1].id, normal_id); // Normal priority second
368    }
369
370    #[tokio::test]
371    async fn test_depth_observer_fires_on_enqueue() {
372        use std::sync::atomic::{AtomicUsize, Ordering};
373
374        struct CountingObserver {
375            latest_depth: Arc<AtomicUsize>,
376            calls: Arc<AtomicUsize>,
377        }
378        impl MailboxDepthObserver for CountingObserver {
379            fn on_depth_change(&self, queued_messages: usize) {
380                self.latest_depth.store(queued_messages, Ordering::SeqCst);
381                self.calls.fetch_add(1, Ordering::SeqCst);
382            }
383        }
384
385        let mailbox = setup_mailbox().await;
386        let latest = Arc::new(AtomicUsize::new(0));
387        let calls = Arc::new(AtomicUsize::new(0));
388        let installed = mailbox.set_depth_observer(Arc::new(CountingObserver {
389            latest_depth: latest.clone(),
390            calls: calls.clone(),
391        }));
392        assert!(installed, "SQLite backend must support push notifications");
393
394        let from = dummy_actr_id_bytes();
395        mailbox
396            .enqueue(from.clone(), b"a".to_vec(), MessagePriority::Normal)
397            .await
398            .unwrap();
399        mailbox
400            .enqueue(from.clone(), b"b".to_vec(), MessagePriority::Normal)
401            .await
402            .unwrap();
403        mailbox
404            .enqueue(from.clone(), b"c".to_vec(), MessagePriority::High)
405            .await
406            .unwrap();
407
408        assert_eq!(
409            calls.load(Ordering::SeqCst),
410            3,
411            "observer must fire once per enqueue"
412        );
413        assert_eq!(
414            latest.load(Ordering::SeqCst),
415            3,
416            "final depth must reflect all three queued messages"
417        );
418    }
419
420    #[tokio::test]
421    async fn test_status_tracking() {
422        let mailbox = setup_mailbox().await;
423
424        let from = dummy_actr_id_bytes();
425        mailbox
426            .enqueue(from.clone(), b"msg1".to_vec(), MessagePriority::Normal)
427            .await
428            .unwrap();
429        mailbox
430            .enqueue(from.clone(), b"msg2".to_vec(), MessagePriority::Normal)
431            .await
432            .unwrap();
433        mailbox
434            .enqueue(from.clone(), b"msg3".to_vec(), MessagePriority::High)
435            .await
436            .unwrap();
437
438        let initial_stats = mailbox.status().await.unwrap();
439        assert_eq!(initial_stats.queued_messages, 3);
440        assert_eq!(initial_stats.inflight_messages, 0);
441        assert_eq!(
442            initial_stats
443                .queued_by_priority
444                .get(&MessagePriority::Normal),
445            Some(&2)
446        );
447        assert_eq!(
448            initial_stats.queued_by_priority.get(&MessagePriority::High),
449            Some(&1)
450        );
451
452        // Dequeue all available messages (since 3 < DEFAULT_BATCH_SIZE)
453        let dequeued = mailbox.dequeue().await.unwrap();
454        assert_eq!(dequeued.len(), 3);
455
456        let after_dequeue_stats = mailbox.status().await.unwrap();
457        assert_eq!(after_dequeue_stats.queued_messages, 0);
458        assert_eq!(after_dequeue_stats.inflight_messages, 3);
459
460        // Ack the first message (which should be the high priority one)
461        mailbox.ack(dequeued[0].id).await.unwrap();
462
463        let final_stats = mailbox.status().await.unwrap();
464        assert_eq!(final_stats.queued_messages, 0);
465        assert_eq!(final_stats.inflight_messages, 2);
466    }
467}