use std::net::SocketAddr;
use std::sync::Arc;
use axum::Router;
use axum::extract::State;
use axum::http::{HeaderMap, StatusCode};
use axum::middleware::{self, Next};
use axum::response::{IntoResponse, Response};
use axum::routing::get;
use rmcp::transport::streamable_http_server::StreamableHttpService;
use rmcp::transport::streamable_http_server::session::local::LocalSessionManager;
use rmcp::transport::streamable_http_server::tower::StreamableHttpServerConfig;
use tokio::net::TcpListener;
use tokio_util::sync::CancellationToken;
use crate::config::Config;
use crate::context::AdapterContext;
use crate::error::AdapterError;
use crate::server::DeribitMcpServer;
pub async fn serve(
config: Arc<Config>,
ctx: Arc<AdapterContext>,
cancel: CancellationToken,
) -> Result<(), AdapterError> {
let listen: SocketAddr = config.http_listen;
let listener = TcpListener::bind(listen).await.map_err(|err| {
tracing::error!(error = %err, addr = %listen, "failed to bind HTTP listener");
AdapterError::internal("failed to bind HTTP listener")
})?;
serve_with_listener(config, ctx, listener, cancel).await
}
pub async fn serve_with_listener(
config: Arc<Config>,
ctx: Arc<AdapterContext>,
listener: TcpListener,
cancel: CancellationToken,
) -> Result<(), AdapterError> {
let bearer = config.http_bearer_token.clone();
let mcp_service = build_streamable_service(ctx, cancel.clone());
let mcp_router =
Router::new()
.fallback_service(mcp_service)
.layer(middleware::from_fn_with_state(
Arc::new(BearerState { bearer }),
bearer_auth,
));
let app = Router::new()
.route("/healthz", get(healthz))
.nest_service("/mcp", mcp_router);
let local = listener.local_addr().ok();
if let Some(addr) = local {
tracing::info!(addr = %addr, "HTTP transport listening");
}
let cancel_clone = cancel.clone();
let serve = axum::serve(listener, app)
.with_graceful_shutdown(async move { cancel_clone.cancelled().await });
serve.await.map_err(|err| {
tracing::error!(error = %err, "HTTP server exited with error");
AdapterError::internal("HTTP server exited with error")
})?;
Ok(())
}
fn build_streamable_service(
ctx: Arc<AdapterContext>,
cancel: CancellationToken,
) -> StreamableHttpService<DeribitMcpServer, LocalSessionManager> {
let config = StreamableHttpServerConfig::default()
.with_cancellation_token(cancel)
.with_allowed_hosts([
"localhost".to_string(),
"127.0.0.1".to_string(),
"0.0.0.0".to_string(),
]);
StreamableHttpService::new(
move || Ok(DeribitMcpServer::new(ctx.clone())),
Arc::new(LocalSessionManager::default()),
config,
)
}
async fn healthz() -> impl IntoResponse {
(StatusCode::OK, "ok")
}
#[derive(Debug, Clone)]
struct BearerState {
bearer: Option<String>,
}
async fn bearer_auth(
State(state): State<Arc<BearerState>>,
request: axum::extract::Request,
next: Next,
) -> Response {
let Some(expected) = state.bearer.as_deref() else {
return next.run(request).await;
};
if !is_bearer_match(request.headers(), expected) {
return (
StatusCode::UNAUTHORIZED,
[(axum::http::header::WWW_AUTHENTICATE, "Bearer")],
"unauthorized",
)
.into_response();
}
next.run(request).await
}
fn is_bearer_match(headers: &HeaderMap, expected: &str) -> bool {
let Some(value) = headers.get(axum::http::header::AUTHORIZATION) else {
return false;
};
let Ok(text) = value.to_str() else {
return false;
};
let Some(token) = text.strip_prefix("Bearer ") else {
return false;
};
constant_time_eq(token.as_bytes(), expected.as_bytes())
}
#[inline]
fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
let len = a.len().max(b.len());
let mut diff: u8 = (a.len() ^ b.len()) as u8;
for i in 0..len {
let x = *a.get(i).unwrap_or(&0);
let y = *b.get(i).unwrap_or(&0);
diff |= x ^ y;
}
diff == 0
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn bearer_match_accepts_correct_token() {
let mut headers = HeaderMap::new();
headers.insert(
axum::http::header::AUTHORIZATION,
"Bearer secret".parse().unwrap(),
);
assert!(is_bearer_match(&headers, "secret"));
}
#[test]
fn bearer_match_rejects_wrong_token() {
let mut headers = HeaderMap::new();
headers.insert(
axum::http::header::AUTHORIZATION,
"Bearer wrong".parse().unwrap(),
);
assert!(!is_bearer_match(&headers, "secret"));
}
#[test]
fn bearer_match_rejects_missing_header() {
let headers = HeaderMap::new();
assert!(!is_bearer_match(&headers, "secret"));
}
#[test]
fn bearer_match_rejects_non_bearer_scheme() {
let mut headers = HeaderMap::new();
headers.insert(
axum::http::header::AUTHORIZATION,
"Basic secret".parse().unwrap(),
);
assert!(!is_bearer_match(&headers, "secret"));
}
#[test]
fn constant_time_eq_basic() {
assert!(constant_time_eq(b"a", b"a"));
assert!(!constant_time_eq(b"a", b"b"));
assert!(!constant_time_eq(b"a", b"aa"));
}
}