use crate::mcp::tools::ToolRegistry;
use crate::sources::SourceRegistry;
use async_trait::async_trait;
use pmcp::{
server::streamable_http_server::{StreamableHttpServer, StreamableHttpServerConfig},
Error, RequestHandlerExtra, Server, ServerCapabilities, ToolHandler, ToolInfo,
};
use serde_json::Value;
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::sync::Mutex;
use tokio::task::JoinHandle;
#[derive(Debug, Clone)]
pub struct McpServer {
server: Arc<Mutex<Server>>,
}
impl McpServer {
pub fn new(sources: Arc<SourceRegistry>) -> Result<Self, pmcp::Error> {
let tools = ToolRegistry::from_sources(&sources);
let server = Self::build_server_impl(tools)?;
Ok(Self {
server: Arc::new(Mutex::new(server)),
})
}
pub fn tools(&self) -> Arc<Mutex<Server>> {
self.server.clone()
}
fn build_server_impl(tools: ToolRegistry) -> Result<Server, pmcp::Error> {
let mut builder = Server::builder()
.name("research-master")
.version(env!("CARGO_PKG_VERSION"))
.capabilities(ServerCapabilities::default());
for tool in tools.all() {
let name = tool.name.clone();
let description = tool.description.clone();
let input_schema = tool.input_schema.clone();
let handler = tool.handler.clone();
let tool_handler = ToolWrapper {
name,
description: Some(description),
input_schema,
handler,
};
builder = builder.tool(tool_handler.name.clone(), tool_handler);
}
builder.build()
}
pub async fn run(&self) -> Result<(), pmcp::Error> {
tracing::info!("Starting MCP server in stdio mode");
let server = Arc::try_unwrap(self.server.clone())
.map_err(|_| Error::internal("Cannot unwrap Arc - multiple references exist"))?
.into_inner();
tracing::info!("MCP server initialized");
server.run_stdio().await
}
pub async fn run_http(&self, addr: &str) -> Result<(SocketAddr, JoinHandle<()>), pmcp::Error> {
tracing::info!("Starting MCP server in HTTP/SSE mode on {}", addr);
let socket_addr: SocketAddr = addr
.parse()
.map_err(|e| Error::invalid_params(format!("Invalid address: {}", e)))?;
let http_server = StreamableHttpServer::new(socket_addr, self.server.clone());
http_server.start().await
}
pub async fn run_http_with_config(
&self,
addr: &str,
config: StreamableHttpServerConfig,
) -> Result<(SocketAddr, JoinHandle<()>), pmcp::Error> {
tracing::info!(
"Starting MCP server in HTTP/SSE mode on {} (with custom config)",
addr
);
let socket_addr: SocketAddr = addr
.parse()
.map_err(|e| Error::invalid_params(format!("Invalid address: {}", e)))?;
let http_server =
StreamableHttpServer::with_config(socket_addr, self.server.clone(), config);
http_server.start().await
}
}
#[derive(Clone)]
struct ToolWrapper {
name: String,
description: Option<String>,
input_schema: Value,
handler: Arc<dyn crate::mcp::tools::ToolHandler>,
}
#[async_trait]
impl ToolHandler for ToolWrapper {
async fn handle(&self, args: Value, _extra: RequestHandlerExtra) -> Result<Value, Error> {
self.handler
.execute(args)
.await
.map_err(|e| Error::internal(&e))
}
fn metadata(&self) -> Option<ToolInfo> {
Some(ToolInfo::new(
self.name.clone(),
self.description.clone(),
self.input_schema.clone(),
))
}
}
pub fn create_mcp_server(sources: Arc<SourceRegistry>) -> Result<McpServer, pmcp::Error> {
McpServer::new(sources)
}