use crate::error::{Error, Result};
use crate::sse_proxy::events::EventManager;
use crate::sse_proxy::proxy::SSEProxy;
use crate::transport::json_rpc::{JsonRpcResponse, error_codes};
use actix_web::{
HttpRequest, HttpResponse, Responder,
web::{Bytes, Data},
};
use serde_json::{Value, json};
use std::sync::Arc;
use tokio::sync::Mutex;
use tracing;
use uuid::Uuid;
struct ParsedMessage {
json_value: Value,
method: String,
request_id: Value,
}
fn parse_message(body: &[u8]) -> Result<ParsedMessage> {
let json_value: Value = match serde_json::from_slice(body) {
Ok(val) => val,
Err(e) => {
tracing::error!(error = %e, "Failed to parse client message as JSON");
return Err(Error::JsonRpc(format!("Invalid JSON format: {}", e)));
}
};
tracing::debug!(
raw_message = ?json_value,
"Raw client message"
);
let method = match json_value.get("method") {
Some(m) => m.as_str().unwrap_or("unknown").to_string(),
None => {
tracing::error!("Message missing 'method' field");
return Err(Error::JsonRpc("Missing 'method' field".to_string()));
}
};
let request_id = match json_value.get("id") {
Some(id) => id.clone(),
None => {
if method == "ping" {
json!("ping-default-id")
} else {
json!(format!("generated-{}", Uuid::new_v4()))
}
}
};
Ok(ParsedMessage {
json_value,
method,
request_id,
})
}
fn create_acceptance_response(request_id: &Value, message: &str) -> HttpResponse {
HttpResponse::Accepted().json(json!({
"status": "accepted",
"id": request_id,
"message": message
}))
}
fn create_jsonrpc_error(request_id: Value, code: i32, message: String) -> HttpResponse {
let error_response = JsonRpcResponse::error(request_id, code, message, None);
HttpResponse::InternalServerError().json(error_response)
}
fn create_jsonrpc_response(request_id: &Value, result: Value) -> Value {
json!({
"jsonrpc": "2.0",
"id": request_id,
"result": result
})
}
async fn handle_notification_message(method: &str, request_id: &Value) -> Result<HttpResponse> {
tracing::info!("Processing notification: {}", method);
if method == "notifications/initialized" {
tracing::debug!("Client sent initialized notification");
}
Ok(create_acceptance_response(
request_id,
"Notification acknowledged",
))
}
async fn handle_initialize_message(
proxy: &Arc<Mutex<SSEProxy>>,
request_id: &Value,
) -> Result<HttpResponse> {
tracing::info!("Processing initialize request");
let proxy_lock = proxy.lock().await;
let event_manager = proxy_lock.event_manager();
let server_info = proxy_lock.get_server_info().lock().await;
let mut servers_map = serde_json::Map::new();
for (_id, info) in server_info.iter() {
servers_map.insert(info.name.clone(), json!(format!("/sse/{}", info.name)));
}
let initialize_response = create_jsonrpc_response(
request_id,
json!({
"servers": servers_map,
"capabilities": {
"streaming": true,
"roots": {
"listChanged": true
},
"sampling": {}
},
"serverInfo": {
"name": "mcp-runner",
"version": env!("CARGO_PKG_VERSION")
},
"protocolVersion": "2024-11-05"
}),
);
tracing::info!(
response = ?initialize_response,
"Sending initialize response"
);
event_manager.send_tool_response(
&request_id.to_string(),
"system",
"initialize",
initialize_response,
);
tracing::debug!("Sent initialization response");
Ok(create_acceptance_response(
request_id,
"Initialization request received and being processed",
))
}
async fn handle_tools_list_message(
proxy: &Arc<Mutex<SSEProxy>>,
request_id: &Value,
) -> Result<HttpResponse> {
tracing::info!("Processing tools/list request");
let proxy_lock = proxy.lock().await;
let event_manager = proxy_lock.event_manager();
let server_info = proxy_lock.get_server_info().lock().await;
let runner_access = proxy_lock.get_runner_access();
let mut all_tools = Vec::new();
let allowed_servers = (runner_access.get_allowed_servers)();
for (server_name, _info) in server_info.iter() {
if let Some(allowed_list) = &allowed_servers {
if !allowed_list.contains(&server_name.to_string()) {
tracing::debug!(server = %server_name, "Skipping server not in allowed list");
continue;
}
}
match (runner_access.get_server_id)(server_name) {
Ok(server_id) => {
match (runner_access.get_client)(server_id) {
Ok(client) => {
if let Err(e) = client.initialize().await {
tracing::warn!(
server = %server_name,
error = %e,
"Failed to initialize client, continuing to next server"
);
continue;
}
match client.list_tools().await {
Ok(tools) => {
for tool in tools {
all_tools.push(json!({
"name": tool.name,
"description": tool.description,
"server": server_name,
"inputSchema": tool.input_schema.unwrap_or(json!({})),
"outputSchema": tool.output_schema.unwrap_or(json!({}))
}));
}
}
Err(e) => {
tracing::warn!(
server = %server_name,
error = %e,
"Failed to list tools for server"
);
}
}
}
Err(e) => {
tracing::warn!(
server = %server_name,
error = %e,
"Failed to get client for server"
);
}
}
}
Err(e) => {
tracing::warn!(
server = %server_name,
error = %e,
"Failed to get server ID"
);
}
}
}
let tools_response = create_jsonrpc_response(
request_id,
json!({
"tools": all_tools
}),
);
tracing::debug!(
response = ?tools_response,
"Sending tools/list response"
);
event_manager.send_tool_response(
&request_id.to_string(),
"system",
"tools/list",
tools_response,
);
Ok(create_acceptance_response(
request_id,
"Tools list request received and processed",
))
}
async fn determine_server_for_tool(
proxy: &Arc<Mutex<SSEProxy>>,
tool_name: &str,
params: &Value,
) -> Result<String> {
if let Some(Value::String(s)) = params.get("server") {
return Ok(s.clone());
}
tracing::info!(
tool = %tool_name,
"Server name not provided, attempting to determine from tool name"
);
let proxy_lock = proxy.lock().await;
let server_info = proxy_lock.get_server_info().lock().await;
let runner_access = proxy_lock.get_runner_access();
for (server_name, _info) in server_info.iter() {
if let Ok(server_id) = (runner_access.get_server_id)(server_name) {
if let Ok(client) = (runner_access.get_client)(server_id) {
if let Err(e) = client.initialize().await {
tracing::warn!(
server = %server_name,
error = %e,
"Failed to initialize client, continuing to next server"
);
continue;
}
if let Ok(tools) = client.list_tools().await {
for tool in tools {
if tool.name == tool_name {
tracing::info!(
tool = %tool_name,
server = %server_name,
"Automatically determined server for tool"
);
return Ok(server_name.clone());
}
}
}
}
}
}
tracing::error!(
tool = %tool_name,
"Could not determine which server provides this tool"
);
Err(Error::JsonRpc(
"Could not determine which server provides this tool".to_string(),
))
}
async fn handle_ping_message(
proxy: &Arc<Mutex<SSEProxy>>,
request_id: &Value,
) -> Result<HttpResponse> {
tracing::info!("Processing ping request");
let proxy_lock = proxy.lock().await;
let event_manager = proxy_lock.event_manager();
let ping_response = create_jsonrpc_response(
request_id,
json!({
"type": "pong"
}),
);
tracing::debug!(
response = ?ping_response,
"Sending ping response"
);
event_manager.send_tool_response(&request_id.to_string(), "system", "ping", ping_response);
Ok(create_acceptance_response(
request_id,
"Ping request received and processed",
))
}
async fn handle_tools_call_message(
proxy: &Arc<Mutex<SSEProxy>>,
request_id: &Value,
json_value: &Value,
) -> Result<HttpResponse> {
tracing::info!("Processing tools/call request");
let params = match json_value.get("params") {
Some(p) => p,
None => {
tracing::error!("Missing params in tools/call request");
return Err(Error::JsonRpc(
"Missing params in tools/call request".to_string(),
));
}
};
let tool_name = match params.get("name") {
Some(Value::String(name)) => name.clone(),
_ => {
tracing::error!("Missing or invalid tool name in tools/call request");
return Err(Error::JsonRpc("Missing or invalid tool name".to_string()));
}
};
let arguments = match params.get("arguments") {
Some(args) => args.clone(),
None => {
tracing::error!("Missing arguments in tools/call request");
return Err(Error::JsonRpc("Missing arguments".to_string()));
}
};
let server_name = determine_server_for_tool(proxy, &tool_name, params).await?;
tracing::debug!(
req_id = ?request_id,
tool_name = %tool_name,
server_name = %server_name,
"Processing tool call"
);
let proxy_lock = Arc::clone(proxy);
let request_id_str = request_id.to_string();
let server_name_clone = server_name;
let tool_name_clone = tool_name;
let arguments_clone = arguments;
tokio::spawn(async move {
let proxy = proxy_lock.lock().await;
if let Err(e) = proxy
.process_tool_call(
&server_name_clone,
&tool_name_clone,
arguments_clone,
&request_id_str,
)
.await
{
tracing::error!(
req_id = %request_id_str,
server = %server_name_clone,
tool = %tool_name_clone,
error = %e,
"Failed to process tool call from tools/call method"
);
}
});
Ok(create_acceptance_response(
request_id,
"Tool call received and being processed",
))
}
pub async fn sse_main_endpoint(
proxy: Data<Arc<Mutex<SSEProxy>>>,
req: HttpRequest,
) -> Result<impl Responder> {
tracing::info!(
path = ?req.path(),
remote_addr = ?req.connection_info().peer_addr(),
"New SSE connection request"
);
let proxy = proxy.lock().await;
let event_manager = proxy.event_manager();
let mut receiver = event_manager.subscribe();
let server_info = proxy.get_server_info().lock().await;
let runner_access = proxy.get_runner_access();
let allowed_servers = (runner_access.get_allowed_servers)();
let mut servers_map = serde_json::Map::new();
for (_id, info) in server_info.iter() {
if let Some(allowed_list) = &allowed_servers {
if !allowed_list.contains(&info.name) {
tracing::debug!(server = %info.name, "Excluding server from initial config - not in allowed list");
continue;
}
}
servers_map.insert(info.name.clone(), json!(format!("/sse/{}", info.name)));
}
let servers = Value::Object(servers_map);
let message_url = "/sse/messages";
tracing::debug!(
message_url = %message_url,
servers = ?servers,
"Preparing SSE endpoint configuration response"
);
let event_manager_clone = Arc::clone(event_manager);
tokio::spawn(async move {
event_manager_clone.send_initial_config(message_url, &servers);
tracing::info!(
message_url = %message_url,
servers = ?servers,
"Sent SSE endpoint configuration to client"
);
});
let stream = async_stream::stream! {
loop {
match receiver.recv().await {
Ok(msg) => {
tracing::debug!(
event_type = %msg.event,
event_id = ?msg.id,
"Sending SSE event to client"
);
yield Ok::<_, actix_web::Error>(EventManager::format_sse_message(&msg));
},
Err(e) => {
tracing::error!(error = %e, "Error receiving SSE event");
break;
}
}
}
};
let response = HttpResponse::Ok()
.insert_header(("Content-Type", "text/event-stream"))
.insert_header(("Cache-Control", "no-cache"))
.insert_header(("Connection", "keep-alive"))
.streaming(stream);
Ok(response)
}
pub async fn sse_messages(
proxy: Data<Arc<Mutex<SSEProxy>>>,
body: Bytes,
_req: HttpRequest,
) -> Result<impl Responder> {
tracing::debug!("Received message from client");
let parsed = match parse_message(&body) {
Ok(parsed) => parsed,
Err(e) => return Err(e),
};
tracing::debug!(
req_id = ?parsed.request_id,
method = %parsed.method,
"Processing client message"
);
if parsed.method.starts_with("notifications/") {
handle_notification_message(&parsed.method, &parsed.request_id).await
} else if parsed.method == "initialize" {
handle_initialize_message(&proxy, &parsed.request_id).await
} else if parsed.method == "tools/list" {
handle_tools_list_message(&proxy, &parsed.request_id).await
} else if parsed.method == "tools/call" {
handle_tools_call_message(&proxy, &parsed.request_id, &parsed.json_value).await
} else if parsed.method == "ping" {
handle_ping_message(&proxy, &parsed.request_id).await
} else {
tracing::warn!(
req_id = ?parsed.request_id,
method = %parsed.method,
"Unknown method received"
);
Ok(create_jsonrpc_error(
parsed.request_id,
error_codes::METHOD_NOT_FOUND,
format!("Method '{}' not found", parsed.method),
))
}
}