cloudiful-server 0.2.0

Rust web server bootstrap crate with Actix and Axum adapters
Documentation
use ::axum::Router;
use axum_server::tls_rustls::RustlsConfig;
use log::info;
use std::{
    future::Future,
    net::{SocketAddr, TcpListener},
    sync::Arc,
    time::Duration,
};
use tower_http::cors::{Any, CorsLayer};

use crate::{CorsConfig, ServerError, ValidatedServerConfig, load_tls_config};

pub struct BoundServer {
    addrs: Vec<SocketAddr>,
    listener: TcpListener,
    app: Router,
    tls_config: Option<RustlsConfig>,
}

impl BoundServer {
    pub fn addrs(&self) -> &[SocketAddr] {
        self.addrs.as_slice()
    }

    pub async fn run(self) -> Result<(), ServerError> {
        if let Some(tls_config) = self.tls_config {
            axum_server::from_tcp_rustls(self.listener, tls_config)
                .serve(self.app.into_make_service())
                .await
                .map_err(ServerError::from)
        } else {
            self.listener
                .set_nonblocking(true)
                .map_err(ServerError::from)?;

            let listener =
                tokio::net::TcpListener::from_std(self.listener).map_err(ServerError::from)?;

            ::axum::serve(listener, self.app)
                .await
                .map_err(ServerError::from)
        }
    }

    pub async fn run_with_graceful_shutdown<F>(self, signal: F) -> Result<(), ServerError>
    where
        F: Future<Output = ()> + Send + 'static,
    {
        if let Some(tls_config) = self.tls_config {
            let handle = axum_server::Handle::new();
            let shutdown_handle = handle.clone();
            tokio::spawn(async move {
                signal.await;
                shutdown_handle.graceful_shutdown(None);
            });

            axum_server::from_tcp_rustls(self.listener, tls_config)
                .handle(handle)
                .serve(self.app.into_make_service())
                .await
                .map_err(ServerError::from)
        } else {
            self.listener
                .set_nonblocking(true)
                .map_err(ServerError::from)?;

            let listener =
                tokio::net::TcpListener::from_std(self.listener).map_err(ServerError::from)?;

            ::axum::serve(listener, self.app)
                .with_graceful_shutdown(signal)
                .await
                .map_err(ServerError::from)
        }
    }
}

pub struct Server<U = ()> {
    config: ValidatedServerConfig<U>,
    app: Router,
}

impl<U> Server<U> {
    #[cfg(test)]
    pub(crate) fn into_router(self) -> Router {
        self.app
    }
}

impl Server<()> {
    pub fn new(config: ValidatedServerConfig<()>, app: Router) -> Self {
        let app = app.layer(build_cors_layer(config.cors()));
        Self { config, app }
    }
}

impl<U> Server<U>
where
    U: Clone + Send + Sync + 'static,
{
    pub fn new_with_state(config: ValidatedServerConfig<U>, app: Router<U>) -> Self {
        let app = app.with_state(
            config
                .app_data()
                .cloned()
                .expect("validated stateful config must contain app_data"),
        );
        let app = app.layer(build_cors_layer(config.cors()));

        Self { config, app }
    }

    pub fn bind(self) -> Result<BoundServer, ServerError> {
        let listener = TcpListener::bind(self.config.listen_addr()).map_err(ServerError::from)?;
        let addrs = vec![listener.local_addr().map_err(ServerError::from)?];
        let uses_tls = self.config.tls_enabled();

        let tls_config = match load_tls_config(&self.config)? {
            Some(config) => Some(RustlsConfig::from_config(Arc::new(config))),
            None => None,
        };

        let scheme = if uses_tls { "https" } else { "http" };
        info!("starting {scheme} server on {:?}", addrs);

        Ok(BoundServer {
            addrs,
            listener,
            app: self.app,
            tls_config,
        })
    }

    pub async fn start(self) -> Result<(), ServerError> {
        self.bind()?.run().await
    }
}

fn build_cors_layer(config: &CorsConfig) -> CorsLayer {
    if config.is_permissive() {
        return CorsLayer::permissive();
    }

    let methods = config
        .allowed_methods()
        .iter()
        .map(|method| method.parse().expect("validated CORS method"))
        .collect::<Vec<::axum::http::Method>>();
    let origins = config
        .allowed_origins()
        .iter()
        .map(|origin| origin.parse().expect("validated CORS origin"))
        .collect::<Vec<::axum::http::HeaderValue>>();

    let mut layer = CorsLayer::new()
        .allow_origin(origins)
        .allow_methods(methods)
        .allow_headers(Any);

    if let Some(max_age) = config.max_age() {
        layer = layer.max_age(Duration::from_secs(max_age as u64));
    }

    layer
}