use std::sync::Arc;
use std::time::Duration;
use deadpool_redis::Pool;
use tokio::sync::broadcast;
use crate::config::StreamConfig;
use crate::error::{StreamError, StreamResult};
use crate::handler::StreamHandler;
use crate::parse::{
extract_data_field, parse_claimed_messages, parse_pending_entries, parse_stream_response,
};
pub struct StreamConsumer<H: StreamHandler> {
pool: Pool,
config: StreamConfig,
handler: Arc<H>,
}
impl<H: StreamHandler> StreamConsumer<H> {
pub fn new(pool: Pool, config: StreamConfig, handler: H) -> Self {
Self {
pool,
config,
handler: Arc::new(handler),
}
}
pub async fn run(&self, mut shutdown_rx: broadcast::Receiver<()>) -> StreamResult<()> {
self.ensure_consumer_group().await?;
let mut reclaim_interval = tokio::time::interval(self.config.reclaim_interval);
reclaim_interval.tick().await;
loop {
let messages = loop {
tokio::select! {
_ = shutdown_rx.recv() => {
return Ok(());
}
_ = reclaim_interval.tick() => {
drop(self.reclaim_pending_messages().await);
}
result = self.read_from_stream() => {
match result {
Ok(msgs) if msgs.is_empty() => continue,
Ok(msgs) => break msgs,
Err(_) => {
tokio::time::sleep(Duration::from_secs(1)).await;
continue;
}
}
}
}
};
self.process_and_ack(&messages).await;
}
}
async fn ensure_consumer_group(&self) -> StreamResult<()> {
let mut conn = self.pool.get().await?;
let result: Result<String, redis::RedisError> = redis::cmd("XGROUP")
.arg("CREATE")
.arg(&self.config.stream_key)
.arg(&self.config.consumer_group)
.arg(&self.config.group_start_id)
.arg("MKSTREAM")
.query_async(&mut *conn)
.await;
match result {
Ok(_) => {}
Err(e) if e.to_string().contains("BUSYGROUP") => {}
Err(e) => return Err(StreamError::Redis(e)),
}
Ok(())
}
async fn read_from_stream(&self) -> StreamResult<Vec<(String, Vec<(String, String)>)>> {
let mut conn = self.pool.get().await?;
let result: redis::Value = redis::cmd("XREADGROUP")
.arg("GROUP")
.arg(&self.config.consumer_group)
.arg(&self.config.consumer_name)
.arg("COUNT")
.arg(self.config.batch_size)
.arg("BLOCK")
.arg(self.config.block_ms)
.arg("STREAMS")
.arg(&self.config.stream_key)
.arg(">")
.query_async(&mut *conn)
.await?;
Ok(parse_stream_response(&result))
}
async fn process_and_ack(&self, messages: &[(String, Vec<(String, String)>)]) {
for (msg_id, fields) in messages {
let data = match extract_data_field(fields) {
Some(d) => d,
None => {
drop(self.ack_message(msg_id).await);
continue;
}
};
match self.handler.handle_message(msg_id, data).await {
Ok(()) => {
drop(self.ack_message(msg_id).await);
}
Err(_) => {
}
}
}
}
async fn ack_message(&self, msg_id: &str) -> StreamResult<()> {
let mut conn = self.pool.get().await?;
let _: i64 = redis::cmd("XACK")
.arg(&self.config.stream_key)
.arg(&self.config.consumer_group)
.arg(msg_id)
.query_async(&mut *conn)
.await?;
Ok(())
}
async fn reclaim_pending_messages(&self) -> StreamResult<()> {
let mut conn = self.pool.get().await?;
let pending: redis::Value = redis::cmd("XPENDING")
.arg(&self.config.stream_key)
.arg(&self.config.consumer_group)
.arg("-")
.arg("+")
.arg(self.config.batch_size)
.query_async(&mut *conn)
.await?;
let entries = parse_pending_entries(&pending);
if entries.is_empty() {
return Ok(());
}
for (msg_id, _consumer, idle_ms, delivery_count) in &entries {
if *idle_ms < self.config.min_idle_ms {
continue;
}
if *delivery_count > self.config.max_retries {
drop(self.move_to_dead_letter_with_conn(&mut conn, msg_id).await);
continue;
}
let claimed: redis::Value = redis::cmd("XCLAIM")
.arg(&self.config.stream_key)
.arg(&self.config.consumer_group)
.arg(&self.config.consumer_name)
.arg(self.config.min_idle_ms)
.arg(msg_id)
.query_async(&mut *conn)
.await?;
let claimed_messages = parse_claimed_messages(&claimed);
for (claimed_id, fields) in &claimed_messages {
let data = match extract_data_field(fields) {
Some(d) => d,
None => {
self.ack_with_conn(&mut conn, claimed_id).await?;
continue;
}
};
match self.handler.handle_message(claimed_id, data).await {
Ok(()) => {
drop(self.ack_with_conn(&mut conn, claimed_id).await);
}
Err(_) => {
}
}
}
}
Ok(())
}
async fn ack_with_conn(
&self,
conn: &mut deadpool_redis::Connection,
msg_id: &str,
) -> StreamResult<()> {
let _: i64 = redis::cmd("XACK")
.arg(&self.config.stream_key)
.arg(&self.config.consumer_group)
.arg(msg_id)
.query_async(&mut *conn)
.await?;
Ok(())
}
async fn move_to_dead_letter_with_conn(
&self,
conn: &mut deadpool_redis::Connection,
msg_id: &str,
) -> StreamResult<()> {
let original: redis::Value = redis::cmd("XRANGE")
.arg(&self.config.stream_key)
.arg(msg_id)
.arg(msg_id)
.query_async(&mut *conn)
.await?;
let messages = parse_claimed_messages(&original);
if let Some((_id, fields)) = messages.first() {
let mut cmd = redis::cmd("XADD");
cmd.arg(&self.config.dead_letter_key)
.arg("MAXLEN")
.arg("~")
.arg(10000_i64)
.arg("*");
cmd.arg("original_id").arg(msg_id);
for (key, value) in fields {
cmd.arg(key).arg(value);
}
let _dead_letter_id: String = cmd.query_async(&mut *conn).await?;
}
self.ack_with_conn(conn, msg_id).await?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use crate::config::StreamConfig;
#[test]
fn test_consumer_config_defaults() {
let config = StreamConfig::new(
"test:stream",
"test:stream:dead_letter",
"test-group",
"worker-1",
);
assert_eq!(config.batch_size, 10);
assert_eq!(config.block_ms, 5000);
assert_eq!(config.max_retries, 5);
assert_eq!(config.min_idle_ms, 60_000);
assert_eq!(config.group_start_id, "$");
}
#[test]
fn test_consumer_config_builder() {
let config = StreamConfig::new(
"settlement:deposits",
"settlement:deposits:dead_letter",
"settlement-service",
"worker-1",
)
.with_min_idle_ms(180_000)
.with_max_retries(10)
.with_group_start_id("0")
.with_batch_size(20)
.with_block_ms(3000);
assert_eq!(config.min_idle_ms, 180_000);
assert_eq!(config.max_retries, 10);
assert_eq!(config.group_start_id, "0");
assert_eq!(config.batch_size, 20);
assert_eq!(config.block_ms, 3000);
}
}