use std::cell::RefCell;
use std::collections::HashMap;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Mutex};
use tokio::sync::Notify;
#[derive(Clone, Debug)]
pub struct CancellationDetails {
pub reason: String,
pub inject_reminder: bool,
}
#[derive(Debug)]
struct Inner {
cancelled: AtomicBool,
completed: AtomicBool,
details: Mutex<Option<CancellationDetails>>,
notify: Notify,
completion_notify: Notify,
}
#[derive(Clone, Debug)]
pub struct Handle {
pub session_id: String,
pub call_id: String,
pub tool_name: String,
inner: Arc<Inner>,
}
impl Handle {
fn new(session_id: String, call_id: String, tool_name: String) -> Self {
Self {
session_id,
call_id,
tool_name,
inner: Arc::new(Inner {
cancelled: AtomicBool::new(false),
completed: AtomicBool::new(false),
details: Mutex::new(None),
notify: Notify::new(),
completion_notify: Notify::new(),
}),
}
}
pub fn is_cancelled(&self) -> bool {
self.inner.cancelled.load(Ordering::SeqCst)
}
pub fn is_completed(&self) -> bool {
self.inner.completed.load(Ordering::SeqCst)
}
fn mark_completed(&self) {
if !self.inner.completed.swap(true, Ordering::SeqCst) {
self.inner.completion_notify.notify_waiters();
}
}
pub fn cancel(&self, reason: impl Into<String>, inject_reminder: bool) -> bool {
if self.inner.cancelled.swap(true, Ordering::SeqCst) {
return false;
}
let reason = reason.into();
let mut details = self
.inner
.details
.lock()
.unwrap_or_else(|err| err.into_inner());
*details = Some(CancellationDetails {
reason,
inject_reminder,
});
drop(details);
self.inner.notify.notify_waiters();
true
}
pub fn details(&self) -> Option<CancellationDetails> {
self.inner
.details
.lock()
.unwrap_or_else(|err| err.into_inner())
.clone()
}
pub fn reason(&self) -> Option<String> {
self.details().map(|d| d.reason)
}
pub async fn cancelled(&self) {
if self.is_cancelled() {
return;
}
let notified = self.inner.notify.notified();
if self.is_cancelled() {
return;
}
notified.await;
}
pub async fn completed(&self) {
if self.is_completed() {
return;
}
let notified = self.inner.completion_notify.notified();
if self.is_completed() {
return;
}
notified.await;
}
}
thread_local! {
static REGISTRY: RefCell<HashMap<(String, String), Handle>> =
RefCell::new(HashMap::new());
}
pub struct Guard {
session_id: String,
call_id: String,
handle: Handle,
}
impl Drop for Guard {
fn drop(&mut self) {
if self.call_id.is_empty() {
return;
}
self.handle.mark_completed();
REGISTRY.with(|registry| {
registry
.borrow_mut()
.remove(&(self.session_id.clone(), self.call_id.clone()));
});
}
}
pub fn register(
session_id: impl Into<String>,
call_id: impl Into<String>,
tool_name: impl Into<String>,
) -> Option<(Handle, Guard)> {
let session_id = session_id.into();
let call_id = call_id.into();
let tool_name = tool_name.into();
if call_id.is_empty() {
return None;
}
let handle = Handle::new(session_id.clone(), call_id.clone(), tool_name);
let guard = Guard {
session_id: session_id.clone(),
call_id: call_id.clone(),
handle: handle.clone(),
};
REGISTRY.with(|registry| {
registry
.borrow_mut()
.insert((session_id, call_id), handle.clone());
});
Some((handle, guard))
}
pub fn lookup(session_id: &str, call_id: &str) -> Option<Handle> {
REGISTRY.with(|registry| {
registry
.borrow()
.get(&(session_id.to_string(), call_id.to_string()))
.cloned()
})
}
pub fn list_for_session(session_id: &str) -> Vec<Handle> {
REGISTRY.with(|registry| {
registry
.borrow()
.values()
.filter(|handle| handle.session_id == session_id)
.cloned()
.collect()
})
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum CancelStatus {
Cancelled,
AlreadyCancelled,
NotFound,
}
impl CancelStatus {
pub fn as_str(self) -> &'static str {
match self {
Self::Cancelled => "cancelled",
Self::AlreadyCancelled => "already_cancelled",
Self::NotFound => "not_found",
}
}
}
#[derive(Clone, Debug)]
pub struct CancelOutcome {
pub status: CancelStatus,
pub tool_name: Option<String>,
pub handle: Option<Handle>,
}
pub fn cancel(
session_id: &str,
call_id: &str,
reason: impl Into<String>,
inject_reminder: bool,
) -> CancelOutcome {
let Some(handle) = lookup(session_id, call_id) else {
return CancelOutcome {
status: CancelStatus::NotFound,
tool_name: None,
handle: None,
};
};
let tool_name = Some(handle.tool_name.clone());
let status = if handle.cancel(reason, inject_reminder) {
CancelStatus::Cancelled
} else {
CancelStatus::AlreadyCancelled
};
CancelOutcome {
status,
tool_name,
handle: Some(handle),
}
}
#[cfg(test)]
pub fn clear_registry_for_test() {
REGISTRY.with(|registry| registry.borrow_mut().clear());
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn handle_resolves_cancelled_future() {
clear_registry_for_test();
let (handle, _guard) = register("sess_1", "call_1", "bash").expect("registered");
assert!(!handle.is_cancelled());
let cancelled = handle.cancel("user requested stop", false);
assert!(cancelled);
handle.cancelled().await;
assert!(handle.is_cancelled());
assert_eq!(handle.reason().as_deref(), Some("user requested stop"));
}
#[tokio::test]
async fn cancel_returns_not_found_when_missing() {
clear_registry_for_test();
let outcome = cancel("sess_unknown", "call_unknown", "irrelevant", false);
assert_eq!(outcome.status, CancelStatus::NotFound);
assert_eq!(outcome.tool_name, None);
}
#[tokio::test]
async fn cancel_is_idempotent() {
clear_registry_for_test();
let (_handle, _guard) = register("sess", "call_2", "shell").expect("registered");
let first = cancel("sess", "call_2", "first", false);
let second = cancel("sess", "call_2", "second", false);
assert_eq!(first.status, CancelStatus::Cancelled);
assert_eq!(second.status, CancelStatus::AlreadyCancelled);
assert_eq!(first.tool_name.as_deref(), Some("shell"));
assert_eq!(second.tool_name.as_deref(), Some("shell"));
}
#[tokio::test]
async fn guard_unregisters_on_drop() {
clear_registry_for_test();
{
let _registration = register("sess", "call_g", "tool").expect("registered");
assert!(lookup("sess", "call_g").is_some());
}
assert!(lookup("sess", "call_g").is_none());
}
#[test]
fn cancelled_wakes_pending_waiter() {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.expect("rt");
let local = tokio::task::LocalSet::new();
local.block_on(&rt, async {
clear_registry_for_test();
let (handle, _guard) = register("sess", "call_w", "tool").expect("registered");
let waiter_handle = handle.clone();
let task = tokio::task::spawn_local(async move { waiter_handle.cancelled().await });
tokio::task::yield_now().await;
handle.cancel("stopping", false);
tokio::time::timeout(std::time::Duration::from_secs(1), task)
.await
.expect("task should resolve quickly")
.expect("task did not panic");
});
}
#[tokio::test]
async fn list_for_session_filters_by_session() {
clear_registry_for_test();
let _r1 = register("sess_a", "call_x", "tool").expect("registered");
let _r2 = register("sess_a", "call_y", "tool").expect("registered");
let _r3 = register("sess_b", "call_x", "tool").expect("registered");
let mut ids: Vec<String> = list_for_session("sess_a")
.into_iter()
.map(|h| h.call_id)
.collect();
ids.sort();
assert_eq!(ids, vec!["call_x".to_string(), "call_y".to_string()]);
}
}