pub mod health;
pub(crate) mod management;
#[cfg(feature = "tls")]
pub(crate) mod tls;
pub mod version;
use crate::application::health::{AlwaysReadyAndAlive, HealthExt};
use crate::application::version::{DefaultVersion, VersionExt};
use crate::configuration::{AppConfig, Empty};
use crate::error::Result;
use crate::management::build_management_router;
use crate::middleware::trace_request;
use axum::middleware::from_fn;
use axum::Router;
use hyper::Server;
use std::fmt::{Debug, Formatter};
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::signal;
use tracing::info;
pub struct Application<H = AlwaysReadyAndAlive, T = Empty, V = DefaultVersion> {
config: Arc<AppConfig<T>>,
health_indicator: H,
version: V,
router: Option<Router>,
metrics_callback: Option<Arc<dyn Fn() + Send + Sync + 'static>>,
use_default_trace_layer: bool,
}
impl<H: Debug, T: Debug, V: Debug> Debug for Application<H, T, V> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
let Self {
config,
health_indicator,
router,
metrics_callback,
use_default_trace_layer,
version,
} = self;
f.debug_struct("Application")
.field("config", config)
.field("health_indicator", health_indicator)
.field("version", version)
.field("router", router)
.field("use_default_trace_layer", use_default_trace_layer)
.field(
"metrics_callback",
if metrics_callback.is_some() {
&"Some"
} else {
&"None"
},
)
.finish()
}
}
impl<T> Application<T> {
pub fn new(config: AppConfig<T>) -> Application<AlwaysReadyAndAlive, T, DefaultVersion> {
Application::<AlwaysReadyAndAlive, T, DefaultVersion> {
config: Arc::new(config),
health_indicator: AlwaysReadyAndAlive,
version: DefaultVersion,
router: None,
metrics_callback: None,
use_default_trace_layer: true,
}
}
pub fn new_from_arced(
config: Arc<AppConfig<T>>,
) -> Application<AlwaysReadyAndAlive, T, DefaultVersion> {
Application::<AlwaysReadyAndAlive, T, DefaultVersion> {
config,
health_indicator: AlwaysReadyAndAlive,
version: DefaultVersion,
router: None,
metrics_callback: None,
use_default_trace_layer: true,
}
}
}
impl<H, T, V> Application<H, T, V> {
pub fn health_indicator<Hh: HealthExt>(self, health: Hh) -> Application<Hh, T, V> {
let Self {
config,
health_indicator: _,
router,
metrics_callback,
use_default_trace_layer,
version,
} = self;
Application::<Hh, T, V> {
config,
health_indicator: health,
router,
metrics_callback,
use_default_trace_layer,
version,
}
}
pub fn version<Vv: VersionExt<T>>(self, version: Vv) -> Application<H, T, Vv> {
let Self {
config,
health_indicator,
router,
metrics_callback,
use_default_trace_layer,
version: _,
} = self;
Application::<H, T, Vv> {
config,
health_indicator,
router,
metrics_callback,
use_default_trace_layer,
version,
}
}
#[must_use]
pub fn router(self, router: Router) -> Self {
Self {
router: Some(router),
..self
}
}
#[must_use]
pub fn metrics_callback(self, metrics_callback: impl Fn() + Send + Sync + 'static) -> Self {
Self {
metrics_callback: Some(Arc::new(metrics_callback)),
..self
}
}
#[must_use]
pub fn use_default_tracing_layer(self, use_default: bool) -> Self {
Self {
use_default_trace_layer: use_default,
..self
}
}
pub async fn serve(self) -> Result<()>
where
H: HealthExt,
V: VersionExt<T>,
T: Send + Sync + 'static,
{
let (router, application_socket) = self.prepare_router();
run_service(&application_socket, router).await
}
#[cfg(feature = "tls")]
pub async fn serve_tls(self) -> Result<()>
where
H: HealthExt,
V: VersionExt<T>,
T: Send + Sync + 'static,
{
use crate::error::Error;
use futures_util::TryFutureExt;
use std::fmt;
use tokio::{fs, try_join};
fn cant_load<Arg: fmt::Display>(r#type: &str) -> impl FnOnce(Arg) -> Error + '_ {
move |error| Error::CustomError(format!("Cant load TLS {type}: `{error}`."))
}
let tls_handshake_timeout = self.config.tls.handshake_timeout;
let tls_cert_path = self
.config
.tls
.cert_path
.as_deref()
.ok_or_else(|| cant_load("certificate")("No path present."))?;
let tls_key_path = self
.config
.tls
.key_path
.as_deref()
.ok_or_else(|| cant_load("key")("No path present."))?;
let (tls_cert, tls_key) = try_join!(
fs::read(tls_cert_path).map_err(cant_load("certificate")),
fs::read(tls_key_path).map_err(cant_load("key"))
)?;
let (router, application_socket) = self.prepare_router();
tls::run_service(
&application_socket,
router,
tls_handshake_timeout,
tls_cert,
tls_key,
)
.await
}
fn prepare_router(self) -> (Router, SocketAddr)
where
H: HealthExt,
V: VersionExt<T>,
T: Send + Sync + 'static,
{
let app_router = self
.router
.map(|router| {
let service_name = self.config.observability_cfg.service_name.clone();
let component_name = self.config.observability_cfg.component_name.clone();
if self.use_default_trace_layer {
router.layer(from_fn(move |req, next| {
trace_request(req, next, service_name.clone(), component_name.clone())
}))
} else {
router
}
})
.unwrap_or_default();
let router = build_management_router(
&self.config,
self.health_indicator,
self.version,
self.metrics_callback,
)
.merge(app_router);
let application_socket = SocketAddr::new(self.config.host, self.config.port);
(router, application_socket)
}
}
#[allow(clippy::expect_used)]
async fn shutdown_signal() {
let ctrl_c = async {
signal::ctrl_c()
.await
.expect("failed to install Ctrl+C handler");
};
#[cfg(unix)]
let terminate = async {
signal::unix::signal(signal::unix::SignalKind::terminate())
.expect("failed to install SIGTERM signal handler")
.recv()
.await;
};
#[cfg(not(unix))]
let terminate = std::future::pending::<()>();
tokio::select! {
_ = ctrl_c => {},
_ = terminate => {},
}
info!("Termination signal, starting shutdown...");
}
async fn run_service(socket: &SocketAddr, router: Router) -> Result<()> {
let app = router.into_make_service_with_connect_info::<SocketAddr>();
let server = Server::bind(socket).serve(app);
info!(target: "server", "Started: http://{socket}");
Ok(server.with_graceful_shutdown(shutdown_signal()).await?)
}