use std::collections::HashMap;
use std::fmt;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Mutex, RwLock};
use tokio::sync::oneshot;
use tokio::task::{AbortHandle, JoinHandle};
use tokio_util::sync::CancellationToken;
use turul_a2a_types::Task;
use crate::server::obs::TARGET_SUPERVISOR_PANIC;
pub type InFlightKey = (String, String);
pub struct InFlightHandle {
pub cancellation: CancellationToken,
yielded_tx: Mutex<Option<oneshot::Sender<Task>>>,
yielded_fired: AtomicBool,
spawned: Mutex<Option<JoinHandle<()>>>,
abort_handle: Mutex<AbortHandle>,
}
impl InFlightHandle {
pub fn new(
cancellation: CancellationToken,
yielded_tx: oneshot::Sender<Task>,
spawned: JoinHandle<()>,
) -> Self {
let abort_handle = spawned.abort_handle();
Self {
cancellation,
yielded_tx: Mutex::new(Some(yielded_tx)),
yielded_fired: AtomicBool::new(false),
spawned: Mutex::new(Some(spawned)),
abort_handle: Mutex::new(abort_handle),
}
}
pub fn abort(&self) {
let guard = self
.abort_handle
.lock()
.expect("abort_handle Mutex poisoned");
guard.abort();
}
pub fn fire_yielded(&self, task: Task) -> bool {
match self
.yielded_fired
.compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
{
Ok(_) => {
let sender = self
.yielded_tx
.lock()
.expect("yielded_tx Mutex poisoned")
.take();
if let Some(tx) = sender {
let _ = tx.send(task);
}
true
}
Err(_) => false,
}
}
pub fn yielded_fired(&self) -> bool {
self.yielded_fired.load(Ordering::Acquire)
}
pub fn set_spawned(&self, spawned: JoinHandle<()>) {
let new_abort = spawned.abort_handle();
*self.spawned.lock().expect("spawned Mutex poisoned") = Some(spawned);
*self
.abort_handle
.lock()
.expect("abort_handle Mutex poisoned") = new_abort;
}
pub fn take_spawned(&self) -> Option<JoinHandle<()>> {
self.spawned.lock().expect("spawned Mutex poisoned").take()
}
}
#[derive(Default)]
pub struct InFlightRegistry {
map: RwLock<HashMap<InFlightKey, Arc<InFlightHandle>>>,
}
impl InFlightRegistry {
pub fn new() -> Self {
Self::default()
}
pub fn try_insert(
&self,
key: InFlightKey,
handle: Arc<InFlightHandle>,
) -> Result<(), InsertCollision> {
use std::collections::hash_map::Entry;
let mut map = self.map.write().expect("InFlightRegistry RwLock poisoned");
match map.entry(key) {
Entry::Occupied(occ) => Err(InsertCollision {
key: occ.key().clone(),
existing: occ.get().clone(),
}),
Entry::Vacant(vac) => {
vac.insert(handle);
Ok(())
}
}
}
pub fn get(&self, key: &InFlightKey) -> Option<Arc<InFlightHandle>> {
self.map
.read()
.expect("InFlightRegistry RwLock poisoned")
.get(key)
.cloned()
}
pub fn remove_if_current(&self, key: &InFlightKey, handle: &Arc<InFlightHandle>) -> bool {
let mut map = self.map.write().expect("InFlightRegistry RwLock poisoned");
match map.get(key) {
Some(existing) if Arc::ptr_eq(existing, handle) => {
map.remove(key);
true
}
_ => false,
}
}
pub fn len(&self) -> usize {
self.map
.read()
.expect("InFlightRegistry RwLock poisoned")
.len()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub(crate) fn snapshot_by_tenant(&self) -> std::collections::HashMap<String, Vec<String>> {
let map = self.map.read().expect("InFlightRegistry RwLock poisoned");
let mut out: std::collections::HashMap<String, Vec<String>> =
std::collections::HashMap::new();
for (tenant, task_id) in map.keys() {
out.entry(tenant.clone()).or_default().push(task_id.clone());
}
out
}
pub(crate) fn get_handle(&self, key: &InFlightKey) -> Option<Arc<InFlightHandle>> {
self.get(key)
}
}
pub struct InsertCollision {
pub key: InFlightKey,
pub existing: Arc<InFlightHandle>,
}
impl fmt::Debug for InsertCollision {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("InsertCollision")
.field("key", &self.key)
.field("existing_ptr", &Arc::as_ptr(&self.existing))
.finish()
}
}
impl fmt::Display for InsertCollision {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let (tenant, task_id) = &self.key;
write!(
f,
"in-flight registry collision for tenant={tenant}, task_id={task_id}; \
a spawn for this key is already active"
)
}
}
impl std::error::Error for InsertCollision {}
pub struct SupervisorSentinel {
registry: Arc<InFlightRegistry>,
key: InFlightKey,
handle: Arc<InFlightHandle>,
}
impl SupervisorSentinel {
pub fn new(
registry: Arc<InFlightRegistry>,
key: InFlightKey,
handle: Arc<InFlightHandle>,
) -> Self {
Self {
registry,
key,
handle,
}
}
}
impl Drop for SupervisorSentinel {
fn drop(&mut self) {
let panicking = std::thread::panicking();
self.handle.abort();
self.registry.remove_if_current(&self.key, &self.handle);
if panicking {
let (tenant, task_id) = &self.key;
tracing::error!(
target: TARGET_SUPERVISOR_PANIC,
tenant = %tenant,
task_id = %task_id,
"supervisor task panicked; cleanup ran via SupervisorSentinel Drop"
);
}
}
}
pub async fn run_cross_instance_cancel_poller(
registry: std::sync::Arc<InFlightRegistry>,
supervisor: std::sync::Arc<dyn crate::storage::A2aCancellationSupervisor>,
interval: std::time::Duration,
shutdown: tokio_util::sync::CancellationToken,
) {
loop {
tokio::select! {
_ = tokio::time::sleep(interval) => {}
_ = shutdown.cancelled() => {
tracing::debug!(
target: "turul_a2a::cross_instance_cancel_poll",
"cross-instance cancel poller shutting down"
);
return;
}
}
poll_once(®istry, supervisor.as_ref()).await;
}
}
pub async fn poll_once_for_tests(
registry: &InFlightRegistry,
supervisor: &dyn crate::storage::A2aCancellationSupervisor,
) {
poll_once(registry, supervisor).await
}
async fn poll_once(
registry: &InFlightRegistry,
supervisor: &dyn crate::storage::A2aCancellationSupervisor,
) {
let groups = registry.snapshot_by_tenant();
for (tenant, task_ids) in groups {
match supervisor
.supervisor_list_cancel_requested(&tenant, &task_ids)
.await
{
Ok(marked) => {
for task_id in marked {
let key = (tenant.clone(), task_id);
if let Some(handle) = registry.get_handle(&key) {
handle.cancellation.cancel();
}
}
}
Err(e) => {
tracing::warn!(
target: "turul_a2a::cross_instance_cancel_poll_error",
tenant = %tenant,
error = %e,
"cross-instance cancel poll failed; will retry next tick"
);
}
}
}
}