use axum::{
Json, Router,
extract::Request,
extract::State,
http::StatusCode,
middleware::{self, Next},
response::IntoResponse,
response::Response,
routing::{get, post},
};
use serde::Deserialize;
use std::sync::Arc;
use crate::adapters::constants::MAX_REQUEST_BYTES;
use crate::server::mcp_dispatcher::dispatch;
use crate::server::mcp_handler::EpistemeMCP;
pub type SharedMCP = Arc<EpistemeMCP>;
#[derive(Clone)]
pub struct McpAuthKeys(pub Vec<String>);
pub fn mcp_http_router(mcp: EpistemeMCP, allowed_api_keys: Vec<String>) -> Router {
let state: SharedMCP = Arc::new(mcp);
Router::new()
.route("/mcp", post(handle_mcp_post))
.route("/health", get(handle_health))
.route("/tools", get(handle_tools_list))
.route("/resources", get(handle_resources_list))
.route("/tool", post(handle_tool_call))
.with_state(state)
.layer(middleware::from_fn_with_state(
Arc::new(McpAuthKeys(allowed_api_keys)),
mcp_auth_middleware,
))
.layer(axum::extract::DefaultBodyLimit::max(MAX_REQUEST_BYTES))
}
async fn mcp_auth_middleware(
State(keys): State<Arc<McpAuthKeys>>,
request: Request,
next: Next,
) -> Response {
if request.uri().path() == "/health" {
return next.run(request).await;
}
if keys.0.is_empty() {
return next.run(request).await;
}
let bearer = request
.headers()
.get(axum::http::header::AUTHORIZATION)
.and_then(|v| v.to_str().ok())
.and_then(|v| v.strip_prefix("Bearer "));
match bearer {
Some(token) if crate::server::mcp_auth::validate_api_key(token, &keys.0) => {
next.run(request).await
}
_ => (
StatusCode::UNAUTHORIZED,
Json(serde_json::json!({"error": "unauthorized"})),
)
.into_response(),
}
}
pub async fn handle_health() -> impl IntoResponse {
Json(serde_json::json!({"status": "ok"}))
}
pub async fn handle_tools_list() -> impl IntoResponse {
Json(crate::server::mcp_schemas::tool_schemas())
}
pub async fn handle_resources_list() -> impl IntoResponse {
Json(crate::server::mcp_schemas::resource_schemas())
}
pub async fn handle_mcp_post(
State(mcp): State<SharedMCP>,
Json(request): Json<serde_json::Value>,
) -> impl IntoResponse {
if !request.is_object() {
return (
StatusCode::BAD_REQUEST,
Json(serde_json::json!({
"jsonrpc": "2.0",
"error": { "code": -32600, "message": "Invalid Request: expected JSON object" },
"id": null,
})),
);
}
match dispatch(&mcp, request) {
Some(response) => (StatusCode::OK, Json(response)),
None => (StatusCode::NO_CONTENT, Json(serde_json::Value::Null)),
}
}
#[derive(Deserialize)]
pub struct ToolCallRequest {
pub name: String,
#[serde(default)]
pub arguments: serde_json::Value,
}
pub async fn handle_tool_call(
State(mcp): State<SharedMCP>,
Json(body): Json<ToolCallRequest>,
) -> impl IntoResponse {
let result = mcp.handle_tool_call(&body.name, &body.arguments);
if let Some(err) = result.get("error") {
let msg = err.as_str().unwrap_or("tool call error");
let status = if msg.contains("Unknown tool") {
StatusCode::NOT_FOUND
} else {
StatusCode::BAD_REQUEST
};
return (status, Json(result));
}
(StatusCode::OK, Json(result))
}
#[cfg(test)]
mod tests {
use super::*;
use axum::body::Body;
use axum::http::Request;
use tower::ServiceExt;
fn test_router(keys: Vec<String>) -> Router {
let graph =
crate::domain::graph::KnowledgeGraph::from_entities(std::collections::HashMap::new());
let mcp = crate::server::mcp_handler::EpistemeMCP::new(graph);
mcp_http_router(mcp, keys)
}
#[tokio::test]
async fn health_is_public() {
let app = test_router(vec!["k1".to_string()]);
let req = Request::builder()
.uri("/health")
.body(Body::empty())
.unwrap();
let res = app.oneshot(req).await.unwrap();
assert_eq!(res.status(), StatusCode::OK);
}
#[tokio::test]
async fn tools_requires_bearer_when_keys_present() {
let app = test_router(vec!["k1".to_string()]);
let req = Request::builder()
.uri("/tools")
.body(Body::empty())
.unwrap();
let res = app.oneshot(req).await.unwrap();
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
}
#[tokio::test]
async fn tools_allows_valid_bearer() {
let app = test_router(vec!["k1".to_string()]);
let req = Request::builder()
.uri("/tools")
.header(axum::http::header::AUTHORIZATION, "Bearer k1")
.body(Body::empty())
.unwrap();
let res = app.oneshot(req).await.unwrap();
assert_eq!(res.status(), StatusCode::OK);
}
}