rivet-adapter-axum 0.1.0

Rivet framework crates and adapters.
Documentation
use std::net::{SocketAddr, ToSocketAddrs};
use std::str::FromStr;
use std::sync::Arc;
use std::time::Duration;

use axum::extract::State;
use axum::routing::any;
use axum::Router;
use rivet_core::{set_current, Application, Dispatcher};
use rivet_http::Response;
use tokio::sync::Semaphore;
use tokio::task::JoinError;
use tokio::time::timeout;

use crate::error::AdapterError;
use crate::mapping::{from_axum_request, to_axum_response, MappingError};

#[derive(Debug, Clone)]
pub struct AdapterOptions {
    pub body_limit: usize,
    pub timeout: Duration,
    pub max_blocking_dispatches: usize,
}

impl Default for AdapterOptions {
    fn default() -> Self {
        Self {
            body_limit: 2 * 1024 * 1024,
            timeout: Duration::from_secs(30),
            max_blocking_dispatches: 256,
        }
    }
}

#[derive(Clone)]
struct AdapterState {
    app: Arc<dyn Dispatcher>,
    options: AdapterOptions,
    semaphore: Arc<Semaphore>,
}

pub fn into_router(app: Arc<dyn Dispatcher>) -> Router {
    into_router_with_options(app, AdapterOptions::default())
}

pub fn into_router_with_options(app: Arc<dyn Dispatcher>, options: AdapterOptions) -> Router {
    let _ = set_current(Arc::clone(&app));

    let state = AdapterState {
        app,
        semaphore: Arc::new(Semaphore::new(options.max_blocking_dispatches)),
        options,
    };

    Router::new()
        .route("/", any(dispatch_handler))
        .route("/{*path}", any(dispatch_handler))
        .with_state(state)
}

pub async fn serve(app: Application) -> Result<(), AdapterError> {
    let addr = serve_addr(&app)?;
    app.log_serving(addr);
    let app = Arc::new(app);
    let listener = tokio::net::TcpListener::bind(addr).await?;
    axum::serve(listener, into_router(app)).await?;
    Ok(())
}

fn serve_addr(app: &Application) -> Result<SocketAddr, AdapterError> {
    let server_url = app
        .config_typed::<String>("app.server.url")
        .ok_or_else(|| {
            std::io::Error::new(
                std::io::ErrorKind::InvalidInput,
                "missing 'app.server.url' config",
            )
        })?;

    let authority = url_authority(&server_url);
    resolve_addr(authority).ok_or_else(|| {
        std::io::Error::new(
            std::io::ErrorKind::InvalidInput,
            format!("invalid 'app.server.url' value '{server_url}'"),
        )
        .into()
    })
}

fn resolve_addr(value: &str) -> Option<SocketAddr> {
    SocketAddr::from_str(value)
        .ok()
        .or_else(|| value.to_socket_addrs().ok()?.next())
}

fn url_authority(url: &str) -> &str {
    let without_scheme = url.split_once("://").map_or(url, |(_, rest)| rest);
    without_scheme
        .split_once(['/', '?', '#'])
        .map_or(without_scheme, |(authority, _)| authority)
}

async fn dispatch_handler(
    State(state): State<AdapterState>,
    req: axum::http::Request<axum::body::Body>,
) -> axum::response::Response {
    let request = match from_axum_request(req, state.options.body_limit).await {
        Ok(req) => req,
        Err(MappingError::UnsupportedMethod(_)) => {
            return to_axum_response(Response::not_implemented())
        }
        Err(MappingError::BodyLimit) => return to_axum_response(Response::new(413)),
    };

    let permit = state.semaphore.acquire().await;
    if permit.is_err() {
        return to_axum_response(Response::internal_error());
    }

    let app = Arc::clone(&state.app);
    let dispatch = tokio::task::spawn_blocking(move || app.dispatch(request));

    match timeout(state.options.timeout, dispatch).await {
        Ok(Ok(response)) => to_axum_response(response),
        Ok(Err(join_err)) => to_axum_response(map_join_error(join_err)),
        Err(_) => to_axum_response(Response::new(504)),
    }
}

fn map_join_error(join_err: JoinError) -> Response {
    if join_err.is_panic() {
        return Response::internal_error();
    }

    Response::internal_error()
}