use crate::builder_utils::IpFilter;
use crate::events::McpEventHandler;
use crate::protocol::ToolProtocol;
use std::error::Error;
use std::net::SocketAddr;
use std::sync::Arc;
#[cfg(feature = "server")]
use axum::Router;
pub struct HttpServerConfig {
pub addr: SocketAddr,
pub bearer_token: Option<String>,
pub ip_filter: IpFilter,
pub event_handler: Option<Arc<dyn McpEventHandler>>,
}
impl Clone for HttpServerConfig {
fn clone(&self) -> Self {
Self {
addr: self.addr,
bearer_token: self.bearer_token.clone(),
ip_filter: self.ip_filter.clone(),
event_handler: self.event_handler.clone(),
}
}
}
impl std::fmt::Debug for HttpServerConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("HttpServerConfig")
.field("addr", &self.addr)
.field("bearer_token", &self.bearer_token)
.field("ip_filter", &self.ip_filter)
.field("has_event_handler", &self.event_handler.is_some())
.finish()
}
}
pub struct HttpServerInstance {
pub addr: SocketAddr,
shutdown_handle: Box<dyn std::any::Any + Send + Sync>,
}
impl HttpServerInstance {
pub fn new(addr: SocketAddr, shutdown_handle: Box<dyn std::any::Any + Send + Sync>) -> Self {
Self {
addr,
shutdown_handle,
}
}
pub fn get_addr(&self) -> SocketAddr {
self.addr
}
pub fn shutdown_handle_mut(&mut self) -> &mut Box<dyn std::any::Any + Send + Sync> {
&mut self.shutdown_handle
}
}
#[async_trait::async_trait]
pub trait HttpServerAdapter: Send + Sync {
async fn start(
&self,
config: HttpServerConfig,
protocol: Arc<dyn ToolProtocol>,
) -> Result<HttpServerInstance, Box<dyn Error + Send + Sync>>;
fn name(&self) -> &str {
"unknown"
}
}
#[cfg(feature = "server")]
pub fn axum_router(config: &HttpServerConfig, protocol: Arc<dyn ToolProtocol>) -> Router {
use crate::events::McpEvent;
use axum::{
extract::ConnectInfo, http::HeaderMap, http::StatusCode, response::IntoResponse,
routing::post, Json, Router,
};
use serde_json::json;
use sha2::{Digest, Sha256};
use subtle::ConstantTimeEq;
fn check_auth(expected_token: &Option<String>, headers: &HeaderMap) -> bool {
match expected_token.as_deref() {
None => true,
Some(expected) => {
let provided = headers
.get("Authorization")
.and_then(|v| v.to_str().ok())
.and_then(|v| v.strip_prefix("Bearer "))
.unwrap_or("");
let expected_hash = Sha256::digest(expected.as_bytes());
let provided_hash = Sha256::digest(provided.as_bytes());
expected_hash.ct_eq(&provided_hash).into()
}
}
}
let bearer_token = Arc::new(config.bearer_token.clone());
let ip_filter = Arc::new(config.ip_filter.clone());
let token_list = bearer_token.clone();
let ips_list = ip_filter.clone();
let token_exec = bearer_token.clone();
let ips_exec = ip_filter.clone();
let token_res_list = bearer_token.clone();
let ips_res_list = ip_filter.clone();
let token_res_read = bearer_token.clone();
let ips_res_read = ip_filter.clone();
let eh_list = config.event_handler.clone();
let eh_exec = config.event_handler.clone();
let protocol_list = protocol.clone();
let protocol_exec = protocol.clone();
let protocol_res_list = protocol.clone();
let protocol_res_read = protocol.clone();
Router::new()
.route(
"/tools/list",
post(
move |ConnectInfo(addr): ConnectInfo<SocketAddr>, headers: HeaderMap| {
let token = token_list.clone();
let allowed = ips_list.clone();
let proto = protocol_list.clone();
let eh = eh_list.clone();
async move {
if !allowed.is_allowed(addr.ip()) {
if let Some(ref handler) = eh {
handler
.on_mcp_event(&McpEvent::RequestRejected {
client_addr: addr.ip().to_string(),
reason: "IP not allowed".to_string(),
})
.await;
}
return (
StatusCode::FORBIDDEN,
Json(json!({"error": "Access denied"})),
)
.into_response();
}
if !check_auth(&token, &headers) {
return (
StatusCode::UNAUTHORIZED,
Json(json!({"error": "Unauthorized"})),
)
.into_response();
}
if let Some(ref handler) = eh {
handler
.on_mcp_event(&McpEvent::ToolListRequested {
client_addr: addr.ip().to_string(),
})
.await;
}
match proto.list_tools().await {
Ok(tools) => {
let tool_count = tools.len();
if let Some(ref handler) = eh {
handler
.on_mcp_event(&McpEvent::ToolListReturned {
client_addr: addr.ip().to_string(),
tool_count,
})
.await;
}
(StatusCode::OK, Json(json!({"tools": tools}))).into_response()
}
Err(e) => (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": e.to_string()})),
)
.into_response(),
}
}
},
),
)
.route(
"/tools/execute",
post(
move |ConnectInfo(addr): ConnectInfo<SocketAddr>,
headers: HeaderMap,
Json(payload): Json<serde_json::Value>| {
let token = token_exec.clone();
let allowed = ips_exec.clone();
let proto = protocol_exec.clone();
let eh = eh_exec.clone();
async move {
if !allowed.is_allowed(addr.ip()) {
if let Some(ref handler) = eh {
handler
.on_mcp_event(&McpEvent::RequestRejected {
client_addr: addr.ip().to_string(),
reason: "IP not allowed".to_string(),
})
.await;
}
return (
StatusCode::FORBIDDEN,
Json(json!({"error": "Access denied"})),
)
.into_response();
}
if !check_auth(&token, &headers) {
return (
StatusCode::UNAUTHORIZED,
Json(json!({"error": "Unauthorized"})),
)
.into_response();
}
let tool_name = payload["tool"].as_str().unwrap_or("").to_string();
let params = payload["parameters"].clone();
if let Some(ref handler) = eh {
handler
.on_mcp_event(&McpEvent::ToolCallReceived {
client_addr: addr.ip().to_string(),
tool_name: tool_name.clone(),
parameters: params.clone(),
})
.await;
}
let exec_start = std::time::Instant::now();
match proto.execute(&tool_name, params).await {
Ok(result) => {
let duration_ms = exec_start.elapsed().as_millis() as u64;
let success = result.success;
let error = result.error.clone();
if let Some(ref handler) = eh {
handler
.on_mcp_event(&McpEvent::ToolCallCompleted {
client_addr: addr.ip().to_string(),
tool_name: tool_name.clone(),
success,
error,
duration_ms,
})
.await;
}
(StatusCode::OK, Json(json!({"result": result}))).into_response()
}
Err(e) => {
let duration_ms = exec_start.elapsed().as_millis() as u64;
let err_msg = e.to_string();
if let Some(ref handler) = eh {
handler
.on_mcp_event(&McpEvent::ToolError {
source: addr.ip().to_string(),
tool_name: tool_name.clone(),
error: err_msg.clone(),
duration_ms,
})
.await;
}
(StatusCode::BAD_REQUEST, Json(json!({"error": err_msg})))
.into_response()
}
}
}
},
),
)
.route(
"/resources/list",
post(
move |ConnectInfo(addr): ConnectInfo<SocketAddr>, headers: HeaderMap| {
let token = token_res_list.clone();
let allowed = ips_res_list.clone();
let proto = protocol_res_list.clone();
async move {
if !allowed.is_allowed(addr.ip()) {
return (
StatusCode::FORBIDDEN,
Json(json!({"error": "Access denied"})),
)
.into_response();
}
if !check_auth(&token, &headers) {
return (
StatusCode::UNAUTHORIZED,
Json(json!({"error": "Unauthorized"})),
)
.into_response();
}
if !proto.supports_resources() {
return (
StatusCode::NOT_IMPLEMENTED,
Json(json!({"error": "Resources not supported"})),
)
.into_response();
}
match proto.list_resources().await {
Ok(resources) => {
(StatusCode::OK, Json(json!({"resources": resources})))
.into_response()
}
Err(e) => (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": e.to_string()})),
)
.into_response(),
}
}
},
),
)
.route(
"/resources/read",
post(
move |ConnectInfo(addr): ConnectInfo<SocketAddr>,
headers: HeaderMap,
Json(payload): Json<serde_json::Value>| {
let token = token_res_read.clone();
let allowed = ips_res_read.clone();
let proto = protocol_res_read.clone();
async move {
if !allowed.is_allowed(addr.ip()) {
return (
StatusCode::FORBIDDEN,
Json(json!({"error": "Access denied"})),
)
.into_response();
}
if !check_auth(&token, &headers) {
return (
StatusCode::UNAUTHORIZED,
Json(json!({"error": "Unauthorized"})),
)
.into_response();
}
if !proto.supports_resources() {
return (
StatusCode::NOT_IMPLEMENTED,
Json(json!({"error": "Resources not supported"})),
)
.into_response();
}
let uri = payload["uri"].as_str().unwrap_or("");
match proto.read_resource(uri).await {
Ok(content) => (
StatusCode::OK,
Json(json!({"uri": uri, "content": content})),
)
.into_response(),
Err(e) => {
(StatusCode::NOT_FOUND, Json(json!({"error": e.to_string()})))
.into_response()
}
}
}
},
),
)
}
#[cfg(feature = "server")]
pub struct AxumHttpAdapter;
#[cfg(feature = "server")]
#[async_trait::async_trait]
impl HttpServerAdapter for AxumHttpAdapter {
async fn start(
&self,
config: HttpServerConfig,
protocol: Arc<dyn ToolProtocol>,
) -> Result<HttpServerInstance, Box<dyn Error + Send + Sync>> {
use crate::events::McpEvent;
use tokio::net::TcpListener;
let app =
axum_router(&config, protocol).into_make_service_with_connect_info::<SocketAddr>();
let listener = TcpListener::bind(config.addr).await?;
let addr = listener.local_addr()?;
if let Some(ref handler) = config.event_handler {
handler
.on_mcp_event(&McpEvent::ServerStarted {
addr: addr.to_string(),
})
.await;
}
let server_handle = tokio::spawn(async move { axum::serve(listener, app).await });
Ok(HttpServerInstance::new(addr, Box::new(server_handle)))
}
fn name(&self) -> &str {
"axum"
}
}