agentic_tools_core/
context.rs1use crate::ToolError;
4use std::future::Future;
5use tokio_util::sync::CancellationToken;
6use tokio_util::sync::WaitForCancellationFutureOwned;
7
8#[derive(Clone, Debug)]
31pub struct ToolContext {
32 cancel: CancellationToken,
33}
34
35impl Default for ToolContext {
36 fn default() -> Self {
37 Self::with_cancel(CancellationToken::new())
38 }
39}
40
41impl ToolContext {
42 pub fn new() -> Self {
44 Self::default()
45 }
46
47 pub fn with_cancel(cancel: CancellationToken) -> Self {
49 Self { cancel }
50 }
51
52 pub fn cancellation_token(&self) -> CancellationToken {
54 self.cancel.clone()
55 }
56
57 pub fn cancelled(&self) -> WaitForCancellationFutureOwned {
59 self.cancel.clone().cancelled_owned()
60 }
61
62 pub fn is_cancelled(&self) -> bool {
64 self.cancel.is_cancelled()
65 }
66
67 pub async fn run_cancellable<F, T, E>(&self, fut: F) -> Result<T, ToolError>
69 where
70 F: Future<Output = Result<T, E>>,
71 E: Into<ToolError>,
72 {
73 tokio::select! {
74 _ = self.cancelled() => {
75 tracing::info!("tool request cancelled during run_cancellable");
76 Err(ToolError::cancelled(None))
77 }
78 result = fut => result.map_err(Into::into),
79 }
80 }
81}
82
83#[cfg(test)]
84mod tests {
85 use super::*;
86 use tokio::time::Duration;
87 use tokio::time::sleep;
88 use tokio::time::timeout;
89
90 #[tokio::test]
91 async fn default_context_is_never_cancelled() {
92 let ctx = ToolContext::default();
93
94 assert!(!ctx.is_cancelled());
95 assert!(
96 timeout(Duration::from_millis(25), ctx.cancelled())
97 .await
98 .is_err()
99 );
100 }
101
102 #[tokio::test]
103 async fn with_cancel_propagates_cancellation() {
104 let cancel = CancellationToken::new();
105 let ctx = ToolContext::with_cancel(cancel.clone());
106
107 cancel.cancel();
108 ctx.cancelled().await;
109
110 assert!(ctx.is_cancelled());
111 assert!(ctx.cancellation_token().is_cancelled());
112 }
113
114 #[tokio::test]
115 async fn run_cancellable_returns_inner_success() {
116 let ctx = ToolContext::default();
117
118 let result = ctx
119 .run_cancellable(async { Ok::<_, ToolError>("done") })
120 .await;
121
122 assert!(matches!(result, Ok("done")));
123 }
124
125 #[tokio::test]
126 async fn run_cancellable_returns_cancelled_when_request_is_cancelled() {
127 let cancel = CancellationToken::new();
128 let ctx = ToolContext::with_cancel(cancel.clone());
129
130 let canceller = tokio::spawn(async move {
131 sleep(Duration::from_millis(25)).await;
132 cancel.cancel();
133 });
134
135 let result = ctx
136 .run_cancellable(async {
137 sleep(Duration::from_secs(5)).await;
138 Ok::<(), ToolError>(())
139 })
140 .await;
141
142 let join_result = canceller.await;
143 assert!(join_result.is_ok());
144 assert!(matches!(result, Err(ToolError::Cancelled { reason: None })));
145 }
146}