use std::{borrow::Cow, future::Future, sync::Arc};
use serde::{Deserialize, Serialize};
use crate::{
ErrorData,
handler::server::{
common::schema_for_empty_input,
tool::{schema_for_output, schema_for_type},
wrapper::{Json, Parameters},
},
model::{Icon, JsonObject, Meta, ToolAnnotations, ToolExecution},
schemars::JsonSchema,
service::{MaybeSend, MaybeSendFuture},
};
pub trait ToolBase {
type Parameter: for<'de> Deserialize<'de> + JsonSchema + Send + Default + 'static;
type Output: Serialize + JsonSchema + Send + 'static;
type Error: Into<ErrorData> + Send + 'static;
fn name() -> Cow<'static, str>;
fn title() -> Option<String> {
None
}
fn description() -> Option<Cow<'static, str>> {
None
}
fn input_schema() -> Option<Arc<JsonObject>> {
Some(schema_for_type::<Parameters<Self::Parameter>>())
}
fn output_schema() -> Option<Arc<JsonObject>> {
Some(schema_for_output::<Self::Output>().unwrap_or_else(|e| {
panic!(
"Invalid output schema for ToolBase::Output type `{0}`: {1}",
std::any::type_name::<Self::Output>(),
e,
);
}))
}
fn annotations() -> Option<ToolAnnotations> {
None
}
fn execution() -> Option<ToolExecution> {
None
}
fn icons() -> Option<Vec<Icon>> {
None
}
fn meta() -> Option<Meta> {
None
}
}
#[allow(private_bounds)]
pub trait SyncTool<S: MaybeSend + 'static>: ToolBase {
fn invoke(service: &S, param: Self::Parameter) -> Result<Self::Output, Self::Error>;
}
#[allow(private_bounds)]
pub trait AsyncTool<S: MaybeSend + 'static>: ToolBase {
fn invoke(
service: &S,
param: Self::Parameter,
) -> impl Future<Output = Result<Self::Output, Self::Error>> + MaybeSendFuture;
}
pub(crate) fn tool_attribute<T: ToolBase>() -> crate::model::Tool {
crate::model::Tool {
name: T::name(),
title: T::title(),
description: T::description(),
input_schema: T::input_schema().unwrap_or_else(schema_for_empty_input),
output_schema: T::output_schema(),
annotations: T::annotations(),
execution: T::execution(),
icons: T::icons(),
meta: T::meta(),
}
}
pub(crate) fn sync_tool_wrapper<S: MaybeSend + 'static, T: SyncTool<S>>(
service: &S,
Parameters(params): Parameters<T::Parameter>,
) -> Result<Json<T::Output>, ErrorData> {
T::invoke(service, params).map(Json).map_err(Into::into)
}
pub(crate) fn sync_tool_wrapper_with_empty_params<S: MaybeSend + 'static, T: SyncTool<S>>(
service: &S,
) -> Result<Json<T::Output>, ErrorData> {
T::invoke(service, T::Parameter::default())
.map(Json)
.map_err(Into::into)
}
pub(crate) fn async_tool_wrapper<S: MaybeSend + 'static, T: AsyncTool<S>>(
service: &S,
Parameters(params): Parameters<T::Parameter>,
) -> crate::service::MaybeBoxFuture<'_, Result<Json<T::Output>, ErrorData>> {
Box::pin(async move {
T::invoke(service, params)
.await
.map(Json)
.map_err(Into::into)
})
}
pub(crate) fn async_tool_wrapper_with_empty_params<S: MaybeSend + 'static, T: AsyncTool<S>>(
service: &S,
) -> crate::service::MaybeBoxFuture<'_, Result<Json<T::Output>, ErrorData>> {
Box::pin(async move {
T::invoke(service, T::Parameter::default())
.await
.map(Json)
.map_err(Into::into)
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate as rmcp;
use crate::tool;
#[derive(Deserialize, schemars::JsonSchema, Default)]
struct AddParameter {
left: usize,
right: usize,
}
#[derive(Serialize, schemars::JsonSchema, PartialEq, Debug)]
struct AddOutput {
sum: usize,
}
struct MacroBasedToolServer;
impl MacroBasedToolServer {
#[expect(unused)]
#[tool(name = "adder", description = "Modular add two integers")]
fn add(
&self,
Parameters(AddParameter { left, right }): Parameters<AddParameter>,
) -> Json<AddOutput> {
Json(AddOutput {
sum: left.wrapping_add(right),
})
}
#[expect(unused)]
#[tool(name = "empty", description = "Empty tool")]
fn empty(&self) {}
}
struct AddTool;
impl ToolBase for AddTool {
type Parameter = AddParameter;
type Output = AddOutput;
type Error = ErrorData;
fn name() -> Cow<'static, str> {
"adder".into()
}
fn description() -> Option<Cow<'static, str>> {
Some("Modular add two integers".into())
}
}
impl SyncTool<TraitBasedToolServer> for AddTool {
fn invoke(
_service: &TraitBasedToolServer,
AddParameter { left, right }: Self::Parameter,
) -> Result<Self::Output, Self::Error> {
Ok(AddOutput {
sum: left.wrapping_add(right),
})
}
}
impl AsyncTool<TraitBasedToolServer> for AddTool {
async fn invoke(
_service: &TraitBasedToolServer,
AddParameter { left, right }: Self::Parameter,
) -> Result<Self::Output, Self::Error> {
Ok(AddOutput {
sum: left.wrapping_add(right),
})
}
}
enum EmptyToolCustomError {
Internal,
InvalidParams,
}
impl From<EmptyToolCustomError> for ErrorData {
fn from(value: EmptyToolCustomError) -> Self {
match value {
EmptyToolCustomError::Internal => Self::internal_error("internal error", None),
EmptyToolCustomError::InvalidParams => Self::invalid_params("invalid params", None),
}
}
}
struct EmptyTool;
impl ToolBase for EmptyTool {
type Parameter = ();
type Output = ();
type Error = EmptyToolCustomError;
fn name() -> Cow<'static, str> {
"empty".into()
}
fn description() -> Option<Cow<'static, str>> {
Some("Empty tool".into())
}
fn input_schema() -> Option<Arc<JsonObject>> {
None
}
fn output_schema() -> Option<Arc<JsonObject>> {
None
}
}
impl SyncTool<TraitBasedToolServer> for EmptyTool {
fn invoke(
_service: &TraitBasedToolServer,
_param: Self::Parameter,
) -> Result<Self::Output, Self::Error> {
Err(EmptyToolCustomError::Internal)
}
}
impl AsyncTool<TraitBasedToolServer> for EmptyTool {
async fn invoke(
_service: &TraitBasedToolServer,
_param: Self::Parameter,
) -> Result<Self::Output, Self::Error> {
Err(EmptyToolCustomError::InvalidParams)
}
}
struct TraitBasedToolServer;
#[test]
fn test_macro_and_trait_have_same_attrs() {
let macro_attrs = MacroBasedToolServer::add_tool_attr();
let trait_attrs = tool_attribute::<AddTool>();
assert_eq!(macro_attrs, trait_attrs);
}
#[test]
fn test_macro_and_trait_have_same_attrs_for_empty_tool() {
let macro_attrs = MacroBasedToolServer::empty_tool_attr();
let trait_attrs = tool_attribute::<EmptyTool>();
assert_eq!(macro_attrs, trait_attrs);
}
#[test]
fn test_sync_tool_wrapper_happy_path() {
let left = 1;
let right = 2;
let result = sync_tool_wrapper::<_, AddTool>(
&TraitBasedToolServer,
Parameters(AddParameter { left, right }),
);
assert!(result.is_ok());
if let Ok(result) = result {
assert_eq!(result.0, AddOutput { sum: 3 });
}
}
#[tokio::test]
async fn test_async_tool_wrapper_happy_path() {
let left = 1;
let right = 2;
let result = async_tool_wrapper::<_, AddTool>(
&TraitBasedToolServer,
Parameters(AddParameter { left, right }),
)
.await;
assert!(result.is_ok());
if let Ok(result) = result {
assert_eq!(result.0, AddOutput { sum: 3 });
}
}
#[test]
fn test_sync_tool_wrapper_error_conversion() {
let result = sync_tool_wrapper::<_, EmptyTool>(&TraitBasedToolServer, Parameters(()));
assert!(result.is_err());
if let Err(result) = result {
assert_eq!(result, ErrorData::internal_error("internal error", None));
}
}
#[tokio::test]
async fn test_async_tool_wrapper_error_conversion() {
let result =
async_tool_wrapper::<_, EmptyTool>(&TraitBasedToolServer, Parameters(())).await;
assert!(result.is_err());
if let Err(result) = result {
assert_eq!(result, ErrorData::invalid_params("invalid params", None));
}
}
}