fuel-web-utils 0.0.29

Fuel library for web utils
use std::{
    net::{Ipv4Addr, SocketAddrV4},
    time::Duration,
};

use axum::{
    extract::{MatchedPath, Request},
    routing::get,
    Router,
};
use tower_http::{
    compression::CompressionLayer,
    cors::{Any, CorsLayer},
    decompression::RequestDecompressionLayer,
    trace::TraceLayer,
};
use tracing::info_span;

use super::{
    http::handlers::{get_health, get_metrics},
    state::StateProvider,
};

pub const API_VERSION: &str = "v1";
pub const API_BASE_PATH: &str = "/api/v1";

pub struct Server {
    app: Router,
    port: u16,
}

impl Server {
    pub fn with_router(mut self, router: Router) -> Self {
        self.app = self.app.merge(router);
        self
    }

    pub async fn run(self) -> anyhow::Result<()> {
        let addr = std::net::SocketAddr::V4(SocketAddrV4::new(
            Ipv4Addr::UNSPECIFIED,
            self.port,
        ));

        let listener = tokio::net::TcpListener::bind(addr).await?;
        tracing::debug!("listening on {}", listener.local_addr().unwrap());
        axum::serve(listener, self.app)
            .with_graceful_shutdown(shutdown_signal())
            .await?;
        Ok(())
    }
}

pub struct ServerBuilder;

impl ServerBuilder {
    pub fn build<S: StateProvider + Clone + 'static>(
        state: &S,
        port: u16,
    ) -> Server {
        let app = Router::new()
            .route("/health", get(get_health::<S>))
            .route("/metrics", get(get_metrics::<S>))
            .layer(TraceLayer::new_for_http().make_span_with(
                |request: &Request<_>| {
                    let matched_path = request
                        .extensions()
                        .get::<MatchedPath>()
                        .map(MatchedPath::as_str);
                    info_span!(
                        "http_request",
                        method = ?request.method(),
                        matched_path,
                        some_other_field = tracing::field::Empty,
                    )
                },
            ))
            .layer(RequestDecompressionLayer::new())
            .layer(CompressionLayer::new())
            .layer(
                CorsLayer::new()
                    .allow_origin(Any)
                    .allow_methods(vec![
                        axum::http::Method::GET,
                        axum::http::Method::POST,
                        axum::http::Method::PUT,
                        axum::http::Method::OPTIONS,
                        axum::http::Method::DELETE,
                        axum::http::Method::PATCH,
                        axum::http::Method::TRACE,
                    ])
                    .allow_headers(vec![
                        axum::http::header::AUTHORIZATION,
                        axum::http::header::ACCEPT,
                        axum::http::header::CONTENT_TYPE,
                    ])
                    .max_age(Duration::from_secs(3600)),
            )
            .with_state(state.to_owned());
        Server { app, port }
    }
}

async fn shutdown_signal() {
    use tokio::signal;
    let ctrl_c = async {
        signal::ctrl_c()
            .await
            .expect("failed to install Ctrl+C handler");
    };

    #[cfg(unix)]
    let terminate = async {
        signal::unix::signal(signal::unix::SignalKind::terminate())
            .expect("failed to install signal handler")
            .recv()
            .await;
    };

    #[cfg(not(unix))]
    let terminate = std::future::pending::<()>();

    tokio::select! {
        _ = ctrl_c => {},
        _ = terminate => {},
    }
}