use std::sync::Arc;
use rmcp::{
RoleServer, ServerHandler, ServiceExt,
model::{
CallToolRequestParams, CallToolResult, Implementation, ListToolsResult,
PaginatedRequestParams, ServerCapabilities, ServerInfo, Tool as RmcpTool,
ToolsCapability,
},
service::RequestContext,
transport::streamable_http_server::{
session::local::LocalSessionManager,
tower::{StreamableHttpService, StreamableHttpServerConfig},
},
};
use serde_json::Value;
use crate::tool_trait::Tool;
use crate::tool_trait::ToolBundle;
#[derive(Debug, thiserror::Error)]
pub enum McpServerError {
#[error("failed to bind address: {0}")]
Bind(#[from] std::io::Error),
#[error("server error: {0}")]
Serve(String),
}
#[derive(Clone)]
pub struct BundleHandler {
bundle: Arc<ToolBundle>,
name: Arc<str>,
version: Arc<str>,
}
impl ServerHandler for BundleHandler {
fn get_info(&self) -> ServerInfo {
let mut capabilities = ServerCapabilities::default();
capabilities.tools = Some(ToolsCapability::default());
ServerInfo::new(capabilities)
.with_server_info(Implementation::new(&*self.name, &*self.version))
}
async fn list_tools(
&self,
_request: Option<PaginatedRequestParams>,
_context: RequestContext<RoleServer>,
) -> Result<ListToolsResult, rmcp::ErrorData> {
let tools = self
.bundle
.raw_tools()
.into_iter()
.map(|raw| {
let input_schema: Arc<serde_json::Map<String, Value>> = match raw.function.parameters {
Value::Object(map) => Arc::new(map),
_ => Arc::new(serde_json::Map::new()),
};
RmcpTool::new_with_raw(
raw.function.name,
raw.function.description.map(Into::into),
input_schema,
)
})
.collect();
Ok(ListToolsResult::with_all_items(tools))
}
async fn call_tool(
&self,
request: CallToolRequestParams,
_context: RequestContext<RoleServer>,
) -> Result<CallToolResult, rmcp::ErrorData> {
let args = request
.arguments
.map(Value::Object)
.unwrap_or(Value::Object(serde_json::Map::new()));
let result = self.bundle.call(&request.name, args).await;
let is_error = result
.as_object()
.and_then(|o| o.get("error"))
.is_some();
if is_error {
Ok(CallToolResult::structured_error(result))
} else {
Ok(CallToolResult::structured(result))
}
}
}
pub struct McpServer {
bundle: Arc<ToolBundle>,
name: Arc<str>,
version: Arc<str>,
}
impl McpServer {
pub fn new(bundle: ToolBundle) -> Self {
Self {
bundle: Arc::new(bundle),
name: "ds-api-mcp-server".into(),
version: env!("CARGO_PKG_VERSION").into(),
}
}
pub fn with_name(mut self, name: impl Into<Arc<str>>) -> Self {
self.name = name.into();
self
}
pub fn with_version(mut self, version: impl Into<Arc<str>>) -> Self {
self.version = version.into();
self
}
fn make_handler(&self) -> BundleHandler {
BundleHandler {
bundle: self.bundle.clone(),
name: self.name.clone(),
version: self.version.clone(),
}
}
pub async fn serve_stdio(self) -> Result<(), McpServerError> {
let handler = self.make_handler();
let (stdin, stdout) = rmcp::transport::stdio();
handler
.serve((stdin, stdout))
.await
.map_err(|e| McpServerError::Serve(e.to_string()))?
.waiting()
.await
.map_err(|e| McpServerError::Serve(e.to_string()))?;
Ok(())
}
pub async fn serve_http(self, addr: &str) -> Result<(), McpServerError> {
let service = self.into_http_service(Default::default());
let router = axum::Router::new().nest_service("/mcp", service);
let listener = tokio::net::TcpListener::bind(addr).await?;
axum::serve(listener, router)
.await
.map_err(|e| McpServerError::Serve(e.to_string()))?;
Ok(())
}
pub fn into_http_service(
self,
config: StreamableHttpServerConfig,
) -> StreamableHttpService<BundleHandler, LocalSessionManager> {
let handler = self.make_handler();
StreamableHttpService::new(move || Ok(handler.clone()), Default::default(), config)
}
}