Skip to main content

pyra_streams/
consumer.rs

1use std::sync::Arc;
2use std::time::Duration;
3
4use deadpool_redis::Pool;
5use tokio::sync::broadcast;
6
7use crate::config::StreamConfig;
8use crate::error::{StreamError, StreamResult};
9use crate::handler::StreamHandler;
10use crate::parse::{
11    extract_field, parse_claimed_messages, parse_pending_entries, parse_stream_response,
12};
13
14/// A Redis Stream consumer that reads messages, dispatches them to a handler,
15/// and manages ACKs, retries (via XCLAIM), and dead-lettering.
16///
17/// The consumer uses XREADGROUP with consumer groups for reliable, at-least-once
18/// delivery. Messages that fail processing are left pending and automatically
19/// reclaimed after `min_idle_ms` via periodic XCLAIM. Messages exceeding
20/// `max_retries` are moved to a dead-letter stream.
21pub struct StreamConsumer<H: StreamHandler> {
22    pool: Pool,
23    config: StreamConfig,
24    handler: Arc<H>,
25}
26
27impl<H: StreamHandler> StreamConsumer<H> {
28    /// Create a new stream consumer.
29    pub fn new(pool: Pool, config: StreamConfig, handler: H) -> Self {
30        Self {
31            pool,
32            config,
33            handler: Arc::new(handler),
34        }
35    }
36
37    /// Run the consumer loop until a shutdown signal is received.
38    ///
39    /// This method:
40    /// 1. Ensures the consumer group exists (creates it if not).
41    /// 2. Reads messages via XREADGROUP in a cancellation-safe inner loop.
42    /// 3. Processes messages outside `select!` to prevent partial execution.
43    /// 4. Periodically reclaims idle pending messages via XCLAIM.
44    pub async fn run(&self, mut shutdown_rx: broadcast::Receiver<()>) -> StreamResult<()> {
45        self.ensure_consumer_group().await?;
46
47        let mut reclaim_interval = tokio::time::interval(self.config.reclaim_interval);
48        reclaim_interval.tick().await; // skip first immediate tick
49
50        loop {
51            // Inner loop: only cancellation-safe operations (XREADGROUP, timer tick).
52            let messages = loop {
53                tokio::select! {
54                    _ = shutdown_rx.recv() => {
55                        return Ok(());
56                    }
57                    _ = reclaim_interval.tick() => {
58                        drop(self.reclaim_pending_messages().await);
59                    }
60                    result = self.read_from_stream() => {
61                        match result {
62                            Ok(msgs) if msgs.is_empty() => continue,
63                            Ok(msgs) => break msgs,
64                            Err(_) => {
65                                tokio::time::sleep(Duration::from_secs(1)).await;
66                                continue;
67                            }
68                        }
69                    }
70                }
71            };
72
73            // Process OUTSIDE select! — cannot be cancelled mid-processing.
74            self.process_and_ack(&messages).await;
75        }
76    }
77
78    /// Create the consumer group if it doesn't already exist.
79    async fn ensure_consumer_group(&self) -> StreamResult<()> {
80        let mut conn = self.pool.get().await?;
81        let result: Result<String, redis::RedisError> = redis::cmd("XGROUP")
82            .arg("CREATE")
83            .arg(&self.config.stream_key)
84            .arg(&self.config.consumer_group)
85            .arg(&self.config.group_start_id)
86            .arg("MKSTREAM")
87            .query_async(&mut *conn)
88            .await;
89
90        match result {
91            Ok(_) => {}
92            Err(e) if e.to_string().contains("BUSYGROUP") => {}
93            Err(e) => return Err(StreamError::Redis(e)),
94        }
95        Ok(())
96    }
97
98    /// Read a batch of messages from the stream using XREADGROUP.
99    async fn read_from_stream(&self) -> StreamResult<Vec<(String, Vec<(String, String)>)>> {
100        let mut conn = self.pool.get().await?;
101        let result: redis::Value = redis::cmd("XREADGROUP")
102            .arg("GROUP")
103            .arg(&self.config.consumer_group)
104            .arg(&self.config.consumer_name)
105            .arg("COUNT")
106            .arg(self.config.batch_size)
107            .arg("BLOCK")
108            .arg(self.config.block_ms)
109            .arg("STREAMS")
110            .arg(&self.config.stream_key)
111            .arg(">")
112            .query_async(&mut *conn)
113            .await?;
114
115        Ok(parse_stream_response(&result))
116    }
117
118    /// Process a batch of messages and ACK successful ones.
119    async fn process_and_ack(&self, messages: &[(String, Vec<(String, String)>)]) {
120        for (msg_id, fields) in messages {
121            let data = match extract_field(fields, &self.config.data_field) {
122                Some(d) => d,
123                None => {
124                    drop(self.ack_message(msg_id).await);
125                    continue;
126                }
127            };
128
129            match self.handler.handle_message(msg_id, data).await {
130                Ok(()) => {
131                    drop(self.ack_message(msg_id).await);
132                }
133                Err(_) => {
134                    // Message stays pending and will be retried via XCLAIM
135                }
136            }
137        }
138    }
139
140    /// Acknowledge a message in the consumer group.
141    async fn ack_message(&self, msg_id: &str) -> StreamResult<()> {
142        let mut conn = self.pool.get().await?;
143        let _: i64 = redis::cmd("XACK")
144            .arg(&self.config.stream_key)
145            .arg(&self.config.consumer_group)
146            .arg(msg_id)
147            .query_async(&mut *conn)
148            .await?;
149        Ok(())
150    }
151
152    /// Reclaim idle pending messages via XPENDING + XCLAIM.
153    ///
154    /// Messages exceeding `max_retries` are moved to the dead-letter stream.
155    /// Others are reclaimed and reprocessed.
156    ///
157    /// Uses a single connection for all Redis operations to avoid pool churn.
158    async fn reclaim_pending_messages(&self) -> StreamResult<()> {
159        let mut conn = self.pool.get().await?;
160
161        // Get pending messages for this consumer group
162        let pending: redis::Value = redis::cmd("XPENDING")
163            .arg(&self.config.stream_key)
164            .arg(&self.config.consumer_group)
165            .arg("-")
166            .arg("+")
167            .arg(self.config.batch_size)
168            .query_async(&mut *conn)
169            .await?;
170
171        let entries = parse_pending_entries(&pending);
172        if entries.is_empty() {
173            return Ok(());
174        }
175
176        for (msg_id, _consumer, idle_ms, delivery_count) in &entries {
177            if *idle_ms < self.config.min_idle_ms {
178                continue;
179            }
180
181            if *delivery_count > self.config.max_retries {
182                // Read message data for the hook before dead-lettering
183                let data = self.read_message_data(&mut conn, msg_id).await;
184                self.handler
185                    .on_dead_letter(msg_id, data.as_deref().unwrap_or(""))
186                    .await;
187                drop(self.move_to_dead_letter_with_conn(&mut conn, msg_id).await);
188                continue;
189            }
190
191            // XCLAIM the message
192            let claimed: redis::Value = redis::cmd("XCLAIM")
193                .arg(&self.config.stream_key)
194                .arg(&self.config.consumer_group)
195                .arg(&self.config.consumer_name)
196                .arg(self.config.min_idle_ms)
197                .arg(msg_id)
198                .query_async(&mut *conn)
199                .await?;
200
201            let claimed_messages = parse_claimed_messages(&claimed);
202            for (claimed_id, fields) in &claimed_messages {
203                let data = match extract_field(fields, &self.config.data_field) {
204                    Some(d) => d,
205                    None => {
206                        self.ack_with_conn(&mut conn, claimed_id).await?;
207                        continue;
208                    }
209                };
210
211                match self.handler.handle_message(claimed_id, data).await {
212                    Ok(()) => {
213                        drop(self.ack_with_conn(&mut conn, claimed_id).await);
214                    }
215                    Err(_) => {
216                        // Message stays pending and will be retried
217                    }
218                }
219            }
220        }
221
222        Ok(())
223    }
224
225    /// ACK a message using an existing connection.
226    async fn ack_with_conn(
227        &self,
228        conn: &mut deadpool_redis::Connection,
229        msg_id: &str,
230    ) -> StreamResult<()> {
231        let _: i64 = redis::cmd("XACK")
232            .arg(&self.config.stream_key)
233            .arg(&self.config.consumer_group)
234            .arg(msg_id)
235            .query_async(&mut *conn)
236            .await?;
237        Ok(())
238    }
239
240    /// Read the "data" field from a message by ID.
241    async fn read_message_data(
242        &self,
243        conn: &mut deadpool_redis::Connection,
244        msg_id: &str,
245    ) -> Option<String> {
246        let value: redis::Value = redis::cmd("XRANGE")
247            .arg(&self.config.stream_key)
248            .arg(msg_id)
249            .arg(msg_id)
250            .query_async(&mut *conn)
251            .await
252            .ok()?;
253
254        let messages = parse_claimed_messages(&value);
255        let (_id, fields) = messages.first()?;
256        extract_field(fields, "data").map(|s| s.to_owned())
257    }
258
259    /// Move a message to the dead-letter stream using an existing connection.
260    ///
261    /// Reads the original message via XRANGE, writes it to the dead-letter stream,
262    /// then ACKs the original.
263    async fn move_to_dead_letter_with_conn(
264        &self,
265        conn: &mut deadpool_redis::Connection,
266        msg_id: &str,
267    ) -> StreamResult<()> {
268        // Read the original message
269        let original: redis::Value = redis::cmd("XRANGE")
270            .arg(&self.config.stream_key)
271            .arg(msg_id)
272            .arg(msg_id)
273            .query_async(&mut *conn)
274            .await?;
275
276        // Extract fields from original message
277        let messages = parse_claimed_messages(&original);
278        if let Some((_id, fields)) = messages.first() {
279            // Write to dead-letter stream with original fields
280            let mut cmd = redis::cmd("XADD");
281            cmd.arg(&self.config.dead_letter_key)
282                .arg("MAXLEN")
283                .arg("~")
284                .arg(10000_i64)
285                .arg("*");
286
287            // Add original message ID as metadata
288            cmd.arg("original_id").arg(msg_id);
289
290            for (key, value) in fields {
291                cmd.arg(key).arg(value);
292            }
293
294            let _dead_letter_id: String = cmd.query_async(&mut *conn).await?;
295        }
296
297        // ACK the original message
298        self.ack_with_conn(conn, msg_id).await?;
299        Ok(())
300    }
301}
302
303#[cfg(test)]
304mod tests {
305    use crate::config::StreamConfig;
306
307    #[test]
308    fn test_consumer_config_defaults() {
309        let config = StreamConfig::new(
310            "test:stream",
311            "test:stream:dead_letter",
312            "test-group",
313            "worker-1",
314        );
315        assert_eq!(config.batch_size, 10);
316        assert_eq!(config.block_ms, 5000);
317        assert_eq!(config.max_retries, 5);
318        assert_eq!(config.min_idle_ms, 60_000);
319        assert_eq!(config.group_start_id, "$");
320    }
321
322    #[test]
323    fn test_consumer_config_builder() {
324        let config = StreamConfig::new(
325            "settlement:deposits",
326            "settlement:deposits:dead_letter",
327            "settlement-service",
328            "worker-1",
329        )
330        .with_min_idle_ms(180_000)
331        .with_max_retries(10)
332        .with_group_start_id("0")
333        .with_batch_size(20)
334        .with_block_ms(3000);
335
336        assert_eq!(config.min_idle_ms, 180_000);
337        assert_eq!(config.max_retries, 10);
338        assert_eq!(config.group_start_id, "0");
339        assert_eq!(config.batch_size, 20);
340        assert_eq!(config.block_ms, 3000);
341    }
342}