axum-bootstrap 0.1.22

a way to bootstrap a web server with axum, including TLS, logging, monitoring, and more.
Documentation
use std::{convert::Infallible, net::SocketAddr, sync::Arc, time::Duration};
pub mod error;
pub mod init_log;
#[cfg(feature = "jwt")]
pub mod jwt;
pub mod util;
type DynError = Box<dyn std::error::Error + Send + Sync>;
use crate::util::{
    io::{self, create_dual_stack_listener},
    tls::{TlsAcceptor, tls_config},
};

use axum::{
    Router,
    extract::Request,
    response::{IntoResponse, Response},
};

use hyper::body::Incoming;
use hyper_util::rt::TokioExecutor;
use log::{info, warn};
use tokio::{
    sync::broadcast::{self, Receiver, Sender, error::RecvError},
    time,
};
use tokio_rustls::rustls::ServerConfig;
use tower::{Service, ServiceExt};
use util::format::SocketAddrFormat;

const REFRESH_INTERVAL: Duration = Duration::from_secs(60 * 60 * 24);
const GRACEFUL_SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(10);

pub struct Server<I: ReqInterceptor = DummyInterceptor> {
    pub port: u16,
    pub tls_param: Option<TlsParam>,
    router: Router,
    pub interceptor: Option<I>,
    pub idle_timeout: Duration,
    shutdown_rx: broadcast::Receiver<()>,
}

#[derive(Debug, Clone)]
pub struct TlsParam {
    pub tls: bool,
    pub cert: String,
    pub key: String,
}

pub enum InterceptResult<T: IntoResponse> {
    Return(Response),
    Drop,
    Continue(Request<Incoming>),
    Error(T),
}

pub trait ReqInterceptor: Send {
    type Error: IntoResponse + Send + Sync + 'static;
    fn intercept(&self, req: Request<Incoming>, ip: SocketAddr) -> impl std::future::Future<Output = InterceptResult<Self::Error>> + Send;
}

#[derive(Clone)]
pub struct DummyInterceptor;

impl ReqInterceptor for DummyInterceptor {
    type Error = error::AppError;

    async fn intercept(&self, req: Request<Incoming>, _ip: SocketAddr) -> InterceptResult<Self::Error> {
        InterceptResult::Continue(req)
    }
}

pub type DefaultServer = Server<DummyInterceptor>;

pub fn new_server(port: u16, router: Router, shutdown_rx: broadcast::Receiver<()>) -> Server {
    Server {
        port,
        tls_param: None, // No TLS by default
        router,
        interceptor: None,
        idle_timeout: Duration::from_secs(120),
        shutdown_rx,
    }
}

impl<I> Server<I>
where
    I: ReqInterceptor + Clone + Send + Sync + 'static,
{
    pub fn with_interceptor<R>(self: Server<I>, interceptor: R) -> Server<R>
    where
        R: ReqInterceptor + Clone + Send + Sync + 'static,
    {
        Server::<R> {
            port: self.port,
            tls_param: self.tls_param,
            router: self.router,
            interceptor: Some(interceptor),
            idle_timeout: self.idle_timeout, // keep the same idle timeout
            shutdown_rx: self.shutdown_rx,
        }
    }
    pub fn with_tls_param(mut self, tls_param: Option<TlsParam>) -> Self {
        // Enable TLS by setting the tls_param
        self.tls_param = tls_param;
        self
    }

    pub fn with_timeout(mut self, timeout: Duration) -> Self {
        self.idle_timeout = timeout;
        self
    }

    pub async fn run(mut self) -> Result<(), std::io::Error> {
        let use_tls = match self.tls_param.clone() {
            Some(config) => config.tls,
            None => false,
        };
        log::info!("listening on port {}, use_tls: {}", self.port, use_tls);
        let server: hyper_util::server::conn::auto::Builder<TokioExecutor> = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new());
        let graceful: hyper_util::server::graceful::GracefulShutdown = hyper_util::server::graceful::GracefulShutdown::new();
        match use_tls {
            #[allow(clippy::expect_used)]
            true => {
                serve_tls(
                    &self.router,
                    server,
                    graceful,
                    self.port,
                    self.tls_param.as_ref().expect("should be some"),
                    self.interceptor.clone(),
                    self.idle_timeout,
                    &mut self.shutdown_rx,
                )
                .await?
            }
            false => {
                serve_plantext(&self.router, server, graceful, self.port, self.interceptor.clone(), self.idle_timeout, &mut self.shutdown_rx).await?
            }
        }
        Ok(())
    }
}

