use crate::builder_utils::IpFilter;
use crate::events::McpEventHandler;
use crate::protocol::ToolProtocol;
use std::error::Error;
use std::net::SocketAddr;
use std::sync::Arc;
#[cfg(feature = "server")]
use axum::Router;
pub const BEARER_TOKEN_REQUIRED_MESSAGE: &str =
"This MCP endpoint requires a Bearer token in the `Authorization: Bearer <token>` header. \
Either supply a valid token, or set the server's `MENTISDB_BEARER_TOKEN_ACCESS=false` to \
disable bearer-token enforcement for the daemon.";
#[cfg(feature = "server")]
fn bearer_token_required_body() -> serde_json::Value {
serde_json::json!({
"error": "Unauthorized",
"message": BEARER_TOKEN_REQUIRED_MESSAGE,
"error_description": BEARER_TOKEN_REQUIRED_MESSAGE,
"hint": "Send `Authorization: Bearer <token>` with a token issued by `mentisdb bearertoken create --alias <name>`.",
})
}
pub struct HttpServerConfig {
pub addr: SocketAddr,
pub bearer_token: Option<String>,
pub bearer_authorizer: Option<Arc<dyn BearerTokenAuthorizer>>,
pub ip_filter: IpFilter,
pub event_handler: Option<Arc<dyn McpEventHandler>>,
}
impl Clone for HttpServerConfig {
fn clone(&self) -> Self {
Self {
addr: self.addr,
bearer_token: self.bearer_token.clone(),
bearer_authorizer: self.bearer_authorizer.clone(),
ip_filter: self.ip_filter.clone(),
event_handler: self.event_handler.clone(),
}
}
}
impl std::fmt::Debug for HttpServerConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("HttpServerConfig")
.field("addr", &self.addr)
.field("has_bearer_token", &self.bearer_token.is_some())
.field("has_bearer_authorizer", &self.bearer_authorizer.is_some())
.field("ip_filter", &self.ip_filter)
.field("has_event_handler", &self.event_handler.is_some())
.finish()
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct BearerAuthContext {
pub client_addr: SocketAddr,
pub route: String,
pub action: String,
pub payload: Option<serde_json::Value>,
}
pub trait BearerTokenAuthorizer: Send + Sync {
fn allow_missing_bearer_token(&self, _context: &BearerAuthContext) -> bool {
false
}
fn authorize_bearer_token(&self, token: &str, context: &BearerAuthContext) -> bool;
}
pub struct HttpServerInstance {
pub addr: SocketAddr,
shutdown_handle: Box<dyn std::any::Any + Send + Sync>,
}
impl HttpServerInstance {
pub fn new(addr: SocketAddr, shutdown_handle: Box<dyn std::any::Any + Send + Sync>) -> Self {
Self {
addr,
shutdown_handle,
}
}
pub fn get_addr(&self) -> SocketAddr {
self.addr
}
pub fn shutdown_handle_mut(&mut self) -> &mut Box<dyn std::any::Any + Send + Sync> {
&mut self.shutdown_handle
}
}
#[async_trait::async_trait]
pub trait HttpServerAdapter: Send + Sync {
async fn start(
&self,
config: HttpServerConfig,
protocol: Arc<dyn ToolProtocol>,
) -> Result<HttpServerInstance, Box<dyn Error + Send + Sync>>;
fn name(&self) -> &str {
"unknown"
}
}
#[cfg(feature = "server")]
pub fn axum_router(config: &HttpServerConfig, protocol: Arc<dyn ToolProtocol>) -> Router {
use crate::events::McpEvent;
use axum::{
extract::ConnectInfo, http::HeaderMap, http::StatusCode, response::IntoResponse,
routing::post, Json, Router,
};
use serde_json::json;
use sha2::{Digest, Sha256};
use subtle::ConstantTimeEq;
fn bearer_from_headers(headers: &HeaderMap) -> Option<&str> {
headers
.get("Authorization")
.and_then(|v| v.to_str().ok())
.and_then(|v| v.strip_prefix("Bearer "))
}
fn check_auth(
expected_token: &Option<String>,
authorizer: &Option<Arc<dyn BearerTokenAuthorizer>>,
headers: &HeaderMap,
context: BearerAuthContext,
) -> bool {
if expected_token.is_none() && authorizer.is_none() {
return true;
}
let Some(provided) = bearer_from_headers(headers) else {
return authorizer
.as_ref()
.is_some_and(|auth| auth.allow_missing_bearer_token(&context));
};
if let Some(expected) = expected_token.as_deref() {
let expected_hash = Sha256::digest(expected.as_bytes());
let provided_hash = Sha256::digest(provided.as_bytes());
if bool::from(expected_hash.ct_eq(&provided_hash)) {
return true;
}
}
authorizer
.as_ref()
.is_some_and(|auth| auth.authorize_bearer_token(provided, &context))
}
let bearer_token = Arc::new(config.bearer_token.clone());
let bearer_authorizer = Arc::new(config.bearer_authorizer.clone());
let ip_filter = Arc::new(config.ip_filter.clone());
let token_list = bearer_token.clone();
let authz_list = bearer_authorizer.clone();
let ips_list = ip_filter.clone();
let token_exec = bearer_token.clone();
let authz_exec = bearer_authorizer.clone();
let ips_exec = ip_filter.clone();
let token_res_list = bearer_token.clone();
let authz_res_list = bearer_authorizer.clone();
let ips_res_list = ip_filter.clone();
let token_res_read = bearer_token.clone();
let authz_res_read = bearer_authorizer.clone();
let ips_res_read = ip_filter.clone();
let eh_list = config.event_handler.clone();
let eh_exec = config.event_handler.clone();
let protocol_list = protocol.clone();
let protocol_exec = protocol.clone();
let protocol_res_list = protocol.clone();
let protocol_res_read = protocol.clone();
Router::new()
.route(
"/tools/list",
post(
move |ConnectInfo(addr): ConnectInfo<SocketAddr>, headers: HeaderMap| {
let token = token_list.clone();
let authz = authz_list.clone();
let allowed = ips_list.clone();
let proto = protocol_list.clone();
let eh = eh_list.clone();
async move {
if !allowed.is_allowed(addr.ip()) {
if let Some(ref handler) = eh {
handler
.on_mcp_event(&McpEvent::RequestRejected {
client_addr: addr.ip().to_string(),
reason: "IP not allowed".to_string(),
})
.await;
}
return (
StatusCode::FORBIDDEN,
Json(json!({"error": "Access denied"})),
)
.into_response();
}
if !check_auth(
&token,
&authz,
&headers,
BearerAuthContext {
client_addr: addr,
route: "/tools/list".to_string(),
action: "tools/list".to_string(),
payload: None,
},
) {
return (StatusCode::UNAUTHORIZED, Json(bearer_token_required_body()))
.into_response();
}
if let Some(ref handler) = eh {
handler
.on_mcp_event(&McpEvent::ToolListRequested {
client_addr: addr.ip().to_string(),
})
.await;
}
match proto.list_tools().await {
Ok(tools) => {
let tool_count = tools.len();
if let Some(ref handler) = eh {
handler
.on_mcp_event(&McpEvent::ToolListReturned {
client_addr: addr.ip().to_string(),
tool_count,
})
.await;
}
(StatusCode::OK, Json(json!({"tools": tools}))).into_response()
}
Err(e) => (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": e.to_string()})),
)
.into_response(),
}
}
},
),
)
.route(
"/tools/execute",
post(
move |ConnectInfo(addr): ConnectInfo<SocketAddr>,
headers: HeaderMap,
Json(payload): Json<serde_json::Value>| {
let token = token_exec.clone();
let authz = authz_exec.clone();
let allowed = ips_exec.clone();
let proto = protocol_exec.clone();
let eh = eh_exec.clone();
async move {
if !allowed.is_allowed(addr.ip()) {
if let Some(ref handler) = eh {
handler
.on_mcp_event(&McpEvent::RequestRejected {
client_addr: addr.ip().to_string(),
reason: "IP not allowed".to_string(),
})
.await;
}
return (
StatusCode::FORBIDDEN,
Json(json!({"error": "Access denied"})),
)
.into_response();
}
if !check_auth(
&token,
&authz,
&headers,
BearerAuthContext {
client_addr: addr,
route: "/tools/execute".to_string(),
action: "tools/execute".to_string(),
payload: Some(payload.clone()),
},
) {
return (StatusCode::UNAUTHORIZED, Json(bearer_token_required_body()))
.into_response();
}
let tool_name = payload["tool"].as_str().unwrap_or("").to_string();
let params = payload["parameters"].clone();
if let Some(ref handler) = eh {
handler
.on_mcp_event(&McpEvent::ToolCallReceived {
client_addr: addr.ip().to_string(),
tool_name: tool_name.clone(),
parameters: params.clone(),
})
.await;
}
let exec_start = std::time::Instant::now();
match proto.execute(&tool_name, params).await {
Ok(result) => {
let duration_ms = exec_start.elapsed().as_millis() as u64;
let success = result.success;
let error = result.error.clone();
if let Some(ref handler) = eh {
handler
.on_mcp_event(&McpEvent::ToolCallCompleted {
client_addr: addr.ip().to_string(),
tool_name: tool_name.clone(),
success,
error,
duration_ms,
})
.await;
}
(StatusCode::OK, Json(json!({"result": result}))).into_response()
}
Err(e) => {
let duration_ms = exec_start.elapsed().as_millis() as u64;
let err_msg = e.to_string();
if let Some(ref handler) = eh {
handler
.on_mcp_event(&McpEvent::ToolError {
source: addr.ip().to_string(),
tool_name: tool_name.clone(),
error: err_msg.clone(),
duration_ms,
})
.await;
}
(StatusCode::BAD_REQUEST, Json(json!({"error": err_msg})))
.into_response()
}
}
}
},
),
)
.route(
"/resources/list",
post(
move |ConnectInfo(addr): ConnectInfo<SocketAddr>, headers: HeaderMap| {
let token = token_res_list.clone();
let authz = authz_res_list.clone();
let allowed = ips_res_list.clone();
let proto = protocol_res_list.clone();
async move {
if !allowed.is_allowed(addr.ip()) {
return (
StatusCode::FORBIDDEN,
Json(json!({"error": "Access denied"})),
)
.into_response();
}
if !check_auth(
&token,
&authz,
&headers,
BearerAuthContext {
client_addr: addr,
route: "/resources/list".to_string(),
action: "resources/list".to_string(),
payload: None,
},
) {
return (StatusCode::UNAUTHORIZED, Json(bearer_token_required_body()))
.into_response();
}
if !proto.supports_resources() {
return (
StatusCode::NOT_IMPLEMENTED,
Json(json!({"error": "Resources not supported"})),
)
.into_response();
}
match proto.list_resources().await {
Ok(resources) => {
(StatusCode::OK, Json(json!({"resources": resources})))
.into_response()
}
Err(e) => (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": e.to_string()})),
)
.into_response(),
}
}
},
),
)
.route(
"/resources/read",
post(
move |ConnectInfo(addr): ConnectInfo<SocketAddr>,
headers: HeaderMap,
Json(payload): Json<serde_json::Value>| {
let token = token_res_read.clone();
let authz = authz_res_read.clone();
let allowed = ips_res_read.clone();
let proto = protocol_res_read.clone();
async move {
if !allowed.is_allowed(addr.ip()) {
return (
StatusCode::FORBIDDEN,
Json(json!({"error": "Access denied"})),
)
.into_response();
}
if !check_auth(
&token,
&authz,
&headers,
BearerAuthContext {
client_addr: addr,
route: "/resources/read".to_string(),
action: "resources/read".to_string(),
payload: Some(payload.clone()),
},
) {
return (StatusCode::UNAUTHORIZED, Json(bearer_token_required_body()))
.into_response();
}
if !proto.supports_resources() {
return (
StatusCode::NOT_IMPLEMENTED,
Json(json!({"error": "Resources not supported"})),
)
.into_response();
}
let uri = payload["uri"].as_str().unwrap_or("");
match proto.read_resource(uri).await {
Ok(content) => (
StatusCode::OK,
Json(json!({"uri": uri, "content": content})),
)
.into_response(),
Err(e) => {
(StatusCode::NOT_FOUND, Json(json!({"error": e.to_string()})))
.into_response()
}
}
}
},
),
)
}
#[cfg(feature = "server")]
pub struct AxumHttpAdapter;
#[cfg(feature = "server")]
#[async_trait::async_trait]
impl HttpServerAdapter for AxumHttpAdapter {
async fn start(
&self,
config: HttpServerConfig,
protocol: Arc<dyn ToolProtocol>,
) -> Result<HttpServerInstance, Box<dyn Error + Send + Sync>> {
use crate::events::McpEvent;
use tokio::net::TcpListener;
let app =
axum_router(&config, protocol).into_make_service_with_connect_info::<SocketAddr>();
let listener = TcpListener::bind(config.addr).await?;
let addr = listener.local_addr()?;
if let Some(ref handler) = config.event_handler {
handler
.on_mcp_event(&McpEvent::ServerStarted {
addr: addr.to_string(),
})
.await;
}
let server_handle = tokio::spawn(async move { axum::serve(listener, app).await });
Ok(HttpServerInstance::new(addr, Box::new(server_handle)))
}
fn name(&self) -> &str {
"axum"
}
}