Skip to main content

kojin_core/
memory_broker.rs

1use async_trait::async_trait;
2use std::collections::{HashMap, VecDeque};
3use std::sync::Arc;
4use tokio::sync::{Mutex, Notify};
5
6use crate::broker::Broker;
7use crate::error::TaskResult;
8use crate::message::TaskMessage;
9use crate::task_id::TaskId;
10
11/// In-memory broker for testing and development.
12#[derive(Clone)]
13pub struct MemoryBroker {
14    inner: Arc<MemoryBrokerInner>,
15}
16
17struct MemoryBrokerInner {
18    queues: Mutex<HashMap<String, VecDeque<TaskMessage>>>,
19    dlq: Mutex<HashMap<String, VecDeque<TaskMessage>>>,
20    processing: Mutex<HashMap<TaskId, TaskMessage>>,
21    notify: Notify,
22}
23
24impl MemoryBroker {
25    pub fn new() -> Self {
26        Self {
27            inner: Arc::new(MemoryBrokerInner {
28                queues: Mutex::new(HashMap::new()),
29                dlq: Mutex::new(HashMap::new()),
30                processing: Mutex::new(HashMap::new()),
31                notify: Notify::new(),
32            }),
33        }
34    }
35}
36
37impl Default for MemoryBroker {
38    fn default() -> Self {
39        Self::new()
40    }
41}
42
43#[async_trait]
44impl Broker for MemoryBroker {
45    async fn enqueue(&self, message: TaskMessage) -> TaskResult<()> {
46        let mut queues = self.inner.queues.lock().await;
47        queues
48            .entry(message.queue.clone())
49            .or_default()
50            .push_back(message);
51        self.inner.notify.notify_one();
52        Ok(())
53    }
54
55    async fn dequeue(
56        &self,
57        queues: &[String],
58        timeout: std::time::Duration,
59    ) -> TaskResult<Option<TaskMessage>> {
60        let deadline = tokio::time::Instant::now() + timeout;
61
62        loop {
63            // Try to pop from any of the requested queues
64            {
65                let mut q = self.inner.queues.lock().await;
66                for queue_name in queues {
67                    if let Some(queue) = q.get_mut(queue_name) {
68                        if let Some(msg) = queue.pop_front() {
69                            // Track in processing
70                            self.inner
71                                .processing
72                                .lock()
73                                .await
74                                .insert(msg.id, msg.clone());
75                            return Ok(Some(msg));
76                        }
77                    }
78                }
79            }
80
81            // Wait for notification or timeout
82            let remaining = deadline.saturating_duration_since(tokio::time::Instant::now());
83            if remaining.is_zero() {
84                return Ok(None);
85            }
86
87            tokio::select! {
88                _ = self.inner.notify.notified() => continue,
89                _ = tokio::time::sleep(remaining) => return Ok(None),
90            }
91        }
92    }
93
94    async fn ack(&self, id: &TaskId) -> TaskResult<()> {
95        self.inner.processing.lock().await.remove(id);
96        Ok(())
97    }
98
99    async fn nack(&self, message: TaskMessage) -> TaskResult<()> {
100        self.inner.processing.lock().await.remove(&message.id);
101        // Re-enqueue
102        self.enqueue(message).await
103    }
104
105    async fn dead_letter(&self, message: TaskMessage) -> TaskResult<()> {
106        self.inner.processing.lock().await.remove(&message.id);
107        let dlq_name = message.queue.clone();
108        let mut dlq = self.inner.dlq.lock().await;
109        dlq.entry(dlq_name).or_default().push_back(message);
110        Ok(())
111    }
112
113    async fn schedule(
114        &self,
115        message: TaskMessage,
116        _eta: chrono::DateTime<chrono::Utc>,
117    ) -> TaskResult<()> {
118        // For MemoryBroker, just enqueue immediately (no scheduled queue support)
119        self.enqueue(message).await
120    }
121
122    async fn queue_len(&self, queue: &str) -> TaskResult<usize> {
123        let queues = self.inner.queues.lock().await;
124        Ok(queues.get(queue).map_or(0, |q| q.len()))
125    }
126
127    async fn dlq_len(&self, queue: &str) -> TaskResult<usize> {
128        let dlq = self.inner.dlq.lock().await;
129        Ok(dlq.get(queue).map_or(0, |q| q.len()))
130    }
131
132    async fn list_queues(&self) -> TaskResult<Vec<String>> {
133        let queues = self.inner.queues.lock().await;
134        Ok(queues.keys().cloned().collect())
135    }
136
137    async fn dlq_messages(
138        &self,
139        queue: &str,
140        offset: usize,
141        limit: usize,
142    ) -> TaskResult<Vec<TaskMessage>> {
143        let dlq = self.inner.dlq.lock().await;
144        Ok(dlq
145            .get(queue)
146            .map(|q| q.iter().skip(offset).take(limit).cloned().collect())
147            .unwrap_or_default())
148    }
149}
150
151#[cfg(test)]
152mod tests {
153    use super::*;
154    use std::time::Duration;
155
156    #[tokio::test]
157    async fn enqueue_dequeue_fifo() {
158        let broker = MemoryBroker::new();
159        let msg1 = TaskMessage::new("task1", "default", serde_json::json!(1));
160        let msg2 = TaskMessage::new("task2", "default", serde_json::json!(2));
161
162        broker.enqueue(msg1.clone()).await.unwrap();
163        broker.enqueue(msg2.clone()).await.unwrap();
164
165        let queues = vec!["default".to_string()];
166        let out1 = broker
167            .dequeue(&queues, Duration::from_secs(1))
168            .await
169            .unwrap()
170            .unwrap();
171        let out2 = broker
172            .dequeue(&queues, Duration::from_secs(1))
173            .await
174            .unwrap()
175            .unwrap();
176
177        assert_eq!(out1.task_name, "task1");
178        assert_eq!(out2.task_name, "task2");
179    }
180
181    #[tokio::test]
182    async fn dequeue_timeout_returns_none() {
183        let broker = MemoryBroker::new();
184        let queues = vec!["default".to_string()];
185        let result = broker
186            .dequeue(&queues, Duration::from_millis(50))
187            .await
188            .unwrap();
189        assert!(result.is_none());
190    }
191
192    #[tokio::test]
193    async fn ack_removes_from_processing() {
194        let broker = MemoryBroker::new();
195        let msg = TaskMessage::new("task1", "default", serde_json::json!(1));
196        let id = msg.id;
197        broker.enqueue(msg).await.unwrap();
198
199        let queues = vec!["default".to_string()];
200        let _out = broker
201            .dequeue(&queues, Duration::from_secs(1))
202            .await
203            .unwrap()
204            .unwrap();
205        assert!(broker.inner.processing.lock().await.contains_key(&id));
206
207        broker.ack(&id).await.unwrap();
208        assert!(!broker.inner.processing.lock().await.contains_key(&id));
209    }
210
211    #[tokio::test]
212    async fn nack_requeues() {
213        let broker = MemoryBroker::new();
214        let msg = TaskMessage::new("task1", "default", serde_json::json!(1));
215        broker.enqueue(msg).await.unwrap();
216
217        let queues = vec!["default".to_string()];
218        let out = broker
219            .dequeue(&queues, Duration::from_secs(1))
220            .await
221            .unwrap()
222            .unwrap();
223        broker.nack(out).await.unwrap();
224
225        assert_eq!(broker.queue_len("default").await.unwrap(), 1);
226    }
227
228    #[tokio::test]
229    async fn dead_letter() {
230        let broker = MemoryBroker::new();
231        let msg = TaskMessage::new("task1", "default", serde_json::json!(1));
232        broker.enqueue(msg).await.unwrap();
233
234        let queues = vec!["default".to_string()];
235        let out = broker
236            .dequeue(&queues, Duration::from_secs(1))
237            .await
238            .unwrap()
239            .unwrap();
240        broker.dead_letter(out).await.unwrap();
241
242        assert_eq!(broker.queue_len("default").await.unwrap(), 0);
243        assert_eq!(broker.dlq_len("default").await.unwrap(), 1);
244    }
245
246    #[tokio::test]
247    async fn queue_len() {
248        let broker = MemoryBroker::new();
249        assert_eq!(broker.queue_len("default").await.unwrap(), 0);
250
251        broker
252            .enqueue(TaskMessage::new("t", "default", serde_json::json!(1)))
253            .await
254            .unwrap();
255        broker
256            .enqueue(TaskMessage::new("t", "default", serde_json::json!(2)))
257            .await
258            .unwrap();
259        assert_eq!(broker.queue_len("default").await.unwrap(), 2);
260    }
261
262    #[tokio::test]
263    async fn list_queues_returns_known_queues() {
264        let broker = MemoryBroker::new();
265        broker
266            .enqueue(TaskMessage::new("t", "emails", serde_json::json!(1)))
267            .await
268            .unwrap();
269        broker
270            .enqueue(TaskMessage::new("t", "notifications", serde_json::json!(2)))
271            .await
272            .unwrap();
273
274        let mut queues = broker.list_queues().await.unwrap();
275        queues.sort();
276        assert_eq!(queues, vec!["emails", "notifications"]);
277    }
278
279    #[tokio::test]
280    async fn dlq_len_via_trait() {
281        let broker = MemoryBroker::new();
282        let msg = TaskMessage::new("task1", "default", serde_json::json!(1));
283        broker.enqueue(msg).await.unwrap();
284
285        let queues = vec!["default".to_string()];
286        let out = broker
287            .dequeue(&queues, Duration::from_secs(1))
288            .await
289            .unwrap()
290            .unwrap();
291        broker.dead_letter(out).await.unwrap();
292
293        assert_eq!(broker.dlq_len("default").await.unwrap(), 1);
294    }
295
296    #[tokio::test]
297    async fn dlq_messages_pagination() {
298        let broker = MemoryBroker::new();
299        let queues = vec!["default".to_string()];
300
301        // Dead-letter 5 messages
302        for i in 0..5 {
303            let msg = TaskMessage::new("task", "default", serde_json::json!(i));
304            broker.enqueue(msg).await.unwrap();
305            let out = broker
306                .dequeue(&queues, Duration::from_secs(1))
307                .await
308                .unwrap()
309                .unwrap();
310            broker.dead_letter(out).await.unwrap();
311        }
312
313        let page1 = broker.dlq_messages("default", 0, 3).await.unwrap();
314        assert_eq!(page1.len(), 3);
315
316        let page2 = broker.dlq_messages("default", 3, 3).await.unwrap();
317        assert_eq!(page2.len(), 2);
318    }
319
320    #[tokio::test]
321    async fn dlq_messages_empty() {
322        let broker = MemoryBroker::new();
323        let messages = broker.dlq_messages("nonexistent", 0, 10).await.unwrap();
324        assert!(messages.is_empty());
325    }
326}