1use std::future::Future;
23use std::sync::Arc;
24
25use serde::de::DeserializeOwned;
26use serde_json::Value;
27use tokio_util::sync::CancellationToken;
28
29use crate::tool::{
30 AgentTool, AgentToolResult, ToolFuture, debug_validated_schema, permissive_object_schema,
31 validated_schema_for,
32};
33
34type ExecuteFn = Arc<
37 dyn Fn(
38 String,
39 Value,
40 CancellationToken,
41 Option<Box<dyn Fn(AgentToolResult) + Send + Sync>>,
42 ) -> ToolFuture<'static>
43 + Send
44 + Sync,
45>;
46
47type ApprovalContextFn = Arc<dyn Fn(&Value) -> Option<Value> + Send + Sync>;
48
49pub struct FnTool {
57 name: String,
58 label: String,
59 description: String,
60 schema: Value,
61 requires_approval: bool,
62 execute_fn: ExecuteFn,
63 approval_context_fn: Option<ApprovalContextFn>,
64}
65
66impl FnTool {
67 #[must_use]
72 pub fn new(
73 name: impl Into<String>,
74 label: impl Into<String>,
75 description: impl Into<String>,
76 ) -> Self {
77 Self {
78 name: name.into(),
79 label: label.into(),
80 description: description.into(),
81 schema: permissive_object_schema(),
82 requires_approval: false,
83 execute_fn: Arc::new(|_, _, _, _| {
84 Box::pin(async { AgentToolResult::error("not implemented") })
85 }),
86 approval_context_fn: None,
87 }
88 }
89
90 #[must_use]
93 pub fn with_schema_for<T: schemars::JsonSchema>(mut self) -> Self {
94 self.schema = validated_schema_for::<T>();
95 self
96 }
97
98 #[must_use]
100 pub fn with_schema(mut self, schema: Value) -> Self {
101 self.schema = debug_validated_schema(schema);
102 self
103 }
104
105 #[must_use]
107 pub const fn with_requires_approval(mut self, requires: bool) -> Self {
108 self.requires_approval = requires;
109 self
110 }
111
112 #[must_use]
116 pub fn with_execute<F, Fut>(mut self, f: F) -> Self
117 where
118 F: Fn(
119 String,
120 Value,
121 CancellationToken,
122 Option<Box<dyn Fn(AgentToolResult) + Send + Sync>>,
123 ) -> Fut
124 + Send
125 + Sync
126 + 'static,
127 Fut: Future<Output = AgentToolResult> + Send + 'static,
128 {
129 self.execute_fn = Arc::new(move |id, params, cancel, on_update| {
130 Box::pin(f(id, params, cancel, on_update))
131 });
132 self
133 }
134
135 #[must_use]
140 pub fn with_execute_simple<F, Fut>(mut self, f: F) -> Self
141 where
142 F: Fn(Value, CancellationToken) -> Fut + Send + Sync + 'static,
143 Fut: Future<Output = AgentToolResult> + Send + 'static,
144 {
145 self.execute_fn =
146 Arc::new(move |_id, params, cancel, _on_update| Box::pin(f(params, cancel)));
147 self
148 }
149
150 #[must_use]
155 pub fn with_execute_async<F, Fut>(self, f: F) -> Self
156 where
157 F: Fn(Value, CancellationToken) -> Fut + Send + Sync + 'static,
158 Fut: Future<Output = AgentToolResult> + Send + 'static,
159 {
160 self.with_execute_simple(f)
161 }
162
163 #[must_use]
169 pub fn with_execute_typed<T, F, Fut>(mut self, f: F) -> Self
170 where
171 T: DeserializeOwned + schemars::JsonSchema + Send + 'static,
172 F: Fn(T, CancellationToken) -> Fut + Send + Sync + 'static,
173 Fut: Future<Output = AgentToolResult> + Send + 'static,
174 {
175 self.schema = validated_schema_for::<T>();
176 self.execute_fn = Arc::new(move |_id, params, cancel, _on_update| {
177 let parsed: T = match serde_json::from_value(params) {
178 Ok(parsed) => parsed,
179 Err(err) => {
180 return Box::pin(async move {
181 AgentToolResult::error(format!("invalid parameters: {err}"))
182 });
183 }
184 };
185 Box::pin(f(parsed, cancel))
186 });
187 self
188 }
189
190 #[must_use]
195 pub fn with_approval_context<F>(mut self, f: F) -> Self
196 where
197 F: Fn(&Value) -> Option<Value> + Send + Sync + 'static,
198 {
199 self.approval_context_fn = Some(Arc::new(f));
200 self
201 }
202}
203
204impl AgentTool for FnTool {
205 fn name(&self) -> &str {
206 &self.name
207 }
208
209 fn label(&self) -> &str {
210 &self.label
211 }
212
213 fn description(&self) -> &str {
214 &self.description
215 }
216
217 fn parameters_schema(&self) -> &Value {
218 &self.schema
219 }
220
221 fn requires_approval(&self) -> bool {
222 self.requires_approval
223 }
224
225 fn approval_context(&self, params: &Value) -> Option<Value> {
226 self.approval_context_fn.as_ref().and_then(|f| f(params))
227 }
228
229 fn execute(
230 &self,
231 tool_call_id: &str,
232 params: Value,
233 cancellation_token: CancellationToken,
234 on_update: Option<Box<dyn Fn(AgentToolResult) + Send + Sync>>,
235 _state: std::sync::Arc<std::sync::RwLock<crate::SessionState>>,
236 _credential: Option<crate::credential::ResolvedCredential>,
237 ) -> ToolFuture<'_> {
238 let fut = (self.execute_fn)(
239 tool_call_id.to_owned(),
240 params,
241 cancellation_token,
242 on_update,
243 );
244 Box::pin(fut)
245 }
246}
247
248impl std::fmt::Debug for FnTool {
249 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
250 f.debug_struct("FnTool")
251 .field("name", &self.name)
252 .field("label", &self.label)
253 .field("description", &self.description)
254 .field("requires_approval", &self.requires_approval)
255 .finish_non_exhaustive()
256 }
257}
258
259const _: () = {
262 const fn assert_send_sync<T: Send + Sync>() {}
263 assert_send_sync::<FnTool>();
264};
265
266#[cfg(test)]
267mod tests {
268 use schemars::JsonSchema;
269 use serde::Deserialize;
270 use serde_json::json;
271 use tokio_util::sync::CancellationToken;
272
273 use super::*;
274 use crate::ContentBlock;
275
276 fn test_state() -> std::sync::Arc<std::sync::RwLock<crate::SessionState>> {
277 std::sync::Arc::new(std::sync::RwLock::new(crate::SessionState::new()))
278 }
279
280 fn sample_tool() -> FnTool {
281 FnTool::new("test", "Test", "A test tool.")
282 }
283
284 #[test]
285 fn metadata_matches_constructor() {
286 let tool = sample_tool();
287 assert_eq!(tool.name(), "test");
288 assert_eq!(tool.label(), "Test");
289 assert_eq!(tool.description(), "A test tool.");
290 assert!(!tool.requires_approval());
291 }
292
293 #[tokio::test]
294 async fn default_execute_returns_error() {
295 let tool = sample_tool();
296 let result = tool
297 .execute(
298 "{}",
299 json!({}),
300 CancellationToken::new(),
301 None,
302 test_state(),
303 None,
304 )
305 .await;
306 assert!(result.is_error);
307 }
308
309 #[tokio::test]
310 async fn simple_execute_receives_params() {
311 let tool = FnTool::new("echo", "Echo", "Echo params.").with_execute_simple(
312 |params, _cancel| async move {
313 let msg = params["msg"].as_str().unwrap_or("none").to_owned();
314 AgentToolResult::text(msg)
315 },
316 );
317
318 let result = tool
319 .execute(
320 "id",
321 json!({"msg": "hello"}),
322 CancellationToken::new(),
323 None,
324 test_state(),
325 None,
326 )
327 .await;
328 assert!(!result.is_error);
329 assert_eq!(result.content.len(), 1);
330 }
331
332 #[tokio::test]
333 async fn async_execute_receives_params() {
334 let tool = FnTool::new("echo", "Echo", "Echo params.").with_execute_async(
335 |params, _cancel| async move {
336 let msg = params["msg"].as_str().unwrap_or("none").to_owned();
337 AgentToolResult::text(msg)
338 },
339 );
340
341 let result = tool
342 .execute(
343 "id",
344 json!({"msg": "hello"}),
345 CancellationToken::new(),
346 None,
347 test_state(),
348 None,
349 )
350 .await;
351 assert!(!result.is_error);
352 assert_eq!(ContentBlock::extract_text(&result.content), "hello");
353 }
354
355 #[derive(Deserialize, JsonSchema)]
356 #[allow(dead_code)]
357 struct TestParams {
358 city: String,
359 }
360
361 #[test]
362 fn with_schema_for_sets_schema() {
363 let tool = sample_tool().with_schema_for::<TestParams>();
364 let schema = tool.parameters_schema();
365 assert_eq!(schema["type"], "object");
366 assert!(
367 schema["required"]
368 .as_array()
369 .unwrap()
370 .contains(&json!("city"))
371 );
372 }
373
374 #[test]
375 fn approval_flag_is_configurable() {
376 let tool = sample_tool().with_requires_approval(true);
377 assert!(tool.requires_approval());
378 }
379
380 #[tokio::test]
381 async fn full_execute_receives_all_args() {
382 let tool =
383 FnTool::new("full", "Full", "Full signature.").with_execute(
384 |id, _params, _cancel, _on_update| async move {
385 AgentToolResult::text(format!("id={id}"))
386 },
387 );
388
389 let result = tool
390 .execute(
391 "call_42",
392 json!({}),
393 CancellationToken::new(),
394 None,
395 test_state(),
396 None,
397 )
398 .await;
399 assert!(!result.is_error);
400 }
401
402 #[derive(Deserialize, JsonSchema)]
403 struct TypedParams {
404 city: String,
405 }
406
407 #[tokio::test]
408 async fn typed_execute_deserializes_params_and_sets_schema() {
409 let tool = FnTool::new("typed", "Typed", "Typed params.").with_execute_typed(
410 |params: TypedParams, _cancel| async move { AgentToolResult::text(params.city) },
411 );
412
413 let schema = tool.parameters_schema();
414 assert_eq!(schema["type"], "object");
415 assert!(
416 schema["required"]
417 .as_array()
418 .unwrap()
419 .contains(&json!("city"))
420 );
421
422 let result = tool
423 .execute(
424 "id",
425 json!({"city": "Chicago"}),
426 CancellationToken::new(),
427 None,
428 test_state(),
429 None,
430 )
431 .await;
432 assert!(!result.is_error);
433 assert_eq!(ContentBlock::extract_text(&result.content), "Chicago");
434 }
435
436 #[tokio::test]
437 async fn typed_execute_reports_deserialization_errors() {
438 let tool = FnTool::new("typed", "Typed", "Typed params.").with_execute_typed(
439 |params: TypedParams, _cancel| async move { AgentToolResult::text(params.city) },
440 );
441
442 let result = tool
443 .execute(
444 "id",
445 json!({"city": 42}),
446 CancellationToken::new(),
447 None,
448 test_state(),
449 None,
450 )
451 .await;
452 assert!(result.is_error);
453 assert!(
454 ContentBlock::extract_text(&result.content).contains("invalid parameters"),
455 "expected invalid parameters error, got: {:?}",
456 result.content
457 );
458 }
459}