use serde::{Deserialize, Serialize};
use std::future::Future;
use std::sync::Arc;
use crate::error::Result;
use super::ToolHandler;
pub struct SdkMcpTool {
pub(crate) name: String,
pub(crate) description: String,
pub(crate) input_schema: serde_json::Value,
pub(crate) handler: ToolHandler,
}
impl SdkMcpTool {
pub fn new<F, Fut>(
name: impl Into<String>,
description: impl Into<String>,
input_schema: serde_json::Value,
handler: F,
) -> Self
where
F: Fn(serde_json::Value) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<ToolResult>> + Send + 'static,
{
Self {
name: name.into(),
description: description.into(),
input_schema,
handler: Arc::new(move |input| Box::pin(handler(input))),
}
}
pub fn name(&self) -> &str {
&self.name
}
pub fn description(&self) -> &str {
&self.description
}
pub fn input_schema(&self) -> &serde_json::Value {
&self.input_schema
}
pub async fn invoke(&self, input: serde_json::Value) -> Result<ToolResult> {
(self.handler)(input).await
}
pub(crate) fn to_tool_info(&self) -> serde_json::Value {
serde_json::json!({
"name": self.name,
"description": self.description,
"inputSchema": self.input_schema,
})
}
}
impl std::fmt::Debug for SdkMcpTool {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SdkMcpTool")
.field("name", &self.name)
.field("description", &self.description)
.field("input_schema", &self.input_schema)
.finish()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolResult {
pub content: Vec<ToolContent>,
#[serde(skip_serializing_if = "Option::is_none")]
pub is_error: Option<bool>,
}
impl ToolResult {
pub fn text(text: impl Into<String>) -> Self {
Self {
content: vec![ToolContent::Text { text: text.into() }],
is_error: None,
}
}
pub fn error(text: impl Into<String>) -> Self {
Self {
content: vec![ToolContent::Text { text: text.into() }],
is_error: Some(true),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "lowercase")]
pub enum ToolContent {
Text {
text: String,
},
Image {
data: String,
#[serde(rename = "mimeType")]
mime_type: String,
},
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[tokio::test]
async fn test_tool_creation() {
let tool = SdkMcpTool::new(
"test",
"Test tool",
json!({"type": "object"}),
|_input| async { Ok(ToolResult::text("test")) },
);
assert_eq!(tool.name(), "test");
assert_eq!(tool.description(), "Test tool");
}
#[tokio::test]
async fn test_tool_invocation() {
let tool = SdkMcpTool::new(
"echo",
"Echo tool",
json!({"type": "object"}),
|input| async move {
let text = input["text"].as_str().unwrap_or("empty");
Ok(ToolResult::text(text))
},
);
let result = tool.invoke(json!({"text": "hello"})).await.unwrap();
assert_eq!(result.content.len(), 1);
if let ToolContent::Text { text } = &result.content[0] {
assert_eq!(text, "hello");
} else {
panic!("Expected text content");
}
}
#[test]
fn test_tool_result_text() {
let result = ToolResult::text("success");
assert_eq!(result.content.len(), 1);
assert!(result.is_error.is_none());
}
#[test]
fn test_tool_result_error() {
let result = ToolResult::error("failed");
assert_eq!(result.content.len(), 1);
assert_eq!(result.is_error, Some(true));
}
#[test]
fn test_tool_content_serialization() {
let content = ToolContent::Text {
text: "test".to_string(),
};
let json = serde_json::to_string(&content).unwrap();
assert!(json.contains("\"type\":\"text\""));
assert!(json.contains("\"text\":\"test\""));
}
#[test]
fn test_tool_info() {
let tool = SdkMcpTool::new(
"test",
"Test tool",
json!({"type": "object"}),
|_| async { Ok(ToolResult::text("test")) },
);
let info = tool.to_tool_info();
assert_eq!(info["name"], "test");
assert_eq!(info["description"], "Test tool");
assert!(info["inputSchema"].is_object());
}
}