use std::future::Future;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use axum::extract::Request;
use axum::handler::Handler;
use axum::middleware::Next;
use axum::response::Response;
use axum::Router;
use tokio::net::TcpListener;
use tower_http::catch_panic::CatchPanicLayer;
use tower_http::cors::{Any, CorsLayer};
use tracing::info;
use crate::config::AppConfig;
use crate::controller::Controller;
use crate::logging;
use crate::auth::{AuthConfig, AuthLayer};
use crate::middleware::{self, InjectStateLayer, RequestTimeoutLayer};
use crate::rate_limit::RateLimitLayer;
use crate::router::{Method, OxideRouter};
use crate::state::{AppState, TypeMap};
type RouterTransform = Box<dyn FnOnce(Router) -> Router>;
pub struct App {
config: AppConfig,
router: OxideRouter,
config_path: Option<String>,
type_map: TypeMap,
request_logging: bool,
rate_limit: Option<(u64, Duration)>,
cors: Option<CorsLayer>,
request_timeout: Option<Duration>,
controller_factories: Vec<Box<dyn FnOnce(AppState) -> OxideRouter>>,
user_layers: Vec<RouterTransform>,
auth: Option<AuthConfig>,
}
impl App {
pub fn new() -> Self {
logging::init();
Self {
config: AppConfig::default(),
router: OxideRouter::new(),
config_path: None,
type_map: TypeMap::default(),
request_logging: true,
rate_limit: None,
cors: None,
request_timeout: None,
controller_factories: Vec::new(),
user_layers: Vec::new(),
auth: None,
}
}
pub fn config(mut self, path: &str) -> Self {
self.config_path = Some(path.to_string());
self
}
pub fn state<T: Send + Sync + 'static>(mut self, value: T) -> Self {
self.type_map.insert(value);
self
}
pub fn route<H, T>(mut self, method: Method, path: &str, handler: H) -> Self
where
H: Handler<T, ()>,
T: 'static,
{
self.router = self.router.route(method, path, handler);
self
}
pub fn get<H, T>(mut self, path: &str, handler: H) -> Self
where
H: Handler<T, ()>,
T: 'static,
{
self.router = self.router.get(path, handler);
self
}
pub fn post<H, T>(mut self, path: &str, handler: H) -> Self
where
H: Handler<T, ()>,
T: 'static,
{
self.router = self.router.post(path, handler);
self
}
pub fn put<H, T>(mut self, path: &str, handler: H) -> Self
where
H: Handler<T, ()>,
T: 'static,
{
self.router = self.router.put(path, handler);
self
}
pub fn delete<H, T>(mut self, path: &str, handler: H) -> Self
where
H: Handler<T, ()>,
T: 'static,
{
self.router = self.router.delete(path, handler);
self
}
pub fn patch<H, T>(mut self, path: &str, handler: H) -> Self
where
H: Handler<T, ()>,
T: 'static,
{
self.router = self.router.patch(path, handler);
self
}
pub fn controller<C: Controller>(mut self) -> Self {
self.controller_factories.push(Box::new(|state: AppState| {
let instance = Arc::new(C::from_state(&state));
let routes = C::register(instance);
let inner = C::configure_router(routes.into_inner());
OxideRouter::from_router(inner).nest_self(C::PREFIX)
}));
self
}
pub fn routes(mut self, router: OxideRouter) -> Self {
self.router = self.router.merge(router);
self
}
pub fn nest(mut self, prefix: &str, router: OxideRouter) -> Self {
self.router = self.router.nest(prefix, router);
self
}
pub fn rate_limit(mut self, max_requests: u64, window_secs: u64) -> Self {
self.rate_limit = Some((max_requests, Duration::from_secs(window_secs)));
self
}
pub fn cors_permissive(mut self) -> Self {
self.cors = Some(
CorsLayer::new()
.allow_origin(Any)
.allow_methods(Any)
.allow_headers(Any),
);
self
}
pub fn cors_origins<I, S>(mut self, origins: I) -> Self
where
I: IntoIterator<Item = S>,
S: AsRef<str>,
{
let origins: Vec<_> = origins
.into_iter()
.filter_map(|o| o.as_ref().parse().ok())
.collect();
self.cors = Some(
CorsLayer::new()
.allow_origin(origins)
.allow_methods(Any)
.allow_headers(Any),
);
self
}
pub fn request_timeout(mut self, secs: u64) -> Self {
self.request_timeout = Some(Duration::from_secs(secs));
self
}
pub fn disable_request_logging(mut self) -> Self {
self.request_logging = false;
self
}
pub fn auth(mut self, config: AuthConfig) -> Self {
assert!(
!config.secret.is_empty(),
"AuthConfig.secret must not be empty"
);
self.auth = Some(config);
self
}
pub fn before<F, Fut>(mut self, f: F) -> Self
where
F: Fn(Request, Next) -> Fut + Clone + Send + Sync + 'static,
Fut: Future<Output = Response> + Send + 'static,
{
self.user_layers.push(Box::new(move |router: Router| {
router.layer(axum::middleware::from_fn(f))
}));
self
}
pub fn scoped_state<F, Fut, T>(mut self, factory: F) -> Self
where
F: Fn(&axum::http::request::Parts) -> Fut + Send + Sync + 'static,
Fut: Future<Output = T> + Send + 'static,
T: Clone + Send + Sync + 'static,
{
let factory = Arc::new(factory);
self.user_layers.push(Box::new(move |router: Router| {
let f = factory.clone();
router.layer(axum::middleware::from_fn(move |req: Request, next: Next| {
let f = f.clone();
async move {
let (mut parts, body) = req.into_parts();
let val = f(&parts).await;
parts.extensions.insert(val);
let req = axum::extract::Request::from_parts(parts, body);
next.run(req).await
}
}))
}));
self
}
pub fn after<F, Fut>(mut self, f: F) -> Self
where
F: Fn(Response) -> Fut + Clone + Send + Sync + 'static,
Fut: Future<Output = Response> + Send + 'static,
{
self.user_layers.push(Box::new(move |router: Router| {
router.layer(axum::middleware::map_response(f))
}));
self
}
pub fn layer<L>(mut self, layer: L) -> Self
where
L: tower::Layer<axum::routing::Route> + Clone + Send + Sync + 'static,
L::Service: tower::Service<Request, Response = Response, Error = std::convert::Infallible>
+ Clone
+ Send
+ Sync
+ 'static,
<L::Service as tower::Service<Request>>::Future: Send + 'static,
{
self.user_layers.push(Box::new(move |router: Router| {
router.layer(layer)
}));
self
}
fn build_router(self, config: AppConfig) -> (Router, AppState) {
let app_state = AppState::new(config, self.type_map);
let mut base = self.router;
for factory in self.controller_factories {
let ctrl_router = factory(app_state.clone());
base = base.merge(ctrl_router);
}
let mut router = base.into_inner();
for transform in self.user_layers {
router = transform(router);
}
if let Some(auth_cfg) = self.auth {
router = router.layer(AuthLayer::new(auth_cfg));
}
router = router.layer(InjectStateLayer::new(app_state.clone()));
router = router.layer(CatchPanicLayer::custom(middleware::panic_json_response));
if let Some((max, window)) = self.rate_limit {
router = router.layer(RateLimitLayer::new(max, window));
}
if let Some(timeout) = self.request_timeout {
router = router.layer(RequestTimeoutLayer::new(timeout));
}
if let Some(cors) = self.cors {
router = router.layer(cors);
}
if self.request_logging {
router = router.layer(axum::middleware::from_fn(middleware::request_logger));
}
(router, app_state)
}
pub fn run(self) {
let rt = tokio::runtime::Runtime::new().expect("failed to create tokio runtime");
rt.block_on(self.serve());
}
pub async fn serve(mut self) {
self.config = AppConfig::load(self.config_path.as_deref());
let addr = format!("{}:{}", self.config.host, self.config.port);
let app_name = if self.config.app_name.is_empty() {
"oxide-app".to_string()
} else {
self.config.app_name.clone()
};
let config = self.config.clone();
let (router, _state) = self.build_router(config);
let listener = TcpListener::bind(&addr)
.await
.unwrap_or_else(|e| panic!("failed to bind to {addr}: {e}"));
info!(
name = %app_name,
address = %addr,
"Oxide server started"
);
axum::serve(
listener,
router.into_make_service_with_connect_info::<SocketAddr>(),
)
.with_graceful_shutdown(shutdown_signal())
.await
.expect("server error");
info!("Oxide server shut down gracefully");
}
pub async fn into_test_server(self) -> TestServer {
let config = self.config.clone();
let (router, _state) = self.build_router(config);
let listener = TcpListener::bind("127.0.0.1:0")
.await
.expect("failed to bind test server");
let addr = listener.local_addr().unwrap();
let handle = tokio::spawn(async move {
axum::serve(
listener,
router.into_make_service_with_connect_info::<SocketAddr>(),
)
.await
.ok();
});
TestServer { addr, handle }
}
}
pub struct TestServer {
addr: SocketAddr,
handle: tokio::task::JoinHandle<()>,
}
impl TestServer {
pub fn addr(&self) -> SocketAddr {
self.addr
}
pub fn url(&self, path: &str) -> String {
format!("http://{}{}", self.addr, path)
}
}
impl Drop for TestServer {
fn drop(&mut self) {
self.handle.abort();
}
}
async fn shutdown_signal() {
let ctrl_c = tokio::signal::ctrl_c();
#[cfg(unix)]
{
let mut sigterm =
tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
.expect("failed to install SIGTERM handler");
tokio::select! {
_ = ctrl_c => info!("received Ctrl+C, shutting down…"),
_ = sigterm.recv() => info!("received SIGTERM, shutting down…"),
}
}
#[cfg(not(unix))]
{
ctrl_c.await.expect("failed to listen for Ctrl+C");
info!("received Ctrl+C, shutting down…");
}
}