use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use http::{Method, StatusCode};
use crate::error::Result;
use crate::extract::RequestContext;
use crate::hooks::{ErrorEvent, RequestEvent, ResponseEvent, ValidationErrorEvent};
use crate::response::Response;
pub mod matcher;
pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
pub(crate) type SharedRequestHook =
Arc<dyn Fn(RequestEvent) -> BoxFuture<'static, ()> + Send + Sync>;
pub(crate) type SharedResponseHook =
Arc<dyn Fn(ResponseEvent) -> BoxFuture<'static, ()> + Send + Sync>;
pub(crate) type SharedErrorHook = Arc<dyn Fn(ErrorEvent) -> BoxFuture<'static, ()> + Send + Sync>;
pub(crate) type SharedValidationErrorHook =
Arc<dyn Fn(ValidationErrorEvent) -> BoxFuture<'static, ()> + Send + Sync>;
macro_rules! scoped_hook_builders {
() => {
pub fn on_request<F, Fut>(mut self, hook: F) -> Self
where
F: Fn(RequestEvent) -> Fut + Send + Sync + 'static,
Fut: Future<Output = ()> + Send + 'static,
{
self.request_hooks
.push(Arc::new(move |event| Box::pin(hook(event))));
self
}
pub fn on_response<F, Fut>(mut self, hook: F) -> Self
where
F: Fn(ResponseEvent) -> Fut + Send + Sync + 'static,
Fut: Future<Output = ()> + Send + 'static,
{
self.response_hooks
.push(Arc::new(move |event| Box::pin(hook(event))));
self
}
pub fn on_error<F, Fut>(mut self, hook: F) -> Self
where
F: Fn(ErrorEvent) -> Fut + Send + Sync + 'static,
Fut: Future<Output = ()> + Send + 'static,
{
self.error_hooks
.push(Arc::new(move |event| Box::pin(hook(event))));
self
}
pub fn on_validation_error<F, Fut>(mut self, hook: F) -> Self
where
F: Fn(ValidationErrorEvent) -> Fut + Send + Sync + 'static,
Fut: Future<Output = ()> + Send + 'static,
{
self.validation_hooks
.push(Arc::new(move |event| Box::pin(hook(event))));
self
}
};
}
pub type HandlerFn =
Arc<dyn Fn(RequestContext) -> BoxFuture<'static, Result<Response>> + Send + Sync>;
pub type SchemaThunk = fn(&mut schemars::SchemaGenerator) -> schemars::Schema;
#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
pub enum RequestBodyKind {
#[default]
Json,
Form,
Multipart,
}
#[derive(Clone, Debug)]
pub struct RouteMeta {
pub summary: Option<String>,
pub description: Option<String>,
pub tags: Vec<String>,
pub status_code: StatusCode,
pub response_model: Option<&'static str>,
pub request_schema: Option<SchemaThunk>,
pub request_kind: RequestBodyKind,
pub response_schema: Option<SchemaThunk>,
pub streaming: bool,
pub websocket: bool,
pub ws_incoming: Option<SchemaThunk>,
pub ws_outgoing: Option<SchemaThunk>,
}
impl Default for RouteMeta {
fn default() -> Self {
Self {
summary: None,
description: None,
tags: Vec::new(),
status_code: StatusCode::OK,
response_model: None,
request_schema: None,
request_kind: RequestBodyKind::Json,
response_schema: None,
streaming: false,
websocket: false,
ws_incoming: None,
ws_outgoing: None,
}
}
}
#[derive(Clone)]
pub struct Route {
method: Method,
path: String,
handler: HandlerFn,
meta: RouteMeta,
request_hooks: Vec<SharedRequestHook>,
response_hooks: Vec<SharedResponseHook>,
error_hooks: Vec<SharedErrorHook>,
validation_hooks: Vec<SharedValidationErrorHook>,
}
impl Route {
pub fn new(method: Method, path: impl Into<String>, handler: HandlerFn) -> Self {
Self {
method,
path: path.into(),
handler,
meta: RouteMeta::default(),
request_hooks: Vec::new(),
response_hooks: Vec::new(),
error_hooks: Vec::new(),
validation_hooks: Vec::new(),
}
}
scoped_hook_builders!();
pub fn summary(mut self, summary: impl Into<String>) -> Self {
self.meta.summary = Some(summary.into());
self
}
pub fn description(mut self, description: impl Into<String>) -> Self {
self.meta.description = Some(description.into());
self
}
pub fn tag(mut self, tag: impl Into<String>) -> Self {
let tag = tag.into();
if !self.meta.tags.contains(&tag) {
self.meta.tags.push(tag);
}
self
}
pub fn status_code(mut self, status_code: StatusCode) -> Self {
self.meta.status_code = status_code;
self
}
pub fn response_model<T: ?Sized>(mut self) -> Self {
self.meta.response_model = Some(std::any::type_name::<T>());
self
}
pub fn request_schema<T: schemars::JsonSchema>(mut self) -> Self {
self.meta.request_schema = Some(|generator| generator.subschema_for::<T>());
self
}
pub fn request_schema_fn(mut self, thunk: SchemaThunk) -> Self {
self.meta.request_schema = Some(thunk);
self
}
pub fn request_kind(mut self, kind: RequestBodyKind) -> Self {
self.meta.request_kind = kind;
self
}
pub fn response_schema<T: schemars::JsonSchema>(mut self) -> Self {
self.meta.response_schema = Some(|generator| generator.subschema_for::<T>());
self
}
pub fn streaming(mut self) -> Self {
self.meta.streaming = true;
self
}
pub fn websocket(mut self) -> Self {
self.meta.websocket = true;
self
}
pub fn ws_incoming<T: schemars::JsonSchema>(mut self) -> Self {
self.meta.ws_incoming = Some(|generator| generator.subschema_for::<T>());
self
}
pub fn ws_outgoing<T: schemars::JsonSchema>(mut self) -> Self {
self.meta.ws_outgoing = Some(|generator| generator.subschema_for::<T>());
self
}
pub fn method(&self) -> &Method {
&self.method
}
pub fn path(&self) -> &str {
&self.path
}
pub fn meta(&self) -> &RouteMeta {
&self.meta
}
pub fn handler(&self) -> &HandlerFn {
&self.handler
}
pub(crate) fn request_hooks(&self) -> &[SharedRequestHook] {
&self.request_hooks
}
pub(crate) fn response_hooks(&self) -> &[SharedResponseHook] {
&self.response_hooks
}
pub(crate) fn error_hooks(&self) -> &[SharedErrorHook] {
&self.error_hooks
}
pub(crate) fn validation_hooks(&self) -> &[SharedValidationErrorHook] {
&self.validation_hooks
}
pub(crate) fn has_hooks(&self) -> bool {
!self.request_hooks.is_empty()
|| !self.response_hooks.is_empty()
|| !self.error_hooks.is_empty()
|| !self.validation_hooks.is_empty()
}
fn prepend_prefix(mut self, prefix: &str) -> Self {
self.path = join_paths(prefix, &self.path);
self
}
fn inherit_tags(mut self, tags: &[String]) -> Self {
for tag in tags {
if !self.meta.tags.contains(tag) {
self.meta.tags.push(tag.clone());
}
}
self
}
fn prepend_hooks(
mut self,
request: &[SharedRequestHook],
response: &[SharedResponseHook],
error: &[SharedErrorHook],
validation: &[SharedValidationErrorHook],
) -> Self {
self.request_hooks.splice(0..0, request.iter().cloned());
self.response_hooks.splice(0..0, response.iter().cloned());
self.error_hooks.splice(0..0, error.iter().cloned());
self.validation_hooks
.splice(0..0, validation.iter().cloned());
self
}
}
#[derive(Default)]
pub struct Router {
prefix: String,
tags: Vec<String>,
routes: Vec<Route>,
request_hooks: Vec<SharedRequestHook>,
response_hooks: Vec<SharedResponseHook>,
error_hooks: Vec<SharedErrorHook>,
validation_hooks: Vec<SharedValidationErrorHook>,
}
impl Router {
pub fn new() -> Self {
Self::default()
}
scoped_hook_builders!();
pub fn prefix(mut self, prefix: impl Into<String>) -> Self {
self.prefix = prefix.into();
self
}
pub fn tags(mut self, tags: &[&str]) -> Self {
self.tags = tags.iter().map(|tag| (*tag).to_owned()).collect();
self
}
pub fn route(mut self, route: Route) -> Self {
self.routes.push(route);
self
}
pub fn include(mut self, child: Router) -> Self {
self.routes.extend(child.into_routes());
self
}
pub fn into_routes(self) -> Vec<Route> {
let Router {
prefix,
tags,
routes,
request_hooks,
response_hooks,
error_hooks,
validation_hooks,
} = self;
routes
.into_iter()
.map(|route| {
route
.prepend_prefix(&prefix)
.inherit_tags(&tags)
.prepend_hooks(
&request_hooks,
&response_hooks,
&error_hooks,
&validation_hooks,
)
})
.collect()
}
pub fn routes(&self) -> &[Route] {
&self.routes
}
}
fn join_paths(prefix: &str, path: &str) -> String {
let head = prefix.trim_end_matches('/');
let tail = path.trim_start_matches('/');
let mut combined = String::with_capacity(head.len() + tail.len() + 1);
combined.push_str(head);
if !tail.is_empty() {
combined.push('/');
combined.push_str(tail);
}
if !combined.starts_with('/') {
combined.insert(0, '/');
}
let normalized = combined.trim_end_matches('/');
if normalized.is_empty() {
"/".to_owned()
} else {
normalized.to_owned()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::response::empty;
fn dummy_handler() -> HandlerFn {
Arc::new(
|_ctx: RequestContext| -> BoxFuture<'static, Result<Response>> {
Box::pin(async { Ok(empty(StatusCode::OK)) })
},
)
}
fn get(path: &str) -> Route {
Route::new(Method::GET, path, dummy_handler())
}
#[test]
fn prefix_is_prepended_to_routes() {
let routes = Router::new()
.prefix("/users")
.tags(&["users"])
.route(get("/{user_id}"))
.into_routes();
assert_eq!(routes.len(), 1);
assert_eq!(routes[0].path(), "/users/{user_id}");
assert_eq!(routes[0].meta().tags, vec!["users".to_owned()]);
}
#[test]
fn root_route_drops_trailing_slash() {
let routes = Router::new().prefix("/users").route(get("/")).into_routes();
assert_eq!(routes[0].path(), "/users");
}
#[test]
fn nested_include_composes_prefixes_and_tags() {
let orders = Router::new()
.prefix("/{user_id}/orders")
.tags(&["orders"])
.route(get("/"));
let routes = Router::new()
.prefix("/users")
.tags(&["users"])
.include(orders)
.into_routes();
assert_eq!(routes[0].path(), "/users/{user_id}/orders");
assert_eq!(
routes[0].meta().tags,
vec!["orders".to_owned(), "users".to_owned()]
);
}
#[test]
fn route_tag_deduplicates_repeated_tags() {
let route = get("/x").tag("a").tag("a").tag("b");
assert_eq!(route.meta().tags, vec!["a".to_owned(), "b".to_owned()]);
}
#[test]
fn route_meta_default_has_empty_collections() {
let meta = RouteMeta::default();
assert!(meta.summary.is_none());
assert!(meta.description.is_none());
assert!(meta.tags.is_empty());
assert!(meta.request_schema.is_none());
assert!(meta.response_schema.is_none());
}
#[tokio::test]
async fn router_hooks_propagate_to_routes_outer_to_inner() {
use crate::hooks::{RequestEvent, RequestInfo};
use std::sync::Mutex;
let log: Arc<Mutex<Vec<&'static str>>> = Arc::new(Mutex::new(Vec::new()));
let outer_log = log.clone();
let inner_log = log.clone();
let inner = Router::new().route(get("/x")).on_request(move |_event| {
let log = inner_log.clone();
async move { log.lock().unwrap().push("inner") }
});
let outer = Router::new()
.on_request(move |_event| {
let log = outer_log.clone();
async move { log.lock().unwrap().push("outer") }
})
.include(inner);
let routes = outer.into_routes();
assert_eq!(routes.len(), 1);
let hooks = routes[0].request_hooks();
assert_eq!(hooks.len(), 2, "both router hooks attach to the route");
let info = RequestInfo::new(Method::GET, "/x".into(), Some("/x".into()), None);
for hook in hooks {
hook(RequestEvent::new(info.clone())).await;
}
assert_eq!(*log.lock().unwrap(), ["outer", "inner"]);
}
}