actr_mailbox/
sqlite.rs

1//! SQLite 存储后端实现
2
3use crate::{
4    error::StorageResult,
5    mailbox::{Mailbox, MailboxStats, MessagePriority, MessageRecord, MessageStatus},
6};
7use async_trait::async_trait;
8use chrono::{DateTime, Utc};
9use rusqlite::{Connection, params};
10use std::{
11    collections::HashMap,
12    path::{Path, PathBuf},
13    sync::{Arc, Mutex},
14};
15use uuid::Uuid;
16
17/// SQLite 配置
18#[derive(Debug, Clone)]
19pub struct SqliteConfig {
20    /// 数据库文件路径
21    pub database_path: PathBuf,
22    /// 是否启用 WAL 模式
23    pub enable_wal: bool,
24}
25
26impl Default for SqliteConfig {
27    fn default() -> Self {
28        Self {
29            database_path: PathBuf::from("mailbox.db"),
30            enable_wal: true,
31        }
32    }
33}
34
35/// SQLite 连接包装器
36struct SqliteConnection {
37    conn: Mutex<Connection>,
38}
39
40impl SqliteConnection {
41    fn new(config: &SqliteConfig) -> StorageResult<Self> {
42        let conn = Connection::open(&config.database_path)?;
43        if config.enable_wal {
44            conn.execute_batch("PRAGMA journal_mode = WAL;")?;
45        }
46        Self::create_tables(&conn)?;
47        Ok(Self {
48            conn: Mutex::new(conn),
49        })
50    }
51
52    fn create_tables(conn: &Connection) -> StorageResult<()> {
53        conn.execute_batch(
54            r#"
55            CREATE TABLE IF NOT EXISTS messages (
56                id TEXT PRIMARY KEY,
57                from_actr_id BLOB NOT NULL,  -- ActrId Protobuf bytes (所有消息必有 sender)
58                payload BLOB NOT NULL,
59                priority INTEGER NOT NULL,
60                status INTEGER NOT NULL DEFAULT 0, -- 0: Queued, 1: Inflight
61                created_at TEXT NOT NULL
62            );
63            CREATE INDEX IF NOT EXISTS idx_messages_priority_status ON messages(priority DESC, status, created_at ASC);
64            "#,
65        )?;
66        Ok(())
67    }
68}
69
70/// SQLite 邮箱实现
71pub struct SqliteMailbox {
72    connection: Arc<SqliteConnection>,
73}
74
75impl SqliteMailbox {
76    pub async fn new<P: AsRef<Path>>(database_path: P) -> StorageResult<Self> {
77        let config = SqliteConfig {
78            database_path: database_path.as_ref().to_path_buf(),
79            ..Default::default()
80        };
81        Self::with_config(config).await
82    }
83
84    pub async fn with_config(config: SqliteConfig) -> StorageResult<Self> {
85        let connection = Arc::new(SqliteConnection::new(&config)?);
86        Ok(Self { connection })
87    }
88}
89
90const DEFAULT_BATCH_SIZE: u32 = 32;
91
92#[async_trait]
93impl Mailbox for SqliteMailbox {
94    async fn enqueue(
95        &self,
96        from: Vec<u8>,
97        payload: Vec<u8>,
98        priority: MessagePriority,
99    ) -> StorageResult<Uuid> {
100        let id = Uuid::new_v4();
101
102        // from 已经是 Protobuf bytes,直接存储
103        let conn = self.connection.conn.lock().unwrap();
104        conn.execute(
105            "INSERT INTO messages (id, from_actr_id, payload, priority, status, created_at) VALUES (?1, ?2, ?3, ?4, 0, ?5)",
106            params![
107                id.to_string(),
108                from,
109                payload,
110                priority as i64,
111                Utc::now().to_rfc3339(),
112            ],
113        )?;
114        Ok(id)
115    }
116
117    async fn dequeue(&self) -> StorageResult<Vec<MessageRecord>> {
118        let conn = self.connection.conn.lock().unwrap();
119        let mut stmt = conn.prepare(
120            r#"
121            UPDATE messages
122            SET status = 1 -- Inflight
123            WHERE id IN (
124                SELECT id FROM messages
125                WHERE status = 0 -- Queued
126                ORDER BY priority DESC, created_at ASC
127                LIMIT ?1
128            )
129            RETURNING id, from_actr_id, payload, priority, created_at, status;
130            "#,
131        )?;
132
133        let mut messages = stmt
134            .query_map(params![DEFAULT_BATCH_SIZE], |row| {
135                // from_actr_id 直接返回 bytes,不反序列化
136                let from: Vec<u8> = row.get(1)?;
137
138                let priority_val: i64 = row.get(3)?;
139                Ok(MessageRecord {
140                    id: Uuid::parse_str(&row.get::<_, String>(0)?).unwrap(),
141                    from,
142                    payload: row.get(2)?,
143                    priority: if priority_val == 1 {
144                        MessagePriority::High
145                    } else {
146                        MessagePriority::Normal
147                    },
148                    created_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(4)?)
149                        .unwrap()
150                        .with_timezone(&Utc),
151                    status: if row.get::<_, i64>(5)? == 1 {
152                        MessageStatus::Inflight
153                    } else {
154                        MessageStatus::Queued
155                    },
156                })
157            })?
158            .collect::<Result<Vec<_>, _>>()?;
159
160        // The order of rows from a RETURNING clause is not guaranteed.
161        // We must sort in memory to ensure priority is respected.
162        messages.sort_unstable_by(|a, b| {
163            b.priority
164                .cmp(&a.priority)
165                .then_with(|| a.created_at.cmp(&b.created_at))
166        });
167
168        Ok(messages)
169    }
170
171    async fn ack(&self, message_id: Uuid) -> StorageResult<()> {
172        let conn = self.connection.conn.lock().unwrap();
173        conn.execute(
174            "DELETE FROM messages WHERE id = ?1",
175            params![message_id.to_string()],
176        )?;
177        Ok(())
178    }
179
180    async fn status(&self) -> StorageResult<MailboxStats> {
181        let conn = self.connection.conn.lock().unwrap();
182        let queued_messages: u64 = conn.query_row(
183            "SELECT COUNT(*) FROM messages WHERE status = 0",
184            [],
185            |row| row.get(0),
186        )?;
187        let inflight_messages: u64 = conn.query_row(
188            "SELECT COUNT(*) FROM messages WHERE status = 1",
189            [],
190            |row| row.get(0),
191        )?;
192
193        let mut queued_by_priority = HashMap::new();
194        let mut stmt = conn.prepare(
195            "SELECT priority, COUNT(*) FROM messages WHERE status = 0 GROUP BY priority",
196        )?;
197        let rows = stmt.query_map([], |row| {
198            let priority_val: i64 = row.get(0)?;
199            let count: u64 = row.get(1)?;
200            Ok((priority_val, count))
201        })?;
202
203        for row in rows {
204            let (priority_val, count) = row?;
205            let priority = if priority_val == 1 {
206                MessagePriority::High
207            } else {
208                MessagePriority::Normal
209            };
210            queued_by_priority.insert(priority, count);
211        }
212
213        Ok(MailboxStats {
214            queued_messages,
215            inflight_messages,
216            queued_by_priority,
217        })
218    }
219}
220
221#[cfg(test)]
222mod tests {
223    use super::*;
224    use actr_protocol::prost::Message as ProstMessage;
225    use actr_protocol::{ActrId, ActrType, Realm};
226    use tempfile::tempdir;
227
228    async fn setup_mailbox() -> SqliteMailbox {
229        let dir = tempdir().unwrap();
230        let path = dir.path().join("test.db");
231        SqliteMailbox::new(&path).await.unwrap()
232    }
233
234    fn dummy_actr_id_bytes() -> Vec<u8> {
235        let actr_id = ActrId {
236            realm: Realm { realm_id: 1 },
237            serial_number: 1000,
238            r#type: ActrType {
239                manufacturer: "test".to_string(),
240                name: "TestActor".to_string(),
241            },
242        };
243        let mut buf = Vec::new();
244        actr_id.encode(&mut buf).unwrap();
245        buf
246    }
247
248    #[tokio::test]
249    async fn test_enqueue_dequeue_ack_workflow() {
250        let mailbox = setup_mailbox().await;
251
252        // 1. Enqueue
253        let from = dummy_actr_id_bytes();
254        let payload = b"hello".to_vec();
255        let msg_id = mailbox
256            .enqueue(from.clone(), payload.clone(), MessagePriority::Normal)
257            .await
258            .unwrap();
259
260        // 2. Dequeue
261        let messages = mailbox.dequeue().await.unwrap();
262        assert_eq!(messages.len(), 1);
263        assert_eq!(messages[0].id, msg_id);
264        assert_eq!(messages[0].from, from);
265        assert_eq!(messages[0].payload, payload);
266        assert_eq!(messages[0].status, MessageStatus::Inflight);
267
268        // 3. Dequeue again, should be empty
269        let messages_again = mailbox.dequeue().await.unwrap();
270        assert!(messages_again.is_empty());
271
272        // 4. Ack
273        mailbox.ack(msg_id).await.unwrap();
274
275        // 5. Check status, should be empty
276        let stats = mailbox.status().await.unwrap();
277        assert_eq!(stats.queued_messages, 0);
278        assert_eq!(stats.inflight_messages, 0);
279    }
280
281    #[tokio::test]
282    async fn test_priority_order() {
283        let mailbox = setup_mailbox().await;
284
285        let from = dummy_actr_id_bytes();
286        let normal_id = mailbox
287            .enqueue(from.clone(), b"normal".to_vec(), MessagePriority::Normal)
288            .await
289            .unwrap();
290        let high_id = mailbox
291            .enqueue(from.clone(), b"high".to_vec(), MessagePriority::High)
292            .await
293            .unwrap();
294
295        // Dequeue should return both messages, with the high priority one first.
296        let messages = mailbox.dequeue().await.unwrap();
297        assert_eq!(messages.len(), 2);
298        assert_eq!(messages[0].id, high_id); // High priority first
299        assert_eq!(messages[1].id, normal_id); // Normal priority second
300    }
301
302    #[tokio::test]
303    async fn test_status_tracking() {
304        let mailbox = setup_mailbox().await;
305
306        let from = dummy_actr_id_bytes();
307        mailbox
308            .enqueue(from.clone(), b"msg1".to_vec(), MessagePriority::Normal)
309            .await
310            .unwrap();
311        mailbox
312            .enqueue(from.clone(), b"msg2".to_vec(), MessagePriority::Normal)
313            .await
314            .unwrap();
315        mailbox
316            .enqueue(from.clone(), b"msg3".to_vec(), MessagePriority::High)
317            .await
318            .unwrap();
319
320        let initial_stats = mailbox.status().await.unwrap();
321        assert_eq!(initial_stats.queued_messages, 3);
322        assert_eq!(initial_stats.inflight_messages, 0);
323        assert_eq!(
324            initial_stats
325                .queued_by_priority
326                .get(&MessagePriority::Normal),
327            Some(&2)
328        );
329        assert_eq!(
330            initial_stats.queued_by_priority.get(&MessagePriority::High),
331            Some(&1)
332        );
333
334        // Dequeue all available messages (since 3 < DEFAULT_BATCH_SIZE)
335        let dequeued = mailbox.dequeue().await.unwrap();
336        assert_eq!(dequeued.len(), 3);
337
338        let after_dequeue_stats = mailbox.status().await.unwrap();
339        assert_eq!(after_dequeue_stats.queued_messages, 0);
340        assert_eq!(after_dequeue_stats.inflight_messages, 3);
341
342        // Ack the first message (which should be the high priority one)
343        mailbox.ack(dequeued[0].id).await.unwrap();
344
345        let final_stats = mailbox.status().await.unwrap();
346        assert_eq!(final_stats.queued_messages, 0);
347        assert_eq!(final_stats.inflight_messages, 2);
348    }
349}