use clap::{Args, Parser};
use database_mcp_config::{ConfigError, DatabaseConfig, HttpConfig};
use rmcp::transport::streamable_http_server::{
StreamableHttpServerConfig, StreamableHttpService, session::local::LocalSessionManager,
};
use std::sync::Arc;
use tokio_util::sync::CancellationToken;
use tower_http::cors::CorsLayer;
use tracing::info;
use crate::commands::common::{self, DatabaseArguments};
use crate::error::Error;
#[derive(Debug, Args)]
#[command(next_help_heading = "HTTP Transport")]
struct HttpArguments {
#[arg(
id = "http-host",
long = "host",
env = "HTTP_HOST",
value_name = "HOST",
default_value = HttpConfig::DEFAULT_HOST
)]
host: String,
#[arg(
id = "http-port",
long = "port",
env = "HTTP_PORT",
value_name = "PORT",
default_value_t = HttpConfig::DEFAULT_PORT
)]
port: u16,
#[arg(
long = "allowed-origins",
env = "HTTP_ALLOWED_ORIGINS",
value_delimiter = ',',
default_values_t = HttpConfig::default_allowed_origins()
)]
allowed_origins: Vec<String>,
#[arg(
long = "allowed-hosts",
env = "HTTP_ALLOWED_HOSTS",
value_delimiter = ',',
default_values_t = HttpConfig::default_allowed_hosts()
)]
allowed_hosts: Vec<String>,
}
impl TryFrom<&HttpArguments> for HttpConfig {
type Error = Vec<ConfigError>;
fn try_from(http: &HttpArguments) -> Result<Self, Self::Error> {
let config = Self {
host: http.host.clone(),
port: http.port,
allowed_origins: http.allowed_origins.clone(),
allowed_hosts: http.allowed_hosts.clone(),
};
config.validate()?;
Ok(config)
}
}
#[derive(Debug, Parser)]
pub(crate) struct HttpCommand {
#[command(flatten)]
db_arguments: DatabaseArguments,
#[command(flatten)]
http_arguments: HttpArguments,
}
impl HttpCommand {
pub(crate) async fn execute(&self) -> Result<(), Error> {
let db_config = DatabaseConfig::try_from(&self.db_arguments)?;
let http_config = HttpConfig::try_from(&self.http_arguments)?;
let server = common::create_server(&db_config);
let cancel_token = CancellationToken::new();
let service = StreamableHttpService::new(
move || Ok(server.clone()),
Arc::new(LocalSessionManager::default()),
StreamableHttpServerConfig::default()
.with_stateful_mode(false)
.with_json_response(true)
.with_cancellation_token(cancel_token.child_token()),
);
let router = axum::Router::new()
.nest_service("/mcp", service)
.layer(build_cors_layer(&http_config));
let bind_addr = format!("{}:{}", http_config.host, http_config.port);
info!("Starting MCP server via HTTP transport on {bind_addr}...");
let listener = tokio::net::TcpListener::bind(&bind_addr).await?;
info!("Listening on http://{bind_addr}/mcp");
axum::serve(listener, router)
.with_graceful_shutdown(async move {
shutdown_signal().await;
cancel_token.cancel();
})
.await?;
Ok(())
}
}
fn build_cors_layer(http_config: &HttpConfig) -> CorsLayer {
let origins: Vec<axum::http::HeaderValue> = http_config
.allowed_origins
.iter()
.filter_map(|origin| origin.parse().ok())
.collect();
CorsLayer::new()
.allow_origin(origins)
.allow_methods([
axum::http::Method::GET,
axum::http::Method::POST,
axum::http::Method::OPTIONS,
])
.allow_headers([axum::http::header::CONTENT_TYPE, axum::http::header::ACCEPT])
}
async fn shutdown_signal() {
let ctrl_c = async {
tokio::signal::ctrl_c().await.expect("failed to install Ctrl-C handler");
};
#[cfg(unix)]
let terminate = async {
tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
.expect("failed to install SIGTERM handler")
.recv()
.await;
};
#[cfg(not(unix))]
let terminate = std::future::pending::<()>();
tokio::select! {
() = ctrl_c => info!("Ctrl-C received, shutting down..."),
() = terminate => info!("SIGTERM received, shutting down..."),
}
}