pub mod api_doc;
pub mod auth_middleware;
pub mod common;
pub mod handlers;
pub mod routes;
use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::Arc;
use axum_server::tls_rustls::RustlsConfig;
use manta_backend_dispatcher::error::Error;
use std::time::Duration;
use crate::dispatcher::StaticBackendDispatcher;
use crate::server::common::app_context::InfraContext;
use crate::server::common::kafka::Kafka;
pub struct SiteBackend {
pub backend: StaticBackendDispatcher,
pub shasta_base_url: String,
pub shasta_root_cert: Vec<u8>,
pub socks5_proxy: Option<String>,
pub vault_base_url: Option<String>,
pub gitea_base_url: String,
pub k8s_api_url: Option<String>,
}
pub struct ServerState {
pub sites: HashMap<String, SiteBackend>,
pub console_inactivity_timeout: Duration,
pub auditor: Option<Kafka>,
pub auth_rate_limit_per_minute: Option<u32>,
pub request_timeout: Duration,
pub shutdown_grace_period: Duration,
pub migrate_backup_root: Option<std::path::PathBuf>,
}
impl ServerState {
pub fn infra_context<'a>(
&'a self,
site_name: &'a str,
) -> Result<InfraContext<'a>, Error> {
let site = self.sites.get(site_name).ok_or_else(|| {
Error::NotFound(format!("site '{site_name}' not found"))
})?;
Ok(InfraContext {
backend: &site.backend,
site_name,
shasta_base_url: &site.shasta_base_url,
shasta_root_cert: &site.shasta_root_cert,
socks5_proxy: site.socks5_proxy.as_deref(),
vault_base_url: site.vault_base_url.as_deref(),
gitea_base_url: &site.gitea_base_url,
k8s_api_url: site.k8s_api_url.as_deref(),
})
}
}
async fn log_requests(
request: axum::extract::Request,
next: axum::middleware::Next,
) -> axum::response::Response {
let method = request.method().clone();
let uri = request.uri().clone();
let response = next.run(request).await;
tracing::info!("{} {} → {}", method, uri, response.status());
response
}
pub async fn start_server(
state: Arc<ServerState>,
listen_addr: &str,
port: u16,
cert_path: Option<&str>,
key_path: Option<&str>,
) -> Result<(), Error> {
let shutdown_grace_period = state.shutdown_grace_period;
let app =
routes::build_router(state).layer(axum::middleware::from_fn(log_requests));
let addr: SocketAddr = format!("{listen_addr}:{port}")
.parse()
.map_err(|e| Error::BadRequest(format!("Invalid listen address: {e}")))?;
match (cert_path, key_path) {
(Some(cert), Some(key)) => {
let tls_config = RustlsConfig::from_pem_file(cert, key).await?;
let handle = axum_server::Handle::new();
let ready_handle = handle.clone();
tokio::spawn(async move {
ready_handle.listening().await;
tracing::info!(
"HTTPS server ready, accepting requests on https://{}",
addr
);
eprintln!("HTTPS server ready, accepting requests on https://{addr}");
});
install_shutdown_handler(handle.clone(), shutdown_grace_period);
axum_server::bind_rustls(addr, tls_config)
.handle(handle)
.serve(app.into_make_service_with_connect_info::<SocketAddr>())
.await?;
}
(None, None) => {
let handle = axum_server::Handle::new();
let ready_handle = handle.clone();
tokio::spawn(async move {
ready_handle.listening().await;
tracing::info!(
"HTTP server ready, accepting requests on http://{}",
addr
);
eprintln!("HTTP server ready, accepting requests on http://{addr}");
});
install_shutdown_handler(handle.clone(), shutdown_grace_period);
axum_server::bind(addr)
.handle(handle)
.serve(app.into_make_service_with_connect_info::<SocketAddr>())
.await?;
}
_ => {
return Err(Error::BadRequest(
"--cert and --key must be provided together".to_string(),
));
}
}
Ok(())
}
fn install_shutdown_handler(
handle: axum_server::Handle<SocketAddr>,
grace_period: Duration,
) {
tokio::spawn(async move {
let mut sigterm = match tokio::signal::unix::signal(
tokio::signal::unix::SignalKind::terminate(),
) {
Ok(s) => s,
Err(e) => {
tracing::warn!(
"failed to install SIGTERM handler; falling back to Ctrl+C only: {e}"
);
let _ = tokio::signal::ctrl_c().await;
handle.graceful_shutdown(Some(grace_period));
return;
}
};
let grace_secs = grace_period.as_secs();
tokio::select! {
_ = sigterm.recv() => {
tracing::info!("SIGTERM received; draining for up to {grace_secs}s");
}
_ = tokio::signal::ctrl_c() => {
tracing::info!("Ctrl+C received; draining for up to {grace_secs}s");
}
}
handle.graceful_shutdown(Some(grace_period));
});
}
#[cfg(test)]
mod timeout_layer_tests {
use std::time::Duration;
use axum::{
Router,
body::Body,
http::{Request, StatusCode},
routing::get,
};
use tower::ServiceExt as _;
use tower_http::timeout::TimeoutLayer;
fn get_req(uri: &str) -> Request<Body> {
Request::builder()
.method("GET")
.uri(uri)
.body(Body::empty())
.unwrap()
}
async fn sleep_handler(delay: Duration) -> &'static str {
tokio::time::sleep(delay).await;
"ok"
}
#[tokio::test]
async fn global_timeout_returns_408_when_handler_exceeds_limit() {
let router = Router::new()
.route(
"/slow",
get(|| async { sleep_handler(Duration::from_millis(400)).await }),
)
.layer(TimeoutLayer::with_status_code(
StatusCode::REQUEST_TIMEOUT,
Duration::from_millis(50),
));
let resp = router.oneshot(get_req("/slow")).await.unwrap();
assert_eq!(resp.status(), StatusCode::REQUEST_TIMEOUT);
}
#[tokio::test]
async fn fast_handler_finishes_before_timeout_fires() {
let router = Router::new()
.route(
"/fast",
get(|| async { sleep_handler(Duration::from_millis(10)).await }),
)
.layer(TimeoutLayer::with_status_code(
StatusCode::REQUEST_TIMEOUT,
Duration::from_secs(5),
));
let resp = router.oneshot(get_req("/fast")).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
}
}