use crate::ToolError;
use std::future::Future;
use tokio_util::sync::CancellationToken;
use tokio_util::sync::WaitForCancellationFutureOwned;
#[derive(Clone, Debug)]
pub struct ToolContext {
cancel: CancellationToken,
}
impl Default for ToolContext {
fn default() -> Self {
Self::with_cancel(CancellationToken::new())
}
}
impl ToolContext {
pub fn new() -> Self {
Self::default()
}
pub fn with_cancel(cancel: CancellationToken) -> Self {
Self { cancel }
}
pub fn cancellation_token(&self) -> CancellationToken {
self.cancel.clone()
}
pub fn cancelled(&self) -> WaitForCancellationFutureOwned {
self.cancel.clone().cancelled_owned()
}
pub fn is_cancelled(&self) -> bool {
self.cancel.is_cancelled()
}
pub async fn run_cancellable<F, T, E>(&self, fut: F) -> Result<T, ToolError>
where
F: Future<Output = Result<T, E>>,
E: Into<ToolError>,
{
tokio::select! {
_ = self.cancelled() => {
tracing::info!("tool request cancelled during run_cancellable");
Err(ToolError::cancelled(None))
}
result = fut => result.map_err(Into::into),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::time::Duration;
use tokio::time::sleep;
use tokio::time::timeout;
#[tokio::test]
async fn default_context_is_never_cancelled() {
let ctx = ToolContext::default();
assert!(!ctx.is_cancelled());
assert!(
timeout(Duration::from_millis(25), ctx.cancelled())
.await
.is_err()
);
}
#[tokio::test]
async fn with_cancel_propagates_cancellation() {
let cancel = CancellationToken::new();
let ctx = ToolContext::with_cancel(cancel.clone());
cancel.cancel();
ctx.cancelled().await;
assert!(ctx.is_cancelled());
assert!(ctx.cancellation_token().is_cancelled());
}
#[tokio::test]
async fn run_cancellable_returns_inner_success() {
let ctx = ToolContext::default();
let result = ctx
.run_cancellable(async { Ok::<_, ToolError>("done") })
.await;
assert!(matches!(result, Ok("done")));
}
#[tokio::test]
async fn run_cancellable_returns_cancelled_when_request_is_cancelled() {
let cancel = CancellationToken::new();
let ctx = ToolContext::with_cancel(cancel.clone());
let canceller = tokio::spawn(async move {
sleep(Duration::from_millis(25)).await;
cancel.cancel();
});
let result = ctx
.run_cancellable(async {
sleep(Duration::from_secs(5)).await;
Ok::<(), ToolError>(())
})
.await;
let join_result = canceller.await;
assert!(join_result.is_ok());
assert!(matches!(result, Err(ToolError::Cancelled { reason: None })));
}
}