use clap::Parser;
use database_mcp_config::{Config, HttpConfig};
use rmcp::transport::streamable_http_server::{
StreamableHttpServerConfig, StreamableHttpService, session::local::LocalSessionManager,
};
use std::sync::Arc;
use tokio_util::sync::CancellationToken;
use tracing::info;
use crate::error::Error;
use crate::server::Server;
#[derive(Debug, Parser)]
pub struct HttpCommand {
#[arg(long, env = "HTTP_HOST", default_value = HttpConfig::DEFAULT_HOST)]
pub host: String,
#[arg(long, env = "HTTP_PORT", default_value_t = HttpConfig::DEFAULT_PORT)]
pub port: u16,
#[arg(
long = "allowed-origins",
env = "HTTP_ALLOWED_ORIGINS",
value_delimiter = ',',
default_values_t = HttpConfig::default_allowed_origins()
)]
pub allowed_origins: Vec<String>,
#[arg(
long = "allowed-hosts",
env = "HTTP_ALLOWED_HOSTS",
value_delimiter = ',',
default_values_t = HttpConfig::default_allowed_hosts()
)]
pub allowed_hosts: Vec<String>,
}
impl HttpCommand {
pub async fn execute(&self, config: &Config, server: Server) -> Result<(), Error> {
let http_config = config
.http
.as_ref()
.ok_or_else(|| Error::Config("HTTP configuration is missing".into()))?;
let bind_addr = format!("{}:{}", http_config.host, http_config.port);
info!("Starting MCP server via HTTP transport on {bind_addr}...");
let ct = CancellationToken::new();
let cors = tower_http::cors::CorsLayer::new()
.allow_origin(
http_config
.allowed_origins
.iter()
.filter_map(|o| o.parse().ok())
.collect::<Vec<axum::http::HeaderValue>>(),
)
.allow_methods([
axum::http::Method::GET,
axum::http::Method::POST,
axum::http::Method::OPTIONS,
])
.allow_headers([axum::http::header::CONTENT_TYPE, axum::http::header::ACCEPT]);
let service = StreamableHttpService::new(
move || Ok(server.clone()),
Arc::new(LocalSessionManager::default()),
StreamableHttpServerConfig::default()
.with_stateful_mode(false)
.with_json_response(true)
.with_cancellation_token(ct.child_token()),
);
let router = axum::Router::new().nest_service("/mcp", service).layer(cors);
let listener = tokio::net::TcpListener::bind(&bind_addr).await?;
info!("Listening on http://{bind_addr}/mcp");
let ct_shutdown = ct.clone();
tokio::spawn(async move {
tokio::signal::ctrl_c().await.ok();
info!("Ctrl-C received, shutting down...");
ct_shutdown.cancel();
});
axum::serve(listener, router)
.with_graceful_shutdown(async move { ct.cancelled().await })
.await?;
Ok(())
}
}