use rmcp::{
Json,
handler::server::{
tool::{schema_for_output, schema_for_type},
wrapper::Parameters,
},
model::IntoContents,
};
use schemars::{JsonSchema, schema_for};
use serde::{Deserialize, Serialize};
pub trait SqliteServerTool: JsonSchema {
const NAME: &str;
type Context;
type Error: IntoContents;
type Input: Serialize + for<'de> Deserialize<'de> + JsonSchema + 'static;
type Output: Serialize + for<'de> Deserialize<'de> + JsonSchema + 'static;
fn handle(
ctx: &Self::Context,
input: Self::Input,
) -> Result<Self::Output, Self::Error>;
fn tool() -> rmcp::model::Tool {
let description = schema_for!(Self)
.get("description")
.and_then(|value| value.as_str().map(|str| str.to_owned()))
.unwrap_or_default();
let input_schema = schema_for_type::<Parameters<Self::Input>>();
let output_schema = schema_for_output::<Self::Output>()
.expect("Invalid output type schema");
rmcp::model::Tool::new_with_raw(
Self::NAME,
Some(description.into()),
input_schema,
)
.with_raw_output_schema(output_schema)
}
#[allow(clippy::type_complexity)]
fn handler_func() -> HandlerFuncFor<Self> {
|ctx, parameters| {
let span = tracing::info_span!("tool_call", tool = Self::NAME);
let _guard = span.enter();
if tracing::enabled!(tracing::Level::DEBUG) {
let input_json =
serde_json::to_string(¶meters.0).unwrap_or_default();
tracing::debug!(input = %input_json, "request");
}
let start = std::time::Instant::now();
let result = Self::handle(ctx, parameters.0);
let elapsed = start.elapsed();
match &result {
Ok(output) => {
if tracing::enabled!(tracing::Level::DEBUG) {
let output_json =
serde_json::to_string(output).unwrap_or_default();
tracing::debug!(output = %output_json, "response");
}
tracing::info!(elapsed_ms = elapsed.as_millis(), "ok");
}
Err(_) => {
tracing::warn!(elapsed_ms = elapsed.as_millis(), "error");
}
}
result.map(Json)
}
}
}
type HandlerFuncFor<T> = fn(
&<T as SqliteServerTool>::Context,
Parameters<<T as SqliteServerTool>::Input>,
) -> Result<
Json<<T as SqliteServerTool>::Output>,
<T as SqliteServerTool>::Error,
>;