use std::sync::Arc;
use futures::future::BoxFuture;
use rmcp::{
ErrorData,
model::{CallToolRequestParams, CallToolResult, Tool},
};
use schemars::JsonSchema;
use serde::de::DeserializeOwned;
use super::PluginContext;
pub(crate) type ToolHandler = Arc<
dyn Fn(
Arc<dyn std::any::Any + Send + Sync>,
CallToolRequestParams,
) -> BoxFuture<'static, Result<CallToolResult, ErrorData>>
+ Send
+ Sync,
>;
pub struct ToolDescriptor {
pub name: &'static str,
pub description: &'static str,
pub(crate) tool: Tool,
pub(crate) handler: ToolHandler,
}
impl std::fmt::Debug for ToolDescriptor {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ToolDescriptor")
.field("name", &self.name)
.field("description", &self.description)
.finish()
}
}
impl Clone for ToolDescriptor {
fn clone(&self) -> Self {
Self {
name: self.name,
description: self.description,
tool: self.tool.clone(),
handler: Arc::clone(&self.handler),
}
}
}
fn build_tool<T: JsonSchema>(name: &'static str, description: &'static str) -> Tool {
let schema_value = serde_json::to_value(schemars::schema_for!(T))
.unwrap_or(serde_json::Value::Object(Default::default()));
let schema_obj = match schema_value {
serde_json::Value::Object(m) => Arc::new(m),
_ => Arc::new(Default::default()),
};
Tool::new(name, description, schema_obj)
}
pub fn make_descriptor<T, F>(
name: &'static str,
description: &'static str,
handler: F,
) -> ToolDescriptor
where
T: DeserializeOwned + JsonSchema + 'static,
F: Fn(T) -> BoxFuture<'static, Result<CallToolResult, ErrorData>> + Send + Sync + 'static,
{
let tool = build_tool::<T>(name, description);
let handler = Arc::new(
move |_ctx: Arc<dyn std::any::Any + Send + Sync>, params: CallToolRequestParams| {
let value = serde_json::Value::Object(params.arguments.clone().unwrap_or_default());
match serde_json::from_value::<T>(value) {
Ok(typed) => handler(typed),
Err(e) => {
Box::pin(async move { Err(ErrorData::invalid_params(e.to_string(), None)) })
}
}
},
);
ToolDescriptor {
name,
description,
tool,
handler,
}
}
pub fn make_descriptor_ctx<Ctx, T, F>(
name: &'static str,
description: &'static str,
handler: F,
) -> ToolDescriptor
where
Ctx: PluginContext,
T: DeserializeOwned + JsonSchema + 'static,
F: Fn(Arc<Ctx>, T) -> BoxFuture<'static, Result<CallToolResult, ErrorData>>
+ Send
+ Sync
+ 'static,
{
let tool = build_tool::<T>(name, description);
let handler = Arc::new(
move |ctx: Arc<dyn std::any::Any + Send + Sync>, params: CallToolRequestParams| {
let ctx = ctx.downcast::<Ctx>().unwrap_or_else(|_| {
panic!(
"context type mismatch: expected {}",
std::any::type_name::<Ctx>()
)
});
let value = serde_json::Value::Object(params.arguments.clone().unwrap_or_default());
match serde_json::from_value::<T>(value) {
Ok(typed) => handler(ctx, typed),
Err(e) => {
Box::pin(async move { Err(ErrorData::invalid_params(e.to_string(), None)) })
}
}
},
);
ToolDescriptor {
name,
description,
tool,
handler,
}
}
#[derive(Debug)]
pub struct PluginToolRegistration {
pub plugin: &'static str,
pub name: &'static str,
pub constructor: fn() -> ToolDescriptor,
}
inventory::collect!(PluginToolRegistration);
impl ToolDescriptor {
pub fn as_tool(&self) -> Tool {
self.tool.clone()
}
pub fn dispatch(
&self,
ctx: Arc<dyn std::any::Any + Send + Sync>,
params: CallToolRequestParams,
) -> BoxFuture<'static, Result<CallToolResult, ErrorData>> {
(self.handler)(ctx, params)
}
}