research_master/mcp/
server.rs1use crate::mcp::tools::ToolRegistry;
7use crate::sources::SourceRegistry;
8use async_trait::async_trait;
9use pmcp::{
10 server::streamable_http_server::{StreamableHttpServer, StreamableHttpServerConfig},
11 Error, RequestHandlerExtra, Server, ServerCapabilities, ToolHandler, ToolInfo,
12};
13use serde_json::Value;
14use std::net::SocketAddr;
15use std::sync::Arc;
16use tokio::sync::Mutex;
17use tokio::task::JoinHandle;
18
19#[derive(Debug, Clone)]
24pub struct McpServer {
25 server: Arc<Mutex<Server>>,
26}
27
28impl McpServer {
29 pub fn new(sources: Arc<SourceRegistry>) -> Result<Self, pmcp::Error> {
31 let tools = ToolRegistry::from_sources(&sources);
32 let server = Self::build_server_impl(tools)?;
33 Ok(Self {
34 server: Arc::new(Mutex::new(server)),
35 })
36 }
37
38 pub fn tools(&self) -> Arc<Mutex<Server>> {
40 self.server.clone()
41 }
42
43 fn build_server_impl(tools: ToolRegistry) -> Result<Server, pmcp::Error> {
45 let mut builder = Server::builder()
46 .name("research-master")
47 .version(env!("CARGO_PKG_VERSION"))
48 .capabilities(ServerCapabilities::default());
49
50 for tool in tools.all() {
52 let name = tool.name.clone();
53 let description = tool.description.clone();
54 let input_schema = tool.input_schema.clone();
55 let handler = tool.handler.clone();
56
57 let tool_handler = ToolWrapper {
58 name,
59 description: Some(description),
60 input_schema,
61 handler,
62 };
63 builder = builder.tool(tool_handler.name.clone(), tool_handler);
64 }
65
66 builder.build()
67 }
68
69 pub async fn run(&self) -> Result<(), pmcp::Error> {
71 tracing::info!("Starting MCP server in stdio mode");
72
73 let server = Arc::try_unwrap(self.server.clone())
77 .map_err(|_| Error::internal("Cannot unwrap Arc - multiple references exist"))?
78 .into_inner();
79
80 tracing::info!("MCP server initialized");
81
82 server.run_stdio().await
83 }
84
85 pub async fn run_http(&self, addr: &str) -> Result<(SocketAddr, JoinHandle<()>), pmcp::Error> {
90 tracing::info!("Starting MCP server in HTTP/SSE mode on {}", addr);
91
92 let socket_addr: SocketAddr = addr
93 .parse()
94 .map_err(|e| Error::invalid_params(format!("Invalid address: {}", e)))?;
95
96 let http_server = StreamableHttpServer::new(socket_addr, self.server.clone());
98
99 http_server.start().await
101 }
102
103 pub async fn run_http_with_config(
105 &self,
106 addr: &str,
107 config: StreamableHttpServerConfig,
108 ) -> Result<(SocketAddr, JoinHandle<()>), pmcp::Error> {
109 tracing::info!(
110 "Starting MCP server in HTTP/SSE mode on {} (with custom config)",
111 addr
112 );
113
114 let socket_addr: SocketAddr = addr
115 .parse()
116 .map_err(|e| Error::invalid_params(format!("Invalid address: {}", e)))?;
117
118 let http_server =
120 StreamableHttpServer::with_config(socket_addr, self.server.clone(), config);
121
122 http_server.start().await
124 }
125}
126
127#[derive(Clone)]
129struct ToolWrapper {
130 name: String,
131 description: Option<String>,
132 input_schema: Value,
133 handler: Arc<dyn crate::mcp::tools::ToolHandler>,
134}
135
136#[async_trait]
137impl ToolHandler for ToolWrapper {
138 async fn handle(&self, args: Value, _extra: RequestHandlerExtra) -> Result<Value, Error> {
139 self.handler
140 .execute(args)
141 .await
142 .map_err(|e| Error::internal(&e))
143 }
144
145 fn metadata(&self) -> Option<ToolInfo> {
146 Some(ToolInfo::new(
147 self.name.clone(),
148 self.description.clone(),
149 self.input_schema.clone(),
150 ))
151 }
152}
153
154pub fn create_mcp_server(sources: Arc<SourceRegistry>) -> Result<McpServer, pmcp::Error> {
156 McpServer::new(sources)
157}