Skip to main content

alien_bindings/providers/queue/
local.rs

1use crate::error::{ErrorData, Result};
2use crate::traits::{
3    Binding, MessagePayload, Queue, QueueMessage, MAX_BATCH_SIZE, MAX_MESSAGE_BYTES,
4};
5use alien_core::bindings::LocalQueueBinding;
6use alien_error::{AlienError, Context, IntoAlienError};
7use async_trait::async_trait;
8use chrono::{DateTime, Utc};
9use serde::{Deserialize, Serialize};
10use std::path::PathBuf;
11use std::sync::Arc;
12use tokio::sync::Mutex;
13
14const LEASE_DURATION_SECS: i64 = 30;
15
16/// Local disk-persisted queue implementation using sled embedded database.
17///
18/// This provides a persistent, thread-safe, disk-based message queue that implements
19/// all Queue trait features including send, receive with visibility timeout, and ack.
20/// Messages survive process restarts.
21#[derive(Debug)]
22pub struct LocalQueue {
23    db: Arc<Mutex<sled::Db>>,
24    data_dir: PathBuf,
25}
26
27/// Stored message format that avoids serde issues with `MessagePayload`'s internal tagging.
28/// We store the payload as a raw JSON value and a discriminator tag.
29#[derive(Debug, Clone, Serialize, Deserialize)]
30struct StoredMessage {
31    /// "json" or "text"
32    payload_type: String,
33    /// The raw payload content (JSON value for json type, string for text type)
34    payload_data: serde_json::Value,
35    enqueued_at: DateTime<Utc>,
36}
37
38impl StoredMessage {
39    fn from_payload(payload: MessagePayload) -> Self {
40        let (payload_type, payload_data) = match payload {
41            MessagePayload::Json(v) => ("json".to_string(), v),
42            MessagePayload::Text(s) => ("text".to_string(), serde_json::Value::String(s)),
43        };
44        Self {
45            payload_type,
46            payload_data,
47            enqueued_at: Utc::now(),
48        }
49    }
50
51    fn into_payload(self) -> MessagePayload {
52        match self.payload_type.as_str() {
53            "json" => MessagePayload::Json(self.payload_data),
54            _ => match self.payload_data {
55                serde_json::Value::String(s) => MessagePayload::Text(s),
56                other => MessagePayload::Text(other.to_string()),
57            },
58        }
59    }
60}
61
62#[derive(Debug, Clone, Serialize, Deserialize)]
63struct InFlightMessage {
64    /// The sequence key in the messages tree (big-endian u64 bytes)
65    seq_bytes: Vec<u8>,
66    message: StoredMessage,
67    leased_until: DateTime<Utc>,
68}
69
70impl LocalQueue {
71    /// Create a new local queue store with the given data directory.
72    pub async fn new(data_dir: PathBuf) -> Result<Self> {
73        tracing::debug!(data_dir = %data_dir.display(), "Opening LocalQueue database");
74
75        if let Some(parent) = data_dir.parent() {
76            tokio::fs::create_dir_all(parent)
77                .await
78                .into_alien_error()
79                .context(ErrorData::LocalFilesystemError {
80                    path: parent.to_string_lossy().to_string(),
81                    operation: "create_dir_all".to_string(),
82                })?;
83        }
84
85        let db =
86            sled::open(&data_dir)
87                .into_alien_error()
88                .context(ErrorData::BindingSetupFailed {
89                    binding_type: "local queue".to_string(),
90                    reason: format!("Failed to open sled database at: {:?}", data_dir),
91                })?;
92
93        tracing::debug!(data_dir = %data_dir.display(), "LocalQueue database opened successfully");
94
95        Ok(Self {
96            db: Arc::new(Mutex::new(db)),
97            data_dir,
98        })
99    }
100
101    /// Create a LocalQueue from a LocalQueueBinding.
102    pub async fn from_binding(binding: LocalQueueBinding) -> Result<Self> {
103        let queue_path = binding
104            .queue_path
105            .into_value("queue", "queue_path")
106            .context(ErrorData::BindingConfigInvalid {
107                binding_name: "queue".to_string(),
108                reason: "Failed to resolve queue_path from binding".to_string(),
109            })?;
110
111        Self::new(PathBuf::from(queue_path)).await
112    }
113
114    /// Reclaim expired in-flight messages back to the messages tree.
115    fn reclaim_expired_leases(db: &sled::Db) -> Result<()> {
116        let in_flight_tree = db.open_tree("in_flight").into_alien_error().context(
117            ErrorData::QueueOperationFailed {
118                operation: "open in_flight tree".to_string(),
119                reason: "Failed to open in_flight tree".to_string(),
120            },
121        )?;
122
123        let messages_tree = db.open_tree("messages").into_alien_error().context(
124            ErrorData::QueueOperationFailed {
125                operation: "open messages tree".to_string(),
126                reason: "Failed to open messages tree".to_string(),
127            },
128        )?;
129
130        let now = Utc::now();
131        let mut expired_handles = Vec::new();
132
133        for result in in_flight_tree.iter() {
134            let (handle_bytes, value_bytes) =
135                result
136                    .into_alien_error()
137                    .context(ErrorData::QueueOperationFailed {
138                        operation: "scan in_flight".to_string(),
139                        reason: "Failed to iterate in-flight messages".to_string(),
140                    })?;
141
142            if let Ok(in_flight) = serde_json::from_slice::<InFlightMessage>(&value_bytes) {
143                if now >= in_flight.leased_until {
144                    // Re-enqueue the message with its original sequence key
145                    let stored_bytes = serde_json::to_vec(&in_flight.message)
146                        .into_alien_error()
147                        .context(ErrorData::QueueOperationFailed {
148                            operation: "serialize reclaimed message".to_string(),
149                            reason: "Failed to serialize message".to_string(),
150                        })?;
151
152                    messages_tree
153                        .insert(&in_flight.seq_bytes, stored_bytes)
154                        .into_alien_error()
155                        .context(ErrorData::QueueOperationFailed {
156                            operation: "re-enqueue expired message".to_string(),
157                            reason: "Failed to re-enqueue expired message".to_string(),
158                        })?;
159
160                    expired_handles.push(handle_bytes);
161                }
162            }
163        }
164
165        for handle in expired_handles {
166            let _ = in_flight_tree.remove(&handle);
167        }
168
169        Ok(())
170    }
171
172    fn serialize_message(message: &StoredMessage) -> Result<Vec<u8>> {
173        serde_json::to_vec(message)
174            .into_alien_error()
175            .context(ErrorData::QueueOperationFailed {
176                operation: "serialize message".to_string(),
177                reason: "Failed to serialize message to JSON".to_string(),
178            })
179    }
180
181    fn message_size(payload: &MessagePayload) -> Result<usize> {
182        match payload {
183            MessagePayload::Json(v) => serde_json::to_string(v)
184                .map(|s| s.len())
185                .into_alien_error()
186                .context(ErrorData::QueueOperationFailed {
187                    operation: "measure message size".to_string(),
188                    reason: "Failed to serialize JSON payload".to_string(),
189                }),
190            MessagePayload::Text(s) => Ok(s.len()),
191        }
192    }
193}
194
195impl Binding for LocalQueue {}
196
197#[async_trait]
198impl Queue for LocalQueue {
199    async fn send(&self, _queue: &str, message: MessagePayload) -> Result<()> {
200        let size = Self::message_size(&message)?;
201        if size > MAX_MESSAGE_BYTES {
202            return Err(AlienError::new(ErrorData::BindingSetupFailed {
203                binding_type: "queue.local".to_string(),
204                reason: format!(
205                    "Message size {} bytes exceeds limit of {} bytes",
206                    size, MAX_MESSAGE_BYTES
207                ),
208            }));
209        }
210
211        let stored = StoredMessage::from_payload(message);
212        let serialized = Self::serialize_message(&stored)?;
213
214        let db = self.db.lock().await;
215        let messages_tree = db.open_tree("messages").into_alien_error().context(
216            ErrorData::QueueOperationFailed {
217                operation: "open messages tree".to_string(),
218                reason: "Failed to open messages tree".to_string(),
219            },
220        )?;
221
222        // Use generate_id for monotonically increasing sequence numbers
223        let seq = db
224            .generate_id()
225            .into_alien_error()
226            .context(ErrorData::QueueOperationFailed {
227                operation: "generate sequence".to_string(),
228                reason: "Failed to generate message sequence number".to_string(),
229            })?;
230        let seq_key = seq.to_be_bytes();
231
232        messages_tree
233            .insert(seq_key, serialized)
234            .into_alien_error()
235            .context(ErrorData::QueueOperationFailed {
236                operation: "send".to_string(),
237                reason: "Failed to insert message".to_string(),
238            })?;
239
240        messages_tree
241            .flush_async()
242            .await
243            .into_alien_error()
244            .context(ErrorData::QueueOperationFailed {
245                operation: "flush".to_string(),
246                reason: "Failed to flush message to disk".to_string(),
247            })?;
248
249        Ok(())
250    }
251
252    async fn receive(&self, _queue: &str, max_messages: usize) -> Result<Vec<QueueMessage>> {
253        if max_messages == 0 || max_messages > MAX_BATCH_SIZE {
254            return Err(AlienError::new(ErrorData::BindingSetupFailed {
255                binding_type: "queue.local".to_string(),
256                reason: format!(
257                    "Batch size {} is invalid. Must be between 1 and {}",
258                    max_messages, MAX_BATCH_SIZE
259                ),
260            }));
261        }
262
263        let db = self.db.lock().await;
264
265        // Reclaim expired leases first
266        Self::reclaim_expired_leases(&db)?;
267
268        let messages_tree = db.open_tree("messages").into_alien_error().context(
269            ErrorData::QueueOperationFailed {
270                operation: "open messages tree".to_string(),
271                reason: "Failed to open messages tree".to_string(),
272            },
273        )?;
274
275        let in_flight_tree = db.open_tree("in_flight").into_alien_error().context(
276            ErrorData::QueueOperationFailed {
277                operation: "open in_flight tree".to_string(),
278                reason: "Failed to open in_flight tree".to_string(),
279            },
280        )?;
281
282        let now = Utc::now();
283        let leased_until = now + chrono::Duration::seconds(LEASE_DURATION_SECS);
284        let mut result = Vec::new();
285
286        // Pop messages from the front (lowest sequence number)
287        for item in messages_tree.iter() {
288            if result.len() >= max_messages {
289                break;
290            }
291
292            let (seq_key, value_bytes) =
293                item.into_alien_error()
294                    .context(ErrorData::QueueOperationFailed {
295                        operation: "receive".to_string(),
296                        reason: "Failed to iterate messages".to_string(),
297                    })?;
298
299            let stored: StoredMessage = match serde_json::from_slice(&value_bytes) {
300                Ok(m) => m,
301                Err(_) => continue, // Skip corrupted messages
302            };
303
304            // Generate a receipt handle
305            let receipt_handle = uuid::Uuid::new_v4().to_string();
306
307            // Move to in-flight
308            let in_flight = InFlightMessage {
309                seq_bytes: seq_key.to_vec(),
310                message: stored.clone(),
311                leased_until,
312            };
313            let in_flight_bytes = serde_json::to_vec(&in_flight).into_alien_error().context(
314                ErrorData::QueueOperationFailed {
315                    operation: "serialize in_flight".to_string(),
316                    reason: "Failed to serialize in-flight message".to_string(),
317                },
318            )?;
319
320            in_flight_tree
321                .insert(receipt_handle.as_bytes(), in_flight_bytes)
322                .into_alien_error()
323                .context(ErrorData::QueueOperationFailed {
324                    operation: "move to in_flight".to_string(),
325                    reason: "Failed to move message to in-flight".to_string(),
326                })?;
327
328            // Remove from messages
329            messages_tree.remove(&seq_key).into_alien_error().context(
330                ErrorData::QueueOperationFailed {
331                    operation: "remove from messages".to_string(),
332                    reason: "Failed to remove message from queue".to_string(),
333                },
334            )?;
335
336            result.push(QueueMessage {
337                payload: stored.into_payload(),
338                receipt_handle,
339            });
340        }
341
342        // Flush both trees
343        messages_tree
344            .flush_async()
345            .await
346            .into_alien_error()
347            .context(ErrorData::QueueOperationFailed {
348                operation: "flush".to_string(),
349                reason: "Failed to flush messages tree".to_string(),
350            })?;
351        in_flight_tree
352            .flush_async()
353            .await
354            .into_alien_error()
355            .context(ErrorData::QueueOperationFailed {
356                operation: "flush".to_string(),
357                reason: "Failed to flush in_flight tree".to_string(),
358            })?;
359
360        Ok(result)
361    }
362
363    async fn ack(&self, _queue: &str, receipt_handle: &str) -> Result<()> {
364        let db = self.db.lock().await;
365        let in_flight_tree = db.open_tree("in_flight").into_alien_error().context(
366            ErrorData::QueueOperationFailed {
367                operation: "open in_flight tree".to_string(),
368                reason: "Failed to open in_flight tree".to_string(),
369            },
370        )?;
371
372        // Remove the message (idempotent - missing key is OK)
373        in_flight_tree
374            .remove(receipt_handle.as_bytes())
375            .into_alien_error()
376            .context(ErrorData::QueueOperationFailed {
377                operation: "ack".to_string(),
378                reason: "Failed to acknowledge message".to_string(),
379            })?;
380
381        in_flight_tree
382            .flush_async()
383            .await
384            .into_alien_error()
385            .context(ErrorData::QueueOperationFailed {
386                operation: "flush".to_string(),
387                reason: "Failed to flush acknowledgment".to_string(),
388            })?;
389
390        Ok(())
391    }
392}
393
394#[cfg(test)]
395mod tests {
396    use super::*;
397    use tempfile::TempDir;
398
399    fn payload_text(msg: &QueueMessage) -> String {
400        match &msg.payload {
401            MessagePayload::Text(s) => s.clone(),
402            MessagePayload::Json(v) => v.to_string(),
403        }
404    }
405
406    async fn create_test_queue() -> (LocalQueue, TempDir) {
407        let temp_dir = tempfile::tempdir().expect("Failed to create temp dir");
408        let queue = LocalQueue::new(temp_dir.path().join("queue.db"))
409            .await
410            .expect("Failed to create LocalQueue");
411        (queue, temp_dir)
412    }
413
414    #[tokio::test]
415    async fn test_send_and_receive() {
416        let (queue, _temp_dir) = create_test_queue().await;
417
418        queue
419            .send("q", MessagePayload::Text("hello".to_string()))
420            .await
421            .unwrap();
422        queue
423            .send("q", MessagePayload::Text("world".to_string()))
424            .await
425            .unwrap();
426
427        let msgs = queue.receive("q", 10).await.unwrap();
428        assert_eq!(msgs.len(), 2);
429        assert_eq!(payload_text(&msgs[0]), "hello");
430        assert_eq!(payload_text(&msgs[1]), "world");
431    }
432
433    #[tokio::test]
434    async fn test_receive_empty_queue() {
435        let (queue, _temp_dir) = create_test_queue().await;
436
437        let msgs = queue.receive("q", 10).await.unwrap();
438        assert!(msgs.is_empty());
439    }
440
441    #[tokio::test]
442    async fn test_ack_removes_message() {
443        let (queue, _temp_dir) = create_test_queue().await;
444
445        queue
446            .send("q", MessagePayload::Text("msg".to_string()))
447            .await
448            .unwrap();
449
450        let msgs = queue.receive("q", 1).await.unwrap();
451        assert_eq!(msgs.len(), 1);
452
453        // Ack the message
454        queue.ack("q", &msgs[0].receipt_handle).await.unwrap();
455
456        // No messages should be available (acked, not expired)
457        let msgs = queue.receive("q", 10).await.unwrap();
458        assert!(msgs.is_empty());
459    }
460
461    #[tokio::test]
462    async fn test_ack_idempotent() {
463        let (queue, _temp_dir) = create_test_queue().await;
464
465        // Acking a non-existent receipt handle should succeed
466        queue.ack("q", "non-existent-handle").await.unwrap();
467    }
468
469    #[tokio::test]
470    async fn test_receive_respects_max_messages() {
471        let (queue, _temp_dir) = create_test_queue().await;
472
473        for i in 0..5 {
474            queue
475                .send("q", MessagePayload::Text(format!("msg-{}", i)))
476                .await
477                .unwrap();
478        }
479
480        let msgs = queue.receive("q", 2).await.unwrap();
481        assert_eq!(msgs.len(), 2);
482        assert_eq!(payload_text(&msgs[0]), "msg-0");
483        assert_eq!(payload_text(&msgs[1]), "msg-1");
484    }
485
486    #[tokio::test]
487    async fn test_json_payload() {
488        let (queue, _temp_dir) = create_test_queue().await;
489
490        let payload = serde_json::json!({"key": "value", "num": 42});
491        queue
492            .send("q", MessagePayload::Json(payload.clone()))
493            .await
494            .unwrap();
495
496        let msgs = queue.receive("q", 1).await.unwrap();
497        assert_eq!(msgs.len(), 1);
498        match &msgs[0].payload {
499            MessagePayload::Json(v) => assert_eq!(v, &payload),
500            _ => panic!("Expected JSON payload"),
501        }
502    }
503
504    #[tokio::test]
505    async fn test_message_size_validation() {
506        let (queue, _temp_dir) = create_test_queue().await;
507
508        let large = "x".repeat(MAX_MESSAGE_BYTES + 1);
509        let result = queue.send("q", MessagePayload::Text(large)).await;
510        assert!(result.is_err());
511    }
512
513    #[tokio::test]
514    async fn test_batch_size_validation() {
515        let (queue, _temp_dir) = create_test_queue().await;
516
517        assert!(queue.receive("q", 0).await.is_err());
518        assert!(queue.receive("q", MAX_BATCH_SIZE + 1).await.is_err());
519    }
520
521    #[tokio::test]
522    async fn test_persistence_across_reopens() {
523        let temp_dir = tempfile::tempdir().expect("Failed to create temp dir");
524        let db_path = temp_dir.path().join("queue.db");
525
526        // Send a message and drop the queue
527        {
528            let queue = LocalQueue::new(db_path.clone()).await.unwrap();
529            queue
530                .send("q", MessagePayload::Text("persistent".to_string()))
531                .await
532                .unwrap();
533        }
534
535        // Reopen and verify message persists
536        {
537            let queue = LocalQueue::new(db_path).await.unwrap();
538            let msgs = queue.receive("q", 1).await.unwrap();
539            assert_eq!(msgs.len(), 1);
540            assert_eq!(payload_text(&msgs[0]), "persistent");
541        }
542    }
543
544    #[tokio::test]
545    async fn test_fifo_ordering() {
546        let (queue, _temp_dir) = create_test_queue().await;
547
548        for i in 0..10 {
549            queue
550                .send("q", MessagePayload::Text(format!("{}", i)))
551                .await
552                .unwrap();
553        }
554
555        let msgs = queue.receive("q", 10).await.unwrap();
556        for (i, msg) in msgs.iter().enumerate() {
557            assert_eq!(payload_text(msg), format!("{}", i));
558        }
559    }
560}