Skip to main content

kojin_redis/
broker.rs

1use async_trait::async_trait;
2use deadpool_redis::{Config, Pool, Runtime};
3use redis::AsyncCommands;
4use std::time::Duration;
5
6use kojin_core::broker::Broker;
7use kojin_core::error::{KojinError, TaskResult};
8use kojin_core::message::TaskMessage;
9use kojin_core::task_id::TaskId;
10
11use crate::config::RedisConfig;
12use crate::keys::KeyBuilder;
13
14fn broker_err(e: impl std::fmt::Display) -> KojinError {
15    KojinError::Broker(e.to_string())
16}
17
18/// Redis-backed message broker.
19#[derive(Clone)]
20pub struct RedisBroker {
21    pool: Pool,
22    keys: KeyBuilder,
23    worker_id: String,
24    dedup_ttl: u64,
25}
26
27impl RedisBroker {
28    /// Create a new Redis broker from config.
29    pub async fn new(config: RedisConfig) -> TaskResult<Self> {
30        let cfg = Config::from_url(&config.url);
31        let pool = cfg
32            .builder()
33            .map_err(broker_err)?
34            .max_size(config.pool_size)
35            .runtime(Runtime::Tokio1)
36            .build()
37            .map_err(broker_err)?;
38
39        // Verify connection
40        let _conn = pool.get().await.map_err(broker_err)?;
41
42        let worker_id = uuid::Uuid::now_v7().to_string();
43
44        Ok(Self {
45            pool,
46            keys: KeyBuilder::new(config.key_prefix),
47            worker_id,
48            dedup_ttl: config.dedup_ttl,
49        })
50    }
51
52    /// Get a connection from the pool.
53    async fn conn(&self) -> TaskResult<deadpool_redis::Connection> {
54        self.pool.get().await.map_err(broker_err)
55    }
56
57    /// Poll the scheduled set and move due items to their queues.
58    pub async fn poll_scheduled(&self) -> TaskResult<usize> {
59        let mut conn = self.conn().await?;
60        let now = chrono::Utc::now().timestamp();
61        let prefix = self.keys.scheduled().replace(":scheduled", "");
62
63        let script = redis::Script::new(crate::scripts::POLL_SCHEDULED_SCRIPT);
64        let count: usize = script
65            .key(self.keys.scheduled())
66            .arg(now)
67            .arg(prefix)
68            .invoke_async(&mut *conn)
69            .await
70            .map_err(broker_err)?;
71
72        Ok(count)
73    }
74}
75
76#[async_trait]
77impl Broker for RedisBroker {
78    async fn enqueue(&self, message: TaskMessage) -> TaskResult<()> {
79        let mut conn = self.conn().await?;
80
81        // Deduplication check via SET NX with TTL
82        if let Some(ref dedup_key) = message.dedup_key {
83            let redis_key = self.keys.dedup(dedup_key);
84            let set: bool = redis::cmd("SET")
85                .arg(&redis_key)
86                .arg(1)
87                .arg("NX")
88                .arg("EX")
89                .arg(self.dedup_ttl)
90                .query_async(&mut *conn)
91                .await
92                .unwrap_or(false);
93
94            if !set {
95                tracing::debug!(dedup_key = %dedup_key, "duplicate task filtered by Redis SET NX");
96                return Ok(());
97            }
98        }
99
100        let queue_key = self.keys.queue(&message.queue);
101        let serialized = serde_json::to_string(&message)?;
102
103        conn.lpush::<_, _, ()>(&queue_key, &serialized)
104            .await
105            .map_err(broker_err)?;
106
107        Ok(())
108    }
109
110    async fn dequeue(
111        &self,
112        queues: &[String],
113        timeout: Duration,
114    ) -> TaskResult<Option<TaskMessage>> {
115        let mut conn = self.conn().await?;
116        let processing_key = self.keys.processing(&self.worker_id);
117
118        // Try BLMOVE from each queue in order
119        for queue_name in queues {
120            let queue_key = self.keys.queue(queue_name);
121            let timeout_secs = timeout.as_secs_f64();
122
123            let result: Option<String> = redis::cmd("BLMOVE")
124                .arg(&queue_key)
125                .arg(&processing_key)
126                .arg("RIGHT")
127                .arg("LEFT")
128                .arg(timeout_secs)
129                .query_async(&mut *conn)
130                .await
131                .map_err(broker_err)?;
132
133            if let Some(data) = result {
134                let message: TaskMessage = serde_json::from_str(&data)?;
135                return Ok(Some(message));
136            }
137        }
138
139        Ok(None)
140    }
141
142    async fn ack(&self, id: &TaskId) -> TaskResult<()> {
143        let mut conn = self.conn().await?;
144        let processing_key = self.keys.processing(&self.worker_id);
145        let id_str = id.to_string();
146
147        // Remove from processing list by scanning for the task ID in serialized messages
148        let items: Vec<String> = conn
149            .lrange(&processing_key, 0, -1)
150            .await
151            .map_err(broker_err)?;
152
153        for item in items {
154            if item.contains(&id_str) {
155                conn.lrem::<_, _, ()>(&processing_key, 1, &item)
156                    .await
157                    .map_err(broker_err)?;
158                break;
159            }
160        }
161
162        Ok(())
163    }
164
165    async fn nack(&self, message: TaskMessage) -> TaskResult<()> {
166        self.ack(&message.id).await?;
167        self.enqueue(message).await
168    }
169
170    async fn dead_letter(&self, message: TaskMessage) -> TaskResult<()> {
171        let mut conn = self.conn().await?;
172        self.ack(&message.id).await?;
173
174        let dlq_key = self.keys.dlq(&message.queue);
175        let serialized = serde_json::to_string(&message)?;
176        conn.lpush::<_, _, ()>(&dlq_key, &serialized)
177            .await
178            .map_err(broker_err)?;
179
180        Ok(())
181    }
182
183    async fn schedule(
184        &self,
185        message: TaskMessage,
186        eta: chrono::DateTime<chrono::Utc>,
187    ) -> TaskResult<()> {
188        let mut conn = self.conn().await?;
189        let scheduled_key = self.keys.scheduled();
190        let serialized = serde_json::to_string(&message)?;
191        let score = eta.timestamp() as f64;
192
193        conn.zadd::<_, _, _, ()>(&scheduled_key, &serialized, score)
194            .await
195            .map_err(broker_err)?;
196
197        Ok(())
198    }
199
200    async fn queue_len(&self, queue: &str) -> TaskResult<usize> {
201        let mut conn = self.conn().await?;
202        let queue_key = self.keys.queue(queue);
203        let len: usize = conn.llen(&queue_key).await.map_err(broker_err)?;
204        Ok(len)
205    }
206
207    async fn dlq_len(&self, queue: &str) -> TaskResult<usize> {
208        let mut conn = self.conn().await?;
209        let dlq_key = self.keys.dlq(queue);
210        let len: usize = conn.llen(&dlq_key).await.map_err(broker_err)?;
211        Ok(len)
212    }
213
214    async fn list_queues(&self) -> TaskResult<Vec<String>> {
215        let mut conn = self.conn().await?;
216        let pattern = format!("{}:queue:*", self.keys.prefix());
217        let keys: Vec<String> = redis::cmd("KEYS")
218            .arg(&pattern)
219            .query_async(&mut *conn)
220            .await
221            .map_err(broker_err)?;
222        let prefix = format!("{}:queue:", self.keys.prefix());
223        Ok(keys
224            .into_iter()
225            .filter_map(|k| k.strip_prefix(&prefix).map(String::from))
226            .collect())
227    }
228
229    async fn dlq_messages(
230        &self,
231        queue: &str,
232        offset: usize,
233        limit: usize,
234    ) -> TaskResult<Vec<TaskMessage>> {
235        let mut conn = self.conn().await?;
236        let dlq_key = self.keys.dlq(queue);
237        let end = offset + limit - 1;
238        let items: Vec<String> = conn
239            .lrange(&dlq_key, offset as isize, end as isize)
240            .await
241            .map_err(broker_err)?;
242        items
243            .into_iter()
244            .map(|s| serde_json::from_str(&s).map_err(KojinError::from))
245            .collect()
246    }
247}
248
249#[cfg(all(test, feature = "integration-tests"))]
250mod tests {
251    use super::*;
252    use testcontainers::{ImageExt, runners::AsyncRunner};
253    use testcontainers_modules::redis::Redis;
254
255    async fn setup_broker() -> (RedisBroker, testcontainers::ContainerAsync<Redis>) {
256        let container = Redis::default().with_tag("7").start().await.unwrap();
257        let port = container.get_host_port_ipv4(6379).await.unwrap();
258        let config = RedisConfig::new(format!("redis://127.0.0.1:{port}")).with_prefix("test");
259        let broker = RedisBroker::new(config).await.unwrap();
260        (broker, container)
261    }
262
263    #[tokio::test]
264    async fn enqueue_dequeue() {
265        let (broker, _container) = setup_broker().await;
266
267        let msg = TaskMessage::new("test_task", "default", serde_json::json!({"key": "value"}));
268        broker.enqueue(msg.clone()).await.unwrap();
269
270        let queues = vec!["default".to_string()];
271        let result = broker
272            .dequeue(&queues, Duration::from_secs(1))
273            .await
274            .unwrap();
275        assert!(result.is_some());
276        let dequeued = result.unwrap();
277        assert_eq!(dequeued.task_name, "test_task");
278    }
279
280    #[tokio::test]
281    async fn ack_and_nack() {
282        let (broker, _container) = setup_broker().await;
283
284        let msg = TaskMessage::new("test_task", "default", serde_json::json!({}));
285        broker.enqueue(msg).await.unwrap();
286
287        let queues = vec!["default".to_string()];
288        let dequeued = broker
289            .dequeue(&queues, Duration::from_secs(1))
290            .await
291            .unwrap()
292            .unwrap();
293
294        broker.ack(&dequeued.id).await.unwrap();
295        assert_eq!(broker.queue_len("default").await.unwrap(), 0);
296    }
297
298    #[tokio::test]
299    async fn dead_letter_queue() {
300        let (broker, _container) = setup_broker().await;
301
302        let msg = TaskMessage::new("test_task", "default", serde_json::json!({}));
303        broker.enqueue(msg).await.unwrap();
304
305        let queues = vec!["default".to_string()];
306        let dequeued = broker
307            .dequeue(&queues, Duration::from_secs(1))
308            .await
309            .unwrap()
310            .unwrap();
311
312        broker.dead_letter(dequeued).await.unwrap();
313
314        let mut conn = broker.conn().await.unwrap();
315        let dlq_len: usize = conn.llen(broker.keys.dlq("default")).await.unwrap();
316        assert_eq!(dlq_len, 1);
317    }
318
319    #[tokio::test]
320    async fn queue_len_tracking() {
321        let (broker, _container) = setup_broker().await;
322
323        assert_eq!(broker.queue_len("default").await.unwrap(), 0);
324
325        broker
326            .enqueue(TaskMessage::new("t", "default", serde_json::json!(1)))
327            .await
328            .unwrap();
329        broker
330            .enqueue(TaskMessage::new("t", "default", serde_json::json!(2)))
331            .await
332            .unwrap();
333
334        assert_eq!(broker.queue_len("default").await.unwrap(), 2);
335    }
336
337    #[tokio::test]
338    async fn dlq_len_via_trait() {
339        let (broker, _container) = setup_broker().await;
340
341        let msg = TaskMessage::new("test_task", "default", serde_json::json!({}));
342        broker.enqueue(msg).await.unwrap();
343
344        let queues = vec!["default".to_string()];
345        let dequeued = broker
346            .dequeue(&queues, Duration::from_secs(1))
347            .await
348            .unwrap()
349            .unwrap();
350
351        broker.dead_letter(dequeued).await.unwrap();
352        assert_eq!(broker.dlq_len("default").await.unwrap(), 1);
353    }
354
355    #[tokio::test]
356    async fn list_queues_returns_known() {
357        let (broker, _container) = setup_broker().await;
358
359        broker
360            .enqueue(TaskMessage::new("t", "emails", serde_json::json!(1)))
361            .await
362            .unwrap();
363        broker
364            .enqueue(TaskMessage::new("t", "default", serde_json::json!(2)))
365            .await
366            .unwrap();
367
368        let mut queues = broker.list_queues().await.unwrap();
369        queues.sort();
370        assert!(queues.contains(&"emails".to_string()));
371        assert!(queues.contains(&"default".to_string()));
372    }
373
374    #[tokio::test]
375    async fn dlq_messages_returns_content() {
376        let (broker, _container) = setup_broker().await;
377        let queues = vec!["default".to_string()];
378
379        // Dead-letter 2 messages
380        for name in ["task_a", "task_b"] {
381            let msg = TaskMessage::new(name, "default", serde_json::json!({}));
382            broker.enqueue(msg).await.unwrap();
383            let dequeued = broker
384                .dequeue(&queues, Duration::from_secs(1))
385                .await
386                .unwrap()
387                .unwrap();
388            broker.dead_letter(dequeued).await.unwrap();
389        }
390
391        let messages = broker.dlq_messages("default", 0, 10).await.unwrap();
392        assert_eq!(messages.len(), 2);
393        let names: Vec<&str> = messages.iter().map(|m| m.task_name.as_str()).collect();
394        assert!(names.contains(&"task_a"));
395        assert!(names.contains(&"task_b"));
396    }
397
398    #[tokio::test]
399    async fn schedule_and_poll() {
400        let (broker, _container) = setup_broker().await;
401
402        let msg = TaskMessage::new("scheduled_task", "default", serde_json::json!({}));
403        let past = chrono::Utc::now() - chrono::Duration::seconds(10);
404        broker.schedule(msg, past).await.unwrap();
405
406        let count = broker.poll_scheduled().await.unwrap();
407        assert_eq!(count, 1);
408        assert_eq!(broker.queue_len("default").await.unwrap(), 1);
409    }
410}