sinan 0.1.0

A Boilerplate for Rapid Axum Web Service Deployment.
Documentation
use std::iter::once;
use std::time::Duration;

use axum::error_handling::HandleErrorLayer;
use axum::extract::DefaultBodyLimit;
use axum::Router;
use eyre::{OptionExt, Result};
use http::header::AUTHORIZATION;
use http::HeaderName;
use tokio::net;
use tower::ServiceBuilder;
use tower_http::catch_panic::CatchPanicLayer;
use tower_http::compression::{CompressionLayer, DefaultPredicate};
use tower_http::decompression::DecompressionLayer;
use tower_http::propagate_header::PropagateHeaderLayer;
use tower_http::sensitive_headers::SetSensitiveRequestHeadersLayer;
use tower_http::trace::TraceLayer;

use crate::contracts::{Application, Service};
use crate::error::Error;
use crate::foundation::ServerTag;
use crate::http::{error, fallback, panic};
use crate::services::service;

pub struct Server;

impl Service for Server {
    fn register<A: Application + ?Sized>() -> Self
    where
        Self: Sized,
    {
        Self
    }

    fn boot<A: Application + ?Sized>() -> Result<()>
    where
        Self: Sized,
    {
        if service().get::<Server>().is_some() {
            A::runtime().block_on(async {
                let (main_sever, metrics_server) = tokio::join!(serve::<A>(), metrics::<A>());

                if let Some(err) = main_sever.err() {
                    tracing::error!(server = ServerTag::Main.as_string(), "error: {err}");
                }

                if let Some(err) = metrics_server.err() {
                    tracing::error!(server = ServerTag::Metrics.as_string(), "error: {err}");
                }
            });
        }

        Ok(())
    }
}

async fn serve<A: Application + ?Sized>() -> Result<()> {
    let routes = A::with_routing()
        .get(&ServerTag::Main)
        .ok_or_eyre(Error::Message("routes is empty".to_string()))?
        .to_owned();

    let mut app = Router::new();

    for route in routes {
        app = app.merge(route.to_owned());
    }

    app = app.fallback(fallback).layer(
        ServiceBuilder::new()
            .layer(DefaultBodyLimit::max(128 * 1024 * 1024))
            .layer(
                CompressionLayer::new()
                    .gzip(true)
                    .compress_when(DefaultPredicate::new()),
            )
            .layer(PropagateHeaderLayer::new(HeaderName::from_static(
                "x-request-id",
            )))
            .layer(DecompressionLayer::new().gzip(true))
            .layer(SetSensitiveRequestHeadersLayer::new(once(AUTHORIZATION)))
            .layer(TraceLayer::new_for_http())
            .layer(HandleErrorLayer::new(error))
            .layer(CatchPanicLayer::custom(panic))
            .timeout(Duration::from_secs(30)),
    );

    let listener = net::TcpListener::bind(std::env::var("APP_URL")?).await?;

    tracing::info!(
        server = ServerTag::Main.as_string(),
        "[service] listening on {}",
        listener.local_addr()?
    );

    axum::serve(listener, app)
        .with_graceful_shutdown(shutdown(ServerTag::Main))
        .await?;

    Ok(())
}

async fn metrics<A: Application + ?Sized>() -> Result<()> {
    let routes = A::with_routing()
        .get(&ServerTag::Metrics)
        .ok_or_eyre(Error::Message("routes is empty".to_string()))?
        .to_owned();

    let mut app = Router::new();

    for route in routes {
        app = app.merge(route.to_owned());
    }

    let listener = net::TcpListener::bind(std::env::var("METRICS_URL")?).await?;

    tracing::info!(
        server = ServerTag::Metrics.as_string(),
        "[service] listening on {}",
        listener.local_addr()?
    );

    axum::serve(listener, app)
        .with_graceful_shutdown(shutdown(ServerTag::Metrics))
        .await?;

    Ok(())
}

async fn shutdown(tag: ServerTag) {
    let ctrl_c = async {
        tokio::signal::ctrl_c()
            .await
            .expect("failed to install Ctrl+C handler");
    };

    #[cfg(unix)]
    let terminate = async {
        tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
            .expect("fail to install the terminate signal handler")
            .recv()
            .await;
    };

    #[cfg(not(unix))]
    let terminate = std::future::pending::<()>();

    tokio::select! {
        _ = ctrl_c => {},
        _ = terminate => {},
    }

    tracing::warn!(
        server = tag.as_string(),
        "signal received, starting graceful shutdown"
    );
}