use std::collections::HashMap;
use std::sync::{Arc, Mutex, OnceLock};
use std::time::Duration;
use tokio_util::sync::CancellationToken;
use super::client::RedisClient;
use super::reaper::spawn_maintenance;
use crate::consumer::DEFAULT_HANDLER_TIMEOUT;
type Key = (usize, String, String);
type Policy = Option<Duration>;
type Spawner = Arc<dyn Fn(CancellationToken, Policy) + Send + Sync>;
struct Entry {
policies: Vec<Policy>,
shutdown: CancellationToken,
spawner: Spawner,
}
impl Entry {
fn effective(&self) -> Policy {
if self.policies.iter().any(Option::is_none) {
None
} else {
self.policies.iter().flatten().max().copied()
}
}
fn respawn(&mut self) {
self.shutdown.cancel();
self.shutdown = CancellationToken::new();
(self.spawner)(self.shutdown.clone(), self.effective());
}
}
static REGISTRY: OnceLock<Mutex<HashMap<Key, Entry>>> = OnceLock::new();
fn registry() -> &'static Mutex<HashMap<Key, Entry>> {
REGISTRY.get_or_init(|| Mutex::new(HashMap::new()))
}
fn lock() -> std::sync::MutexGuard<'static, HashMap<Key, Entry>> {
registry()
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
}
pub(super) struct MaintenanceGuard {
key: Key,
policy: Policy,
}
impl Drop for MaintenanceGuard {
fn drop(&mut self) {
let mut map = lock();
if let Some(entry) = map.get_mut(&self.key) {
let before = entry.effective();
if let Some(pos) = entry.policies.iter().position(|p| *p == self.policy) {
entry.policies.swap_remove(pos);
}
if entry.policies.is_empty() {
entry.shutdown.cancel();
map.remove(&self.key);
} else if entry.effective() != before {
entry.respawn();
}
}
}
}
fn sidecar_timing(handler_timeout: Option<Duration>) -> (Duration, Option<u64>) {
match handler_timeout {
Some(timeout) => {
let min_idle_ms = timeout.as_millis() as u64;
let interval = Duration::from_millis(min_idle_ms.max(30_000));
(interval, Some(min_idle_ms))
}
None => (DEFAULT_HANDLER_TIMEOUT, None),
}
}
pub(super) fn acquire(
client: &RedisClient,
stream: &str,
handler_timeout: Option<Duration>,
) -> MaintenanceGuard {
let key = (
client.instance_id(),
stream.to_owned(),
client.group().to_owned(),
);
let client = client.clone();
let stream = stream.to_owned();
let spawner: Spawner = Arc::new(move |shutdown, policy| {
let (interval, min_idle_ms) = sidecar_timing(policy);
spawn_maintenance(
client.clone(),
vec![stream.clone()],
client.group().to_owned(),
interval,
min_idle_ms,
shutdown,
);
});
acquire_with(key, handler_timeout, spawner)
}
fn acquire_with(key: Key, policy: Policy, spawner: Spawner) -> MaintenanceGuard {
let mut map = lock();
match map.get_mut(&key) {
Some(entry) => {
let before = entry.effective();
entry.policies.push(policy);
if entry.effective() != before {
entry.respawn();
}
}
None => {
let shutdown = CancellationToken::new();
spawner(shutdown.clone(), policy);
map.insert(
key.clone(),
Entry {
policies: vec![policy],
shutdown,
spawner,
},
);
}
}
MaintenanceGuard { key, policy }
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicUsize, Ordering};
fn key(n: usize) -> Key {
(usize::MAX - n, format!("stream-{n}"), format!("group-{n}"))
}
#[test]
fn timing_with_timeout_floors_interval_and_sets_min_idle() {
assert_eq!(
sidecar_timing(Some(Duration::from_secs(5))),
(Duration::from_secs(30), Some(5_000)),
"short timeouts keep the 30s sweep floor but gate reclaim at the timeout"
);
assert_eq!(
sidecar_timing(Some(Duration::from_secs(45))),
(Duration::from_secs(45), Some(45_000)),
"long timeouts stretch the sweep interval to match"
);
}
#[test]
fn timing_without_timeout_disables_reclaim() {
assert_eq!(
sidecar_timing(None),
(Duration::from_secs(30), None),
"no handler deadline means no XAUTOCLAIM deadline"
);
}
use std::sync::Mutex as StdMutex;
type SpawnLog = Arc<StdMutex<Vec<(CancellationToken, Policy)>>>;
fn recorder() -> (SpawnLog, Spawner) {
let log: SpawnLog = Arc::new(StdMutex::new(Vec::new()));
let l = Arc::clone(&log);
let spawner: Spawner = Arc::new(move |token, policy| {
l.lock().unwrap().push((token, policy));
});
(log, spawner)
}
#[test]
fn second_acquire_for_same_key_does_not_spawn() {
let (log, spawner) = recorder();
let g1 = acquire_with(key(1), Some(Duration::from_secs(30)), spawner.clone());
let g2 = acquire_with(key(1), Some(Duration::from_secs(30)), spawner);
assert_eq!(log.lock().unwrap().len(), 1, "one sidecar per key");
drop(g1);
drop(g2);
}
#[test]
fn distinct_keys_spawn_independently() {
let (log, spawner) = recorder();
let g1 = acquire_with(key(2), Some(Duration::from_secs(30)), spawner.clone());
let g2 = acquire_with(key(3), Some(Duration::from_secs(30)), spawner);
assert_eq!(log.lock().unwrap().len(), 2);
drop(g1);
drop(g2);
}
#[test]
fn last_guard_cancels_and_next_acquire_respawns() {
let (log, spawner) = recorder();
let g1 = acquire_with(key(4), Some(Duration::from_secs(30)), spawner.clone());
let g2 = acquire_with(key(4), Some(Duration::from_secs(30)), spawner.clone());
assert_eq!(log.lock().unwrap().len(), 1);
drop(g1);
assert!(
!log.lock().unwrap()[0].0.is_cancelled(),
"sidecar must survive while a guard remains"
);
drop(g2);
assert!(
log.lock().unwrap()[0].0.is_cancelled(),
"dropping the last guard must cancel the sidecar"
);
let g3 = acquire_with(key(4), Some(Duration::from_secs(30)), spawner);
assert_eq!(
log.lock().unwrap().len(),
2,
"fresh acquire after teardown respawns"
);
drop(g3);
assert!(log.lock().unwrap()[1].0.is_cancelled());
}
#[test]
fn no_timeout_guard_downgrades_sidecar_to_trim_only() {
let (log, spawner) = recorder();
let g1 = acquire_with(key(5), Some(Duration::from_secs(30)), spawner.clone());
assert_eq!(
log.lock().unwrap().last().unwrap().1,
Some(Duration::from_secs(30))
);
let g2 = acquire_with(key(5), None, spawner);
{
let entries = log.lock().unwrap();
assert_eq!(entries.len(), 2, "policy change must respawn the sidecar");
assert!(entries[0].0.is_cancelled(), "old sidecar must be cancelled");
assert_eq!(entries[1].1, None, "effective policy must be trim-only");
}
drop(g2);
{
let entries = log.lock().unwrap();
assert_eq!(entries.len(), 3);
assert!(entries[1].0.is_cancelled());
assert_eq!(entries[2].1, Some(Duration::from_secs(30)));
}
drop(g1);
assert!(log.lock().unwrap()[2].0.is_cancelled());
}
#[test]
fn longest_timeout_wins_among_timeout_guards() {
let (log, spawner) = recorder();
let g1 = acquire_with(key(6), Some(Duration::from_secs(30)), spawner.clone());
let g2 = acquire_with(key(6), Some(Duration::from_secs(120)), spawner.clone());
{
let entries = log.lock().unwrap();
assert_eq!(entries.len(), 2, "longer timeout must respawn the sidecar");
assert_eq!(entries[1].1, Some(Duration::from_secs(120)));
}
let g3 = acquire_with(key(6), Some(Duration::from_secs(5)), spawner);
assert_eq!(
log.lock().unwrap().len(),
2,
"shorter timeout must not respawn or lower the deadline"
);
drop(g2);
{
let entries = log.lock().unwrap();
assert_eq!(entries.len(), 3, "dropping the 120s guard recomputes");
assert_eq!(entries[2].1, Some(Duration::from_secs(30)));
}
drop(g1);
drop(g3);
assert!(log.lock().unwrap()[2].0.is_cancelled());
}
}