1#![allow(clippy::disallowed_methods)] use crate::types::{
21 JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, ToolCallResult, ToolDefinition,
22};
23use std::collections::HashMap;
24use std::sync::mpsc::{self, Sender};
25use std::sync::{Arc, Mutex};
26
27pub type NotificationSink = Box<dyn Fn(JsonRpcNotification) + Send + Sync>;
38
39#[derive(Debug)]
45pub struct CancelHandle {
46 pub cancel_tx: Sender<()>,
49}
50
51type InFlight = Arc<Mutex<HashMap<serde_json::Value, CancelHandle>>>;
56
57#[derive(Debug, Default)]
62pub struct AprMcpServer {
63 in_flight: InFlight,
64}
65
66impl AprMcpServer {
67 #[must_use]
69 pub fn new() -> Self {
70 Self::default()
71 }
72
73 #[must_use]
90 pub fn handle_request(&mut self, request: &JsonRpcRequest) -> JsonRpcResponse {
91 if request.jsonrpc != "2.0" {
92 return JsonRpcResponse::error(
93 request.id.clone(),
94 -32600,
95 format!(
96 "Invalid Request: jsonrpc must be \"2.0\", got \"{}\"",
97 request.jsonrpc
98 ),
99 );
100 }
101
102 match request.method.as_str() {
103 "initialize" => self.handle_initialize(request),
104 "tools/list" => self.handle_tools_list(request),
105 "tools/call" => self.handle_tools_call_sync(request),
106 other => JsonRpcResponse::error(
107 request.id.clone(),
108 -32601,
109 format!("Method not found: {other}"),
110 ),
111 }
112 }
113
114 fn handle_initialize(&self, request: &JsonRpcRequest) -> JsonRpcResponse {
115 if let Some(client_version) = request
119 .params
120 .get("protocolVersion")
121 .and_then(|v| v.as_str())
122 {
123 if client_version != crate::PROTOCOL_VERSION {
124 return JsonRpcResponse::error(
125 request.id.clone(),
126 -32602,
127 format!(
128 "Unsupported protocolVersion: client requested \"{}\", server speaks \"{}\"",
129 client_version,
130 crate::PROTOCOL_VERSION
131 ),
132 );
133 }
134 }
135
136 JsonRpcResponse::success(
137 request.id.clone(),
138 serde_json::json!({
139 "protocolVersion": crate::PROTOCOL_VERSION,
140 "capabilities": {
141 "tools": { "listChanged": false }
142 },
143 "serverInfo": {
144 "name": crate::SERVER_NAME,
145 "version": env!("CARGO_PKG_VERSION"),
146 },
147 }),
148 )
149 }
150
151 fn handle_tools_list(&self, request: &JsonRpcRequest) -> JsonRpcResponse {
152 let tools: Vec<ToolDefinition> = self.tool_definitions();
153 JsonRpcResponse::success(request.id.clone(), serde_json::json!({ "tools": tools }))
154 }
155
156 fn handle_tools_call_sync(&self, request: &JsonRpcRequest) -> JsonRpcResponse {
161 let (_tx, rx) = mpsc::channel::<()>();
162 let result = dispatch_tool_call(&request.params, &rx, None);
163 JsonRpcResponse::success(
164 request.id.clone(),
165 serde_json::to_value(result).unwrap_or_else(|_| serde_json::json!({})),
166 )
167 }
168
169 #[must_use]
183 pub fn handle_request_with_sink(
184 &mut self,
185 request: &JsonRpcRequest,
186 sink: &NotificationSink,
187 ) -> Option<JsonRpcResponse> {
188 if request.jsonrpc != "2.0" {
189 return Some(JsonRpcResponse::error(
190 request.id.clone(),
191 -32600,
192 format!(
193 "Invalid Request: jsonrpc must be \"2.0\", got \"{}\"",
194 request.jsonrpc
195 ),
196 ));
197 }
198
199 if request.method.starts_with("notifications/") {
200 return None;
201 }
202
203 if request.method != "tools/call" {
204 return Some(self.handle_request(request));
205 }
206
207 let progress_token = extract_progress_token(&request.params);
208 let (_tx, rx) = mpsc::channel::<()>();
209 let sink_for_dispatch = progress_token.as_ref().map(|_| sink);
210 let result =
211 dispatch_tool_call_with_sink(&request.params, &rx, sink_for_dispatch, progress_token);
212 Some(JsonRpcResponse::success(
213 request.id.clone(),
214 serde_json::to_value(result).unwrap_or_else(|_| serde_json::json!({})),
215 ))
216 }
217
218 #[must_use]
227 pub fn tool_definitions(&self) -> Vec<ToolDefinition> {
228 tool_index().definitions().to_vec()
229 }
230
231 #[must_use]
236 pub fn register_in_flight(in_flight: &InFlight, id: serde_json::Value) -> mpsc::Receiver<()> {
237 let (tx, rx) = mpsc::channel::<()>();
238 let mut guard = in_flight
239 .lock()
240 .expect("in_flight mutex not poisoned during register");
241 guard.insert(id, CancelHandle { cancel_tx: tx });
242 rx
243 }
244
245 pub fn cancel_in_flight(in_flight: &InFlight, id: &serde_json::Value) -> bool {
251 let mut guard = in_flight
252 .lock()
253 .expect("in_flight mutex not poisoned during cancel");
254 if let Some(handle) = guard.remove(id) {
255 let _ = handle.cancel_tx.send(());
258 true
259 } else {
260 false
261 }
262 }
263
264 fn deregister_in_flight(in_flight: &InFlight, id: &serde_json::Value) {
267 if let Ok(mut guard) = in_flight.lock() {
268 guard.remove(id);
269 }
270 }
271
272 #[cfg(feature = "native")]
284 pub fn run_stdio(&mut self) -> anyhow::Result<()> {
285 use std::io::{self, BufRead};
286
287 let stdin = io::stdin();
288 let stdout = Arc::new(Mutex::new(io::stdout()));
289
290 for line in stdin.lock().lines() {
291 let line = line?;
292 if line.trim().is_empty() {
293 continue;
294 }
295
296 let parsed: Result<JsonRpcRequest, _> = serde_json::from_str(&line);
297 match parsed {
298 Ok(req) => self.route_stdio_message(req, &stdout)?,
299 Err(e) => {
300 let resp = JsonRpcResponse::error(None, -32700, format!("Parse error: {e}"));
301 write_response(&stdout, &resp)?;
302 }
303 }
304 }
305
306 Ok(())
307 }
308
309 #[cfg(feature = "native")]
312 fn route_stdio_message(
313 &mut self,
314 req: JsonRpcRequest,
315 stdout: &Arc<Mutex<std::io::Stdout>>,
316 ) -> anyhow::Result<()> {
317 if req.jsonrpc != "2.0" {
319 let resp = JsonRpcResponse::error(
320 req.id.clone(),
321 -32600,
322 format!(
323 "Invalid Request: jsonrpc must be \"2.0\", got \"{}\"",
324 req.jsonrpc
325 ),
326 );
327 return write_response(stdout, &resp);
328 }
329
330 match req.method.as_str() {
331 "notifications/cancelled" => {
333 if let Some(request_id) = req.params.get("requestId").cloned() {
334 let _ = Self::cancel_in_flight(&self.in_flight, &request_id);
335 }
336 Ok(())
337 }
338 "notifications/initialized" => {
339 Ok(())
341 }
342 "tools/call" => self.spawn_tools_call_worker(req, stdout),
343 _ => {
345 let resp = self.handle_request(&req);
346 write_response(stdout, &resp)
347 }
348 }
349 }
350
351 #[cfg(feature = "native")]
352 fn spawn_tools_call_worker(
353 &mut self,
354 req: JsonRpcRequest,
355 stdout: &Arc<Mutex<std::io::Stdout>>,
356 ) -> anyhow::Result<()> {
357 let Some(id) = req.id.clone() else {
361 let resp =
362 JsonRpcResponse::error(None, -32600, "Invalid Request: tools/call requires an id");
363 return write_response(stdout, &resp);
364 };
365
366 let cancel_rx = Self::register_in_flight(&self.in_flight, id.clone());
367 let stdout_clone = Arc::clone(stdout);
368 let in_flight_clone = Arc::clone(&self.in_flight);
369 let params = req.params.clone();
370 let id_for_worker = id.clone();
371 let progress_token = extract_progress_token(¶ms);
372
373 let sink_stdout = Arc::clone(stdout);
378 let sink: NotificationSink = Box::new(move |notif| {
379 let _ = write_notification(&sink_stdout, ¬if);
381 });
382
383 let builder = std::thread::Builder::new().name(format!("apr-mcp-call-{id}"));
386 let spawn_result = builder.spawn(move || {
387 let sink_ref = progress_token.as_ref().map(|_| &sink);
388 let result =
389 dispatch_tool_call_with_sink(¶ms, &cancel_rx, sink_ref, progress_token);
390 let resp = JsonRpcResponse::success(
391 Some(id_for_worker.clone()),
392 serde_json::to_value(result).unwrap_or_else(|_| serde_json::json!({})),
393 );
394 let _ = write_response(&stdout_clone, &resp);
397 Self::deregister_in_flight(&in_flight_clone, &id_for_worker);
398 });
399
400 match spawn_result {
401 Ok(_handle) => Ok(()),
402 Err(e) => {
403 Self::deregister_in_flight(&self.in_flight, &id);
406 let resp = JsonRpcResponse::error(
407 Some(id),
408 -32603,
409 format!("Internal error: failed to spawn worker thread: {e}"),
410 );
411 write_response(stdout, &resp)
412 }
413 }
414 }
415
416 #[must_use]
418 pub fn in_flight_handle(&self) -> InFlight {
419 Arc::clone(&self.in_flight)
420 }
421}
422
423fn dispatch_tool_call(
430 params: &serde_json::Value,
431 cancel_rx: &mpsc::Receiver<()>,
432 sink: Option<&NotificationSink>,
433) -> ToolCallResult {
434 dispatch_tool_call_with_sink(params, cancel_rx, sink, None)
435}
436
437fn dispatch_tool_call_with_sink(
446 params: &serde_json::Value,
447 cancel_rx: &mpsc::Receiver<()>,
448 sink: Option<&NotificationSink>,
449 progress_token: Option<serde_json::Value>,
450) -> ToolCallResult {
451 let name = params.get("name").and_then(|v| v.as_str());
452 let arguments = params
453 .get("arguments")
454 .cloned()
455 .unwrap_or_else(|| serde_json::json!({}));
456
457 let Some(name) = name else {
464 return ToolCallResult::error("Missing tool name");
465 };
466 match tool_index().dispatch_for(name) {
467 Some(dispatch_fn) => dispatch_fn(&arguments, cancel_rx, sink, progress_token),
468 None => ToolCallResult::error(format!("Unknown tool: {name}")),
469 }
470}
471
472fn tool_index() -> &'static crate::tools::ToolIndex {
478 static INDEX: std::sync::OnceLock<crate::tools::ToolIndex> = std::sync::OnceLock::new();
479 INDEX.get_or_init(crate::tools::ToolIndex::from_inventory)
480}
481
482fn extract_progress_token(params: &serde_json::Value) -> Option<serde_json::Value> {
486 params
487 .get("_meta")
488 .and_then(|m| m.get("progressToken"))
489 .cloned()
490}
491
492#[cfg(feature = "native")]
493fn write_response(
494 stdout: &Arc<Mutex<std::io::Stdout>>,
495 resp: &JsonRpcResponse,
496) -> anyhow::Result<()> {
497 use std::io::Write;
498
499 let json = serde_json::to_string(resp)?;
500 let mut guard = stdout
501 .lock()
502 .map_err(|e| anyhow::anyhow!("stdout mutex poisoned: {e}"))?;
503 writeln!(&mut *guard, "{json}")?;
504 guard.flush()?;
505 Ok(())
506}
507
508#[cfg(feature = "native")]
513fn write_notification(
514 stdout: &Arc<Mutex<std::io::Stdout>>,
515 notif: &JsonRpcNotification,
516) -> anyhow::Result<()> {
517 use std::io::Write;
518
519 let json = notif.to_json_line()?;
520 let mut guard = stdout
521 .lock()
522 .map_err(|e| anyhow::anyhow!("stdout mutex poisoned: {e}"))?;
523 writeln!(&mut *guard, "{json}")?;
524 guard.flush()?;
525 Ok(())
526}
527
528#[cfg(test)]
529#[allow(clippy::disallowed_methods)] mod tests {
531 use super::*;
532
533 fn make_request(method: &str, params: serde_json::Value) -> JsonRpcRequest {
534 JsonRpcRequest {
535 jsonrpc: "2.0".to_string(),
536 id: Some(serde_json::json!(1)),
537 method: method.to_string(),
538 params,
539 }
540 }
541
542 #[test]
544 fn initialize_returns_protocol_version() {
545 let mut server = AprMcpServer::new();
546 let req = make_request("initialize", serde_json::json!({}));
547 let resp = server.handle_request(&req);
548
549 assert!(resp.error.is_none());
550 let result = resp.result.expect("result present");
551 assert_eq!(result["protocolVersion"], "2024-11-05");
552 assert_eq!(result["serverInfo"]["name"], "aprender-mcp");
553 assert!(result["capabilities"]["tools"].is_object());
554 }
555
556 #[test]
562 fn tools_list_returns_registered_tools() {
563 let mut server = AprMcpServer::new();
564 let req = make_request("tools/list", serde_json::json!({}));
565 let resp = server.handle_request(&req);
566
567 let result = resp.result.expect("result present");
568 let tools = result["tools"].as_array().expect("tools array");
569 let names: Vec<&str> = tools.iter().filter_map(|t| t["name"].as_str()).collect();
570 for expected in [
571 "apr.version",
572 "apr.validate",
573 "apr.tensors",
574 "apr.bench",
575 "apr.qa",
576 "apr.trace",
577 "apr.run",
578 "apr.serve",
579 "apr.finetune",
580 ] {
581 assert!(names.contains(&expected), "{expected} registered");
582 }
583
584 for tool in tools {
585 assert_eq!(tool["inputSchema"]["type"], "object");
586 }
587 }
588
589 #[test]
590 fn tools_call_version_returns_metadata() {
591 let mut server = AprMcpServer::new();
592 let req = make_request(
593 "tools/call",
594 serde_json::json!({ "name": "apr.version", "arguments": {} }),
595 );
596 let resp = server.handle_request(&req);
597
598 let result = resp.result.expect("result present");
599 let text = result["content"][0]["text"].as_str().expect("text");
600 let parsed: serde_json::Value = serde_json::from_str(text).expect("json");
601 assert_eq!(parsed["server"], "aprender-mcp");
602 assert_eq!(parsed["protocol_version"], "2024-11-05");
603 }
604
605 #[test]
606 fn unknown_method_returns_method_not_found() {
607 let mut server = AprMcpServer::new();
608 let req = make_request("tools/explode", serde_json::json!({}));
609 let resp = server.handle_request(&req);
610
611 assert!(resp.result.is_none());
612 let err = resp.error.expect("error present");
613 assert_eq!(err.code, -32601);
614 }
615
616 #[test]
619 fn tools_call_validate_missing_model_path_is_error() {
620 let mut server = AprMcpServer::new();
621 let req = make_request(
622 "tools/call",
623 serde_json::json!({ "name": "apr.validate", "arguments": {} }),
624 );
625 let resp = server.handle_request(&req);
626
627 let result = resp.result.expect("result present");
628 assert_eq!(result["isError"], true);
629 let text = result["content"][0]["text"].as_str().expect("text");
630 assert!(text.contains("model_path"));
631 }
632
633 #[test]
634 fn tools_call_unknown_tool_returns_is_error() {
635 let mut server = AprMcpServer::new();
636 let req = make_request(
637 "tools/call",
638 serde_json::json!({ "name": "apr.nonexistent" }),
639 );
640 let resp = server.handle_request(&req);
641
642 let result = resp.result.expect("result present");
643 assert_eq!(result["isError"], true);
644 }
645
646 #[test]
647 fn tools_call_missing_name_returns_is_error() {
648 let mut server = AprMcpServer::new();
649 let req = make_request("tools/call", serde_json::json!({}));
650 let resp = server.handle_request(&req);
651
652 let result = resp.result.expect("result present");
653 assert_eq!(result["isError"], true);
654 }
655
656 #[test]
657 fn id_is_echoed_back() {
658 let mut server = AprMcpServer::new();
659 let req = JsonRpcRequest {
660 jsonrpc: "2.0".to_string(),
661 id: Some(serde_json::json!("req-42")),
662 method: "initialize".to_string(),
663 params: serde_json::json!({}),
664 };
665 let resp = server.handle_request(&req);
666 assert_eq!(resp.id, Some(serde_json::json!("req-42")));
667 }
668
669 #[test]
672 fn cancel_in_flight_signals_and_deregisters() {
673 let server = AprMcpServer::new();
674 let id = serde_json::json!(99);
675 let rx = AprMcpServer::register_in_flight(&server.in_flight, id.clone());
676
677 let signalled = AprMcpServer::cancel_in_flight(&server.in_flight, &id);
678 assert!(signalled, "live id should signal");
679 let received = rx.try_recv();
683 assert!(received.is_ok(), "cancel signal must be deliverable");
684
685 let signalled_again = AprMcpServer::cancel_in_flight(&server.in_flight, &id);
687 assert!(
688 !signalled_again,
689 "cancelling an already-removed id is a no-op"
690 );
691 }
692
693 #[test]
695 fn cancel_unknown_id_is_noop() {
696 let server = AprMcpServer::new();
697 let id = serde_json::json!("never-registered");
698 let signalled = AprMcpServer::cancel_in_flight(&server.in_flight, &id);
699 assert!(!signalled);
700 }
701}