use std::collections::HashMap;
use std::future::Future;
use std::sync::Arc;
use std::time::Duration;
use tokio_util::sync::CancellationToken;
use crate::ConsumerOptions;
use crate::backend::ConsumerOptionsInner;
use crate::backend::consumer::ConsumerImpl;
use crate::consumer_supervisor::{SupervisorOutcome, drive_fifo_until_timeout};
use crate::error::{Result, ShoveError};
use crate::handler::MessageHandler;
use crate::markers::Redis;
use crate::metadata::MessageMetadata;
use crate::metrics;
use crate::outcome::Outcome;
use crate::retry::Backoff;
use crate::topic::{SequencedTopic, Topic};
use crate::topology::{HoldQueue, QueueTopology};
use super::client::{RedisClient, RedisConnection};
use super::constants::{
BLOCK_MS, PAYLOAD_FIELD, X_DEATH_COUNT, X_DEATH_REASON, X_MESSAGE_ID, X_ORIGINAL_QUEUE,
X_RETRY_COUNT, X_SEQUENCE_KEY,
};
use super::requeue::{HoldEntry, enqueue_hold, spawn_requeuer};
use super::topology::RedisTopologyDeclarer;
#[derive(Clone)]
pub struct RedisConsumer {
client: RedisClient,
}
impl RedisConsumer {
pub fn new(client: RedisClient) -> Self {
Self { client }
}
fn consumer_name() -> String {
let hostname = std::env::var("HOSTNAME").unwrap_or_else(|_| "unknown".to_string());
let uid = uuid::Uuid::new_v4();
format!("{hostname}-{uid}")
}
pub(super) async fn run_concurrent<T, H>(
&self,
handler: H,
ctx: H::Context,
options: ConsumerOptionsInner,
) -> Result<()>
where
T: Topic,
H: MessageHandler<T> + 'static,
H::Context: 'static,
{
let topology = T::topology();
let stream = topology.queue();
let hold_queues = topology.hold_queues();
let shutdown = options.shutdown.clone();
let hold_names: Vec<String> = hold_queues.iter().map(|hq| hq.name().to_owned()).collect();
let requeue_handle = if !hold_names.is_empty() {
Some(spawn_requeuer(
self.client.clone(),
hold_names,
shutdown.clone(),
))
} else {
None
};
let result = run_stream_loop_concurrent::<T, H>(
self.client.clone(),
Arc::new(handler),
Arc::new(ctx),
options,
topology,
stream,
hold_queues,
)
.await;
if let Some(h) = requeue_handle {
h.abort();
}
result
}
}
impl RedisConsumer {
pub async fn run<T, H>(
&self,
handler: H,
ctx: H::Context,
options: ConsumerOptions<Redis>,
) -> Result<()>
where
T: Topic,
H: MessageHandler<T>,
{
<Self as ConsumerImpl>::run::<T, H>(self, handler, ctx, options.into_inner()).await
}
pub async fn run_fifo<T, H>(
&self,
handler: H,
ctx: H::Context,
options: ConsumerOptions<Redis>,
) -> Result<()>
where
T: SequencedTopic,
H: MessageHandler<T>,
{
<Self as ConsumerImpl>::run_fifo::<T, H>(self, handler, ctx, options.into_inner()).await
}
pub async fn run_fifo_until_timeout<T, H, S>(
&self,
handler: H,
ctx: H::Context,
options: ConsumerOptions<Redis>,
signal: S,
drain_timeout: Duration,
) -> SupervisorOutcome
where
T: SequencedTopic,
H: MessageHandler<T>,
S: Future<Output = ()> + Send + 'static,
{
let inner = options.into_inner();
let shutdown = inner.shutdown.clone();
let handles = match <Self as ConsumerImpl>::spawn_fifo_shards::<T, H>(
self, handler, ctx, inner,
)
.await
{
Ok(h) => h,
Err(e) => {
tracing::error!(error = %e, "run_fifo_until_timeout: shard spawn failed");
return SupervisorOutcome {
errors: 1,
panics: 0,
timed_out: false,
};
}
};
drive_fifo_until_timeout(handles, shutdown, signal, drain_timeout).await
}
pub async fn run_dlq<T, H>(&self, handler: H, ctx: H::Context) -> Result<()>
where
T: Topic,
H: MessageHandler<T>,
{
let options = crate::ConsumerOptions::<crate::Redis>::new().into_inner();
<Self as ConsumerImpl>::run_dlq::<T, H>(self, handler, ctx, options).await
}
}
impl ConsumerImpl for RedisConsumer {
fn run<T, H>(
&self,
handler: H,
ctx: H::Context,
options: ConsumerOptionsInner,
) -> impl Future<Output = Result<()>> + Send
where
T: Topic,
H: MessageHandler<T>,
{
let client = self.client.clone();
async move {
let topology = T::topology();
let stream = topology.queue();
run_stream_loop::<T, H>(client, handler, ctx, options, topology, stream).await
}
}
fn run_fifo<T, H>(
&self,
handler: H,
ctx: H::Context,
options: ConsumerOptionsInner,
) -> impl Future<Output = Result<()>> + Send
where
T: SequencedTopic,
H: MessageHandler<T>,
{
let consumer = self.clone();
async move {
let handles = consumer
.spawn_fifo_shards::<T, H>(handler, ctx, options)
.await?;
for handle in handles {
match handle.await {
Ok(Ok(())) => {}
Ok(Err(e)) => tracing::error!("sequenced shard task failed: {e}"),
Err(e) => tracing::error!("sequenced shard task panicked: {e}"),
}
}
Ok(())
}
}
fn run_dlq<T, H>(
&self,
handler: H,
ctx: H::Context,
options: ConsumerOptionsInner,
) -> impl Future<Output = Result<()>> + Send
where
T: Topic,
H: MessageHandler<T>,
{
let client = self.client.clone();
async move {
let topology = T::topology();
let dlq_name = topology.dlq().ok_or_else(|| {
ShoveError::Topology(format!(
"run_dlq called on topic {} without DLQ",
topology.queue()
))
})?;
run_stream_loop::<T, H>(client, handler, ctx, options, topology, dlq_name).await
}
}
fn spawn_fifo_shards<T, H>(
&self,
handler: H,
ctx: H::Context,
options: ConsumerOptionsInner,
) -> impl Future<Output = Result<Vec<tokio::task::JoinHandle<Result<()>>>>> + Send
where
T: SequencedTopic,
H: MessageHandler<T>,
{
let client = self.client.clone();
async move {
let topology = T::topology();
let seq = topology.sequencing().ok_or_else(|| {
ShoveError::Topology(format!(
"spawn_fifo_shards called on topic {} without sequencing config",
topology.queue()
))
})?;
let n_shards = seq.routing_shards();
let mut handles: Vec<tokio::task::JoinHandle<Result<()>>> =
Vec::with_capacity(n_shards as usize);
let handler = Arc::new(handler);
let ctx = Arc::new(ctx);
for shard_idx in 0..n_shards {
let stream_name =
RedisTopologyDeclarer::shard_stream_name(topology.queue(), shard_idx);
let shard_hold_queues = topology.shard_hold_queue_names(shard_idx);
let client = client.clone();
let handler = Arc::clone(&handler);
let ctx = Arc::clone(&ctx);
let options = options.clone();
handles.push(tokio::spawn(async move {
let hold_names: Vec<String> = shard_hold_queues
.iter()
.map(|hq| hq.name().to_owned())
.collect();
let shutdown = options.shutdown.clone();
let requeue_handle = if !hold_names.is_empty() {
Some(spawn_requeuer(client.clone(), hold_names, shutdown.clone()))
} else {
None
};
let result = run_stream_loop_arc::<T, H>(
client,
handler,
ctx,
options,
topology,
&stream_name,
&shard_hold_queues,
)
.await;
if let Some(h) = requeue_handle {
h.abort();
}
result
}));
}
Ok(handles)
}
}
}
async fn run_with_reconnect<F, Fut>(
shutdown: &CancellationToken,
stream: &str,
max_reconnect_attempts: Option<u32>,
mut f: F,
) -> Result<()>
where
F: FnMut() -> Fut,
Fut: Future<Output = Result<()>>,
{
let mut backoff = Backoff::default();
let mut attempts = 0u32;
loop {
match f().await {
Ok(()) => return Ok(()),
Err(e) => {
if !e.is_retryable() {
return Err(e);
}
if shutdown.is_cancelled() {
return Ok(());
}
attempts += 1;
if let Some(max) = max_reconnect_attempts
&& attempts >= max
{
tracing::error!(
stream,
attempts,
error = %e,
"max reconnect attempts reached, giving up"
);
return Err(ShoveError::Connection(format!(
"consumer on '{stream}' exhausted {max} reconnect attempt(s): {e}"
)));
}
let delay = backoff.next().expect("backoff is infinite");
tracing::warn!(
stream,
attempt = attempts,
?max_reconnect_attempts,
error = %e,
"consumer error, reconnecting in {delay:?}"
);
tokio::select! {
_ = tokio::time::sleep(delay) => {}
_ = shutdown.cancelled() => return Ok(()),
}
}
}
}
}
async fn run_stream_loop<T, H>(
client: RedisClient,
handler: H,
ctx: H::Context,
options: ConsumerOptionsInner,
topology: &'static QueueTopology,
stream: &str,
) -> Result<()>
where
T: Topic,
H: MessageHandler<T>,
{
let hold_queues = topology.hold_queues();
let shutdown = options.shutdown.clone();
let hold_names: Vec<String> = hold_queues.iter().map(|hq| hq.name().to_owned()).collect();
let requeue_handle = if !hold_names.is_empty() {
Some(spawn_requeuer(client.clone(), hold_names, shutdown.clone()))
} else {
None
};
let result = run_stream_loop_arc::<T, H>(
client,
Arc::new(handler),
Arc::new(ctx),
options,
topology,
stream,
hold_queues,
)
.await;
if let Some(h) = requeue_handle {
h.abort();
}
result
}
async fn run_stream_loop_arc<T, H>(
client: RedisClient,
handler: Arc<H>,
ctx: Arc<H::Context>,
options: ConsumerOptionsInner,
topology: &'static QueueTopology,
stream: &str,
hold_queues: &[HoldQueue],
) -> Result<()>
where
T: Topic,
H: MessageHandler<T>,
{
let group = client.group().to_owned();
let shutdown = options.shutdown.clone();
let topic_name = topology.queue();
let consumer_group = options.consumer_group.as_deref();
let topic_arc: Arc<str> = Arc::from(topic_name);
let group_arc: Option<Arc<str>> = consumer_group.map(Arc::from);
let prefetch = options.prefetch_count.max(1) as usize;
run_with_reconnect(&shutdown, stream, options.max_reconnect_attempts, || {
let client = client.clone();
let handler = Arc::clone(&handler);
let ctx = Arc::clone(&ctx);
let options = options.clone();
let group = group.clone();
let consumer = RedisConsumer::consumer_name();
tracing::debug!(
consumer,
stream,
"new consumer name registered; previous name left as stale entry in group until XGROUP DELCONSUMER is called"
);
let topic_arc = Arc::clone(&topic_arc);
let group_arc = group_arc.clone();
let shutdown = shutdown.clone();
async move {
let mut conn = client.dedicated_conn().await?;
loop {
if shutdown.is_cancelled() {
return Ok(());
}
let mut xreadgroup_cmd = redis::cmd("XREADGROUP");
xreadgroup_cmd
.arg("GROUP")
.arg(&group)
.arg(&consumer)
.arg("COUNT")
.arg(prefetch)
.arg("BLOCK")
.arg(BLOCK_MS)
.arg("STREAMS")
.arg(stream)
.arg(">");
let xreadgroup_fut = conn.query(&mut xreadgroup_cmd);
let raw_reply: redis::Value = tokio::select! {
biased;
_ = shutdown.cancelled() => return Ok(()),
result = xreadgroup_fut => match result {
Ok(v) => v,
Err(e) => {
if e.to_string().contains("NOGROUP") {
tracing::warn!(
stream,
error = %e,
"consumer group does not exist — topology may not be declared yet; will retry"
);
return Err(ShoveError::Connection(format!(
"consumer group does not exist on stream '{stream}': {e}"
)));
}
tracing::warn!(error = %e, stream, "XREADGROUP failed");
return Err(e);
}
}
};
let entries = parse_xreadgroup_reply(raw_reply, prefetch);
for (entry_id, fields_vec) in entries {
let (mut fields, user_headers) = partition_entry_fields(fields_vec);
let payload_raw = match fields.remove(PAYLOAD_FIELD) {
Some(s) => s,
None => {
tracing::warn!(entry_id, "missing payload field — acking and skipping");
if let Err(e) = xack(&mut conn, stream, &group, &entry_id).await {
tracing::warn!(entry_id, error = %e, "XACK failed after skipping corrupt entry");
metrics::record_backend_error(metrics::BackendLabel::Redis, metrics::BackendErrorKind::Ack);
}
continue;
}
};
let retry_count = fields
.get(X_RETRY_COUNT)
.and_then(|s| s.parse::<u32>().ok())
.unwrap_or(0);
if let Some(max) = options.max_message_size
&& payload_raw.len() > max
{
tracing::warn!(
entry_id,
size = payload_raw.len(),
limit = max,
"message exceeds size limit — sending to DLQ"
);
metrics::record_failed(
topic_name,
consumer_group,
metrics::FailReason::Oversize,
);
fields.insert(PAYLOAD_FIELD.to_owned(), payload_raw);
route_to_dlq(
&mut conn,
topology,
stream,
&group,
&entry_id,
&fields,
"oversize",
retry_count,
)
.await?;
continue;
}
let msg: T::Message = match <T::Codec as crate::Codec<T::Message>>::decode(
payload_raw.as_bytes(),
) {
Ok(m) => m,
Err(e) => {
tracing::warn!(
error = %e,
entry_id,
"deserialization failed — sending to DLQ"
);
metrics::record_failed(
topic_name,
consumer_group,
metrics::FailReason::Deserialize,
);
fields.insert(PAYLOAD_FIELD.to_owned(), payload_raw);
route_to_dlq(
&mut conn,
topology,
stream,
&group,
&entry_id,
&fields,
"deserialize",
retry_count,
)
.await?;
continue;
}
};
let delivery_id = fields
.get(X_MESSAGE_ID)
.cloned()
.unwrap_or_else(|| entry_id.clone());
let meta = MessageMetadata {
retry_count,
delivery_id,
redelivered: retry_count > 0,
headers: Arc::new(user_headers),
};
options
.processing
.store(true, std::sync::atomic::Ordering::Release);
let handler_clone = Arc::clone(&handler);
let ctx_clone = Arc::clone(&ctx);
let _inflight =
metrics::InflightGuard::new(topic_arc.clone(), group_arc.clone());
let start = std::time::Instant::now();
let outcome_opt = match options.handler_timeout {
Some(timeout_dur) => {
match tokio::time::timeout(
timeout_dur,
handler_clone.handle(msg, meta, &ctx_clone),
)
.await
{
Ok(o) => Some(o),
Err(_) => {
tracing::warn!(
entry_id,
timeout = ?timeout_dur,
"handler timed out — leaving in PEL for XAUTOCLAIM"
);
metrics::record_failed(
&topic_arc,
group_arc.as_deref(),
metrics::FailReason::Timeout,
);
None
}
}
}
None => Some(handler_clone.handle(msg, meta, &ctx_clone).await),
};
let elapsed = start.elapsed().as_secs_f64();
let Some(outcome) = outcome_opt else {
options
.processing
.store(false, std::sync::atomic::Ordering::Release);
continue;
};
metrics::record_consumed(&topic_arc, group_arc.as_deref(), &outcome);
metrics::record_processing_duration(
&topic_arc,
group_arc.as_deref(),
&outcome,
elapsed,
);
options
.processing
.store(false, std::sync::atomic::Ordering::Release);
fields.insert(PAYLOAD_FIELD.to_owned(), payload_raw);
route_outcome(
&mut conn,
topology,
stream,
&group,
&entry_id,
&fields,
outcome,
retry_count,
options.max_retries,
hold_queues,
)
.await?;
}
}
}
})
.await
}
#[allow(clippy::too_many_arguments)]
async fn run_stream_loop_concurrent<T, H>(
client: RedisClient,
handler: Arc<H>,
ctx: Arc<H::Context>,
options: ConsumerOptionsInner,
topology: &'static QueueTopology,
stream: &str,
hold_queues: &'static [HoldQueue],
) -> Result<()>
where
T: Topic,
H: MessageHandler<T> + 'static,
H::Context: 'static,
{
use tokio::sync::Semaphore;
let group = client.group().to_owned();
let shutdown = options.shutdown.clone();
let topic_name = topology.queue();
let consumer_group = options.consumer_group.as_deref();
let topic_arc: Arc<str> = Arc::from(topic_name);
let group_arc: Option<Arc<str>> = consumer_group.map(Arc::from);
let prefetch = options.prefetch_count.max(1) as usize;
let semaphore = Arc::new(Semaphore::new(prefetch));
let max_retries = options.max_retries;
let max_message_size = options.max_message_size;
let handler_timeout = options.handler_timeout;
let processing = options.processing.clone();
run_with_reconnect(&shutdown, stream, options.max_reconnect_attempts, || {
let client = client.clone();
let handler = Arc::clone(&handler);
let ctx = Arc::clone(&ctx);
let consumer = RedisConsumer::consumer_name();
let topic_arc = Arc::clone(&topic_arc);
let group_arc = group_arc.clone();
let shutdown = shutdown.clone();
let semaphore = Arc::clone(&semaphore);
let processing = Arc::clone(&processing);
let group = group.clone();
async move {
let mut conn = client.dedicated_conn().await?;
let outcome_conn = client.multiplexed_conn().await?;
loop {
if shutdown.is_cancelled() {
let _ = semaphore.acquire_many(prefetch as u32).await;
return Ok(());
}
let mut xreadgroup_cmd = redis::cmd("XREADGROUP");
xreadgroup_cmd
.arg("GROUP")
.arg(&group)
.arg(&consumer)
.arg("COUNT")
.arg(prefetch)
.arg("BLOCK")
.arg(BLOCK_MS)
.arg("STREAMS")
.arg(stream)
.arg(">");
let xreadgroup_fut = conn.query(&mut xreadgroup_cmd);
let raw_reply: redis::Value = tokio::select! {
biased;
_ = shutdown.cancelled() => {
let _ = semaphore.acquire_many(prefetch as u32).await;
return Ok(());
}
result = xreadgroup_fut => match result {
Ok(v) => v,
Err(e) => {
if e.to_string().contains("NOGROUP") {
tracing::warn!(
stream,
error = %e,
"consumer group does not exist — topology may not be declared yet; will retry"
);
return Err(ShoveError::Connection(format!(
"consumer group does not exist on stream '{stream}': {e}"
)));
}
tracing::warn!(error = %e, stream, "XREADGROUP failed");
return Err(e);
}
}
};
let entries = parse_xreadgroup_reply(raw_reply, prefetch);
for (entry_id, fields_vec) in entries {
let (mut fields, user_headers) = partition_entry_fields(fields_vec);
let payload_raw = match fields.remove(PAYLOAD_FIELD) {
Some(s) => s,
None => {
tracing::warn!(entry_id, "missing payload field — acking and skipping");
if let Err(e) = xack(&mut conn, stream, &group, &entry_id).await {
tracing::warn!(entry_id, error = %e, "XACK failed after skipping corrupt entry");
metrics::record_backend_error(metrics::BackendLabel::Redis, metrics::BackendErrorKind::Ack);
}
continue;
}
};
let retry_count = fields
.get(X_RETRY_COUNT)
.and_then(|s| s.parse::<u32>().ok())
.unwrap_or(0);
if let Some(max) = max_message_size
&& payload_raw.len() > max
{
tracing::warn!(
entry_id,
size = payload_raw.len(),
limit = max,
"message exceeds size limit — sending to DLQ"
);
metrics::record_failed(
topic_name,
consumer_group,
metrics::FailReason::Oversize,
);
fields.insert(PAYLOAD_FIELD.to_owned(), payload_raw);
route_to_dlq(
&mut conn,
topology,
stream,
&group,
&entry_id,
&fields,
"oversize",
retry_count,
)
.await?;
continue;
}
let msg: T::Message = match <T::Codec as crate::Codec<T::Message>>::decode(
payload_raw.as_bytes(),
) {
Ok(m) => m,
Err(e) => {
tracing::warn!(
error = %e,
entry_id,
"deserialization failed — sending to DLQ"
);
metrics::record_failed(
topic_name,
consumer_group,
metrics::FailReason::Deserialize,
);
fields.insert(PAYLOAD_FIELD.to_owned(), payload_raw);
route_to_dlq(
&mut conn,
topology,
stream,
&group,
&entry_id,
&fields,
"deserialize",
retry_count,
)
.await?;
continue;
}
};
let delivery_id = fields
.get(X_MESSAGE_ID)
.cloned()
.unwrap_or_else(|| entry_id.clone());
let meta = MessageMetadata {
retry_count,
delivery_id,
redelivered: retry_count > 0,
headers: Arc::new(user_headers),
};
let permit = match semaphore.clone().acquire_owned().await {
Ok(p) => p,
Err(_) => {
return Err(ShoveError::Connection(
"concurrent consumer semaphore closed".to_string(),
));
}
};
processing.store(true, std::sync::atomic::Ordering::Release);
let task_handler = Arc::clone(&handler);
let task_ctx = Arc::clone(&ctx);
let mut task_conn = outcome_conn.clone();
let task_topic = Arc::clone(&topic_arc);
let task_group_metric = group_arc.clone();
let task_group = group.clone();
let task_stream = stream.to_owned();
let task_processing = Arc::clone(&processing);
let task_semaphore = Arc::clone(&semaphore);
fields.insert(PAYLOAD_FIELD.to_owned(), payload_raw);
tokio::spawn(async move {
let _inflight =
metrics::InflightGuard::new(task_topic.clone(), task_group_metric.clone());
let start = std::time::Instant::now();
let outcome_opt = match handler_timeout {
Some(timeout_dur) => {
match tokio::time::timeout(
timeout_dur,
task_handler.handle(msg, meta, &task_ctx),
)
.await
{
Ok(o) => Some(o),
Err(_) => {
tracing::warn!(
entry_id,
timeout = ?timeout_dur,
"handler timed out — leaving in PEL for XAUTOCLAIM"
);
metrics::record_failed(
&task_topic,
task_group_metric.as_deref(),
metrics::FailReason::Timeout,
);
None
}
}
}
None => Some(task_handler.handle(msg, meta, &task_ctx).await),
};
let elapsed = start.elapsed().as_secs_f64();
drop(permit);
if task_semaphore.available_permits() == prefetch {
task_processing
.store(false, std::sync::atomic::Ordering::Release);
}
let Some(outcome) = outcome_opt else { return };
metrics::record_consumed(
&task_topic,
task_group_metric.as_deref(),
&outcome,
);
metrics::record_processing_duration(
&task_topic,
task_group_metric.as_deref(),
&outcome,
elapsed,
);
if let Err(e) = route_outcome(
&mut task_conn,
topology,
&task_stream,
&task_group,
&entry_id,
&fields,
outcome,
retry_count,
max_retries,
hold_queues,
)
.await
{
tracing::warn!(
error = %e,
entry_id,
"outcome routing failed; message left in PEL"
);
}
});
}
}
}
})
.await
}
#[allow(clippy::too_many_arguments)]
async fn route_outcome(
conn: &mut RedisConnection,
topology: &'static QueueTopology,
stream: &str,
group: &str,
entry_id: &str,
fields: &HashMap<String, String>,
outcome: Outcome,
retry_count: u32,
max_retries: u32,
hold_queues: &[HoldQueue],
) -> Result<()> {
match outcome {
Outcome::Ack => {
if let Err(e) = xack(conn, stream, group, entry_id).await {
tracing::warn!(stream, entry_id, error = %e, "XACK failed on Ack");
metrics::record_backend_error(
metrics::BackendLabel::Redis,
metrics::BackendErrorKind::Ack,
);
}
}
Outcome::Retry => {
let new_retry = retry_count + 1;
if new_retry >= max_retries {
route_to_dlq(
conn,
topology,
stream,
group,
entry_id,
fields,
"max-retries",
new_retry,
)
.await?;
} else if hold_queues.is_empty() {
tracing::warn!(
stream,
entry_id,
"Retry but no hold queues — re-queueing immediately"
);
requeue_to_stream(conn, stream, fields, new_retry).await;
if let Err(e) = xack(conn, stream, group, entry_id).await {
tracing::warn!(stream, entry_id, error = %e, "XACK failed after immediate requeue");
metrics::record_backend_error(
metrics::BackendLabel::Redis,
metrics::BackendErrorKind::Ack,
);
}
} else if let Some(level) = hold_level(new_retry, hold_queues) {
let hq = &hold_queues[level];
route_to_hold(
conn,
stream,
group,
entry_id,
fields,
hq.name(),
hq.delay(),
new_retry,
)
.await;
}
}
Outcome::Reject => {
route_to_dlq(
conn,
topology,
stream,
group,
entry_id,
fields,
"rejected",
retry_count,
)
.await?;
}
Outcome::Defer => {
if hold_queues.is_empty() {
tracing::warn!(
stream,
entry_id,
"Defer but no hold queues — re-queueing immediately"
);
requeue_to_stream(conn, stream, fields, retry_count).await;
if let Err(e) = xack(conn, stream, group, entry_id).await {
tracing::warn!(stream, entry_id, error = %e, "XACK failed after defer requeue");
metrics::record_backend_error(
metrics::BackendLabel::Redis,
metrics::BackendErrorKind::Ack,
);
}
} else {
let hq = &hold_queues[0];
route_to_hold(
conn,
stream,
group,
entry_id,
fields,
hq.name(),
hq.delay(),
retry_count,
)
.await;
}
}
}
Ok(())
}
#[allow(clippy::too_many_arguments)]
async fn route_to_hold(
conn: &mut RedisConnection,
stream: &str,
group: &str,
entry_id: &str,
fields: &HashMap<String, String>,
hold_name: &str,
delay: Duration,
new_retry_count: u32,
) {
let mut hold_fields: Vec<(String, String)> = fields
.iter()
.filter(|(k, _)| k.as_str() != X_RETRY_COUNT)
.map(|(k, v)| (k.clone(), v.clone()))
.collect();
hold_fields.push((X_RETRY_COUNT.into(), new_retry_count.to_string()));
let entry = HoldEntry {
stream: stream.to_owned(),
fields: hold_fields,
};
if let Err(e) = enqueue_hold(conn, hold_name, entry, delay).await {
tracing::warn!(error = %e, hold_name, "enqueue_hold failed — message may be lost");
return;
}
if let Err(e) = xack(conn, stream, group, entry_id).await {
tracing::warn!(stream, entry_id, error = %e, "XACK failed after enqueue_hold");
metrics::record_backend_error(metrics::BackendLabel::Redis, metrics::BackendErrorKind::Ack);
}
}
#[allow(clippy::too_many_arguments)]
async fn route_to_dlq(
conn: &mut RedisConnection,
topology: &'static QueueTopology,
stream: &str,
group: &str,
entry_id: &str,
fields: &HashMap<String, String>,
reason: &str,
death_count: u32,
) -> Result<()> {
let dlq = match topology.dlq() {
Some(d) => d,
None => {
tracing::warn!(stream, entry_id, reason, "no DLQ configured — discarding");
if let Err(e) = xack(conn, stream, group, entry_id).await {
tracing::warn!(stream, entry_id, error = %e, "XACK failed while discarding (no DLQ)");
metrics::record_backend_error(
metrics::BackendLabel::Redis,
metrics::BackendErrorKind::Ack,
);
}
return Ok(());
}
};
let arg_count = fields.len() * 2 + 9;
let mut cmd = redis::Cmd::with_capacity(arg_count, arg_count * 16);
cmd.arg("XADD").arg(dlq).arg("*");
for (k, v) in fields {
cmd.arg(k.as_str()).arg(v.as_str());
}
cmd.arg(X_DEATH_REASON).arg(reason);
cmd.arg(X_DEATH_COUNT).arg(death_count.to_string());
cmd.arg(X_ORIGINAL_QUEUE).arg(stream);
conn.query::<redis::Value>(&mut cmd).await.map_err(|e| {
tracing::warn!(error = %e, dlq, "XADD to DLQ failed — message stays in PEL");
ShoveError::Connection(format!("XADD to DLQ failed: {e}"))
})?;
if let Err(e) = xack(conn, stream, group, entry_id).await {
tracing::warn!(stream, entry_id, error = %e, "XACK failed after DLQ enqueue");
metrics::record_backend_error(metrics::BackendLabel::Redis, metrics::BackendErrorKind::Ack);
}
Ok(())
}
async fn requeue_to_stream(
conn: &mut RedisConnection,
stream: &str,
fields: &HashMap<String, String>,
retry_count: u32,
) {
let arg_count = fields.len() * 2 + 4;
let mut cmd = redis::Cmd::with_capacity(arg_count, arg_count * 16);
cmd.arg("XADD").arg(stream).arg("*");
for (k, v) in fields {
if k.as_str() != X_RETRY_COUNT {
cmd.arg(k.as_str()).arg(v.as_str());
}
}
cmd.arg(X_RETRY_COUNT).arg(retry_count.to_string());
if let Err(e) = conn.query::<redis::Value>(&mut cmd).await {
tracing::warn!(error = %e, stream, "XADD on immediate requeue failed — message may be lost");
}
}
async fn xack(conn: &mut RedisConnection, stream: &str, group: &str, entry_id: &str) -> Result<()> {
conn.query::<i64>(redis::cmd("XACK").arg(stream).arg(group).arg(entry_id))
.await
.map(|_| ())
.map_err(|e| ShoveError::Connection(format!("XACK failed: {e}")))
}
pub(super) fn parse_xreadgroup_reply(
value: redis::Value,
capacity_hint: usize,
) -> Vec<(String, Vec<(String, String)>)> {
let streams = match value {
redis::Value::Nil => return Vec::new(),
redis::Value::Array(arr) => arr,
_ => return Vec::new(),
};
let mut result = Vec::with_capacity(capacity_hint);
for stream_item in streams {
let stream_pair = match stream_item {
redis::Value::Array(arr) if arr.len() >= 2 => arr,
_ => continue,
};
let entry_list = match &stream_pair[1] {
redis::Value::Array(arr) => arr,
_ => continue,
};
for entry_item in entry_list {
let entry_pair = match entry_item {
redis::Value::Array(arr) if arr.len() >= 2 => arr,
_ => continue,
};
let entry_id = match &entry_pair[0] {
redis::Value::BulkString(b) => match std::str::from_utf8(b) {
Ok(s) => s.to_owned(),
Err(_) => continue,
},
redis::Value::SimpleString(s) => s.clone(),
_ => continue,
};
let field_list = match &entry_pair[1] {
redis::Value::Array(arr) => arr,
_ => continue,
};
let mut fields: Vec<(String, String)> = Vec::new();
let mut iter = field_list.iter();
loop {
let key = match iter.next() {
Some(redis::Value::BulkString(b)) => match std::str::from_utf8(b) {
Ok(s) => s.to_owned(),
Err(_) => break,
},
Some(redis::Value::SimpleString(s)) => s.clone(),
Some(_) => break,
None => break,
};
let val = match iter.next() {
Some(redis::Value::BulkString(b)) => String::from_utf8_lossy(b).into_owned(),
Some(redis::Value::SimpleString(s)) => s.clone(),
Some(redis::Value::Nil) => String::new(),
Some(_) => break,
None => break,
};
fields.push((key, val));
}
result.push((entry_id, fields));
}
}
result
}
const INTERNAL_KEYS: &[&str] = &[
PAYLOAD_FIELD,
X_RETRY_COUNT,
X_SEQUENCE_KEY,
X_MESSAGE_ID,
X_DEATH_REASON,
X_DEATH_COUNT,
X_ORIGINAL_QUEUE,
];
fn partition_entry_fields(
fields_vec: Vec<(String, String)>,
) -> (HashMap<String, String>, HashMap<String, String>) {
let mut internal = HashMap::with_capacity(INTERNAL_KEYS.len());
let mut user = HashMap::new();
for (k, v) in fields_vec {
if INTERNAL_KEYS.contains(&k.as_str()) {
internal.insert(k, v);
} else {
user.insert(k, v);
}
}
(internal, user)
}
pub(super) fn hold_level<T>(retry_count: u32, hold_queues: &[T]) -> Option<usize> {
if hold_queues.is_empty() {
None
} else {
Some((retry_count as usize).min(hold_queues.len() - 1))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn retry_count_routing_to_hold_level() {
let hold_queues = vec!["orders-hold-5s", "orders-hold-30s"];
assert_eq!(hold_level(0, &hold_queues), Some(0));
assert_eq!(hold_level(1, &hold_queues), Some(1));
assert_eq!(hold_level(2, &hold_queues), Some(1)); }
#[test]
fn hold_level_empty_returns_none() {
assert_eq!(hold_level(0, &[""]), Some(0));
let empty: Vec<&str> = vec![];
assert_eq!(hold_level(0, &empty), None);
}
#[test]
fn parse_xreadgroup_nil_returns_empty() {
let result = parse_xreadgroup_reply(redis::Value::Nil, 0);
assert!(result.is_empty());
}
#[test]
fn parse_xreadgroup_empty_array_returns_empty() {
let result = parse_xreadgroup_reply(redis::Value::Array(vec![]), 0);
assert!(result.is_empty());
}
#[test]
fn parse_xreadgroup_valid_entry() {
let entry = redis::Value::Array(vec![
redis::Value::BulkString(b"1234-0".to_vec()),
redis::Value::Array(vec![
redis::Value::BulkString(b"payload".to_vec()),
redis::Value::BulkString(b"{}".to_vec()),
redis::Value::BulkString(b"x-retry-count".to_vec()),
redis::Value::BulkString(b"0".to_vec()),
]),
]);
let stream = redis::Value::Array(vec![
redis::Value::BulkString(b"mystream".to_vec()),
redis::Value::Array(vec![entry]),
]);
let reply = redis::Value::Array(vec![stream]);
let result = parse_xreadgroup_reply(reply, 0);
assert_eq!(result.len(), 1);
assert_eq!(result[0].0, "1234-0");
assert_eq!(result[0].1.len(), 2);
assert_eq!(result[0].1[0], ("payload".to_string(), "{}".to_string()));
assert_eq!(
result[0].1[1],
("x-retry-count".to_string(), "0".to_string())
);
}
#[test]
fn parse_xreadgroup_simple_string_id() {
let entry = redis::Value::Array(vec![
redis::Value::SimpleString("9999-1".to_string()),
redis::Value::Array(vec![
redis::Value::BulkString(b"payload".to_vec()),
redis::Value::BulkString(b"hello".to_vec()),
]),
]);
let stream = redis::Value::Array(vec![
redis::Value::BulkString(b"s".to_vec()),
redis::Value::Array(vec![entry]),
]);
let result = parse_xreadgroup_reply(redis::Value::Array(vec![stream]), 0);
assert_eq!(result.len(), 1);
assert_eq!(result[0].0, "9999-1");
}
#[test]
fn parse_xreadgroup_nil_field_value_becomes_empty_string() {
let entry = redis::Value::Array(vec![
redis::Value::BulkString(b"1-0".to_vec()),
redis::Value::Array(vec![
redis::Value::BulkString(b"payload".to_vec()),
redis::Value::Nil,
]),
]);
let stream = redis::Value::Array(vec![
redis::Value::BulkString(b"s".to_vec()),
redis::Value::Array(vec![entry]),
]);
let result = parse_xreadgroup_reply(redis::Value::Array(vec![stream]), 0);
assert_eq!(result.len(), 1);
assert_eq!(result[0].1[0], ("payload".to_string(), String::new()));
}
#[test]
fn parse_xreadgroup_odd_field_count_stops_at_last_key() {
let entry = redis::Value::Array(vec![
redis::Value::BulkString(b"2-0".to_vec()),
redis::Value::Array(vec![
redis::Value::BulkString(b"payload".to_vec()),
redis::Value::BulkString(b"{}".to_vec()),
redis::Value::BulkString(b"dangling-key".to_vec()),
]),
]);
let stream = redis::Value::Array(vec![
redis::Value::BulkString(b"s".to_vec()),
redis::Value::Array(vec![entry]),
]);
let result = parse_xreadgroup_reply(redis::Value::Array(vec![stream]), 0);
assert_eq!(result.len(), 1);
assert_eq!(result[0].1.len(), 1);
assert_eq!(result[0].1[0].0, "payload");
}
#[test]
fn parse_xreadgroup_wrong_root_type_returns_empty() {
let result = parse_xreadgroup_reply(redis::Value::Int(0), 0);
assert!(result.is_empty());
}
#[test]
fn partition_entry_fields_separates_user_headers() {
let fields_vec = vec![
(PAYLOAD_FIELD.to_string(), "data".to_string()),
(X_RETRY_COUNT.to_string(), "2".to_string()),
(X_SEQUENCE_KEY.to_string(), "acct-1".to_string()),
("x-custom".to_string(), "val".to_string()),
];
let (internal, user) = partition_entry_fields(fields_vec);
assert_eq!(user.len(), 1);
assert_eq!(user.get("x-custom").map(String::as_str), Some("val"));
assert!(internal.contains_key(PAYLOAD_FIELD));
assert!(internal.contains_key(X_RETRY_COUNT));
assert!(internal.contains_key(X_SEQUENCE_KEY));
}
#[test]
fn partition_entry_fields_all_internal_keys_go_to_internal() {
let fields_vec = vec![
(PAYLOAD_FIELD.to_string(), "data".to_string()),
(X_RETRY_COUNT.to_string(), "2".to_string()),
(X_SEQUENCE_KEY.to_string(), "acct-1".to_string()),
(X_MESSAGE_ID.to_string(), "msg-abc".to_string()),
(X_DEATH_REASON.to_string(), "max-retries".to_string()),
(X_DEATH_COUNT.to_string(), "5".to_string()),
(X_ORIGINAL_QUEUE.to_string(), "orders".to_string()),
("x-custom".to_string(), "val".to_string()),
];
let (internal, user) = partition_entry_fields(fields_vec);
assert_eq!(user.len(), 1);
assert_eq!(user.get("x-custom").map(String::as_str), Some("val"));
for key in INTERNAL_KEYS {
assert!(
!user.contains_key(*key),
"internal key {key:?} leaked into user headers"
);
assert!(
internal.contains_key(*key),
"internal key {key:?} missing from internal map"
);
}
}
#[test]
fn partition_entry_fields_empty_input_returns_empty_maps() {
let (internal, user) = partition_entry_fields(vec![]);
assert!(internal.is_empty());
assert!(user.is_empty());
}
#[test]
fn consumer_name_is_unique() {
let a = RedisConsumer::consumer_name();
let b = RedisConsumer::consumer_name();
assert_ne!(a, b, "consumer names must be unique per call");
}
#[test]
fn parse_xreadgroup_non_array_stream_item_skipped() {
let reply = redis::Value::Array(vec![
redis::Value::Int(42), ]);
let result = parse_xreadgroup_reply(reply, 0);
assert!(result.is_empty());
}
#[test]
fn parse_xreadgroup_stream_pair_too_short_skipped() {
let reply = redis::Value::Array(vec![redis::Value::Array(vec![redis::Value::BulkString(
b"only-one".to_vec(),
)])]);
let result = parse_xreadgroup_reply(reply, 0);
assert!(result.is_empty());
}
#[test]
fn parse_xreadgroup_non_array_entry_list_skipped() {
let reply = redis::Value::Array(vec![redis::Value::Array(vec![
redis::Value::BulkString(b"mystream".to_vec()),
redis::Value::Int(99), ])]);
let result = parse_xreadgroup_reply(reply, 0);
assert!(result.is_empty());
}
#[test]
fn parse_xreadgroup_entry_pair_too_short_skipped() {
let reply = redis::Value::Array(vec![redis::Value::Array(vec![
redis::Value::BulkString(b"mystream".to_vec()),
redis::Value::Array(vec![
redis::Value::Array(vec![redis::Value::BulkString(b"1-0".to_vec())]),
]),
])]);
let result = parse_xreadgroup_reply(reply, 0);
assert!(result.is_empty());
}
#[test]
fn parse_xreadgroup_int_entry_id_skipped() {
let reply = redis::Value::Array(vec![redis::Value::Array(vec![
redis::Value::BulkString(b"mystream".to_vec()),
redis::Value::Array(vec![redis::Value::Array(vec![
redis::Value::Int(12345), redis::Value::Array(vec![]),
])]),
])]);
let result = parse_xreadgroup_reply(reply, 0);
assert!(result.is_empty());
}
#[test]
fn parse_xreadgroup_non_array_field_list_skipped() {
let reply = redis::Value::Array(vec![redis::Value::Array(vec![
redis::Value::BulkString(b"mystream".to_vec()),
redis::Value::Array(vec![redis::Value::Array(vec![
redis::Value::BulkString(b"1-0".to_vec()),
redis::Value::Int(0), ])]),
])]);
let result = parse_xreadgroup_reply(reply, 0);
assert!(result.is_empty());
}
#[test]
fn parse_xreadgroup_simple_string_field_key() {
let reply = redis::Value::Array(vec![redis::Value::Array(vec![
redis::Value::BulkString(b"mystream".to_vec()),
redis::Value::Array(vec![redis::Value::Array(vec![
redis::Value::BulkString(b"1-0".to_vec()),
redis::Value::Array(vec![
redis::Value::SimpleString("myfieldkey".to_string()),
redis::Value::BulkString(b"myvalue".to_vec()),
]),
])]),
])]);
let result = parse_xreadgroup_reply(reply, 0);
assert_eq!(result.len(), 1);
assert_eq!(result[0].1.len(), 1);
assert_eq!(result[0].1[0].0, "myfieldkey");
assert_eq!(result[0].1[0].1, "myvalue");
}
#[test]
fn parse_xreadgroup_int_field_key_breaks_loop() {
let reply = redis::Value::Array(vec![redis::Value::Array(vec![
redis::Value::BulkString(b"mystream".to_vec()),
redis::Value::Array(vec![redis::Value::Array(vec![
redis::Value::BulkString(b"1-0".to_vec()),
redis::Value::Array(vec![
redis::Value::BulkString(b"good-key".to_vec()),
redis::Value::BulkString(b"good-val".to_vec()),
redis::Value::Int(42), redis::Value::BulkString(b"after-break".to_vec()),
]),
])]),
])]);
let result = parse_xreadgroup_reply(reply, 0);
assert_eq!(result.len(), 1);
assert_eq!(result[0].1.len(), 1);
assert_eq!(result[0].1[0].0, "good-key");
}
#[test]
fn parse_xreadgroup_simple_string_field_value() {
let reply = redis::Value::Array(vec![redis::Value::Array(vec![
redis::Value::BulkString(b"mystream".to_vec()),
redis::Value::Array(vec![redis::Value::Array(vec![
redis::Value::BulkString(b"1-0".to_vec()),
redis::Value::Array(vec![
redis::Value::BulkString(b"key".to_vec()),
redis::Value::SimpleString("simplevalue".to_string()),
]),
])]),
])]);
let result = parse_xreadgroup_reply(reply, 0);
assert_eq!(result.len(), 1);
assert_eq!(result[0].1[0].1, "simplevalue");
}
#[test]
fn parse_xreadgroup_int_field_value_breaks_loop() {
let reply = redis::Value::Array(vec![redis::Value::Array(vec![
redis::Value::BulkString(b"mystream".to_vec()),
redis::Value::Array(vec![redis::Value::Array(vec![
redis::Value::BulkString(b"1-0".to_vec()),
redis::Value::Array(vec![
redis::Value::BulkString(b"k1".to_vec()),
redis::Value::BulkString(b"v1".to_vec()),
redis::Value::BulkString(b"k2".to_vec()),
redis::Value::Int(99), ]),
])]),
])]);
let result = parse_xreadgroup_reply(reply, 0);
assert_eq!(result.len(), 1);
assert_eq!(result[0].1.len(), 1);
assert_eq!(result[0].1[0].0, "k1");
}
#[test]
fn parse_xreadgroup_multiple_streams_merged_flat() {
fn make_stream(name: &str, id: &str, val: &str) -> redis::Value {
redis::Value::Array(vec![
redis::Value::BulkString(name.as_bytes().to_vec()),
redis::Value::Array(vec![redis::Value::Array(vec![
redis::Value::BulkString(id.as_bytes().to_vec()),
redis::Value::Array(vec![
redis::Value::BulkString(b"payload".to_vec()),
redis::Value::BulkString(val.as_bytes().to_vec()),
]),
])]),
])
}
let reply = redis::Value::Array(vec![
make_stream("stream-a", "1-0", "msg-a"),
make_stream("stream-b", "2-0", "msg-b"),
make_stream("stream-c", "3-0", "msg-c"),
]);
let result = parse_xreadgroup_reply(reply, 0);
assert_eq!(result.len(), 3);
assert_eq!(result[0].0, "1-0");
assert_eq!(result[1].0, "2-0");
assert_eq!(result[2].0, "3-0");
}
#[test]
fn hold_level_single_element_always_returns_zero() {
let single = vec!["only-queue"];
assert_eq!(hold_level(0, &single), Some(0));
assert_eq!(hold_level(1, &single), Some(0));
assert_eq!(hold_level(100, &single), Some(0));
assert_eq!(hold_level(u32::MAX, &single), Some(0));
}
#[test]
fn nogroup_error_string_is_detected() {
let err_str = "NOGROUP No such consumer group 'grp' for key name 'stream'";
assert!(err_str.contains("NOGROUP"));
let err = ShoveError::Connection(err_str.to_string());
assert!(err.to_string().contains("NOGROUP"));
}
#[test]
fn nogroup_error_is_retryable() {
let err = ShoveError::Connection(
"consumer group does not exist on stream 'foo': NOGROUP ...".into(),
);
assert!(
err.is_retryable(),
"NOGROUP error must be retryable so consumers survive Redis restart"
);
}
#[test]
fn nogroup_error_is_connection_not_topology() {
let err = ShoveError::Connection(
"consumer group does not exist on stream 'foo': NOGROUP ...".into(),
);
assert!(
matches!(err, ShoveError::Connection(_)),
"NOGROUP must be ShoveError::Connection, not Topology"
);
assert!(
!matches!(err, ShoveError::Topology(_)),
"NOGROUP must not be ShoveError::Topology"
);
}
#[test]
fn exhausted_reconnect_error_message_format() {
let stream = "orders";
let max: u32 = 3;
let cause = "connection refused";
let msg = format!("consumer on '{stream}' exhausted {max} reconnect attempt(s): {cause}");
assert!(msg.contains(stream), "stream name must appear in error");
assert!(
msg.contains(&max.to_string()),
"attempt count must appear in error"
);
assert!(msg.contains(cause), "root cause must appear in error");
}
#[tokio::test]
async fn run_with_reconnect_stops_when_limit_reached() {
use tokio_util::sync::CancellationToken;
let shutdown = CancellationToken::new();
let mut calls = 0u32;
let result = run_with_reconnect(&shutdown, "test-stream", Some(2), || {
calls += 1;
async { Err(ShoveError::Connection("transient".into())) }
})
.await;
assert!(
result.is_err(),
"must propagate error after exhausting attempts"
);
let msg = result.unwrap_err().to_string();
assert!(
msg.contains("test-stream"),
"error must name the stream; got: {msg}"
);
assert_eq!(calls, 2, "must attempt exactly max times before giving up");
}
#[tokio::test]
async fn run_with_reconnect_unlimited_can_succeed_after_retries() {
use tokio_util::sync::CancellationToken;
let shutdown = CancellationToken::new();
let mut calls = 0u32;
let result = run_with_reconnect(&shutdown, "test-stream", None, || {
calls += 1;
async move {
if calls < 3 {
Err(ShoveError::Connection("transient".into()))
} else {
Ok(())
}
}
})
.await;
assert!(result.is_ok(), "must succeed once the closure returns Ok");
assert_eq!(calls, 3);
}
#[tokio::test]
async fn run_with_reconnect_non_retryable_error_propagates_immediately() {
use tokio_util::sync::CancellationToken;
let shutdown = CancellationToken::new();
let mut calls = 0u32;
let result = run_with_reconnect(&shutdown, "test-stream", None, || {
calls += 1;
async { Err(ShoveError::Topology("bad topology".into())) }
})
.await;
assert!(result.is_err());
assert_eq!(calls, 1, "non-retryable error must not trigger reconnect");
}
#[tokio::test(start_paused = true)]
async fn run_with_reconnect_shutdown_during_sleep_returns_ok() {
use std::sync::atomic::{AtomicU32, Ordering};
use tokio_util::sync::CancellationToken;
let shutdown = CancellationToken::new();
let calls = Arc::new(AtomicU32::new(0));
let canceller = shutdown.clone();
let calls_clone = Arc::clone(&calls);
tokio::spawn(async move {
tokio::task::yield_now().await;
canceller.cancel();
});
let result = run_with_reconnect(&shutdown, "test-stream", None, || {
calls_clone.fetch_add(1, Ordering::SeqCst);
async { Err(ShoveError::Connection("transient".into())) }
})
.await;
assert!(
result.is_ok(),
"shutdown during backoff sleep must short-circuit to Ok"
);
assert_eq!(
calls.load(Ordering::SeqCst),
1,
"closure must not be re-invoked after cancellation"
);
}
#[tokio::test]
async fn run_with_reconnect_shutdown_between_error_and_sleep_returns_ok() {
use tokio_util::sync::CancellationToken;
let shutdown = CancellationToken::new();
shutdown.cancel();
let mut calls = 0u32;
let result = run_with_reconnect(&shutdown, "test-stream", None, || {
calls += 1;
async { Err(ShoveError::Connection("transient".into())) }
})
.await;
assert!(
result.is_ok(),
"cancellation observed after a retryable error must yield Ok"
);
assert_eq!(
calls, 1,
"closure runs exactly once before the cancellation check"
);
}
#[test]
fn parse_xreadgroup_non_utf8_entry_id_skipped() {
let bad_id_entry = redis::Value::Array(vec![
redis::Value::BulkString(vec![0xff, 0xfe, 0xfd]), redis::Value::Array(vec![
redis::Value::BulkString(b"payload".to_vec()),
redis::Value::BulkString(b"x".to_vec()),
]),
]);
let good_entry = redis::Value::Array(vec![
redis::Value::BulkString(b"2-0".to_vec()),
redis::Value::Array(vec![
redis::Value::BulkString(b"payload".to_vec()),
redis::Value::BulkString(b"y".to_vec()),
]),
]);
let reply = redis::Value::Array(vec![redis::Value::Array(vec![
redis::Value::BulkString(b"mystream".to_vec()),
redis::Value::Array(vec![bad_id_entry, good_entry]),
])]);
let result = parse_xreadgroup_reply(reply, 0);
assert_eq!(
result.len(),
1,
"non-UTF-8 entry ID must be skipped, leaving only the good entry"
);
assert_eq!(result[0].0, "2-0");
}
#[test]
fn parse_xreadgroup_non_utf8_field_key_breaks_loop() {
let reply = redis::Value::Array(vec![redis::Value::Array(vec![
redis::Value::BulkString(b"mystream".to_vec()),
redis::Value::Array(vec![redis::Value::Array(vec![
redis::Value::BulkString(b"1-0".to_vec()),
redis::Value::Array(vec![
redis::Value::BulkString(b"good-key".to_vec()),
redis::Value::BulkString(b"good-val".to_vec()),
redis::Value::BulkString(vec![0xff, 0xfe]), redis::Value::BulkString(b"never-reached".to_vec()),
]),
])]),
])]);
let result = parse_xreadgroup_reply(reply, 0);
assert_eq!(result.len(), 1);
assert_eq!(
result[0].1.len(),
1,
"only the pair before the bad key survives"
);
assert_eq!(result[0].1[0].0, "good-key");
}
#[test]
fn parse_xreadgroup_multiple_entries_within_single_stream() {
fn entry(id: &str, val: &str) -> redis::Value {
redis::Value::Array(vec![
redis::Value::BulkString(id.as_bytes().to_vec()),
redis::Value::Array(vec![
redis::Value::BulkString(b"payload".to_vec()),
redis::Value::BulkString(val.as_bytes().to_vec()),
]),
])
}
let reply = redis::Value::Array(vec![redis::Value::Array(vec![
redis::Value::BulkString(b"mystream".to_vec()),
redis::Value::Array(vec![
entry("1-0", "a"),
entry("2-0", "b"),
entry("3-0", "c"),
]),
])]);
let result = parse_xreadgroup_reply(reply, 0);
assert_eq!(result.len(), 3);
assert_eq!(result[0].0, "1-0");
assert_eq!(result[1].0, "2-0");
assert_eq!(result[2].0, "3-0");
assert_eq!(result[0].1[0].1, "a");
assert_eq!(result[1].1[0].1, "b");
assert_eq!(result[2].1[0].1, "c");
}
}