Skip to main content

kojin_redis/
broker.rs

1use async_trait::async_trait;
2use deadpool_redis::{Config, Pool, Runtime};
3use redis::AsyncCommands;
4use std::time::Duration;
5
6use kojin_core::broker::Broker;
7use kojin_core::error::{KojinError, TaskResult};
8use kojin_core::message::TaskMessage;
9use kojin_core::task_id::TaskId;
10
11use crate::config::RedisConfig;
12use crate::keys::KeyBuilder;
13
14fn broker_err(e: impl std::fmt::Display) -> KojinError {
15    KojinError::Broker(e.to_string())
16}
17
18/// Redis-backed message broker.
19#[derive(Clone)]
20pub struct RedisBroker {
21    pool: Pool,
22    keys: KeyBuilder,
23    worker_id: String,
24}
25
26impl RedisBroker {
27    /// Create a new Redis broker from config.
28    pub async fn new(config: RedisConfig) -> TaskResult<Self> {
29        let cfg = Config::from_url(&config.url);
30        let pool = cfg
31            .builder()
32            .map_err(broker_err)?
33            .max_size(config.pool_size)
34            .runtime(Runtime::Tokio1)
35            .build()
36            .map_err(broker_err)?;
37
38        // Verify connection
39        let _conn = pool.get().await.map_err(broker_err)?;
40
41        let worker_id = uuid::Uuid::now_v7().to_string();
42
43        Ok(Self {
44            pool,
45            keys: KeyBuilder::new(config.key_prefix),
46            worker_id,
47        })
48    }
49
50    /// Get a connection from the pool.
51    async fn conn(&self) -> TaskResult<deadpool_redis::Connection> {
52        self.pool.get().await.map_err(broker_err)
53    }
54
55    /// Poll the scheduled set and move due items to their queues.
56    pub async fn poll_scheduled(&self) -> TaskResult<usize> {
57        let mut conn = self.conn().await?;
58        let now = chrono::Utc::now().timestamp();
59        let prefix = self.keys.scheduled().replace(":scheduled", "");
60
61        let script = redis::Script::new(crate::scripts::POLL_SCHEDULED_SCRIPT);
62        let count: usize = script
63            .key(self.keys.scheduled())
64            .arg(now)
65            .arg(prefix)
66            .invoke_async(&mut *conn)
67            .await
68            .map_err(broker_err)?;
69
70        Ok(count)
71    }
72}
73
74#[async_trait]
75impl Broker for RedisBroker {
76    async fn enqueue(&self, message: TaskMessage) -> TaskResult<()> {
77        let mut conn = self.conn().await?;
78        let queue_key = self.keys.queue(&message.queue);
79        let serialized = serde_json::to_string(&message)?;
80
81        conn.lpush::<_, _, ()>(&queue_key, &serialized)
82            .await
83            .map_err(broker_err)?;
84
85        Ok(())
86    }
87
88    async fn dequeue(
89        &self,
90        queues: &[String],
91        timeout: Duration,
92    ) -> TaskResult<Option<TaskMessage>> {
93        let mut conn = self.conn().await?;
94        let processing_key = self.keys.processing(&self.worker_id);
95
96        // Try BLMOVE from each queue in order
97        for queue_name in queues {
98            let queue_key = self.keys.queue(queue_name);
99            let timeout_secs = timeout.as_secs_f64();
100
101            let result: Option<String> = redis::cmd("BLMOVE")
102                .arg(&queue_key)
103                .arg(&processing_key)
104                .arg("RIGHT")
105                .arg("LEFT")
106                .arg(timeout_secs)
107                .query_async(&mut *conn)
108                .await
109                .map_err(broker_err)?;
110
111            if let Some(data) = result {
112                let message: TaskMessage = serde_json::from_str(&data)?;
113                return Ok(Some(message));
114            }
115        }
116
117        Ok(None)
118    }
119
120    async fn ack(&self, id: &TaskId) -> TaskResult<()> {
121        let mut conn = self.conn().await?;
122        let processing_key = self.keys.processing(&self.worker_id);
123        let id_str = id.to_string();
124
125        // Remove from processing list by scanning for the task ID in serialized messages
126        let items: Vec<String> = conn
127            .lrange(&processing_key, 0, -1)
128            .await
129            .map_err(broker_err)?;
130
131        for item in items {
132            if item.contains(&id_str) {
133                conn.lrem::<_, _, ()>(&processing_key, 1, &item)
134                    .await
135                    .map_err(broker_err)?;
136                break;
137            }
138        }
139
140        Ok(())
141    }
142
143    async fn nack(&self, message: TaskMessage) -> TaskResult<()> {
144        self.ack(&message.id).await?;
145        self.enqueue(message).await
146    }
147
148    async fn dead_letter(&self, message: TaskMessage) -> TaskResult<()> {
149        let mut conn = self.conn().await?;
150        self.ack(&message.id).await?;
151
152        let dlq_key = self.keys.dlq(&message.queue);
153        let serialized = serde_json::to_string(&message)?;
154        conn.lpush::<_, _, ()>(&dlq_key, &serialized)
155            .await
156            .map_err(broker_err)?;
157
158        Ok(())
159    }
160
161    async fn schedule(
162        &self,
163        message: TaskMessage,
164        eta: chrono::DateTime<chrono::Utc>,
165    ) -> TaskResult<()> {
166        let mut conn = self.conn().await?;
167        let scheduled_key = self.keys.scheduled();
168        let serialized = serde_json::to_string(&message)?;
169        let score = eta.timestamp() as f64;
170
171        conn.zadd::<_, _, _, ()>(&scheduled_key, &serialized, score)
172            .await
173            .map_err(broker_err)?;
174
175        Ok(())
176    }
177
178    async fn queue_len(&self, queue: &str) -> TaskResult<usize> {
179        let mut conn = self.conn().await?;
180        let queue_key = self.keys.queue(queue);
181        let len: usize = conn.llen(&queue_key).await.map_err(broker_err)?;
182        Ok(len)
183    }
184}
185
186#[cfg(all(test, feature = "integration-tests"))]
187mod tests {
188    use super::*;
189    use testcontainers::{ImageExt, runners::AsyncRunner};
190    use testcontainers_modules::redis::Redis;
191
192    async fn setup_broker() -> (RedisBroker, testcontainers::ContainerAsync<Redis>) {
193        let container = Redis::default().with_tag("7").start().await.unwrap();
194        let port = container.get_host_port_ipv4(6379).await.unwrap();
195        let config = RedisConfig::new(format!("redis://127.0.0.1:{port}")).with_prefix("test");
196        let broker = RedisBroker::new(config).await.unwrap();
197        (broker, container)
198    }
199
200    #[tokio::test]
201    async fn enqueue_dequeue() {
202        let (broker, _container) = setup_broker().await;
203
204        let msg = TaskMessage::new("test_task", "default", serde_json::json!({"key": "value"}));
205        broker.enqueue(msg.clone()).await.unwrap();
206
207        let queues = vec!["default".to_string()];
208        let result = broker
209            .dequeue(&queues, Duration::from_secs(1))
210            .await
211            .unwrap();
212        assert!(result.is_some());
213        let dequeued = result.unwrap();
214        assert_eq!(dequeued.task_name, "test_task");
215    }
216
217    #[tokio::test]
218    async fn ack_and_nack() {
219        let (broker, _container) = setup_broker().await;
220
221        let msg = TaskMessage::new("test_task", "default", serde_json::json!({}));
222        broker.enqueue(msg).await.unwrap();
223
224        let queues = vec!["default".to_string()];
225        let dequeued = broker
226            .dequeue(&queues, Duration::from_secs(1))
227            .await
228            .unwrap()
229            .unwrap();
230
231        broker.ack(&dequeued.id).await.unwrap();
232        assert_eq!(broker.queue_len("default").await.unwrap(), 0);
233    }
234
235    #[tokio::test]
236    async fn dead_letter_queue() {
237        let (broker, _container) = setup_broker().await;
238
239        let msg = TaskMessage::new("test_task", "default", serde_json::json!({}));
240        broker.enqueue(msg).await.unwrap();
241
242        let queues = vec!["default".to_string()];
243        let dequeued = broker
244            .dequeue(&queues, Duration::from_secs(1))
245            .await
246            .unwrap()
247            .unwrap();
248
249        broker.dead_letter(dequeued).await.unwrap();
250
251        let mut conn = broker.conn().await.unwrap();
252        let dlq_len: usize = conn.llen(broker.keys.dlq("default")).await.unwrap();
253        assert_eq!(dlq_len, 1);
254    }
255
256    #[tokio::test]
257    async fn queue_len_tracking() {
258        let (broker, _container) = setup_broker().await;
259
260        assert_eq!(broker.queue_len("default").await.unwrap(), 0);
261
262        broker
263            .enqueue(TaskMessage::new("t", "default", serde_json::json!(1)))
264            .await
265            .unwrap();
266        broker
267            .enqueue(TaskMessage::new("t", "default", serde_json::json!(2)))
268            .await
269            .unwrap();
270
271        assert_eq!(broker.queue_len("default").await.unwrap(), 2);
272    }
273
274    #[tokio::test]
275    async fn schedule_and_poll() {
276        let (broker, _container) = setup_broker().await;
277
278        let msg = TaskMessage::new("scheduled_task", "default", serde_json::json!({}));
279        let past = chrono::Utc::now() - chrono::Duration::seconds(10);
280        broker.schedule(msg, past).await.unwrap();
281
282        let count = broker.poll_scheduled().await.unwrap();
283        assert_eq!(count, 1);
284        assert_eq!(broker.queue_len("default").await.unwrap(), 1);
285    }
286}