1use crate::{
4 error::StorageResult,
5 mailbox::{
6 Mailbox, MailboxDepthObserver, MailboxStats, MessagePriority, MessageRecord, MessageStatus,
7 },
8};
9use async_trait::async_trait;
10use chrono::{DateTime, Utc};
11use rusqlite::{Connection, params};
12use std::{
13 collections::HashMap,
14 path::{Path, PathBuf},
15 sync::{Arc, Mutex},
16};
17use uuid::Uuid;
18
19#[derive(Debug, Clone)]
21pub struct SqliteConfig {
22 pub database_path: PathBuf,
24 pub enable_wal: bool,
26}
27
28impl Default for SqliteConfig {
29 fn default() -> Self {
30 Self {
31 database_path: PathBuf::from("mailbox.db"),
32 enable_wal: true,
33 }
34 }
35}
36
37struct SqliteConnection {
39 conn: Mutex<Connection>,
40}
41
42impl SqliteConnection {
43 fn new(config: &SqliteConfig) -> StorageResult<Self> {
44 let conn = Connection::open(&config.database_path)?;
45 if config.enable_wal {
46 conn.execute_batch("PRAGMA journal_mode = WAL;")?;
47 }
48 Self::create_tables(&conn)?;
49 Ok(Self {
50 conn: Mutex::new(conn),
51 })
52 }
53
54 fn create_tables(conn: &Connection) -> StorageResult<()> {
55 conn.execute_batch(
56 r#"
57 CREATE TABLE IF NOT EXISTS messages (
58 id TEXT PRIMARY KEY,
59 from_actr_id BLOB NOT NULL, -- ActrId Protobuf bytes (all messages must have a sender)
60 payload BLOB NOT NULL,
61 priority INTEGER NOT NULL,
62 status INTEGER NOT NULL DEFAULT 0, -- 0: Queued, 1: Inflight
63 created_at TEXT NOT NULL
64 );
65 CREATE INDEX IF NOT EXISTS idx_messages_priority_status ON messages(priority DESC, status, created_at ASC);
66 "#,
67 )?;
68 Ok(())
69 }
70}
71
72pub struct SqliteMailbox {
74 connection: Arc<SqliteConnection>,
75 depth_observer: Arc<Mutex<Option<Arc<dyn MailboxDepthObserver>>>>,
80}
81
82impl SqliteMailbox {
83 pub async fn new<P: AsRef<Path>>(database_path: P) -> StorageResult<Self> {
84 let config = SqliteConfig {
85 database_path: database_path.as_ref().to_path_buf(),
86 ..Default::default()
87 };
88 Self::with_config(config).await
89 }
90
91 pub async fn with_config(config: SqliteConfig) -> StorageResult<Self> {
92 let connection = Arc::new(SqliteConnection::new(&config)?);
93 Ok(Self {
94 connection,
95 depth_observer: Arc::new(Mutex::new(None)),
96 })
97 }
98
99 fn current_depth_observer(&self) -> Option<Arc<dyn MailboxDepthObserver>> {
102 self.depth_observer
103 .lock()
104 .expect("depth_observer mutex poisoned")
105 .clone()
106 }
107}
108
109const DEFAULT_BATCH_SIZE: u32 = 32;
110
111#[async_trait]
112impl Mailbox for SqliteMailbox {
113 async fn enqueue(
114 &self,
115 from: Vec<u8>,
116 payload: Vec<u8>,
117 priority: MessagePriority,
118 ) -> StorageResult<Uuid> {
119 let id = Uuid::new_v4();
120 let observer = self.current_depth_observer();
121
122 let depth = {
129 let conn = self.connection.conn.lock().unwrap();
130 conn.execute(
131 "INSERT INTO messages (id, from_actr_id, payload, priority, status, created_at) VALUES (?1, ?2, ?3, ?4, 0, ?5)",
132 params![
133 id.to_string(),
134 from,
135 payload,
136 priority as i64,
137 Utc::now().to_rfc3339(),
138 ],
139 )?;
140 if observer.is_some() {
141 let queued: i64 = conn.query_row(
142 "SELECT COUNT(*) FROM messages WHERE status = 0",
143 [],
144 |row| row.get(0),
145 )?;
146 Some(queued.max(0) as usize)
147 } else {
148 None
149 }
150 };
151
152 if let (Some(observer), Some(queued)) = (observer, depth) {
153 observer.on_depth_change(queued);
154 }
155
156 Ok(id)
157 }
158
159 async fn dequeue(&self) -> StorageResult<Vec<MessageRecord>> {
160 let conn = self.connection.conn.lock().unwrap();
161 let mut stmt = conn.prepare(
162 r#"
163 UPDATE messages
164 SET status = 1 -- Inflight
165 WHERE id IN (
166 SELECT id FROM messages
167 WHERE status = 0 -- Queued
168 ORDER BY priority DESC, created_at ASC
169 LIMIT ?1
170 )
171 RETURNING id, from_actr_id, payload, priority, created_at, status;
172 "#,
173 )?;
174
175 let mut messages = stmt
176 .query_map(params![DEFAULT_BATCH_SIZE], |row| {
177 let from: Vec<u8> = row.get(1)?;
179
180 let priority_val: i64 = row.get(3)?;
181 let id_str: String = row.get(0)?;
182 let id = Uuid::parse_str(&id_str).map_err(|e| {
183 rusqlite::Error::FromSqlConversionFailure(
184 0,
185 rusqlite::types::Type::Text,
186 Box::new(e),
187 )
188 })?;
189 let created_at_str: String = row.get(4)?;
190 let created_at = DateTime::parse_from_rfc3339(&created_at_str)
191 .map_err(|e| {
192 rusqlite::Error::FromSqlConversionFailure(
193 4,
194 rusqlite::types::Type::Text,
195 Box::new(e),
196 )
197 })?
198 .with_timezone(&Utc);
199 Ok(MessageRecord {
200 id,
201 from,
202 payload: row.get(2)?,
203 priority: if priority_val == 1 {
204 MessagePriority::High
205 } else {
206 MessagePriority::Normal
207 },
208 created_at,
209 status: if row.get::<_, i64>(5)? == 1 {
210 MessageStatus::Inflight
211 } else {
212 MessageStatus::Queued
213 },
214 })
215 })?
216 .collect::<Result<Vec<_>, _>>()?;
217
218 messages.sort_unstable_by(|a, b| {
221 b.priority
222 .cmp(&a.priority)
223 .then_with(|| a.created_at.cmp(&b.created_at))
224 });
225
226 Ok(messages)
227 }
228
229 async fn ack(&self, message_id: Uuid) -> StorageResult<()> {
230 let conn = self.connection.conn.lock().unwrap();
231 conn.execute(
232 "DELETE FROM messages WHERE id = ?1",
233 params![message_id.to_string()],
234 )?;
235 Ok(())
236 }
237
238 async fn status(&self) -> StorageResult<MailboxStats> {
239 let conn = self.connection.conn.lock().unwrap();
240 let queued_messages: u64 = conn.query_row(
241 "SELECT COUNT(*) FROM messages WHERE status = 0",
242 [],
243 |row| row.get(0),
244 )?;
245 let inflight_messages: u64 = conn.query_row(
246 "SELECT COUNT(*) FROM messages WHERE status = 1",
247 [],
248 |row| row.get(0),
249 )?;
250
251 let mut queued_by_priority = HashMap::new();
252 let mut stmt = conn.prepare(
253 "SELECT priority, COUNT(*) FROM messages WHERE status = 0 GROUP BY priority",
254 )?;
255 let rows = stmt.query_map([], |row| {
256 let priority_val: i64 = row.get(0)?;
257 let count: u64 = row.get(1)?;
258 Ok((priority_val, count))
259 })?;
260
261 for row in rows {
262 let (priority_val, count) = row?;
263 let priority = if priority_val == 1 {
264 MessagePriority::High
265 } else {
266 MessagePriority::Normal
267 };
268 queued_by_priority.insert(priority, count);
269 }
270
271 Ok(MailboxStats {
272 queued_messages,
273 inflight_messages,
274 queued_by_priority,
275 })
276 }
277
278 fn set_depth_observer(&self, observer: Arc<dyn MailboxDepthObserver>) -> bool {
279 let mut guard = self
280 .depth_observer
281 .lock()
282 .expect("depth_observer mutex poisoned");
283 *guard = Some(observer);
284 true
285 }
286}
287
288#[cfg(test)]
289mod tests {
290 use super::*;
291 use actr_protocol::prost::Message as ProstMessage;
292 use actr_protocol::{ActrId, ActrType, Realm};
293 use tempfile::tempdir;
294
295 async fn setup_mailbox() -> SqliteMailbox {
296 let dir = tempdir().unwrap();
297 let path = dir.path().join("test.db");
298 SqliteMailbox::new(&path).await.unwrap()
299 }
300
301 fn dummy_actr_id_bytes() -> Vec<u8> {
302 let actr_id = ActrId {
303 realm: Realm { realm_id: 1 },
304 serial_number: 1000,
305 r#type: ActrType {
306 manufacturer: "test".to_string(),
307 name: "TestActor".to_string(),
308 version: "1.0.0".to_string(),
309 },
310 };
311 let mut buf = Vec::new();
312 actr_id.encode(&mut buf).unwrap();
313 buf
314 }
315
316 #[tokio::test]
317 async fn test_enqueue_dequeue_ack_workflow() {
318 let mailbox = setup_mailbox().await;
319
320 let from = dummy_actr_id_bytes();
322 let payload = b"hello".to_vec();
323 let msg_id = mailbox
324 .enqueue(from.clone(), payload.clone(), MessagePriority::Normal)
325 .await
326 .unwrap();
327
328 let messages = mailbox.dequeue().await.unwrap();
330 assert_eq!(messages.len(), 1);
331 assert_eq!(messages[0].id, msg_id);
332 assert_eq!(messages[0].from, from);
333 assert_eq!(messages[0].payload, payload);
334 assert_eq!(messages[0].status, MessageStatus::Inflight);
335
336 let messages_again = mailbox.dequeue().await.unwrap();
338 assert!(messages_again.is_empty());
339
340 mailbox.ack(msg_id).await.unwrap();
342
343 let stats = mailbox.status().await.unwrap();
345 assert_eq!(stats.queued_messages, 0);
346 assert_eq!(stats.inflight_messages, 0);
347 }
348
349 #[tokio::test]
350 async fn test_priority_order() {
351 let mailbox = setup_mailbox().await;
352
353 let from = dummy_actr_id_bytes();
354 let normal_id = mailbox
355 .enqueue(from.clone(), b"normal".to_vec(), MessagePriority::Normal)
356 .await
357 .unwrap();
358 let high_id = mailbox
359 .enqueue(from.clone(), b"high".to_vec(), MessagePriority::High)
360 .await
361 .unwrap();
362
363 let messages = mailbox.dequeue().await.unwrap();
365 assert_eq!(messages.len(), 2);
366 assert_eq!(messages[0].id, high_id); assert_eq!(messages[1].id, normal_id); }
369
370 #[tokio::test]
371 async fn test_depth_observer_fires_on_enqueue() {
372 use std::sync::atomic::{AtomicUsize, Ordering};
373
374 struct CountingObserver {
375 latest_depth: Arc<AtomicUsize>,
376 calls: Arc<AtomicUsize>,
377 }
378 impl MailboxDepthObserver for CountingObserver {
379 fn on_depth_change(&self, queued_messages: usize) {
380 self.latest_depth.store(queued_messages, Ordering::SeqCst);
381 self.calls.fetch_add(1, Ordering::SeqCst);
382 }
383 }
384
385 let mailbox = setup_mailbox().await;
386 let latest = Arc::new(AtomicUsize::new(0));
387 let calls = Arc::new(AtomicUsize::new(0));
388 let installed = mailbox.set_depth_observer(Arc::new(CountingObserver {
389 latest_depth: latest.clone(),
390 calls: calls.clone(),
391 }));
392 assert!(installed, "SQLite backend must support push notifications");
393
394 let from = dummy_actr_id_bytes();
395 mailbox
396 .enqueue(from.clone(), b"a".to_vec(), MessagePriority::Normal)
397 .await
398 .unwrap();
399 mailbox
400 .enqueue(from.clone(), b"b".to_vec(), MessagePriority::Normal)
401 .await
402 .unwrap();
403 mailbox
404 .enqueue(from.clone(), b"c".to_vec(), MessagePriority::High)
405 .await
406 .unwrap();
407
408 assert_eq!(
409 calls.load(Ordering::SeqCst),
410 3,
411 "observer must fire once per enqueue"
412 );
413 assert_eq!(
414 latest.load(Ordering::SeqCst),
415 3,
416 "final depth must reflect all three queued messages"
417 );
418 }
419
420 #[tokio::test]
421 async fn test_status_tracking() {
422 let mailbox = setup_mailbox().await;
423
424 let from = dummy_actr_id_bytes();
425 mailbox
426 .enqueue(from.clone(), b"msg1".to_vec(), MessagePriority::Normal)
427 .await
428 .unwrap();
429 mailbox
430 .enqueue(from.clone(), b"msg2".to_vec(), MessagePriority::Normal)
431 .await
432 .unwrap();
433 mailbox
434 .enqueue(from.clone(), b"msg3".to_vec(), MessagePriority::High)
435 .await
436 .unwrap();
437
438 let initial_stats = mailbox.status().await.unwrap();
439 assert_eq!(initial_stats.queued_messages, 3);
440 assert_eq!(initial_stats.inflight_messages, 0);
441 assert_eq!(
442 initial_stats
443 .queued_by_priority
444 .get(&MessagePriority::Normal),
445 Some(&2)
446 );
447 assert_eq!(
448 initial_stats.queued_by_priority.get(&MessagePriority::High),
449 Some(&1)
450 );
451
452 let dequeued = mailbox.dequeue().await.unwrap();
454 assert_eq!(dequeued.len(), 3);
455
456 let after_dequeue_stats = mailbox.status().await.unwrap();
457 assert_eq!(after_dequeue_stats.queued_messages, 0);
458 assert_eq!(after_dequeue_stats.inflight_messages, 3);
459
460 mailbox.ack(dequeued[0].id).await.unwrap();
462
463 let final_stats = mailbox.status().await.unwrap();
464 assert_eq!(final_stats.queued_messages, 0);
465 assert_eq!(final_stats.inflight_messages, 2);
466 }
467}