async fn handle<I>(
    request: Request<Incoming>, client_socket_addr: SocketAddr, app: axum::middleware::AddExtension<Router, axum::extract::ConnectInfo<SocketAddr>>,
    interceptor: Option<I>,
) -> std::result::Result<Response, std::io::Error>
where
    I: ReqInterceptor + Clone + Send + Sync + 'static,
{
    if let Some(interceptor) = interceptor {
        match interceptor.intercept(request, client_socket_addr).await {
            InterceptResult::Return(res) => Ok(res),
            InterceptResult::Drop => Err(std::io::Error::other("Request dropped by interceptor")),
            InterceptResult::Continue(req) => app
                .oneshot(req)
                .await
                .map_err(|err| std::io::Error::new(std::io::ErrorKind::Interrupted, err)),
            InterceptResult::Error(err) => {
                let res = err.into_response();
                Ok(res)
            }
        }
    } else {
        app.oneshot(request)
            .await
            .map_err(|err| std::io::Error::new(std::io::ErrorKind::Interrupted, err))
    }
}

async fn handle_connection<C, I>(
    conn: C, client_socket_addr: std::net::SocketAddr, app: Router, server: hyper_util::server::conn::auto::Builder<TokioExecutor>,
    interceptor: Option<I>, graceful: &hyper_util::server::graceful::GracefulShutdown, timeout: Duration,
) where
    C: tokio::io::AsyncRead + tokio::io::AsyncWrite + 'static + Send + Sync,
    I: ReqInterceptor + Clone + Send + Sync + 'static,
{
    let timeout_io = Box::pin(io::TimeoutIO::new(conn, timeout));
    use hyper::Request;
    use hyper_util::rt::TokioIo;
    let stream = TokioIo::new(timeout_io);
    let mut app = app.into_make_service_with_connect_info::<SocketAddr>();
    let app: axum::middleware::AddExtension<Router, axum::extract::ConnectInfo<SocketAddr>> = unwrap_infallible(app.call(client_socket_addr).await);
    // https://github.com/tokio-rs/axum/blob/main/examples/serve-with-hyper/src/main.rs#L81
    let hyper_service = hyper::service::service_fn(move |request: Request<hyper::body::Incoming>| {
        handle(request, client_socket_addr, app.clone(), interceptor.clone())
    });

    let conn = server.serve_connection_with_upgrades(stream, hyper_service);
    let conn = graceful.watch(conn.into_owned());

    tokio::spawn(async move {
        if let Err(err) = conn.await {
            handle_hyper_error(client_socket_addr, err);
        }
        log::debug!("connection dropped: {client_socket_addr}");
    });
}

fn handle_hyper_error(client_socket_addr: SocketAddr, http_err: DynError) {
    use std::error::Error;
    match http_err.downcast_ref::<hyper::Error>() {
        Some(hyper_err) => {
            let level = if hyper_err.is_user() { log::Level::Warn } else { log::Level::Debug };
            let source = hyper_err.source().unwrap_or(hyper_err);
            log::log!(
                level,
                "[hyper {}]: {:?} from {}",
                if hyper_err.is_user() { "user" } else { "system" },
                source,
                SocketAddrFormat(&client_socket_addr)
            );
        }
        None => match http_err.downcast_ref::<std::io::Error>() {
            Some(io_err) => {
                warn!("[hyper io]: [{}] {} from {}", io_err.kind(), io_err, SocketAddrFormat(&client_socket_addr));
            }
            None => {
                warn!("[hyper]: {} from {}", http_err, SocketAddrFormat(&client_socket_addr));
            }
        },
    }
}

async fn serve_plantext<I>(
    app: &Router, server: hyper_util::server::conn::auto::Builder<TokioExecutor>, graceful: hyper_util::server::graceful::GracefulShutdown,
    port: u16, interceptor: Option<I>, timeout: Duration, shutdown_rx: &mut broadcast::Receiver<()>,
) -> Result<(), std::io::Error>
where
    I: ReqInterceptor + Clone + Send + Sync + 'static,
{
    let listener = create_dual_stack_listener(port).await?;
    loop {
        tokio::select! {
            _ = shutdown_rx.recv() => {
                info!("start graceful shutdown!");
                drop(listener);
                break;
            }
            conn = listener.accept() => {
                match conn {
                    Ok((conn, client_socket_addr)) => {
                        handle_connection(conn,client_socket_addr, app.clone(), server.clone(),interceptor.clone(), &graceful, timeout).await;}
                    Err(e) => {
                        warn!("accept error:{e}");
                    }
                }
            }
        }
    }
    match tokio::time::timeout(GRACEFUL_SHUTDOWN_TIMEOUT, graceful.shutdown()).await {
        Ok(_) => info!("Gracefully shutdown!"),
        Err(_) => info!("Waited {GRACEFUL_SHUTDOWN_TIMEOUT:?} for graceful shutdown, aborting..."),
    }
    Ok(())
}

