use std::collections::BTreeMap;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Mutex, OnceLock};
use std::time::{Duration, Instant};
use tokio::sync::oneshot;
static INTERCEPT_MODE: AtomicBool = AtomicBool::new(false);
static INTERCEPT_STORE: OnceLock<InterceptStore> = OnceLock::new();
#[must_use]
pub fn intercept_mode_enabled() -> bool {
INTERCEPT_MODE.load(Ordering::Acquire)
}
static MODE_TRANSITION: std::sync::Mutex<()> = std::sync::Mutex::new(());
pub fn toggle_intercept_mode() -> bool {
let _guard = MODE_TRANSITION
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
let prev = INTERCEPT_MODE.fetch_xor(true, Ordering::Release);
let now_on = !prev;
if !now_on {
let _ = global_store().drain_release();
}
now_on
}
pub fn set_intercept_mode(on: bool) {
let _guard = MODE_TRANSITION
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
let prev = INTERCEPT_MODE.swap(on, Ordering::Release);
if prev && !on {
let _ = global_store().drain_release();
}
}
pub fn global_store() -> &'static InterceptStore {
INTERCEPT_STORE.get_or_init(InterceptStore::new)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum InterceptDecision {
Release,
Kill,
}
#[derive(Debug, Clone)]
pub struct PendingIntercept {
pub id: u64,
pub host: String,
pub method: String,
pub path: String,
pub since: Instant,
}
#[derive(Debug, Default, Clone)]
pub struct InterceptStore {
inner: Arc<Mutex<InterceptInner>>,
}
#[derive(Debug, Default)]
struct InterceptInner {
senders: BTreeMap<u64, oneshot::Sender<InterceptDecision>>,
pending: BTreeMap<u64, PendingIntercept>,
next_id: u64,
}
pub const INTERCEPT_TIMEOUT: Duration = Duration::from_secs(30);
impl InterceptStore {
pub fn new() -> Self {
Self::default()
}
pub fn register(
&self,
host: impl Into<String>,
method: impl Into<String>,
path: impl Into<String>,
) -> (u64, oneshot::Receiver<InterceptDecision>) {
let (tx, rx) = oneshot::channel();
let mut inner = self
.inner
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
let dead: Vec<u64> = inner
.senders
.iter()
.filter(|(_, tx)| tx.is_closed())
.map(|(id, _)| *id)
.collect();
for id in dead {
inner.senders.remove(&id);
inner.pending.remove(&id);
}
inner.next_id = inner.next_id.wrapping_add(1);
if inner.next_id == 0 {
inner.next_id = 1;
}
let id = inner.next_id;
inner.senders.insert(id, tx);
inner.pending.insert(
id,
PendingIntercept {
id,
host: host.into(),
method: method.into(),
path: path.into(),
since: Instant::now(),
},
);
(id, rx)
}
pub fn gc_dead_senders(&self) -> usize {
let mut inner = self
.inner
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
let dead: Vec<u64> = inner
.senders
.iter()
.filter(|(_, tx)| tx.is_closed())
.map(|(id, _)| *id)
.collect();
let n = dead.len();
for id in dead {
inner.senders.remove(&id);
inner.pending.remove(&id);
}
n
}
pub fn resolve(&self, id: u64, decision: InterceptDecision) -> bool {
let mut inner = self
.inner
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
inner.pending.remove(&id);
if let Some(tx) = inner.senders.remove(&id) {
let _ = tx.send(decision);
true
} else {
false
}
}
pub fn cancel(&self, id: u64) -> bool {
let mut inner = self
.inner
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
let removed_pending = inner.pending.remove(&id).is_some();
let removed_sender = inner.senders.remove(&id).is_some();
removed_pending || removed_sender
}
pub fn drain_release(&self) -> usize {
let mut inner = self
.inner
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
let ids: Vec<u64> = inner.senders.keys().copied().collect();
let mut released = 0;
for id in ids {
if let Some(tx) = inner.senders.remove(&id) {
inner.pending.remove(&id);
let _ = tx.send(InterceptDecision::Release);
released += 1;
}
}
released
}
pub fn snapshot(&self) -> Vec<PendingIntercept> {
let inner = self
.inner
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
inner.pending.values().cloned().collect()
}
pub fn pending_count(&self) -> usize {
let inner = self
.inner
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
inner.pending.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn store() -> InterceptStore {
InterceptStore::new()
}
#[tokio::test]
async fn register_then_release_unblocks_with_release() {
let s = store();
let (id, rx) = s.register("h", "GET", "/");
let s2 = s.clone();
tokio::spawn(async move {
tokio::time::sleep(Duration::from_millis(10)).await;
s2.resolve(id, InterceptDecision::Release);
});
let decision = rx.await.expect("rx");
assert_eq!(decision, InterceptDecision::Release);
assert_eq!(s.pending_count(), 0, "pending must drain after resolve");
}
#[tokio::test]
async fn register_then_kill_unblocks_with_kill() {
let s = store();
let (id, rx) = s.register("h", "POST", "/admin");
let s2 = s.clone();
tokio::spawn(async move {
s2.resolve(id, InterceptDecision::Kill);
});
assert_eq!(rx.await.unwrap(), InterceptDecision::Kill);
}
#[tokio::test]
async fn snapshot_shows_pending_until_resolved() {
let s = store();
let (id1, _r1) = s.register("a.com", "GET", "/x");
let (id2, _r2) = s.register("b.com", "POST", "/y");
let snap = s.snapshot();
assert_eq!(snap.len(), 2);
assert!(snap.iter().any(|p| p.id == id1 && p.host == "a.com"));
assert!(snap.iter().any(|p| p.id == id2 && p.host == "b.com"));
}
#[tokio::test]
async fn drain_release_unblocks_every_pending() {
let s = store();
let (_, rx1) = s.register("a", "GET", "/");
let (_, rx2) = s.register("b", "GET", "/");
let n = s.drain_release();
assert_eq!(n, 2);
assert_eq!(rx1.await.unwrap(), InterceptDecision::Release);
assert_eq!(rx2.await.unwrap(), InterceptDecision::Release);
assert_eq!(s.pending_count(), 0);
}
#[tokio::test]
async fn resolve_unknown_id_is_idempotent_no_op() {
let s = store();
let acted = s.resolve(999, InterceptDecision::Release);
assert!(!acted, "resolve of unknown id must report it didn't fire");
}
#[tokio::test]
async fn resolve_twice_only_fires_once() {
let s = store();
let (id, rx) = s.register("h", "GET", "/");
assert!(s.resolve(id, InterceptDecision::Release));
assert!(
!s.resolve(id, InterceptDecision::Kill),
"second resolve must no-op"
);
assert_eq!(rx.await.unwrap(), InterceptDecision::Release);
}
#[tokio::test]
async fn timeout_default_release_via_select() {
let s = store();
let (_id, rx) = s.register("h", "GET", "/");
let result = tokio::time::timeout(Duration::from_millis(50), rx).await;
assert!(result.is_err(), "rx must NOT complete on its own");
}
#[tokio::test]
async fn ids_are_monotonic_per_register() {
let s = store();
let (id1, _) = s.register("a", "GET", "/");
let (id2, _) = s.register("a", "GET", "/");
let (id3, _) = s.register("a", "GET", "/");
assert_eq!(id2, id1 + 1);
assert_eq!(id3, id2 + 1);
}
#[test]
fn id_zero_is_reserved_and_resolve_cancel_return_false() {
let s = store();
assert!(!s.resolve(0, InterceptDecision::Release));
assert!(!s.cancel(0));
let (id, _rx) = s.register("h", "GET", "/");
assert_eq!(id, 1, "first id must be 1 (0 is reserved)");
assert!(!s.resolve(0, InterceptDecision::Release));
assert!(!s.cancel(0));
}
#[test]
fn id_wraparound_skips_zero() {
let s = store();
{
let mut inner = s.inner.lock().unwrap();
inner.next_id = u64::MAX - 1;
}
let (id1, _rx1) = s.register("h", "GET", "/");
assert_eq!(id1, u64::MAX, "pre-wraparound id must be u64::MAX");
let (id2, _rx2) = s.register("h", "GET", "/");
assert_eq!(id2, 1, "post-wraparound id must skip 0 and return 1");
assert_ne!(id2, 0, "id=0 must never be issued");
}
#[test]
fn cancel_removes_from_both_maps() {
let s = store();
let (id, _rx) = s.register("h", "GET", "/path");
assert_eq!(s.pending_count(), 1);
let removed = s.cancel(id);
assert!(removed, "cancel must return true for a valid id");
assert_eq!(s.pending_count(), 0, "cancel must drain from pending map");
assert!(!s.cancel(id), "second cancel returns false (already gone)");
}
#[test]
fn gc_dead_senders_removes_disconnected_rx() {
let s = store();
let (id, rx) = s.register("h", "GET", "/");
drop(rx);
let removed = s.gc_dead_senders();
assert_eq!(removed, 1, "exactly one dead sender must be GCd");
assert_eq!(s.pending_count(), 0);
assert!(!s.cancel(id));
}
#[test]
fn resolve_zero_id_never_matches_real_intercept() {
let s = store();
{
let mut inner = s.inner.lock().unwrap();
inner.next_id = u64::MAX;
}
let (id, _rx) = s.register("h", "GET", "/");
assert_eq!(id, 1);
assert!(!s.resolve(0, InterceptDecision::Kill));
assert_eq!(
s.pending_count(),
1,
"id=1 must still be pending after resolve(0)"
);
}
}