1use async_trait::async_trait;
2use deadpool_redis::{Config, Pool, Runtime};
3use redis::AsyncCommands;
4use std::time::Duration;
5
6use kojin_core::broker::Broker;
7use kojin_core::error::{KojinError, TaskResult};
8use kojin_core::message::TaskMessage;
9use kojin_core::task_id::TaskId;
10
11use crate::config::RedisConfig;
12use crate::keys::KeyBuilder;
13
14fn broker_err(e: impl std::fmt::Display) -> KojinError {
15 KojinError::Broker(e.to_string())
16}
17
18#[derive(Clone)]
20pub struct RedisBroker {
21 pool: Pool,
22 keys: KeyBuilder,
23 worker_id: String,
24}
25
26impl RedisBroker {
27 pub async fn new(config: RedisConfig) -> TaskResult<Self> {
29 let cfg = Config::from_url(&config.url);
30 let pool = cfg
31 .builder()
32 .map_err(broker_err)?
33 .max_size(config.pool_size)
34 .runtime(Runtime::Tokio1)
35 .build()
36 .map_err(broker_err)?;
37
38 let _conn = pool.get().await.map_err(broker_err)?;
40
41 let worker_id = uuid::Uuid::now_v7().to_string();
42
43 Ok(Self {
44 pool,
45 keys: KeyBuilder::new(config.key_prefix),
46 worker_id,
47 })
48 }
49
50 async fn conn(&self) -> TaskResult<deadpool_redis::Connection> {
52 self.pool.get().await.map_err(broker_err)
53 }
54
55 pub async fn poll_scheduled(&self) -> TaskResult<usize> {
57 let mut conn = self.conn().await?;
58 let now = chrono::Utc::now().timestamp();
59 let prefix = self.keys.scheduled().replace(":scheduled", "");
60
61 let script = redis::Script::new(crate::scripts::POLL_SCHEDULED_SCRIPT);
62 let count: usize = script
63 .key(self.keys.scheduled())
64 .arg(now)
65 .arg(prefix)
66 .invoke_async(&mut *conn)
67 .await
68 .map_err(broker_err)?;
69
70 Ok(count)
71 }
72}
73
74#[async_trait]
75impl Broker for RedisBroker {
76 async fn enqueue(&self, message: TaskMessage) -> TaskResult<()> {
77 let mut conn = self.conn().await?;
78 let queue_key = self.keys.queue(&message.queue);
79 let serialized = serde_json::to_string(&message)?;
80
81 conn.lpush::<_, _, ()>(&queue_key, &serialized)
82 .await
83 .map_err(broker_err)?;
84
85 Ok(())
86 }
87
88 async fn dequeue(
89 &self,
90 queues: &[String],
91 timeout: Duration,
92 ) -> TaskResult<Option<TaskMessage>> {
93 let mut conn = self.conn().await?;
94 let processing_key = self.keys.processing(&self.worker_id);
95
96 for queue_name in queues {
98 let queue_key = self.keys.queue(queue_name);
99 let timeout_secs = timeout.as_secs_f64();
100
101 let result: Option<String> = redis::cmd("BLMOVE")
102 .arg(&queue_key)
103 .arg(&processing_key)
104 .arg("RIGHT")
105 .arg("LEFT")
106 .arg(timeout_secs)
107 .query_async(&mut *conn)
108 .await
109 .map_err(broker_err)?;
110
111 if let Some(data) = result {
112 let message: TaskMessage = serde_json::from_str(&data)?;
113 return Ok(Some(message));
114 }
115 }
116
117 Ok(None)
118 }
119
120 async fn ack(&self, id: &TaskId) -> TaskResult<()> {
121 let mut conn = self.conn().await?;
122 let processing_key = self.keys.processing(&self.worker_id);
123 let id_str = id.to_string();
124
125 let items: Vec<String> = conn
127 .lrange(&processing_key, 0, -1)
128 .await
129 .map_err(broker_err)?;
130
131 for item in items {
132 if item.contains(&id_str) {
133 conn.lrem::<_, _, ()>(&processing_key, 1, &item)
134 .await
135 .map_err(broker_err)?;
136 break;
137 }
138 }
139
140 Ok(())
141 }
142
143 async fn nack(&self, message: TaskMessage) -> TaskResult<()> {
144 self.ack(&message.id).await?;
145 self.enqueue(message).await
146 }
147
148 async fn dead_letter(&self, message: TaskMessage) -> TaskResult<()> {
149 let mut conn = self.conn().await?;
150 self.ack(&message.id).await?;
151
152 let dlq_key = self.keys.dlq(&message.queue);
153 let serialized = serde_json::to_string(&message)?;
154 conn.lpush::<_, _, ()>(&dlq_key, &serialized)
155 .await
156 .map_err(broker_err)?;
157
158 Ok(())
159 }
160
161 async fn schedule(
162 &self,
163 message: TaskMessage,
164 eta: chrono::DateTime<chrono::Utc>,
165 ) -> TaskResult<()> {
166 let mut conn = self.conn().await?;
167 let scheduled_key = self.keys.scheduled();
168 let serialized = serde_json::to_string(&message)?;
169 let score = eta.timestamp() as f64;
170
171 conn.zadd::<_, _, _, ()>(&scheduled_key, &serialized, score)
172 .await
173 .map_err(broker_err)?;
174
175 Ok(())
176 }
177
178 async fn queue_len(&self, queue: &str) -> TaskResult<usize> {
179 let mut conn = self.conn().await?;
180 let queue_key = self.keys.queue(queue);
181 let len: usize = conn.llen(&queue_key).await.map_err(broker_err)?;
182 Ok(len)
183 }
184}
185
186#[cfg(all(test, feature = "integration-tests"))]
187mod tests {
188 use super::*;
189 use testcontainers::{ImageExt, runners::AsyncRunner};
190 use testcontainers_modules::redis::Redis;
191
192 async fn setup_broker() -> (RedisBroker, testcontainers::ContainerAsync<Redis>) {
193 let container = Redis::default().with_tag("7").start().await.unwrap();
194 let port = container.get_host_port_ipv4(6379).await.unwrap();
195 let config = RedisConfig::new(format!("redis://127.0.0.1:{port}")).with_prefix("test");
196 let broker = RedisBroker::new(config).await.unwrap();
197 (broker, container)
198 }
199
200 #[tokio::test]
201 async fn enqueue_dequeue() {
202 let (broker, _container) = setup_broker().await;
203
204 let msg = TaskMessage::new("test_task", "default", serde_json::json!({"key": "value"}));
205 broker.enqueue(msg.clone()).await.unwrap();
206
207 let queues = vec!["default".to_string()];
208 let result = broker
209 .dequeue(&queues, Duration::from_secs(1))
210 .await
211 .unwrap();
212 assert!(result.is_some());
213 let dequeued = result.unwrap();
214 assert_eq!(dequeued.task_name, "test_task");
215 }
216
217 #[tokio::test]
218 async fn ack_and_nack() {
219 let (broker, _container) = setup_broker().await;
220
221 let msg = TaskMessage::new("test_task", "default", serde_json::json!({}));
222 broker.enqueue(msg).await.unwrap();
223
224 let queues = vec!["default".to_string()];
225 let dequeued = broker
226 .dequeue(&queues, Duration::from_secs(1))
227 .await
228 .unwrap()
229 .unwrap();
230
231 broker.ack(&dequeued.id).await.unwrap();
232 assert_eq!(broker.queue_len("default").await.unwrap(), 0);
233 }
234
235 #[tokio::test]
236 async fn dead_letter_queue() {
237 let (broker, _container) = setup_broker().await;
238
239 let msg = TaskMessage::new("test_task", "default", serde_json::json!({}));
240 broker.enqueue(msg).await.unwrap();
241
242 let queues = vec!["default".to_string()];
243 let dequeued = broker
244 .dequeue(&queues, Duration::from_secs(1))
245 .await
246 .unwrap()
247 .unwrap();
248
249 broker.dead_letter(dequeued).await.unwrap();
250
251 let mut conn = broker.conn().await.unwrap();
252 let dlq_len: usize = conn.llen(broker.keys.dlq("default")).await.unwrap();
253 assert_eq!(dlq_len, 1);
254 }
255
256 #[tokio::test]
257 async fn queue_len_tracking() {
258 let (broker, _container) = setup_broker().await;
259
260 assert_eq!(broker.queue_len("default").await.unwrap(), 0);
261
262 broker
263 .enqueue(TaskMessage::new("t", "default", serde_json::json!(1)))
264 .await
265 .unwrap();
266 broker
267 .enqueue(TaskMessage::new("t", "default", serde_json::json!(2)))
268 .await
269 .unwrap();
270
271 assert_eq!(broker.queue_len("default").await.unwrap(), 2);
272 }
273
274 #[tokio::test]
275 async fn schedule_and_poll() {
276 let (broker, _container) = setup_broker().await;
277
278 let msg = TaskMessage::new("scheduled_task", "default", serde_json::json!({}));
279 let past = chrono::Utc::now() - chrono::Duration::seconds(10);
280 broker.schedule(msg, past).await.unwrap();
281
282 let count = broker.poll_scheduled().await.unwrap();
283 assert_eq!(count, 1);
284 assert_eq!(broker.queue_len("default").await.unwrap(), 1);
285 }
286}