#[allow(clippy::too_many_arguments)]
async fn serve_tls<I>(
    app: &Router, server: hyper_util::server::conn::auto::Builder<TokioExecutor>, graceful: hyper_util::server::graceful::GracefulShutdown,
    port: u16, tls_param: &TlsParam, interceptor: Option<I>, timeout: Duration, shutdown_rx: &mut broadcast::Receiver<()>,
) -> Result<(), std::io::Error>
where
    I: ReqInterceptor + Clone + Send + Sync + 'static,
{
    let (tx, mut rx) = broadcast::channel::<Arc<ServerConfig>>(1);
    let tls_param_clone = tls_param.clone();
    tokio::spawn(async move {
        info!("update tls config every {REFRESH_INTERVAL:?}");
        loop {
            time::sleep(REFRESH_INTERVAL).await;
            if let Ok(new_acceptor) = tls_config(&tls_param_clone.key, &tls_param_clone.cert) {
                info!("update tls config");
                if let Err(e) = tx.send(new_acceptor) {
                    warn!("send tls config error:{e}");
                }
            }
        }
    });
    let mut acceptor: TlsAcceptor = TlsAcceptor::new(tls_config(&tls_param.key, &tls_param.cert)?, create_dual_stack_listener(port).await?);
    loop {
        tokio::select! {
            _ = shutdown_rx.recv() => {
                info!("start graceful shutdown!");
                drop(acceptor);
                break;
            }
            message = rx.recv() => {
                match message {
                    Ok(new_config) => {
                        acceptor.replace_config(new_config);
                        info!("replaced tls config");
                    },
                    Err(e) => {
                        match e {
                            RecvError::Closed => {
                                warn!("this channel should not be closed!");
                                break;
                            },
                            RecvError::Lagged(n) => {
                                warn!("lagged {n} messages, this may cause tls config not updated in time");
                            }
                        }
                    }
                }
            }
            conn = acceptor.accept() => {
                match conn {
                    Ok((conn, client_socket_addr)) => {
                        handle_connection(conn,client_socket_addr, app.clone(), server.clone(),interceptor.clone(), &graceful, timeout).await;}
                    Err(e) => {
                        warn!("accept error:{e}");
                    }
                }
            }
        }
    }
    match tokio::time::timeout(GRACEFUL_SHUTDOWN_TIMEOUT, graceful.shutdown()).await {
        Ok(_) => info!("Gracefully shutdown!"),
        Err(_) => info!("Waited {GRACEFUL_SHUTDOWN_TIMEOUT:?} for graceful shutdown, aborting..."),
    }
    Ok(())
}

pub fn generate_shutdown_receiver() -> Receiver<()> {
    let (shutdown_tx, shutdown_rx) = tokio::sync::broadcast::channel::<()>(1);
    subscribe_shutdown_sender(shutdown_tx);
    shutdown_rx
}

pub fn subscribe_shutdown_sender(shutdown_tx: Sender<()>) {
    tokio::spawn(async move {
        match wait_signal().await {
            Ok(_) => {
                let _ = shutdown_tx.send(());
            }
            Err(e) => {
                log::error!("wait_signal error: {}", e);
                panic!("wait_signal error: {}", e);
            }
        }
    });
}

#[cfg(unix)]
pub(crate) async fn wait_signal() -> Result<(), DynError> {
    use log::info;
    use tokio::signal::unix::{SignalKind, signal};
    let mut terminate_signal = signal(SignalKind::terminate())?;
    tokio::select! {
        _ = terminate_signal.recv() => {
            info!("receive terminate signal");
        },
        _ = tokio::signal::ctrl_c() => {
            info!("receive ctrl_c signal");
        },
    };
    Ok(())
}

#[cfg(windows)]
pub(crate) async fn wait_signal() -> Result<(), DynError> {
    let _ = tokio::signal::ctrl_c().await;
    info!("receive ctrl_c signal");
    Ok(())
}

fn unwrap_infallible<T>(result: Result<T, Infallible>) -> T {
    match result {
        Ok(value) => value,
        Err(err) => match err {},
    }
}