pyra-streams 0.4.2

Redis Stream consumer infrastructure for Pyra services
Documentation
use std::sync::Arc;
use std::time::Duration;

use deadpool_redis::Pool;
use tokio::sync::broadcast;

use crate::config::StreamConfig;
use crate::error::{StreamError, StreamResult};
use crate::handler::StreamHandler;
use crate::parse::{
    extract_data_field, parse_claimed_messages, parse_pending_entries, parse_stream_response,
};

/// A Redis Stream consumer that reads messages, dispatches them to a handler,
/// and manages ACKs, retries (via XCLAIM), and dead-lettering.
///
/// The consumer uses XREADGROUP with consumer groups for reliable, at-least-once
/// delivery. Messages that fail processing are left pending and automatically
/// reclaimed after `min_idle_ms` via periodic XCLAIM. Messages exceeding
/// `max_retries` are moved to a dead-letter stream.
pub struct StreamConsumer<H: StreamHandler> {
    pool: Pool,
    config: StreamConfig,
    handler: Arc<H>,
}

impl<H: StreamHandler> StreamConsumer<H> {
    /// Create a new stream consumer.
    pub fn new(pool: Pool, config: StreamConfig, handler: H) -> Self {
        Self {
            pool,
            config,
            handler: Arc::new(handler),
        }
    }

    /// Run the consumer loop until a shutdown signal is received.
    ///
    /// This method:
    /// 1. Ensures the consumer group exists (creates it if not).
    /// 2. Reads messages via XREADGROUP in a cancellation-safe inner loop.
    /// 3. Processes messages outside `select!` to prevent partial execution.
    /// 4. Periodically reclaims idle pending messages via XCLAIM.
    pub async fn run(&self, mut shutdown_rx: broadcast::Receiver<()>) -> StreamResult<()> {
        self.ensure_consumer_group().await?;

        let mut reclaim_interval = tokio::time::interval(self.config.reclaim_interval);
        reclaim_interval.tick().await; // skip first immediate tick

        loop {
            // Inner loop: only cancellation-safe operations (XREADGROUP, timer tick).
            let messages = loop {
                tokio::select! {
                    _ = shutdown_rx.recv() => {
                        return Ok(());
                    }
                    _ = reclaim_interval.tick() => {
                        drop(self.reclaim_pending_messages().await);
                    }
                    result = self.read_from_stream() => {
                        match result {
                            Ok(msgs) if msgs.is_empty() => continue,
                            Ok(msgs) => break msgs,
                            Err(_) => {
                                tokio::time::sleep(Duration::from_secs(1)).await;
                                continue;
                            }
                        }
                    }
                }
            };

            // Process OUTSIDE select! — cannot be cancelled mid-processing.
            self.process_and_ack(&messages).await;
        }
    }

    /// Create the consumer group if it doesn't already exist.
    async fn ensure_consumer_group(&self) -> StreamResult<()> {
        let mut conn = self.pool.get().await?;
        let result: Result<String, redis::RedisError> = redis::cmd("XGROUP")
            .arg("CREATE")
            .arg(&self.config.stream_key)
            .arg(&self.config.consumer_group)
            .arg(&self.config.group_start_id)
            .arg("MKSTREAM")
            .query_async(&mut *conn)
            .await;

        match result {
            Ok(_) => {}
            Err(e) if e.to_string().contains("BUSYGROUP") => {}
            Err(e) => return Err(StreamError::Redis(e)),
        }
        Ok(())
    }

    /// Read a batch of messages from the stream using XREADGROUP.
    async fn read_from_stream(&self) -> StreamResult<Vec<(String, Vec<(String, String)>)>> {
        let mut conn = self.pool.get().await?;
        let result: redis::Value = redis::cmd("XREADGROUP")
            .arg("GROUP")
            .arg(&self.config.consumer_group)
            .arg(&self.config.consumer_name)
            .arg("COUNT")
            .arg(self.config.batch_size)
            .arg("BLOCK")
            .arg(self.config.block_ms)
            .arg("STREAMS")
            .arg(&self.config.stream_key)
            .arg(">")
            .query_async(&mut *conn)
            .await?;

        Ok(parse_stream_response(&result))
    }

    /// Process a batch of messages and ACK successful ones.
    async fn process_and_ack(&self, messages: &[(String, Vec<(String, String)>)]) {
        for (msg_id, fields) in messages {
            let data = match extract_data_field(fields) {
                Some(d) => d,
                None => {
                    drop(self.ack_message(msg_id).await);
                    continue;
                }
            };

            match self.handler.handle_message(msg_id, data).await {
                Ok(()) => {
                    drop(self.ack_message(msg_id).await);
                }
                Err(_) => {
                    // Message stays pending and will be retried via XCLAIM
                }
            }
        }
    }

    /// Acknowledge a message in the consumer group.
    async fn ack_message(&self, msg_id: &str) -> StreamResult<()> {
        let mut conn = self.pool.get().await?;
        let _: i64 = redis::cmd("XACK")
            .arg(&self.config.stream_key)
            .arg(&self.config.consumer_group)
            .arg(msg_id)
            .query_async(&mut *conn)
            .await?;
        Ok(())
    }

