use std::time::{Duration, SystemTime, UNIX_EPOCH};
use bytes::Bytes;
use fred::clients::Pool;
use fred::interfaces::{KeysInterface, SortedSetsInterface, StreamsInterface};
use ruststream::runtime::RETRY_COUNT_HEADER;
use ruststream::{AckError, Headers};
use crate::convert::fields_for_publish;
use crate::envelope::{frame, unframe};
use crate::error::RedisError;
const SWEEP_BATCH: i64 = 128;
#[derive(Debug, Clone)]
#[non_exhaustive]
pub enum DelayedRetry {
DurableZset {
key: String,
ttl: Option<Duration>,
},
}
#[derive(Debug, Clone)]
pub(crate) struct DelayConfig {
zset_key: String,
ttl: Option<Duration>,
}
impl DelayConfig {
pub(crate) fn from_retry(retry: &DelayedRetry) -> Self {
match retry {
DelayedRetry::DurableZset { key, ttl } => Self {
zset_key: key.clone(),
ttl: *ttl,
},
}
}
}
fn now_ms() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.map_or(0, |d| u64::try_from(d.as_millis()).unwrap_or(u64::MAX))
}
fn delay_millis(delay: Duration) -> u64 {
u64::try_from(delay.as_millis()).unwrap_or(u64::MAX)
}
#[allow(
clippy::cast_precision_loss,
reason = "epoch-ms < 2^53 is exact in f64"
)]
fn as_score(ms: u64) -> f64 {
ms as f64
}
fn ttl_millis(ttl: Duration) -> i64 {
i64::try_from(ttl.as_millis()).unwrap_or(i64::MAX).max(1)
}
fn encode_member(id: &str, payload: &[u8], headers: &Headers) -> Vec<u8> {
let body = frame(None, payload, headers);
let id = id.as_bytes();
let id_len = u32::try_from(id.len()).unwrap_or(u32::MAX);
let mut buf = Vec::with_capacity(4 + id.len() + body.len());
buf.extend_from_slice(&id_len.to_be_bytes());
buf.extend_from_slice(id);
buf.extend_from_slice(&body);
buf
}
fn decode_member(member: &[u8]) -> Option<(Bytes, Headers)> {
let id_len = u32::from_be_bytes(member.get(0..4)?.try_into().ok()?) as usize;
let body = member.get(4usize.checked_add(id_len)?..)?;
Some(unframe(None, body))
}
fn next_retry_count(headers: &Headers) -> u64 {
headers
.get_str(RETRY_COUNT_HEADER)
.and_then(|v| v.parse::<u64>().ok())
.unwrap_or(0)
+ 1
}
fn broker_err(err: fred::error::Error) -> AckError {
AckError::Broker(Box::new(err))
}
pub(crate) async fn schedule(
pool: &Pool,
cfg: &DelayConfig,
id: &str,
payload: &[u8],
headers: &Headers,
delay: Duration,
) -> Result<(), AckError> {
let fire_at = now_ms().saturating_add(delay_millis(delay));
let mut next = headers.clone();
next.insert(RETRY_COUNT_HEADER, next_retry_count(headers).to_string());
let member = encode_member(id, payload, &next);
let _: i64 = pool
.zadd(
cfg.zset_key.as_str(),
None,
None,
false,
false,
(as_score(fire_at), member),
)
.await
.map_err(broker_err)?;
if let Some(ttl) = cfg.ttl {
let _: i64 = pool
.pexpire(cfg.zset_key.as_str(), ttl_millis(ttl), None)
.await
.map_err(broker_err)?;
}
Ok(())
}
pub(crate) async fn sweep_due(
pool: &Pool,
cfg: &DelayConfig,
stream_key: &str,
) -> Result<(), RedisError> {
let now = as_score(now_ms());
let due: Vec<Bytes> = pool
.zrangebyscore(
cfg.zset_key.as_str(),
0.0,
now,
false,
Some((0, SWEEP_BATCH)),
)
.await
.map_err(RedisError::stream)?;
for member in due {
let removed: i64 = pool
.zrem(cfg.zset_key.as_str(), member.clone())
.await
.map_err(RedisError::stream)?;
if removed != 1 {
continue;
}
let Some((payload, headers)) = decode_member(&member) else {
continue;
};
let fields = fields_for_publish(&payload, &headers);
let _: String = pool
.xadd(stream_key, false, None::<()>, "*", fields)
.await
.map_err(RedisError::stream)?;
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn member_round_trips_payload_and_headers() {
let mut headers = Headers::new();
headers.insert("content-type", "application/json");
headers.insert(RETRY_COUNT_HEADER, "2");
let member = encode_member("1700000000000-0", b"{}", &headers);
let (payload, decoded) = decode_member(&member).expect("decodes");
assert_eq!(payload.as_ref(), b"{}");
assert_eq!(decoded.content_type(), Some("application/json"));
assert_eq!(decoded.get_str(RETRY_COUNT_HEADER), Some("2"));
}
#[test]
fn distinct_ids_yield_distinct_members_for_equal_payloads() {
let headers = Headers::new();
let a = encode_member("1-0", b"dup", &headers);
let b = encode_member("2-0", b"dup", &headers);
assert_ne!(
a, b,
"the id salt must keep equal payloads from colliding in the ZSET"
);
}
#[test]
fn next_retry_count_starts_at_one_and_increments() {
let mut headers = Headers::new();
assert_eq!(next_retry_count(&headers), 1);
headers.insert(RETRY_COUNT_HEADER, "4");
assert_eq!(next_retry_count(&headers), 5);
headers.insert(RETRY_COUNT_HEADER, "not-a-number");
assert_eq!(next_retry_count(&headers), 1);
}
#[test]
fn ttl_millis_clamps_sub_millisecond_to_one() {
assert_eq!(ttl_millis(Duration::from_secs(30)), 30_000);
assert_eq!(ttl_millis(Duration::ZERO), 1);
}
}