1use std::collections::HashMap;
17
18use async_trait::async_trait;
19#[cfg(feature = "derive")]
21pub use schemars::JsonSchema;
22
23use crate::Error;
24#[cfg(any(feature = "derive", test))]
25use crate::types::Tool;
26use crate::types::{ToolBinaryResult, ToolInvocation, ToolResult, ToolResultExpanded};
27
28#[cfg(feature = "derive")]
49pub fn schema_for<T: schemars::JsonSchema>() -> serde_json::Value {
50 let schema = schemars::schema_for!(T);
51 let mut value = serde_json::to_value(schema).expect("JSON Schema serialization cannot fail");
52 if let Some(obj) = value.as_object_mut() {
53 obj.remove("$schema");
54 obj.remove("title");
55 }
56 value
57}
58
59pub fn tool_parameters(schema: serde_json::Value) -> HashMap<String, serde_json::Value> {
84 try_tool_parameters(schema).expect("tool parameter schema must be a JSON object")
85}
86
87pub fn try_tool_parameters(
89 schema: serde_json::Value,
90) -> Result<HashMap<String, serde_json::Value>, serde_json::Error> {
91 serde_json::from_value(schema)
92}
93
94pub fn convert_mcp_call_tool_result(value: &serde_json::Value) -> Option<ToolResult> {
98 let content = value.get("content")?.as_array()?;
99 let mut text_parts = Vec::new();
100 let mut binary_results = Vec::new();
101
102 for block in content {
103 match block.get("type").and_then(serde_json::Value::as_str) {
104 Some("text") => {
105 if let Some(text) = block.get("text").and_then(serde_json::Value::as_str) {
106 text_parts.push(text.to_string());
107 }
108 }
109 Some("image") => {
110 let data = block
111 .get("data")
112 .and_then(serde_json::Value::as_str)
113 .filter(|s| !s.is_empty());
114 let mime_type = block
115 .get("mimeType")
116 .and_then(serde_json::Value::as_str)
117 .filter(|s| !s.is_empty());
118 if let (Some(data), Some(mime_type)) = (data, mime_type) {
119 binary_results.push(ToolBinaryResult {
120 data: data.to_string(),
121 mime_type: mime_type.to_string(),
122 r#type: "image".to_string(),
123 description: None,
124 });
125 }
126 }
127 Some("resource") => {
128 let Some(resource) = block.get("resource").and_then(serde_json::Value::as_object)
129 else {
130 continue;
131 };
132 if let Some(text) = resource
133 .get("text")
134 .and_then(serde_json::Value::as_str)
135 .filter(|s| !s.is_empty())
136 {
137 text_parts.push(text.to_string());
138 }
139 if let Some(blob) = resource
140 .get("blob")
141 .and_then(serde_json::Value::as_str)
142 .filter(|s| !s.is_empty())
143 {
144 let mime_type = resource
145 .get("mimeType")
146 .and_then(serde_json::Value::as_str)
147 .filter(|s| !s.is_empty())
148 .unwrap_or("application/octet-stream");
149 let description = resource
150 .get("uri")
151 .and_then(serde_json::Value::as_str)
152 .filter(|s| !s.is_empty())
153 .map(ToString::to_string);
154 binary_results.push(ToolBinaryResult {
155 data: blob.to_string(),
156 mime_type: mime_type.to_string(),
157 r#type: "resource".to_string(),
158 description,
159 });
160 }
161 }
162 _ => {}
163 }
164 }
165
166 Some(ToolResult::Expanded(ToolResultExpanded {
167 text_result_for_llm: text_parts.join("\n"),
168 result_type: if value.get("isError").and_then(serde_json::Value::as_bool) == Some(true) {
169 "failure".to_string()
170 } else {
171 "success".to_string()
172 },
173 binary_results_for_llm: (!binary_results.is_empty()).then_some(binary_results),
174 session_log: None,
175 error: None,
176 tool_telemetry: None,
177 }))
178}
179
180#[async_trait]
227pub trait ToolHandler: Send + Sync + 'static {
228 async fn call(&self, invocation: ToolInvocation) -> Result<ToolResult, Error>;
230}
231
232#[cfg(feature = "derive")]
290pub fn define_tool<P, F, Fut>(
291 name: impl Into<String>,
292 description: impl Into<String>,
293 handler: F,
294) -> Tool
295where
296 P: schemars::JsonSchema + serde::de::DeserializeOwned + Send + 'static,
297 F: Fn(ToolInvocation, P) -> Fut + Send + Sync + 'static,
298 Fut: std::future::Future<Output = Result<ToolResult, Error>> + Send + 'static,
299{
300 struct FnHandler<P, F> {
301 handler: F,
302 _marker: std::marker::PhantomData<fn(P)>,
303 }
304
305 #[async_trait]
306 impl<P, F, Fut> ToolHandler for FnHandler<P, F>
307 where
308 P: schemars::JsonSchema + serde::de::DeserializeOwned + Send + 'static,
309 F: Fn(ToolInvocation, P) -> Fut + Send + Sync + 'static,
310 Fut: std::future::Future<Output = Result<ToolResult, Error>> + Send + 'static,
311 {
312 async fn call(&self, mut invocation: ToolInvocation) -> Result<ToolResult, Error> {
313 let arguments = std::mem::take(&mut invocation.arguments);
314 let params: P = serde_json::from_value(arguments)?;
315 (self.handler)(invocation, params).await
316 }
317 }
318
319 Tool {
320 name: name.into(),
321 description: description.into(),
322 parameters: tool_parameters(schema_for::<P>()),
323 ..Default::default()
324 }
325 .with_handler(std::sync::Arc::new(FnHandler {
326 handler,
327 _marker: std::marker::PhantomData,
328 }))
329}
330
331#[cfg(feature = "derive")]
353pub fn define_tool_declaration<P>(name: impl Into<String>, description: impl Into<String>) -> Tool
354where
355 P: schemars::JsonSchema,
356{
357 Tool {
358 name: name.into(),
359 description: description.into(),
360 parameters: tool_parameters(schema_for::<P>()),
361 ..Default::default()
362 }
363}
364
365#[cfg(test)]
366mod tests {
367 use super::*;
368 use crate::types::SessionId;
369
370 struct EchoTool;
371
372 fn echo_tool() -> Tool {
373 Tool {
374 name: "echo".to_string(),
375 description: "Echo the input".to_string(),
376 parameters: tool_parameters(serde_json::json!({"type": "object"})),
377 ..Default::default()
378 }
379 .with_handler(std::sync::Arc::new(EchoTool))
380 }
381
382 #[async_trait]
383 impl ToolHandler for EchoTool {
384 async fn call(&self, inv: ToolInvocation) -> Result<ToolResult, Error> {
385 Ok(ToolResult::Text(inv.arguments.to_string()))
386 }
387 }
388
389 #[test]
390 fn tool_handler_returns_tool_definition() {
391 let def = echo_tool();
392 assert_eq!(def.name, "echo");
393 assert_eq!(def.description, "Echo the input");
394 assert!(def.parameters.contains_key("type"));
395 assert!(def.handler.is_some());
396 }
397
398 #[test]
399 fn try_tool_parameters_rejects_non_object_schema() {
400 let err = try_tool_parameters(serde_json::json!(["not", "an", "object"]))
401 .expect_err("non-object schemas should be rejected");
402
403 assert!(err.is_data());
404 }
405
406 #[test]
407 fn convert_mcp_call_tool_result_collects_text_and_binary_content() {
408 let result = convert_mcp_call_tool_result(&serde_json::json!({
409 "isError": true,
410 "content": [
411 { "type": "text", "text": "hello" },
412 { "type": "image", "data": "aW1n", "mimeType": "image/png" },
413 {
414 "type": "resource",
415 "resource": {
416 "uri": "file:///tmp/data.bin",
417 "blob": "Ymlu",
418 "mimeType": "application/octet-stream",
419 "text": "resource text"
420 }
421 }
422 ]
423 }))
424 .expect("valid CallToolResult should convert");
425
426 let ToolResult::Expanded(expanded) = result else {
427 panic!("expected expanded tool result");
428 };
429
430 assert_eq!(expanded.text_result_for_llm, "hello\nresource text");
431 assert_eq!(expanded.result_type, "failure");
432 let binary_results = expanded
433 .binary_results_for_llm
434 .expect("binary results should be captured");
435 assert_eq!(binary_results.len(), 2);
436 assert_eq!(binary_results[0].r#type, "image");
437 assert_eq!(binary_results[0].data, "aW1n");
438 assert_eq!(binary_results[0].mime_type, "image/png");
439 assert_eq!(
440 binary_results[1].description.as_deref(),
441 Some("file:///tmp/data.bin")
442 );
443 }
444
445 #[test]
446 fn convert_mcp_call_tool_result_converts_image_content() {
447 let result = convert_mcp_call_tool_result(&serde_json::json!({
448 "content": [
449 { "type": "image", "data": "aW1hZ2U=", "mimeType": "image/jpeg" }
450 ]
451 }))
452 .expect("valid CallToolResult should convert");
453
454 let ToolResult::Expanded(expanded) = result else {
455 panic!("expected expanded tool result");
456 };
457
458 assert_eq!(expanded.text_result_for_llm, "");
459 assert_eq!(expanded.result_type, "success");
460 let binary_results = expanded
461 .binary_results_for_llm
462 .expect("image result should be captured");
463 assert_eq!(binary_results.len(), 1);
464 assert_eq!(binary_results[0].data, "aW1hZ2U=");
465 assert_eq!(binary_results[0].mime_type, "image/jpeg");
466 assert_eq!(binary_results[0].r#type, "image");
467 assert!(binary_results[0].description.is_none());
468 }
469
470 #[test]
471 fn convert_mcp_call_tool_result_converts_resource_blob_content() {
472 let result = convert_mcp_call_tool_result(&serde_json::json!({
473 "content": [
474 {
475 "type": "resource",
476 "resource": {
477 "uri": "file:///tmp/report.pdf",
478 "blob": "cGRm",
479 "mimeType": "application/pdf"
480 }
481 }
482 ]
483 }))
484 .expect("valid CallToolResult should convert");
485
486 let ToolResult::Expanded(expanded) = result else {
487 panic!("expected expanded tool result");
488 };
489
490 let binary_results = expanded
491 .binary_results_for_llm
492 .expect("resource result should be captured");
493 assert_eq!(binary_results.len(), 1);
494 assert_eq!(binary_results[0].data, "cGRm");
495 assert_eq!(binary_results[0].mime_type, "application/pdf");
496 assert_eq!(binary_results[0].r#type, "resource");
497 assert_eq!(
498 binary_results[0].description.as_deref(),
499 Some("file:///tmp/report.pdf")
500 );
501 }
502
503 #[test]
504 fn convert_mcp_call_tool_result_defaults_resource_blob_mime_type() {
505 let result = convert_mcp_call_tool_result(&serde_json::json!({
506 "content": [
507 {
508 "type": "resource",
509 "resource": {
510 "uri": "file:///tmp/data.bin",
511 "blob": "Ymlu"
512 }
513 },
514 {
515 "type": "resource",
516 "resource": {
517 "blob": "YmluMg==",
518 "mimeType": ""
519 }
520 }
521 ]
522 }))
523 .expect("valid CallToolResult should convert");
524
525 let ToolResult::Expanded(expanded) = result else {
526 panic!("expected expanded tool result");
527 };
528
529 let binary_results = expanded
530 .binary_results_for_llm
531 .expect("resource blobs should be captured");
532 assert_eq!(binary_results.len(), 2);
533 assert_eq!(binary_results[0].mime_type, "application/octet-stream");
534 assert_eq!(binary_results[1].mime_type, "application/octet-stream");
535 }
536
537 #[test]
538 fn convert_mcp_call_tool_result_omits_binary_results_without_binary_content() {
539 let result = convert_mcp_call_tool_result(&serde_json::json!({
540 "content": [
541 { "type": "text", "text": "hello" },
542 {
543 "type": "resource",
544 "resource": {
545 "uri": "file:///tmp/readme.md",
546 "text": "resource text"
547 }
548 }
549 ]
550 }))
551 .expect("valid CallToolResult should convert");
552
553 let ToolResult::Expanded(expanded) = result else {
554 panic!("expected expanded tool result");
555 };
556
557 assert_eq!(expanded.text_result_for_llm, "hello\nresource text");
558 assert!(expanded.binary_results_for_llm.is_none());
559 }
560
561 #[tokio::test]
562 async fn tool_handler_call_returns_result() {
563 let tool = EchoTool;
564 let inv = ToolInvocation {
565 session_id: SessionId::from("s1"),
566 tool_call_id: "tc1".to_string(),
567 tool_name: "echo".to_string(),
568 arguments: serde_json::json!({"msg": "hello"}),
569 traceparent: None,
570 tracestate: None,
571 };
572
573 let result = tool.call(inv).await.unwrap();
574 match result {
575 ToolResult::Text(s) => assert!(s.contains("hello")),
576 _ => panic!("expected Text result"),
577 }
578 }
579
580 #[cfg(feature = "derive")]
581 #[tokio::test]
582 async fn define_tool_builds_schema_and_dispatches() {
583 use serde::Deserialize;
584
585 #[derive(Deserialize, schemars::JsonSchema)]
586 struct Params {
587 city: String,
588 }
589
590 let tool = define_tool(
591 "weather",
592 "Get the weather for a city",
593 |_inv, params: Params| async move {
594 Ok(ToolResult::Text(format!("sunny in {}", params.city)))
595 },
596 );
597
598 assert_eq!(tool.name, "weather");
599 assert_eq!(tool.description, "Get the weather for a city");
600 assert_eq!(tool.parameters["type"], "object");
601 assert!(tool.parameters["properties"]["city"].is_object());
602 let handler = tool.handler.as_ref().expect("define_tool attaches handler");
603
604 let inv = ToolInvocation {
605 session_id: SessionId::from("s1"),
606 tool_call_id: "tc1".to_string(),
607 tool_name: "weather".to_string(),
608 arguments: serde_json::json!({"city": "Seattle"}),
609 traceparent: None,
610 tracestate: None,
611 };
612 match handler.call(inv).await.unwrap() {
613 ToolResult::Text(s) => assert_eq!(s, "sunny in Seattle"),
614 _ => panic!("expected Text result"),
615 }
616 }
617
618 #[cfg(feature = "derive")]
620 mod derive_tests {
621 use serde::Deserialize;
622
623 use super::super::*;
624 use crate::{ErrorKind, SessionId};
625
626 #[derive(Deserialize, schemars::JsonSchema)]
627 struct GetWeatherParams {
628 city: String,
630 unit: Option<String>,
632 }
633
634 #[test]
635 fn schema_for_generates_clean_schema() {
636 let schema = schema_for::<GetWeatherParams>();
637 assert_eq!(schema["type"], "object");
638 assert!(schema["properties"]["city"].is_object());
639 assert!(schema["properties"]["unit"].is_object());
640 let required = schema["required"].as_array().unwrap();
642 assert!(required.contains(&serde_json::json!("city")));
643 assert!(!required.contains(&serde_json::json!("unit")));
644 assert!(schema.get("$schema").is_none());
646 assert!(schema.get("title").is_none());
647 }
648
649 struct GetWeatherTool;
650
651 fn get_weather_tool() -> Tool {
652 Tool {
653 name: "get_weather".to_string(),
654 description: "Get weather for a city".to_string(),
655 parameters: tool_parameters(schema_for::<GetWeatherParams>()),
656 ..Default::default()
657 }
658 .with_handler(std::sync::Arc::new(GetWeatherTool))
659 }
660
661 #[async_trait]
662 impl ToolHandler for GetWeatherTool {
663 async fn call(&self, inv: ToolInvocation) -> Result<ToolResult, Error> {
664 let params: GetWeatherParams = serde_json::from_value(inv.arguments)?;
665 Ok(ToolResult::Text(format!(
666 "{} {}",
667 params.city,
668 params.unit.unwrap_or_default()
669 )))
670 }
671 }
672
673 #[test]
674 fn tool_handler_with_schema_for() {
675 let def = get_weather_tool();
676 assert_eq!(def.name, "get_weather");
677 let schema = serde_json::to_value(&def.parameters).expect("serialize tool parameters");
678 assert_eq!(schema["type"], "object");
679 assert!(schema["properties"]["city"].is_object());
680 assert!(def.handler.is_some());
681 }
682
683 #[tokio::test]
684 async fn tool_handler_deserializes_typed_params() {
685 let tool = GetWeatherTool;
686 let inv = ToolInvocation {
687 session_id: SessionId::from("s1"),
688 tool_call_id: "tc1".to_string(),
689 tool_name: "get_weather".to_string(),
690 arguments: serde_json::json!({"city": "Seattle", "unit": "celsius"}),
691 traceparent: None,
692 tracestate: None,
693 };
694
695 let result = tool.call(inv).await.unwrap();
696 match result {
697 ToolResult::Text(s) => assert_eq!(s, "Seattle celsius"),
698 _ => panic!("expected Text result"),
699 }
700 }
701
702 #[tokio::test]
703 async fn tool_handler_returns_error_on_bad_params() {
704 let tool = GetWeatherTool;
705 let inv = ToolInvocation {
706 session_id: SessionId::from("s1"),
707 tool_call_id: "tc1".to_string(),
708 tool_name: "get_weather".to_string(),
709 arguments: serde_json::json!({"wrong_field": 42}),
710 traceparent: None,
711 tracestate: None,
712 };
713
714 let err = tool.call(inv).await.unwrap_err();
715 assert!(matches!(err.kind(), ErrorKind::Json));
716 }
717
718 #[tokio::test]
719 async fn schema_for_derived_tool_round_trips_through_call() {
720 let tool = GetWeatherTool;
721
722 let result = tool
726 .call(ToolInvocation {
727 session_id: SessionId::from("s1"),
728 tool_call_id: "tc1".to_string(),
729 tool_name: "get_weather".to_string(),
730 arguments: serde_json::json!({"city": "Portland"}),
731 traceparent: None,
732 tracestate: None,
733 })
734 .await
735 .expect("ToolHandler::call should succeed for matching args");
736 match result {
737 ToolResult::Text(s) => assert!(s.contains("Portland")),
738 _ => panic!("expected ToolResult::Text"),
739 }
740 }
741 }
742}