use std::collections::HashMap;
use std::sync::Arc;
use serde::Deserialize;
use worker::{Headers, Request, Response};
use super::context::RequestContext;
use super::server::{McpServer, PromptHandlerKind, ResourceHandlerKind, ToolHandlerKind};
use super::types::{JsonRpcRequest, JsonRpcResponse, error_codes};
use turbomcp_core::PROTOCOL_VERSION;
use turbomcp_core::types::capabilities::ClientCapabilities;
use turbomcp_core::types::core::Implementation;
use turbomcp_core::types::initialization::InitializeResult;
const MAX_BODY_SIZE: usize = 1024 * 1024;
pub struct McpHandler<'a> {
server: &'a McpServer,
}
impl<'a> McpHandler<'a> {
pub fn new(server: &'a McpServer) -> Self {
Self { server }
}
fn extract_headers(req: &Request) -> HashMap<String, String> {
let mut headers = HashMap::new();
let worker_headers = req.headers();
for key in [
"authorization",
"content-type",
"user-agent",
"x-request-id",
"x-session-id",
"x-client-id",
"mcp-session-id",
"origin",
"referer",
] {
if let Ok(Some(value)) = worker_headers.get(key) {
headers.insert(key.to_string(), value);
}
}
headers
}
fn create_context_from_request(req: &Request) -> RequestContext {
let headers = Self::extract_headers(req);
let session_id = headers
.get("mcp-session-id")
.or_else(|| headers.get("x-session-id"))
.cloned();
let request_id = headers.get("x-request-id").cloned();
RequestContext::from_worker_request(request_id, session_id, headers)
}
pub async fn handle(&self, mut req: Request) -> worker::Result<Response> {
let request_origin = req.headers().get("origin").ok().flatten();
let origin_ref = request_origin.as_deref();
if req.method() == worker::Method::Options {
return self.cors_preflight_response(origin_ref);
}
if req.method() != worker::Method::Post {
return self.error_response(
405,
"Method not allowed. Use POST for JSON-RPC requests.",
origin_ref,
);
}
if !self.is_valid_content_type(&req) {
return self.error_response(
415,
"Unsupported Media Type. Use Content-Type: application/json",
origin_ref,
);
}
let context = Arc::new(Self::create_context_from_request(&req));
if let Some(content_length) = req.headers().get("content-length").ok().flatten()
&& let Ok(length) = content_length.parse::<usize>()
&& length > MAX_BODY_SIZE
{
return self.error_response(413, "Request body too large", origin_ref);
}
let body = match req.text().await {
Ok(b) => {
if b.len() > MAX_BODY_SIZE {
return self.error_response(413, "Request body too large", origin_ref);
}
if b.is_empty() {
let response = JsonRpcResponse::error(
None,
error_codes::INVALID_REQUEST,
"Empty request body",
);
return self.json_response(&response, origin_ref);
}
b
}
Err(e) => {
let response = JsonRpcResponse::error(
None,
error_codes::PARSE_ERROR,
format!("Failed to read request body: {e}"),
);
return self.json_response(&response, origin_ref);
}
};
let rpc_request: JsonRpcRequest = match serde_json::from_str(&body) {
Ok(r) => r,
Err(e) => {
let response = JsonRpcResponse::error(
None,
error_codes::PARSE_ERROR,
format!("Parse error: {e}"),
);
return self.json_response(&response, origin_ref);
}
};
if rpc_request.jsonrpc != "2.0" {
let response = JsonRpcResponse::error(
rpc_request.id,
error_codes::INVALID_REQUEST,
"Invalid JSON-RPC version. Expected \"2.0\".",
);
return self.json_response(&response, origin_ref);
}
let is_notification = rpc_request.id.is_none();
let response = self.route_request_with_ctx(&rpc_request, context).await;
if is_notification && response.error.is_none() {
return Response::empty()
.map(|r| r.with_status(204))
.map(|r| r.with_headers(self.cors_headers(origin_ref)));
}
self.json_response(&response, origin_ref)
}
fn is_valid_content_type(&self, req: &Request) -> bool {
req.headers()
.get("Content-Type")
.ok()
.flatten()
.map(|ct| ct.contains("application/json") || ct.contains("text/json"))
.unwrap_or(true) }
async fn route_request_with_ctx(
&self,
req: &JsonRpcRequest,
ctx: Arc<RequestContext>,
) -> JsonRpcResponse {
match req.method.as_str() {
"initialize" => self.handle_initialize(req),
"notifications/initialized" => self.handle_initialized_notification(req),
"ping" => self.handle_ping(req),
"tools/list" => self.handle_tools_list(req),
"tools/call" => self.handle_tools_call(req, ctx.clone()).await,
"resources/list" => self.handle_resources_list(req),
"resources/templates/list" => self.handle_resource_templates_list(req),
"resources/read" => self.handle_resources_read(req, ctx.clone()).await,
"prompts/list" => self.handle_prompts_list(req),
"prompts/get" => self.handle_prompts_get(req, ctx.clone()).await,
"logging/setLevel" => self.handle_logging_set_level(req),
_ => JsonRpcResponse::error(
req.id.clone(),
error_codes::METHOD_NOT_FOUND,
format!("Method not found: {}", req.method),
),
}
}
fn handle_initialize(&self, req: &JsonRpcRequest) -> JsonRpcResponse {
let _params: Option<InitializeParams> = req
.params
.as_ref()
.and_then(|p| serde_json::from_value(p.clone()).ok());
let result = InitializeResult {
protocol_version: PROTOCOL_VERSION.into(),
capabilities: self.server.capabilities.clone(),
server_info: self.server.server_info.clone(),
instructions: self.server.instructions.clone(),
_meta: None,
};
match serde_json::to_value(&result) {
Ok(value) => JsonRpcResponse::success(req.id.clone(), value),
Err(e) => JsonRpcResponse::error(
req.id.clone(),
error_codes::INTERNAL_ERROR,
format!("Failed to serialize result: {e}"),
),
}
}
fn handle_initialized_notification(&self, req: &JsonRpcRequest) -> JsonRpcResponse {
JsonRpcResponse::success(req.id.clone(), serde_json::json!({}))
}
fn handle_ping(&self, req: &JsonRpcRequest) -> JsonRpcResponse {
JsonRpcResponse::success(req.id.clone(), serde_json::json!({}))
}
fn handle_logging_set_level(&self, req: &JsonRpcRequest) -> JsonRpcResponse {
JsonRpcResponse::success(req.id.clone(), serde_json::json!({}))
}
fn handle_tools_list(&self, req: &JsonRpcRequest) -> JsonRpcResponse {
let tools: Vec<_> = self.server.tools.values().map(|r| &r.tool).collect();
let result = serde_json::json!({
"tools": tools
});
JsonRpcResponse::success(req.id.clone(), result)
}
async fn handle_tools_call(
&self,
req: &JsonRpcRequest,
ctx: Arc<RequestContext>,
) -> JsonRpcResponse {
#[derive(Deserialize)]
struct CallToolParams {
name: String,
#[serde(default)]
arguments: Option<serde_json::Value>,
}
let params: CallToolParams = match req.params.as_ref() {
Some(p) => match serde_json::from_value(p.clone()) {
Ok(params) => params,
Err(e) => {
return JsonRpcResponse::error(
req.id.clone(),
error_codes::INVALID_PARAMS,
format!("Invalid params: {e}"),
);
}
},
None => {
return JsonRpcResponse::error(
req.id.clone(),
error_codes::INVALID_PARAMS,
"Missing params: expected {name, arguments?}",
);
}
};
let registered_tool = match self.server.tools.get(¶ms.name) {
Some(tool) => tool,
None => {
return JsonRpcResponse::error(
req.id.clone(),
error_codes::METHOD_NOT_FOUND,
format!("Tool not found: {}", params.name),
);
}
};
let args = params.arguments.unwrap_or(serde_json::json!({}));
let tool_result = match ®istered_tool.handler {
ToolHandlerKind::NoCtx(handler) => handler(args).await,
ToolHandlerKind::WithCtx(handler) => handler(ctx, args).await,
};
match serde_json::to_value(&tool_result) {
Ok(value) => JsonRpcResponse::success(req.id.clone(), value),
Err(e) => JsonRpcResponse::error(
req.id.clone(),
error_codes::INTERNAL_ERROR,
format!("Failed to serialize result: {e}"),
),
}
}
fn handle_resources_list(&self, req: &JsonRpcRequest) -> JsonRpcResponse {
let resources: Vec<_> = self
.server
.resources
.values()
.map(|r| &r.resource)
.collect();
let result = serde_json::json!({
"resources": resources
});
JsonRpcResponse::success(req.id.clone(), result)
}
fn handle_resource_templates_list(&self, req: &JsonRpcRequest) -> JsonRpcResponse {
let templates: Vec<_> = self
.server
.resource_templates
.values()
.map(|r| &r.template)
.collect();
let result = serde_json::json!({
"resourceTemplates": templates
});
JsonRpcResponse::success(req.id.clone(), result)
}
async fn handle_resources_read(
&self,
req: &JsonRpcRequest,
ctx: Arc<RequestContext>,
) -> JsonRpcResponse {
#[derive(Deserialize)]
struct ReadResourceParams {
uri: String,
}
let params: ReadResourceParams = match req.params.as_ref() {
Some(p) => match serde_json::from_value(p.clone()) {
Ok(params) => params,
Err(e) => {
return JsonRpcResponse::error(
req.id.clone(),
error_codes::INVALID_PARAMS,
format!("Invalid params: {e}"),
);
}
},
None => {
return JsonRpcResponse::error(
req.id.clone(),
error_codes::INVALID_PARAMS,
"Missing params: expected {uri}",
);
}
};
if let Some(registered_resource) = self.server.resources.get(¶ms.uri) {
let result = match ®istered_resource.handler {
ResourceHandlerKind::NoCtx(handler) => handler(params.uri.clone()).await,
ResourceHandlerKind::WithCtx(handler) => {
handler(ctx.clone(), params.uri.clone()).await
}
};
return match result {
Ok(resource_result) => match serde_json::to_value(&resource_result) {
Ok(value) => JsonRpcResponse::success(req.id.clone(), value),
Err(e) => JsonRpcResponse::error(
req.id.clone(),
error_codes::INTERNAL_ERROR,
format!("Failed to serialize result: {e}"),
),
},
Err(e) => JsonRpcResponse::error(req.id.clone(), error_codes::INTERNAL_ERROR, e),
};
}
for (template_uri, registered_template) in &self.server.resource_templates {
if Self::matches_template(template_uri, ¶ms.uri) {
let result = match ®istered_template.handler {
ResourceHandlerKind::NoCtx(handler) => handler(params.uri.clone()).await,
ResourceHandlerKind::WithCtx(handler) => {
handler(ctx.clone(), params.uri.clone()).await
}
};
return match result {
Ok(resource_result) => match serde_json::to_value(&resource_result) {
Ok(value) => JsonRpcResponse::success(req.id.clone(), value),
Err(e) => JsonRpcResponse::error(
req.id.clone(),
error_codes::INTERNAL_ERROR,
format!("Failed to serialize result: {e}"),
),
},
Err(e) => {
JsonRpcResponse::error(req.id.clone(), error_codes::INTERNAL_ERROR, e)
}
};
}
}
JsonRpcResponse::error(
req.id.clone(),
error_codes::INVALID_PARAMS,
format!("Resource not found: {}", params.uri),
)
}
fn handle_prompts_list(&self, req: &JsonRpcRequest) -> JsonRpcResponse {
let prompts: Vec<_> = self.server.prompts.values().map(|r| &r.prompt).collect();
let result = serde_json::json!({
"prompts": prompts
});
JsonRpcResponse::success(req.id.clone(), result)
}
async fn handle_prompts_get(
&self,
req: &JsonRpcRequest,
ctx: Arc<RequestContext>,
) -> JsonRpcResponse {
#[derive(Deserialize)]
struct GetPromptParams {
name: String,
#[serde(default)]
arguments: Option<serde_json::Value>,
}
let params: GetPromptParams = match req.params.as_ref() {
Some(p) => match serde_json::from_value(p.clone()) {
Ok(params) => params,
Err(e) => {
return JsonRpcResponse::error(
req.id.clone(),
error_codes::INVALID_PARAMS,
format!("Invalid params: {e}"),
);
}
},
None => {
return JsonRpcResponse::error(
req.id.clone(),
error_codes::INVALID_PARAMS,
"Missing params: expected {name, arguments?}",
);
}
};
let registered_prompt = match self.server.prompts.get(¶ms.name) {
Some(prompt) => prompt,
None => {
return JsonRpcResponse::error(
req.id.clone(),
error_codes::INVALID_PARAMS,
format!("Prompt not found: {}", params.name),
);
}
};
let result = match ®istered_prompt.handler {
PromptHandlerKind::NoCtx(handler) => handler(params.arguments).await,
PromptHandlerKind::WithCtx(handler) => handler(ctx, params.arguments).await,
};
match result {
Ok(prompt_result) => match serde_json::to_value(&prompt_result) {
Ok(value) => JsonRpcResponse::success(req.id.clone(), value),
Err(e) => JsonRpcResponse::error(
req.id.clone(),
error_codes::INTERNAL_ERROR,
format!("Failed to serialize result: {e}"),
),
},
Err(e) => JsonRpcResponse::error(req.id.clone(), error_codes::INTERNAL_ERROR, e),
}
}
fn matches_template(template: &str, uri: &str) -> bool {
let template_parts: Vec<&str> = template.split('/').collect();
let uri_parts: Vec<&str> = uri.split('/').collect();
if template_parts.len() != uri_parts.len() {
return false;
}
for (t, u) in template_parts.iter().zip(uri_parts.iter()) {
if t.starts_with('{') && t.ends_with('}') {
if u.is_empty() {
return false;
}
if u.contains("..") || u.contains('\0') || u.contains('%') {
return false;
}
continue;
}
if t != u {
return false;
}
}
true
}
#[allow(dead_code)] pub fn extract_template_params(template: &str, uri: &str) -> HashMap<String, String> {
let mut params = HashMap::new();
let template_parts: Vec<&str> = template.split('/').collect();
let uri_parts: Vec<&str> = uri.split('/').collect();
if template_parts.len() != uri_parts.len() {
return params;
}
for (t, u) in template_parts.iter().zip(uri_parts.iter()) {
if t.starts_with('{') && t.ends_with('}') {
let param_name = &t[1..t.len() - 1];
if u.is_empty() || u.contains("..") || u.contains('\0') || u.contains('%') {
return HashMap::new(); }
params.insert(param_name.to_string(), u.to_string());
} else if t != u {
return HashMap::new();
}
}
params
}
fn cors_headers(&self, request_origin: Option<&str>) -> Headers {
let headers = Headers::new();
let origin = request_origin.unwrap_or("*");
let _ = headers.set("Access-Control-Allow-Origin", origin);
if request_origin.is_some() {
let _ = headers.set("Vary", "Origin");
}
let _ = headers.set("Access-Control-Allow-Methods", "POST, OPTIONS");
let _ = headers.set(
"Access-Control-Allow-Headers",
"Content-Type, Authorization, X-Request-ID",
);
let _ = headers.set("Access-Control-Max-Age", "86400");
headers
}
fn cors_preflight_response(&self, request_origin: Option<&str>) -> worker::Result<Response> {
Response::empty()
.map(|r| r.with_status(204))
.map(|r| r.with_headers(self.cors_headers(request_origin)))
}
fn json_response(
&self,
body: &JsonRpcResponse,
request_origin: Option<&str>,
) -> worker::Result<Response> {
let json = serde_json::to_string(body).map_err(|e| worker::Error::from(e.to_string()))?;
let headers = self.cors_headers(request_origin);
let _ = headers.set("Content-Type", "application/json");
Ok(Response::ok(json)?.with_headers(headers))
}
fn error_response(
&self,
status: u16,
message: &str,
request_origin: Option<&str>,
) -> worker::Result<Response> {
Response::error(message, status).map(|r| r.with_headers(self.cors_headers(request_origin)))
}
}
#[derive(Deserialize)]
#[serde(rename_all = "camelCase")]
#[allow(dead_code)] struct InitializeParams {
#[serde(default)]
protocol_version: String,
#[serde(default)]
capabilities: ClientCapabilities,
#[serde(default)]
client_info: Option<Implementation>,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_template_matching_exact() {
assert!(McpHandler::matches_template(
"file:///path/to/file",
"file:///path/to/file"
));
assert!(McpHandler::matches_template("config://app", "config://app"));
}
#[test]
fn test_template_matching_with_params() {
assert!(McpHandler::matches_template(
"file:///{name}",
"file:///test.txt"
));
assert!(McpHandler::matches_template(
"user://{id}/profile",
"user://123/profile"
));
assert!(McpHandler::matches_template(
"data://{type}/{id}",
"data://users/42"
));
}
#[test]
fn test_template_matching_non_matching() {
assert!(!McpHandler::matches_template(
"file:///path",
"file:///other"
));
assert!(!McpHandler::matches_template(
"file:///{name}/extra",
"file:///test.txt"
));
assert!(!McpHandler::matches_template(
"http://example.com",
"https://example.com"
));
}
#[test]
fn test_template_matching_empty_segments() {
assert!(!McpHandler::matches_template("file:///{name}", "file:///"));
assert!(!McpHandler::matches_template("a/{b}/c", "a//c"));
}
#[test]
fn test_template_matching_rejects_path_traversal() {
assert!(!McpHandler::matches_template(
"file:///{name}",
"file:///../etc/passwd"
));
assert!(!McpHandler::matches_template(
"data://{id}/content",
"data://../secret/content"
));
assert!(!McpHandler::matches_template(
"user://{id}",
"user://../../root"
));
}
#[test]
fn test_template_matching_rejects_null_bytes() {
assert!(!McpHandler::matches_template(
"file:///{name}",
"file:///test\0.txt"
));
}
#[test]
fn test_template_matching_rejects_percent_encoding() {
assert!(!McpHandler::matches_template(
"file:///{name}",
"file:///%2e%2e%2fetc%2fpasswd"
));
assert!(!McpHandler::matches_template(
"data://{type}/{id}",
"data://users/%2e%2e"
));
}
#[test]
fn test_extract_template_params_valid() {
let params = McpHandler::extract_template_params("file:///{name}", "file:///document.txt");
assert_eq!(params.get("name"), Some(&"document.txt".to_string()));
let params =
McpHandler::extract_template_params("user://{id}/profile", "user://123/profile");
assert_eq!(params.get("id"), Some(&"123".to_string()));
let params = McpHandler::extract_template_params("data://{type}/{id}", "data://users/42");
assert_eq!(params.get("type"), Some(&"users".to_string()));
assert_eq!(params.get("id"), Some(&"42".to_string()));
let params =
McpHandler::extract_template_params("file:///{name}", "file:///document-2024.txt");
assert_eq!(params.get("name"), Some(&"document-2024.txt".to_string()));
}
#[test]
fn test_extract_template_params_rejects_dangerous_content() {
let params = McpHandler::extract_template_params("file:///{name}", "file:///../etc/passwd");
assert!(params.is_empty());
let params = McpHandler::extract_template_params("file:///{name}", "file:///test\0.txt");
assert!(params.is_empty());
let params = McpHandler::extract_template_params("file:///{name}", "file:///%2e%2e");
assert!(params.is_empty());
let params = McpHandler::extract_template_params("file:///{name}/data", "file:////data");
assert!(params.is_empty());
}
#[test]
fn test_json_rpc_error_codes() {
assert_eq!(error_codes::PARSE_ERROR, -32700);
assert_eq!(error_codes::INVALID_REQUEST, -32600);
assert_eq!(error_codes::METHOD_NOT_FOUND, -32601);
assert_eq!(error_codes::INVALID_PARAMS, -32602);
assert_eq!(error_codes::INTERNAL_ERROR, -32603);
}
}