use std::collections::HashMap;
use async_trait::async_trait;
#[cfg(feature = "derive")]
pub use schemars::JsonSchema;
use crate::Error;
#[cfg(any(feature = "derive", test))]
use crate::types::Tool;
use crate::types::{ToolBinaryResult, ToolInvocation, ToolResult, ToolResultExpanded};
#[cfg(feature = "derive")]
pub fn schema_for<T: schemars::JsonSchema>() -> serde_json::Value {
let schema = schemars::schema_for!(T);
let mut value = serde_json::to_value(schema).expect("JSON Schema serialization cannot fail");
if let Some(obj) = value.as_object_mut() {
obj.remove("$schema");
obj.remove("title");
}
value
}
pub fn tool_parameters(schema: serde_json::Value) -> HashMap<String, serde_json::Value> {
try_tool_parameters(schema).expect("tool parameter schema must be a JSON object")
}
pub fn try_tool_parameters(
schema: serde_json::Value,
) -> Result<HashMap<String, serde_json::Value>, serde_json::Error> {
serde_json::from_value(schema)
}
pub fn convert_mcp_call_tool_result(value: &serde_json::Value) -> Option<ToolResult> {
let content = value.get("content")?.as_array()?;
let mut text_parts = Vec::new();
let mut binary_results = Vec::new();
for block in content {
match block.get("type").and_then(serde_json::Value::as_str) {
Some("text") => {
if let Some(text) = block.get("text").and_then(serde_json::Value::as_str) {
text_parts.push(text.to_string());
}
}
Some("image") => {
let data = block
.get("data")
.and_then(serde_json::Value::as_str)
.filter(|s| !s.is_empty());
let mime_type = block
.get("mimeType")
.and_then(serde_json::Value::as_str)
.filter(|s| !s.is_empty());
if let (Some(data), Some(mime_type)) = (data, mime_type) {
binary_results.push(ToolBinaryResult {
data: data.to_string(),
mime_type: mime_type.to_string(),
r#type: "image".to_string(),
description: None,
});
}
}
Some("resource") => {
let Some(resource) = block.get("resource").and_then(serde_json::Value::as_object)
else {
continue;
};
if let Some(text) = resource
.get("text")
.and_then(serde_json::Value::as_str)
.filter(|s| !s.is_empty())
{
text_parts.push(text.to_string());
}
if let Some(blob) = resource
.get("blob")
.and_then(serde_json::Value::as_str)
.filter(|s| !s.is_empty())
{
let mime_type = resource
.get("mimeType")
.and_then(serde_json::Value::as_str)
.filter(|s| !s.is_empty())
.unwrap_or("application/octet-stream");
let description = resource
.get("uri")
.and_then(serde_json::Value::as_str)
.filter(|s| !s.is_empty())
.map(ToString::to_string);
binary_results.push(ToolBinaryResult {
data: blob.to_string(),
mime_type: mime_type.to_string(),
r#type: "resource".to_string(),
description,
});
}
}
_ => {}
}
}
Some(ToolResult::Expanded(ToolResultExpanded {
text_result_for_llm: text_parts.join("\n"),
result_type: if value.get("isError").and_then(serde_json::Value::as_bool) == Some(true) {
"failure".to_string()
} else {
"success".to_string()
},
binary_results_for_llm: (!binary_results.is_empty()).then_some(binary_results),
session_log: None,
error: None,
tool_telemetry: None,
}))
}
#[async_trait]
pub trait ToolHandler: Send + Sync + 'static {
async fn call(&self, invocation: ToolInvocation) -> Result<ToolResult, Error>;
}
#[cfg(feature = "derive")]
pub fn define_tool<P, F, Fut>(
name: impl Into<String>,
description: impl Into<String>,
handler: F,
) -> Tool
where
P: schemars::JsonSchema + serde::de::DeserializeOwned + Send + 'static,
F: Fn(ToolInvocation, P) -> Fut + Send + Sync + 'static,
Fut: std::future::Future<Output = Result<ToolResult, Error>> + Send + 'static,
{
struct FnHandler<P, F> {
handler: F,
_marker: std::marker::PhantomData<fn(P)>,
}
#[async_trait]
impl<P, F, Fut> ToolHandler for FnHandler<P, F>
where
P: schemars::JsonSchema + serde::de::DeserializeOwned + Send + 'static,
F: Fn(ToolInvocation, P) -> Fut + Send + Sync + 'static,
Fut: std::future::Future<Output = Result<ToolResult, Error>> + Send + 'static,
{
async fn call(&self, mut invocation: ToolInvocation) -> Result<ToolResult, Error> {
let arguments = std::mem::take(&mut invocation.arguments);
let params: P = serde_json::from_value(arguments)?;
(self.handler)(invocation, params).await
}
}
Tool {
name: name.into(),
description: description.into(),
parameters: tool_parameters(schema_for::<P>()),
..Default::default()
}
.with_handler(std::sync::Arc::new(FnHandler {
handler,
_marker: std::marker::PhantomData,
}))
}
#[cfg(feature = "derive")]
pub fn define_tool_declaration<P>(name: impl Into<String>, description: impl Into<String>) -> Tool
where
P: schemars::JsonSchema,
{
Tool {
name: name.into(),
description: description.into(),
parameters: tool_parameters(schema_for::<P>()),
..Default::default()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::SessionId;
struct EchoTool;
fn echo_tool() -> Tool {
Tool {
name: "echo".to_string(),
description: "Echo the input".to_string(),
parameters: tool_parameters(serde_json::json!({"type": "object"})),
..Default::default()
}
.with_handler(std::sync::Arc::new(EchoTool))
}
#[async_trait]
impl ToolHandler for EchoTool {
async fn call(&self, inv: ToolInvocation) -> Result<ToolResult, Error> {
Ok(ToolResult::Text(inv.arguments.to_string()))
}
}
#[test]
fn tool_handler_returns_tool_definition() {
let def = echo_tool();
assert_eq!(def.name, "echo");
assert_eq!(def.description, "Echo the input");
assert!(def.parameters.contains_key("type"));
assert!(def.handler.is_some());
}
#[test]
fn try_tool_parameters_rejects_non_object_schema() {
let err = try_tool_parameters(serde_json::json!(["not", "an", "object"]))
.expect_err("non-object schemas should be rejected");
assert!(err.is_data());
}
#[test]
fn convert_mcp_call_tool_result_collects_text_and_binary_content() {
let result = convert_mcp_call_tool_result(&serde_json::json!({
"isError": true,
"content": [
{ "type": "text", "text": "hello" },
{ "type": "image", "data": "aW1n", "mimeType": "image/png" },
{
"type": "resource",
"resource": {
"uri": "file:///tmp/data.bin",
"blob": "Ymlu",
"mimeType": "application/octet-stream",
"text": "resource text"
}
}
]
}))
.expect("valid CallToolResult should convert");
let ToolResult::Expanded(expanded) = result else {
panic!("expected expanded tool result");
};
assert_eq!(expanded.text_result_for_llm, "hello\nresource text");
assert_eq!(expanded.result_type, "failure");
let binary_results = expanded
.binary_results_for_llm
.expect("binary results should be captured");
assert_eq!(binary_results.len(), 2);
assert_eq!(binary_results[0].r#type, "image");
assert_eq!(binary_results[0].data, "aW1n");
assert_eq!(binary_results[0].mime_type, "image/png");
assert_eq!(
binary_results[1].description.as_deref(),
Some("file:///tmp/data.bin")
);
}
#[test]
fn convert_mcp_call_tool_result_converts_image_content() {
let result = convert_mcp_call_tool_result(&serde_json::json!({
"content": [
{ "type": "image", "data": "aW1hZ2U=", "mimeType": "image/jpeg" }
]
}))
.expect("valid CallToolResult should convert");
let ToolResult::Expanded(expanded) = result else {
panic!("expected expanded tool result");
};
assert_eq!(expanded.text_result_for_llm, "");
assert_eq!(expanded.result_type, "success");
let binary_results = expanded
.binary_results_for_llm
.expect("image result should be captured");
assert_eq!(binary_results.len(), 1);
assert_eq!(binary_results[0].data, "aW1hZ2U=");
assert_eq!(binary_results[0].mime_type, "image/jpeg");
assert_eq!(binary_results[0].r#type, "image");
assert!(binary_results[0].description.is_none());
}
#[test]
fn convert_mcp_call_tool_result_converts_resource_blob_content() {
let result = convert_mcp_call_tool_result(&serde_json::json!({
"content": [
{
"type": "resource",
"resource": {
"uri": "file:///tmp/report.pdf",
"blob": "cGRm",
"mimeType": "application/pdf"
}
}
]
}))
.expect("valid CallToolResult should convert");
let ToolResult::Expanded(expanded) = result else {
panic!("expected expanded tool result");
};
let binary_results = expanded
.binary_results_for_llm
.expect("resource result should be captured");
assert_eq!(binary_results.len(), 1);
assert_eq!(binary_results[0].data, "cGRm");
assert_eq!(binary_results[0].mime_type, "application/pdf");
assert_eq!(binary_results[0].r#type, "resource");
assert_eq!(
binary_results[0].description.as_deref(),
Some("file:///tmp/report.pdf")
);
}
#[test]
fn convert_mcp_call_tool_result_defaults_resource_blob_mime_type() {
let result = convert_mcp_call_tool_result(&serde_json::json!({
"content": [
{
"type": "resource",
"resource": {
"uri": "file:///tmp/data.bin",
"blob": "Ymlu"
}
},
{
"type": "resource",
"resource": {
"blob": "YmluMg==",
"mimeType": ""
}
}
]
}))
.expect("valid CallToolResult should convert");
let ToolResult::Expanded(expanded) = result else {
panic!("expected expanded tool result");
};
let binary_results = expanded
.binary_results_for_llm
.expect("resource blobs should be captured");
assert_eq!(binary_results.len(), 2);
assert_eq!(binary_results[0].mime_type, "application/octet-stream");
assert_eq!(binary_results[1].mime_type, "application/octet-stream");
}
#[test]
fn convert_mcp_call_tool_result_omits_binary_results_without_binary_content() {
let result = convert_mcp_call_tool_result(&serde_json::json!({
"content": [
{ "type": "text", "text": "hello" },
{
"type": "resource",
"resource": {
"uri": "file:///tmp/readme.md",
"text": "resource text"
}
}
]
}))
.expect("valid CallToolResult should convert");
let ToolResult::Expanded(expanded) = result else {
panic!("expected expanded tool result");
};
assert_eq!(expanded.text_result_for_llm, "hello\nresource text");
assert!(expanded.binary_results_for_llm.is_none());
}
#[tokio::test]
async fn tool_handler_call_returns_result() {
let tool = EchoTool;
let inv = ToolInvocation {
session_id: SessionId::from("s1"),
tool_call_id: "tc1".to_string(),
tool_name: "echo".to_string(),
arguments: serde_json::json!({"msg": "hello"}),
traceparent: None,
tracestate: None,
};
let result = tool.call(inv).await.unwrap();
match result {
ToolResult::Text(s) => assert!(s.contains("hello")),
_ => panic!("expected Text result"),
}
}
#[cfg(feature = "derive")]
#[tokio::test]
async fn define_tool_builds_schema_and_dispatches() {
use serde::Deserialize;
#[derive(Deserialize, schemars::JsonSchema)]
struct Params {
city: String,
}
let tool = define_tool(
"weather",
"Get the weather for a city",
|_inv, params: Params| async move {
Ok(ToolResult::Text(format!("sunny in {}", params.city)))
},
);
assert_eq!(tool.name, "weather");
assert_eq!(tool.description, "Get the weather for a city");
assert_eq!(tool.parameters["type"], "object");
assert!(tool.parameters["properties"]["city"].is_object());
let handler = tool.handler.as_ref().expect("define_tool attaches handler");
let inv = ToolInvocation {
session_id: SessionId::from("s1"),
tool_call_id: "tc1".to_string(),
tool_name: "weather".to_string(),
arguments: serde_json::json!({"city": "Seattle"}),
traceparent: None,
tracestate: None,
};
match handler.call(inv).await.unwrap() {
ToolResult::Text(s) => assert_eq!(s, "sunny in Seattle"),
_ => panic!("expected Text result"),
}
}
#[cfg(feature = "derive")]
mod derive_tests {
use serde::Deserialize;
use super::super::*;
use crate::SessionId;
#[derive(Deserialize, schemars::JsonSchema)]
struct GetWeatherParams {
city: String,
unit: Option<String>,
}
#[test]
fn schema_for_generates_clean_schema() {
let schema = schema_for::<GetWeatherParams>();
assert_eq!(schema["type"], "object");
assert!(schema["properties"]["city"].is_object());
assert!(schema["properties"]["unit"].is_object());
let required = schema["required"].as_array().unwrap();
assert!(required.contains(&serde_json::json!("city")));
assert!(!required.contains(&serde_json::json!("unit")));
assert!(schema.get("$schema").is_none());
assert!(schema.get("title").is_none());
}
struct GetWeatherTool;
fn get_weather_tool() -> Tool {
Tool {
name: "get_weather".to_string(),
description: "Get weather for a city".to_string(),
parameters: tool_parameters(schema_for::<GetWeatherParams>()),
..Default::default()
}
.with_handler(std::sync::Arc::new(GetWeatherTool))
}
#[async_trait]
impl ToolHandler for GetWeatherTool {
async fn call(&self, inv: ToolInvocation) -> Result<ToolResult, Error> {
let params: GetWeatherParams = serde_json::from_value(inv.arguments)?;
Ok(ToolResult::Text(format!(
"{} {}",
params.city,
params.unit.unwrap_or_default()
)))
}
}
#[test]
fn tool_handler_with_schema_for() {
let def = get_weather_tool();
assert_eq!(def.name, "get_weather");
let schema = serde_json::to_value(&def.parameters).expect("serialize tool parameters");
assert_eq!(schema["type"], "object");
assert!(schema["properties"]["city"].is_object());
assert!(def.handler.is_some());
}
#[tokio::test]
async fn tool_handler_deserializes_typed_params() {
let tool = GetWeatherTool;
let inv = ToolInvocation {
session_id: SessionId::from("s1"),
tool_call_id: "tc1".to_string(),
tool_name: "get_weather".to_string(),
arguments: serde_json::json!({"city": "Seattle", "unit": "celsius"}),
traceparent: None,
tracestate: None,
};
let result = tool.call(inv).await.unwrap();
match result {
ToolResult::Text(s) => assert_eq!(s, "Seattle celsius"),
_ => panic!("expected Text result"),
}
}
#[tokio::test]
async fn tool_handler_returns_error_on_bad_params() {
let tool = GetWeatherTool;
let inv = ToolInvocation {
session_id: SessionId::from("s1"),
tool_call_id: "tc1".to_string(),
tool_name: "get_weather".to_string(),
arguments: serde_json::json!({"wrong_field": 42}),
traceparent: None,
tracestate: None,
};
let err = tool.call(inv).await.unwrap_err();
assert!(matches!(err, Error::Json(_)));
}
#[tokio::test]
async fn schema_for_derived_tool_round_trips_through_call() {
let tool = GetWeatherTool;
let result = tool
.call(ToolInvocation {
session_id: SessionId::from("s1"),
tool_call_id: "tc1".to_string(),
tool_name: "get_weather".to_string(),
arguments: serde_json::json!({"city": "Portland"}),
traceparent: None,
tracestate: None,
})
.await
.expect("ToolHandler::call should succeed for matching args");
match result {
ToolResult::Text(s) => assert!(s.contains("Portland")),
_ => panic!("expected ToolResult::Text"),
}
}
}
}