Skip to main content

pyra_streams/
consumer.rs

1use std::sync::Arc;
2use std::time::Duration;
3
4use deadpool_redis::Pool;
5use tokio::sync::broadcast;
6
7use crate::config::StreamConfig;
8use crate::error::{StreamError, StreamResult};
9use crate::handler::StreamHandler;
10use crate::parse::{
11    extract_data_field, parse_claimed_messages, parse_pending_entries, parse_stream_response,
12};
13
14/// A Redis Stream consumer that reads messages, dispatches them to a handler,
15/// and manages ACKs, retries (via XCLAIM), and dead-lettering.
16///
17/// The consumer uses XREADGROUP with consumer groups for reliable, at-least-once
18/// delivery. Messages that fail processing are left pending and automatically
19/// reclaimed after `min_idle_ms` via periodic XCLAIM. Messages exceeding
20/// `max_retries` are moved to a dead-letter stream.
21pub struct StreamConsumer<H: StreamHandler> {
22    pool: Pool,
23    config: StreamConfig,
24    handler: Arc<H>,
25}
26
27impl<H: StreamHandler> StreamConsumer<H> {
28    /// Create a new stream consumer.
29    pub fn new(pool: Pool, config: StreamConfig, handler: H) -> Self {
30        Self {
31            pool,
32            config,
33            handler: Arc::new(handler),
34        }
35    }
36
37    /// Run the consumer loop until a shutdown signal is received.
38    ///
39    /// This method:
40    /// 1. Ensures the consumer group exists (creates it if not).
41    /// 2. Reads messages via XREADGROUP in a cancellation-safe inner loop.
42    /// 3. Processes messages outside `select!` to prevent partial execution.
43    /// 4. Periodically reclaims idle pending messages via XCLAIM.
44    pub async fn run(&self, mut shutdown_rx: broadcast::Receiver<()>) -> StreamResult<()> {
45        self.ensure_consumer_group().await?;
46
47        let mut reclaim_interval = tokio::time::interval(self.config.reclaim_interval);
48        reclaim_interval.tick().await; // skip first immediate tick
49
50        loop {
51            // Inner loop: only cancellation-safe operations (XREADGROUP, timer tick).
52            let messages = loop {
53                tokio::select! {
54                    _ = shutdown_rx.recv() => {
55                        return Ok(());
56                    }
57                    _ = reclaim_interval.tick() => {
58                        drop(self.reclaim_pending_messages().await);
59                    }
60                    result = self.read_from_stream() => {
61                        match result {
62                            Ok(msgs) if msgs.is_empty() => continue,
63                            Ok(msgs) => break msgs,
64                            Err(_) => {
65                                tokio::time::sleep(Duration::from_secs(1)).await;
66                                continue;
67                            }
68                        }
69                    }
70                }
71            };
72
73            // Process OUTSIDE select! — cannot be cancelled mid-processing.
74            self.process_and_ack(&messages).await;
75        }
76    }
77
78    /// Create the consumer group if it doesn't already exist.
79    async fn ensure_consumer_group(&self) -> StreamResult<()> {
80        let mut conn = self.pool.get().await?;
81        let result: Result<String, redis::RedisError> = redis::cmd("XGROUP")
82            .arg("CREATE")
83            .arg(&self.config.stream_key)
84            .arg(&self.config.consumer_group)
85            .arg(&self.config.group_start_id)
86            .arg("MKSTREAM")
87            .query_async(&mut *conn)
88            .await;
89
90        match result {
91            Ok(_) => {}
92            Err(e) if e.to_string().contains("BUSYGROUP") => {}
93            Err(e) => return Err(StreamError::Redis(e)),
94        }
95        Ok(())
96    }
97
98    /// Read a batch of messages from the stream using XREADGROUP.
99    async fn read_from_stream(&self) -> StreamResult<Vec<(String, Vec<(String, String)>)>> {
100        let mut conn = self.pool.get().await?;
101        let result: redis::Value = redis::cmd("XREADGROUP")
102            .arg("GROUP")
103            .arg(&self.config.consumer_group)
104            .arg(&self.config.consumer_name)
105            .arg("COUNT")
106            .arg(self.config.batch_size)
107            .arg("BLOCK")
108            .arg(self.config.block_ms)
109            .arg("STREAMS")
110            .arg(&self.config.stream_key)
111            .arg(">")
112            .query_async(&mut *conn)
113            .await?;
114
115        Ok(parse_stream_response(&result))
116    }
117
118    /// Process a batch of messages and ACK successful ones.
119    async fn process_and_ack(&self, messages: &[(String, Vec<(String, String)>)]) {
120        for (msg_id, fields) in messages {
121            let data = match extract_data_field(fields) {
122                Some(d) => d,
123                None => {
124                    drop(self.ack_message(msg_id).await);
125                    continue;
126                }
127            };
128
129            match self.handler.handle_message(msg_id, data).await {
130                Ok(()) => {
131                    drop(self.ack_message(msg_id).await);
132                }
133                Err(_) => {
134                    // Message stays pending and will be retried via XCLAIM
135                }
136            }
137        }
138    }
139
140    /// Acknowledge a message in the consumer group.
141    async fn ack_message(&self, msg_id: &str) -> StreamResult<()> {
142        let mut conn = self.pool.get().await?;
143        let _: i64 = redis::cmd("XACK")
144            .arg(&self.config.stream_key)
145            .arg(&self.config.consumer_group)
146            .arg(msg_id)
147            .query_async(&mut *conn)
148            .await?;
149        Ok(())
150    }
151
152    /// Reclaim idle pending messages via XPENDING + XCLAIM.
153    ///
154    /// Messages exceeding `max_retries` are moved to the dead-letter stream.
155    /// Others are reclaimed and reprocessed.
156    ///
157    /// Uses a single connection for all Redis operations to avoid pool churn.
158    async fn reclaim_pending_messages(&self) -> StreamResult<()> {
159        let mut conn = self.pool.get().await?;
160
161        // Get pending messages for this consumer group
162        let pending: redis::Value = redis::cmd("XPENDING")
163            .arg(&self.config.stream_key)
164            .arg(&self.config.consumer_group)
165            .arg("-")
166            .arg("+")
167            .arg(self.config.batch_size)
168            .query_async(&mut *conn)
169            .await?;
170
171        let entries = parse_pending_entries(&pending);
172        if entries.is_empty() {
173            return Ok(());
174        }
175
176        for (msg_id, _consumer, idle_ms, delivery_count) in &entries {
177            if *idle_ms < self.config.min_idle_ms {
178                continue;
179            }
180
181            if *delivery_count > self.config.max_retries {
182                drop(self.move_to_dead_letter_with_conn(&mut conn, msg_id).await);
183                continue;
184            }
185
186            // XCLAIM the message
187            let claimed: redis::Value = redis::cmd("XCLAIM")
188                .arg(&self.config.stream_key)
189                .arg(&self.config.consumer_group)
190                .arg(&self.config.consumer_name)
191                .arg(self.config.min_idle_ms)
192                .arg(msg_id)
193                .query_async(&mut *conn)
194                .await?;
195
196            let claimed_messages = parse_claimed_messages(&claimed);
197            for (claimed_id, fields) in &claimed_messages {
198                let data = match extract_data_field(fields) {
199                    Some(d) => d,
200                    None => {
201                        self.ack_with_conn(&mut conn, claimed_id).await?;
202                        continue;
203                    }
204                };
205
206                match self.handler.handle_message(claimed_id, data).await {
207                    Ok(()) => {
208                        drop(self.ack_with_conn(&mut conn, claimed_id).await);
209                    }
210                    Err(_) => {
211                        // Message stays pending and will be retried
212                    }
213                }
214            }
215        }
216
217        Ok(())
218    }
219
220    /// ACK a message using an existing connection.
221    async fn ack_with_conn(
222        &self,
223        conn: &mut deadpool_redis::Connection,
224        msg_id: &str,
225    ) -> StreamResult<()> {
226        let _: i64 = redis::cmd("XACK")
227            .arg(&self.config.stream_key)
228            .arg(&self.config.consumer_group)
229            .arg(msg_id)
230            .query_async(&mut *conn)
231            .await?;
232        Ok(())
233    }
234
235    /// Move a message to the dead-letter stream using an existing connection.
236    ///
237    /// Reads the original message via XRANGE, writes it to the dead-letter stream,
238    /// then ACKs the original.
239    async fn move_to_dead_letter_with_conn(
240        &self,
241        conn: &mut deadpool_redis::Connection,
242        msg_id: &str,
243    ) -> StreamResult<()> {
244        // Read the original message
245        let original: redis::Value = redis::cmd("XRANGE")
246            .arg(&self.config.stream_key)
247            .arg(msg_id)
248            .arg(msg_id)
249            .query_async(&mut *conn)
250            .await?;
251
252        // Extract fields from original message
253        let messages = parse_claimed_messages(&original);
254        if let Some((_id, fields)) = messages.first() {
255            // Write to dead-letter stream with original fields
256            let mut cmd = redis::cmd("XADD");
257            cmd.arg(&self.config.dead_letter_key)
258                .arg("MAXLEN")
259                .arg("~")
260                .arg(10000_i64)
261                .arg("*");
262
263            // Add original message ID as metadata
264            cmd.arg("original_id").arg(msg_id);
265
266            for (key, value) in fields {
267                cmd.arg(key).arg(value);
268            }
269
270            let _dead_letter_id: String = cmd.query_async(&mut *conn).await?;
271        }
272
273        // ACK the original message
274        self.ack_with_conn(conn, msg_id).await?;
275        Ok(())
276    }
277}
278
279#[cfg(test)]
280mod tests {
281    use crate::config::StreamConfig;
282
283    #[test]
284    fn test_consumer_config_defaults() {
285        let config = StreamConfig::new(
286            "test:stream",
287            "test:stream:dead_letter",
288            "test-group",
289            "worker-1",
290        );
291        assert_eq!(config.batch_size, 10);
292        assert_eq!(config.block_ms, 5000);
293        assert_eq!(config.max_retries, 5);
294        assert_eq!(config.min_idle_ms, 60_000);
295        assert_eq!(config.group_start_id, "$");
296    }
297
298    #[test]
299    fn test_consumer_config_builder() {
300        let config = StreamConfig::new(
301            "settlement:deposits",
302            "settlement:deposits:dead_letter",
303            "settlement-service",
304            "worker-1",
305        )
306        .with_min_idle_ms(180_000)
307        .with_max_retries(10)
308        .with_group_start_id("0")
309        .with_batch_size(20)
310        .with_block_ms(3000);
311
312        assert_eq!(config.min_idle_ms, 180_000);
313        assert_eq!(config.max_retries, 10);
314        assert_eq!(config.group_start_id, "0");
315        assert_eq!(config.batch_size, 20);
316        assert_eq!(config.block_ms, 3000);
317    }
318}