use std::sync::Arc;
use axum::Router;
use axum::extract::{Request, State};
use axum::http::{StatusCode, header::AUTHORIZATION};
use axum::middleware::{self, Next};
use axum::response::{IntoResponse, Response};
use rmcp::transport::streamable_http_server::StreamableHttpService;
use rmcp::transport::streamable_http_server::session::local::LocalSessionManager;
use crate::RsigmaMcp;
pub fn http_router(handler: RsigmaMcp, auth_token: Option<String>) -> Router {
let service = StreamableHttpService::new(
move || Ok(handler.clone()),
Arc::new(LocalSessionManager::default()),
Default::default(),
);
let mut router = Router::new().nest_service("/mcp", service);
if let Some(token) = auth_token {
router = router.layer(middleware::from_fn_with_state(Arc::new(token), bearer_auth));
}
router
}
pub async fn serve_http(
handler: RsigmaMcp,
listener: tokio::net::TcpListener,
auth_token: Option<String>,
) -> anyhow::Result<()> {
let router = http_router(handler, auth_token);
axum::serve(listener, router).await?;
Ok(())
}
async fn bearer_auth(
State(expected): State<Arc<String>>,
request: Request,
next: Next,
) -> Response {
let authorized = request
.headers()
.get(AUTHORIZATION)
.and_then(|v| v.to_str().ok())
.and_then(|s| s.strip_prefix("Bearer "))
.map(|t| constant_time_eq(t.as_bytes(), expected.as_bytes()))
.unwrap_or(false);
if authorized {
next.run(request).await
} else {
(StatusCode::UNAUTHORIZED, "missing or invalid bearer token").into_response()
}
}
fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
if a.len() != b.len() {
return false;
}
let mut diff = 0u8;
for (x, y) in a.iter().zip(b.iter()) {
diff |= x ^ y;
}
diff == 0
}
#[cfg(test)]
mod tests {
use super::constant_time_eq;
#[test]
fn constant_time_eq_matches() {
assert!(constant_time_eq(b"secret", b"secret"));
assert!(!constant_time_eq(b"secret", b"secrey"));
assert!(!constant_time_eq(b"secret", b"secretx"));
assert!(!constant_time_eq(b"", b"x"));
}
}