Skip to main content

swink_agent/
tool_middleware.rs

1//! Middleware wrapper for [`AgentTool`] that intercepts `execute()` while
2//! delegating all metadata methods to the inner tool.
3//!
4//! # Example
5//!
6//! ```no_run
7//! # #[cfg(feature = "builtin-tools")]
8//! # {
9//! use std::sync::Arc;
10//! use swink_agent::{AgentTool, AgentToolResult, BashTool, ToolMiddleware};
11//!
12//! let tool = Arc::new(BashTool::new());
13//! let logged = ToolMiddleware::new(tool, |inner, id, params, cancel, on_update, state, credential| {
14//!     Box::pin(async move {
15//!         println!("before");
16//!         let result = inner.execute(&id, params, cancel, on_update, state, credential).await;
17//!         println!("after");
18//!         result
19//!     })
20//! });
21//!
22//! assert_eq!(logged.name(), "bash");
23//! # }
24//! ```
25
26use std::sync::Arc;
27use std::time::Duration;
28
29use serde_json::Value;
30use tokio_util::sync::CancellationToken;
31
32use crate::tool::{AgentTool, AgentToolResult, ToolFuture};
33
34// ─── Type alias for the middleware closure ──────────────────────────────────
35
36type MiddlewareFn = Arc<
37    dyn Fn(
38            Arc<dyn AgentTool>,
39            String,
40            Value,
41            CancellationToken,
42            Option<Box<dyn Fn(AgentToolResult) + Send + Sync>>,
43            std::sync::Arc<std::sync::RwLock<crate::SessionState>>,
44            Option<crate::credential::ResolvedCredential>,
45        ) -> ToolFuture<'static>
46        + Send
47        + Sync,
48>;
49
50// ─── ToolMiddleware ─────────────────────────────────────────────────────────
51
52/// Intercepts [`execute()`](AgentTool::execute) on a wrapped [`AgentTool`].
53///
54/// All descriptor methods (`name`, `label`, `description`,
55/// `parameters_schema`, `metadata`, `requires_approval`, `auth_config`)
56/// delegate to the inner tool.
57pub struct ToolMiddleware {
58    inner: Arc<dyn AgentTool>,
59    middleware_fn: MiddlewareFn,
60}
61
62impl ToolMiddleware {
63    /// Create a new middleware wrapping `inner`.
64    ///
65    /// The closure receives `(inner_tool, tool_call_id, params, cancel, on_update, state, credential)`
66    /// and can call through to the inner tool's `execute()` at any point.
67    pub fn new<F>(inner: Arc<dyn AgentTool>, f: F) -> Self
68    where
69        F: Fn(
70                Arc<dyn AgentTool>,
71                String,
72                Value,
73                CancellationToken,
74                Option<Box<dyn Fn(AgentToolResult) + Send + Sync>>,
75                std::sync::Arc<std::sync::RwLock<crate::SessionState>>,
76                Option<crate::credential::ResolvedCredential>,
77            ) -> ToolFuture<'static>
78            + Send
79            + Sync
80            + 'static,
81    {
82        Self {
83            inner,
84            middleware_fn: Arc::new(f),
85        }
86    }
87
88    /// Create a middleware that enforces a timeout on tool execution.
89    ///
90    /// If the inner tool does not complete within `timeout`, an error result
91    /// is returned.
92    pub fn with_timeout(inner: Arc<dyn AgentTool>, timeout: Duration) -> Self {
93        Self::new(
94            inner,
95            move |tool, id, params, cancel, on_update, state, credential| {
96                Box::pin(async move {
97                    tokio::select! {
98                        result = tool.execute(&id, params, cancel.clone(), on_update, state, credential) => result,
99                        () = tokio::time::sleep(timeout) => {
100                            cancel.cancel();
101                            AgentToolResult::error(format!(
102                                "tool timed out after {}ms",
103                                timeout.as_millis()
104                            ))
105                        }
106                    }
107                })
108            },
109        )
110    }
111
112    /// Create a middleware that calls a logging callback before and after
113    /// tool execution.
114    ///
115    /// The callback receives `(tool_name, tool_call_id, is_start)` where
116    /// `is_start` is `true` before execution and `false` after.
117    pub fn with_logging<F>(inner: Arc<dyn AgentTool>, callback: F) -> Self
118    where
119        F: Fn(&str, &str, bool) + Send + Sync + 'static,
120    {
121        let callback = Arc::new(callback);
122        Self::new(
123            inner,
124            move |tool, id, params, cancel, on_update, state, credential| {
125                let cb = callback.clone();
126                let name = tool.name().to_owned();
127                Box::pin(async move {
128                    cb(&name, &id, true);
129                    let result = tool
130                        .execute(&id, params, cancel, on_update, state, credential)
131                        .await;
132                    cb(&name, &id, false);
133                    result
134                })
135            },
136        )
137    }
138}
139
140impl AgentTool for ToolMiddleware {
141    fn name(&self) -> &str {
142        self.inner.name()
143    }
144
145    fn label(&self) -> &str {
146        self.inner.label()
147    }
148
149    fn description(&self) -> &str {
150        self.inner.description()
151    }
152
153    fn parameters_schema(&self) -> &Value {
154        self.inner.parameters_schema()
155    }
156
157    fn metadata(&self) -> Option<crate::tool::ToolMetadata> {
158        self.inner.metadata()
159    }
160
161    fn requires_approval(&self) -> bool {
162        self.inner.requires_approval()
163    }
164
165    fn approval_context(&self, params: &Value) -> Option<Value> {
166        self.inner.approval_context(params)
167    }
168
169    fn auth_config(&self) -> Option<crate::credential::AuthConfig> {
170        self.inner.auth_config()
171    }
172
173    fn execute(
174        &self,
175        tool_call_id: &str,
176        params: Value,
177        cancellation_token: CancellationToken,
178        on_update: Option<Box<dyn Fn(AgentToolResult) + Send + Sync>>,
179        state: std::sync::Arc<std::sync::RwLock<crate::SessionState>>,
180        credential: Option<crate::credential::ResolvedCredential>,
181    ) -> ToolFuture<'_> {
182        let inner = self.inner.clone();
183        let id = tool_call_id.to_owned();
184        let fut = (self.middleware_fn)(
185            inner,
186            id,
187            params,
188            cancellation_token,
189            on_update,
190            state,
191            credential,
192        );
193        Box::pin(fut)
194    }
195}
196
197impl std::fmt::Debug for ToolMiddleware {
198    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
199        f.debug_struct("ToolMiddleware")
200            .field("inner_name", &self.inner.name())
201            .finish_non_exhaustive()
202    }
203}
204
205// ─── Compile-time Send + Sync assertion ─────────────────────────────────────
206
207const _: () = {
208    const fn assert_send_sync<T: Send + Sync>() {}
209    assert_send_sync::<ToolMiddleware>();
210};
211
212#[cfg(test)]
213mod tests {
214    use std::sync::atomic::{AtomicU32, Ordering};
215
216    use serde_json::json;
217
218    use super::*;
219    use crate::FnTool;
220    use crate::tool::AgentTool;
221
222    fn dummy_tool() -> Arc<dyn AgentTool> {
223        Arc::new(
224            FnTool::new("dummy", "Dummy", "A dummy tool.")
225                .with_requires_approval(true)
226                .with_execute_simple(|_params, _cancel| async {
227                    AgentToolResult::text("dummy result")
228                }),
229        )
230    }
231
232    #[test]
233    fn metadata_and_auth_config_delegate_to_inner() {
234        struct MetadataAuthTool;
235
236        impl AgentTool for MetadataAuthTool {
237            fn name(&self) -> &str {
238                "auth_tool"
239            }
240
241            fn label(&self) -> &str {
242                "Auth Tool"
243            }
244
245            fn description(&self) -> &str {
246                "A tool with metadata and auth config."
247            }
248
249            fn parameters_schema(&self) -> &Value {
250                &Value::Null
251            }
252
253            fn metadata(&self) -> Option<crate::tool::ToolMetadata> {
254                Some(
255                    crate::tool::ToolMetadata::with_namespace("middleware-tests")
256                        .with_version("1.0.0"),
257                )
258            }
259
260            fn auth_config(&self) -> Option<crate::credential::AuthConfig> {
261                Some(crate::credential::AuthConfig {
262                    credential_key: "weather-api".to_string(),
263                    auth_scheme: crate::credential::AuthScheme::ApiKeyHeader(
264                        "X-Api-Key".to_string(),
265                    ),
266                    credential_type: crate::credential::CredentialType::ApiKey,
267                })
268            }
269
270            fn execute(
271                &self,
272                _tool_call_id: &str,
273                _params: Value,
274                _cancellation_token: CancellationToken,
275                _on_update: Option<Box<dyn Fn(AgentToolResult) + Send + Sync>>,
276                _state: std::sync::Arc<std::sync::RwLock<crate::SessionState>>,
277                _credential: Option<crate::credential::ResolvedCredential>,
278            ) -> ToolFuture<'_> {
279                Box::pin(async { AgentToolResult::text("ok") })
280            }
281        }
282
283        let inner: Arc<dyn AgentTool> = Arc::new(MetadataAuthTool);
284        let mw = ToolMiddleware::new(
285            inner,
286            |tool, id, params, cancel, on_update, state, credential| {
287                Box::pin(async move {
288                    tool.execute(&id, params, cancel, on_update, state, credential)
289                        .await
290                })
291            },
292        );
293
294        assert_eq!(mw.name(), "auth_tool");
295        assert_eq!(mw.label(), "Auth Tool");
296        assert_eq!(mw.description(), "A tool with metadata and auth config.");
297        assert!(!mw.requires_approval());
298        assert_eq!(
299            mw.metadata(),
300            Some(
301                crate::tool::ToolMetadata::with_namespace("middleware-tests").with_version("1.0.0"),
302            )
303        );
304
305        let auth_config = mw
306            .auth_config()
307            .expect("middleware should delegate auth config");
308        assert_eq!(auth_config.credential_key, "weather-api");
309        assert!(matches!(
310            auth_config.auth_scheme,
311            crate::credential::AuthScheme::ApiKeyHeader(ref header) if header == "X-Api-Key"
312        ));
313        assert_eq!(
314            auth_config.credential_type,
315            crate::credential::CredentialType::ApiKey
316        );
317    }
318
319    fn test_state() -> std::sync::Arc<std::sync::RwLock<crate::SessionState>> {
320        std::sync::Arc::new(std::sync::RwLock::new(crate::SessionState::new()))
321    }
322
323    #[tokio::test]
324    async fn middleware_intercepts_execute() {
325        let counter = Arc::new(AtomicU32::new(0));
326        let counter_clone = counter.clone();
327
328        let inner: Arc<dyn AgentTool> = dummy_tool();
329        let mw = ToolMiddleware::new(
330            inner,
331            move |tool, id, params, cancel, on_update, state, credential| {
332                let c = counter_clone.clone();
333                Box::pin(async move {
334                    c.fetch_add(1, Ordering::SeqCst);
335                    tool.execute(&id, params, cancel, on_update, state, credential)
336                        .await
337                })
338            },
339        );
340
341        let result = mw
342            .execute(
343                "id",
344                json!({}),
345                CancellationToken::new(),
346                None,
347                test_state(),
348                None,
349            )
350            .await;
351        assert!(!result.is_error);
352        assert_eq!(counter.load(Ordering::SeqCst), 1);
353    }
354
355    #[tokio::test]
356    async fn call_through_returns_inner_result() {
357        let inner: Arc<dyn AgentTool> = dummy_tool();
358        let mw = ToolMiddleware::new(
359            inner,
360            |tool, id, params, cancel, on_update, state, credential| {
361                Box::pin(async move {
362                    tool.execute(&id, params, cancel, on_update, state, credential)
363                        .await
364                })
365            },
366        );
367
368        let result = mw
369            .execute(
370                "id",
371                json!({}),
372                CancellationToken::new(),
373                None,
374                test_state(),
375                None,
376            )
377            .await;
378        assert!(!result.is_error);
379    }
380
381    #[tokio::test]
382    async fn timeout_middleware_returns_error_on_slow_tool() {
383        /// A tool that sleeps forever.
384        struct SlowTool;
385        impl AgentTool for SlowTool {
386            fn name(&self) -> &'static str {
387                "slow"
388            }
389            fn label(&self) -> &'static str {
390                "Slow"
391            }
392            fn description(&self) -> &'static str {
393                "Sleeps."
394            }
395            fn parameters_schema(&self) -> &Value {
396                &Value::Null
397            }
398            fn execute(
399                &self,
400                _id: &str,
401                _params: Value,
402                cancel: CancellationToken,
403                _on_update: Option<Box<dyn Fn(AgentToolResult) + Send + Sync>>,
404                _state: std::sync::Arc<std::sync::RwLock<crate::SessionState>>,
405                _credential: Option<crate::credential::ResolvedCredential>,
406            ) -> ToolFuture<'_> {
407                Box::pin(async move {
408                    cancel.cancelled().await;
409                    AgentToolResult::error("cancelled")
410                })
411            }
412        }
413
414        let inner: Arc<dyn AgentTool> = Arc::new(SlowTool);
415        let mw = ToolMiddleware::with_timeout(inner, Duration::from_millis(10));
416
417        let result = mw
418            .execute(
419                "id",
420                json!({}),
421                CancellationToken::new(),
422                None,
423                test_state(),
424                None,
425            )
426            .await;
427        assert!(result.is_error);
428    }
429
430    #[tokio::test]
431    async fn logging_middleware_calls_callback() {
432        let calls = Arc::new(AtomicU32::new(0));
433        let calls_clone = calls.clone();
434
435        let inner: Arc<dyn AgentTool> = dummy_tool();
436        let mw = ToolMiddleware::with_logging(inner, move |_name, _id, _is_start| {
437            calls_clone.fetch_add(1, Ordering::SeqCst);
438        });
439
440        mw.execute(
441            "id",
442            json!({}),
443            CancellationToken::new(),
444            None,
445            test_state(),
446            None,
447        )
448        .await;
449
450        // Should be called twice — once before, once after.
451        assert_eq!(calls.load(Ordering::SeqCst), 2);
452    }
453}