use crate::{
app::AppContext,
config::Config,
compression::build_compression_layer,
cors::build_cors_layer,
http::RouteModule,
middleware::MakeRequestUuid,
ratelimit::build_rate_limit_layer,
request_logging::build_request_logging_layer,
security::build_security_headers_layer,
timeout::build_timeout_layer,
};
use axum::{extract::DefaultBodyLimit, Router};
use std::sync::Arc;
use std::time::Duration;
use tokio::signal;
use tower_http::request_id::{PropagateRequestIdLayer, SetRequestIdLayer};
use tower_http::trace::TraceLayer;
#[cfg(feature = "metrics")]
use crate::metrics::{build_metrics_layer, metrics_handler, MetricsCollector};
#[cfg(feature = "jobs")]
use crate::jobs::{JobRegistry, WorkerPool};
pub struct App {
router: Router<AppContext>,
config: Config,
context: AppContext,
extra_routers: Vec<Router>,
#[cfg(feature = "metrics")]
metrics_collector: Option<Arc<MetricsCollector>>,
#[cfg(feature = "jobs")]
worker_pool: Option<WorkerPool>,
}
impl App {
pub fn new() -> Self {
Self::with_config(Config::default())
}
pub fn with_config(config: Config) -> Self {
let context = AppContext::new();
let router = Self::build_router(&config);
#[cfg(feature = "metrics")]
let metrics_collector = if config.metrics.enabled {
Some(Arc::new(
MetricsCollector::new().expect("Failed to create metrics collector"),
))
} else {
None
};
Self {
router,
config,
context,
extra_routers: Vec::new(),
#[cfg(feature = "metrics")]
metrics_collector,
#[cfg(feature = "jobs")]
worker_pool: None,
}
}
pub fn builder() -> AppBuilder {
AppBuilder::new()
}
fn build_router(_config: &Config) -> Router<AppContext> {
Router::<AppContext>::new()
}
pub fn register_module<M: RouteModule>(mut self, module: M) -> Self {
let module_router = module.routes();
if let Some(prefix) = module.prefix() {
self.router = self.router.nest(prefix, module_router);
} else {
self.router = self.router.merge(module_router);
}
self
}
pub fn merge_router(mut self, router: Router) -> Self {
self.extra_routers.push(router);
self
}
pub fn with_context(mut self, context: AppContext) -> Self {
self.context = context;
use axum::routing::get;
use crate::health;
let health_routes = Router::<AppContext>::new()
.route("/health", get(health::health_handler));
self.router = self.router.merge(health_routes);
self
}
pub fn layer<L>(mut self, layer: L) -> Self
where
L: tower::Layer<axum::routing::Route> + Clone + Send + Sync + 'static,
L::Service: tower::Service<axum::http::Request<axum::body::Body>, Error = std::convert::Infallible> + Clone + Send + Sync + 'static,
<L::Service as tower::Service<axum::http::Request<axum::body::Body>>>::Response: axum::response::IntoResponse + 'static,
<L::Service as tower::Service<axum::http::Request<axum::body::Body>>>::Future: Send + 'static,
{
self.router = self.router.layer(layer);
self
}
pub fn into_test_router(self) -> Router {
self.router.with_state(self.context)
}
fn with_middleware(mut self) -> Self {
let mut router = self.router;
#[cfg(feature = "metrics")]
{
if let Some(ref collector) = self.metrics_collector {
let _metrics_router: Router = Router::new()
.route(
self.config.metrics.path.as_str(),
axum::routing::get(metrics_handler),
)
.with_state(collector.clone());
router = router.layer(build_metrics_layer(collector.clone()));
}
}
router = router.layer(DefaultBodyLimit::max(self.config.server.max_body_size));
if let Some(timeout_layer) = build_timeout_layer(&self.config.timeout) {
router = router.layer(timeout_layer);
}
if let Some(security_layer) = build_security_headers_layer(&self.config.security) {
router = router.layer(security_layer);
}
if let Some(compression_layer) = build_compression_layer(&self.config.compression) {
router = router.layer(compression_layer);
}
if let Some(rate_limit_layer) = build_rate_limit_layer(&self.config.rate_limit) {
router = router.layer(rate_limit_layer);
}
if let Some(cors_layer) = build_cors_layer(&self.config.cors) {
router = router.layer(cors_layer);
}
router = router
.layer(SetRequestIdLayer::x_request_id(MakeRequestUuid))
.layer(PropagateRequestIdLayer::x_request_id());
router = router.layer(TraceLayer::new_for_http());
if let Some(logging_layer) = build_request_logging_layer(&self.config.request_logging) {
router = router.layer(logging_layer);
}
self.router = router;
self
}
#[cfg(feature = "jobs")]
pub fn start_workers(mut self, registry: Arc<JobRegistry>) -> Self {
if let Some(ref queue) = self.context.jobs {
if self.config.jobs.enabled {
let pool = WorkerPool::new(
queue.clone(),
registry,
Arc::new(self.context.clone()),
self.config.jobs.worker_count,
);
self.worker_pool = Some(pool);
tracing::info!(
worker_count = self.config.jobs.worker_count,
"Background job workers started"
);
}
}
self
}
pub async fn serve(self) -> Result<(), std::io::Error> {
let addr = self
.config
.server
.addr()
.expect("Invalid server address in config");
#[allow(unused_mut)] let mut app = self.with_middleware();
let listener = tokio::net::TcpListener::bind(addr).await?;
tracing::info!("Server starting on http://{}", addr);
tracing::info!("Health check available at http://{}/health", addr);
#[cfg(feature = "jobs")]
let worker_pool = app.worker_pool.take();
let shutdown = async move {
shutdown_signal().await;
#[cfg(feature = "jobs")]
{
if let Some(pool) = worker_pool {
pool.shutdown().await;
}
}
};
let mut final_router = app.router.with_state(app.context);
for extra in app.extra_routers {
final_router = final_router.merge(extra);
}
if let Some(cors_layer) = crate::cors::build_cors_layer(&app.config.cors) {
final_router = final_router.layer(cors_layer);
}
axum::serve(listener, final_router)
.with_graceful_shutdown(shutdown)
.await
}
}
impl Default for App {
fn default() -> Self {
Self::new()
}
}
#[must_use = "builder does nothing until you call build()"]
pub struct AppBuilder {
config: Config,
context: AppContext,
modules: Vec<Router<AppContext>>,
#[cfg(feature = "metrics")]
metrics_collector: Option<Arc<MetricsCollector>>,
}
impl AppBuilder {
pub fn new() -> Self {
Self {
config: Config::default(),
context: AppContext::new(),
modules: Vec::new(),
#[cfg(feature = "metrics")]
metrics_collector: None,
}
}
pub fn with_config(mut self, config: Config) -> Self {
self.config = config;
self
}
pub fn with_context(mut self, context: AppContext) -> Self {
self.context = context;
self
}
pub fn with_cors(mut self, cors: crate::cors::CorsConfig) -> Self {
self.config.cors = cors;
self
}
pub fn register_module<M: RouteModule>(mut self, module: M) -> Self {
self.modules.push(module.routes());
self
}
pub fn build(self) -> App {
let mut app = App::with_config(self.config);
app.context = self.context;
#[cfg(feature = "metrics")]
{
app.metrics_collector = self.metrics_collector;
}
for module_router in self.modules {
app.router = app.router.merge(module_router);
}
app
}
}
impl Default for AppBuilder {
fn default() -> Self {
Self::new()
}
}
async fn shutdown_signal() {
let ctrl_c = async {
signal::ctrl_c()
.await
.expect("failed to install Ctrl+C handler");
};
#[cfg(unix)]
let terminate = async {
signal::unix::signal(signal::unix::SignalKind::terminate())
.expect("failed to install signal handler")
.recv()
.await;
};
#[cfg(not(unix))]
let terminate = std::future::pending::<()>();
tokio::select! {
_ = ctrl_c => {
tracing::info!("Received Ctrl+C signal, starting graceful shutdown");
},
_ = terminate => {
tracing::info!("Received terminate signal, starting graceful shutdown");
},
}
tokio::time::sleep(Duration::from_secs(1)).await;
tracing::info!("Shutdown complete");
}