#![allow(clippy::disallowed_methods)]
use crate::tools;
use crate::types::{
JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, ToolCallResult, ToolDefinition,
};
use std::collections::HashMap;
use std::sync::mpsc::{self, Sender};
use std::sync::{Arc, Mutex};
pub type NotificationSink = Box<dyn Fn(JsonRpcNotification) + Send + Sync>;
#[derive(Debug)]
pub struct CancelHandle {
pub cancel_tx: Sender<()>,
}
type InFlight = Arc<Mutex<HashMap<serde_json::Value, CancelHandle>>>;
#[derive(Debug, Default)]
pub struct AprMcpServer {
in_flight: InFlight,
}
impl AprMcpServer {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn handle_request(&mut self, request: &JsonRpcRequest) -> JsonRpcResponse {
if request.jsonrpc != "2.0" {
return JsonRpcResponse::error(
request.id.clone(),
-32600,
format!(
"Invalid Request: jsonrpc must be \"2.0\", got \"{}\"",
request.jsonrpc
),
);
}
match request.method.as_str() {
"initialize" => self.handle_initialize(request),
"tools/list" => self.handle_tools_list(request),
"tools/call" => self.handle_tools_call_sync(request),
other => JsonRpcResponse::error(
request.id.clone(),
-32601,
format!("Method not found: {other}"),
),
}
}
fn handle_initialize(&self, request: &JsonRpcRequest) -> JsonRpcResponse {
if let Some(client_version) = request
.params
.get("protocolVersion")
.and_then(|v| v.as_str())
{
if client_version != crate::PROTOCOL_VERSION {
return JsonRpcResponse::error(
request.id.clone(),
-32602,
format!(
"Unsupported protocolVersion: client requested \"{}\", server speaks \"{}\"",
client_version,
crate::PROTOCOL_VERSION
),
);
}
}
JsonRpcResponse::success(
request.id.clone(),
serde_json::json!({
"protocolVersion": crate::PROTOCOL_VERSION,
"capabilities": {
"tools": { "listChanged": false }
},
"serverInfo": {
"name": crate::SERVER_NAME,
"version": env!("CARGO_PKG_VERSION"),
},
}),
)
}
fn handle_tools_list(&self, request: &JsonRpcRequest) -> JsonRpcResponse {
let tools: Vec<ToolDefinition> = self.tool_definitions();
JsonRpcResponse::success(request.id.clone(), serde_json::json!({ "tools": tools }))
}
fn handle_tools_call_sync(&self, request: &JsonRpcRequest) -> JsonRpcResponse {
let (_tx, rx) = mpsc::channel::<()>();
let result = dispatch_tool_call(&request.params, &rx, None);
JsonRpcResponse::success(
request.id.clone(),
serde_json::to_value(result).unwrap_or_else(|_| serde_json::json!({})),
)
}
#[must_use]
pub fn handle_request_with_sink(
&mut self,
request: &JsonRpcRequest,
sink: &NotificationSink,
) -> Option<JsonRpcResponse> {
if request.jsonrpc != "2.0" {
return Some(JsonRpcResponse::error(
request.id.clone(),
-32600,
format!(
"Invalid Request: jsonrpc must be \"2.0\", got \"{}\"",
request.jsonrpc
),
));
}
if request.method.starts_with("notifications/") {
return None;
}
if request.method != "tools/call" {
return Some(self.handle_request(request));
}
let progress_token = extract_progress_token(&request.params);
let (_tx, rx) = mpsc::channel::<()>();
let sink_for_dispatch = progress_token.as_ref().map(|_| sink);
let result =
dispatch_tool_call_with_sink(&request.params, &rx, sink_for_dispatch, progress_token);
Some(JsonRpcResponse::success(
request.id.clone(),
serde_json::to_value(result).unwrap_or_else(|_| serde_json::json!({})),
))
}
#[must_use]
pub fn tool_definitions(&self) -> Vec<ToolDefinition> {
vec![
tools::version_tool_definition(),
tools::validate_tool_definition(),
tools::tensors_tool_definition(),
tools::bench_tool_definition(),
tools::qa_tool_definition(),
tools::trace_tool_definition(),
tools::run_tool_definition(),
tools::serve_tool_definition(),
tools::finetune_tool_definition(),
]
}
#[must_use]
pub fn register_in_flight(in_flight: &InFlight, id: serde_json::Value) -> mpsc::Receiver<()> {
let (tx, rx) = mpsc::channel::<()>();
let mut guard = in_flight
.lock()
.expect("in_flight mutex not poisoned during register");
guard.insert(id, CancelHandle { cancel_tx: tx });
rx
}
pub fn cancel_in_flight(in_flight: &InFlight, id: &serde_json::Value) -> bool {
let mut guard = in_flight
.lock()
.expect("in_flight mutex not poisoned during cancel");
if let Some(handle) = guard.remove(id) {
let _ = handle.cancel_tx.send(());
true
} else {
false
}
}
fn deregister_in_flight(in_flight: &InFlight, id: &serde_json::Value) {
if let Ok(mut guard) = in_flight.lock() {
guard.remove(id);
}
}
#[cfg(feature = "native")]
pub fn run_stdio(&mut self) -> anyhow::Result<()> {
use std::io::{self, BufRead};
let stdin = io::stdin();
let stdout = Arc::new(Mutex::new(io::stdout()));
for line in stdin.lock().lines() {
let line = line?;
if line.trim().is_empty() {
continue;
}
let parsed: Result<JsonRpcRequest, _> = serde_json::from_str(&line);
match parsed {
Ok(req) => self.route_stdio_message(req, &stdout)?,
Err(e) => {
let resp = JsonRpcResponse::error(None, -32700, format!("Parse error: {e}"));
write_response(&stdout, &resp)?;
}
}
}
Ok(())
}
#[cfg(feature = "native")]
fn route_stdio_message(
&mut self,
req: JsonRpcRequest,
stdout: &Arc<Mutex<std::io::Stdout>>,
) -> anyhow::Result<()> {
if req.jsonrpc != "2.0" {
let resp = JsonRpcResponse::error(
req.id.clone(),
-32600,
format!(
"Invalid Request: jsonrpc must be \"2.0\", got \"{}\"",
req.jsonrpc
),
);
return write_response(stdout, &resp);
}
match req.method.as_str() {
"notifications/cancelled" => {
if let Some(request_id) = req.params.get("requestId").cloned() {
let _ = Self::cancel_in_flight(&self.in_flight, &request_id);
}
Ok(())
}
"notifications/initialized" => {
Ok(())
}
"tools/call" => self.spawn_tools_call_worker(req, stdout),
_ => {
let resp = self.handle_request(&req);
write_response(stdout, &resp)
}
}
}
#[cfg(feature = "native")]
fn spawn_tools_call_worker(
&mut self,
req: JsonRpcRequest,
stdout: &Arc<Mutex<std::io::Stdout>>,
) -> anyhow::Result<()> {
let Some(id) = req.id.clone() else {
let resp =
JsonRpcResponse::error(None, -32600, "Invalid Request: tools/call requires an id");
return write_response(stdout, &resp);
};
let cancel_rx = Self::register_in_flight(&self.in_flight, id.clone());
let stdout_clone = Arc::clone(stdout);
let in_flight_clone = Arc::clone(&self.in_flight);
let params = req.params.clone();
let id_for_worker = id.clone();
let progress_token = extract_progress_token(¶ms);
let sink_stdout = Arc::clone(stdout);
let sink: NotificationSink = Box::new(move |notif| {
let _ = write_notification(&sink_stdout, ¬if);
});
let builder = std::thread::Builder::new().name(format!("apr-mcp-call-{id}"));
let spawn_result = builder.spawn(move || {
let sink_ref = progress_token.as_ref().map(|_| &sink);
let result =
dispatch_tool_call_with_sink(¶ms, &cancel_rx, sink_ref, progress_token);
let resp = JsonRpcResponse::success(
Some(id_for_worker.clone()),
serde_json::to_value(result).unwrap_or_else(|_| serde_json::json!({})),
);
let _ = write_response(&stdout_clone, &resp);
Self::deregister_in_flight(&in_flight_clone, &id_for_worker);
});
match spawn_result {
Ok(_handle) => Ok(()),
Err(e) => {
Self::deregister_in_flight(&self.in_flight, &id);
let resp = JsonRpcResponse::error(
Some(id),
-32603,
format!("Internal error: failed to spawn worker thread: {e}"),
);
write_response(stdout, &resp)
}
}
}
#[must_use]
pub fn in_flight_handle(&self) -> InFlight {
Arc::clone(&self.in_flight)
}
}
fn dispatch_tool_call(
params: &serde_json::Value,
cancel_rx: &mpsc::Receiver<()>,
sink: Option<&NotificationSink>,
) -> ToolCallResult {
dispatch_tool_call_with_sink(params, cancel_rx, sink, None)
}
fn dispatch_tool_call_with_sink(
params: &serde_json::Value,
cancel_rx: &mpsc::Receiver<()>,
sink: Option<&NotificationSink>,
progress_token: Option<serde_json::Value>,
) -> ToolCallResult {
let name = params.get("name").and_then(|v| v.as_str());
let arguments = params
.get("arguments")
.cloned()
.unwrap_or_else(|| serde_json::json!({}));
match name {
Some(tools::version::NAME) => tools::version::call(&arguments),
Some(tools::validate::NAME) => tools::validate::call(&arguments),
Some(tools::tensors::NAME) => tools::tensors::call(&arguments),
Some(tools::bench::NAME) => tools::bench::call(&arguments),
Some(tools::qa::NAME) => tools::qa::call(&arguments),
Some(tools::trace::NAME) => tools::trace::call(&arguments),
Some(tools::run::NAME) => {
tools::run::call_with_sink(&arguments, cancel_rx, sink, progress_token)
}
Some(tools::serve::NAME) => tools::serve::call(&arguments),
Some(tools::finetune::NAME) => {
tools::finetune::call_with_sink(&arguments, sink, progress_token)
}
Some(other) => ToolCallResult::error(format!("Unknown tool: {other}")),
None => ToolCallResult::error("Missing tool name"),
}
}
fn extract_progress_token(params: &serde_json::Value) -> Option<serde_json::Value> {
params
.get("_meta")
.and_then(|m| m.get("progressToken"))
.cloned()
}
#[cfg(feature = "native")]
fn write_response(
stdout: &Arc<Mutex<std::io::Stdout>>,
resp: &JsonRpcResponse,
) -> anyhow::Result<()> {
use std::io::Write;
let json = serde_json::to_string(resp)?;
let mut guard = stdout
.lock()
.map_err(|e| anyhow::anyhow!("stdout mutex poisoned: {e}"))?;
writeln!(&mut *guard, "{json}")?;
guard.flush()?;
Ok(())
}
#[cfg(feature = "native")]
fn write_notification(
stdout: &Arc<Mutex<std::io::Stdout>>,
notif: &JsonRpcNotification,
) -> anyhow::Result<()> {
use std::io::Write;
let json = notif.to_json_line()?;
let mut guard = stdout
.lock()
.map_err(|e| anyhow::anyhow!("stdout mutex poisoned: {e}"))?;
writeln!(&mut *guard, "{json}")?;
guard.flush()?;
Ok(())
}
#[cfg(test)]
#[allow(clippy::disallowed_methods)] mod tests {
use super::*;
fn make_request(method: &str, params: serde_json::Value) -> JsonRpcRequest {
JsonRpcRequest {
jsonrpc: "2.0".to_string(),
id: Some(serde_json::json!(1)),
method: method.to_string(),
params,
}
}
#[test]
fn initialize_returns_protocol_version() {
let mut server = AprMcpServer::new();
let req = make_request("initialize", serde_json::json!({}));
let resp = server.handle_request(&req);
assert!(resp.error.is_none());
let result = resp.result.expect("result present");
assert_eq!(result["protocolVersion"], "2024-11-05");
assert_eq!(result["serverInfo"]["name"], "aprender-mcp");
assert!(result["capabilities"]["tools"].is_object());
}
#[test]
fn tools_list_returns_registered_tools() {
let mut server = AprMcpServer::new();
let req = make_request("tools/list", serde_json::json!({}));
let resp = server.handle_request(&req);
let result = resp.result.expect("result present");
let tools = result["tools"].as_array().expect("tools array");
let names: Vec<&str> = tools.iter().filter_map(|t| t["name"].as_str()).collect();
for expected in [
"apr.version",
"apr.validate",
"apr.tensors",
"apr.bench",
"apr.qa",
"apr.trace",
"apr.run",
"apr.serve",
"apr.finetune",
] {
assert!(names.contains(&expected), "{expected} registered");
}
for tool in tools {
assert_eq!(tool["inputSchema"]["type"], "object");
}
}
#[test]
fn tools_call_version_returns_metadata() {
let mut server = AprMcpServer::new();
let req = make_request(
"tools/call",
serde_json::json!({ "name": "apr.version", "arguments": {} }),
);
let resp = server.handle_request(&req);
let result = resp.result.expect("result present");
let text = result["content"][0]["text"].as_str().expect("text");
let parsed: serde_json::Value = serde_json::from_str(text).expect("json");
assert_eq!(parsed["server"], "aprender-mcp");
assert_eq!(parsed["protocol_version"], "2024-11-05");
}
#[test]
fn unknown_method_returns_method_not_found() {
let mut server = AprMcpServer::new();
let req = make_request("tools/explode", serde_json::json!({}));
let resp = server.handle_request(&req);
assert!(resp.result.is_none());
let err = resp.error.expect("error present");
assert_eq!(err.code, -32601);
}
#[test]
fn tools_call_validate_missing_model_path_is_error() {
let mut server = AprMcpServer::new();
let req = make_request(
"tools/call",
serde_json::json!({ "name": "apr.validate", "arguments": {} }),
);
let resp = server.handle_request(&req);
let result = resp.result.expect("result present");
assert_eq!(result["isError"], true);
let text = result["content"][0]["text"].as_str().expect("text");
assert!(text.contains("model_path"));
}
#[test]
fn tools_call_unknown_tool_returns_is_error() {
let mut server = AprMcpServer::new();
let req = make_request(
"tools/call",
serde_json::json!({ "name": "apr.nonexistent" }),
);
let resp = server.handle_request(&req);
let result = resp.result.expect("result present");
assert_eq!(result["isError"], true);
}
#[test]
fn tools_call_missing_name_returns_is_error() {
let mut server = AprMcpServer::new();
let req = make_request("tools/call", serde_json::json!({}));
let resp = server.handle_request(&req);
let result = resp.result.expect("result present");
assert_eq!(result["isError"], true);
}
#[test]
fn id_is_echoed_back() {
let mut server = AprMcpServer::new();
let req = JsonRpcRequest {
jsonrpc: "2.0".to_string(),
id: Some(serde_json::json!("req-42")),
method: "initialize".to_string(),
params: serde_json::json!({}),
};
let resp = server.handle_request(&req);
assert_eq!(resp.id, Some(serde_json::json!("req-42")));
}
#[test]
fn cancel_in_flight_signals_and_deregisters() {
let server = AprMcpServer::new();
let id = serde_json::json!(99);
let rx = AprMcpServer::register_in_flight(&server.in_flight, id.clone());
let signalled = AprMcpServer::cancel_in_flight(&server.in_flight, &id);
assert!(signalled, "live id should signal");
let received = rx.try_recv();
assert!(received.is_ok(), "cancel signal must be deliverable");
let signalled_again = AprMcpServer::cancel_in_flight(&server.in_flight, &id);
assert!(
!signalled_again,
"cancelling an already-removed id is a no-op"
);
}
#[test]
fn cancel_unknown_id_is_noop() {
let server = AprMcpServer::new();
let id = serde_json::json!("never-registered");
let signalled = AprMcpServer::cancel_in_flight(&server.in_flight, &id);
assert!(!signalled);
}
}