use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use serde_json::Value;
use tokio::sync::Notify;
use super::result::LoopToolResult;
use super::types::ToolExecutionMode;
#[derive(Debug, Clone, Default)]
pub struct AbortSignal {
cancelled: Arc<AtomicBool>,
interjected: Arc<AtomicBool>,
notify: Arc<Notify>,
}
impl AbortSignal {
pub fn new() -> Self {
Self::default()
}
pub fn cancel(&self) {
self.cancelled.store(true, Ordering::SeqCst);
self.notify.notify_waiters();
}
pub fn is_cancelled(&self) -> bool {
self.cancelled.load(Ordering::SeqCst)
}
pub async fn cancelled(&self) {
let notified = self.notify.notified();
tokio::pin!(notified);
notified.as_mut().enable();
if self.is_cancelled() {
return;
}
notified.await;
}
pub fn interject(&self) {
self.interjected.store(true, Ordering::SeqCst);
}
pub fn is_interjected(&self) -> bool {
self.interjected.load(Ordering::SeqCst)
}
}
pub type LoopToolUpdate = Arc<dyn Fn(&LoopToolResult) + Send + Sync>;
pub trait LoopTool: Send + Sync + std::fmt::Debug {
fn name(&self) -> &str;
fn description(&self) -> &str;
#[allow(dead_code)]
fn label(&self) -> &str;
fn parameters(&self) -> &Value;
fn flat_parameters(&self) -> Option<&Value> {
None
}
fn execution_mode(&self) -> Option<ToolExecutionMode> {
None
}
fn prepare_arguments(&self, args: Value) -> Value {
args
}
fn execute<'a>(
&'a self,
tool_call_id: &'a str,
args: Value,
signal: AbortSignal,
on_update: LoopToolUpdate,
) -> Pin<Box<dyn Future<Output = Result<LoopToolResult, String>> + Send + 'a>>;
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn abort_signal_shared_state() {
let sig = AbortSignal::new();
assert!(!sig.is_cancelled());
let clone = sig.clone();
sig.cancel();
assert!(clone.is_cancelled(), "clone must see the cancel");
clone.cancel();
assert!(sig.is_cancelled());
}
#[test]
fn abort_signal_default_uncancelled() {
let sig = AbortSignal::default();
assert!(!sig.is_cancelled());
}
#[tokio::test]
async fn cancelled_returns_immediately_when_already_cancelled() {
let sig = AbortSignal::new();
sig.cancel();
tokio::time::timeout(std::time::Duration::from_secs(1), sig.cancelled())
.await
.expect("cancelled() must resolve immediately when already cancelled");
}
#[tokio::test]
async fn cancelled_wakes_on_concurrent_cancel() {
let sig = AbortSignal::new();
let waiter = sig.clone();
let handle = tokio::spawn(async move { waiter.cancelled().await });
tokio::task::yield_now().await;
sig.cancel();
tokio::time::timeout(std::time::Duration::from_secs(1), handle)
.await
.expect("cancelled() must wake promptly on concurrent cancel")
.expect("waiter task panicked");
}
}