1use async_trait::async_trait;
2use deadpool_redis::{Config, Pool, Runtime};
3use redis::AsyncCommands;
4use std::time::Duration;
5
6use kojin_core::broker::Broker;
7use kojin_core::error::{KojinError, TaskResult};
8use kojin_core::message::TaskMessage;
9use kojin_core::task_id::TaskId;
10
11use crate::config::RedisConfig;
12use crate::keys::KeyBuilder;
13
14fn broker_err(e: impl std::fmt::Display) -> KojinError {
15 KojinError::Broker(e.to_string())
16}
17
18#[derive(Clone)]
20pub struct RedisBroker {
21 pool: Pool,
22 keys: KeyBuilder,
23 worker_id: String,
24 dedup_ttl: u64,
25}
26
27impl RedisBroker {
28 pub async fn new(config: RedisConfig) -> TaskResult<Self> {
30 let cfg = Config::from_url(&config.url);
31 let pool = cfg
32 .builder()
33 .map_err(broker_err)?
34 .max_size(config.pool_size)
35 .runtime(Runtime::Tokio1)
36 .build()
37 .map_err(broker_err)?;
38
39 let _conn = pool.get().await.map_err(broker_err)?;
41
42 let worker_id = uuid::Uuid::now_v7().to_string();
43
44 Ok(Self {
45 pool,
46 keys: KeyBuilder::new(config.key_prefix),
47 worker_id,
48 dedup_ttl: config.dedup_ttl,
49 })
50 }
51
52 async fn conn(&self) -> TaskResult<deadpool_redis::Connection> {
54 self.pool.get().await.map_err(broker_err)
55 }
56
57 pub async fn poll_scheduled(&self) -> TaskResult<usize> {
59 let mut conn = self.conn().await?;
60 let now = chrono::Utc::now().timestamp();
61 let prefix = self.keys.scheduled().replace(":scheduled", "");
62
63 let script = redis::Script::new(crate::scripts::POLL_SCHEDULED_SCRIPT);
64 let count: usize = script
65 .key(self.keys.scheduled())
66 .arg(now)
67 .arg(prefix)
68 .invoke_async(&mut *conn)
69 .await
70 .map_err(broker_err)?;
71
72 Ok(count)
73 }
74}
75
76#[async_trait]
77impl Broker for RedisBroker {
78 async fn enqueue(&self, message: TaskMessage) -> TaskResult<()> {
79 let mut conn = self.conn().await?;
80
81 if let Some(ref dedup_key) = message.dedup_key {
83 let redis_key = self.keys.dedup(dedup_key);
84 let set: bool = redis::cmd("SET")
85 .arg(&redis_key)
86 .arg(1)
87 .arg("NX")
88 .arg("EX")
89 .arg(self.dedup_ttl)
90 .query_async(&mut *conn)
91 .await
92 .unwrap_or(false);
93
94 if !set {
95 tracing::debug!(dedup_key = %dedup_key, "duplicate task filtered by Redis SET NX");
96 return Ok(());
97 }
98 }
99
100 let queue_key = self.keys.queue(&message.queue);
101 let serialized = serde_json::to_string(&message)?;
102
103 conn.lpush::<_, _, ()>(&queue_key, &serialized)
104 .await
105 .map_err(broker_err)?;
106
107 Ok(())
108 }
109
110 async fn dequeue(
111 &self,
112 queues: &[String],
113 timeout: Duration,
114 ) -> TaskResult<Option<TaskMessage>> {
115 let mut conn = self.conn().await?;
116 let processing_key = self.keys.processing(&self.worker_id);
117
118 for queue_name in queues {
120 let queue_key = self.keys.queue(queue_name);
121 let timeout_secs = timeout.as_secs_f64();
122
123 let result: Option<String> = redis::cmd("BLMOVE")
124 .arg(&queue_key)
125 .arg(&processing_key)
126 .arg("RIGHT")
127 .arg("LEFT")
128 .arg(timeout_secs)
129 .query_async(&mut *conn)
130 .await
131 .map_err(broker_err)?;
132
133 if let Some(data) = result {
134 let message: TaskMessage = serde_json::from_str(&data)?;
135 return Ok(Some(message));
136 }
137 }
138
139 Ok(None)
140 }
141
142 async fn ack(&self, id: &TaskId) -> TaskResult<()> {
143 let mut conn = self.conn().await?;
144 let processing_key = self.keys.processing(&self.worker_id);
145 let id_str = id.to_string();
146
147 let items: Vec<String> = conn
149 .lrange(&processing_key, 0, -1)
150 .await
151 .map_err(broker_err)?;
152
153 for item in items {
154 if item.contains(&id_str) {
155 conn.lrem::<_, _, ()>(&processing_key, 1, &item)
156 .await
157 .map_err(broker_err)?;
158 break;
159 }
160 }
161
162 Ok(())
163 }
164
165 async fn nack(&self, message: TaskMessage) -> TaskResult<()> {
166 self.ack(&message.id).await?;
167 self.enqueue(message).await
168 }
169
170 async fn dead_letter(&self, message: TaskMessage) -> TaskResult<()> {
171 let mut conn = self.conn().await?;
172 self.ack(&message.id).await?;
173
174 let dlq_key = self.keys.dlq(&message.queue);
175 let serialized = serde_json::to_string(&message)?;
176 conn.lpush::<_, _, ()>(&dlq_key, &serialized)
177 .await
178 .map_err(broker_err)?;
179
180 Ok(())
181 }
182
183 async fn schedule(
184 &self,
185 message: TaskMessage,
186 eta: chrono::DateTime<chrono::Utc>,
187 ) -> TaskResult<()> {
188 let mut conn = self.conn().await?;
189 let scheduled_key = self.keys.scheduled();
190 let serialized = serde_json::to_string(&message)?;
191 let score = eta.timestamp() as f64;
192
193 conn.zadd::<_, _, _, ()>(&scheduled_key, &serialized, score)
194 .await
195 .map_err(broker_err)?;
196
197 Ok(())
198 }
199
200 async fn queue_len(&self, queue: &str) -> TaskResult<usize> {
201 let mut conn = self.conn().await?;
202 let queue_key = self.keys.queue(queue);
203 let len: usize = conn.llen(&queue_key).await.map_err(broker_err)?;
204 Ok(len)
205 }
206
207 async fn dlq_len(&self, queue: &str) -> TaskResult<usize> {
208 let mut conn = self.conn().await?;
209 let dlq_key = self.keys.dlq(queue);
210 let len: usize = conn.llen(&dlq_key).await.map_err(broker_err)?;
211 Ok(len)
212 }
213
214 async fn list_queues(&self) -> TaskResult<Vec<String>> {
215 let mut conn = self.conn().await?;
216 let pattern = format!("{}:queue:*", self.keys.prefix());
217 let keys: Vec<String> = redis::cmd("KEYS")
218 .arg(&pattern)
219 .query_async(&mut *conn)
220 .await
221 .map_err(broker_err)?;
222 let prefix = format!("{}:queue:", self.keys.prefix());
223 Ok(keys
224 .into_iter()
225 .filter_map(|k| k.strip_prefix(&prefix).map(String::from))
226 .collect())
227 }
228
229 async fn dlq_messages(
230 &self,
231 queue: &str,
232 offset: usize,
233 limit: usize,
234 ) -> TaskResult<Vec<TaskMessage>> {
235 let mut conn = self.conn().await?;
236 let dlq_key = self.keys.dlq(queue);
237 let end = offset + limit - 1;
238 let items: Vec<String> = conn
239 .lrange(&dlq_key, offset as isize, end as isize)
240 .await
241 .map_err(broker_err)?;
242 items
243 .into_iter()
244 .map(|s| serde_json::from_str(&s).map_err(KojinError::from))
245 .collect()
246 }
247}
248
249#[cfg(all(test, feature = "integration-tests"))]
250mod tests {
251 use super::*;
252 use testcontainers::{ImageExt, runners::AsyncRunner};
253 use testcontainers_modules::redis::Redis;
254
255 async fn setup_broker() -> (RedisBroker, testcontainers::ContainerAsync<Redis>) {
256 let container = Redis::default().with_tag("7").start().await.unwrap();
257 let port = container.get_host_port_ipv4(6379).await.unwrap();
258 let config = RedisConfig::new(format!("redis://127.0.0.1:{port}")).with_prefix("test");
259 let broker = RedisBroker::new(config).await.unwrap();
260 (broker, container)
261 }
262
263 #[tokio::test]
264 async fn enqueue_dequeue() {
265 let (broker, _container) = setup_broker().await;
266
267 let msg = TaskMessage::new("test_task", "default", serde_json::json!({"key": "value"}));
268 broker.enqueue(msg.clone()).await.unwrap();
269
270 let queues = vec!["default".to_string()];
271 let result = broker
272 .dequeue(&queues, Duration::from_secs(1))
273 .await
274 .unwrap();
275 assert!(result.is_some());
276 let dequeued = result.unwrap();
277 assert_eq!(dequeued.task_name, "test_task");
278 }
279
280 #[tokio::test]
281 async fn ack_and_nack() {
282 let (broker, _container) = setup_broker().await;
283
284 let msg = TaskMessage::new("test_task", "default", serde_json::json!({}));
285 broker.enqueue(msg).await.unwrap();
286
287 let queues = vec!["default".to_string()];
288 let dequeued = broker
289 .dequeue(&queues, Duration::from_secs(1))
290 .await
291 .unwrap()
292 .unwrap();
293
294 broker.ack(&dequeued.id).await.unwrap();
295 assert_eq!(broker.queue_len("default").await.unwrap(), 0);
296 }
297
298 #[tokio::test]
299 async fn dead_letter_queue() {
300 let (broker, _container) = setup_broker().await;
301
302 let msg = TaskMessage::new("test_task", "default", serde_json::json!({}));
303 broker.enqueue(msg).await.unwrap();
304
305 let queues = vec!["default".to_string()];
306 let dequeued = broker
307 .dequeue(&queues, Duration::from_secs(1))
308 .await
309 .unwrap()
310 .unwrap();
311
312 broker.dead_letter(dequeued).await.unwrap();
313
314 let mut conn = broker.conn().await.unwrap();
315 let dlq_len: usize = conn.llen(broker.keys.dlq("default")).await.unwrap();
316 assert_eq!(dlq_len, 1);
317 }
318
319 #[tokio::test]
320 async fn queue_len_tracking() {
321 let (broker, _container) = setup_broker().await;
322
323 assert_eq!(broker.queue_len("default").await.unwrap(), 0);
324
325 broker
326 .enqueue(TaskMessage::new("t", "default", serde_json::json!(1)))
327 .await
328 .unwrap();
329 broker
330 .enqueue(TaskMessage::new("t", "default", serde_json::json!(2)))
331 .await
332 .unwrap();
333
334 assert_eq!(broker.queue_len("default").await.unwrap(), 2);
335 }
336
337 #[tokio::test]
338 async fn dlq_len_via_trait() {
339 let (broker, _container) = setup_broker().await;
340
341 let msg = TaskMessage::new("test_task", "default", serde_json::json!({}));
342 broker.enqueue(msg).await.unwrap();
343
344 let queues = vec!["default".to_string()];
345 let dequeued = broker
346 .dequeue(&queues, Duration::from_secs(1))
347 .await
348 .unwrap()
349 .unwrap();
350
351 broker.dead_letter(dequeued).await.unwrap();
352 assert_eq!(broker.dlq_len("default").await.unwrap(), 1);
353 }
354
355 #[tokio::test]
356 async fn list_queues_returns_known() {
357 let (broker, _container) = setup_broker().await;
358
359 broker
360 .enqueue(TaskMessage::new("t", "emails", serde_json::json!(1)))
361 .await
362 .unwrap();
363 broker
364 .enqueue(TaskMessage::new("t", "default", serde_json::json!(2)))
365 .await
366 .unwrap();
367
368 let mut queues = broker.list_queues().await.unwrap();
369 queues.sort();
370 assert!(queues.contains(&"emails".to_string()));
371 assert!(queues.contains(&"default".to_string()));
372 }
373
374 #[tokio::test]
375 async fn dlq_messages_returns_content() {
376 let (broker, _container) = setup_broker().await;
377 let queues = vec!["default".to_string()];
378
379 for name in ["task_a", "task_b"] {
381 let msg = TaskMessage::new(name, "default", serde_json::json!({}));
382 broker.enqueue(msg).await.unwrap();
383 let dequeued = broker
384 .dequeue(&queues, Duration::from_secs(1))
385 .await
386 .unwrap()
387 .unwrap();
388 broker.dead_letter(dequeued).await.unwrap();
389 }
390
391 let messages = broker.dlq_messages("default", 0, 10).await.unwrap();
392 assert_eq!(messages.len(), 2);
393 let names: Vec<&str> = messages.iter().map(|m| m.task_name.as_str()).collect();
394 assert!(names.contains(&"task_a"));
395 assert!(names.contains(&"task_b"));
396 }
397
398 #[tokio::test]
399 async fn schedule_and_poll() {
400 let (broker, _container) = setup_broker().await;
401
402 let msg = TaskMessage::new("scheduled_task", "default", serde_json::json!({}));
403 let past = chrono::Utc::now() - chrono::Duration::seconds(10);
404 broker.schedule(msg, past).await.unwrap();
405
406 let count = broker.poll_scheduled().await.unwrap();
407 assert_eq!(count, 1);
408 assert_eq!(broker.queue_len("default").await.unwrap(), 1);
409 }
410}