use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use serde::{Deserialize, Serialize};
use serde_json::Value;
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum ToolContent {
#[serde(rename = "text")]
Text {
text: String,
},
#[serde(rename = "image")]
Image {
data: String,
#[serde(rename = "mimeType")]
mime_type: String,
},
}
impl ToolContent {
pub fn text(text: impl Into<String>) -> Self {
Self::Text { text: text.into() }
}
pub fn image(data: impl Into<String>, mime_type: impl Into<String>) -> Self {
Self::Image {
data: data.into(),
mime_type: mime_type.into(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolResult {
pub content: Vec<ToolContent>,
#[serde(skip_serializing_if = "Option::is_none")]
pub is_error: Option<bool>,
}
impl ToolResult {
pub fn text(text: impl Into<String>) -> Self {
Self {
content: vec![ToolContent::text(text)],
is_error: None,
}
}
pub fn error(message: impl Into<String>) -> Self {
Self {
content: vec![ToolContent::text(message)],
is_error: Some(true),
}
}
pub fn with_content(content: Vec<ToolContent>) -> Self {
Self {
content,
is_error: None,
}
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ToolAnnotations {
#[serde(skip_serializing_if = "Option::is_none")]
pub title: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub read_only_hint: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub destructive_hint: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub idempotent_hint: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub open_world_hint: Option<bool>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolInputSchema {
#[serde(rename = "type")]
pub schema_type: String,
#[serde(default, skip_serializing_if = "HashMap::is_empty")]
pub properties: HashMap<String, Value>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub required: Vec<String>,
}
impl ToolInputSchema {
pub fn object() -> Self {
Self {
schema_type: "object".to_string(),
properties: HashMap::new(),
required: Vec::new(),
}
}
pub fn string_property(
mut self,
name: impl Into<String>,
description: impl Into<String>,
) -> Self {
self.properties.insert(
name.into(),
serde_json::json!({
"type": "string",
"description": description.into()
}),
);
self
}
pub fn number_property(
mut self,
name: impl Into<String>,
description: impl Into<String>,
) -> Self {
self.properties.insert(
name.into(),
serde_json::json!({
"type": "number",
"description": description.into()
}),
);
self
}
pub fn boolean_property(
mut self,
name: impl Into<String>,
description: impl Into<String>,
) -> Self {
self.properties.insert(
name.into(),
serde_json::json!({
"type": "boolean",
"description": description.into()
}),
);
self
}
pub fn required_property(mut self, name: impl Into<String>) -> Self {
self.required.push(name.into());
self
}
}
pub type ToolHandler =
Arc<dyn Fn(Value) -> Pin<Box<dyn Future<Output = ToolResult> + Send>> + Send + Sync>;
pub struct SdkMcpTool {
pub name: String,
pub description: String,
pub input_schema: ToolInputSchema,
pub handler: ToolHandler,
pub annotations: Option<ToolAnnotations>,
}
impl SdkMcpTool {
pub fn new<F, Fut>(
name: impl Into<String>,
description: impl Into<String>,
input_schema: ToolInputSchema,
handler: F,
) -> Self
where
F: Fn(Value) -> Fut + Send + Sync + 'static,
Fut: Future<Output = ToolResult> + Send + 'static,
{
Self {
name: name.into(),
description: description.into(),
input_schema,
handler: Arc::new(move |input| Box::pin(handler(input))),
annotations: None,
}
}
pub fn with_annotations(mut self, annotations: ToolAnnotations) -> Self {
self.annotations = Some(annotations);
self
}
}
impl std::fmt::Debug for SdkMcpTool {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SdkMcpTool")
.field("name", &self.name)
.field("description", &self.description)
.field("input_schema", &self.input_schema)
.finish()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct McpSdkServerConfig {
#[serde(rename = "type")]
pub server_type: String,
pub name: String,
pub version: String,
}
pub fn create_sdk_mcp_server(
name: impl Into<String>,
version: impl Into<String>,
tools: Vec<SdkMcpTool>,
) -> (McpSdkServerConfig, Vec<SdkMcpTool>) {
let config = McpSdkServerConfig {
server_type: "sdk".to_string(),
name: name.into(),
version: version.into(),
};
(config, tools)
}
#[macro_export]
macro_rules! tool {
(
$(#[$meta:meta])*
fn $name:ident($($arg:ident: $type:ty),*) -> $ret:ty $body:block
) => {
{
use $crate::mcp::{SdkMcpTool, ToolInputSchema, ToolResult};
let mut schema = ToolInputSchema::object();
$(
schema = schema.string_property(stringify!($arg), "");
schema = schema.required_property(stringify!($arg));
)*
SdkMcpTool::new(
stringify!($name),
"",
schema,
|input: serde_json::Value| async move {
$(
let $arg: $type = serde_json::from_value(
input.get(stringify!($arg)).cloned().unwrap_or_default()
).unwrap_or_default();
)*
$body
},
)
}
};
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tool_content_text() {
let content = ToolContent::text("Hello");
match content {
ToolContent::Text { text } => assert_eq!(text, "Hello"),
_ => panic!("Expected text content"),
}
}
#[test]
fn test_tool_result_text() {
let result = ToolResult::text("Success");
assert_eq!(result.content.len(), 1);
assert!(result.is_error.is_none());
}
#[test]
fn test_tool_result_error() {
let result = ToolResult::error("Something went wrong");
assert_eq!(result.is_error, Some(true));
}
#[test]
fn test_input_schema_builder() {
let schema = ToolInputSchema::object()
.string_property("name", "The name")
.number_property("age", "The age")
.required_property("name");
assert_eq!(schema.schema_type, "object");
assert!(schema.properties.contains_key("name"));
assert!(schema.properties.contains_key("age"));
assert!(schema.required.contains(&"name".to_string()));
}
#[test]
fn test_create_sdk_server() {
let tool = SdkMcpTool::new("test", "Test tool", ToolInputSchema::object(), |_| async {
ToolResult::text("ok")
});
let (config, tools) = create_sdk_mcp_server("test-server", "1.0.0", vec![tool]);
assert_eq!(config.server_type, "sdk");
assert_eq!(config.name, "test-server");
assert_eq!(config.version, "1.0.0");
assert_eq!(tools.len(), 1);
}
#[test]
fn test_tool_annotations_default() {
let annotations = ToolAnnotations::default();
assert!(annotations.title.is_none());
assert!(annotations.read_only_hint.is_none());
assert!(annotations.destructive_hint.is_none());
assert!(annotations.idempotent_hint.is_none());
assert!(annotations.open_world_hint.is_none());
}
#[test]
fn test_tool_annotations_serialization() {
let annotations = ToolAnnotations {
title: Some("My Tool".to_string()),
read_only_hint: Some(true),
destructive_hint: Some(false),
idempotent_hint: Some(true),
open_world_hint: Some(false),
};
let json = serde_json::to_value(&annotations).unwrap();
assert_eq!(json["title"], "My Tool");
assert_eq!(json["readOnlyHint"], true);
assert_eq!(json["destructiveHint"], false);
assert_eq!(json["idempotentHint"], true);
assert_eq!(json["openWorldHint"], false);
}
#[test]
fn test_tool_annotations_skips_none() {
let annotations = ToolAnnotations {
read_only_hint: Some(true),
..Default::default()
};
let json = serde_json::to_value(&annotations).unwrap();
assert_eq!(json["readOnlyHint"], true);
assert!(json.get("title").is_none());
assert!(json.get("destructiveHint").is_none());
}
#[test]
fn test_sdk_mcp_tool_with_annotations() {
let tool = SdkMcpTool::new("test", "Test tool", ToolInputSchema::object(), |_| async {
ToolResult::text("ok")
})
.with_annotations(ToolAnnotations {
read_only_hint: Some(true),
..Default::default()
});
assert!(tool.annotations.is_some());
assert_eq!(tool.annotations.unwrap().read_only_hint, Some(true));
}
#[test]
fn test_sdk_mcp_tool_no_annotations_by_default() {
let tool = SdkMcpTool::new("test", "Test tool", ToolInputSchema::object(), |_| async {
ToolResult::text("ok")
});
assert!(tool.annotations.is_none());
}
}