use serde::{Deserialize, Serialize};
use std::time::Duration;
use tokio_util::sync::CancellationToken;
use super::client::{RedisClient, RedisConnection};
use super::constants::{REQUEUE_BATCH_SIZE, REQUEUE_POLL_MS};
use super::topology::RedisTopologyDeclarer;
use crate::error::{Result, ShoveError};
use crate::metrics::{BackendErrorKind, BackendLabel, record_backend_error};
use crate::retry::Backoff;
const POLL_INTERVAL: Duration = Duration::from_millis(REQUEUE_POLL_MS);
#[derive(Debug, Clone, Serialize, Deserialize)]
pub(crate) struct HoldEntry {
pub stream: String,
pub fields: Vec<(String, String)>,
}
pub(crate) async fn enqueue_hold(
conn: &mut RedisConnection,
hold_queue_name: &str,
entry: HoldEntry,
delay: Duration,
) -> Result<()> {
let set_key = RedisTopologyDeclarer::hold_set_name(hold_queue_name);
let delay_ms = u64::try_from(delay.as_millis()).unwrap_or(u64::MAX);
let redeliver_at_ms = now_ms().saturating_add(delay_ms);
let value = serde_json::to_string(&entry)?;
let mut cmd = redis::cmd("ZADD");
cmd.arg(set_key).arg(redeliver_at_ms as f64).arg(&value);
let _: i64 = conn.query(&mut cmd).await?;
Ok(())
}
pub(crate) fn spawn_requeuer(
client: RedisClient,
hold_queue_names: Vec<String>,
shutdown: CancellationToken,
) -> tokio::task::JoinHandle<()> {
let hold_set_keys: Vec<String> = hold_queue_names
.iter()
.map(|n| RedisTopologyDeclarer::hold_set_name(n))
.collect();
tokio::spawn(async move {
let mut conn = match acquire_conn_with_retry(&client, &shutdown).await {
Some(c) => c,
None => return,
};
loop {
let mut needs_reconnect = false;
for (hold_queue_name, set_key) in hold_queue_names.iter().zip(hold_set_keys.iter()) {
if let Err(e) = poll_hold_set(&mut conn, hold_queue_name, set_key).await {
tracing::warn!("requeuer: poll failed for {}: {}", hold_queue_name, e);
needs_reconnect = true;
break;
}
}
if needs_reconnect {
match acquire_conn_with_retry(&client, &shutdown).await {
Some(c) => {
conn = c;
continue;
}
None => break,
}
}
tokio::select! {
_ = shutdown.cancelled() => break,
_ = tokio::time::sleep(POLL_INTERVAL) => {}
}
}
})
}
async fn acquire_conn_with_retry(
client: &RedisClient,
shutdown: &CancellationToken,
) -> Option<RedisConnection> {
let mut backoff = Backoff::default();
loop {
match client.multiplexed_conn().await {
Ok(c) => return Some(c),
Err(e) => {
if shutdown.is_cancelled() {
return None;
}
let delay = backoff.next().expect("backoff is infinite");
tracing::warn!(
"requeuer: connection failed ({}), retrying in {:.1}s",
e,
delay.as_secs_f64()
);
tokio::select! {
_ = tokio::time::sleep(delay) => {}
_ = shutdown.cancelled() => return None,
}
}
}
}
}
fn now_ms() -> u64 {
u64::try_from(
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_millis(),
)
.unwrap_or(u64::MAX)
}
async fn poll_hold_set(
conn: &mut RedisConnection,
hold_queue_name: &str,
set_key: &str,
) -> Result<()> {
let now = now_ms();
let entries: Vec<String> = match conn
.query(
redis::cmd("ZRANGE")
.arg(set_key)
.arg(0f64)
.arg(now as f64)
.arg("BYSCORE")
.arg("LIMIT")
.arg(0i64)
.arg(REQUEUE_BATCH_SIZE),
)
.await
{
Ok(entries) => entries,
Err(e) => {
tracing::error!(
hold_queue = hold_queue_name,
set_key = set_key,
error = %e,
"requeuer: ZRANGE failed, cannot poll hold set"
);
return Err(ShoveError::Connection(format!(
"ZRANGE failed for hold set '{set_key}': {e}"
)));
}
};
for raw_json in entries {
let entry: HoldEntry = match serde_json::from_str(&raw_json) {
Ok(e) => e,
Err(e) => {
tracing::warn!(
"requeuer: corrupt hold entry in {} (removing): {}",
hold_queue_name,
e
);
let _: i64 = conn
.query(redis::cmd("ZREM").arg(set_key).arg(&raw_json))
.await
.unwrap_or(0);
continue;
}
};
let mut cmd = redis::cmd("XADD");
cmd.arg(&entry.stream).arg("*");
for (k, v) in &entry.fields {
cmd.arg(k).arg(v);
}
match conn.query::<String>(&mut cmd).await {
Ok(_) => {
if let Err(e) = conn
.query::<i64>(redis::cmd("ZREM").arg(set_key).arg(&raw_json))
.await
{
tracing::warn!(
"requeuer: ZREM failed for entry in {}: {}",
hold_queue_name,
e
);
record_backend_error(BackendLabel::Redis, BackendErrorKind::Ack);
}
}
Err(e) => {
tracing::warn!("requeuer: XADD failed for stream {}: {}", entry.stream, e);
}
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::error::ShoveError;
#[test]
fn hold_entry_roundtrips() {
use super::super::constants::{PAYLOAD_FIELD, X_RETRY_COUNT};
let entry = HoldEntry {
stream: "orders".into(),
fields: vec![
(PAYLOAD_FIELD.into(), "{}".into()),
(X_RETRY_COUNT.into(), "1".into()),
],
};
let json = serde_json::to_string(&entry).unwrap();
let decoded: HoldEntry = serde_json::from_str(&json).unwrap();
assert_eq!(decoded.stream, "orders");
assert_eq!(decoded.fields[0], (PAYLOAD_FIELD.into(), "{}".into()));
assert_eq!(decoded.fields[1], (X_RETRY_COUNT.into(), "1".into()));
}
#[test]
fn now_ms_is_nonzero() {
assert!(now_ms() > 0);
}
#[test]
fn hold_entry_with_empty_fields_roundtrips() {
let entry = HoldEntry {
stream: "my-stream".into(),
fields: vec![],
};
let json = serde_json::to_string(&entry).unwrap();
let decoded: HoldEntry = serde_json::from_str(&json).unwrap();
assert_eq!(decoded.stream, "my-stream");
assert!(decoded.fields.is_empty());
}
#[test]
fn hold_entry_fields_order_preserved() {
let entry = HoldEntry {
stream: "order-stream".into(),
fields: vec![
("alpha".into(), "1".into()),
("beta".into(), "2".into()),
("gamma".into(), "3".into()),
],
};
let json = serde_json::to_string(&entry).unwrap();
let decoded: HoldEntry = serde_json::from_str(&json).unwrap();
assert_eq!(decoded.fields[0], ("alpha".into(), "1".into()));
assert_eq!(decoded.fields[1], ("beta".into(), "2".into()));
assert_eq!(decoded.fields[2], ("gamma".into(), "3".into()));
}
#[test]
fn now_ms_is_positive_and_recent() {
let before = now_ms();
let after = now_ms();
assert!(after >= before);
assert!(
before > 1_577_836_800_000u64,
"timestamp too small: {before}"
);
}
#[test]
fn zrange_error_is_connection_variant() {
let set_key = "orders-hold-5s:pending";
let err = ShoveError::Connection(format!(
"ZRANGE failed for hold set '{set_key}': connection refused"
));
assert!(
matches!(err, ShoveError::Connection(_)),
"ZRANGE error must be ShoveError::Connection"
);
}
#[test]
fn zrange_error_message_contains_set_key_and_cause() {
let set_key = "orders-hold-5s:pending";
let cause = "connection timed out";
let msg = format!("ZRANGE failed for hold set '{set_key}': {cause}");
assert!(
msg.contains(set_key),
"error message must name the hold set; got: {msg}"
);
assert!(
msg.contains(cause),
"error message must preserve the original error; got: {msg}"
);
}
#[test]
fn backoff_is_infinite_for_retry_loop() {
let delays: Vec<_> = Backoff::default().take(500).collect();
assert_eq!(
delays.len(),
500,
"Backoff must never return None; the .expect() in acquire_conn_with_retry would panic"
);
}
#[test]
fn backoff_default_delay_stays_within_bounds() {
let max_expected = std::time::Duration::from_millis(45_000);
let min_expected = std::time::Duration::from_millis(500);
for delay in Backoff::default().take(50) {
assert!(
delay >= min_expected,
"delay {delay:?} is below the minimum expected bound"
);
assert!(
delay <= max_expected,
"delay {delay:?} exceeds the maximum expected bound"
);
}
}
}