pub mod method;
pub mod radix;
pub mod route;
use crate::handler::universal_handler;
use crate::middleware::{BoxFuture, Middleware, Next};
use crate::response::IntoResponse;
use crate::{Error, Extensions, Handler, HandlerFn, Request, Response, Result};
use arc_swap::ArcSwap;
use dashmap::DashMap;
use http::Method;
use parking_lot::RwLock;
use std::sync::Arc;
pub use radix::{RadixNode, RadixRouter};
pub use route::Route;
macro_rules! define_http_method {
($name:ident, $method:expr, $doc:expr) => {
#[doc = $doc]
pub fn $name<H, T>(self, path: &str, handler: H) -> Self
where
H: crate::handler::UniversalHandler<T>,
{
self.route_with(path, $method, handler)
}
};
}
#[derive(Clone)]
pub struct LayeredHandler {
handler: HandlerFn,
middleware: Vec<Arc<dyn Middleware>>,
}
impl LayeredHandler {
pub fn new<H, T>(handler: H) -> Self
where
H: crate::handler::UniversalHandler<T>,
{
Self {
handler: universal_handler(handler),
middleware: Vec::new(),
}
}
pub fn layer<M: Middleware + 'static>(mut self, mw: M) -> Self {
self.middleware.push(Arc::new(mw));
self
}
pub fn into_handler(self) -> HandlerFn {
wrap_handler_with_middleware(self.handler, self.middleware)
}
}
#[derive(Clone)]
struct CompiledRouter {
radix_router: RadixRouter,
_middleware: Vec<Arc<dyn Middleware>>,
not_found_handler: Option<HandlerFn>,
}
pub struct Router {
pub inner: Arc<RwLock<RouterInner>>,
compiled: ArcSwap<CompiledRouter>,
}
pub struct RouterInner {
radix_router: RadixRouter,
middleware: Vec<Arc<dyn Middleware>>,
not_found_handler: Option<HandlerFn>,
nested_routers: Vec<(String, Router)>,
dirty: bool,
pub extensions: Extensions,
#[cfg(feature = "websocket")]
#[cfg_attr(docsrs, doc(cfg(feature = "websocket")))]
websocket_routes: DashMap<String, Arc<dyn crate::websocket::WebSocketHandler>>,
}
impl Router {
pub fn new() -> Self {
let inner = RouterInner {
radix_router: RadixRouter::new(),
middleware: Vec::new(),
not_found_handler: None,
nested_routers: Vec::new(),
extensions: Extensions::new(),
dirty: true,
#[cfg(feature = "websocket")]
websocket_routes: DashMap::new(),
};
Self {
inner: Arc::new(RwLock::new(inner)),
compiled: ArcSwap::new(Arc::new(CompiledRouter {
radix_router: RadixRouter::new(),
_middleware: Vec::new(),
not_found_handler: None,
})),
}
}
fn extract_radix_routes(radix_router: &RadixRouter) -> DashMap<Method, Vec<Route>> {
let routes = DashMap::new();
Self::extract_node_routes(&radix_router.root, "", &routes);
routes
}
fn extract_node_routes(
node: &RadixNode,
path_prefix: &str,
routes: &DashMap<Method, Vec<Route>>,
) {
let current_path = if path_prefix.is_empty() {
if let Some(param_name) = &node.param_name {
if node.is_wildcard {
format!("/{{*{}}}", param_name)
} else {
format!("/{{{}}}", param_name)
}
} else {
node.path.clone()
}
} else {
if let Some(param_name) = &node.param_name {
let param_syntax = if node.is_wildcard {
format!("/{{*{}}}", param_name)
} else {
format!("/{{{}}}", param_name)
};
format!("{}{}", path_prefix.trim_end_matches('/'), param_syntax)
} else if node.path.is_empty() {
path_prefix.to_string()
} else if node.path.starts_with('/') {
format!("{}{}", path_prefix.trim_end_matches('/'), node.path)
} else {
format!("{}/{}", path_prefix.trim_end_matches('/'), node.path)
}
};
for entry in &node.handlers {
let method = entry.key();
let handler = entry.value();
let route_path = if current_path.is_empty() {
"/".to_string()
} else if !current_path.starts_with('/') {
format!("/{}", current_path)
} else {
current_path.clone()
};
let route = Route::new(&route_path, method.clone(), handler.clone());
routes
.entry(method.clone())
.or_insert_with(Vec::new)
.push(route);
}
for child in &node.children {
Self::extract_node_routes(child, ¤t_path, routes);
}
}
pub fn route(self, path: &str, method: Method, handler: HandlerFn) -> Self {
let full_path = normalize_path(path);
let mut inner = self.inner.write();
inner.dirty = true;
inner.radix_router.insert(&full_path, method, handler);
drop(inner);
self
}
pub fn route_with<H, T>(self, path: &str, method: Method, handler: H) -> Self
where
H: crate::handler::UniversalHandler<T>,
{
self.route(path, method, universal_handler(handler))
}
pub fn route_with_layered(self, path: &str, method: Method, lh: LayeredHandler) -> Self {
self.route(path, method, lh.into_handler())
}
define_http_method!(get, Method::GET, "Adds a GET route");
define_http_method!(post, Method::POST, "Adds a POST route");
define_http_method!(put, Method::PUT, "Adds a PUT route");
define_http_method!(delete, Method::DELETE, "Adds a DELETE route");
define_http_method!(patch, Method::PATCH, "Adds a PATCH route");
define_http_method!(head, Method::HEAD, "Adds a HEAD route");
define_http_method!(options, Method::OPTIONS, "Adds an OPTIONS route");
define_http_method!(connect, Method::CONNECT, "Adds an CONNECT route");
define_http_method!(trace, Method::TRACE, "Adds an TRACE route");
pub fn any<H, T>(self, path: &str, handler: H) -> Self
where
H: crate::handler::UniversalHandler<T>,
{
let methods = [
Method::GET,
Method::POST,
Method::PUT,
Method::DELETE,
Method::PATCH,
Method::HEAD,
Method::OPTIONS,
Method::CONNECT,
Method::TRACE,
];
let mut router = self;
for method in methods {
router = router.route_with(path, method, handler.clone());
}
router
}
pub fn middleware<M: Middleware + 'static>(self, middleware: M) -> Self {
let mut inner = self.inner.write();
inner.dirty = true;
inner.middleware.push(Arc::new(middleware));
drop(inner);
self
}
pub fn not_found<H, T>(self, handler: H) -> Self
where
H: crate::handler::UniversalHandler<T>,
{
let mut inner = self.inner.write();
inner.dirty = true;
inner.not_found_handler = Some(universal_handler(handler));
drop(inner);
self
}
pub fn nest(self, path: &str, router: Router) -> Self {
let prefix = normalize_path(path);
let mut inner = self.inner.write();
inner.dirty = true;
let nested_inner = router.inner.read();
let mut wrapped_root = nested_inner.radix_router.root.clone();
wrap_tree_handlers(&mut wrapped_root, nested_inner.middleware.clone());
inner
.radix_router
.insert_nested(&prefix, &RadixRouter { root: wrapped_root });
#[cfg(feature = "websocket")]
{
for entry in nested_inner.websocket_routes.iter() {
let path = entry.key();
let handler = entry.value();
let full_path = if path == "/" {
prefix.clone()
} else {
format!("{}{}", prefix.trim_end_matches('/'), path)
};
inner.websocket_routes.insert(full_path, handler.clone());
}
}
if inner.not_found_handler.is_none() {
inner.not_found_handler = nested_inner.not_found_handler.clone();
}
Self::merge_extensions(&mut inner.extensions, &nested_inner.extensions);
drop(nested_inner);
drop(inner);
self
}
pub fn merge(self, other: Router) -> Self {
let mut inner = self.inner.write();
let other_inner = other.inner.read();
inner.dirty = true;
let extracted_routes = Self::extract_radix_routes(&other_inner.radix_router);
for entry in extracted_routes.iter() {
let method = entry.key().clone();
let routes = entry.value();
for route in routes.iter() {
inner
.radix_router
.insert(&route.path, method.clone(), route.handler.clone());
}
}
inner
.middleware
.extend(other_inner.middleware.iter().cloned());
inner
.nested_routers
.extend(other_inner.nested_routers.iter().cloned());
if inner.not_found_handler.is_none() && other_inner.not_found_handler.is_some() {
inner.not_found_handler = other_inner.not_found_handler.clone();
}
#[cfg(feature = "websocket")]
{
for entry in other_inner.websocket_routes.iter() {
let path = entry.key();
let handler = entry.value();
if !inner.websocket_routes.contains_key(path) {
inner.websocket_routes.insert(path.clone(), handler.clone());
}
}
}
Self::merge_extensions(&mut inner.extensions, &other_inner.extensions);
drop(inner);
drop(other_inner);
self
}
fn merge_extensions(target_extensions: &mut Extensions, source_extensions: &Extensions) {
for entry in source_extensions.map.iter() {
let type_id = entry.key();
let extension = entry.value();
target_extensions.insert_if_not_exists_typeid(*type_id, extension.clone());
}
}
#[cfg(feature = "websocket")]
pub fn websocket<H, T>(self, path: &str, handler: H) -> Self
where
H: crate::websocket::UniversalWebSocketHandler<T>,
T: Send + Sync + 'static,
{
let normalized_path = normalize_path(path);
let ws_handler = crate::websocket::universal_ws_handler(handler);
{
let mut inner = self.inner.write();
inner.dirty = true;
inner
.websocket_routes
.insert(normalized_path.clone(), Arc::clone(&ws_handler));
}
let http_handler = Arc::new(move |req: Request| {
Box::pin(async move {
if crate::websocket::is_websocket_request(&req) {
Ok(crate::websocket::upgrade_connection(&req)?)
} else {
Ok(Response::bad_request("WebSocket upgrade required"))
}
}) as crate::handler::BoxFuture<'static, crate::Result<Response>>
});
self.route(&normalized_path, Method::GET, http_handler)
}
#[cfg(feature = "websocket")]
pub fn websocket_fn<F, Fut, R>(self, path: &str, f: F) -> Self
where
F: Fn(crate::websocket::WebSocketConnection) -> Fut + Clone + Send + Sync + 'static,
Fut: std::future::Future<Output = R> + Send + 'static,
R: crate::response::IntoResponse,
{
use crate::websocket::websocket_handler;
self.websocket(path, websocket_handler(f))
}
#[cfg(feature = "websocket")]
pub fn get_websocket_handlers(
&self,
) -> DashMap<String, Arc<dyn crate::websocket::WebSocketHandler>> {
self.inner.read().websocket_routes.clone()
}
pub fn state<T>(self, state: T) -> Self
where
T: Clone + Send + Sync + 'static,
{
let mut inner = self.inner.write();
inner.dirty = true;
inner.extensions.insert(state);
drop(inner);
self
}
pub fn state_arc<T>(self, state: Arc<T>) -> Self
where
T: Send + Sync + 'static,
{
let mut inner = self.inner.write();
inner.dirty = true;
inner.extensions.insert(state);
drop(inner);
self
}
pub fn state_factory<T, F>(self, factory: F) -> Self
where
T: Clone + Send + Sync + 'static,
F: FnOnce() -> T,
{
let state = factory();
let mut inner = self.inner.write();
inner.dirty = true;
inner.extensions.insert(state);
drop(inner);
self
}
pub fn has_state<T: Send + Sync + Clone + 'static>(&self) -> bool {
self.inner.read().extensions.get::<T>().is_some()
}
pub fn get_state<T: Clone + Send + Sync + 'static>(&self) -> Option<T> {
self.inner
.read()
.extensions
.get::<T>()
.map(|arc_t| arc_t.as_ref().clone())
}
fn ensure_compiled(&self) -> Arc<CompiledRouter> {
{
let inner = self.inner.read();
if !inner.dirty {
return self.compiled.load_full();
}
}
let compiled = {
let inner = self.inner.read();
self.compile_inner(&*inner)
};
let compiled_arc = Arc::new(compiled);
self.compiled.store(Arc::clone(&compiled_arc));
{
let mut inner = self.inner.write();
inner.dirty = false;
}
compiled_arc
}
fn compile_inner(&self, inner: &RouterInner) -> CompiledRouter {
let mut middleware = inner.middleware.clone();
let mut not_found_handler = inner.not_found_handler.clone();
let mut radix_router = inner.radix_router.clone();
if !middleware.is_empty() {
if let Some(h) = &mut not_found_handler {
let wrapped = wrap_handler_with_middleware(h.clone(), middleware.clone());
not_found_handler = Some(wrapped);
}
let mut root = radix_router.root.clone();
wrap_tree_handlers(&mut root, middleware.clone());
radix_router = RadixRouter { root };
middleware.clear();
}
CompiledRouter {
radix_router,
_middleware: middleware,
not_found_handler,
}
}
#[inline]
pub async fn handle(&self, req: Request) -> Result<Response> {
let compiled = self.ensure_compiled();
let mut req = req;
{
let inner = self.inner.read();
Self::merge_extensions(&mut req.extensions, &inner.extensions);
}
self.handle_radix_route(&compiled, req).await
}
async fn handle_radix_route(
&self,
compiled: &CompiledRouter,
req: Request,
) -> Result<Response> {
if let Some((handler, params)) = compiled.radix_router.lookup(&req.method, req.uri.path()) {
let mut req = req;
req.params = params;
return handler.handle(req).await;
}
if let Some(handler) = &compiled.not_found_handler {
handler.handle(req.clone()).await
} else {
Err(Error::NotFound(req.uri.path().to_string()))
}
}
pub fn matches(&self, method: &Method, path: &str) -> bool {
let compiled = self.ensure_compiled();
compiled.radix_router.lookup(method, path).is_some()
}
pub fn stats(&self) -> Option<crate::router::radix::RadixStats> {
let compiled = self.ensure_compiled();
Some(compiled.radix_router.stats())
}
pub fn print_tree(&self) {
let compiled = self.ensure_compiled();
compiled.radix_router.print_tree();
}
}
fn wrap_handler_with_middleware(
handler: HandlerFn,
middleware: Vec<Arc<dyn Middleware>>,
) -> HandlerFn {
if middleware.is_empty() {
return handler;
}
let terminal = Arc::new(move |req: Request| -> BoxFuture<'static, Response> {
let handler = handler.clone();
Box::pin(async move {
match handler.handle(req).await {
Ok(resp) => resp,
Err(err) => err.into_response(),
}
})
}) as Arc<dyn Fn(Request) -> BoxFuture<'static, Response> + Send + Sync>;
let chain = middleware.iter().rev().fold(terminal, |next, mw| {
let mw = mw.clone();
Arc::new(move |req: Request| -> BoxFuture<'static, Response> {
let mw = mw.clone();
let next = next.clone();
Box::pin(async move {
let nxt = Next::new(move |r: Request| -> BoxFuture<'static, Response> { next(r) });
mw.handle(req, nxt).await
})
}) as Arc<dyn Fn(Request) -> BoxFuture<'static, Response> + Send + Sync>
});
Arc::new(
move |req: Request| -> BoxFuture<'static, crate::Result<Response>> {
let chain = chain.clone();
Box::pin(async move { Ok(chain(req).await) })
},
)
}
fn wrap_tree_handlers(node: &mut RadixNode, middleware: Vec<Arc<dyn Middleware>>) {
let new_handlers = DashMap::new();
for entry in &node.handlers {
let method = entry.key().clone();
let handler = entry.value().clone();
let wrapped = wrap_handler_with_middleware(handler.clone(), middleware.clone());
new_handlers.insert(method.clone(), wrapped);
}
node.handlers = new_handlers;
for child in &mut node.children {
wrap_tree_handlers(child, middleware.clone());
}
}
#[inline]
fn normalize_path(path: &str) -> String {
let mut normalized = path.to_string();
if !normalized.starts_with('/') {
normalized.insert(0, '/');
}
if normalized != "/" && normalized.ends_with('/') {
normalized.pop();
}
normalized
}
impl Clone for Router {
fn clone(&self) -> Self {
let inner = self.inner.read();
Self {
inner: Arc::new(RwLock::new(RouterInner {
radix_router: inner.radix_router.clone(),
middleware: inner.middleware.clone(),
not_found_handler: inner.not_found_handler.clone(),
nested_routers: inner.nested_routers.clone(),
dirty: inner.dirty,
extensions: inner.extensions.clone(),
#[cfg(feature = "websocket")]
websocket_routes: inner.websocket_routes.clone(),
})),
compiled: ArcSwap::new(Arc::new(CompiledRouter {
radix_router: RadixRouter::new(),
_middleware: Vec::new(),
not_found_handler: None,
})),
}
}
}
impl Default for Router {
fn default() -> Self {
Self::new()
}
}