1use crate::{
4 error::StorageResult,
5 mailbox::{Mailbox, MailboxStats, MessagePriority, MessageRecord, MessageStatus},
6};
7use async_trait::async_trait;
8use chrono::{DateTime, Utc};
9use rusqlite::{Connection, params};
10use std::{
11 collections::HashMap,
12 path::{Path, PathBuf},
13 sync::{Arc, Mutex},
14};
15use uuid::Uuid;
16
17#[derive(Debug, Clone)]
19pub struct SqliteConfig {
20 pub database_path: PathBuf,
22 pub enable_wal: bool,
24}
25
26impl Default for SqliteConfig {
27 fn default() -> Self {
28 Self {
29 database_path: PathBuf::from("mailbox.db"),
30 enable_wal: true,
31 }
32 }
33}
34
35struct SqliteConnection {
37 conn: Mutex<Connection>,
38}
39
40impl SqliteConnection {
41 fn new(config: &SqliteConfig) -> StorageResult<Self> {
42 let conn = Connection::open(&config.database_path)?;
43 if config.enable_wal {
44 conn.execute_batch("PRAGMA journal_mode = WAL;")?;
45 }
46 Self::create_tables(&conn)?;
47 Ok(Self {
48 conn: Mutex::new(conn),
49 })
50 }
51
52 fn create_tables(conn: &Connection) -> StorageResult<()> {
53 conn.execute_batch(
54 r#"
55 CREATE TABLE IF NOT EXISTS messages (
56 id TEXT PRIMARY KEY,
57 from_actr_id BLOB NOT NULL, -- ActrId Protobuf bytes (所有消息必有 sender)
58 payload BLOB NOT NULL,
59 priority INTEGER NOT NULL,
60 status INTEGER NOT NULL DEFAULT 0, -- 0: Queued, 1: Inflight
61 created_at TEXT NOT NULL
62 );
63 CREATE INDEX IF NOT EXISTS idx_messages_priority_status ON messages(priority DESC, status, created_at ASC);
64 "#,
65 )?;
66 Ok(())
67 }
68}
69
70pub struct SqliteMailbox {
72 connection: Arc<SqliteConnection>,
73}
74
75impl SqliteMailbox {
76 pub async fn new<P: AsRef<Path>>(database_path: P) -> StorageResult<Self> {
77 let config = SqliteConfig {
78 database_path: database_path.as_ref().to_path_buf(),
79 ..Default::default()
80 };
81 Self::with_config(config).await
82 }
83
84 pub async fn with_config(config: SqliteConfig) -> StorageResult<Self> {
85 let connection = Arc::new(SqliteConnection::new(&config)?);
86 Ok(Self { connection })
87 }
88}
89
90const DEFAULT_BATCH_SIZE: u32 = 32;
91
92#[async_trait]
93impl Mailbox for SqliteMailbox {
94 async fn enqueue(
95 &self,
96 from: Vec<u8>,
97 payload: Vec<u8>,
98 priority: MessagePriority,
99 ) -> StorageResult<Uuid> {
100 let id = Uuid::new_v4();
101
102 let conn = self.connection.conn.lock().unwrap();
104 conn.execute(
105 "INSERT INTO messages (id, from_actr_id, payload, priority, status, created_at) VALUES (?1, ?2, ?3, ?4, 0, ?5)",
106 params![
107 id.to_string(),
108 from,
109 payload,
110 priority as i64,
111 Utc::now().to_rfc3339(),
112 ],
113 )?;
114 Ok(id)
115 }
116
117 async fn dequeue(&self) -> StorageResult<Vec<MessageRecord>> {
118 let conn = self.connection.conn.lock().unwrap();
119 let mut stmt = conn.prepare(
120 r#"
121 UPDATE messages
122 SET status = 1 -- Inflight
123 WHERE id IN (
124 SELECT id FROM messages
125 WHERE status = 0 -- Queued
126 ORDER BY priority DESC, created_at ASC
127 LIMIT ?1
128 )
129 RETURNING id, from_actr_id, payload, priority, created_at, status;
130 "#,
131 )?;
132
133 let mut messages = stmt
134 .query_map(params![DEFAULT_BATCH_SIZE], |row| {
135 let from: Vec<u8> = row.get(1)?;
137
138 let priority_val: i64 = row.get(3)?;
139 Ok(MessageRecord {
140 id: Uuid::parse_str(&row.get::<_, String>(0)?).unwrap(),
141 from,
142 payload: row.get(2)?,
143 priority: if priority_val == 1 {
144 MessagePriority::High
145 } else {
146 MessagePriority::Normal
147 },
148 created_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(4)?)
149 .unwrap()
150 .with_timezone(&Utc),
151 status: if row.get::<_, i64>(5)? == 1 {
152 MessageStatus::Inflight
153 } else {
154 MessageStatus::Queued
155 },
156 })
157 })?
158 .collect::<Result<Vec<_>, _>>()?;
159
160 messages.sort_unstable_by(|a, b| {
163 b.priority
164 .cmp(&a.priority)
165 .then_with(|| a.created_at.cmp(&b.created_at))
166 });
167
168 Ok(messages)
169 }
170
171 async fn ack(&self, message_id: Uuid) -> StorageResult<()> {
172 let conn = self.connection.conn.lock().unwrap();
173 conn.execute(
174 "DELETE FROM messages WHERE id = ?1",
175 params![message_id.to_string()],
176 )?;
177 Ok(())
178 }
179
180 async fn status(&self) -> StorageResult<MailboxStats> {
181 let conn = self.connection.conn.lock().unwrap();
182 let queued_messages: u64 = conn.query_row(
183 "SELECT COUNT(*) FROM messages WHERE status = 0",
184 [],
185 |row| row.get(0),
186 )?;
187 let inflight_messages: u64 = conn.query_row(
188 "SELECT COUNT(*) FROM messages WHERE status = 1",
189 [],
190 |row| row.get(0),
191 )?;
192
193 let mut queued_by_priority = HashMap::new();
194 let mut stmt = conn.prepare(
195 "SELECT priority, COUNT(*) FROM messages WHERE status = 0 GROUP BY priority",
196 )?;
197 let rows = stmt.query_map([], |row| {
198 let priority_val: i64 = row.get(0)?;
199 let count: u64 = row.get(1)?;
200 Ok((priority_val, count))
201 })?;
202
203 for row in rows {
204 let (priority_val, count) = row?;
205 let priority = if priority_val == 1 {
206 MessagePriority::High
207 } else {
208 MessagePriority::Normal
209 };
210 queued_by_priority.insert(priority, count);
211 }
212
213 Ok(MailboxStats {
214 queued_messages,
215 inflight_messages,
216 queued_by_priority,
217 })
218 }
219}
220
221#[cfg(test)]
222mod tests {
223 use super::*;
224 use actr_protocol::prost::Message as ProstMessage;
225 use actr_protocol::{ActrId, ActrType, Realm};
226 use tempfile::tempdir;
227
228 async fn setup_mailbox() -> SqliteMailbox {
229 let dir = tempdir().unwrap();
230 let path = dir.path().join("test.db");
231 SqliteMailbox::new(&path).await.unwrap()
232 }
233
234 fn dummy_actr_id_bytes() -> Vec<u8> {
235 let actr_id = ActrId {
236 realm: Realm { realm_id: 1 },
237 serial_number: 1000,
238 r#type: ActrType {
239 manufacturer: "test".to_string(),
240 name: "TestActor".to_string(),
241 },
242 };
243 let mut buf = Vec::new();
244 actr_id.encode(&mut buf).unwrap();
245 buf
246 }
247
248 #[tokio::test]
249 async fn test_enqueue_dequeue_ack_workflow() {
250 let mailbox = setup_mailbox().await;
251
252 let from = dummy_actr_id_bytes();
254 let payload = b"hello".to_vec();
255 let msg_id = mailbox
256 .enqueue(from.clone(), payload.clone(), MessagePriority::Normal)
257 .await
258 .unwrap();
259
260 let messages = mailbox.dequeue().await.unwrap();
262 assert_eq!(messages.len(), 1);
263 assert_eq!(messages[0].id, msg_id);
264 assert_eq!(messages[0].from, from);
265 assert_eq!(messages[0].payload, payload);
266 assert_eq!(messages[0].status, MessageStatus::Inflight);
267
268 let messages_again = mailbox.dequeue().await.unwrap();
270 assert!(messages_again.is_empty());
271
272 mailbox.ack(msg_id).await.unwrap();
274
275 let stats = mailbox.status().await.unwrap();
277 assert_eq!(stats.queued_messages, 0);
278 assert_eq!(stats.inflight_messages, 0);
279 }
280
281 #[tokio::test]
282 async fn test_priority_order() {
283 let mailbox = setup_mailbox().await;
284
285 let from = dummy_actr_id_bytes();
286 let normal_id = mailbox
287 .enqueue(from.clone(), b"normal".to_vec(), MessagePriority::Normal)
288 .await
289 .unwrap();
290 let high_id = mailbox
291 .enqueue(from.clone(), b"high".to_vec(), MessagePriority::High)
292 .await
293 .unwrap();
294
295 let messages = mailbox.dequeue().await.unwrap();
297 assert_eq!(messages.len(), 2);
298 assert_eq!(messages[0].id, high_id); assert_eq!(messages[1].id, normal_id); }
301
302 #[tokio::test]
303 async fn test_status_tracking() {
304 let mailbox = setup_mailbox().await;
305
306 let from = dummy_actr_id_bytes();
307 mailbox
308 .enqueue(from.clone(), b"msg1".to_vec(), MessagePriority::Normal)
309 .await
310 .unwrap();
311 mailbox
312 .enqueue(from.clone(), b"msg2".to_vec(), MessagePriority::Normal)
313 .await
314 .unwrap();
315 mailbox
316 .enqueue(from.clone(), b"msg3".to_vec(), MessagePriority::High)
317 .await
318 .unwrap();
319
320 let initial_stats = mailbox.status().await.unwrap();
321 assert_eq!(initial_stats.queued_messages, 3);
322 assert_eq!(initial_stats.inflight_messages, 0);
323 assert_eq!(
324 initial_stats
325 .queued_by_priority
326 .get(&MessagePriority::Normal),
327 Some(&2)
328 );
329 assert_eq!(
330 initial_stats.queued_by_priority.get(&MessagePriority::High),
331 Some(&1)
332 );
333
334 let dequeued = mailbox.dequeue().await.unwrap();
336 assert_eq!(dequeued.len(), 3);
337
338 let after_dequeue_stats = mailbox.status().await.unwrap();
339 assert_eq!(after_dequeue_stats.queued_messages, 0);
340 assert_eq!(after_dequeue_stats.inflight_messages, 3);
341
342 mailbox.ack(dequeued[0].id).await.unwrap();
344
345 let final_stats = mailbox.status().await.unwrap();
346 assert_eq!(final_stats.queued_messages, 0);
347 assert_eq!(final_stats.inflight_messages, 2);
348 }
349}