1use std::sync::Arc;
27use std::time::Duration;
28
29use serde_json::Value;
30use tokio_util::sync::CancellationToken;
31
32use crate::tool::{AgentTool, AgentToolResult, ToolFuture};
33
34type MiddlewareFn = Arc<
37 dyn Fn(
38 Arc<dyn AgentTool>,
39 String,
40 Value,
41 CancellationToken,
42 Option<Box<dyn Fn(AgentToolResult) + Send + Sync>>,
43 std::sync::Arc<std::sync::RwLock<crate::SessionState>>,
44 Option<crate::credential::ResolvedCredential>,
45 ) -> ToolFuture<'static>
46 + Send
47 + Sync,
48>;
49
50pub struct ToolMiddleware {
58 inner: Arc<dyn AgentTool>,
59 middleware_fn: MiddlewareFn,
60}
61
62impl ToolMiddleware {
63 pub fn new<F>(inner: Arc<dyn AgentTool>, f: F) -> Self
68 where
69 F: Fn(
70 Arc<dyn AgentTool>,
71 String,
72 Value,
73 CancellationToken,
74 Option<Box<dyn Fn(AgentToolResult) + Send + Sync>>,
75 std::sync::Arc<std::sync::RwLock<crate::SessionState>>,
76 Option<crate::credential::ResolvedCredential>,
77 ) -> ToolFuture<'static>
78 + Send
79 + Sync
80 + 'static,
81 {
82 Self {
83 inner,
84 middleware_fn: Arc::new(f),
85 }
86 }
87
88 pub fn with_timeout(inner: Arc<dyn AgentTool>, timeout: Duration) -> Self {
93 Self::new(
94 inner,
95 move |tool, id, params, cancel, on_update, state, credential| {
96 Box::pin(async move {
97 tokio::select! {
98 result = tool.execute(&id, params, cancel.clone(), on_update, state, credential) => result,
99 () = tokio::time::sleep(timeout) => {
100 cancel.cancel();
101 AgentToolResult::error(format!(
102 "tool timed out after {}ms",
103 timeout.as_millis()
104 ))
105 }
106 }
107 })
108 },
109 )
110 }
111
112 pub fn with_logging<F>(inner: Arc<dyn AgentTool>, callback: F) -> Self
118 where
119 F: Fn(&str, &str, bool) + Send + Sync + 'static,
120 {
121 let callback = Arc::new(callback);
122 Self::new(
123 inner,
124 move |tool, id, params, cancel, on_update, state, credential| {
125 let cb = callback.clone();
126 let name = tool.name().to_owned();
127 Box::pin(async move {
128 cb(&name, &id, true);
129 let result = tool
130 .execute(&id, params, cancel, on_update, state, credential)
131 .await;
132 cb(&name, &id, false);
133 result
134 })
135 },
136 )
137 }
138}
139
140impl AgentTool for ToolMiddleware {
141 fn name(&self) -> &str {
142 self.inner.name()
143 }
144
145 fn label(&self) -> &str {
146 self.inner.label()
147 }
148
149 fn description(&self) -> &str {
150 self.inner.description()
151 }
152
153 fn parameters_schema(&self) -> &Value {
154 self.inner.parameters_schema()
155 }
156
157 fn metadata(&self) -> Option<crate::tool::ToolMetadata> {
158 self.inner.metadata()
159 }
160
161 fn requires_approval(&self) -> bool {
162 self.inner.requires_approval()
163 }
164
165 fn approval_context(&self, params: &Value) -> Option<Value> {
166 self.inner.approval_context(params)
167 }
168
169 fn auth_config(&self) -> Option<crate::credential::AuthConfig> {
170 self.inner.auth_config()
171 }
172
173 fn execute(
174 &self,
175 tool_call_id: &str,
176 params: Value,
177 cancellation_token: CancellationToken,
178 on_update: Option<Box<dyn Fn(AgentToolResult) + Send + Sync>>,
179 state: std::sync::Arc<std::sync::RwLock<crate::SessionState>>,
180 credential: Option<crate::credential::ResolvedCredential>,
181 ) -> ToolFuture<'_> {
182 let inner = self.inner.clone();
183 let id = tool_call_id.to_owned();
184 let fut = (self.middleware_fn)(
185 inner,
186 id,
187 params,
188 cancellation_token,
189 on_update,
190 state,
191 credential,
192 );
193 Box::pin(fut)
194 }
195}
196
197impl std::fmt::Debug for ToolMiddleware {
198 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
199 f.debug_struct("ToolMiddleware")
200 .field("inner_name", &self.inner.name())
201 .finish_non_exhaustive()
202 }
203}
204
205const _: () = {
208 const fn assert_send_sync<T: Send + Sync>() {}
209 assert_send_sync::<ToolMiddleware>();
210};
211
212#[cfg(test)]
213mod tests {
214 use std::sync::atomic::{AtomicU32, Ordering};
215
216 use serde_json::json;
217
218 use super::*;
219 use crate::FnTool;
220 use crate::tool::AgentTool;
221
222 fn dummy_tool() -> Arc<dyn AgentTool> {
223 Arc::new(
224 FnTool::new("dummy", "Dummy", "A dummy tool.")
225 .with_requires_approval(true)
226 .with_execute_simple(|_params, _cancel| async {
227 AgentToolResult::text("dummy result")
228 }),
229 )
230 }
231
232 #[test]
233 fn metadata_and_auth_config_delegate_to_inner() {
234 struct MetadataAuthTool;
235
236 impl AgentTool for MetadataAuthTool {
237 fn name(&self) -> &str {
238 "auth_tool"
239 }
240
241 fn label(&self) -> &str {
242 "Auth Tool"
243 }
244
245 fn description(&self) -> &str {
246 "A tool with metadata and auth config."
247 }
248
249 fn parameters_schema(&self) -> &Value {
250 &Value::Null
251 }
252
253 fn metadata(&self) -> Option<crate::tool::ToolMetadata> {
254 Some(
255 crate::tool::ToolMetadata::with_namespace("middleware-tests")
256 .with_version("1.0.0"),
257 )
258 }
259
260 fn auth_config(&self) -> Option<crate::credential::AuthConfig> {
261 Some(crate::credential::AuthConfig {
262 credential_key: "weather-api".to_string(),
263 auth_scheme: crate::credential::AuthScheme::ApiKeyHeader(
264 "X-Api-Key".to_string(),
265 ),
266 credential_type: crate::credential::CredentialType::ApiKey,
267 })
268 }
269
270 fn execute(
271 &self,
272 _tool_call_id: &str,
273 _params: Value,
274 _cancellation_token: CancellationToken,
275 _on_update: Option<Box<dyn Fn(AgentToolResult) + Send + Sync>>,
276 _state: std::sync::Arc<std::sync::RwLock<crate::SessionState>>,
277 _credential: Option<crate::credential::ResolvedCredential>,
278 ) -> ToolFuture<'_> {
279 Box::pin(async { AgentToolResult::text("ok") })
280 }
281 }
282
283 let inner: Arc<dyn AgentTool> = Arc::new(MetadataAuthTool);
284 let mw = ToolMiddleware::new(
285 inner,
286 |tool, id, params, cancel, on_update, state, credential| {
287 Box::pin(async move {
288 tool.execute(&id, params, cancel, on_update, state, credential)
289 .await
290 })
291 },
292 );
293
294 assert_eq!(mw.name(), "auth_tool");
295 assert_eq!(mw.label(), "Auth Tool");
296 assert_eq!(mw.description(), "A tool with metadata and auth config.");
297 assert!(!mw.requires_approval());
298 assert_eq!(
299 mw.metadata(),
300 Some(
301 crate::tool::ToolMetadata::with_namespace("middleware-tests").with_version("1.0.0"),
302 )
303 );
304
305 let auth_config = mw
306 .auth_config()
307 .expect("middleware should delegate auth config");
308 assert_eq!(auth_config.credential_key, "weather-api");
309 assert!(matches!(
310 auth_config.auth_scheme,
311 crate::credential::AuthScheme::ApiKeyHeader(ref header) if header == "X-Api-Key"
312 ));
313 assert_eq!(
314 auth_config.credential_type,
315 crate::credential::CredentialType::ApiKey
316 );
317 }
318
319 fn test_state() -> std::sync::Arc<std::sync::RwLock<crate::SessionState>> {
320 std::sync::Arc::new(std::sync::RwLock::new(crate::SessionState::new()))
321 }
322
323 #[tokio::test]
324 async fn middleware_intercepts_execute() {
325 let counter = Arc::new(AtomicU32::new(0));
326 let counter_clone = counter.clone();
327
328 let inner: Arc<dyn AgentTool> = dummy_tool();
329 let mw = ToolMiddleware::new(
330 inner,
331 move |tool, id, params, cancel, on_update, state, credential| {
332 let c = counter_clone.clone();
333 Box::pin(async move {
334 c.fetch_add(1, Ordering::SeqCst);
335 tool.execute(&id, params, cancel, on_update, state, credential)
336 .await
337 })
338 },
339 );
340
341 let result = mw
342 .execute(
343 "id",
344 json!({}),
345 CancellationToken::new(),
346 None,
347 test_state(),
348 None,
349 )
350 .await;
351 assert!(!result.is_error);
352 assert_eq!(counter.load(Ordering::SeqCst), 1);
353 }
354
355 #[tokio::test]
356 async fn call_through_returns_inner_result() {
357 let inner: Arc<dyn AgentTool> = dummy_tool();
358 let mw = ToolMiddleware::new(
359 inner,
360 |tool, id, params, cancel, on_update, state, credential| {
361 Box::pin(async move {
362 tool.execute(&id, params, cancel, on_update, state, credential)
363 .await
364 })
365 },
366 );
367
368 let result = mw
369 .execute(
370 "id",
371 json!({}),
372 CancellationToken::new(),
373 None,
374 test_state(),
375 None,
376 )
377 .await;
378 assert!(!result.is_error);
379 }
380
381 #[tokio::test]
382 async fn timeout_middleware_returns_error_on_slow_tool() {
383 struct SlowTool;
385 impl AgentTool for SlowTool {
386 fn name(&self) -> &'static str {
387 "slow"
388 }
389 fn label(&self) -> &'static str {
390 "Slow"
391 }
392 fn description(&self) -> &'static str {
393 "Sleeps."
394 }
395 fn parameters_schema(&self) -> &Value {
396 &Value::Null
397 }
398 fn execute(
399 &self,
400 _id: &str,
401 _params: Value,
402 cancel: CancellationToken,
403 _on_update: Option<Box<dyn Fn(AgentToolResult) + Send + Sync>>,
404 _state: std::sync::Arc<std::sync::RwLock<crate::SessionState>>,
405 _credential: Option<crate::credential::ResolvedCredential>,
406 ) -> ToolFuture<'_> {
407 Box::pin(async move {
408 cancel.cancelled().await;
409 AgentToolResult::error("cancelled")
410 })
411 }
412 }
413
414 let inner: Arc<dyn AgentTool> = Arc::new(SlowTool);
415 let mw = ToolMiddleware::with_timeout(inner, Duration::from_millis(10));
416
417 let result = mw
418 .execute(
419 "id",
420 json!({}),
421 CancellationToken::new(),
422 None,
423 test_state(),
424 None,
425 )
426 .await;
427 assert!(result.is_error);
428 }
429
430 #[tokio::test]
431 async fn logging_middleware_calls_callback() {
432 let calls = Arc::new(AtomicU32::new(0));
433 let calls_clone = calls.clone();
434
435 let inner: Arc<dyn AgentTool> = dummy_tool();
436 let mw = ToolMiddleware::with_logging(inner, move |_name, _id, _is_start| {
437 calls_clone.fetch_add(1, Ordering::SeqCst);
438 });
439
440 mw.execute(
441 "id",
442 json!({}),
443 CancellationToken::new(),
444 None,
445 test_state(),
446 None,
447 )
448 .await;
449
450 assert_eq!(calls.load(Ordering::SeqCst), 2);
452 }
453}