1use std::sync::Arc;
2use std::time::Duration;
3
4use deadpool_redis::Pool;
5use tokio::sync::broadcast;
6
7use crate::config::StreamConfig;
8use crate::error::{StreamError, StreamResult};
9use crate::handler::StreamHandler;
10use crate::parse::{
11 extract_field, parse_claimed_messages, parse_pending_entries, parse_stream_response,
12};
13
14pub struct StreamConsumer<H: StreamHandler> {
22 pool: Pool,
23 config: StreamConfig,
24 handler: Arc<H>,
25}
26
27impl<H: StreamHandler> StreamConsumer<H> {
28 pub fn new(pool: Pool, config: StreamConfig, handler: H) -> Self {
30 Self {
31 pool,
32 config,
33 handler: Arc::new(handler),
34 }
35 }
36
37 pub async fn run(&self, mut shutdown_rx: broadcast::Receiver<()>) -> StreamResult<()> {
45 self.ensure_consumer_group().await?;
46
47 let mut reclaim_interval = tokio::time::interval(self.config.reclaim_interval);
48 reclaim_interval.tick().await; loop {
51 let messages = loop {
53 tokio::select! {
54 _ = shutdown_rx.recv() => {
55 return Ok(());
56 }
57 _ = reclaim_interval.tick() => {
58 drop(self.reclaim_pending_messages().await);
59 }
60 result = self.read_from_stream() => {
61 match result {
62 Ok(msgs) if msgs.is_empty() => continue,
63 Ok(msgs) => break msgs,
64 Err(_) => {
65 tokio::time::sleep(Duration::from_secs(1)).await;
66 continue;
67 }
68 }
69 }
70 }
71 };
72
73 self.process_and_ack(&messages).await;
75 }
76 }
77
78 async fn ensure_consumer_group(&self) -> StreamResult<()> {
80 let mut conn = self.pool.get().await?;
81 let result: Result<String, redis::RedisError> = redis::cmd("XGROUP")
82 .arg("CREATE")
83 .arg(&self.config.stream_key)
84 .arg(&self.config.consumer_group)
85 .arg(&self.config.group_start_id)
86 .arg("MKSTREAM")
87 .query_async(&mut *conn)
88 .await;
89
90 match result {
91 Ok(_) => {}
92 Err(e) if e.to_string().contains("BUSYGROUP") => {}
93 Err(e) => return Err(StreamError::Redis(e)),
94 }
95 Ok(())
96 }
97
98 async fn read_from_stream(&self) -> StreamResult<Vec<(String, Vec<(String, String)>)>> {
100 let mut conn = self.pool.get().await?;
101 let result: redis::Value = redis::cmd("XREADGROUP")
102 .arg("GROUP")
103 .arg(&self.config.consumer_group)
104 .arg(&self.config.consumer_name)
105 .arg("COUNT")
106 .arg(self.config.batch_size)
107 .arg("BLOCK")
108 .arg(self.config.block_ms)
109 .arg("STREAMS")
110 .arg(&self.config.stream_key)
111 .arg(">")
112 .query_async(&mut *conn)
113 .await?;
114
115 Ok(parse_stream_response(&result))
116 }
117
118 async fn process_and_ack(&self, messages: &[(String, Vec<(String, String)>)]) {
120 for (msg_id, fields) in messages {
121 let data = match extract_field(fields, &self.config.data_field) {
122 Some(d) => d,
123 None => {
124 drop(self.ack_message(msg_id).await);
125 continue;
126 }
127 };
128
129 match self.handler.handle_message(msg_id, data).await {
130 Ok(()) => {
131 drop(self.ack_message(msg_id).await);
132 }
133 Err(_) => {
134 }
136 }
137 }
138 }
139
140 async fn ack_message(&self, msg_id: &str) -> StreamResult<()> {
142 let mut conn = self.pool.get().await?;
143 let _: i64 = redis::cmd("XACK")
144 .arg(&self.config.stream_key)
145 .arg(&self.config.consumer_group)
146 .arg(msg_id)
147 .query_async(&mut *conn)
148 .await?;
149 Ok(())
150 }
151
152 async fn reclaim_pending_messages(&self) -> StreamResult<()> {
159 let mut conn = self.pool.get().await?;
160
161 let pending: redis::Value = redis::cmd("XPENDING")
163 .arg(&self.config.stream_key)
164 .arg(&self.config.consumer_group)
165 .arg("-")
166 .arg("+")
167 .arg(self.config.batch_size)
168 .query_async(&mut *conn)
169 .await?;
170
171 let entries = parse_pending_entries(&pending);
172 if entries.is_empty() {
173 return Ok(());
174 }
175
176 for (msg_id, _consumer, idle_ms, delivery_count) in &entries {
177 if *idle_ms < self.config.min_idle_ms {
178 continue;
179 }
180
181 if *delivery_count > self.config.max_retries {
182 let data = self.read_message_data(&mut conn, msg_id).await;
184 self.handler
185 .on_dead_letter(msg_id, data.as_deref().unwrap_or(""))
186 .await;
187 drop(self.move_to_dead_letter_with_conn(&mut conn, msg_id).await);
188 continue;
189 }
190
191 let claimed: redis::Value = redis::cmd("XCLAIM")
193 .arg(&self.config.stream_key)
194 .arg(&self.config.consumer_group)
195 .arg(&self.config.consumer_name)
196 .arg(self.config.min_idle_ms)
197 .arg(msg_id)
198 .query_async(&mut *conn)
199 .await?;
200
201 let claimed_messages = parse_claimed_messages(&claimed);
202 for (claimed_id, fields) in &claimed_messages {
203 let data = match extract_field(fields, &self.config.data_field) {
204 Some(d) => d,
205 None => {
206 self.ack_with_conn(&mut conn, claimed_id).await?;
207 continue;
208 }
209 };
210
211 match self.handler.handle_message(claimed_id, data).await {
212 Ok(()) => {
213 drop(self.ack_with_conn(&mut conn, claimed_id).await);
214 }
215 Err(_) => {
216 }
218 }
219 }
220 }
221
222 Ok(())
223 }
224
225 async fn ack_with_conn(
227 &self,
228 conn: &mut deadpool_redis::Connection,
229 msg_id: &str,
230 ) -> StreamResult<()> {
231 let _: i64 = redis::cmd("XACK")
232 .arg(&self.config.stream_key)
233 .arg(&self.config.consumer_group)
234 .arg(msg_id)
235 .query_async(&mut *conn)
236 .await?;
237 Ok(())
238 }
239
240 async fn read_message_data(
242 &self,
243 conn: &mut deadpool_redis::Connection,
244 msg_id: &str,
245 ) -> Option<String> {
246 let value: redis::Value = redis::cmd("XRANGE")
247 .arg(&self.config.stream_key)
248 .arg(msg_id)
249 .arg(msg_id)
250 .query_async(&mut *conn)
251 .await
252 .ok()?;
253
254 let messages = parse_claimed_messages(&value);
255 let (_id, fields) = messages.first()?;
256 extract_field(fields, "data").map(|s| s.to_owned())
257 }
258
259 async fn move_to_dead_letter_with_conn(
264 &self,
265 conn: &mut deadpool_redis::Connection,
266 msg_id: &str,
267 ) -> StreamResult<()> {
268 let original: redis::Value = redis::cmd("XRANGE")
270 .arg(&self.config.stream_key)
271 .arg(msg_id)
272 .arg(msg_id)
273 .query_async(&mut *conn)
274 .await?;
275
276 let messages = parse_claimed_messages(&original);
278 if let Some((_id, fields)) = messages.first() {
279 let mut cmd = redis::cmd("XADD");
281 cmd.arg(&self.config.dead_letter_key)
282 .arg("MAXLEN")
283 .arg("~")
284 .arg(10000_i64)
285 .arg("*");
286
287 cmd.arg("original_id").arg(msg_id);
289
290 for (key, value) in fields {
291 cmd.arg(key).arg(value);
292 }
293
294 let _dead_letter_id: String = cmd.query_async(&mut *conn).await?;
295 }
296
297 self.ack_with_conn(conn, msg_id).await?;
299 Ok(())
300 }
301}
302
303#[cfg(test)]
304mod tests {
305 use crate::config::StreamConfig;
306
307 #[test]
308 fn test_consumer_config_defaults() {
309 let config = StreamConfig::new(
310 "test:stream",
311 "test:stream:dead_letter",
312 "test-group",
313 "worker-1",
314 );
315 assert_eq!(config.batch_size, 10);
316 assert_eq!(config.block_ms, 5000);
317 assert_eq!(config.max_retries, 5);
318 assert_eq!(config.min_idle_ms, 60_000);
319 assert_eq!(config.group_start_id, "$");
320 }
321
322 #[test]
323 fn test_consumer_config_builder() {
324 let config = StreamConfig::new(
325 "settlement:deposits",
326 "settlement:deposits:dead_letter",
327 "settlement-service",
328 "worker-1",
329 )
330 .with_min_idle_ms(180_000)
331 .with_max_retries(10)
332 .with_group_start_id("0")
333 .with_batch_size(20)
334 .with_block_ms(3000);
335
336 assert_eq!(config.min_idle_ms, 180_000);
337 assert_eq!(config.max_retries, 10);
338 assert_eq!(config.group_start_id, "0");
339 assert_eq!(config.batch_size, 20);
340 assert_eq!(config.block_ms, 3000);
341 }
342}