use std::collections::HashMap;
use std::future::Future;
use std::sync::Arc;
use std::time::Duration;
use tokio_util::sync::CancellationToken;
use crate::backend::ConsumerOptionsInner;
use crate::backend::consumer::ConsumerImpl;
use crate::error::{Result, ShoveError};
use crate::handler::MessageHandler;
use crate::metadata::MessageMetadata;
use crate::metrics;
use crate::outcome::Outcome;
use crate::retry::Backoff;
use crate::topic::{SequencedTopic, Topic};
use crate::topology::{HoldQueue, QueueTopology};
use redis::streams::StreamAutoClaimReply;
use super::client::{RedisClient, RedisConnection};
use super::constants::{
AUTOCLAIM_COUNT, BLOCK_MS, PAYLOAD_FIELD, X_DEATH_COUNT, X_DEATH_REASON, X_MESSAGE_ID,
X_ORIGINAL_QUEUE, X_RETRY_COUNT, X_SEQUENCE_KEY,
};
use super::requeue::{HoldEntry, enqueue_hold, spawn_requeuer};
use super::topology::RedisTopologyDeclarer;
#[derive(Clone)]
pub struct RedisConsumer {
client: RedisClient,
}
impl RedisConsumer {
pub fn new(client: RedisClient) -> Self {
Self { client }
}
fn consumer_name() -> String {
let hostname = std::env::var("HOSTNAME").unwrap_or_else(|_| "unknown".to_string());
let uid = uuid::Uuid::new_v4();
format!("{hostname}-{uid}")
}
}
impl ConsumerImpl for RedisConsumer {
fn run<T, H>(
&self,
handler: H,
ctx: H::Context,
options: ConsumerOptionsInner,
) -> impl Future<Output = Result<()>> + Send
where
T: Topic,
H: MessageHandler<T>,
{
let client = self.client.clone();
async move {
let topology = T::topology();
let stream = topology.queue();
run_stream_loop::<T, H>(client, handler, ctx, options, topology, stream).await
}
}
fn run_fifo<T, H>(
&self,
handler: H,
ctx: H::Context,
options: ConsumerOptionsInner,
) -> impl Future<Output = Result<()>> + Send
where
T: SequencedTopic,
H: MessageHandler<T>,
{
let consumer = self.clone();
async move {
let handles = consumer
.spawn_fifo_shards::<T, H>(handler, ctx, options)
.await?;
for handle in handles {
match handle.await {
Ok(Ok(())) => {}
Ok(Err(e)) => tracing::error!("sequenced shard task failed: {e}"),
Err(e) => tracing::error!("sequenced shard task panicked: {e}"),
}
}
Ok(())
}
}
fn run_dlq<T, H>(&self, handler: H, ctx: H::Context) -> impl Future<Output = Result<()>> + Send
where
T: Topic,
H: MessageHandler<T>,
{
let client = self.client.clone();
async move {
let topology = T::topology();
let dlq_name = topology.dlq().ok_or_else(|| {
ShoveError::Topology(format!(
"run_dlq called on topic {} without DLQ",
topology.queue()
))
})?;
let shutdown = CancellationToken::new();
let options = ConsumerOptionsInner::defaults_with_shutdown(shutdown);
run_stream_loop::<T, H>(client, handler, ctx, options, topology, dlq_name).await
}
}
fn spawn_fifo_shards<T, H>(
&self,
handler: H,
ctx: H::Context,
options: ConsumerOptionsInner,
) -> impl Future<Output = Result<Vec<tokio::task::JoinHandle<Result<()>>>>> + Send
where
T: SequencedTopic,
H: MessageHandler<T>,
{
let client = self.client.clone();
async move {
let topology = T::topology();
let seq = topology.sequencing().ok_or_else(|| {
ShoveError::Topology(format!(
"spawn_fifo_shards called on topic {} without sequencing config",
topology.queue()
))
})?;
let n_shards = seq.routing_shards();
let mut handles: Vec<tokio::task::JoinHandle<Result<()>>> =
Vec::with_capacity(n_shards as usize);
let handler = Arc::new(handler);
let ctx = Arc::new(ctx);
for shard_idx in 0..n_shards {
let stream_name =
RedisTopologyDeclarer::shard_stream_name(topology.queue(), shard_idx);
let shard_hold_queues = topology.shard_hold_queue_names(shard_idx);
let client = client.clone();
let handler = Arc::clone(&handler);
let ctx = Arc::clone(&ctx);
let options = options.clone();
handles.push(tokio::spawn(async move {
let hold_names: Vec<String> = shard_hold_queues
.iter()
.map(|hq| hq.name().to_owned())
.collect();
let shutdown = options.shutdown.clone();
let requeue_handle = if !hold_names.is_empty() {
Some(spawn_requeuer(client.clone(), hold_names, shutdown.clone()))
} else {
None
};
let result = run_stream_loop_arc::<T, H>(
client,
handler,
ctx,
options,
topology,
&stream_name,
&shard_hold_queues,
)
.await;
if let Some(h) = requeue_handle {
h.abort();
}
result
}));
}
Ok(handles)
}
}
}
async fn run_with_reconnect<F, Fut>(
shutdown: &CancellationToken,
stream: &str,
max_reconnect_attempts: Option<u32>,
mut f: F,
) -> Result<()>
where
F: FnMut() -> Fut,
Fut: Future<Output = Result<()>>,
{
let mut backoff = Backoff::default();
let mut attempts = 0u32;
loop {
match f().await {
Ok(()) => return Ok(()),
Err(e) => {
if !e.is_retryable() {
return Err(e);
}
if shutdown.is_cancelled() {
return Ok(());
}
attempts += 1;
if let Some(max) = max_reconnect_attempts
&& attempts >= max
{
tracing::error!(
stream,
attempts,
error = %e,
"max reconnect attempts reached, giving up"
);
return Err(ShoveError::Connection(format!(
"consumer on '{stream}' exhausted {max} reconnect attempt(s): {e}"
)));
}
let delay = backoff.next().expect("backoff is infinite");
tracing::warn!(
stream,
attempt = attempts,
?max_reconnect_attempts,
error = %e,
"consumer error, reconnecting in {delay:?}"
);
tokio::select! {
_ = tokio::time::sleep(delay) => {}
_ = shutdown.cancelled() => return Ok(()),
}
}
}
}
}
async fn run_stream_loop<T, H>(
client: RedisClient,
handler: H,
ctx: H::Context,
options: ConsumerOptionsInner,
topology: &'static QueueTopology,
stream: &str,
) -> Result<()>
where
T: Topic,
H: MessageHandler<T>,
{
let hold_queues = topology.hold_queues();
let shutdown = options.shutdown.clone();
let hold_names: Vec<String> = hold_queues.iter().map(|hq| hq.name().to_owned()).collect();
let requeue_handle = if !hold_names.is_empty() {
Some(spawn_requeuer(client.clone(), hold_names, shutdown.clone()))
} else {
None
};
let result = run_stream_loop_arc::<T, H>(
client,
Arc::new(handler),
Arc::new(ctx),
options,
topology,
stream,
hold_queues,
)
.await;
if let Some(h) = requeue_handle {
h.abort();
}
result
}
async fn run_stream_loop_arc<T, H>(
client: RedisClient,
handler: Arc<H>,
ctx: Arc<H::Context>,
options: ConsumerOptionsInner,
topology: &'static QueueTopology,
stream: &str,
hold_queues: &[HoldQueue],
) -> Result<()>
where
T: Topic,
H: MessageHandler<T>,
{
let group = client.group().to_owned();
let shutdown = options.shutdown.clone();
let topic_name = topology.queue();
let consumer_group = options.consumer_group.as_deref();
let topic_arc: Arc<str> = Arc::from(topic_name);
let group_arc: Option<Arc<str>> = consumer_group.map(Arc::from);
let idle_ms = options
.handler_timeout
.unwrap_or(Duration::from_secs(30))
.as_millis() as u64;
let prefetch = options.prefetch_count.max(1) as usize;
let autoclaim_interval = Duration::from_millis(idle_ms.max(30_000));
run_with_reconnect(&shutdown, stream, options.max_reconnect_attempts, || {
let client = client.clone();
let handler = Arc::clone(&handler);
let ctx = Arc::clone(&ctx);
let options = options.clone();
let group = group.clone();
let consumer = RedisConsumer::consumer_name();
tracing::debug!(
consumer,
stream,
"new consumer name registered; previous name left as stale entry in group until XGROUP DELCONSUMER is called"
);
let topic_arc = Arc::clone(&topic_arc);
let group_arc = group_arc.clone();
let shutdown = shutdown.clone();
async move {
let mut conn = client.dedicated_conn().await?;
let _ = autoclaim_all(&mut conn, stream, &group, &consumer, idle_ms).await;
let mut last_autoclaim = std::time::Instant::now();
loop {
if shutdown.is_cancelled() {
return Ok(());
}
let mut xreadgroup_cmd = redis::cmd("XREADGROUP");
xreadgroup_cmd
.arg("GROUP")
.arg(&group)
.arg(&consumer)
.arg("COUNT")
.arg(prefetch)
.arg("BLOCK")
.arg(BLOCK_MS)
.arg("STREAMS")
.arg(stream)
.arg(">");
let xreadgroup_fut = conn.query(&mut xreadgroup_cmd);
let raw_reply: redis::Value = tokio::select! {
biased;
_ = shutdown.cancelled() => return Ok(()),
result = xreadgroup_fut => match result {
Ok(v) => v,
Err(e) => {
if e.to_string().contains("NOGROUP") {
tracing::warn!(
stream,
error = %e,
"consumer group does not exist — topology may not be declared yet; will retry"
);
return Err(ShoveError::Connection(format!(
"consumer group does not exist on stream '{stream}': {e}"
)));
}
tracing::warn!(error = %e, stream, "XREADGROUP failed");
return Err(e);
}
}
};
let entries = parse_xreadgroup_reply(raw_reply);
for (entry_id, fields_vec) in entries {
let fields: HashMap<String, String> = fields_vec.into_iter().collect();
let payload_raw = match fields.get(PAYLOAD_FIELD) {
Some(s) => s.clone(),
None => {
tracing::warn!(entry_id, "missing payload field — acking and skipping");
if let Err(e) = xack(&mut conn, stream, &group, &entry_id).await {
tracing::warn!(entry_id, error = %e, "XACK failed after skipping corrupt entry");
metrics::record_backend_error(metrics::BackendLabel::Redis, metrics::BackendErrorKind::Ack);
}
continue;
}
};
let retry_count = fields
.get(X_RETRY_COUNT)
.and_then(|s| s.parse::<u32>().ok())
.unwrap_or(0);
if let Some(max) = options.max_message_size
&& payload_raw.len() > max
{
tracing::warn!(
entry_id,
size = payload_raw.len(),
limit = max,
"message exceeds size limit — sending to DLQ"
);
metrics::record_failed(
topic_name,
consumer_group,
metrics::FailReason::Oversize,
);
route_to_dlq(
&mut conn,
topology,
stream,
&group,
&entry_id,
&fields,
"oversize",
retry_count,
)
.await?;
continue;
}
let msg: T::Message = match serde_json::from_str(&payload_raw) {
Ok(m) => m,
Err(e) => {
tracing::warn!(
error = %e,
entry_id,
"deserialization failed — sending to DLQ"
);
metrics::record_failed(
topic_name,
consumer_group,
metrics::FailReason::Deserialize,
);
route_to_dlq(
&mut conn,
topology,
stream,
&group,
&entry_id,
&fields,
"deserialize",
retry_count,
)
.await?;
continue;
}
};
let delivery_id = fields
.get(X_MESSAGE_ID)
.cloned()
.unwrap_or_else(|| entry_id.clone());
let meta = MessageMetadata {
retry_count,
delivery_id,
redelivered: retry_count > 0,
headers: build_headers(&fields),
};
options
.processing
.store(true, std::sync::atomic::Ordering::Release);
let handler_clone = Arc::clone(&handler);
let ctx_clone = Arc::clone(&ctx);
let _inflight =
metrics::InflightGuard::new(topic_arc.clone(), group_arc.clone());
let start = std::time::Instant::now();
let outcome_opt = match options.handler_timeout {
Some(timeout_dur) => {
match tokio::time::timeout(
timeout_dur,
handler_clone.handle(msg, meta, &ctx_clone),
)
.await
{
Ok(o) => Some(o),
Err(_) => {
tracing::warn!(
entry_id,
timeout = ?timeout_dur,
"handler timed out — leaving in PEL for XAUTOCLAIM"
);
metrics::record_failed(
&topic_arc,
group_arc.as_deref(),
metrics::FailReason::Timeout,
);
None
}
}
}
None => Some(handler_clone.handle(msg, meta, &ctx_clone).await),
};
let elapsed = start.elapsed().as_secs_f64();
let Some(outcome) = outcome_opt else {
options
.processing
.store(false, std::sync::atomic::Ordering::Release);
continue;
};
metrics::record_consumed(&topic_arc, group_arc.as_deref(), &outcome);
metrics::record_processing_duration(
&topic_arc,
group_arc.as_deref(),
&outcome,
elapsed,
);
options
.processing
.store(false, std::sync::atomic::Ordering::Release);
route_outcome(
&mut conn,
topology,
stream,
&group,
&entry_id,
&fields,
outcome,
retry_count,
options.max_retries,
hold_queues,
)
.await?;
}
if last_autoclaim.elapsed() >= autoclaim_interval {
let _ = autoclaim_all(&mut conn, stream, &group, &consumer, idle_ms).await;
last_autoclaim = std::time::Instant::now();
}
}
}
})
.await
}
#[allow(clippy::too_many_arguments)]
async fn route_outcome(
conn: &mut RedisConnection,
topology: &'static QueueTopology,
stream: &str,
group: &str,
entry_id: &str,
fields: &HashMap<String, String>,
outcome: Outcome,
retry_count: u32,
max_retries: u32,
hold_queues: &[HoldQueue],
) -> Result<()> {
match outcome {
Outcome::Ack => {
if let Err(e) = xack(conn, stream, group, entry_id).await {
tracing::warn!(stream, entry_id, error = %e, "XACK failed on Ack");
metrics::record_backend_error(
metrics::BackendLabel::Redis,
metrics::BackendErrorKind::Ack,
);
}
}
Outcome::Retry => {
let new_retry = retry_count + 1;
if new_retry >= max_retries {
route_to_dlq(
conn,
topology,
stream,
group,
entry_id,
fields,
"max-retries",
new_retry,
)
.await?;
} else if hold_queues.is_empty() {
tracing::warn!(
stream,
entry_id,
"Retry but no hold queues — re-queueing immediately"
);
requeue_to_stream(conn, stream, fields, new_retry).await;
if let Err(e) = xack(conn, stream, group, entry_id).await {
tracing::warn!(stream, entry_id, error = %e, "XACK failed after immediate requeue");
metrics::record_backend_error(
metrics::BackendLabel::Redis,
metrics::BackendErrorKind::Ack,
);
}
} else if let Some(level) = hold_level(new_retry, hold_queues) {
let hq = &hold_queues[level];
route_to_hold(
conn,
stream,
group,
entry_id,
fields,
hq.name(),
hq.delay(),
new_retry,
)
.await;
}
}
Outcome::Reject => {
route_to_dlq(
conn,
topology,
stream,
group,
entry_id,
fields,
"rejected",
retry_count,
)
.await?;
}
Outcome::Defer => {
if hold_queues.is_empty() {
tracing::warn!(
stream,
entry_id,
"Defer but no hold queues — re-queueing immediately"
);
requeue_to_stream(conn, stream, fields, retry_count).await;
if let Err(e) = xack(conn, stream, group, entry_id).await {
tracing::warn!(stream, entry_id, error = %e, "XACK failed after defer requeue");
metrics::record_backend_error(
metrics::BackendLabel::Redis,
metrics::BackendErrorKind::Ack,
);
}
} else {
let hq = &hold_queues[0];
route_to_hold(
conn,
stream,
group,
entry_id,
fields,
hq.name(),
hq.delay(),
retry_count,
)
.await;
}
}
}
Ok(())
}
#[allow(clippy::too_many_arguments)]
async fn route_to_hold(
conn: &mut RedisConnection,
stream: &str,
group: &str,
entry_id: &str,
fields: &HashMap<String, String>,
hold_name: &str,
delay: Duration,
new_retry_count: u32,
) {
let mut hold_fields: Vec<(String, String)> = fields
.iter()
.filter(|(k, _)| k.as_str() != X_RETRY_COUNT)
.map(|(k, v)| (k.clone(), v.clone()))
.collect();
hold_fields.push((X_RETRY_COUNT.into(), new_retry_count.to_string()));
let entry = HoldEntry {
stream: stream.to_owned(),
fields: hold_fields,
};
if let Err(e) = enqueue_hold(conn, hold_name, entry, delay).await {
tracing::warn!(error = %e, hold_name, "enqueue_hold failed — message may be lost");
return;
}
if let Err(e) = xack(conn, stream, group, entry_id).await {
tracing::warn!(stream, entry_id, error = %e, "XACK failed after enqueue_hold");
metrics::record_backend_error(metrics::BackendLabel::Redis, metrics::BackendErrorKind::Ack);
}
}
#[allow(clippy::too_many_arguments)]
async fn route_to_dlq(
conn: &mut RedisConnection,
topology: &'static QueueTopology,
stream: &str,
group: &str,
entry_id: &str,
fields: &HashMap<String, String>,
reason: &str,
death_count: u32,
) -> Result<()> {
let dlq = match topology.dlq() {
Some(d) => d,
None => {
tracing::warn!(stream, entry_id, reason, "no DLQ configured — discarding");
if let Err(e) = xack(conn, stream, group, entry_id).await {
tracing::warn!(stream, entry_id, error = %e, "XACK failed while discarding (no DLQ)");
metrics::record_backend_error(
metrics::BackendLabel::Redis,
metrics::BackendErrorKind::Ack,
);
}
return Ok(());
}
};
let mut cmd = redis::cmd("XADD");
cmd.arg(dlq).arg("*");
for (k, v) in fields {
cmd.arg(k.as_str()).arg(v.as_str());
}
cmd.arg(X_DEATH_REASON).arg(reason);
cmd.arg(X_DEATH_COUNT).arg(death_count.to_string());
cmd.arg(X_ORIGINAL_QUEUE).arg(stream);
conn.query::<redis::Value>(&mut cmd).await.map_err(|e| {
tracing::warn!(error = %e, dlq, "XADD to DLQ failed — message stays in PEL");
ShoveError::Connection(format!("XADD to DLQ failed: {e}"))
})?;
if let Err(e) = xack(conn, stream, group, entry_id).await {
tracing::warn!(stream, entry_id, error = %e, "XACK failed after DLQ enqueue");
metrics::record_backend_error(metrics::BackendLabel::Redis, metrics::BackendErrorKind::Ack);
}
Ok(())
}
async fn requeue_to_stream(
conn: &mut RedisConnection,
stream: &str,
fields: &HashMap<String, String>,
retry_count: u32,
) {
let mut cmd = redis::cmd("XADD");
cmd.arg(stream).arg("*");
for (k, v) in fields {
if k.as_str() != X_RETRY_COUNT {
cmd.arg(k.as_str()).arg(v.as_str());
}
}
cmd.arg(X_RETRY_COUNT).arg(retry_count.to_string());
if let Err(e) = conn.query::<redis::Value>(&mut cmd).await {
tracing::warn!(error = %e, stream, "XADD on immediate requeue failed — message may be lost");
}
}
async fn xack(conn: &mut RedisConnection, stream: &str, group: &str, entry_id: &str) -> Result<()> {
conn.query::<i64>(redis::cmd("XACK").arg(stream).arg(group).arg(entry_id))
.await
.map(|_| ())
.map_err(|e| ShoveError::Connection(format!("XACK failed: {e}")))
}
async fn autoclaim_all(
conn: &mut RedisConnection,
stream: &str,
group: &str,
consumer: &str,
min_idle_ms: u64,
) -> Result<()> {
let mut cursor = "0-0".to_owned();
loop {
let reply: StreamAutoClaimReply = conn
.query(
redis::cmd("XAUTOCLAIM")
.arg(stream)
.arg(group)
.arg(consumer)
.arg(min_idle_ms)
.arg(&cursor)
.arg("COUNT")
.arg(AUTOCLAIM_COUNT),
)
.await
.map_err(|e| ShoveError::Connection(format!("XAUTOCLAIM failed: {e}")))?;
if reply.next_stream_id == "0-0" || reply.next_stream_id.is_empty() {
break;
}
cursor = reply.next_stream_id;
}
Ok(())
}
pub(super) fn parse_xreadgroup_reply(value: redis::Value) -> Vec<(String, Vec<(String, String)>)> {
let streams = match value {
redis::Value::Nil => return vec![],
redis::Value::Array(arr) => arr,
_ => return vec![],
};
let mut result = Vec::new();
for stream_item in streams {
let stream_pair = match stream_item {
redis::Value::Array(arr) if arr.len() >= 2 => arr,
_ => continue,
};
let entry_list = match &stream_pair[1] {
redis::Value::Array(arr) => arr,
_ => continue,
};
for entry_item in entry_list {
let entry_pair = match entry_item {
redis::Value::Array(arr) if arr.len() >= 2 => arr,
_ => continue,
};
let entry_id = match &entry_pair[0] {
redis::Value::BulkString(b) => match std::str::from_utf8(b) {
Ok(s) => s.to_owned(),
Err(_) => continue,
},
redis::Value::SimpleString(s) => s.clone(),
_ => continue,
};
let field_list = match &entry_pair[1] {
redis::Value::Array(arr) => arr,
_ => continue,
};
let mut fields: Vec<(String, String)> = Vec::new();
let mut iter = field_list.iter();
loop {
let key = match iter.next() {
Some(redis::Value::BulkString(b)) => match std::str::from_utf8(b) {
Ok(s) => s.to_owned(),
Err(_) => break,
},
Some(redis::Value::SimpleString(s)) => s.clone(),
Some(_) => break,
None => break,
};
let val = match iter.next() {
Some(redis::Value::BulkString(b)) => String::from_utf8_lossy(b).into_owned(),
Some(redis::Value::SimpleString(s)) => s.clone(),
Some(redis::Value::Nil) => String::new(),
Some(_) => break,
None => break,
};
fields.push((key, val));
}
result.push((entry_id, fields));
}
}
result
}
fn build_headers(fields: &HashMap<String, String>) -> HashMap<String, String> {
const SKIP: &[&str] = &[
PAYLOAD_FIELD,
X_RETRY_COUNT,
X_SEQUENCE_KEY,
X_MESSAGE_ID,
X_DEATH_REASON,
X_DEATH_COUNT,
X_ORIGINAL_QUEUE,
];
fields
.iter()
.filter(|(k, _)| !SKIP.contains(&k.as_str()))
.map(|(k, v)| (k.clone(), v.clone()))
.collect()
}
pub(super) fn hold_level<T>(retry_count: u32, hold_queues: &[T]) -> Option<usize> {
if hold_queues.is_empty() {
None
} else {
Some((retry_count as usize).min(hold_queues.len() - 1))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn retry_count_routing_to_hold_level() {
let hold_queues = vec!["orders-hold-5s", "orders-hold-30s"];
assert_eq!(hold_level(0, &hold_queues), Some(0));
assert_eq!(hold_level(1, &hold_queues), Some(1));
assert_eq!(hold_level(2, &hold_queues), Some(1)); }
#[test]
fn hold_level_empty_returns_none() {
assert_eq!(hold_level(0, &[""]), Some(0));
let empty: Vec<&str> = vec![];
assert_eq!(hold_level(0, &empty), None);
}
#[test]
fn parse_xreadgroup_nil_returns_empty() {
let result = parse_xreadgroup_reply(redis::Value::Nil);
assert!(result.is_empty());
}
#[test]
fn parse_xreadgroup_empty_array_returns_empty() {
let result = parse_xreadgroup_reply(redis::Value::Array(vec![]));
assert!(result.is_empty());
}
#[test]
fn parse_xreadgroup_valid_entry() {
let entry = redis::Value::Array(vec![
redis::Value::BulkString(b"1234-0".to_vec()),
redis::Value::Array(vec![
redis::Value::BulkString(b"payload".to_vec()),
redis::Value::BulkString(b"{}".to_vec()),
redis::Value::BulkString(b"x-retry-count".to_vec()),
redis::Value::BulkString(b"0".to_vec()),
]),
]);
let stream = redis::Value::Array(vec![
redis::Value::BulkString(b"mystream".to_vec()),
redis::Value::Array(vec![entry]),
]);
let reply = redis::Value::Array(vec![stream]);
let result = parse_xreadgroup_reply(reply);
assert_eq!(result.len(), 1);
assert_eq!(result[0].0, "1234-0");
assert_eq!(result[0].1.len(), 2);
assert_eq!(result[0].1[0], ("payload".to_string(), "{}".to_string()));
assert_eq!(
result[0].1[1],
("x-retry-count".to_string(), "0".to_string())
);
}
#[test]
fn parse_xreadgroup_simple_string_id() {
let entry = redis::Value::Array(vec![
redis::Value::SimpleString("9999-1".to_string()),
redis::Value::Array(vec![
redis::Value::BulkString(b"payload".to_vec()),
redis::Value::BulkString(b"hello".to_vec()),
]),
]);
let stream = redis::Value::Array(vec![
redis::Value::BulkString(b"s".to_vec()),
redis::Value::Array(vec![entry]),
]);
let result = parse_xreadgroup_reply(redis::Value::Array(vec![stream]));
assert_eq!(result.len(), 1);
assert_eq!(result[0].0, "9999-1");
}
#[test]
fn parse_xreadgroup_nil_field_value_becomes_empty_string() {
let entry = redis::Value::Array(vec![
redis::Value::BulkString(b"1-0".to_vec()),
redis::Value::Array(vec![
redis::Value::BulkString(b"payload".to_vec()),
redis::Value::Nil,
]),
]);
let stream = redis::Value::Array(vec![
redis::Value::BulkString(b"s".to_vec()),
redis::Value::Array(vec![entry]),
]);
let result = parse_xreadgroup_reply(redis::Value::Array(vec![stream]));
assert_eq!(result.len(), 1);
assert_eq!(result[0].1[0], ("payload".to_string(), String::new()));
}
#[test]
fn parse_xreadgroup_odd_field_count_stops_at_last_key() {
let entry = redis::Value::Array(vec![
redis::Value::BulkString(b"2-0".to_vec()),
redis::Value::Array(vec![
redis::Value::BulkString(b"payload".to_vec()),
redis::Value::BulkString(b"{}".to_vec()),
redis::Value::BulkString(b"dangling-key".to_vec()),
]),
]);
let stream = redis::Value::Array(vec![
redis::Value::BulkString(b"s".to_vec()),
redis::Value::Array(vec![entry]),
]);
let result = parse_xreadgroup_reply(redis::Value::Array(vec![stream]));
assert_eq!(result.len(), 1);
assert_eq!(result[0].1.len(), 1);
assert_eq!(result[0].1[0].0, "payload");
}
#[test]
fn parse_xreadgroup_wrong_root_type_returns_empty() {
let result = parse_xreadgroup_reply(redis::Value::Int(0));
assert!(result.is_empty());
}
#[test]
fn build_headers_excludes_internal_fields() {
let mut fields = std::collections::HashMap::new();
fields.insert(PAYLOAD_FIELD.to_string(), "data".to_string());
fields.insert(X_RETRY_COUNT.to_string(), "2".to_string());
fields.insert(X_SEQUENCE_KEY.to_string(), "acct-1".to_string());
fields.insert("x-custom".to_string(), "val".to_string());
let headers = build_headers(&fields);
assert_eq!(headers.len(), 1);
assert_eq!(headers.get("x-custom").map(String::as_str), Some("val"));
}
#[test]
fn build_headers_excludes_all_internal_fields() {
let mut fields = std::collections::HashMap::new();
fields.insert(PAYLOAD_FIELD.to_string(), "data".to_string());
fields.insert(X_RETRY_COUNT.to_string(), "2".to_string());
fields.insert(X_SEQUENCE_KEY.to_string(), "acct-1".to_string());
fields.insert(X_MESSAGE_ID.to_string(), "msg-abc".to_string());
fields.insert(X_DEATH_REASON.to_string(), "max-retries".to_string());
fields.insert(X_DEATH_COUNT.to_string(), "5".to_string());
fields.insert(X_ORIGINAL_QUEUE.to_string(), "orders".to_string());
fields.insert("x-custom".to_string(), "val".to_string());
let headers = build_headers(&fields);
assert_eq!(headers.len(), 1);
assert_eq!(headers.get("x-custom").map(String::as_str), Some("val"));
assert!(!headers.contains_key(X_MESSAGE_ID));
assert!(!headers.contains_key(X_DEATH_REASON));
assert!(!headers.contains_key(X_DEATH_COUNT));
assert!(!headers.contains_key(X_ORIGINAL_QUEUE));
}
#[test]
fn build_headers_empty_input_returns_empty() {
let fields = std::collections::HashMap::new();
let headers = build_headers(&fields);
assert!(headers.is_empty());
}
#[test]
fn consumer_name_is_unique() {
let a = RedisConsumer::consumer_name();
let b = RedisConsumer::consumer_name();
assert_ne!(a, b, "consumer names must be unique per call");
}
#[test]
fn parse_xreadgroup_non_array_stream_item_skipped() {
let reply = redis::Value::Array(vec![
redis::Value::Int(42), ]);
let result = parse_xreadgroup_reply(reply);
assert!(result.is_empty());
}
#[test]
fn parse_xreadgroup_stream_pair_too_short_skipped() {
let reply = redis::Value::Array(vec![redis::Value::Array(vec![redis::Value::BulkString(
b"only-one".to_vec(),
)])]);
let result = parse_xreadgroup_reply(reply);
assert!(result.is_empty());
}
#[test]
fn parse_xreadgroup_non_array_entry_list_skipped() {
let reply = redis::Value::Array(vec![redis::Value::Array(vec![
redis::Value::BulkString(b"mystream".to_vec()),
redis::Value::Int(99), ])]);
let result = parse_xreadgroup_reply(reply);
assert!(result.is_empty());
}
#[test]
fn parse_xreadgroup_entry_pair_too_short_skipped() {
let reply = redis::Value::Array(vec![redis::Value::Array(vec![
redis::Value::BulkString(b"mystream".to_vec()),
redis::Value::Array(vec![
redis::Value::Array(vec![redis::Value::BulkString(b"1-0".to_vec())]),
]),
])]);
let result = parse_xreadgroup_reply(reply);
assert!(result.is_empty());
}
#[test]
fn parse_xreadgroup_int_entry_id_skipped() {
let reply = redis::Value::Array(vec![redis::Value::Array(vec![
redis::Value::BulkString(b"mystream".to_vec()),
redis::Value::Array(vec![redis::Value::Array(vec![
redis::Value::Int(12345), redis::Value::Array(vec![]),
])]),
])]);
let result = parse_xreadgroup_reply(reply);
assert!(result.is_empty());
}
#[test]
fn parse_xreadgroup_non_array_field_list_skipped() {
let reply = redis::Value::Array(vec![redis::Value::Array(vec![
redis::Value::BulkString(b"mystream".to_vec()),
redis::Value::Array(vec![redis::Value::Array(vec![
redis::Value::BulkString(b"1-0".to_vec()),
redis::Value::Int(0), ])]),
])]);
let result = parse_xreadgroup_reply(reply);
assert!(result.is_empty());
}
#[test]
fn parse_xreadgroup_simple_string_field_key() {
let reply = redis::Value::Array(vec![redis::Value::Array(vec![
redis::Value::BulkString(b"mystream".to_vec()),
redis::Value::Array(vec![redis::Value::Array(vec![
redis::Value::BulkString(b"1-0".to_vec()),
redis::Value::Array(vec![
redis::Value::SimpleString("myfieldkey".to_string()),
redis::Value::BulkString(b"myvalue".to_vec()),
]),
])]),
])]);
let result = parse_xreadgroup_reply(reply);
assert_eq!(result.len(), 1);
assert_eq!(result[0].1.len(), 1);
assert_eq!(result[0].1[0].0, "myfieldkey");
assert_eq!(result[0].1[0].1, "myvalue");
}
#[test]
fn parse_xreadgroup_int_field_key_breaks_loop() {
let reply = redis::Value::Array(vec![redis::Value::Array(vec![
redis::Value::BulkString(b"mystream".to_vec()),
redis::Value::Array(vec![redis::Value::Array(vec![
redis::Value::BulkString(b"1-0".to_vec()),
redis::Value::Array(vec![
redis::Value::BulkString(b"good-key".to_vec()),
redis::Value::BulkString(b"good-val".to_vec()),
redis::Value::Int(42), redis::Value::BulkString(b"after-break".to_vec()),
]),
])]),
])]);
let result = parse_xreadgroup_reply(reply);
assert_eq!(result.len(), 1);
assert_eq!(result[0].1.len(), 1);
assert_eq!(result[0].1[0].0, "good-key");
}
#[test]
fn parse_xreadgroup_simple_string_field_value() {
let reply = redis::Value::Array(vec![redis::Value::Array(vec![
redis::Value::BulkString(b"mystream".to_vec()),
redis::Value::Array(vec![redis::Value::Array(vec![
redis::Value::BulkString(b"1-0".to_vec()),
redis::Value::Array(vec![
redis::Value::BulkString(b"key".to_vec()),
redis::Value::SimpleString("simplevalue".to_string()),
]),
])]),
])]);
let result = parse_xreadgroup_reply(reply);
assert_eq!(result.len(), 1);
assert_eq!(result[0].1[0].1, "simplevalue");
}
#[test]
fn parse_xreadgroup_int_field_value_breaks_loop() {
let reply = redis::Value::Array(vec![redis::Value::Array(vec![
redis::Value::BulkString(b"mystream".to_vec()),
redis::Value::Array(vec![redis::Value::Array(vec![
redis::Value::BulkString(b"1-0".to_vec()),
redis::Value::Array(vec![
redis::Value::BulkString(b"k1".to_vec()),
redis::Value::BulkString(b"v1".to_vec()),
redis::Value::BulkString(b"k2".to_vec()),
redis::Value::Int(99), ]),
])]),
])]);
let result = parse_xreadgroup_reply(reply);
assert_eq!(result.len(), 1);
assert_eq!(result[0].1.len(), 1);
assert_eq!(result[0].1[0].0, "k1");
}
#[test]
fn parse_xreadgroup_multiple_streams_merged_flat() {
fn make_stream(name: &str, id: &str, val: &str) -> redis::Value {
redis::Value::Array(vec![
redis::Value::BulkString(name.as_bytes().to_vec()),
redis::Value::Array(vec![redis::Value::Array(vec![
redis::Value::BulkString(id.as_bytes().to_vec()),
redis::Value::Array(vec![
redis::Value::BulkString(b"payload".to_vec()),
redis::Value::BulkString(val.as_bytes().to_vec()),
]),
])]),
])
}
let reply = redis::Value::Array(vec![
make_stream("stream-a", "1-0", "msg-a"),
make_stream("stream-b", "2-0", "msg-b"),
make_stream("stream-c", "3-0", "msg-c"),
]);
let result = parse_xreadgroup_reply(reply);
assert_eq!(result.len(), 3);
assert_eq!(result[0].0, "1-0");
assert_eq!(result[1].0, "2-0");
assert_eq!(result[2].0, "3-0");
}
#[test]
fn hold_level_single_element_always_returns_zero() {
let single = vec!["only-queue"];
assert_eq!(hold_level(0, &single), Some(0));
assert_eq!(hold_level(1, &single), Some(0));
assert_eq!(hold_level(100, &single), Some(0));
assert_eq!(hold_level(u32::MAX, &single), Some(0));
}
#[test]
fn nogroup_error_string_is_detected() {
let err_str = "NOGROUP No such consumer group 'grp' for key name 'stream'";
assert!(err_str.contains("NOGROUP"));
let err = ShoveError::Connection(err_str.to_string());
assert!(err.to_string().contains("NOGROUP"));
}
#[test]
fn nogroup_error_is_retryable() {
let err = ShoveError::Connection(
"consumer group does not exist on stream 'foo': NOGROUP ...".into(),
);
assert!(
err.is_retryable(),
"NOGROUP error must be retryable so consumers survive Redis restart"
);
}
#[test]
fn nogroup_error_is_connection_not_topology() {
let err = ShoveError::Connection(
"consumer group does not exist on stream 'foo': NOGROUP ...".into(),
);
assert!(
matches!(err, ShoveError::Connection(_)),
"NOGROUP must be ShoveError::Connection, not Topology"
);
assert!(
!matches!(err, ShoveError::Topology(_)),
"NOGROUP must not be ShoveError::Topology"
);
}
#[test]
fn exhausted_reconnect_error_message_format() {
let stream = "orders";
let max: u32 = 3;
let cause = "connection refused";
let msg = format!("consumer on '{stream}' exhausted {max} reconnect attempt(s): {cause}");
assert!(msg.contains(stream), "stream name must appear in error");
assert!(
msg.contains(&max.to_string()),
"attempt count must appear in error"
);
assert!(msg.contains(cause), "root cause must appear in error");
}
#[tokio::test]
async fn run_with_reconnect_stops_when_limit_reached() {
use tokio_util::sync::CancellationToken;
let shutdown = CancellationToken::new();
let mut calls = 0u32;
let result = run_with_reconnect(&shutdown, "test-stream", Some(2), || {
calls += 1;
async { Err(ShoveError::Connection("transient".into())) }
})
.await;
assert!(
result.is_err(),
"must propagate error after exhausting attempts"
);
let msg = result.unwrap_err().to_string();
assert!(
msg.contains("test-stream"),
"error must name the stream; got: {msg}"
);
assert_eq!(calls, 2, "must attempt exactly max times before giving up");
}
#[tokio::test]
async fn run_with_reconnect_unlimited_can_succeed_after_retries() {
use tokio_util::sync::CancellationToken;
let shutdown = CancellationToken::new();
let mut calls = 0u32;
let result = run_with_reconnect(&shutdown, "test-stream", None, || {
calls += 1;
async move {
if calls < 3 {
Err(ShoveError::Connection("transient".into()))
} else {
Ok(())
}
}
})
.await;
assert!(result.is_ok(), "must succeed once the closure returns Ok");
assert_eq!(calls, 3);
}
#[tokio::test]
async fn run_with_reconnect_non_retryable_error_propagates_immediately() {
use tokio_util::sync::CancellationToken;
let shutdown = CancellationToken::new();
let mut calls = 0u32;
let result = run_with_reconnect(&shutdown, "test-stream", None, || {
calls += 1;
async { Err(ShoveError::Topology("bad topology".into())) }
})
.await;
assert!(result.is_err());
assert_eq!(calls, 1, "non-retryable error must not trigger reconnect");
}
}