use std::net::SocketAddr;
use std::sync::Arc;
use axum::extract::Request;
use axum::http::{HeaderName, Method, StatusCode};
use axum::middleware::{self, Next};
use axum::response::IntoResponse;
use axum::routing::get;
use axum::Router;
use rmcp::transport::streamable_http_server::{
session::local::LocalSessionManager, StreamableHttpServerConfig, StreamableHttpService,
};
use tower_http::cors::CorsLayer;
use crate::auth;
use crate::client::OriginClient;
use crate::tools::{OriginMcpServer, TransportMode};
#[derive(Debug, Clone)]
pub struct ServeConfig {
pub port: u16,
pub host: String,
pub origin_url: String,
pub token: Option<String>,
pub agent_name: String,
pub user_id: Option<String>,
pub allowed_origins: Vec<String>,
}
async fn health() -> impl IntoResponse {
axum::Json(serde_json::json!({
"status": "ok",
"server": "origin-mcp",
"version": env!("CARGO_PKG_VERSION"),
}))
}
pub async fn run_serve(config: ServeConfig) -> anyhow::Result<()> {
let client =
OriginClient::new(config.origin_url.clone()).with_agent_name(config.agent_name.clone());
let agent_name = config.agent_name.clone();
let user_id = config.user_id.clone();
let token = config.token.clone();
let allowed_origins = config.allowed_origins.clone();
let mcp_config = StreamableHttpServerConfig::default().disable_allowed_hosts();
let mcp_service = StreamableHttpService::new(
move || {
Ok(OriginMcpServer::new(
client.clone(),
TransportMode::Http,
agent_name.clone(),
user_id.clone(),
))
},
Arc::new(LocalSessionManager::default()),
mcp_config,
);
let cors = build_cors_layer(&config.allowed_origins);
let mut router = Router::new()
.nest_service("/mcp", mcp_service)
.route("/health", get(health))
.layer(cors);
if let Some(ref expected_token) = token {
let token_for_middleware = expected_token.clone();
let origins_for_middleware = allowed_origins.clone();
router = router.layer(middleware::from_fn(move |req: Request, next: Next| {
let token = token_for_middleware.clone();
let origins = origins_for_middleware.clone();
async move { auth_and_origin_middleware(req, next, &token, &origins).await }
}));
}
let addr: SocketAddr = format!("{}:{}", config.host, config.port).parse()?;
let listener = tokio::net::TcpListener::bind(addr).await?;
tracing::info!("origin-mcp HTTP server listening on {}", addr);
if token.is_some() {
tracing::info!("Bearer token authentication enabled");
} else {
tracing::warn!("Running without authentication — only safe on loopback");
}
let shutdown = async {
#[cfg(unix)]
{
use tokio::signal::unix::{signal, SignalKind};
let ctrl_c = tokio::signal::ctrl_c();
let mut sigterm =
signal(SignalKind::terminate()).expect("failed to register SIGTERM handler");
tokio::select! {
_ = ctrl_c => {},
_ = sigterm.recv() => {},
}
}
#[cfg(not(unix))]
{
tokio::signal::ctrl_c().await.ok();
}
tracing::info!("Shutting down origin-mcp HTTP server");
};
axum::serve(listener, router)
.with_graceful_shutdown(shutdown)
.await?;
Ok(())
}
fn build_cors_layer(allowed_origins: &[String]) -> CorsLayer {
let cors = CorsLayer::new()
.allow_methods([Method::GET, Method::POST, Method::DELETE, Method::OPTIONS])
.allow_headers([
http::header::AUTHORIZATION,
http::header::CONTENT_TYPE,
http::header::ACCEPT,
HeaderName::from_static("mcp-session-id"),
HeaderName::from_static("mcp-protocol-version"),
]);
if allowed_origins.iter().any(|o| o == "*") {
cors.allow_origin(tower_http::cors::Any)
} else {
let origins: Vec<http::HeaderValue> = allowed_origins
.iter()
.filter_map(|o| o.parse().ok())
.collect();
cors.allow_origin(origins)
}
}
async fn auth_and_origin_middleware(
req: Request,
next: Next,
expected_token: &str,
allowed_origins: &[String],
) -> axum::response::Response {
let is_preflight = req.method() == Method::OPTIONS;
let is_health = req.uri().path() == "/health";
if is_preflight || is_health {
return next.run(req).await;
}
let auth_header = req.headers().get(http::header::AUTHORIZATION);
match auth_header {
Some(value) => {
let value_str = value.to_str().unwrap_or("");
match auth::extract_bearer_token(value_str) {
Some(provided) if auth::verify_token(provided, expected_token) => {}
_ => return (StatusCode::UNAUTHORIZED, "Invalid bearer token").into_response(),
}
}
None => return (StatusCode::UNAUTHORIZED, "Authorization header required").into_response(),
}
if let Some(origin) = req.headers().get(http::header::ORIGIN) {
if let Ok(origin_str) = origin.to_str() {
if !auth::is_origin_allowed(origin_str, allowed_origins) {
return (StatusCode::FORBIDDEN, "Origin not allowed").into_response();
}
}
}
next.run(req).await
}