1use std::collections::HashMap;
11use std::sync::Arc;
12
13use async_trait::async_trait;
14#[cfg(feature = "derive")]
16pub use schemars::JsonSchema;
17
18use crate::Error;
19use crate::handler::{ExitPlanModeResult, PermissionResult, SessionHandler, UserInputResponse};
20use crate::types::{
21 ElicitationRequest, ElicitationResult, ExitPlanModeData, PermissionRequestData, RequestId,
22 SessionEvent, SessionId, Tool, ToolInvocation, ToolResult, ToolResultExpanded,
23};
24
25#[cfg(feature = "derive")]
46pub fn schema_for<T: schemars::JsonSchema>() -> serde_json::Value {
47 let schema = schemars::schema_for!(T);
48 let mut value = serde_json::to_value(schema).expect("JSON Schema serialization cannot fail");
49 if let Some(obj) = value.as_object_mut() {
50 obj.remove("$schema");
51 obj.remove("title");
52 }
53 value
54}
55
56pub fn tool_parameters(schema: serde_json::Value) -> HashMap<String, serde_json::Value> {
79 try_tool_parameters(schema).expect("tool parameter schema must be a JSON object")
80}
81
82pub fn try_tool_parameters(
84 schema: serde_json::Value,
85) -> Result<HashMap<String, serde_json::Value>, serde_json::Error> {
86 serde_json::from_value(schema)
87}
88
89#[async_trait]
133pub trait ToolHandler: Send + Sync {
134 fn tool(&self) -> Tool;
136
137 async fn call(&self, invocation: ToolInvocation) -> Result<ToolResult, Error>;
139}
140
141#[cfg(feature = "derive")]
200pub fn define_tool<P, F, Fut>(
201 name: impl Into<String>,
202 description: impl Into<String>,
203 handler: F,
204) -> Box<dyn ToolHandler>
205where
206 P: schemars::JsonSchema + serde::de::DeserializeOwned + Send + 'static,
207 F: Fn(ToolInvocation, P) -> Fut + Send + Sync + 'static,
208 Fut: std::future::Future<Output = Result<ToolResult, Error>> + Send + 'static,
209{
210 struct FnTool<P, F> {
211 name: String,
212 description: String,
213 parameters: HashMap<String, serde_json::Value>,
214 handler: F,
215 _marker: std::marker::PhantomData<fn(P)>,
216 }
217
218 #[async_trait]
219 impl<P, F, Fut> ToolHandler for FnTool<P, F>
220 where
221 P: schemars::JsonSchema + serde::de::DeserializeOwned + Send + 'static,
222 F: Fn(ToolInvocation, P) -> Fut + Send + Sync + 'static,
223 Fut: std::future::Future<Output = Result<ToolResult, Error>> + Send + 'static,
224 {
225 fn tool(&self) -> Tool {
226 Tool {
227 name: self.name.clone(),
228 description: self.description.clone(),
229 parameters: self.parameters.clone(),
230 ..Default::default()
231 }
232 }
233
234 async fn call(&self, mut invocation: ToolInvocation) -> Result<ToolResult, Error> {
235 let arguments = std::mem::take(&mut invocation.arguments);
236 let params: P = serde_json::from_value(arguments)?;
237 (self.handler)(invocation, params).await
238 }
239 }
240
241 Box::new(FnTool {
242 name: name.into(),
243 description: description.into(),
244 parameters: tool_parameters(schema_for::<P>()),
245 handler,
246 _marker: std::marker::PhantomData,
247 })
248}
249
250pub struct ToolHandlerRouter {
273 handlers: HashMap<String, Box<dyn ToolHandler>>,
274 inner: Arc<dyn SessionHandler>,
275}
276
277impl std::fmt::Debug for ToolHandlerRouter {
278 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
279 let mut tools: Vec<_> = self.handlers.keys().collect();
280 tools.sort();
281 f.debug_struct("ToolHandlerRouter")
282 .field("tool_count", &self.handlers.len())
283 .field("tools", &tools)
284 .finish()
285 }
286}
287
288impl ToolHandlerRouter {
289 pub fn new(tools: Vec<Box<dyn ToolHandler>>, inner: Arc<dyn SessionHandler>) -> Self {
294 let mut handlers = HashMap::new();
295 for tool in tools {
296 handlers.insert(tool.tool().name.clone(), tool);
297 }
298 Self { handlers, inner }
299 }
300
301 pub fn tools(&self) -> Vec<Tool> {
303 self.handlers.values().map(|h| h.tool()).collect()
304 }
305}
306
307#[async_trait]
308impl SessionHandler for ToolHandlerRouter {
309 async fn on_external_tool(&self, invocation: ToolInvocation) -> ToolResult {
310 let Some(handler) = self.handlers.get(&invocation.tool_name) else {
311 return self.inner.on_external_tool(invocation).await;
312 };
313 match handler.call(invocation).await {
314 Ok(result) => result,
315 Err(e) => {
316 let msg = e.to_string();
317 ToolResult::Expanded(ToolResultExpanded {
318 text_result_for_llm: msg.clone(),
319 result_type: "failure".to_string(),
320 session_log: None,
321 error: Some(msg),
322 })
323 }
324 }
325 }
326
327 async fn on_session_event(&self, session_id: SessionId, event: SessionEvent) {
328 self.inner.on_session_event(session_id, event).await
329 }
330
331 async fn on_permission_request(
332 &self,
333 session_id: SessionId,
334 request_id: RequestId,
335 data: PermissionRequestData,
336 ) -> PermissionResult {
337 self.inner
338 .on_permission_request(session_id, request_id, data)
339 .await
340 }
341
342 async fn on_user_input(
343 &self,
344 session_id: SessionId,
345 question: String,
346 choices: Option<Vec<String>>,
347 allow_freeform: Option<bool>,
348 ) -> Option<UserInputResponse> {
349 self.inner
350 .on_user_input(session_id, question, choices, allow_freeform)
351 .await
352 }
353
354 async fn on_elicitation(
355 &self,
356 session_id: SessionId,
357 request_id: RequestId,
358 request: ElicitationRequest,
359 ) -> ElicitationResult {
360 self.inner
361 .on_elicitation(session_id, request_id, request)
362 .await
363 }
364
365 async fn on_exit_plan_mode(
366 &self,
367 session_id: SessionId,
368 data: ExitPlanModeData,
369 ) -> ExitPlanModeResult {
370 self.inner.on_exit_plan_mode(session_id, data).await
371 }
372}
373
374#[cfg(test)]
375mod tests {
376 use super::*;
377 use crate::types::{PermissionRequestData, RequestId, SessionId};
378
379 struct EchoTool;
380
381 #[async_trait]
382 impl ToolHandler for EchoTool {
383 fn tool(&self) -> Tool {
384 Tool {
385 name: "echo".to_string(),
386 namespaced_name: None,
387 description: "Echo the input".to_string(),
388 parameters: tool_parameters(serde_json::json!({"type": "object"})),
389 instructions: None,
390 ..Default::default()
391 }
392 }
393
394 async fn call(&self, inv: ToolInvocation) -> Result<ToolResult, Error> {
395 Ok(ToolResult::Text(inv.arguments.to_string()))
396 }
397 }
398
399 #[test]
400 fn tool_handler_returns_tool_definition() {
401 let tool = EchoTool;
402 let def = tool.tool();
403 assert_eq!(def.name, "echo");
404 assert_eq!(def.description, "Echo the input");
405 assert!(def.parameters.contains_key("type"));
406 }
407
408 #[test]
409 fn try_tool_parameters_rejects_non_object_schema() {
410 let err = try_tool_parameters(serde_json::json!(["not", "an", "object"]))
411 .expect_err("non-object schemas should be rejected");
412
413 assert!(err.is_data());
414 }
415
416 #[tokio::test]
417 async fn tool_handler_call_returns_result() {
418 let tool = EchoTool;
419 let inv = ToolInvocation {
420 session_id: SessionId::from("s1"),
421 tool_call_id: "tc1".to_string(),
422 tool_name: "echo".to_string(),
423 arguments: serde_json::json!({"msg": "hello"}),
424 traceparent: None,
425 tracestate: None,
426 };
427
428 let result = tool.call(inv).await.unwrap();
429 match result {
430 ToolResult::Text(s) => assert!(s.contains("hello")),
431 _ => panic!("expected Text result"),
432 }
433 }
434
435 #[cfg(feature = "derive")]
436 #[tokio::test]
437 async fn define_tool_builds_schema_and_dispatches() {
438 use serde::Deserialize;
439
440 #[derive(Deserialize, schemars::JsonSchema)]
441 struct Params {
442 city: String,
443 }
444
445 let tool = define_tool(
446 "weather",
447 "Get the weather for a city",
448 |_inv, params: Params| async move {
449 Ok(ToolResult::Text(format!("sunny in {}", params.city)))
450 },
451 );
452
453 let def = tool.tool();
454 assert_eq!(def.name, "weather");
455 assert_eq!(def.description, "Get the weather for a city");
456 assert_eq!(def.parameters["type"], "object");
457 assert!(def.parameters["properties"]["city"].is_object());
458
459 let inv = ToolInvocation {
460 session_id: SessionId::from("s1"),
461 tool_call_id: "tc1".to_string(),
462 tool_name: "weather".to_string(),
463 arguments: serde_json::json!({"city": "Seattle"}),
464 traceparent: None,
465 tracestate: None,
466 };
467 match tool.call(inv).await.unwrap() {
468 ToolResult::Text(s) => assert_eq!(s, "sunny in Seattle"),
469 _ => panic!("expected Text result"),
470 }
471 }
472
473 #[tokio::test]
474 async fn router_dispatches_to_correct_handler() {
475 struct ToolA;
476 #[async_trait]
477 impl ToolHandler for ToolA {
478 fn tool(&self) -> Tool {
479 Tool {
480 name: "tool_a".to_string(),
481 namespaced_name: None,
482 description: "A".to_string(),
483 parameters: HashMap::new(),
484 instructions: None,
485 ..Default::default()
486 }
487 }
488
489 async fn call(&self, _inv: ToolInvocation) -> Result<ToolResult, Error> {
490 Ok(ToolResult::Text("a_result".to_string()))
491 }
492 }
493
494 struct ToolB;
495 #[async_trait]
496 impl ToolHandler for ToolB {
497 fn tool(&self) -> Tool {
498 Tool {
499 name: "tool_b".to_string(),
500 namespaced_name: None,
501 description: "B".to_string(),
502 parameters: HashMap::new(),
503 instructions: None,
504 ..Default::default()
505 }
506 }
507
508 async fn call(&self, _inv: ToolInvocation) -> Result<ToolResult, Error> {
509 Ok(ToolResult::Text("b_result".to_string()))
510 }
511 }
512
513 let router = ToolHandlerRouter::new(
514 vec![Box::new(ToolA), Box::new(ToolB)],
515 Arc::new(crate::handler::ApproveAllHandler),
516 );
517
518 let tools = router.tools();
519 assert_eq!(tools.len(), 2);
520
521 let response = router
522 .on_external_tool(ToolInvocation {
523 session_id: SessionId::from("s1"),
524 tool_call_id: "tc1".to_string(),
525 tool_name: "tool_b".to_string(),
526 arguments: serde_json::json!({}),
527 traceparent: None,
528 tracestate: None,
529 })
530 .await;
531 match response {
532 ToolResult::Text(s) => assert_eq!(s, "b_result"),
533 _ => panic!("expected ToolResult::Text"),
534 }
535 }
536
537 #[tokio::test]
538 async fn router_falls_through_for_unknown_tool() {
539 use std::sync::atomic::{AtomicBool, Ordering};
540
541 struct FallbackHandler {
542 called: AtomicBool,
543 }
544 #[async_trait]
545 impl SessionHandler for FallbackHandler {
546 async fn on_external_tool(&self, _inv: ToolInvocation) -> ToolResult {
547 self.called.store(true, Ordering::Relaxed);
548 ToolResult::Text("fallback".to_string())
549 }
550 }
551
552 let fallback = Arc::new(FallbackHandler {
553 called: AtomicBool::new(false),
554 });
555 let router = ToolHandlerRouter::new(vec![], fallback.clone());
556
557 let response = router
558 .on_external_tool(ToolInvocation {
559 session_id: SessionId::from("s1"),
560 tool_call_id: "tc1".to_string(),
561 tool_name: "unknown".to_string(),
562 arguments: serde_json::json!({}),
563 traceparent: None,
564 tracestate: None,
565 })
566 .await;
567 assert!(fallback.called.load(Ordering::Relaxed));
568 match response {
569 ToolResult::Text(s) => assert_eq!(s, "fallback"),
570 _ => panic!("expected fallback result"),
571 }
572 }
573
574 #[tokio::test]
575 async fn router_returns_failure_on_handler_error() {
576 struct FailTool;
577 #[async_trait]
578 impl ToolHandler for FailTool {
579 fn tool(&self) -> Tool {
580 Tool {
581 name: "bad_tool".to_string(),
582 namespaced_name: None,
583 description: "Always fails".to_string(),
584 parameters: HashMap::new(),
585 instructions: None,
586 ..Default::default()
587 }
588 }
589
590 async fn call(&self, _inv: ToolInvocation) -> Result<ToolResult, Error> {
591 Err(Error::Rpc {
592 code: -1,
593 message: "intentional failure".to_string(),
594 })
595 }
596 }
597
598 let router = ToolHandlerRouter::new(
599 vec![Box::new(FailTool)],
600 Arc::new(crate::handler::ApproveAllHandler),
601 );
602
603 let response = router
604 .on_external_tool(ToolInvocation {
605 session_id: SessionId::from("s1"),
606 tool_call_id: "tc1".to_string(),
607 tool_name: "bad_tool".to_string(),
608 arguments: serde_json::json!({}),
609 traceparent: None,
610 tracestate: None,
611 })
612 .await;
613 match response {
614 ToolResult::Expanded(exp) => {
615 assert_eq!(exp.result_type, "failure");
616 assert!(exp.error.unwrap().contains("intentional failure"));
617 }
618 _ => panic!("expected expanded failure result"),
619 }
620 }
621
622 #[tokio::test]
623 async fn router_forwards_non_tool_events() {
624 struct PermHandler;
625 #[async_trait]
626 impl SessionHandler for PermHandler {
627 async fn on_permission_request(
628 &self,
629 _session_id: SessionId,
630 _request_id: RequestId,
631 _data: PermissionRequestData,
632 ) -> PermissionResult {
633 PermissionResult::Denied
634 }
635 }
636
637 let router = ToolHandlerRouter::new(vec![], Arc::new(PermHandler));
638
639 let response = router
640 .on_permission_request(
641 SessionId::from("s1"),
642 RequestId::new("r1"),
643 PermissionRequestData {
644 extra: serde_json::json!({}),
645 ..Default::default()
646 },
647 )
648 .await;
649 assert!(matches!(response, PermissionResult::Denied));
650 }
651
652 #[tokio::test]
653 async fn router_default_on_event_dispatches_via_per_event_methods() {
654 use crate::handler::{HandlerEvent, HandlerResponse};
657
658 struct OkTool;
659 #[async_trait]
660 impl ToolHandler for OkTool {
661 fn tool(&self) -> Tool {
662 Tool {
663 name: "ok_tool".to_string(),
664 namespaced_name: None,
665 description: "ok".to_string(),
666 parameters: HashMap::new(),
667 instructions: None,
668 ..Default::default()
669 }
670 }
671
672 async fn call(&self, _inv: ToolInvocation) -> Result<ToolResult, Error> {
673 Ok(ToolResult::Text("ok".to_string()))
674 }
675 }
676
677 let router = ToolHandlerRouter::new(
678 vec![Box::new(OkTool)],
679 Arc::new(crate::handler::ApproveAllHandler),
680 );
681
682 let response = router
683 .on_event(HandlerEvent::ExternalTool {
684 invocation: ToolInvocation {
685 session_id: SessionId::from("s1"),
686 tool_call_id: "tc1".to_string(),
687 tool_name: "ok_tool".to_string(),
688 arguments: serde_json::json!({}),
689 traceparent: None,
690 tracestate: None,
691 },
692 })
693 .await;
694 match response {
695 HandlerResponse::ToolResult(ToolResult::Text(s)) => assert_eq!(s, "ok"),
696 _ => panic!("expected ToolResult via default on_event"),
697 }
698 }
699
700 #[cfg(feature = "derive")]
702 mod derive_tests {
703 use serde::Deserialize;
704
705 use super::super::*;
706 use crate::SessionId;
707
708 #[derive(Deserialize, schemars::JsonSchema)]
709 struct GetWeatherParams {
710 city: String,
712 unit: Option<String>,
714 }
715
716 #[test]
717 fn schema_for_generates_clean_schema() {
718 let schema = schema_for::<GetWeatherParams>();
719 assert_eq!(schema["type"], "object");
720 assert!(schema["properties"]["city"].is_object());
721 assert!(schema["properties"]["unit"].is_object());
722 let required = schema["required"].as_array().unwrap();
724 assert!(required.contains(&serde_json::json!("city")));
725 assert!(!required.contains(&serde_json::json!("unit")));
726 assert!(schema.get("$schema").is_none());
728 assert!(schema.get("title").is_none());
729 }
730
731 struct GetWeatherTool;
732
733 #[async_trait]
734 impl ToolHandler for GetWeatherTool {
735 fn tool(&self) -> Tool {
736 Tool {
737 name: "get_weather".to_string(),
738 namespaced_name: None,
739 description: "Get weather for a city".to_string(),
740 parameters: tool_parameters(schema_for::<GetWeatherParams>()),
741 instructions: None,
742 ..Default::default()
743 }
744 }
745
746 async fn call(&self, inv: ToolInvocation) -> Result<ToolResult, Error> {
747 let params: GetWeatherParams = serde_json::from_value(inv.arguments)?;
748 Ok(ToolResult::Text(format!(
749 "{} {}",
750 params.city,
751 params.unit.unwrap_or_default()
752 )))
753 }
754 }
755
756 #[test]
757 fn tool_handler_with_schema_for() {
758 let tool = GetWeatherTool;
759 let def = tool.tool();
760 assert_eq!(def.name, "get_weather");
761 let schema = serde_json::to_value(&def.parameters).expect("serialize tool parameters");
762 assert_eq!(schema["type"], "object");
763 assert!(schema["properties"]["city"].is_object());
764 }
765
766 #[tokio::test]
767 async fn tool_handler_deserializes_typed_params() {
768 let tool = GetWeatherTool;
769 let inv = ToolInvocation {
770 session_id: SessionId::from("s1"),
771 tool_call_id: "tc1".to_string(),
772 tool_name: "get_weather".to_string(),
773 arguments: serde_json::json!({"city": "Seattle", "unit": "celsius"}),
774 traceparent: None,
775 tracestate: None,
776 };
777
778 let result = tool.call(inv).await.unwrap();
779 match result {
780 ToolResult::Text(s) => assert_eq!(s, "Seattle celsius"),
781 _ => panic!("expected Text result"),
782 }
783 }
784
785 #[tokio::test]
786 async fn tool_handler_returns_error_on_bad_params() {
787 let tool = GetWeatherTool;
788 let inv = ToolInvocation {
789 session_id: SessionId::from("s1"),
790 tool_call_id: "tc1".to_string(),
791 tool_name: "get_weather".to_string(),
792 arguments: serde_json::json!({"wrong_field": 42}),
793 traceparent: None,
794 tracestate: None,
795 };
796
797 let err = tool.call(inv).await.unwrap_err();
798 assert!(matches!(err, Error::Json(_)));
799 }
800
801 #[tokio::test]
802 async fn router_with_schema_for_tools() {
803 let router = ToolHandlerRouter::new(
804 vec![Box::new(GetWeatherTool)],
805 Arc::new(crate::handler::ApproveAllHandler),
806 );
807
808 let tools = router.tools();
809 assert_eq!(tools.len(), 1);
810 assert_eq!(tools[0].name, "get_weather");
811
812 let response = router
813 .on_external_tool(ToolInvocation {
814 session_id: SessionId::from("s1"),
815 tool_call_id: "tc1".to_string(),
816 tool_name: "get_weather".to_string(),
817 arguments: serde_json::json!({"city": "Portland"}),
818 traceparent: None,
819 tracestate: None,
820 })
821 .await;
822 match response {
823 ToolResult::Text(s) => assert!(s.contains("Portland")),
824 _ => panic!("expected ToolResult::Text"),
825 }
826 }
827 }
828}