use std::net::{SocketAddr, ToSocketAddrs};
use std::str::FromStr;
use std::sync::Arc;
use std::time::Duration;
use axum::extract::State;
use axum::routing::any;
use axum::Router;
use rivet_core::{set_current, Application, Dispatcher};
use rivet_http::Response;
use tokio::sync::Semaphore;
use tokio::task::JoinError;
use tokio::time::timeout;
use crate::error::AdapterError;
use crate::mapping::{from_axum_request, to_axum_response, MappingError};
#[derive(Debug, Clone)]
pub struct AdapterOptions {
pub body_limit: usize,
pub timeout: Duration,
pub max_blocking_dispatches: usize,
}
impl Default for AdapterOptions {
fn default() -> Self {
Self {
body_limit: 2 * 1024 * 1024,
timeout: Duration::from_secs(30),
max_blocking_dispatches: 256,
}
}
}
#[derive(Clone)]
struct AdapterState {
app: Arc<dyn Dispatcher>,
options: AdapterOptions,
semaphore: Arc<Semaphore>,
}
pub fn into_router(app: Arc<dyn Dispatcher>) -> Router {
into_router_with_options(app, AdapterOptions::default())
}
pub fn into_router_with_options(app: Arc<dyn Dispatcher>, options: AdapterOptions) -> Router {
let _ = set_current(Arc::clone(&app));
let state = AdapterState {
app,
semaphore: Arc::new(Semaphore::new(options.max_blocking_dispatches)),
options,
};
Router::new()
.route("/", any(dispatch_handler))
.route("/{*path}", any(dispatch_handler))
.with_state(state)
}
pub async fn serve(app: Application) -> Result<(), AdapterError> {
let addr = serve_addr(&app)?;
app.log_serving(addr);
let app = Arc::new(app);
let listener = tokio::net::TcpListener::bind(addr).await?;
axum::serve(listener, into_router(app)).await?;
Ok(())
}
fn serve_addr(app: &Application) -> Result<SocketAddr, AdapterError> {
let server_url = app
.config_typed::<String>("app.server.url")
.ok_or_else(|| {
std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"missing 'app.server.url' config",
)
})?;
let authority = url_authority(&server_url);
resolve_addr(authority).ok_or_else(|| {
std::io::Error::new(
std::io::ErrorKind::InvalidInput,
format!("invalid 'app.server.url' value '{server_url}'"),
)
.into()
})
}
fn resolve_addr(value: &str) -> Option<SocketAddr> {
SocketAddr::from_str(value)
.ok()
.or_else(|| value.to_socket_addrs().ok()?.next())
}
fn url_authority(url: &str) -> &str {
let without_scheme = url.split_once("://").map_or(url, |(_, rest)| rest);
without_scheme
.split_once(['/', '?', '#'])
.map_or(without_scheme, |(authority, _)| authority)
}
async fn dispatch_handler(
State(state): State<AdapterState>,
req: axum::http::Request<axum::body::Body>,
) -> axum::response::Response {
let request = match from_axum_request(req, state.options.body_limit).await {
Ok(req) => req,
Err(MappingError::UnsupportedMethod(_)) => {
return to_axum_response(Response::not_implemented())
}
Err(MappingError::BodyLimit) => return to_axum_response(Response::new(413)),
};
let permit = state.semaphore.acquire().await;
if permit.is_err() {
return to_axum_response(Response::internal_error());
}
let app = Arc::clone(&state.app);
let dispatch = tokio::task::spawn_blocking(move || app.dispatch(request));
match timeout(state.options.timeout, dispatch).await {
Ok(Ok(response)) => to_axum_response(response),
Ok(Err(join_err)) => to_axum_response(map_join_error(join_err)),
Err(_) => to_axum_response(Response::new(504)),
}
}
fn map_join_error(join_err: JoinError) -> Response {
if join_err.is_panic() {
return Response::internal_error();
}
Response::internal_error()
}