Skip to main content

github_copilot_sdk/
tool.rs

1//! Typed tool definition framework.
2//!
3//! Provides the [`ToolHandler`](crate::tool::ToolHandler) trait for implementing tools as named types,
4//! and [`ToolHandlerRouter`](crate::tool::ToolHandlerRouter) for automatic dispatch of tool calls within a
5//! [`SessionHandler`](crate::handler::SessionHandler).
6//!
7//! Enable the `derive` feature for `schema_for`, which generates JSON
8//! Schema from Rust types via `schemars`.
9
10use std::collections::HashMap;
11use std::sync::Arc;
12
13use async_trait::async_trait;
14/// Re-export of [`schemars::JsonSchema`] for deriving tool parameter schemas.
15#[cfg(feature = "derive")]
16pub use schemars::JsonSchema;
17
18use crate::Error;
19use crate::handler::{PermissionResult, SessionHandler, UserInputResponse};
20use crate::types::{
21    ElicitationRequest, ElicitationResult, PermissionRequestData, RequestId, SessionEvent,
22    SessionId, Tool, ToolBinaryResult, ToolInvocation, ToolResult, ToolResultExpanded,
23};
24
25/// Generate a JSON Schema [`Value`](serde_json::Value) from a Rust type.
26///
27/// Strips `$schema` and `title` root-level metadata so the output is ready
28/// to use as [`Tool::parameters`].
29///
30/// # Example
31///
32/// ```rust
33/// use github_copilot_sdk::tool::{schema_for, JsonSchema};
34///
35/// #[derive(JsonSchema)]
36/// struct Params {
37///     /// City name
38///     city: String,
39/// }
40///
41/// let schema = schema_for::<Params>();
42/// assert_eq!(schema["type"], "object");
43/// assert!(schema["properties"]["city"].is_object());
44/// ```
45#[cfg(feature = "derive")]
46pub fn schema_for<T: schemars::JsonSchema>() -> serde_json::Value {
47    let schema = schemars::schema_for!(T);
48    let mut value = serde_json::to_value(schema).expect("JSON Schema serialization cannot fail");
49    if let Some(obj) = value.as_object_mut() {
50        obj.remove("$schema");
51        obj.remove("title");
52    }
53    value
54}
55
56/// Convert a JSON Schema [`Value`](serde_json::Value) into the
57/// [`Tool::parameters`] map shape expected by the protocol.
58///
59/// Panics if the input is not a JSON object — tool parameter schemas
60/// are always top-level objects (`{"type": "object", ...}`). Pair with
61/// [`schema_for`] or a `serde_json::json!(...)` literal.
62///
63/// Use [`try_tool_parameters`] when the schema comes from dynamic input and
64/// should return a recoverable error instead of panicking.
65///
66/// # Example
67///
68/// ```rust
69/// use github_copilot_sdk::tool::tool_parameters;
70/// use github_copilot_sdk::Tool;
71///
72/// let mut tool = Tool::default();
73/// tool.name = "ping".to_string();
74/// tool.description = "ping the server".to_string();
75/// tool.parameters = tool_parameters(serde_json::json!({"type": "object"}));
76/// # let _ = tool;
77/// ```
78pub fn tool_parameters(schema: serde_json::Value) -> HashMap<String, serde_json::Value> {
79    try_tool_parameters(schema).expect("tool parameter schema must be a JSON object")
80}
81
82/// Fallible variant of [`tool_parameters`] for callers handling dynamic schema input.
83pub fn try_tool_parameters(
84    schema: serde_json::Value,
85) -> Result<HashMap<String, serde_json::Value>, serde_json::Error> {
86    serde_json::from_value(schema)
87}
88
89/// Convert an MCP `CallToolResult` JSON value into a Copilot tool result.
90///
91/// Returns `None` when the value is not shaped like a `CallToolResult`.
92pub fn convert_mcp_call_tool_result(value: &serde_json::Value) -> Option<ToolResult> {
93    let content = value.get("content")?.as_array()?;
94    let mut text_parts = Vec::new();
95    let mut binary_results = Vec::new();
96
97    for block in content {
98        match block.get("type").and_then(serde_json::Value::as_str) {
99            Some("text") => {
100                if let Some(text) = block.get("text").and_then(serde_json::Value::as_str) {
101                    text_parts.push(text.to_string());
102                }
103            }
104            Some("image") => {
105                let data = block
106                    .get("data")
107                    .and_then(serde_json::Value::as_str)
108                    .filter(|s| !s.is_empty());
109                let mime_type = block
110                    .get("mimeType")
111                    .and_then(serde_json::Value::as_str)
112                    .filter(|s| !s.is_empty());
113                if let (Some(data), Some(mime_type)) = (data, mime_type) {
114                    binary_results.push(ToolBinaryResult {
115                        data: data.to_string(),
116                        mime_type: mime_type.to_string(),
117                        r#type: "image".to_string(),
118                        description: None,
119                    });
120                }
121            }
122            Some("resource") => {
123                let Some(resource) = block.get("resource").and_then(serde_json::Value::as_object)
124                else {
125                    continue;
126                };
127                if let Some(text) = resource
128                    .get("text")
129                    .and_then(serde_json::Value::as_str)
130                    .filter(|s| !s.is_empty())
131                {
132                    text_parts.push(text.to_string());
133                }
134                if let Some(blob) = resource
135                    .get("blob")
136                    .and_then(serde_json::Value::as_str)
137                    .filter(|s| !s.is_empty())
138                {
139                    let mime_type = resource
140                        .get("mimeType")
141                        .and_then(serde_json::Value::as_str)
142                        .filter(|s| !s.is_empty())
143                        .unwrap_or("application/octet-stream");
144                    let description = resource
145                        .get("uri")
146                        .and_then(serde_json::Value::as_str)
147                        .filter(|s| !s.is_empty())
148                        .map(ToString::to_string);
149                    binary_results.push(ToolBinaryResult {
150                        data: blob.to_string(),
151                        mime_type: mime_type.to_string(),
152                        r#type: "resource".to_string(),
153                        description,
154                    });
155                }
156            }
157            _ => {}
158        }
159    }
160
161    Some(ToolResult::Expanded(ToolResultExpanded {
162        text_result_for_llm: text_parts.join("\n"),
163        result_type: if value.get("isError").and_then(serde_json::Value::as_bool) == Some(true) {
164            "failure".to_string()
165        } else {
166            "success".to_string()
167        },
168        binary_results_for_llm: (!binary_results.is_empty()).then_some(binary_results),
169        session_log: None,
170        error: None,
171        tool_telemetry: None,
172    }))
173}
174
175/// A client-defined tool with its handler logic.
176///
177/// Implement this trait for each tool you expose to the Copilot agent.
178/// The struct is a named type — visible in stack traces and navigable
179/// via "go to definition" — unlike closure-based alternatives.
180///
181/// # Example
182///
183/// ```rust,ignore
184/// use github_copilot_sdk::tool::{schema_for, tool_parameters, JsonSchema, ToolHandler};
185/// use github_copilot_sdk::{Error, Tool, ToolInvocation, ToolResult};
186/// use serde::Deserialize;
187/// use async_trait::async_trait;
188///
189/// #[derive(Deserialize, JsonSchema)]
190/// struct GetWeatherParams {
191///     /// City name
192///     city: String,
193///     /// Temperature unit
194///     unit: Option<String>,
195/// }
196///
197/// struct GetWeatherTool;
198///
199/// #[async_trait]
200/// impl ToolHandler for GetWeatherTool {
201///     fn tool(&self) -> Tool {
202///         Tool {
203///             name: "get_weather".to_string(),
204///             namespaced_name: None,
205///             description: "Get weather for a city".to_string(),
206///             parameters: tool_parameters(schema_for::<GetWeatherParams>()),
207///             instructions: None,
208///             ..Default::default()
209///         }
210///     }
211///
212///     async fn call(&self, inv: ToolInvocation) -> Result<ToolResult, Error> {
213///         let params: GetWeatherParams = serde_json::from_value(inv.arguments)?;
214///         Ok(ToolResult::Text(format!("Weather in {}: sunny", params.city)))
215///     }
216/// }
217/// ```
218#[async_trait]
219pub trait ToolHandler: Send + Sync {
220    /// The tool definition sent to the CLI during session creation.
221    fn tool(&self) -> Tool;
222
223    /// Handle a tool invocation from the agent.
224    async fn call(&self, invocation: ToolInvocation) -> Result<ToolResult, Error>;
225}
226
227/// Define a tool from an async function (or closure) that takes a typed,
228/// `JsonSchema`-derived parameter struct.
229///
230/// The returned `Box<dyn ToolHandler>` plugs directly into
231/// [`ToolHandlerRouter::new`]. JSON Schema for the parameter type is generated
232/// via [`schema_for`] at construction time.
233///
234/// The handler bound (`Fn(ToolInvocation, P) -> Fut + Send + Sync + 'static`)
235/// accepts both bare `async fn` items and closures — the same shape as
236/// [`tower::service_fn`][tower-service-fn] and
237/// [`hyper::service::service_fn`][hyper-service-fn]. Prefer a free `async fn`
238/// for non-trivial tools so it shows up in stack traces by name.
239///
240/// The closure receives the full [`ToolInvocation`] alongside the deserialized
241/// parameters so handlers can use `inv.session_id`, `inv.tool_call_id`, or
242/// other invocation metadata. Handlers that don't need that metadata can
243/// destructure with `|_inv, params|`.
244///
245/// # Example
246///
247/// ```rust,no_run
248/// use github_copilot_sdk::tool::{define_tool, JsonSchema};
249/// use github_copilot_sdk::types::ToolInvocation;
250/// use github_copilot_sdk::{Error, ToolResult};
251/// use serde::Deserialize;
252///
253/// #[derive(Deserialize, JsonSchema)]
254/// struct GetWeatherParams {
255///     /// City name
256///     city: String,
257/// }
258///
259/// async fn get_weather(
260///     inv: ToolInvocation,
261///     params: GetWeatherParams,
262/// ) -> Result<ToolResult, Error> {
263///     // `inv.session_id` and `inv.tool_call_id` are available for telemetry,
264///     // streaming updates, scoping DB lookups, etc.
265///     let _ = inv.session_id;
266///     Ok(ToolResult::Text(format!("Sunny in {}", params.city)))
267/// }
268///
269/// // Pass a free async fn — preferred for non-trivial tools.
270/// let tool = define_tool("get_weather", "Get weather for a city", get_weather);
271///
272/// // ...or an inline closure when the body is trivial.
273/// let tool = define_tool(
274///     "echo",
275///     "Echo the input",
276///     |_inv, params: GetWeatherParams| async move {
277///         Ok(ToolResult::Text(params.city))
278///     },
279/// );
280/// # let _ = tool;
281/// ```
282///
283/// [tower-service-fn]: https://docs.rs/tower/latest/tower/fn.service_fn.html
284/// [hyper-service-fn]: https://docs.rs/hyper/latest/hyper/service/fn.service_fn.html
285#[cfg(feature = "derive")]
286pub fn define_tool<P, F, Fut>(
287    name: impl Into<String>,
288    description: impl Into<String>,
289    handler: F,
290) -> Box<dyn ToolHandler>
291where
292    P: schemars::JsonSchema + serde::de::DeserializeOwned + Send + 'static,
293    F: Fn(ToolInvocation, P) -> Fut + Send + Sync + 'static,
294    Fut: std::future::Future<Output = Result<ToolResult, Error>> + Send + 'static,
295{
296    struct FnTool<P, F> {
297        name: String,
298        description: String,
299        parameters: HashMap<String, serde_json::Value>,
300        handler: F,
301        _marker: std::marker::PhantomData<fn(P)>,
302    }
303
304    #[async_trait]
305    impl<P, F, Fut> ToolHandler for FnTool<P, F>
306    where
307        P: schemars::JsonSchema + serde::de::DeserializeOwned + Send + 'static,
308        F: Fn(ToolInvocation, P) -> Fut + Send + Sync + 'static,
309        Fut: std::future::Future<Output = Result<ToolResult, Error>> + Send + 'static,
310    {
311        fn tool(&self) -> Tool {
312            Tool {
313                name: self.name.clone(),
314                description: self.description.clone(),
315                parameters: self.parameters.clone(),
316                ..Default::default()
317            }
318        }
319
320        async fn call(&self, mut invocation: ToolInvocation) -> Result<ToolResult, Error> {
321            let arguments = std::mem::take(&mut invocation.arguments);
322            let params: P = serde_json::from_value(arguments)?;
323            (self.handler)(invocation, params).await
324        }
325    }
326
327    Box::new(FnTool {
328        name: name.into(),
329        description: description.into(),
330        parameters: tool_parameters(schema_for::<P>()),
331        handler,
332        _marker: std::marker::PhantomData,
333    })
334}
335
336/// A [`SessionHandler`] that dispatches tool calls to registered
337/// [`ToolHandler`] implementations by name.
338///
339/// For tool calls matching a registered handler, the handler is invoked
340/// directly. All other events (permissions, user input, unrecognized tools)
341/// are forwarded to the inner handler.
342///
343/// # Example
344///
345/// ```rust,no_run
346/// use std::sync::Arc;
347/// use github_copilot_sdk::handler::ApproveAllHandler;
348/// use github_copilot_sdk::tool::ToolHandlerRouter;
349///
350/// let router = ToolHandlerRouter::new(
351///     vec![/* Box::new(MyTool), ... */],
352///     Arc::new(ApproveAllHandler),
353/// );
354///
355/// // Use router.tools() in SessionConfig
356/// // Use Arc::new(router) as the session handler
357/// ```
358pub struct ToolHandlerRouter {
359    handlers: HashMap<String, Box<dyn ToolHandler>>,
360    inner: Arc<dyn SessionHandler>,
361}
362
363impl std::fmt::Debug for ToolHandlerRouter {
364    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
365        let mut tools: Vec<_> = self.handlers.keys().collect();
366        tools.sort();
367        f.debug_struct("ToolHandlerRouter")
368            .field("tool_count", &self.handlers.len())
369            .field("tools", &tools)
370            .finish()
371    }
372}
373
374impl ToolHandlerRouter {
375    /// Create a router from tool handler impls and a fallback handler.
376    ///
377    /// Call [`tools()`](Self::tools) to get the tool definitions for
378    /// [`SessionConfig::tools`](crate::SessionConfig::tools).
379    pub fn new(tools: Vec<Box<dyn ToolHandler>>, inner: Arc<dyn SessionHandler>) -> Self {
380        let mut handlers = HashMap::new();
381        for tool in tools {
382            handlers.insert(tool.tool().name.clone(), tool);
383        }
384        Self { handlers, inner }
385    }
386
387    /// Tool definitions for [`SessionConfig::tools`](crate::SessionConfig::tools).
388    pub fn tools(&self) -> Vec<Tool> {
389        self.handlers.values().map(|h| h.tool()).collect()
390    }
391}
392
393#[async_trait]
394impl SessionHandler for ToolHandlerRouter {
395    async fn on_external_tool(&self, invocation: ToolInvocation) -> ToolResult {
396        let Some(handler) = self.handlers.get(&invocation.tool_name) else {
397            return self.inner.on_external_tool(invocation).await;
398        };
399        match handler.call(invocation).await {
400            Ok(result) => result,
401            Err(e) => {
402                let msg = e.to_string();
403                ToolResult::Expanded(ToolResultExpanded {
404                    text_result_for_llm: msg.clone(),
405                    result_type: "failure".to_string(),
406                    binary_results_for_llm: None,
407                    session_log: None,
408                    error: Some(msg),
409                    tool_telemetry: None,
410                })
411            }
412        }
413    }
414
415    async fn on_session_event(&self, session_id: SessionId, event: SessionEvent) {
416        self.inner.on_session_event(session_id, event).await
417    }
418
419    async fn on_permission_request(
420        &self,
421        session_id: SessionId,
422        request_id: RequestId,
423        data: PermissionRequestData,
424    ) -> PermissionResult {
425        self.inner
426            .on_permission_request(session_id, request_id, data)
427            .await
428    }
429
430    async fn on_user_input(
431        &self,
432        session_id: SessionId,
433        question: String,
434        choices: Option<Vec<String>>,
435        allow_freeform: Option<bool>,
436    ) -> Option<UserInputResponse> {
437        self.inner
438            .on_user_input(session_id, question, choices, allow_freeform)
439            .await
440    }
441
442    async fn on_elicitation(
443        &self,
444        session_id: SessionId,
445        request_id: RequestId,
446        request: ElicitationRequest,
447    ) -> ElicitationResult {
448        self.inner
449            .on_elicitation(session_id, request_id, request)
450            .await
451    }
452}
453
454#[cfg(test)]
455mod tests {
456    use super::*;
457    use crate::types::{PermissionRequestData, RequestId, SessionId};
458
459    struct EchoTool;
460
461    #[async_trait]
462    impl ToolHandler for EchoTool {
463        fn tool(&self) -> Tool {
464            Tool {
465                name: "echo".to_string(),
466                namespaced_name: None,
467                description: "Echo the input".to_string(),
468                parameters: tool_parameters(serde_json::json!({"type": "object"})),
469                instructions: None,
470                ..Default::default()
471            }
472        }
473
474        async fn call(&self, inv: ToolInvocation) -> Result<ToolResult, Error> {
475            Ok(ToolResult::Text(inv.arguments.to_string()))
476        }
477    }
478
479    #[test]
480    fn tool_handler_returns_tool_definition() {
481        let tool = EchoTool;
482        let def = tool.tool();
483        assert_eq!(def.name, "echo");
484        assert_eq!(def.description, "Echo the input");
485        assert!(def.parameters.contains_key("type"));
486    }
487
488    #[test]
489    fn try_tool_parameters_rejects_non_object_schema() {
490        let err = try_tool_parameters(serde_json::json!(["not", "an", "object"]))
491            .expect_err("non-object schemas should be rejected");
492
493        assert!(err.is_data());
494    }
495
496    #[test]
497    fn convert_mcp_call_tool_result_collects_text_and_binary_content() {
498        let result = convert_mcp_call_tool_result(&serde_json::json!({
499            "isError": true,
500            "content": [
501                { "type": "text", "text": "hello" },
502                { "type": "image", "data": "aW1n", "mimeType": "image/png" },
503                {
504                    "type": "resource",
505                    "resource": {
506                        "uri": "file:///tmp/data.bin",
507                        "blob": "Ymlu",
508                        "mimeType": "application/octet-stream",
509                        "text": "resource text"
510                    }
511                }
512            ]
513        }))
514        .expect("valid CallToolResult should convert");
515
516        let ToolResult::Expanded(expanded) = result else {
517            panic!("expected expanded tool result");
518        };
519
520        assert_eq!(expanded.text_result_for_llm, "hello\nresource text");
521        assert_eq!(expanded.result_type, "failure");
522        let binary_results = expanded
523            .binary_results_for_llm
524            .expect("binary results should be captured");
525        assert_eq!(binary_results.len(), 2);
526        assert_eq!(binary_results[0].r#type, "image");
527        assert_eq!(binary_results[0].data, "aW1n");
528        assert_eq!(binary_results[0].mime_type, "image/png");
529        assert_eq!(
530            binary_results[1].description.as_deref(),
531            Some("file:///tmp/data.bin")
532        );
533    }
534
535    #[test]
536    fn convert_mcp_call_tool_result_converts_image_content() {
537        let result = convert_mcp_call_tool_result(&serde_json::json!({
538            "content": [
539                { "type": "image", "data": "aW1hZ2U=", "mimeType": "image/jpeg" }
540            ]
541        }))
542        .expect("valid CallToolResult should convert");
543
544        let ToolResult::Expanded(expanded) = result else {
545            panic!("expected expanded tool result");
546        };
547
548        assert_eq!(expanded.text_result_for_llm, "");
549        assert_eq!(expanded.result_type, "success");
550        let binary_results = expanded
551            .binary_results_for_llm
552            .expect("image result should be captured");
553        assert_eq!(binary_results.len(), 1);
554        assert_eq!(binary_results[0].data, "aW1hZ2U=");
555        assert_eq!(binary_results[0].mime_type, "image/jpeg");
556        assert_eq!(binary_results[0].r#type, "image");
557        assert!(binary_results[0].description.is_none());
558    }
559
560    #[test]
561    fn convert_mcp_call_tool_result_converts_resource_blob_content() {
562        let result = convert_mcp_call_tool_result(&serde_json::json!({
563            "content": [
564                {
565                    "type": "resource",
566                    "resource": {
567                        "uri": "file:///tmp/report.pdf",
568                        "blob": "cGRm",
569                        "mimeType": "application/pdf"
570                    }
571                }
572            ]
573        }))
574        .expect("valid CallToolResult should convert");
575
576        let ToolResult::Expanded(expanded) = result else {
577            panic!("expected expanded tool result");
578        };
579
580        let binary_results = expanded
581            .binary_results_for_llm
582            .expect("resource result should be captured");
583        assert_eq!(binary_results.len(), 1);
584        assert_eq!(binary_results[0].data, "cGRm");
585        assert_eq!(binary_results[0].mime_type, "application/pdf");
586        assert_eq!(binary_results[0].r#type, "resource");
587        assert_eq!(
588            binary_results[0].description.as_deref(),
589            Some("file:///tmp/report.pdf")
590        );
591    }
592
593    #[test]
594    fn convert_mcp_call_tool_result_defaults_resource_blob_mime_type() {
595        let result = convert_mcp_call_tool_result(&serde_json::json!({
596            "content": [
597                {
598                    "type": "resource",
599                    "resource": {
600                        "uri": "file:///tmp/data.bin",
601                        "blob": "Ymlu"
602                    }
603                },
604                {
605                    "type": "resource",
606                    "resource": {
607                        "blob": "YmluMg==",
608                        "mimeType": ""
609                    }
610                }
611            ]
612        }))
613        .expect("valid CallToolResult should convert");
614
615        let ToolResult::Expanded(expanded) = result else {
616            panic!("expected expanded tool result");
617        };
618
619        let binary_results = expanded
620            .binary_results_for_llm
621            .expect("resource blobs should be captured");
622        assert_eq!(binary_results.len(), 2);
623        assert_eq!(binary_results[0].mime_type, "application/octet-stream");
624        assert_eq!(binary_results[1].mime_type, "application/octet-stream");
625    }
626
627    #[test]
628    fn convert_mcp_call_tool_result_omits_binary_results_without_binary_content() {
629        let result = convert_mcp_call_tool_result(&serde_json::json!({
630            "content": [
631                { "type": "text", "text": "hello" },
632                {
633                    "type": "resource",
634                    "resource": {
635                        "uri": "file:///tmp/readme.md",
636                        "text": "resource text"
637                    }
638                }
639            ]
640        }))
641        .expect("valid CallToolResult should convert");
642
643        let ToolResult::Expanded(expanded) = result else {
644            panic!("expected expanded tool result");
645        };
646
647        assert_eq!(expanded.text_result_for_llm, "hello\nresource text");
648        assert!(expanded.binary_results_for_llm.is_none());
649    }
650
651    #[tokio::test]
652    async fn tool_handler_call_returns_result() {
653        let tool = EchoTool;
654        let inv = ToolInvocation {
655            session_id: SessionId::from("s1"),
656            tool_call_id: "tc1".to_string(),
657            tool_name: "echo".to_string(),
658            arguments: serde_json::json!({"msg": "hello"}),
659            traceparent: None,
660            tracestate: None,
661        };
662
663        let result = tool.call(inv).await.unwrap();
664        match result {
665            ToolResult::Text(s) => assert!(s.contains("hello")),
666            _ => panic!("expected Text result"),
667        }
668    }
669
670    #[cfg(feature = "derive")]
671    #[tokio::test]
672    async fn define_tool_builds_schema_and_dispatches() {
673        use serde::Deserialize;
674
675        #[derive(Deserialize, schemars::JsonSchema)]
676        struct Params {
677            city: String,
678        }
679
680        let tool = define_tool(
681            "weather",
682            "Get the weather for a city",
683            |_inv, params: Params| async move {
684                Ok(ToolResult::Text(format!("sunny in {}", params.city)))
685            },
686        );
687
688        let def = tool.tool();
689        assert_eq!(def.name, "weather");
690        assert_eq!(def.description, "Get the weather for a city");
691        assert_eq!(def.parameters["type"], "object");
692        assert!(def.parameters["properties"]["city"].is_object());
693
694        let inv = ToolInvocation {
695            session_id: SessionId::from("s1"),
696            tool_call_id: "tc1".to_string(),
697            tool_name: "weather".to_string(),
698            arguments: serde_json::json!({"city": "Seattle"}),
699            traceparent: None,
700            tracestate: None,
701        };
702        match tool.call(inv).await.unwrap() {
703            ToolResult::Text(s) => assert_eq!(s, "sunny in Seattle"),
704            _ => panic!("expected Text result"),
705        }
706    }
707
708    #[tokio::test]
709    async fn router_dispatches_to_correct_handler() {
710        struct ToolA;
711        #[async_trait]
712        impl ToolHandler for ToolA {
713            fn tool(&self) -> Tool {
714                Tool {
715                    name: "tool_a".to_string(),
716                    namespaced_name: None,
717                    description: "A".to_string(),
718                    parameters: HashMap::new(),
719                    instructions: None,
720                    ..Default::default()
721                }
722            }
723
724            async fn call(&self, _inv: ToolInvocation) -> Result<ToolResult, Error> {
725                Ok(ToolResult::Text("a_result".to_string()))
726            }
727        }
728
729        struct ToolB;
730        #[async_trait]
731        impl ToolHandler for ToolB {
732            fn tool(&self) -> Tool {
733                Tool {
734                    name: "tool_b".to_string(),
735                    namespaced_name: None,
736                    description: "B".to_string(),
737                    parameters: HashMap::new(),
738                    instructions: None,
739                    ..Default::default()
740                }
741            }
742
743            async fn call(&self, _inv: ToolInvocation) -> Result<ToolResult, Error> {
744                Ok(ToolResult::Text("b_result".to_string()))
745            }
746        }
747
748        let router = ToolHandlerRouter::new(
749            vec![Box::new(ToolA), Box::new(ToolB)],
750            Arc::new(crate::handler::ApproveAllHandler),
751        );
752
753        let tools = router.tools();
754        assert_eq!(tools.len(), 2);
755
756        let response = router
757            .on_external_tool(ToolInvocation {
758                session_id: SessionId::from("s1"),
759                tool_call_id: "tc1".to_string(),
760                tool_name: "tool_b".to_string(),
761                arguments: serde_json::json!({}),
762                traceparent: None,
763                tracestate: None,
764            })
765            .await;
766        match response {
767            ToolResult::Text(s) => assert_eq!(s, "b_result"),
768            _ => panic!("expected ToolResult::Text"),
769        }
770    }
771
772    #[tokio::test]
773    async fn router_falls_through_for_unknown_tool() {
774        use std::sync::atomic::{AtomicBool, Ordering};
775
776        struct FallbackHandler {
777            called: AtomicBool,
778        }
779        #[async_trait]
780        impl SessionHandler for FallbackHandler {
781            async fn on_external_tool(&self, _inv: ToolInvocation) -> ToolResult {
782                self.called.store(true, Ordering::Relaxed);
783                ToolResult::Text("fallback".to_string())
784            }
785        }
786
787        let fallback = Arc::new(FallbackHandler {
788            called: AtomicBool::new(false),
789        });
790        let router = ToolHandlerRouter::new(vec![], fallback.clone());
791
792        let response = router
793            .on_external_tool(ToolInvocation {
794                session_id: SessionId::from("s1"),
795                tool_call_id: "tc1".to_string(),
796                tool_name: "unknown".to_string(),
797                arguments: serde_json::json!({}),
798                traceparent: None,
799                tracestate: None,
800            })
801            .await;
802        assert!(fallback.called.load(Ordering::Relaxed));
803        match response {
804            ToolResult::Text(s) => assert_eq!(s, "fallback"),
805            _ => panic!("expected fallback result"),
806        }
807    }
808
809    #[tokio::test]
810    async fn router_returns_failure_on_handler_error() {
811        struct FailTool;
812        #[async_trait]
813        impl ToolHandler for FailTool {
814            fn tool(&self) -> Tool {
815                Tool {
816                    name: "bad_tool".to_string(),
817                    namespaced_name: None,
818                    description: "Always fails".to_string(),
819                    parameters: HashMap::new(),
820                    instructions: None,
821                    ..Default::default()
822                }
823            }
824
825            async fn call(&self, _inv: ToolInvocation) -> Result<ToolResult, Error> {
826                Err(Error::Rpc {
827                    code: -1,
828                    message: "intentional failure".to_string(),
829                })
830            }
831        }
832
833        let router = ToolHandlerRouter::new(
834            vec![Box::new(FailTool)],
835            Arc::new(crate::handler::ApproveAllHandler),
836        );
837
838        let response = router
839            .on_external_tool(ToolInvocation {
840                session_id: SessionId::from("s1"),
841                tool_call_id: "tc1".to_string(),
842                tool_name: "bad_tool".to_string(),
843                arguments: serde_json::json!({}),
844                traceparent: None,
845                tracestate: None,
846            })
847            .await;
848        match response {
849            ToolResult::Expanded(exp) => {
850                assert_eq!(exp.result_type, "failure");
851                assert!(exp.error.unwrap().contains("intentional failure"));
852            }
853            _ => panic!("expected expanded failure result"),
854        }
855    }
856
857    #[tokio::test]
858    async fn router_forwards_non_tool_events() {
859        struct PermHandler;
860        #[async_trait]
861        impl SessionHandler for PermHandler {
862            async fn on_permission_request(
863                &self,
864                _session_id: SessionId,
865                _request_id: RequestId,
866                _data: PermissionRequestData,
867            ) -> PermissionResult {
868                PermissionResult::Denied
869            }
870        }
871
872        let router = ToolHandlerRouter::new(vec![], Arc::new(PermHandler));
873
874        let response = router
875            .on_permission_request(
876                SessionId::from("s1"),
877                RequestId::new("r1"),
878                PermissionRequestData {
879                    extra: serde_json::json!({}),
880                    ..Default::default()
881                },
882            )
883            .await;
884        assert!(matches!(response, PermissionResult::Denied));
885    }
886
887    #[tokio::test]
888    async fn router_default_on_event_dispatches_via_per_event_methods() {
889        // Regression: callers using the legacy on_event entry point should
890        // still get correct dispatch through the inherited default impl.
891        use crate::handler::{HandlerEvent, HandlerResponse};
892
893        struct OkTool;
894        #[async_trait]
895        impl ToolHandler for OkTool {
896            fn tool(&self) -> Tool {
897                Tool {
898                    name: "ok_tool".to_string(),
899                    namespaced_name: None,
900                    description: "ok".to_string(),
901                    parameters: HashMap::new(),
902                    instructions: None,
903                    ..Default::default()
904                }
905            }
906
907            async fn call(&self, _inv: ToolInvocation) -> Result<ToolResult, Error> {
908                Ok(ToolResult::Text("ok".to_string()))
909            }
910        }
911
912        let router = ToolHandlerRouter::new(
913            vec![Box::new(OkTool)],
914            Arc::new(crate::handler::ApproveAllHandler),
915        );
916
917        let response = router
918            .on_event(HandlerEvent::ExternalTool {
919                invocation: ToolInvocation {
920                    session_id: SessionId::from("s1"),
921                    tool_call_id: "tc1".to_string(),
922                    tool_name: "ok_tool".to_string(),
923                    arguments: serde_json::json!({}),
924                    traceparent: None,
925                    tracestate: None,
926                },
927            })
928            .await;
929        match response {
930            HandlerResponse::ToolResult(ToolResult::Text(s)) => assert_eq!(s, "ok"),
931            _ => panic!("expected ToolResult via default on_event"),
932        }
933    }
934
935    // Tests requiring `schemars` (the `derive` feature).
936    #[cfg(feature = "derive")]
937    mod derive_tests {
938        use serde::Deserialize;
939
940        use super::super::*;
941        use crate::SessionId;
942
943        #[derive(Deserialize, schemars::JsonSchema)]
944        struct GetWeatherParams {
945            /// City name to get weather for.
946            city: String,
947            /// Temperature unit (celsius or fahrenheit).
948            unit: Option<String>,
949        }
950
951        #[test]
952        fn schema_for_generates_clean_schema() {
953            let schema = schema_for::<GetWeatherParams>();
954            assert_eq!(schema["type"], "object");
955            assert!(schema["properties"]["city"].is_object());
956            assert!(schema["properties"]["unit"].is_object());
957            // city is required (non-Option), unit is not
958            let required = schema["required"].as_array().unwrap();
959            assert!(required.contains(&serde_json::json!("city")));
960            assert!(!required.contains(&serde_json::json!("unit")));
961            // Root-level metadata stripped
962            assert!(schema.get("$schema").is_none());
963            assert!(schema.get("title").is_none());
964        }
965
966        struct GetWeatherTool;
967
968        #[async_trait]
969        impl ToolHandler for GetWeatherTool {
970            fn tool(&self) -> Tool {
971                Tool {
972                    name: "get_weather".to_string(),
973                    namespaced_name: None,
974                    description: "Get weather for a city".to_string(),
975                    parameters: tool_parameters(schema_for::<GetWeatherParams>()),
976                    instructions: None,
977                    ..Default::default()
978                }
979            }
980
981            async fn call(&self, inv: ToolInvocation) -> Result<ToolResult, Error> {
982                let params: GetWeatherParams = serde_json::from_value(inv.arguments)?;
983                Ok(ToolResult::Text(format!(
984                    "{} {}",
985                    params.city,
986                    params.unit.unwrap_or_default()
987                )))
988            }
989        }
990
991        #[test]
992        fn tool_handler_with_schema_for() {
993            let tool = GetWeatherTool;
994            let def = tool.tool();
995            assert_eq!(def.name, "get_weather");
996            let schema = serde_json::to_value(&def.parameters).expect("serialize tool parameters");
997            assert_eq!(schema["type"], "object");
998            assert!(schema["properties"]["city"].is_object());
999        }
1000
1001        #[tokio::test]
1002        async fn tool_handler_deserializes_typed_params() {
1003            let tool = GetWeatherTool;
1004            let inv = ToolInvocation {
1005                session_id: SessionId::from("s1"),
1006                tool_call_id: "tc1".to_string(),
1007                tool_name: "get_weather".to_string(),
1008                arguments: serde_json::json!({"city": "Seattle", "unit": "celsius"}),
1009                traceparent: None,
1010                tracestate: None,
1011            };
1012
1013            let result = tool.call(inv).await.unwrap();
1014            match result {
1015                ToolResult::Text(s) => assert_eq!(s, "Seattle celsius"),
1016                _ => panic!("expected Text result"),
1017            }
1018        }
1019
1020        #[tokio::test]
1021        async fn tool_handler_returns_error_on_bad_params() {
1022            let tool = GetWeatherTool;
1023            let inv = ToolInvocation {
1024                session_id: SessionId::from("s1"),
1025                tool_call_id: "tc1".to_string(),
1026                tool_name: "get_weather".to_string(),
1027                arguments: serde_json::json!({"wrong_field": 42}),
1028                traceparent: None,
1029                tracestate: None,
1030            };
1031
1032            let err = tool.call(inv).await.unwrap_err();
1033            assert!(matches!(err, Error::Json(_)));
1034        }
1035
1036        #[tokio::test]
1037        async fn router_with_schema_for_tools() {
1038            let router = ToolHandlerRouter::new(
1039                vec![Box::new(GetWeatherTool)],
1040                Arc::new(crate::handler::ApproveAllHandler),
1041            );
1042
1043            let tools = router.tools();
1044            assert_eq!(tools.len(), 1);
1045            assert_eq!(tools[0].name, "get_weather");
1046
1047            let response = router
1048                .on_external_tool(ToolInvocation {
1049                    session_id: SessionId::from("s1"),
1050                    tool_call_id: "tc1".to_string(),
1051                    tool_name: "get_weather".to_string(),
1052                    arguments: serde_json::json!({"city": "Portland"}),
1053                    traceparent: None,
1054                    tracestate: None,
1055                })
1056                .await;
1057            match response {
1058                ToolResult::Text(s) => assert!(s.contains("Portland")),
1059                _ => panic!("expected ToolResult::Text"),
1060            }
1061        }
1062    }
1063}