kojin_core/
memory_broker.rs1use async_trait::async_trait;
2use std::collections::{HashMap, VecDeque};
3use std::sync::Arc;
4use tokio::sync::{Mutex, Notify};
5
6use crate::broker::Broker;
7use crate::error::TaskResult;
8use crate::message::TaskMessage;
9use crate::task_id::TaskId;
10
11#[derive(Clone)]
13pub struct MemoryBroker {
14 inner: Arc<MemoryBrokerInner>,
15}
16
17struct MemoryBrokerInner {
18 queues: Mutex<HashMap<String, VecDeque<TaskMessage>>>,
19 dlq: Mutex<HashMap<String, VecDeque<TaskMessage>>>,
20 processing: Mutex<HashMap<TaskId, TaskMessage>>,
21 notify: Notify,
22}
23
24impl MemoryBroker {
25 pub fn new() -> Self {
26 Self {
27 inner: Arc::new(MemoryBrokerInner {
28 queues: Mutex::new(HashMap::new()),
29 dlq: Mutex::new(HashMap::new()),
30 processing: Mutex::new(HashMap::new()),
31 notify: Notify::new(),
32 }),
33 }
34 }
35}
36
37impl Default for MemoryBroker {
38 fn default() -> Self {
39 Self::new()
40 }
41}
42
43#[async_trait]
44impl Broker for MemoryBroker {
45 async fn enqueue(&self, message: TaskMessage) -> TaskResult<()> {
46 let mut queues = self.inner.queues.lock().await;
47 queues
48 .entry(message.queue.clone())
49 .or_default()
50 .push_back(message);
51 self.inner.notify.notify_one();
52 Ok(())
53 }
54
55 async fn dequeue(
56 &self,
57 queues: &[String],
58 timeout: std::time::Duration,
59 ) -> TaskResult<Option<TaskMessage>> {
60 let deadline = tokio::time::Instant::now() + timeout;
61
62 loop {
63 {
65 let mut q = self.inner.queues.lock().await;
66 for queue_name in queues {
67 if let Some(queue) = q.get_mut(queue_name) {
68 if let Some(msg) = queue.pop_front() {
69 self.inner
71 .processing
72 .lock()
73 .await
74 .insert(msg.id, msg.clone());
75 return Ok(Some(msg));
76 }
77 }
78 }
79 }
80
81 let remaining = deadline.saturating_duration_since(tokio::time::Instant::now());
83 if remaining.is_zero() {
84 return Ok(None);
85 }
86
87 tokio::select! {
88 _ = self.inner.notify.notified() => continue,
89 _ = tokio::time::sleep(remaining) => return Ok(None),
90 }
91 }
92 }
93
94 async fn ack(&self, id: &TaskId) -> TaskResult<()> {
95 self.inner.processing.lock().await.remove(id);
96 Ok(())
97 }
98
99 async fn nack(&self, message: TaskMessage) -> TaskResult<()> {
100 self.inner.processing.lock().await.remove(&message.id);
101 self.enqueue(message).await
103 }
104
105 async fn dead_letter(&self, message: TaskMessage) -> TaskResult<()> {
106 self.inner.processing.lock().await.remove(&message.id);
107 let dlq_name = message.queue.clone();
108 let mut dlq = self.inner.dlq.lock().await;
109 dlq.entry(dlq_name).or_default().push_back(message);
110 Ok(())
111 }
112
113 async fn schedule(
114 &self,
115 message: TaskMessage,
116 _eta: chrono::DateTime<chrono::Utc>,
117 ) -> TaskResult<()> {
118 self.enqueue(message).await
120 }
121
122 async fn queue_len(&self, queue: &str) -> TaskResult<usize> {
123 let queues = self.inner.queues.lock().await;
124 Ok(queues.get(queue).map_or(0, |q| q.len()))
125 }
126
127 async fn dlq_len(&self, queue: &str) -> TaskResult<usize> {
128 let dlq = self.inner.dlq.lock().await;
129 Ok(dlq.get(queue).map_or(0, |q| q.len()))
130 }
131
132 async fn list_queues(&self) -> TaskResult<Vec<String>> {
133 let queues = self.inner.queues.lock().await;
134 Ok(queues.keys().cloned().collect())
135 }
136
137 async fn dlq_messages(
138 &self,
139 queue: &str,
140 offset: usize,
141 limit: usize,
142 ) -> TaskResult<Vec<TaskMessage>> {
143 let dlq = self.inner.dlq.lock().await;
144 Ok(dlq
145 .get(queue)
146 .map(|q| q.iter().skip(offset).take(limit).cloned().collect())
147 .unwrap_or_default())
148 }
149}
150
151#[cfg(test)]
152mod tests {
153 use super::*;
154 use std::time::Duration;
155
156 #[tokio::test]
157 async fn enqueue_dequeue_fifo() {
158 let broker = MemoryBroker::new();
159 let msg1 = TaskMessage::new("task1", "default", serde_json::json!(1));
160 let msg2 = TaskMessage::new("task2", "default", serde_json::json!(2));
161
162 broker.enqueue(msg1.clone()).await.unwrap();
163 broker.enqueue(msg2.clone()).await.unwrap();
164
165 let queues = vec!["default".to_string()];
166 let out1 = broker
167 .dequeue(&queues, Duration::from_secs(1))
168 .await
169 .unwrap()
170 .unwrap();
171 let out2 = broker
172 .dequeue(&queues, Duration::from_secs(1))
173 .await
174 .unwrap()
175 .unwrap();
176
177 assert_eq!(out1.task_name, "task1");
178 assert_eq!(out2.task_name, "task2");
179 }
180
181 #[tokio::test]
182 async fn dequeue_timeout_returns_none() {
183 let broker = MemoryBroker::new();
184 let queues = vec!["default".to_string()];
185 let result = broker
186 .dequeue(&queues, Duration::from_millis(50))
187 .await
188 .unwrap();
189 assert!(result.is_none());
190 }
191
192 #[tokio::test]
193 async fn ack_removes_from_processing() {
194 let broker = MemoryBroker::new();
195 let msg = TaskMessage::new("task1", "default", serde_json::json!(1));
196 let id = msg.id;
197 broker.enqueue(msg).await.unwrap();
198
199 let queues = vec!["default".to_string()];
200 let _out = broker
201 .dequeue(&queues, Duration::from_secs(1))
202 .await
203 .unwrap()
204 .unwrap();
205 assert!(broker.inner.processing.lock().await.contains_key(&id));
206
207 broker.ack(&id).await.unwrap();
208 assert!(!broker.inner.processing.lock().await.contains_key(&id));
209 }
210
211 #[tokio::test]
212 async fn nack_requeues() {
213 let broker = MemoryBroker::new();
214 let msg = TaskMessage::new("task1", "default", serde_json::json!(1));
215 broker.enqueue(msg).await.unwrap();
216
217 let queues = vec!["default".to_string()];
218 let out = broker
219 .dequeue(&queues, Duration::from_secs(1))
220 .await
221 .unwrap()
222 .unwrap();
223 broker.nack(out).await.unwrap();
224
225 assert_eq!(broker.queue_len("default").await.unwrap(), 1);
226 }
227
228 #[tokio::test]
229 async fn dead_letter() {
230 let broker = MemoryBroker::new();
231 let msg = TaskMessage::new("task1", "default", serde_json::json!(1));
232 broker.enqueue(msg).await.unwrap();
233
234 let queues = vec!["default".to_string()];
235 let out = broker
236 .dequeue(&queues, Duration::from_secs(1))
237 .await
238 .unwrap()
239 .unwrap();
240 broker.dead_letter(out).await.unwrap();
241
242 assert_eq!(broker.queue_len("default").await.unwrap(), 0);
243 assert_eq!(broker.dlq_len("default").await.unwrap(), 1);
244 }
245
246 #[tokio::test]
247 async fn queue_len() {
248 let broker = MemoryBroker::new();
249 assert_eq!(broker.queue_len("default").await.unwrap(), 0);
250
251 broker
252 .enqueue(TaskMessage::new("t", "default", serde_json::json!(1)))
253 .await
254 .unwrap();
255 broker
256 .enqueue(TaskMessage::new("t", "default", serde_json::json!(2)))
257 .await
258 .unwrap();
259 assert_eq!(broker.queue_len("default").await.unwrap(), 2);
260 }
261
262 #[tokio::test]
263 async fn list_queues_returns_known_queues() {
264 let broker = MemoryBroker::new();
265 broker
266 .enqueue(TaskMessage::new("t", "emails", serde_json::json!(1)))
267 .await
268 .unwrap();
269 broker
270 .enqueue(TaskMessage::new("t", "notifications", serde_json::json!(2)))
271 .await
272 .unwrap();
273
274 let mut queues = broker.list_queues().await.unwrap();
275 queues.sort();
276 assert_eq!(queues, vec!["emails", "notifications"]);
277 }
278
279 #[tokio::test]
280 async fn dlq_len_via_trait() {
281 let broker = MemoryBroker::new();
282 let msg = TaskMessage::new("task1", "default", serde_json::json!(1));
283 broker.enqueue(msg).await.unwrap();
284
285 let queues = vec!["default".to_string()];
286 let out = broker
287 .dequeue(&queues, Duration::from_secs(1))
288 .await
289 .unwrap()
290 .unwrap();
291 broker.dead_letter(out).await.unwrap();
292
293 assert_eq!(broker.dlq_len("default").await.unwrap(), 1);
294 }
295
296 #[tokio::test]
297 async fn dlq_messages_pagination() {
298 let broker = MemoryBroker::new();
299 let queues = vec!["default".to_string()];
300
301 for i in 0..5 {
303 let msg = TaskMessage::new("task", "default", serde_json::json!(i));
304 broker.enqueue(msg).await.unwrap();
305 let out = broker
306 .dequeue(&queues, Duration::from_secs(1))
307 .await
308 .unwrap()
309 .unwrap();
310 broker.dead_letter(out).await.unwrap();
311 }
312
313 let page1 = broker.dlq_messages("default", 0, 3).await.unwrap();
314 assert_eq!(page1.len(), 3);
315
316 let page2 = broker.dlq_messages("default", 3, 3).await.unwrap();
317 assert_eq!(page2.len(), 2);
318 }
319
320 #[tokio::test]
321 async fn dlq_messages_empty() {
322 let broker = MemoryBroker::new();
323 let messages = broker.dlq_messages("nonexistent", 0, 10).await.unwrap();
324 assert!(messages.is_empty());
325 }
326}