kojin_core/
memory_broker.rs1use async_trait::async_trait;
2use std::collections::{HashMap, VecDeque};
3use std::sync::Arc;
4use tokio::sync::{Mutex, Notify};
5
6use crate::broker::Broker;
7use crate::error::TaskResult;
8use crate::message::TaskMessage;
9use crate::task_id::TaskId;
10
11#[derive(Clone)]
13pub struct MemoryBroker {
14 inner: Arc<MemoryBrokerInner>,
15}
16
17struct MemoryBrokerInner {
18 queues: Mutex<HashMap<String, VecDeque<TaskMessage>>>,
19 dlq: Mutex<HashMap<String, VecDeque<TaskMessage>>>,
20 processing: Mutex<HashMap<TaskId, TaskMessage>>,
21 notify: Notify,
22}
23
24impl MemoryBroker {
25 pub fn new() -> Self {
26 Self {
27 inner: Arc::new(MemoryBrokerInner {
28 queues: Mutex::new(HashMap::new()),
29 dlq: Mutex::new(HashMap::new()),
30 processing: Mutex::new(HashMap::new()),
31 notify: Notify::new(),
32 }),
33 }
34 }
35
36 pub async fn dlq_len(&self, queue: &str) -> usize {
38 let dlq = self.inner.dlq.lock().await;
39 dlq.get(queue).map_or(0, |q| q.len())
40 }
41}
42
43impl Default for MemoryBroker {
44 fn default() -> Self {
45 Self::new()
46 }
47}
48
49#[async_trait]
50impl Broker for MemoryBroker {
51 async fn enqueue(&self, message: TaskMessage) -> TaskResult<()> {
52 let mut queues = self.inner.queues.lock().await;
53 queues
54 .entry(message.queue.clone())
55 .or_default()
56 .push_back(message);
57 self.inner.notify.notify_one();
58 Ok(())
59 }
60
61 async fn dequeue(
62 &self,
63 queues: &[String],
64 timeout: std::time::Duration,
65 ) -> TaskResult<Option<TaskMessage>> {
66 let deadline = tokio::time::Instant::now() + timeout;
67
68 loop {
69 {
71 let mut q = self.inner.queues.lock().await;
72 for queue_name in queues {
73 if let Some(queue) = q.get_mut(queue_name) {
74 if let Some(msg) = queue.pop_front() {
75 self.inner
77 .processing
78 .lock()
79 .await
80 .insert(msg.id, msg.clone());
81 return Ok(Some(msg));
82 }
83 }
84 }
85 }
86
87 let remaining = deadline.saturating_duration_since(tokio::time::Instant::now());
89 if remaining.is_zero() {
90 return Ok(None);
91 }
92
93 tokio::select! {
94 _ = self.inner.notify.notified() => continue,
95 _ = tokio::time::sleep(remaining) => return Ok(None),
96 }
97 }
98 }
99
100 async fn ack(&self, id: &TaskId) -> TaskResult<()> {
101 self.inner.processing.lock().await.remove(id);
102 Ok(())
103 }
104
105 async fn nack(&self, message: TaskMessage) -> TaskResult<()> {
106 self.inner.processing.lock().await.remove(&message.id);
107 self.enqueue(message).await
109 }
110
111 async fn dead_letter(&self, message: TaskMessage) -> TaskResult<()> {
112 self.inner.processing.lock().await.remove(&message.id);
113 let dlq_name = message.queue.clone();
114 let mut dlq = self.inner.dlq.lock().await;
115 dlq.entry(dlq_name).or_default().push_back(message);
116 Ok(())
117 }
118
119 async fn schedule(
120 &self,
121 message: TaskMessage,
122 _eta: chrono::DateTime<chrono::Utc>,
123 ) -> TaskResult<()> {
124 self.enqueue(message).await
126 }
127
128 async fn queue_len(&self, queue: &str) -> TaskResult<usize> {
129 let queues = self.inner.queues.lock().await;
130 Ok(queues.get(queue).map_or(0, |q| q.len()))
131 }
132}
133
134#[cfg(test)]
135mod tests {
136 use super::*;
137 use std::time::Duration;
138
139 #[tokio::test]
140 async fn enqueue_dequeue_fifo() {
141 let broker = MemoryBroker::new();
142 let msg1 = TaskMessage::new("task1", "default", serde_json::json!(1));
143 let msg2 = TaskMessage::new("task2", "default", serde_json::json!(2));
144
145 broker.enqueue(msg1.clone()).await.unwrap();
146 broker.enqueue(msg2.clone()).await.unwrap();
147
148 let queues = vec!["default".to_string()];
149 let out1 = broker
150 .dequeue(&queues, Duration::from_secs(1))
151 .await
152 .unwrap()
153 .unwrap();
154 let out2 = broker
155 .dequeue(&queues, Duration::from_secs(1))
156 .await
157 .unwrap()
158 .unwrap();
159
160 assert_eq!(out1.task_name, "task1");
161 assert_eq!(out2.task_name, "task2");
162 }
163
164 #[tokio::test]
165 async fn dequeue_timeout_returns_none() {
166 let broker = MemoryBroker::new();
167 let queues = vec!["default".to_string()];
168 let result = broker
169 .dequeue(&queues, Duration::from_millis(50))
170 .await
171 .unwrap();
172 assert!(result.is_none());
173 }
174
175 #[tokio::test]
176 async fn ack_removes_from_processing() {
177 let broker = MemoryBroker::new();
178 let msg = TaskMessage::new("task1", "default", serde_json::json!(1));
179 let id = msg.id;
180 broker.enqueue(msg).await.unwrap();
181
182 let queues = vec!["default".to_string()];
183 let _out = broker
184 .dequeue(&queues, Duration::from_secs(1))
185 .await
186 .unwrap()
187 .unwrap();
188 assert!(broker.inner.processing.lock().await.contains_key(&id));
189
190 broker.ack(&id).await.unwrap();
191 assert!(!broker.inner.processing.lock().await.contains_key(&id));
192 }
193
194 #[tokio::test]
195 async fn nack_requeues() {
196 let broker = MemoryBroker::new();
197 let msg = TaskMessage::new("task1", "default", serde_json::json!(1));
198 broker.enqueue(msg).await.unwrap();
199
200 let queues = vec!["default".to_string()];
201 let out = broker
202 .dequeue(&queues, Duration::from_secs(1))
203 .await
204 .unwrap()
205 .unwrap();
206 broker.nack(out).await.unwrap();
207
208 assert_eq!(broker.queue_len("default").await.unwrap(), 1);
209 }
210
211 #[tokio::test]
212 async fn dead_letter() {
213 let broker = MemoryBroker::new();
214 let msg = TaskMessage::new("task1", "default", serde_json::json!(1));
215 broker.enqueue(msg).await.unwrap();
216
217 let queues = vec!["default".to_string()];
218 let out = broker
219 .dequeue(&queues, Duration::from_secs(1))
220 .await
221 .unwrap()
222 .unwrap();
223 broker.dead_letter(out).await.unwrap();
224
225 assert_eq!(broker.queue_len("default").await.unwrap(), 0);
226 assert_eq!(broker.dlq_len("default").await, 1);
227 }
228
229 #[tokio::test]
230 async fn queue_len() {
231 let broker = MemoryBroker::new();
232 assert_eq!(broker.queue_len("default").await.unwrap(), 0);
233
234 broker
235 .enqueue(TaskMessage::new("t", "default", serde_json::json!(1)))
236 .await
237 .unwrap();
238 broker
239 .enqueue(TaskMessage::new("t", "default", serde_json::json!(2)))
240 .await
241 .unwrap();
242 assert_eq!(broker.queue_len("default").await.unwrap(), 2);
243 }
244}