use std::{marker::PhantomData, pin::pin, sync::Arc};
use futures::future::Either;
use fxhash::FxHashMap;
use rmcp::{
ErrorData, ServerHandler,
handler::server::tool::cached_schema_for_type,
model::{CallToolResult, ListToolsResult, Tool},
};
use sacp::{BoxFuture, ByteStreams, Component};
mod tool;
use schemars::JsonSchema;
use serde::{Serialize, de::DeserializeOwned};
use tokio_util::compat::{TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt};
pub use tool::*;
use crate::McpContext;
#[derive(Clone, Default)]
pub struct McpServer {
instructions: Option<String>,
tool_models: Vec<rmcp::model::Tool>,
tools: FxHashMap<String, Arc<dyn ErasedMcpTool>>,
}
impl McpServer {
pub fn new() -> Self {
Self::default()
}
pub fn instructions(mut self, instructions: impl ToString) -> Self {
self.instructions = Some(instructions.to_string());
self
}
pub fn tool(mut self, tool: impl McpTool + 'static) -> Self {
let tool_model = make_tool_model(&tool);
self.tool_models.push(tool_model);
self.tools.insert(tool.name(), make_erased_mcp_tool(tool));
self
}
pub fn tool_fn<P, R, F, H>(
self,
name: impl ToString,
description: impl ToString,
func: F,
to_future_hack: H,
) -> Self
where
P: JsonSchema + DeserializeOwned + 'static + Send,
R: JsonSchema + Serialize + 'static + Send,
F: AsyncFn(P, McpContext) -> Result<R, sacp::Error> + Send + Sync + 'static,
H: Fn(&F, P, McpContext) -> BoxFuture<'_, Result<R, sacp::Error>> + Send + Sync + 'static,
{
struct ToolFnTool<P, R, F, H> {
name: String,
description: String,
func: F,
to_future_hack: H,
phantom: PhantomData<fn(P) -> R>,
}
impl<P, R, F, H> McpTool for ToolFnTool<P, R, F, H>
where
P: JsonSchema + DeserializeOwned + 'static + Send,
R: JsonSchema + Serialize + 'static + Send,
F: AsyncFn(P, McpContext) -> Result<R, sacp::Error> + Send + Sync + 'static,
H: Fn(&F, P, McpContext) -> BoxFuture<'_, Result<R, sacp::Error>>
+ Send
+ Sync
+ 'static,
{
type Input = P;
type Output = R;
fn name(&self) -> String {
self.name.clone()
}
fn description(&self) -> String {
self.description.clone()
}
async fn call_tool(&self, params: P, cx: McpContext) -> Result<R, sacp::Error> {
(self.to_future_hack)(&self.func, params, cx).await
}
}
self.tool(ToolFnTool {
name: name.to_string(),
description: description.to_string(),
func,
to_future_hack,
phantom: PhantomData::<fn(P) -> R>,
})
}
pub(crate) fn new_connection(&self, mcp_cx: McpContext) -> McpServerConnection {
McpServerConnection {
service: self.clone(),
mcp_cx,
}
}
}
pub(crate) struct McpServerConnection {
service: McpServer,
mcp_cx: McpContext,
}
impl Component for McpServerConnection {
async fn serve(self, client: impl Component) -> Result<(), sacp::Error> {
let (mcp_server_stream, mcp_client_stream) = tokio::io::duplex(8192);
let (mcp_server_read, mcp_server_write) = tokio::io::split(mcp_server_stream);
let (mcp_client_read, mcp_client_write) = tokio::io::split(mcp_client_stream);
let byte_streams =
ByteStreams::new(mcp_client_write.compat_write(), mcp_client_read.compat());
tokio::spawn(async move {
let _ = byte_streams.serve(client).await;
});
let running_server = rmcp::ServiceExt::serve(self, (mcp_server_read, mcp_server_write))
.await
.map_err(sacp::Error::into_internal_error)?;
running_server
.waiting()
.await
.map(|_quit_reason| ())
.map_err(sacp::Error::into_internal_error)
}
}
impl ServerHandler for McpServerConnection {
async fn call_tool(
&self,
request: rmcp::model::CallToolRequestParam,
context: rmcp::service::RequestContext<rmcp::RoleServer>,
) -> Result<CallToolResult, ErrorData> {
let Some(tool) = self.service.tools.get(&request.name[..]) else {
return Err(rmcp::model::ErrorData::invalid_params(
format!("tool `{}` not found", request.name),
None,
));
};
let serde_value = serde_json::to_value(request.arguments).expect("valid json");
match futures::future::select(
tool.call_tool(serde_value, self.mcp_cx.clone()),
pin!(context.ct.cancelled()),
)
.await
{
Either::Left((m, _)) => match m {
Ok(result) => Ok(CallToolResult::structured(result)),
Err(error) => Err(to_rmcp_error(error)),
},
Either::Right(((), _)) => {
Err(rmcp::ErrorData::internal_error("operation cancelled", None))
}
}
}
async fn list_tools(
&self,
_request: Option<rmcp::model::PaginatedRequestParam>,
_context: rmcp::service::RequestContext<rmcp::RoleServer>,
) -> Result<rmcp::model::ListToolsResult, ErrorData> {
Ok(ListToolsResult::with_all_items(
self.service.tool_models.clone(),
))
}
fn get_info(&self) -> rmcp::model::ServerInfo {
rmcp::model::ServerInfo {
protocol_version: rmcp::model::ProtocolVersion::default(),
capabilities: rmcp::model::ServerCapabilities::builder()
.enable_tools()
.build(),
server_info: rmcp::model::Implementation::default(),
instructions: self.service.instructions.clone(),
}
}
}
trait ErasedMcpTool: Send + Sync {
fn call_tool(
&self,
input: serde_json::Value,
context: McpContext,
) -> BoxFuture<'_, Result<serde_json::Value, sacp::Error>>;
}
fn make_tool_model<M: McpTool>(tool: &M) -> Tool {
rmcp::model::Tool {
name: tool.name().into(),
title: tool.title(),
description: Some(tool.description().into()),
input_schema: cached_schema_for_type::<M::Input>(),
output_schema: Some(cached_schema_for_type::<M::Output>()),
annotations: None,
icons: None,
}
}
fn make_erased_mcp_tool<'s, M: McpTool + 's>(tool: M) -> Arc<dyn ErasedMcpTool + 's> {
struct ErasedMcpToolImpl<M: McpTool> {
tool: M,
}
impl<M: McpTool> ErasedMcpTool for ErasedMcpToolImpl<M> {
fn call_tool(
&self,
input: serde_json::Value,
context: McpContext,
) -> BoxFuture<'_, Result<serde_json::Value, sacp::Error>> {
Box::pin(async move {
let input = serde_json::from_value(input).map_err(sacp::util::internal_error)?;
serde_json::to_value(self.tool.call_tool(input, context).await?)
.map_err(sacp::util::internal_error)
})
}
}
Arc::new(ErasedMcpToolImpl { tool })
}
fn to_rmcp_error(error: sacp::Error) -> rmcp::ErrorData {
rmcp::ErrorData {
code: rmcp::model::ErrorCode(error.code),
message: error.message.into(),
data: error.data,
}
}