1#![allow(clippy::disallowed_methods)] use crate::tools;
21use crate::types::{
22 JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, ToolCallResult, ToolDefinition,
23};
24use std::collections::HashMap;
25use std::sync::mpsc::{self, Sender};
26use std::sync::{Arc, Mutex};
27
28pub type NotificationSink = Box<dyn Fn(JsonRpcNotification) + Send + Sync>;
39
40#[derive(Debug)]
46pub struct CancelHandle {
47 pub cancel_tx: Sender<()>,
50}
51
52type InFlight = Arc<Mutex<HashMap<serde_json::Value, CancelHandle>>>;
57
58#[derive(Debug, Default)]
63pub struct AprMcpServer {
64 in_flight: InFlight,
65}
66
67impl AprMcpServer {
68 #[must_use]
70 pub fn new() -> Self {
71 Self::default()
72 }
73
74 #[must_use]
91 pub fn handle_request(&mut self, request: &JsonRpcRequest) -> JsonRpcResponse {
92 if request.jsonrpc != "2.0" {
93 return JsonRpcResponse::error(
94 request.id.clone(),
95 -32600,
96 format!(
97 "Invalid Request: jsonrpc must be \"2.0\", got \"{}\"",
98 request.jsonrpc
99 ),
100 );
101 }
102
103 match request.method.as_str() {
104 "initialize" => self.handle_initialize(request),
105 "tools/list" => self.handle_tools_list(request),
106 "tools/call" => self.handle_tools_call_sync(request),
107 other => JsonRpcResponse::error(
108 request.id.clone(),
109 -32601,
110 format!("Method not found: {other}"),
111 ),
112 }
113 }
114
115 fn handle_initialize(&self, request: &JsonRpcRequest) -> JsonRpcResponse {
116 if let Some(client_version) = request
120 .params
121 .get("protocolVersion")
122 .and_then(|v| v.as_str())
123 {
124 if client_version != crate::PROTOCOL_VERSION {
125 return JsonRpcResponse::error(
126 request.id.clone(),
127 -32602,
128 format!(
129 "Unsupported protocolVersion: client requested \"{}\", server speaks \"{}\"",
130 client_version,
131 crate::PROTOCOL_VERSION
132 ),
133 );
134 }
135 }
136
137 JsonRpcResponse::success(
138 request.id.clone(),
139 serde_json::json!({
140 "protocolVersion": crate::PROTOCOL_VERSION,
141 "capabilities": {
142 "tools": { "listChanged": false }
143 },
144 "serverInfo": {
145 "name": crate::SERVER_NAME,
146 "version": env!("CARGO_PKG_VERSION"),
147 },
148 }),
149 )
150 }
151
152 fn handle_tools_list(&self, request: &JsonRpcRequest) -> JsonRpcResponse {
153 let tools: Vec<ToolDefinition> = self.tool_definitions();
154 JsonRpcResponse::success(request.id.clone(), serde_json::json!({ "tools": tools }))
155 }
156
157 fn handle_tools_call_sync(&self, request: &JsonRpcRequest) -> JsonRpcResponse {
162 let (_tx, rx) = mpsc::channel::<()>();
163 let result = dispatch_tool_call(&request.params, &rx, None);
164 JsonRpcResponse::success(
165 request.id.clone(),
166 serde_json::to_value(result).unwrap_or_else(|_| serde_json::json!({})),
167 )
168 }
169
170 #[must_use]
184 pub fn handle_request_with_sink(
185 &mut self,
186 request: &JsonRpcRequest,
187 sink: &NotificationSink,
188 ) -> Option<JsonRpcResponse> {
189 if request.jsonrpc != "2.0" {
190 return Some(JsonRpcResponse::error(
191 request.id.clone(),
192 -32600,
193 format!(
194 "Invalid Request: jsonrpc must be \"2.0\", got \"{}\"",
195 request.jsonrpc
196 ),
197 ));
198 }
199
200 if request.method.starts_with("notifications/") {
201 return None;
202 }
203
204 if request.method != "tools/call" {
205 return Some(self.handle_request(request));
206 }
207
208 let progress_token = extract_progress_token(&request.params);
209 let (_tx, rx) = mpsc::channel::<()>();
210 let sink_for_dispatch = progress_token.as_ref().map(|_| sink);
211 let result =
212 dispatch_tool_call_with_sink(&request.params, &rx, sink_for_dispatch, progress_token);
213 Some(JsonRpcResponse::success(
214 request.id.clone(),
215 serde_json::to_value(result).unwrap_or_else(|_| serde_json::json!({})),
216 ))
217 }
218
219 #[must_use]
221 pub fn tool_definitions(&self) -> Vec<ToolDefinition> {
222 vec![
223 tools::version_tool_definition(),
224 tools::validate_tool_definition(),
225 tools::tensors_tool_definition(),
226 tools::bench_tool_definition(),
227 tools::qa_tool_definition(),
228 tools::trace_tool_definition(),
229 tools::run_tool_definition(),
230 tools::serve_tool_definition(),
231 tools::finetune_tool_definition(),
232 ]
233 }
234
235 #[must_use]
240 pub fn register_in_flight(in_flight: &InFlight, id: serde_json::Value) -> mpsc::Receiver<()> {
241 let (tx, rx) = mpsc::channel::<()>();
242 let mut guard = in_flight
243 .lock()
244 .expect("in_flight mutex not poisoned during register");
245 guard.insert(id, CancelHandle { cancel_tx: tx });
246 rx
247 }
248
249 pub fn cancel_in_flight(in_flight: &InFlight, id: &serde_json::Value) -> bool {
255 let mut guard = in_flight
256 .lock()
257 .expect("in_flight mutex not poisoned during cancel");
258 if let Some(handle) = guard.remove(id) {
259 let _ = handle.cancel_tx.send(());
262 true
263 } else {
264 false
265 }
266 }
267
268 fn deregister_in_flight(in_flight: &InFlight, id: &serde_json::Value) {
271 if let Ok(mut guard) = in_flight.lock() {
272 guard.remove(id);
273 }
274 }
275
276 #[cfg(feature = "native")]
288 pub fn run_stdio(&mut self) -> anyhow::Result<()> {
289 use std::io::{self, BufRead};
290
291 let stdin = io::stdin();
292 let stdout = Arc::new(Mutex::new(io::stdout()));
293
294 for line in stdin.lock().lines() {
295 let line = line?;
296 if line.trim().is_empty() {
297 continue;
298 }
299
300 let parsed: Result<JsonRpcRequest, _> = serde_json::from_str(&line);
301 match parsed {
302 Ok(req) => self.route_stdio_message(req, &stdout)?,
303 Err(e) => {
304 let resp = JsonRpcResponse::error(None, -32700, format!("Parse error: {e}"));
305 write_response(&stdout, &resp)?;
306 }
307 }
308 }
309
310 Ok(())
311 }
312
313 #[cfg(feature = "native")]
316 fn route_stdio_message(
317 &mut self,
318 req: JsonRpcRequest,
319 stdout: &Arc<Mutex<std::io::Stdout>>,
320 ) -> anyhow::Result<()> {
321 if req.jsonrpc != "2.0" {
323 let resp = JsonRpcResponse::error(
324 req.id.clone(),
325 -32600,
326 format!(
327 "Invalid Request: jsonrpc must be \"2.0\", got \"{}\"",
328 req.jsonrpc
329 ),
330 );
331 return write_response(stdout, &resp);
332 }
333
334 match req.method.as_str() {
335 "notifications/cancelled" => {
337 if let Some(request_id) = req.params.get("requestId").cloned() {
338 let _ = Self::cancel_in_flight(&self.in_flight, &request_id);
339 }
340 Ok(())
341 }
342 "notifications/initialized" => {
343 Ok(())
345 }
346 "tools/call" => self.spawn_tools_call_worker(req, stdout),
347 _ => {
349 let resp = self.handle_request(&req);
350 write_response(stdout, &resp)
351 }
352 }
353 }
354
355 #[cfg(feature = "native")]
356 fn spawn_tools_call_worker(
357 &mut self,
358 req: JsonRpcRequest,
359 stdout: &Arc<Mutex<std::io::Stdout>>,
360 ) -> anyhow::Result<()> {
361 let Some(id) = req.id.clone() else {
365 let resp =
366 JsonRpcResponse::error(None, -32600, "Invalid Request: tools/call requires an id");
367 return write_response(stdout, &resp);
368 };
369
370 let cancel_rx = Self::register_in_flight(&self.in_flight, id.clone());
371 let stdout_clone = Arc::clone(stdout);
372 let in_flight_clone = Arc::clone(&self.in_flight);
373 let params = req.params.clone();
374 let id_for_worker = id.clone();
375 let progress_token = extract_progress_token(¶ms);
376
377 let sink_stdout = Arc::clone(stdout);
382 let sink: NotificationSink = Box::new(move |notif| {
383 let _ = write_notification(&sink_stdout, ¬if);
385 });
386
387 let builder = std::thread::Builder::new().name(format!("apr-mcp-call-{id}"));
390 let spawn_result = builder.spawn(move || {
391 let sink_ref = progress_token.as_ref().map(|_| &sink);
392 let result =
393 dispatch_tool_call_with_sink(¶ms, &cancel_rx, sink_ref, progress_token);
394 let resp = JsonRpcResponse::success(
395 Some(id_for_worker.clone()),
396 serde_json::to_value(result).unwrap_or_else(|_| serde_json::json!({})),
397 );
398 let _ = write_response(&stdout_clone, &resp);
401 Self::deregister_in_flight(&in_flight_clone, &id_for_worker);
402 });
403
404 match spawn_result {
405 Ok(_handle) => Ok(()),
406 Err(e) => {
407 Self::deregister_in_flight(&self.in_flight, &id);
410 let resp = JsonRpcResponse::error(
411 Some(id),
412 -32603,
413 format!("Internal error: failed to spawn worker thread: {e}"),
414 );
415 write_response(stdout, &resp)
416 }
417 }
418 }
419
420 #[must_use]
422 pub fn in_flight_handle(&self) -> InFlight {
423 Arc::clone(&self.in_flight)
424 }
425}
426
427fn dispatch_tool_call(
434 params: &serde_json::Value,
435 cancel_rx: &mpsc::Receiver<()>,
436 sink: Option<&NotificationSink>,
437) -> ToolCallResult {
438 dispatch_tool_call_with_sink(params, cancel_rx, sink, None)
439}
440
441fn dispatch_tool_call_with_sink(
450 params: &serde_json::Value,
451 cancel_rx: &mpsc::Receiver<()>,
452 sink: Option<&NotificationSink>,
453 progress_token: Option<serde_json::Value>,
454) -> ToolCallResult {
455 let name = params.get("name").and_then(|v| v.as_str());
456 let arguments = params
457 .get("arguments")
458 .cloned()
459 .unwrap_or_else(|| serde_json::json!({}));
460
461 match name {
462 Some(tools::version::NAME) => tools::version::call(&arguments),
463 Some(tools::validate::NAME) => tools::validate::call(&arguments),
464 Some(tools::tensors::NAME) => tools::tensors::call(&arguments),
465 Some(tools::bench::NAME) => tools::bench::call(&arguments),
466 Some(tools::qa::NAME) => tools::qa::call(&arguments),
467 Some(tools::trace::NAME) => tools::trace::call(&arguments),
468 Some(tools::run::NAME) => {
469 tools::run::call_with_sink(&arguments, cancel_rx, sink, progress_token)
475 }
476 Some(tools::serve::NAME) => tools::serve::call(&arguments),
477 Some(tools::finetune::NAME) => {
478 tools::finetune::call_with_sink(&arguments, sink, progress_token)
479 }
480 Some(other) => ToolCallResult::error(format!("Unknown tool: {other}")),
481 None => ToolCallResult::error("Missing tool name"),
482 }
483}
484
485fn extract_progress_token(params: &serde_json::Value) -> Option<serde_json::Value> {
489 params
490 .get("_meta")
491 .and_then(|m| m.get("progressToken"))
492 .cloned()
493}
494
495#[cfg(feature = "native")]
496fn write_response(
497 stdout: &Arc<Mutex<std::io::Stdout>>,
498 resp: &JsonRpcResponse,
499) -> anyhow::Result<()> {
500 use std::io::Write;
501
502 let json = serde_json::to_string(resp)?;
503 let mut guard = stdout
504 .lock()
505 .map_err(|e| anyhow::anyhow!("stdout mutex poisoned: {e}"))?;
506 writeln!(&mut *guard, "{json}")?;
507 guard.flush()?;
508 Ok(())
509}
510
511#[cfg(feature = "native")]
516fn write_notification(
517 stdout: &Arc<Mutex<std::io::Stdout>>,
518 notif: &JsonRpcNotification,
519) -> anyhow::Result<()> {
520 use std::io::Write;
521
522 let json = notif.to_json_line()?;
523 let mut guard = stdout
524 .lock()
525 .map_err(|e| anyhow::anyhow!("stdout mutex poisoned: {e}"))?;
526 writeln!(&mut *guard, "{json}")?;
527 guard.flush()?;
528 Ok(())
529}
530
531#[cfg(test)]
532#[allow(clippy::disallowed_methods)] mod tests {
534 use super::*;
535
536 fn make_request(method: &str, params: serde_json::Value) -> JsonRpcRequest {
537 JsonRpcRequest {
538 jsonrpc: "2.0".to_string(),
539 id: Some(serde_json::json!(1)),
540 method: method.to_string(),
541 params,
542 }
543 }
544
545 #[test]
547 fn initialize_returns_protocol_version() {
548 let mut server = AprMcpServer::new();
549 let req = make_request("initialize", serde_json::json!({}));
550 let resp = server.handle_request(&req);
551
552 assert!(resp.error.is_none());
553 let result = resp.result.expect("result present");
554 assert_eq!(result["protocolVersion"], "2024-11-05");
555 assert_eq!(result["serverInfo"]["name"], "aprender-mcp");
556 assert!(result["capabilities"]["tools"].is_object());
557 }
558
559 #[test]
565 fn tools_list_returns_registered_tools() {
566 let mut server = AprMcpServer::new();
567 let req = make_request("tools/list", serde_json::json!({}));
568 let resp = server.handle_request(&req);
569
570 let result = resp.result.expect("result present");
571 let tools = result["tools"].as_array().expect("tools array");
572 let names: Vec<&str> = tools.iter().filter_map(|t| t["name"].as_str()).collect();
573 for expected in [
574 "apr.version",
575 "apr.validate",
576 "apr.tensors",
577 "apr.bench",
578 "apr.qa",
579 "apr.trace",
580 "apr.run",
581 "apr.serve",
582 "apr.finetune",
583 ] {
584 assert!(names.contains(&expected), "{expected} registered");
585 }
586
587 for tool in tools {
588 assert_eq!(tool["inputSchema"]["type"], "object");
589 }
590 }
591
592 #[test]
593 fn tools_call_version_returns_metadata() {
594 let mut server = AprMcpServer::new();
595 let req = make_request(
596 "tools/call",
597 serde_json::json!({ "name": "apr.version", "arguments": {} }),
598 );
599 let resp = server.handle_request(&req);
600
601 let result = resp.result.expect("result present");
602 let text = result["content"][0]["text"].as_str().expect("text");
603 let parsed: serde_json::Value = serde_json::from_str(text).expect("json");
604 assert_eq!(parsed["server"], "aprender-mcp");
605 assert_eq!(parsed["protocol_version"], "2024-11-05");
606 }
607
608 #[test]
609 fn unknown_method_returns_method_not_found() {
610 let mut server = AprMcpServer::new();
611 let req = make_request("tools/explode", serde_json::json!({}));
612 let resp = server.handle_request(&req);
613
614 assert!(resp.result.is_none());
615 let err = resp.error.expect("error present");
616 assert_eq!(err.code, -32601);
617 }
618
619 #[test]
622 fn tools_call_validate_missing_model_path_is_error() {
623 let mut server = AprMcpServer::new();
624 let req = make_request(
625 "tools/call",
626 serde_json::json!({ "name": "apr.validate", "arguments": {} }),
627 );
628 let resp = server.handle_request(&req);
629
630 let result = resp.result.expect("result present");
631 assert_eq!(result["isError"], true);
632 let text = result["content"][0]["text"].as_str().expect("text");
633 assert!(text.contains("model_path"));
634 }
635
636 #[test]
637 fn tools_call_unknown_tool_returns_is_error() {
638 let mut server = AprMcpServer::new();
639 let req = make_request(
640 "tools/call",
641 serde_json::json!({ "name": "apr.nonexistent" }),
642 );
643 let resp = server.handle_request(&req);
644
645 let result = resp.result.expect("result present");
646 assert_eq!(result["isError"], true);
647 }
648
649 #[test]
650 fn tools_call_missing_name_returns_is_error() {
651 let mut server = AprMcpServer::new();
652 let req = make_request("tools/call", serde_json::json!({}));
653 let resp = server.handle_request(&req);
654
655 let result = resp.result.expect("result present");
656 assert_eq!(result["isError"], true);
657 }
658
659 #[test]
660 fn id_is_echoed_back() {
661 let mut server = AprMcpServer::new();
662 let req = JsonRpcRequest {
663 jsonrpc: "2.0".to_string(),
664 id: Some(serde_json::json!("req-42")),
665 method: "initialize".to_string(),
666 params: serde_json::json!({}),
667 };
668 let resp = server.handle_request(&req);
669 assert_eq!(resp.id, Some(serde_json::json!("req-42")));
670 }
671
672 #[test]
675 fn cancel_in_flight_signals_and_deregisters() {
676 let server = AprMcpServer::new();
677 let id = serde_json::json!(99);
678 let rx = AprMcpServer::register_in_flight(&server.in_flight, id.clone());
679
680 let signalled = AprMcpServer::cancel_in_flight(&server.in_flight, &id);
681 assert!(signalled, "live id should signal");
682 let received = rx.try_recv();
686 assert!(received.is_ok(), "cancel signal must be deliverable");
687
688 let signalled_again = AprMcpServer::cancel_in_flight(&server.in_flight, &id);
690 assert!(
691 !signalled_again,
692 "cancelling an already-removed id is a no-op"
693 );
694 }
695
696 #[test]
698 fn cancel_unknown_id_is_noop() {
699 let server = AprMcpServer::new();
700 let id = serde_json::json!("never-registered");
701 let signalled = AprMcpServer::cancel_in_flight(&server.in_flight, &id);
702 assert!(!signalled);
703 }
704}