use std::iter::once;
use std::time::Duration;
use axum::error_handling::HandleErrorLayer;
use axum::extract::DefaultBodyLimit;
use axum::Router;
use eyre::{OptionExt, Result};
use http::header::AUTHORIZATION;
use http::HeaderName;
use tokio::net;
use tower::ServiceBuilder;
use tower_http::catch_panic::CatchPanicLayer;
use tower_http::compression::{CompressionLayer, DefaultPredicate};
use tower_http::decompression::DecompressionLayer;
use tower_http::propagate_header::PropagateHeaderLayer;
use tower_http::sensitive_headers::SetSensitiveRequestHeadersLayer;
use tower_http::trace::TraceLayer;
use crate::contracts::{Application, Service};
use crate::error::Error;
use crate::foundation::ServerTag;
use crate::http::{error, fallback, panic};
use crate::services::service;
pub struct Server;
impl Service for Server {
fn register<A: Application + ?Sized>() -> Self
where
Self: Sized,
{
Self
}
fn boot<A: Application + ?Sized>() -> Result<()>
where
Self: Sized,
{
if service().get::<Server>().is_some() {
A::runtime().block_on(async {
let (main_sever, metrics_server) = tokio::join!(serve::<A>(), metrics::<A>());
if let Some(err) = main_sever.err() {
tracing::error!(server = ServerTag::Main.as_string(), "error: {err}");
}
if let Some(err) = metrics_server.err() {
tracing::error!(server = ServerTag::Metrics.as_string(), "error: {err}");
}
});
}
Ok(())
}
}
async fn serve<A: Application + ?Sized>() -> Result<()> {
let routes = A::with_routing()
.get(&ServerTag::Main)
.ok_or_eyre(Error::Message("routes is empty".to_string()))?
.to_owned();
let mut app = Router::new();
for route in routes {
app = app.merge(route.to_owned());
}
app = app.fallback(fallback).layer(
ServiceBuilder::new()
.layer(DefaultBodyLimit::max(128 * 1024 * 1024))
.layer(
CompressionLayer::new()
.gzip(true)
.compress_when(DefaultPredicate::new()),
)
.layer(PropagateHeaderLayer::new(HeaderName::from_static(
"x-request-id",
)))
.layer(DecompressionLayer::new().gzip(true))
.layer(SetSensitiveRequestHeadersLayer::new(once(AUTHORIZATION)))
.layer(TraceLayer::new_for_http())
.layer(HandleErrorLayer::new(error))
.layer(CatchPanicLayer::custom(panic))
.timeout(Duration::from_secs(30)),
);
let listener = net::TcpListener::bind(std::env::var("APP_URL")?).await?;
tracing::info!(
server = ServerTag::Main.as_string(),
"[service] listening on {}",
listener.local_addr()?
);
axum::serve(listener, app)
.with_graceful_shutdown(shutdown(ServerTag::Main))
.await?;
Ok(())
}
async fn metrics<A: Application + ?Sized>() -> Result<()> {
let routes = A::with_routing()
.get(&ServerTag::Metrics)
.ok_or_eyre(Error::Message("routes is empty".to_string()))?
.to_owned();
let mut app = Router::new();
for route in routes {
app = app.merge(route.to_owned());
}
let listener = net::TcpListener::bind(std::env::var("METRICS_URL")?).await?;
tracing::info!(
server = ServerTag::Metrics.as_string(),
"[service] listening on {}",
listener.local_addr()?
);
axum::serve(listener, app)
.with_graceful_shutdown(shutdown(ServerTag::Metrics))
.await?;
Ok(())
}
async fn shutdown(tag: ServerTag) {
let ctrl_c = async {
tokio::signal::ctrl_c()
.await
.expect("failed to install Ctrl+C handler");
};
#[cfg(unix)]
let terminate = async {
tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
.expect("fail to install the terminate signal handler")
.recv()
.await;
};
#[cfg(not(unix))]
let terminate = std::future::pending::<()>();
tokio::select! {
_ = ctrl_c => {},
_ = terminate => {},
}
tracing::warn!(
server = tag.as_string(),
"signal received, starting graceful shutdown"
);
}