use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use bytes::Bytes;
use fred::clients::Pool;
use fred::interfaces::{KeysInterface, ListInterface, SortedSetsInterface};
use ruststream::AckError;
use crate::error::RedisError;
const SWEEP_BATCH: i64 = 128;
const SALT_LEN: usize = 12;
#[derive(Debug, Clone)]
pub(crate) struct RecoveryConfig {
pub(crate) zset_key: String,
pub(crate) min_idle: Duration,
pub(crate) ttl: Option<Duration>,
}
fn now_ms() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.map_or(0, |d| u64::try_from(d.as_millis()).unwrap_or(u64::MAX))
}
#[allow(
clippy::cast_precision_loss,
reason = "epoch-ms < 2^53 is exact in f64"
)]
fn as_score(ms: u64) -> f64 {
ms as f64
}
fn millis(d: Duration) -> u64 {
u64::try_from(d.as_millis()).unwrap_or(u64::MAX)
}
fn ttl_millis(ttl: Duration) -> i64 {
i64::try_from(ttl.as_millis()).unwrap_or(i64::MAX).max(1)
}
fn claim_member(value: &[u8]) -> Vec<u8> {
static COUNTER: AtomicU64 = AtomicU64::new(0);
let n = COUNTER.fetch_add(1, Ordering::Relaxed);
let mut buf = Vec::with_capacity(SALT_LEN + value.len());
buf.extend_from_slice(&std::process::id().to_be_bytes());
buf.extend_from_slice(&n.to_be_bytes());
buf.extend_from_slice(value);
buf
}
fn value_from_member(member: &[u8]) -> Option<&[u8]> {
member.get(SALT_LEN..)
}
fn broker_err(err: fred::error::Error) -> AckError {
AckError::Broker(Box::new(err))
}
pub(crate) async fn record_claim(
pool: &Pool,
cfg: &RecoveryConfig,
value: &[u8],
) -> Result<Vec<u8>, RedisError> {
let member = claim_member(value);
let _: i64 = pool
.zadd(
cfg.zset_key.as_str(),
None,
None,
false,
false,
(as_score(now_ms()), member.clone()),
)
.await
.map_err(RedisError::stream)?;
if let Some(ttl) = cfg.ttl {
let _: i64 = pool
.pexpire(cfg.zset_key.as_str(), ttl_millis(ttl), None)
.await
.map_err(RedisError::stream)?;
}
Ok(member)
}
pub(crate) async fn forget(pool: &Pool, zset_key: &str, member: &[u8]) -> Result<(), AckError> {
let _: i64 = pool
.zrem(zset_key, member.to_vec())
.await
.map_err(broker_err)?;
Ok(())
}
pub(crate) async fn sweep_orphans(
pool: &Pool,
cfg: &RecoveryConfig,
main_key: &str,
processing_key: &str,
) -> Result<(), RedisError> {
let cutoff = as_score(now_ms().saturating_sub(millis(cfg.min_idle)));
let due: Vec<Bytes> = pool
.zrangebyscore(
cfg.zset_key.as_str(),
0.0,
cutoff,
false,
Some((0, SWEEP_BATCH)),
)
.await
.map_err(RedisError::stream)?;
for member in due {
if let Some(value) = value_from_member(&member) {
let value = value.to_vec();
let removed: i64 = pool
.lrem(processing_key, 1, value.clone())
.await
.map_err(RedisError::stream)?;
if removed == 1 {
let _: i64 = pool
.lpush(main_key, value)
.await
.map_err(RedisError::stream)?;
}
}
let _: i64 = pool
.zrem(cfg.zset_key.as_str(), member.to_vec())
.await
.map_err(RedisError::stream)?;
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn member_carries_value_after_the_salt() {
let member = claim_member(b"job-payload");
assert_eq!(member.len(), SALT_LEN + b"job-payload".len());
assert_eq!(value_from_member(&member), Some(b"job-payload".as_slice()));
}
#[test]
fn equal_values_get_distinct_members() {
let a = claim_member(b"dup");
let b = claim_member(b"dup");
assert_ne!(
a, b,
"the per-claim salt must keep equal values from colliding"
);
assert_eq!(value_from_member(&a), value_from_member(&b));
}
#[test]
fn short_member_has_no_value() {
assert_eq!(value_from_member(&[0u8; SALT_LEN - 1]), None);
}
#[test]
fn ttl_millis_clamps_sub_millisecond_to_one() {
assert_eq!(ttl_millis(Duration::from_secs(30)), 30_000);
assert_eq!(ttl_millis(Duration::ZERO), 1);
}
}