Skip to main content

github_copilot_sdk/
tool.rs

1//! Typed tool definition framework.
2//!
3//! Provides the [`ToolHandler`](crate::tool::ToolHandler) trait for implementing tools as named types,
4//! and [`ToolHandlerRouter`](crate::tool::ToolHandlerRouter) for automatic dispatch of tool calls within a
5//! [`SessionHandler`](crate::handler::SessionHandler).
6//!
7//! Enable the `derive` feature for `schema_for`, which generates JSON
8//! Schema from Rust types via `schemars`.
9
10use std::collections::HashMap;
11use std::sync::Arc;
12
13use async_trait::async_trait;
14/// Re-export of [`schemars::JsonSchema`] for deriving tool parameter schemas.
15#[cfg(feature = "derive")]
16pub use schemars::JsonSchema;
17
18use crate::Error;
19use crate::handler::{ExitPlanModeResult, PermissionResult, SessionHandler, UserInputResponse};
20use crate::types::{
21    ElicitationRequest, ElicitationResult, ExitPlanModeData, PermissionRequestData, RequestId,
22    SessionEvent, SessionId, Tool, ToolInvocation, ToolResult, ToolResultExpanded,
23};
24
25/// Generate a JSON Schema [`Value`](serde_json::Value) from a Rust type.
26///
27/// Strips `$schema` and `title` root-level metadata so the output is ready
28/// to use as [`Tool::parameters`].
29///
30/// # Example
31///
32/// ```rust
33/// use github_copilot_sdk::tool::{schema_for, JsonSchema};
34///
35/// #[derive(JsonSchema)]
36/// struct Params {
37///     /// City name
38///     city: String,
39/// }
40///
41/// let schema = schema_for::<Params>();
42/// assert_eq!(schema["type"], "object");
43/// assert!(schema["properties"]["city"].is_object());
44/// ```
45#[cfg(feature = "derive")]
46pub fn schema_for<T: schemars::JsonSchema>() -> serde_json::Value {
47    let schema = schemars::schema_for!(T);
48    let mut value = serde_json::to_value(schema).expect("JSON Schema serialization cannot fail");
49    if let Some(obj) = value.as_object_mut() {
50        obj.remove("$schema");
51        obj.remove("title");
52    }
53    value
54}
55
56/// Convert a JSON Schema [`Value`](serde_json::Value) into the
57/// [`Tool::parameters`] map shape expected by the protocol.
58///
59/// Panics if the input is not a JSON object — tool parameter schemas
60/// are always top-level objects (`{"type": "object", ...}`). Pair with
61/// [`schema_for`] or a `serde_json::json!(...)` literal.
62///
63/// Use [`try_tool_parameters`] when the schema comes from dynamic input and
64/// should return a recoverable error instead of panicking.
65///
66/// # Example
67///
68/// ```rust
69/// use github_copilot_sdk::tool::tool_parameters;
70/// use github_copilot_sdk::Tool;
71///
72/// let mut tool = Tool::default();
73/// tool.name = "ping".to_string();
74/// tool.description = "ping the server".to_string();
75/// tool.parameters = tool_parameters(serde_json::json!({"type": "object"}));
76/// # let _ = tool;
77/// ```
78pub fn tool_parameters(schema: serde_json::Value) -> HashMap<String, serde_json::Value> {
79    try_tool_parameters(schema).expect("tool parameter schema must be a JSON object")
80}
81
82/// Fallible variant of [`tool_parameters`] for callers handling dynamic schema input.
83pub fn try_tool_parameters(
84    schema: serde_json::Value,
85) -> Result<HashMap<String, serde_json::Value>, serde_json::Error> {
86    serde_json::from_value(schema)
87}
88
89/// A client-defined tool with its handler logic.
90///
91/// Implement this trait for each tool you expose to the Copilot agent.
92/// The struct is a named type — visible in stack traces and navigable
93/// via "go to definition" — unlike closure-based alternatives.
94///
95/// # Example
96///
97/// ```rust,ignore
98/// use github_copilot_sdk::tool::{schema_for, tool_parameters, JsonSchema, ToolHandler};
99/// use github_copilot_sdk::{Error, Tool, ToolInvocation, ToolResult};
100/// use serde::Deserialize;
101/// use async_trait::async_trait;
102///
103/// #[derive(Deserialize, JsonSchema)]
104/// struct GetWeatherParams {
105///     /// City name
106///     city: String,
107///     /// Temperature unit
108///     unit: Option<String>,
109/// }
110///
111/// struct GetWeatherTool;
112///
113/// #[async_trait]
114/// impl ToolHandler for GetWeatherTool {
115///     fn tool(&self) -> Tool {
116///         Tool {
117///             name: "get_weather".to_string(),
118///             namespaced_name: None,
119///             description: "Get weather for a city".to_string(),
120///             parameters: tool_parameters(schema_for::<GetWeatherParams>()),
121///             instructions: None,
122///             ..Default::default()
123///         }
124///     }
125///
126///     async fn call(&self, inv: ToolInvocation) -> Result<ToolResult, Error> {
127///         let params: GetWeatherParams = serde_json::from_value(inv.arguments)?;
128///         Ok(ToolResult::Text(format!("Weather in {}: sunny", params.city)))
129///     }
130/// }
131/// ```
132#[async_trait]
133pub trait ToolHandler: Send + Sync {
134    /// The tool definition sent to the CLI during session creation.
135    fn tool(&self) -> Tool;
136
137    /// Handle a tool invocation from the agent.
138    async fn call(&self, invocation: ToolInvocation) -> Result<ToolResult, Error>;
139}
140
141/// Define a tool from an async function (or closure) that takes a typed,
142/// `JsonSchema`-derived parameter struct.
143///
144/// The returned `Box<dyn ToolHandler>` plugs directly into
145/// [`ToolHandlerRouter::new`]. JSON Schema for the parameter type is generated
146/// via [`schema_for`] at construction time.
147///
148/// The handler bound (`Fn(ToolInvocation, P) -> Fut + Send + Sync + 'static`)
149/// accepts both bare `async fn` items and closures — the same shape as
150/// [`tower::service_fn`][tower-service-fn] and
151/// [`hyper::service::service_fn`][hyper-service-fn]. Prefer a free `async fn`
152/// for non-trivial tools so it shows up in stack traces by name.
153///
154/// The closure receives the full [`ToolInvocation`] alongside the deserialized
155/// parameters so handlers can use `inv.session_id`, `inv.tool_call_id`, or
156/// other invocation metadata. Handlers that don't need that metadata can
157/// destructure with `|_inv, params|`.
158///
159/// # Example
160///
161/// ```rust,no_run
162/// use github_copilot_sdk::tool::{define_tool, JsonSchema};
163/// use github_copilot_sdk::types::ToolInvocation;
164/// use github_copilot_sdk::{Error, ToolResult};
165/// use serde::Deserialize;
166///
167/// #[derive(Deserialize, JsonSchema)]
168/// struct GetWeatherParams {
169///     /// City name
170///     city: String,
171/// }
172///
173/// async fn get_weather(
174///     inv: ToolInvocation,
175///     params: GetWeatherParams,
176/// ) -> Result<ToolResult, Error> {
177///     // `inv.session_id` and `inv.tool_call_id` are available for telemetry,
178///     // streaming updates, scoping DB lookups, etc.
179///     let _ = inv.session_id;
180///     Ok(ToolResult::Text(format!("Sunny in {}", params.city)))
181/// }
182///
183/// // Pass a free async fn — preferred for non-trivial tools.
184/// let tool = define_tool("get_weather", "Get weather for a city", get_weather);
185///
186/// // ...or an inline closure when the body is trivial.
187/// let tool = define_tool(
188///     "echo",
189///     "Echo the input",
190///     |_inv, params: GetWeatherParams| async move {
191///         Ok(ToolResult::Text(params.city))
192///     },
193/// );
194/// # let _ = tool;
195/// ```
196///
197/// [tower-service-fn]: https://docs.rs/tower/latest/tower/fn.service_fn.html
198/// [hyper-service-fn]: https://docs.rs/hyper/latest/hyper/service/fn.service_fn.html
199#[cfg(feature = "derive")]
200pub fn define_tool<P, F, Fut>(
201    name: impl Into<String>,
202    description: impl Into<String>,
203    handler: F,
204) -> Box<dyn ToolHandler>
205where
206    P: schemars::JsonSchema + serde::de::DeserializeOwned + Send + 'static,
207    F: Fn(ToolInvocation, P) -> Fut + Send + Sync + 'static,
208    Fut: std::future::Future<Output = Result<ToolResult, Error>> + Send + 'static,
209{
210    struct FnTool<P, F> {
211        name: String,
212        description: String,
213        parameters: HashMap<String, serde_json::Value>,
214        handler: F,
215        _marker: std::marker::PhantomData<fn(P)>,
216    }
217
218    #[async_trait]
219    impl<P, F, Fut> ToolHandler for FnTool<P, F>
220    where
221        P: schemars::JsonSchema + serde::de::DeserializeOwned + Send + 'static,
222        F: Fn(ToolInvocation, P) -> Fut + Send + Sync + 'static,
223        Fut: std::future::Future<Output = Result<ToolResult, Error>> + Send + 'static,
224    {
225        fn tool(&self) -> Tool {
226            Tool {
227                name: self.name.clone(),
228                description: self.description.clone(),
229                parameters: self.parameters.clone(),
230                ..Default::default()
231            }
232        }
233
234        async fn call(&self, mut invocation: ToolInvocation) -> Result<ToolResult, Error> {
235            let arguments = std::mem::take(&mut invocation.arguments);
236            let params: P = serde_json::from_value(arguments)?;
237            (self.handler)(invocation, params).await
238        }
239    }
240
241    Box::new(FnTool {
242        name: name.into(),
243        description: description.into(),
244        parameters: tool_parameters(schema_for::<P>()),
245        handler,
246        _marker: std::marker::PhantomData,
247    })
248}
249
250/// A [`SessionHandler`] that dispatches tool calls to registered
251/// [`ToolHandler`] implementations by name.
252///
253/// For tool calls matching a registered handler, the handler is invoked
254/// directly. All other events (permissions, user input, unrecognized tools)
255/// are forwarded to the inner handler.
256///
257/// # Example
258///
259/// ```rust,no_run
260/// use std::sync::Arc;
261/// use github_copilot_sdk::handler::ApproveAllHandler;
262/// use github_copilot_sdk::tool::ToolHandlerRouter;
263///
264/// let router = ToolHandlerRouter::new(
265///     vec![/* Box::new(MyTool), ... */],
266///     Arc::new(ApproveAllHandler),
267/// );
268///
269/// // Use router.tools() in SessionConfig
270/// // Use Arc::new(router) as the session handler
271/// ```
272pub struct ToolHandlerRouter {
273    handlers: HashMap<String, Box<dyn ToolHandler>>,
274    inner: Arc<dyn SessionHandler>,
275}
276
277impl std::fmt::Debug for ToolHandlerRouter {
278    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
279        let mut tools: Vec<_> = self.handlers.keys().collect();
280        tools.sort();
281        f.debug_struct("ToolHandlerRouter")
282            .field("tool_count", &self.handlers.len())
283            .field("tools", &tools)
284            .finish()
285    }
286}
287
288impl ToolHandlerRouter {
289    /// Create a router from tool handler impls and a fallback handler.
290    ///
291    /// Call [`tools()`](Self::tools) to get the tool definitions for
292    /// [`SessionConfig::tools`](crate::SessionConfig::tools).
293    pub fn new(tools: Vec<Box<dyn ToolHandler>>, inner: Arc<dyn SessionHandler>) -> Self {
294        let mut handlers = HashMap::new();
295        for tool in tools {
296            handlers.insert(tool.tool().name.clone(), tool);
297        }
298        Self { handlers, inner }
299    }
300
301    /// Tool definitions for [`SessionConfig::tools`](crate::SessionConfig::tools).
302    pub fn tools(&self) -> Vec<Tool> {
303        self.handlers.values().map(|h| h.tool()).collect()
304    }
305}
306
307#[async_trait]
308impl SessionHandler for ToolHandlerRouter {
309    async fn on_external_tool(&self, invocation: ToolInvocation) -> ToolResult {
310        let Some(handler) = self.handlers.get(&invocation.tool_name) else {
311            return self.inner.on_external_tool(invocation).await;
312        };
313        match handler.call(invocation).await {
314            Ok(result) => result,
315            Err(e) => {
316                let msg = e.to_string();
317                ToolResult::Expanded(ToolResultExpanded {
318                    text_result_for_llm: msg.clone(),
319                    result_type: "failure".to_string(),
320                    session_log: None,
321                    error: Some(msg),
322                })
323            }
324        }
325    }
326
327    async fn on_session_event(&self, session_id: SessionId, event: SessionEvent) {
328        self.inner.on_session_event(session_id, event).await
329    }
330
331    async fn on_permission_request(
332        &self,
333        session_id: SessionId,
334        request_id: RequestId,
335        data: PermissionRequestData,
336    ) -> PermissionResult {
337        self.inner
338            .on_permission_request(session_id, request_id, data)
339            .await
340    }
341
342    async fn on_user_input(
343        &self,
344        session_id: SessionId,
345        question: String,
346        choices: Option<Vec<String>>,
347        allow_freeform: Option<bool>,
348    ) -> Option<UserInputResponse> {
349        self.inner
350            .on_user_input(session_id, question, choices, allow_freeform)
351            .await
352    }
353
354    async fn on_elicitation(
355        &self,
356        session_id: SessionId,
357        request_id: RequestId,
358        request: ElicitationRequest,
359    ) -> ElicitationResult {
360        self.inner
361            .on_elicitation(session_id, request_id, request)
362            .await
363    }
364
365    async fn on_exit_plan_mode(
366        &self,
367        session_id: SessionId,
368        data: ExitPlanModeData,
369    ) -> ExitPlanModeResult {
370        self.inner.on_exit_plan_mode(session_id, data).await
371    }
372}
373
374#[cfg(test)]
375mod tests {
376    use super::*;
377    use crate::types::{PermissionRequestData, RequestId, SessionId};
378
379    struct EchoTool;
380
381    #[async_trait]
382    impl ToolHandler for EchoTool {
383        fn tool(&self) -> Tool {
384            Tool {
385                name: "echo".to_string(),
386                namespaced_name: None,
387                description: "Echo the input".to_string(),
388                parameters: tool_parameters(serde_json::json!({"type": "object"})),
389                instructions: None,
390                ..Default::default()
391            }
392        }
393
394        async fn call(&self, inv: ToolInvocation) -> Result<ToolResult, Error> {
395            Ok(ToolResult::Text(inv.arguments.to_string()))
396        }
397    }
398
399    #[test]
400    fn tool_handler_returns_tool_definition() {
401        let tool = EchoTool;
402        let def = tool.tool();
403        assert_eq!(def.name, "echo");
404        assert_eq!(def.description, "Echo the input");
405        assert!(def.parameters.contains_key("type"));
406    }
407
408    #[test]
409    fn try_tool_parameters_rejects_non_object_schema() {
410        let err = try_tool_parameters(serde_json::json!(["not", "an", "object"]))
411            .expect_err("non-object schemas should be rejected");
412
413        assert!(err.is_data());
414    }
415
416    #[tokio::test]
417    async fn tool_handler_call_returns_result() {
418        let tool = EchoTool;
419        let inv = ToolInvocation {
420            session_id: SessionId::from("s1"),
421            tool_call_id: "tc1".to_string(),
422            tool_name: "echo".to_string(),
423            arguments: serde_json::json!({"msg": "hello"}),
424            traceparent: None,
425            tracestate: None,
426        };
427
428        let result = tool.call(inv).await.unwrap();
429        match result {
430            ToolResult::Text(s) => assert!(s.contains("hello")),
431            _ => panic!("expected Text result"),
432        }
433    }
434
435    #[cfg(feature = "derive")]
436    #[tokio::test]
437    async fn define_tool_builds_schema_and_dispatches() {
438        use serde::Deserialize;
439
440        #[derive(Deserialize, schemars::JsonSchema)]
441        struct Params {
442            city: String,
443        }
444
445        let tool = define_tool(
446            "weather",
447            "Get the weather for a city",
448            |_inv, params: Params| async move {
449                Ok(ToolResult::Text(format!("sunny in {}", params.city)))
450            },
451        );
452
453        let def = tool.tool();
454        assert_eq!(def.name, "weather");
455        assert_eq!(def.description, "Get the weather for a city");
456        assert_eq!(def.parameters["type"], "object");
457        assert!(def.parameters["properties"]["city"].is_object());
458
459        let inv = ToolInvocation {
460            session_id: SessionId::from("s1"),
461            tool_call_id: "tc1".to_string(),
462            tool_name: "weather".to_string(),
463            arguments: serde_json::json!({"city": "Seattle"}),
464            traceparent: None,
465            tracestate: None,
466        };
467        match tool.call(inv).await.unwrap() {
468            ToolResult::Text(s) => assert_eq!(s, "sunny in Seattle"),
469            _ => panic!("expected Text result"),
470        }
471    }
472
473    #[tokio::test]
474    async fn router_dispatches_to_correct_handler() {
475        struct ToolA;
476        #[async_trait]
477        impl ToolHandler for ToolA {
478            fn tool(&self) -> Tool {
479                Tool {
480                    name: "tool_a".to_string(),
481                    namespaced_name: None,
482                    description: "A".to_string(),
483                    parameters: HashMap::new(),
484                    instructions: None,
485                    ..Default::default()
486                }
487            }
488
489            async fn call(&self, _inv: ToolInvocation) -> Result<ToolResult, Error> {
490                Ok(ToolResult::Text("a_result".to_string()))
491            }
492        }
493
494        struct ToolB;
495        #[async_trait]
496        impl ToolHandler for ToolB {
497            fn tool(&self) -> Tool {
498                Tool {
499                    name: "tool_b".to_string(),
500                    namespaced_name: None,
501                    description: "B".to_string(),
502                    parameters: HashMap::new(),
503                    instructions: None,
504                    ..Default::default()
505                }
506            }
507
508            async fn call(&self, _inv: ToolInvocation) -> Result<ToolResult, Error> {
509                Ok(ToolResult::Text("b_result".to_string()))
510            }
511        }
512
513        let router = ToolHandlerRouter::new(
514            vec![Box::new(ToolA), Box::new(ToolB)],
515            Arc::new(crate::handler::ApproveAllHandler),
516        );
517
518        let tools = router.tools();
519        assert_eq!(tools.len(), 2);
520
521        let response = router
522            .on_external_tool(ToolInvocation {
523                session_id: SessionId::from("s1"),
524                tool_call_id: "tc1".to_string(),
525                tool_name: "tool_b".to_string(),
526                arguments: serde_json::json!({}),
527                traceparent: None,
528                tracestate: None,
529            })
530            .await;
531        match response {
532            ToolResult::Text(s) => assert_eq!(s, "b_result"),
533            _ => panic!("expected ToolResult::Text"),
534        }
535    }
536
537    #[tokio::test]
538    async fn router_falls_through_for_unknown_tool() {
539        use std::sync::atomic::{AtomicBool, Ordering};
540
541        struct FallbackHandler {
542            called: AtomicBool,
543        }
544        #[async_trait]
545        impl SessionHandler for FallbackHandler {
546            async fn on_external_tool(&self, _inv: ToolInvocation) -> ToolResult {
547                self.called.store(true, Ordering::Relaxed);
548                ToolResult::Text("fallback".to_string())
549            }
550        }
551
552        let fallback = Arc::new(FallbackHandler {
553            called: AtomicBool::new(false),
554        });
555        let router = ToolHandlerRouter::new(vec![], fallback.clone());
556
557        let response = router
558            .on_external_tool(ToolInvocation {
559                session_id: SessionId::from("s1"),
560                tool_call_id: "tc1".to_string(),
561                tool_name: "unknown".to_string(),
562                arguments: serde_json::json!({}),
563                traceparent: None,
564                tracestate: None,
565            })
566            .await;
567        assert!(fallback.called.load(Ordering::Relaxed));
568        match response {
569            ToolResult::Text(s) => assert_eq!(s, "fallback"),
570            _ => panic!("expected fallback result"),
571        }
572    }
573
574    #[tokio::test]
575    async fn router_returns_failure_on_handler_error() {
576        struct FailTool;
577        #[async_trait]
578        impl ToolHandler for FailTool {
579            fn tool(&self) -> Tool {
580                Tool {
581                    name: "bad_tool".to_string(),
582                    namespaced_name: None,
583                    description: "Always fails".to_string(),
584                    parameters: HashMap::new(),
585                    instructions: None,
586                    ..Default::default()
587                }
588            }
589
590            async fn call(&self, _inv: ToolInvocation) -> Result<ToolResult, Error> {
591                Err(Error::Rpc {
592                    code: -1,
593                    message: "intentional failure".to_string(),
594                })
595            }
596        }
597
598        let router = ToolHandlerRouter::new(
599            vec![Box::new(FailTool)],
600            Arc::new(crate::handler::ApproveAllHandler),
601        );
602
603        let response = router
604            .on_external_tool(ToolInvocation {
605                session_id: SessionId::from("s1"),
606                tool_call_id: "tc1".to_string(),
607                tool_name: "bad_tool".to_string(),
608                arguments: serde_json::json!({}),
609                traceparent: None,
610                tracestate: None,
611            })
612            .await;
613        match response {
614            ToolResult::Expanded(exp) => {
615                assert_eq!(exp.result_type, "failure");
616                assert!(exp.error.unwrap().contains("intentional failure"));
617            }
618            _ => panic!("expected expanded failure result"),
619        }
620    }
621
622    #[tokio::test]
623    async fn router_forwards_non_tool_events() {
624        struct PermHandler;
625        #[async_trait]
626        impl SessionHandler for PermHandler {
627            async fn on_permission_request(
628                &self,
629                _session_id: SessionId,
630                _request_id: RequestId,
631                _data: PermissionRequestData,
632            ) -> PermissionResult {
633                PermissionResult::Denied
634            }
635        }
636
637        let router = ToolHandlerRouter::new(vec![], Arc::new(PermHandler));
638
639        let response = router
640            .on_permission_request(
641                SessionId::from("s1"),
642                RequestId::new("r1"),
643                PermissionRequestData {
644                    extra: serde_json::json!({}),
645                    ..Default::default()
646                },
647            )
648            .await;
649        assert!(matches!(response, PermissionResult::Denied));
650    }
651
652    #[tokio::test]
653    async fn router_default_on_event_dispatches_via_per_event_methods() {
654        // Regression: callers using the legacy on_event entry point should
655        // still get correct dispatch through the inherited default impl.
656        use crate::handler::{HandlerEvent, HandlerResponse};
657
658        struct OkTool;
659        #[async_trait]
660        impl ToolHandler for OkTool {
661            fn tool(&self) -> Tool {
662                Tool {
663                    name: "ok_tool".to_string(),
664                    namespaced_name: None,
665                    description: "ok".to_string(),
666                    parameters: HashMap::new(),
667                    instructions: None,
668                    ..Default::default()
669                }
670            }
671
672            async fn call(&self, _inv: ToolInvocation) -> Result<ToolResult, Error> {
673                Ok(ToolResult::Text("ok".to_string()))
674            }
675        }
676
677        let router = ToolHandlerRouter::new(
678            vec![Box::new(OkTool)],
679            Arc::new(crate::handler::ApproveAllHandler),
680        );
681
682        let response = router
683            .on_event(HandlerEvent::ExternalTool {
684                invocation: ToolInvocation {
685                    session_id: SessionId::from("s1"),
686                    tool_call_id: "tc1".to_string(),
687                    tool_name: "ok_tool".to_string(),
688                    arguments: serde_json::json!({}),
689                    traceparent: None,
690                    tracestate: None,
691                },
692            })
693            .await;
694        match response {
695            HandlerResponse::ToolResult(ToolResult::Text(s)) => assert_eq!(s, "ok"),
696            _ => panic!("expected ToolResult via default on_event"),
697        }
698    }
699
700    // Tests requiring `schemars` (the `derive` feature).
701    #[cfg(feature = "derive")]
702    mod derive_tests {
703        use serde::Deserialize;
704
705        use super::super::*;
706        use crate::SessionId;
707
708        #[derive(Deserialize, schemars::JsonSchema)]
709        struct GetWeatherParams {
710            /// City name to get weather for.
711            city: String,
712            /// Temperature unit (celsius or fahrenheit).
713            unit: Option<String>,
714        }
715
716        #[test]
717        fn schema_for_generates_clean_schema() {
718            let schema = schema_for::<GetWeatherParams>();
719            assert_eq!(schema["type"], "object");
720            assert!(schema["properties"]["city"].is_object());
721            assert!(schema["properties"]["unit"].is_object());
722            // city is required (non-Option), unit is not
723            let required = schema["required"].as_array().unwrap();
724            assert!(required.contains(&serde_json::json!("city")));
725            assert!(!required.contains(&serde_json::json!("unit")));
726            // Root-level metadata stripped
727            assert!(schema.get("$schema").is_none());
728            assert!(schema.get("title").is_none());
729        }
730
731        struct GetWeatherTool;
732
733        #[async_trait]
734        impl ToolHandler for GetWeatherTool {
735            fn tool(&self) -> Tool {
736                Tool {
737                    name: "get_weather".to_string(),
738                    namespaced_name: None,
739                    description: "Get weather for a city".to_string(),
740                    parameters: tool_parameters(schema_for::<GetWeatherParams>()),
741                    instructions: None,
742                    ..Default::default()
743                }
744            }
745
746            async fn call(&self, inv: ToolInvocation) -> Result<ToolResult, Error> {
747                let params: GetWeatherParams = serde_json::from_value(inv.arguments)?;
748                Ok(ToolResult::Text(format!(
749                    "{} {}",
750                    params.city,
751                    params.unit.unwrap_or_default()
752                )))
753            }
754        }
755
756        #[test]
757        fn tool_handler_with_schema_for() {
758            let tool = GetWeatherTool;
759            let def = tool.tool();
760            assert_eq!(def.name, "get_weather");
761            let schema = serde_json::to_value(&def.parameters).expect("serialize tool parameters");
762            assert_eq!(schema["type"], "object");
763            assert!(schema["properties"]["city"].is_object());
764        }
765
766        #[tokio::test]
767        async fn tool_handler_deserializes_typed_params() {
768            let tool = GetWeatherTool;
769            let inv = ToolInvocation {
770                session_id: SessionId::from("s1"),
771                tool_call_id: "tc1".to_string(),
772                tool_name: "get_weather".to_string(),
773                arguments: serde_json::json!({"city": "Seattle", "unit": "celsius"}),
774                traceparent: None,
775                tracestate: None,
776            };
777
778            let result = tool.call(inv).await.unwrap();
779            match result {
780                ToolResult::Text(s) => assert_eq!(s, "Seattle celsius"),
781                _ => panic!("expected Text result"),
782            }
783        }
784
785        #[tokio::test]
786        async fn tool_handler_returns_error_on_bad_params() {
787            let tool = GetWeatherTool;
788            let inv = ToolInvocation {
789                session_id: SessionId::from("s1"),
790                tool_call_id: "tc1".to_string(),
791                tool_name: "get_weather".to_string(),
792                arguments: serde_json::json!({"wrong_field": 42}),
793                traceparent: None,
794                tracestate: None,
795            };
796
797            let err = tool.call(inv).await.unwrap_err();
798            assert!(matches!(err, Error::Json(_)));
799        }
800
801        #[tokio::test]
802        async fn router_with_schema_for_tools() {
803            let router = ToolHandlerRouter::new(
804                vec![Box::new(GetWeatherTool)],
805                Arc::new(crate::handler::ApproveAllHandler),
806            );
807
808            let tools = router.tools();
809            assert_eq!(tools.len(), 1);
810            assert_eq!(tools[0].name, "get_weather");
811
812            let response = router
813                .on_external_tool(ToolInvocation {
814                    session_id: SessionId::from("s1"),
815                    tool_call_id: "tc1".to_string(),
816                    tool_name: "get_weather".to_string(),
817                    arguments: serde_json::json!({"city": "Portland"}),
818                    traceparent: None,
819                    tracestate: None,
820                })
821                .await;
822            match response {
823                ToolResult::Text(s) => assert!(s.contains("Portland")),
824                _ => panic!("expected ToolResult::Text"),
825            }
826        }
827    }
828}