use std::future::Future;
use anyhow::Result;
use rmcp::ErrorData as McpError;
use tokio_util::sync::CancellationToken;
const CANCELLED_MESSAGE: &str = "request cancelled";
pub async fn spawn_blocking_cancellable<F, T>(
token: &CancellationToken,
f: F,
) -> Result<T, McpError>
where
F: FnOnce() -> Result<T> + Send + 'static,
T: Send + 'static,
{
let handle = tokio::task::spawn_blocking(f);
tokio::select! {
() = token.cancelled() => Err(cancelled_error()),
join = handle => match join {
Ok(Ok(value)) => Ok(value),
Ok(Err(domain)) => Err(crate::mcp::error::tool_error(domain)),
Err(join) => Err(crate::mcp::error::tool_error(
anyhow::anyhow!("join error: {join}"),
)),
},
}
}
pub async fn cancellable<F, T>(token: &CancellationToken, fut: F) -> Result<T, McpError>
where
F: Future<Output = Result<T, McpError>>,
{
tokio::select! {
() = token.cancelled() => Err(cancelled_error()),
result = fut => result,
}
}
pub fn cancelled_error() -> McpError {
McpError::internal_error(CANCELLED_MESSAGE.to_string(), None)
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
#[tokio::test]
async fn spawn_blocking_returns_value_when_not_cancelled() {
let token = CancellationToken::new();
let value = spawn_blocking_cancellable(&token, || Ok::<_, anyhow::Error>(42))
.await
.unwrap();
assert_eq!(value, 42);
}
#[tokio::test]
async fn spawn_blocking_returns_early_on_cancellation() {
let token = CancellationToken::new();
let child = token.child_token();
tokio::spawn(async move {
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
token.cancel();
});
let result: Result<(), _> = spawn_blocking_cancellable(&child, || {
std::thread::sleep(std::time::Duration::from_millis(500));
Ok(())
})
.await;
let err = result.expect_err("should be cancelled");
assert!(err.message.contains("cancelled"), "got: {}", err.message);
}
#[tokio::test]
async fn spawn_blocking_surfaces_join_error_when_task_panics() {
let token = CancellationToken::new();
let result: Result<(), _> = spawn_blocking_cancellable(&token, || {
panic!("blocking task panicked");
})
.await;
let err = result.expect_err("panic should surface as a join error");
assert!(
err.message.contains("join error"),
"expected join error, got: {}",
err.message
);
}
#[tokio::test]
async fn spawn_blocking_propagates_domain_error() {
let token = CancellationToken::new();
let result: Result<(), _> =
spawn_blocking_cancellable(&token, || Err(anyhow::anyhow!("domain failure"))).await;
let err = result.expect_err("should propagate error");
assert!(err.message.contains("domain failure"));
}
#[tokio::test]
async fn cancellable_returns_value_when_not_cancelled() {
let token = CancellationToken::new();
let value = cancellable(&token, async { Ok::<_, McpError>(7) })
.await
.unwrap();
assert_eq!(value, 7);
}
#[tokio::test]
async fn cancellable_returns_early_on_cancellation() {
let token = CancellationToken::new();
let child = token.child_token();
tokio::spawn(async move {
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
token.cancel();
});
let result: Result<(), _> = cancellable(&child, std::future::pending()).await;
let err = result.expect_err("should be cancelled");
assert!(err.message.contains("cancelled"));
}
#[tokio::test]
async fn cancellable_preserves_inner_error() {
let token = CancellationToken::new();
let inner = McpError::invalid_params("bad".to_string(), None);
let result: Result<(), _> =
cancellable(&token, async move { Err::<(), _>(inner.clone()) }).await;
let err = result.expect_err("should propagate inner error");
assert!(err.message.contains("bad"));
}
#[test]
fn cancelled_error_has_expected_message() {
let err = cancelled_error();
assert_eq!(err.message, CANCELLED_MESSAGE);
}
}