1use crate::types::{McpContent, McpResource, McpTool, McpToolCallResponse};
13use axum::{
14 extract::State,
15 http::StatusCode,
16 response::{sse::Event as SseEvent, sse::Sse, IntoResponse},
17 routing::{get, post},
18 Json, Router,
19};
20use serde_json::{json, Value};
21use std::{collections::HashMap, convert::Infallible, sync::Arc};
22use tokio::sync::broadcast;
23use tokio_stream::wrappers::BroadcastStream;
24use tokio_stream::StreamExt;
25use tracing::{debug, info, warn};
26
27pub struct RegisteredTool {
31 pub definition: McpTool,
32 pub handler: Box<dyn ToolFn>,
33}
34
35pub trait ToolFn: Send + Sync {
36 fn call(
37 &self,
38 arguments: Value,
39 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<Vec<McpContent>, String>> + Send>>;
40}
41
42impl<F, Fut> ToolFn for F
43where
44 F: Fn(Value) -> Fut + Send + Sync,
45 Fut: std::future::Future<Output = Result<Vec<McpContent>, String>> + Send + 'static,
46{
47 fn call(
48 &self,
49 arguments: Value,
50 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<Vec<McpContent>, String>> + Send>>
51 {
52 Box::pin(self(arguments))
53 }
54}
55
56#[derive(Clone)]
59struct McpServerState {
60 server_name: String,
61 server_version: String,
62 tools: Arc<HashMap<String, Arc<RegisteredTool>>>,
63 resources: Arc<Vec<McpResource>>,
64 sse_tx: broadcast::Sender<Value>,
65}
66
67fn build_router(state: McpServerState) -> Router {
70 Router::new()
71 .route("/mcp", post(rpc_handler))
72 .route("/mcp/sse", get(sse_handler))
73 .with_state(state)
74}
75
76async fn rpc_handler(
77 State(state): State<McpServerState>,
78 Json(body): Json<Value>,
79) -> impl IntoResponse {
80 let method = body.get("method").and_then(|v| v.as_str()).unwrap_or("");
81 let rpc_id = body.get("id").cloned().unwrap_or(json!(1));
82 let params = body.get("params").cloned().unwrap_or(json!({}));
83
84 debug!(method = %method, "MCP RPC request");
85
86 let result = match method {
87 "initialize" => {
88 json!({
89 "protocolVersion": "2024-11-05",
90 "capabilities": {
91 "tools": { "listChanged": false },
92 "resources": { "subscribe": false, "listChanged": false },
93 "prompts": { "listChanged": false }
94 },
95 "serverInfo": {
96 "name": state.server_name,
97 "version": state.server_version
98 }
99 })
100 }
101
102 "initialized" | "notifications/initialized" => json!({}),
103
104 "tools/list" => {
105 let tools: Vec<Value> = state
106 .tools
107 .values()
108 .map(|t| {
109 json!({
110 "name": t.definition.name,
111 "description": t.definition.description,
112 "inputSchema": t.definition.input_schema,
113 })
114 })
115 .collect();
116 json!({ "tools": tools })
117 }
118
119 "tools/call" => {
120 let name = params.get("name").and_then(|v| v.as_str()).unwrap_or("");
121 let args = params.get("arguments").cloned().unwrap_or(json!({}));
122
123 match state.tools.get(name) {
124 Some(tool) => match tool.handler.call(args).await {
125 Ok(content) => {
126 let resp = McpToolCallResponse {
127 content,
128 is_error: Some(false),
129 };
130 serde_json::to_value(resp).unwrap_or(json!({}))
131 }
132 Err(e) => {
133 warn!(tool = %name, error = %e, "Tool call failed");
134 let resp = McpToolCallResponse {
135 content: vec![McpContent::Text { text: e }],
136 is_error: Some(true),
137 };
138 serde_json::to_value(resp).unwrap_or(json!({}))
139 }
140 },
141 None => {
142 return (
143 StatusCode::OK,
144 Json(json!({
145 "jsonrpc": "2.0",
146 "id": rpc_id,
147 "error": {
148 "code": -32601,
149 "message": format!("tool not found: {name}")
150 }
151 })),
152 )
153 .into_response();
154 }
155 }
156 }
157
158 "resources/list" => {
159 let resources: Vec<Value> = state
160 .resources
161 .iter()
162 .map(|r| {
163 json!({
164 "uri": r.uri,
165 "name": r.name,
166 "description": r.description,
167 "mimeType": r.mime_type,
168 })
169 })
170 .collect();
171 json!({ "resources": resources })
172 }
173
174 "prompts/list" => json!({ "prompts": [] }),
175
176 "ping" => json!({}),
177
178 other if other.starts_with("notifications/") => json!({}),
180
181 other => {
182 warn!(method = %other, "Unknown MCP method");
183 return (
184 StatusCode::OK,
185 Json(json!({
186 "jsonrpc": "2.0",
187 "id": rpc_id,
188 "error": {
189 "code": -32601,
190 "message": format!("method not found: {other}")
191 }
192 })),
193 )
194 .into_response();
195 }
196 };
197
198 (
199 StatusCode::OK,
200 Json(json!({
201 "jsonrpc": "2.0",
202 "id": rpc_id,
203 "result": result,
204 })),
205 )
206 .into_response()
207}
208
209async fn sse_handler(State(state): State<McpServerState>) -> impl IntoResponse {
210 let rx = state.sse_tx.subscribe();
211 let stream = BroadcastStream::new(rx).filter_map(|result: Result<Value, _>| {
212 result.ok().and_then(|event| {
213 serde_json::to_string(&event)
214 .ok()
215 .map(|data| Ok::<_, Infallible>(SseEvent::default().data(data)))
216 })
217 });
218 Sse::new(stream)
219}
220
221pub struct McpServerConfig {
225 pub port: u16,
226 pub server_name: String,
227 pub server_version: String,
228 pub exposed_tools: Vec<String>,
229 pub exposed_resources: Vec<String>,
230}
231
232pub struct McpServer {
234 port: u16,
235 server_name: String,
236 server_version: String,
237 tools: HashMap<String, Arc<RegisteredTool>>,
238 resources: Vec<McpResource>,
239}
240
241impl McpServer {
242 pub fn new(name: impl Into<String>, version: impl Into<String>, port: u16) -> Self {
243 Self {
244 port,
245 server_name: name.into(),
246 server_version: version.into(),
247 tools: HashMap::new(),
248 resources: Vec::new(),
249 }
250 }
251
252 pub fn register_tool<F, Fut>(mut self, definition: McpTool, handler: F) -> Self
254 where
255 F: Fn(Value) -> Fut + Send + Sync + 'static,
256 Fut: std::future::Future<Output = Result<Vec<McpContent>, String>> + Send + 'static,
257 {
258 let name = definition.name.clone();
259 self.tools.insert(
260 name,
261 Arc::new(RegisteredTool {
262 definition,
263 handler: Box::new(handler),
264 }),
265 );
266 self
267 }
268
269 pub fn register_resource(mut self, resource: McpResource) -> Self {
271 self.resources.push(resource);
272 self
273 }
274
275 pub fn into_router(self) -> Router {
280 let (sse_tx, _) = broadcast::channel(64);
281 let state = McpServerState {
282 server_name: self.server_name,
283 server_version: self.server_version,
284 tools: Arc::new(self.tools),
285 resources: Arc::new(self.resources),
286 sse_tx,
287 };
288 build_router(state)
289 }
290
291 pub async fn start(self) -> Result<(), String> {
293 let (sse_tx, _) = broadcast::channel(64);
294 let state = McpServerState {
295 server_name: self.server_name.clone(),
296 server_version: self.server_version.clone(),
297 tools: Arc::new(self.tools),
298 resources: Arc::new(self.resources),
299 sse_tx,
300 };
301
302 let router = build_router(state);
303 let addr = format!("0.0.0.0:{}", self.port);
304 let listener = tokio::net::TcpListener::bind(&addr)
305 .await
306 .map_err(|e| format!("failed to bind MCP server on {addr}: {e}"))?;
307
308 info!(
309 name = %self.server_name,
310 addr = %addr,
311 "MCP server listening"
312 );
313
314 axum::serve(listener, router)
315 .await
316 .map_err(|e| format!("MCP server error: {e}"))?;
317
318 Ok(())
319 }
320}