Skip to main content

kojin_core/
memory_broker.rs

1use async_trait::async_trait;
2use std::collections::{HashMap, VecDeque};
3use std::sync::Arc;
4use tokio::sync::{Mutex, Notify};
5
6use crate::broker::Broker;
7use crate::error::TaskResult;
8use crate::message::TaskMessage;
9use crate::task_id::TaskId;
10
11/// In-memory broker for testing and development.
12#[derive(Clone)]
13pub struct MemoryBroker {
14    inner: Arc<MemoryBrokerInner>,
15}
16
17struct MemoryBrokerInner {
18    queues: Mutex<HashMap<String, VecDeque<TaskMessage>>>,
19    dlq: Mutex<HashMap<String, VecDeque<TaskMessage>>>,
20    processing: Mutex<HashMap<TaskId, TaskMessage>>,
21    notify: Notify,
22}
23
24impl MemoryBroker {
25    pub fn new() -> Self {
26        Self {
27            inner: Arc::new(MemoryBrokerInner {
28                queues: Mutex::new(HashMap::new()),
29                dlq: Mutex::new(HashMap::new()),
30                processing: Mutex::new(HashMap::new()),
31                notify: Notify::new(),
32            }),
33        }
34    }
35
36    /// Get the dead-letter queue contents for testing.
37    pub async fn dlq_len(&self, queue: &str) -> usize {
38        let dlq = self.inner.dlq.lock().await;
39        dlq.get(queue).map_or(0, |q| q.len())
40    }
41}
42
43impl Default for MemoryBroker {
44    fn default() -> Self {
45        Self::new()
46    }
47}
48
49#[async_trait]
50impl Broker for MemoryBroker {
51    async fn enqueue(&self, message: TaskMessage) -> TaskResult<()> {
52        let mut queues = self.inner.queues.lock().await;
53        queues
54            .entry(message.queue.clone())
55            .or_default()
56            .push_back(message);
57        self.inner.notify.notify_one();
58        Ok(())
59    }
60
61    async fn dequeue(
62        &self,
63        queues: &[String],
64        timeout: std::time::Duration,
65    ) -> TaskResult<Option<TaskMessage>> {
66        let deadline = tokio::time::Instant::now() + timeout;
67
68        loop {
69            // Try to pop from any of the requested queues
70            {
71                let mut q = self.inner.queues.lock().await;
72                for queue_name in queues {
73                    if let Some(queue) = q.get_mut(queue_name) {
74                        if let Some(msg) = queue.pop_front() {
75                            // Track in processing
76                            self.inner
77                                .processing
78                                .lock()
79                                .await
80                                .insert(msg.id, msg.clone());
81                            return Ok(Some(msg));
82                        }
83                    }
84                }
85            }
86
87            // Wait for notification or timeout
88            let remaining = deadline.saturating_duration_since(tokio::time::Instant::now());
89            if remaining.is_zero() {
90                return Ok(None);
91            }
92
93            tokio::select! {
94                _ = self.inner.notify.notified() => continue,
95                _ = tokio::time::sleep(remaining) => return Ok(None),
96            }
97        }
98    }
99
100    async fn ack(&self, id: &TaskId) -> TaskResult<()> {
101        self.inner.processing.lock().await.remove(id);
102        Ok(())
103    }
104
105    async fn nack(&self, message: TaskMessage) -> TaskResult<()> {
106        self.inner.processing.lock().await.remove(&message.id);
107        // Re-enqueue
108        self.enqueue(message).await
109    }
110
111    async fn dead_letter(&self, message: TaskMessage) -> TaskResult<()> {
112        self.inner.processing.lock().await.remove(&message.id);
113        let dlq_name = message.queue.clone();
114        let mut dlq = self.inner.dlq.lock().await;
115        dlq.entry(dlq_name).or_default().push_back(message);
116        Ok(())
117    }
118
119    async fn schedule(
120        &self,
121        message: TaskMessage,
122        _eta: chrono::DateTime<chrono::Utc>,
123    ) -> TaskResult<()> {
124        // For MemoryBroker, just enqueue immediately (no scheduled queue support)
125        self.enqueue(message).await
126    }
127
128    async fn queue_len(&self, queue: &str) -> TaskResult<usize> {
129        let queues = self.inner.queues.lock().await;
130        Ok(queues.get(queue).map_or(0, |q| q.len()))
131    }
132}
133
134#[cfg(test)]
135mod tests {
136    use super::*;
137    use std::time::Duration;
138
139    #[tokio::test]
140    async fn enqueue_dequeue_fifo() {
141        let broker = MemoryBroker::new();
142        let msg1 = TaskMessage::new("task1", "default", serde_json::json!(1));
143        let msg2 = TaskMessage::new("task2", "default", serde_json::json!(2));
144
145        broker.enqueue(msg1.clone()).await.unwrap();
146        broker.enqueue(msg2.clone()).await.unwrap();
147
148        let queues = vec!["default".to_string()];
149        let out1 = broker
150            .dequeue(&queues, Duration::from_secs(1))
151            .await
152            .unwrap()
153            .unwrap();
154        let out2 = broker
155            .dequeue(&queues, Duration::from_secs(1))
156            .await
157            .unwrap()
158            .unwrap();
159
160        assert_eq!(out1.task_name, "task1");
161        assert_eq!(out2.task_name, "task2");
162    }
163
164    #[tokio::test]
165    async fn dequeue_timeout_returns_none() {
166        let broker = MemoryBroker::new();
167        let queues = vec!["default".to_string()];
168        let result = broker
169            .dequeue(&queues, Duration::from_millis(50))
170            .await
171            .unwrap();
172        assert!(result.is_none());
173    }
174
175    #[tokio::test]
176    async fn ack_removes_from_processing() {
177        let broker = MemoryBroker::new();
178        let msg = TaskMessage::new("task1", "default", serde_json::json!(1));
179        let id = msg.id;
180        broker.enqueue(msg).await.unwrap();
181
182        let queues = vec!["default".to_string()];
183        let _out = broker
184            .dequeue(&queues, Duration::from_secs(1))
185            .await
186            .unwrap()
187            .unwrap();
188        assert!(broker.inner.processing.lock().await.contains_key(&id));
189
190        broker.ack(&id).await.unwrap();
191        assert!(!broker.inner.processing.lock().await.contains_key(&id));
192    }
193
194    #[tokio::test]
195    async fn nack_requeues() {
196        let broker = MemoryBroker::new();
197        let msg = TaskMessage::new("task1", "default", serde_json::json!(1));
198        broker.enqueue(msg).await.unwrap();
199
200        let queues = vec!["default".to_string()];
201        let out = broker
202            .dequeue(&queues, Duration::from_secs(1))
203            .await
204            .unwrap()
205            .unwrap();
206        broker.nack(out).await.unwrap();
207
208        assert_eq!(broker.queue_len("default").await.unwrap(), 1);
209    }
210
211    #[tokio::test]
212    async fn dead_letter() {
213        let broker = MemoryBroker::new();
214        let msg = TaskMessage::new("task1", "default", serde_json::json!(1));
215        broker.enqueue(msg).await.unwrap();
216
217        let queues = vec!["default".to_string()];
218        let out = broker
219            .dequeue(&queues, Duration::from_secs(1))
220            .await
221            .unwrap()
222            .unwrap();
223        broker.dead_letter(out).await.unwrap();
224
225        assert_eq!(broker.queue_len("default").await.unwrap(), 0);
226        assert_eq!(broker.dlq_len("default").await, 1);
227    }
228
229    #[tokio::test]
230    async fn queue_len() {
231        let broker = MemoryBroker::new();
232        assert_eq!(broker.queue_len("default").await.unwrap(), 0);
233
234        broker
235            .enqueue(TaskMessage::new("t", "default", serde_json::json!(1)))
236            .await
237            .unwrap();
238        broker
239            .enqueue(TaskMessage::new("t", "default", serde_json::json!(2)))
240            .await
241            .unwrap();
242        assert_eq!(broker.queue_len("default").await.unwrap(), 2);
243    }
244}