    /// Reclaim idle pending messages via XPENDING + XCLAIM.
    ///
    /// Messages exceeding `max_retries` are moved to the dead-letter stream.
    /// Others are reclaimed and reprocessed.
    ///
    /// Uses a single connection for all Redis operations to avoid pool churn.
    async fn reclaim_pending_messages(&self) -> StreamResult<()> {
        let mut conn = self.pool.get().await?;

        // Get pending messages for this consumer group
        let pending: redis::Value = redis::cmd("XPENDING")
            .arg(&self.config.stream_key)
            .arg(&self.config.consumer_group)
            .arg("-")
            .arg("+")
            .arg(self.config.batch_size)
            .query_async(&mut *conn)
            .await?;

        let entries = parse_pending_entries(&pending);
        if entries.is_empty() {
            return Ok(());
        }

        for (msg_id, _consumer, idle_ms, delivery_count) in &entries {
            if *idle_ms < self.config.min_idle_ms {
                continue;
            }

            if *delivery_count > self.config.max_retries {
                drop(self.move_to_dead_letter_with_conn(&mut conn, msg_id).await);
                continue;
            }

            // XCLAIM the message
            let claimed: redis::Value = redis::cmd("XCLAIM")
                .arg(&self.config.stream_key)
                .arg(&self.config.consumer_group)
                .arg(&self.config.consumer_name)
                .arg(self.config.min_idle_ms)
                .arg(msg_id)
                .query_async(&mut *conn)
                .await?;

            let claimed_messages = parse_claimed_messages(&claimed);
            for (claimed_id, fields) in &claimed_messages {
                let data = match extract_data_field(fields) {
                    Some(d) => d,
                    None => {
                        self.ack_with_conn(&mut conn, claimed_id).await?;
                        continue;
                    }
                };

                match self.handler.handle_message(claimed_id, data).await {
                    Ok(()) => {
                        drop(self.ack_with_conn(&mut conn, claimed_id).await);
                    }
                    Err(_) => {
                        // Message stays pending and will be retried
                    }
                }
            }
        }

        Ok(())
    }

    /// ACK a message using an existing connection.
    async fn ack_with_conn(
        &self,
        conn: &mut deadpool_redis::Connection,
        msg_id: &str,
    ) -> StreamResult<()> {
        let _: i64 = redis::cmd("XACK")
            .arg(&self.config.stream_key)
            .arg(&self.config.consumer_group)
            .arg(msg_id)
            .query_async(&mut *conn)
            .await?;
        Ok(())
    }

    /// Move a message to the dead-letter stream using an existing connection.
    ///
    /// Reads the original message via XRANGE, writes it to the dead-letter stream,
    /// then ACKs the original.
    async fn move_to_dead_letter_with_conn(
        &self,
        conn: &mut deadpool_redis::Connection,
        msg_id: &str,
    ) -> StreamResult<()> {
        // Read the original message
        let original: redis::Value = redis::cmd("XRANGE")
            .arg(&self.config.stream_key)
            .arg(msg_id)
            .arg(msg_id)
            .query_async(&mut *conn)
            .await?;

        // Extract fields from original message
        let messages = parse_claimed_messages(&original);
        if let Some((_id, fields)) = messages.first() {
            // Write to dead-letter stream with original fields
            let mut cmd = redis::cmd("XADD");
            cmd.arg(&self.config.dead_letter_key)
                .arg("MAXLEN")
                .arg("~")
                .arg(10000_i64)
                .arg("*");

            // Add original message ID as metadata
            cmd.arg("original_id").arg(msg_id);

            for (key, value) in fields {
                cmd.arg(key).arg(value);
            }

            let _dead_letter_id: String = cmd.query_async(&mut *conn).await?;
        }

        // ACK the original message
        self.ack_with_conn(conn, msg_id).await?;
        Ok(())
    }
}

#[cfg(test)]
mod tests {
    use crate::config::StreamConfig;

    #[test]
    fn test_consumer_config_defaults() {
        let config = StreamConfig::new(
            "test:stream",
            "test:stream:dead_letter",
            "test-group",
            "worker-1",
        );
        assert_eq!(config.batch_size, 10);
        assert_eq!(config.block_ms, 5000);
        assert_eq!(config.max_retries, 5);
        assert_eq!(config.min_idle_ms, 60_000);
        assert_eq!(config.group_start_id, "$");
    }

    #[test]
    fn test_consumer_config_builder() {
        let config = StreamConfig::new(
            "settlement:deposits",
            "settlement:deposits:dead_letter",
            "settlement-service",
            "worker-1",
        )
        .with_min_idle_ms(180_000)
        .with_max_retries(10)
        .with_group_start_id("0")
        .with_batch_size(20)
        .with_block_ms(3000);

        assert_eq!(config.min_idle_ms, 180_000);
        assert_eq!(config.max_retries, 10);
        assert_eq!(config.group_start_id, "0");
        assert_eq!(config.batch_size, 20);
        assert_eq!(config.block_ms, 3000);
    }
}