use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, OnceLock};
use std::time::{Duration, Instant};
use parking_lot::Mutex;
use tokio::sync::Notify;
pub const ORPHAN_TTL: Duration = Duration::from_secs(120);
static NEXT_CANCEL_TOKEN: AtomicU64 = AtomicU64::new(1);
fn never_firing_notify() -> Arc<Notify> {
static NOTIFY: OnceLock<Arc<Notify>> = OnceLock::new();
NOTIFY.get_or_init(|| Arc::new(Notify::new())).clone()
}
const GC_INTERVAL: Duration = Duration::from_secs(1);
pub(crate) fn next_token() -> u64 {
NEXT_CANCEL_TOKEN.fetch_add(1, Ordering::Relaxed)
}
struct CancelEntry {
pre_cancelled: bool,
notify: Option<Arc<Notify>>,
marked_at: Option<Instant>,
}
impl CancelEntry {
fn new() -> Self {
Self {
pre_cancelled: false,
notify: None,
marked_at: None,
}
}
}
pub struct CancelRegistry {
entries: Mutex<RegistryInner>,
}
struct RegistryInner {
entries: HashMap<u64, CancelEntry>,
last_gc: Instant,
}
impl Default for CancelRegistry {
fn default() -> Self {
Self::new()
}
}
impl CancelRegistry {
pub fn new() -> Self {
Self {
entries: Mutex::new(RegistryInner {
entries: HashMap::new(),
last_gc: Instant::now(),
}),
}
}
pub fn reserve_token(&self) -> u64 {
next_token()
}
pub fn cancel(&self, token: u64) {
if token == 0 {
return;
}
let notify = {
let mut inner = self.entries.lock();
Self::maybe_gc(&mut inner);
let entry = inner.entries.entry(token).or_insert_with(CancelEntry::new);
entry.pre_cancelled = true;
if entry.marked_at.is_none() {
entry.marked_at = Some(Instant::now());
}
entry.notify.clone()
};
if let Some(notify) = notify {
notify.notify_one();
}
}
pub fn register_notify(&self, token: u64) -> Arc<Notify> {
if token == 0 {
return never_firing_notify();
}
let (notify, was_precancelled) = {
let mut inner = self.entries.lock();
Self::maybe_gc(&mut inner);
let entry = inner.entries.entry(token).or_insert_with(CancelEntry::new);
let notify = entry
.notify
.get_or_insert_with(|| Arc::new(Notify::new()))
.clone();
let was_precancelled = entry.pre_cancelled;
entry.marked_at = None;
(notify, was_precancelled)
};
if was_precancelled {
notify.notify_one();
}
notify
}
pub fn release(&self, token: u64) {
if token == 0 {
return;
}
let mut inner = self.entries.lock();
inner.entries.remove(&token);
}
fn maybe_gc(inner: &mut RegistryInner) {
let now = Instant::now();
if now.duration_since(inner.last_gc) < GC_INTERVAL {
return;
}
inner.last_gc = now;
Self::gc(&mut inner.entries);
}
fn gc(entries: &mut HashMap<u64, CancelEntry>) {
let now = Instant::now();
entries.retain(|_, entry| {
if entry.notify.is_some() {
return true;
}
match entry.marked_at {
Some(t) => now.duration_since(t) < ORPHAN_TTL,
None => true,
}
});
}
pub fn len(&self) -> usize {
self.entries.lock().entries.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn cancel_zero_token_is_noop() {
let reg = CancelRegistry::new();
reg.cancel(0);
assert_eq!(reg.len(), 0, "cancel(0) must not create an entry");
}
#[test]
fn register_zero_token_returns_never_firing_notify() {
let reg = CancelRegistry::new();
let notify = reg.register_notify(0);
assert_eq!(reg.len(), 0);
let _ = notify;
}
#[test]
fn release_zero_token_is_noop() {
let reg = CancelRegistry::new();
reg.release(0);
assert_eq!(reg.len(), 0);
}
#[tokio::test]
async fn cancel_then_register_pre_arms_notify() {
let reg = CancelRegistry::new();
let token = reg.reserve_token();
reg.cancel(token);
let notify = reg.register_notify(token);
tokio::time::timeout(Duration::from_millis(100), notify.notified())
.await
.expect("pre-armed Notify must fire immediately");
}
#[tokio::test]
async fn register_then_cancel_wakes_waiter() {
let reg = CancelRegistry::new();
let token = reg.reserve_token();
let notify = reg.register_notify(token);
let reg2 = std::sync::Arc::new(reg);
let reg2_clone = reg2.clone();
tokio::spawn(async move {
tokio::time::sleep(Duration::from_millis(10)).await;
reg2_clone.cancel(token);
});
tokio::time::timeout(Duration::from_millis(500), notify.notified())
.await
.expect("register-then-cancel must wake the waiter");
}
#[test]
fn release_removes_entry() {
let reg = CancelRegistry::new();
let token = reg.reserve_token();
let _notify = reg.register_notify(token);
assert_eq!(reg.len(), 1);
reg.release(token);
assert_eq!(reg.len(), 0);
reg.release(token);
assert_eq!(reg.len(), 0);
}
#[test]
fn cancel_after_release_is_safe() {
let reg = CancelRegistry::new();
let token = reg.reserve_token();
let _notify = reg.register_notify(token);
reg.release(token);
reg.cancel(token);
assert!(reg.len() <= 1);
}
#[test]
fn next_token_is_monotonic_and_nonzero() {
let a = next_token();
let b = next_token();
let c = next_token();
assert!(a >= 1, "tokens start at 1, not 0");
assert!(b > a);
assert!(c > b);
}
#[test]
fn zero_token_returns_shared_never_firing_notify() {
let reg = CancelRegistry::new();
let a = reg.register_notify(0);
let b = reg.register_notify(0);
assert!(
Arc::ptr_eq(&a, &b),
"both no-cancel registrations must hand back the same Arc<Notify>"
);
let c = never_firing_notify();
assert!(Arc::ptr_eq(&a, &c));
assert_eq!(reg.len(), 0);
}
#[test]
fn gc_rate_limited_across_burst() {
let reg = CancelRegistry::new();
{
let mut inner = reg.entries.lock();
inner.last_gc = Instant::now();
}
let stale = next_token();
{
let mut inner = reg.entries.lock();
let entry = inner.entries.entry(stale).or_insert_with(CancelEntry::new);
entry.pre_cancelled = true;
entry.marked_at = Some(Instant::now() - (ORPHAN_TTL * 2));
}
let _ = reg.register_notify(next_token());
assert!(
reg.entries.lock().entries.contains_key(&stale),
"stale entry survives because gc is rate-limited"
);
}
#[tokio::test]
async fn pre_arm_works_with_lock_released_notify() {
let reg = CancelRegistry::new();
let token = reg.reserve_token();
reg.cancel(token);
let notify = reg.register_notify(token);
tokio::time::timeout(Duration::from_millis(100), notify.notified())
.await
.expect("pre-armed Notify fires even with lock-narrowed register");
}
}