use std::sync::Arc;
use std::sync::Weak;
#[cfg(feature = "plugins")]
use std::sync::atomic::AtomicBool;
use std::time::Duration;
use arc_swap::ArcSwap;
use http::Method;
use http::StatusCode;
use smallvec::SmallVec;
use crate::body::TakoBody;
use crate::extractors::params::PathParams;
use crate::handler::BoxHandler;
use crate::handler::Handler;
use crate::middleware::Next;
#[cfg(feature = "plugins")]
use crate::plugins::TakoPlugin;
use crate::responder::Responder;
use crate::route::Route;
#[cfg(feature = "signals")]
use crate::signals::Signal;
#[cfg(feature = "signals")]
use crate::signals::SignalArbiter;
#[cfg(feature = "signals")]
use crate::signals::ids;
use crate::state::set_state;
use crate::types::BoxMiddleware;
use crate::types::Request;
use crate::types::Response;
#[doc(alias = "router")]
pub type ErrorHandler = Arc<dyn Fn(Response) -> Response + Send + Sync + 'static>;
pub struct Router {
inner: MethodMap<matchit::Router<Arc<Route>>>,
routes: MethodMap<Vec<Weak<Route>>>,
pub(crate) middlewares: ArcSwap<Vec<BoxMiddleware>>,
fallback: Option<BoxHandler>,
#[cfg(feature = "plugins")]
plugins: Vec<Box<dyn TakoPlugin>>,
#[cfg(feature = "plugins")]
plugins_initialized: AtomicBool,
#[cfg(feature = "signals")]
signals: SignalArbiter,
pub(crate) timeout: Option<Duration>,
timeout_fallback: Option<BoxHandler>,
error_handler: Option<ErrorHandler>,
}
impl Default for Router {
#[inline]
fn default() -> Self {
Self::new()
}
}
impl Router {
#[must_use]
pub fn new() -> Self {
let router = Self {
inner: MethodMap::new(),
routes: MethodMap::new(),
middlewares: ArcSwap::new(Arc::default()),
fallback: None,
#[cfg(feature = "plugins")]
plugins: Vec::new(),
#[cfg(feature = "plugins")]
plugins_initialized: AtomicBool::new(false),
#[cfg(feature = "signals")]
signals: SignalArbiter::new(),
timeout: None,
timeout_fallback: None,
error_handler: None,
};
#[cfg(feature = "signals")]
{
if crate::state::get_state::<SignalArbiter>().is_none() {
set_state::<SignalArbiter>(router.signals.clone());
}
}
router
}
pub fn route<H, T>(&mut self, method: Method, path: &str, handler: H) -> Arc<Route>
where
H: Handler<T> + Clone + 'static,
{
let route = Arc::new(Route::new(
path.to_string(),
method.clone(),
BoxHandler::new::<H, T>(handler),
None,
));
if let Err(err) = self
.inner
.get_or_default_mut(&method)
.insert(path.to_string(), route.clone())
{
panic!("Failed to register route: {err}");
}
self
.routes
.get_or_default_mut(&method)
.push(Arc::downgrade(&route));
route
}
pub fn route_with_tsr<H, T>(&mut self, method: Method, path: &str, handler: H) -> Arc<Route>
where
H: Handler<T> + Clone + 'static,
{
if path == "/" {
panic!("Cannot route with TSR for root path");
}
let route = Arc::new(Route::new(
path.to_string(),
method.clone(),
BoxHandler::new::<H, T>(handler),
Some(true),
));
if let Err(err) = self
.inner
.get_or_default_mut(&method)
.insert(path.to_string(), route.clone())
{
panic!("Failed to register route: {err}");
}
self
.routes
.get_or_default_mut(&method)
.push(Arc::downgrade(&route));
route
}
async fn run_with_global_middlewares_for_endpoint(
&self,
req: Request,
endpoint: BoxHandler,
) -> Response {
let guard = self.middlewares.load();
if guard.is_empty() {
drop(guard);
endpoint.call(req).await
} else {
drop(guard);
Next {
global_middlewares: self.middlewares.load_full(),
route_middlewares: Arc::default(),
index: 0,
endpoint,
}
.run(req)
.await
}
}
async fn run_with_timeout(
&self,
req: Request,
next: Next,
timeout_duration: Option<Duration>,
) -> Response {
match timeout_duration {
Some(duration) => {
#[cfg(not(feature = "compio"))]
{
match tokio::time::timeout(duration, next.run(req)).await {
Ok(response) => response,
Err(_elapsed) => self.handle_timeout().await,
}
}
#[cfg(feature = "compio")]
{
let sleep = std::pin::pin!(compio::time::sleep(duration));
let work = std::pin::pin!(next.run(req));
match futures_util::future::select(work, sleep).await {
futures_util::future::Either::Left((response, _)) => response,
futures_util::future::Either::Right((_, _)) => self.handle_timeout().await,
}
}
}
None => next.run(req).await,
}
}
async fn handle_timeout(&self) -> Response {
if let Some(handler) = &self.timeout_fallback {
handler.call(Request::default()).await
} else {
http::Response::builder()
.status(StatusCode::REQUEST_TIMEOUT)
.body(TakoBody::empty())
.expect("valid 408 response")
}
}
pub async fn dispatch(&self, mut req: Request) -> Response {
let method = req.method().clone();
let route_match = {
if let Some(method_router) = self.inner.get(&method)
&& let Ok(matched) = method_router.at(req.uri().path())
{
let route = Arc::clone(matched.value);
let mut it = matched.params.iter();
let first = it.next();
let params = first.map(|(fk, fv)| {
let mut p = SmallVec::<[(String, String); 4]>::new();
p.push((fk.to_string(), fv.to_string()));
for (k, v) in it {
p.push((k.to_string(), v.to_string()));
}
PathParams(p)
});
Some((route, params))
} else {
None
}
};
let response = if let Some((route, params)) = route_match {
if let Some(res) = Self::enforce_protocol_guard(&route, &req) {
return self.maybe_apply_error_handler(res);
}
#[cfg(feature = "signals")]
let route_signals = route.signal_arbiter();
#[cfg(feature = "plugins")]
route.setup_plugins_once();
if let Some(mode) = route.get_simd_json_mode() {
req.extensions_mut().insert(mode);
}
if let Some(params) = params {
req.extensions_mut().insert(params);
}
let effective_timeout = route.get_timeout().or(self.timeout);
let g = self.middlewares.load();
let r = route.middlewares.load();
let needs_chain = !g.is_empty() || !r.is_empty();
drop(g);
drop(r);
#[cfg(feature = "signals")]
{
let method_str = method.to_string();
let path_str = req.uri().path().to_string();
route_signals
.emit(
Signal::with_capacity(ids::ROUTE_REQUEST_STARTED, 2)
.meta("method", method_str.clone())
.meta("path", path_str.clone()),
)
.await;
let response = if !needs_chain && effective_timeout.is_none() {
route.handler.call(req).await
} else {
let next = Next {
global_middlewares: self.middlewares.load_full(),
route_middlewares: route.middlewares.load_full(),
index: 0,
endpoint: route.handler.clone(),
};
self.run_with_timeout(req, next, effective_timeout).await
};
route_signals
.emit(
Signal::with_capacity(ids::ROUTE_REQUEST_COMPLETED, 3)
.meta("method", method_str)
.meta("path", path_str)
.meta("status", response.status().as_u16().to_string()),
)
.await;
response
}
#[cfg(not(feature = "signals"))]
{
if !needs_chain && effective_timeout.is_none() {
route.handler.call(req).await
} else {
let next = Next {
global_middlewares: self.middlewares.load_full(),
route_middlewares: route.middlewares.load_full(),
index: 0,
endpoint: route.handler.clone(),
};
self.run_with_timeout(req, next, effective_timeout).await
}
}
} else {
let tsr_path = {
let p = req.uri().path();
if p.ends_with('/') {
p.trim_end_matches('/').to_string()
} else {
format!("{p}/")
}
};
if let Some(method_router) = self.inner.get(&method)
&& let Ok(matched) = method_router.at(&tsr_path)
&& matched.value.tsr
{
let handler = move |_req: Request| async move {
http::Response::builder()
.status(StatusCode::TEMPORARY_REDIRECT)
.header("Location", tsr_path.clone())
.body(TakoBody::empty())
.expect("valid redirect response")
};
self
.run_with_global_middlewares_for_endpoint(req, BoxHandler::new::<_, (Request,)>(handler))
.await
} else if let Some(handler) = &self.fallback {
self
.run_with_global_middlewares_for_endpoint(req, handler.clone())
.await
} else {
let handler = |_req: Request| async {
http::Response::builder()
.status(StatusCode::NOT_FOUND)
.body(TakoBody::empty())
.expect("valid 404 response")
};
self
.run_with_global_middlewares_for_endpoint(req, BoxHandler::new::<_, (Request,)>(handler))
.await
}
};
self.maybe_apply_error_handler(response)
}
fn maybe_apply_error_handler(&self, response: Response) -> Response {
if response.status().is_server_error() {
if let Some(handler) = &self.error_handler {
return handler(response);
}
}
response
}
pub fn state<T: Clone + Send + Sync + 'static>(&mut self, value: T) {
set_state(value);
}
#[cfg(feature = "signals")]
pub fn signals(&self) -> &SignalArbiter {
&self.signals
}
#[cfg(feature = "signals")]
pub fn signal_arbiter(&self) -> SignalArbiter {
self.signals.clone()
}
#[cfg(feature = "signals")]
pub fn on_signal<F, Fut>(&self, id: impl Into<String>, handler: F)
where
F: Fn(Signal) -> Fut + Send + Sync + 'static,
Fut: std::future::Future<Output = ()> + Send + 'static,
{
self.signals.on(id, handler);
}
#[cfg(feature = "signals")]
pub async fn emit_signal(&self, signal: Signal) {
self.signals.emit(signal).await;
}
pub fn middleware<F, Fut, R>(&self, f: F) -> &Self
where
F: Fn(Request, Next) -> Fut + Clone + Send + Sync + 'static,
Fut: std::future::Future<Output = R> + Send + 'static,
R: Responder + Send + 'static,
{
let mw: BoxMiddleware = Arc::new(move |req, next| {
let fut = f(req, next);
Box::pin(async move { fut.await.into_response() })
});
let mut middlewares = self.middlewares.load().iter().cloned().collect::<Vec<_>>();
middlewares.push(mw);
self.middlewares.store(Arc::new(middlewares));
self
}
pub fn fallback<F, Fut, R>(&mut self, handler: F) -> &mut Self
where
F: Fn(Request) -> Fut + Clone + Send + Sync + 'static,
Fut: std::future::Future<Output = R> + Send + 'static,
R: Responder + Send + 'static,
{
self.fallback = Some(BoxHandler::new::<F, (Request,)>(handler));
self
}
pub fn fallback_with_extractors<H, T>(&mut self, handler: H) -> &mut Self
where
H: Handler<T> + Clone + 'static,
{
self.fallback = Some(BoxHandler::new::<H, T>(handler));
self
}
pub fn timeout(&mut self, duration: Duration) -> &mut Self {
self.timeout = Some(duration);
self
}
pub fn timeout_fallback<F, Fut, R>(&mut self, handler: F) -> &mut Self
where
F: Fn(Request) -> Fut + Clone + Send + Sync + 'static,
Fut: std::future::Future<Output = R> + Send + 'static,
R: Responder + Send + 'static,
{
self.timeout_fallback = Some(BoxHandler::new::<F, (Request,)>(handler));
self
}
pub fn error_handler(
&mut self,
handler: impl Fn(Response) -> Response + Send + Sync + 'static,
) -> &mut Self {
self.error_handler = Some(Arc::new(handler));
self
}
#[cfg(feature = "plugins")]
#[cfg_attr(docsrs, doc(cfg(feature = "plugins")))]
pub fn plugin<P>(&mut self, plugin: P) -> &mut Self
where
P: TakoPlugin + Clone + Send + Sync + 'static,
{
self.plugins.push(Box::new(plugin));
self
}
#[cfg(feature = "plugins")]
#[cfg_attr(docsrs, doc(cfg(feature = "plugins")))]
pub(crate) fn plugins(&self) -> Vec<&dyn TakoPlugin> {
self.plugins.iter().map(|plugin| plugin.as_ref()).collect()
}
#[cfg(feature = "plugins")]
#[cfg_attr(docsrs, doc(cfg(feature = "plugins")))]
pub(crate) fn setup_plugins_once(&self) {
use std::sync::atomic::Ordering;
if !self.plugins_initialized.swap(true, Ordering::SeqCst) {
for plugin in self.plugins() {
let _ = plugin.setup(self);
}
}
}
pub fn merge(&mut self, other: Router) {
let upstream_globals = other.middlewares.load_full();
for (method, weak_vec) in other.routes.iter() {
for weak in weak_vec {
if let Some(route) = weak.upgrade() {
let existing = route.middlewares.load_full();
let mut merged = Vec::with_capacity(upstream_globals.len() + existing.len());
merged.extend(upstream_globals.iter().cloned());
merged.extend(existing.iter().cloned());
route.middlewares.store(Arc::new(merged));
let _ = self
.inner
.get_or_default_mut(&method)
.insert(route.path.clone(), route.clone());
self
.routes
.get_or_default_mut(&method)
.push(Arc::downgrade(&route));
}
}
}
#[cfg(feature = "signals")]
self.signals.merge_from(&other.signals);
}
fn enforce_protocol_guard(route: &Route, req: &Request) -> Option<Response> {
if let Some(guard) = route.protocol_guard()
&& guard != req.version()
{
return Some(
http::Response::builder()
.status(StatusCode::HTTP_VERSION_NOT_SUPPORTED)
.body(TakoBody::empty())
.expect("valid HTTP version not supported response"),
);
}
None
}
#[cfg(any(feature = "utoipa", feature = "vespera"))]
#[cfg_attr(docsrs, doc(cfg(any(feature = "utoipa", feature = "vespera"))))]
pub fn collect_openapi_routes(&self) -> Vec<(Method, String, crate::openapi::RouteOpenApi)> {
let mut result = Vec::new();
for (method, weak_vec) in self.routes.iter() {
for weak in weak_vec {
if let Some(route) = weak.upgrade() {
if let Some(openapi) = route.openapi_metadata() {
result.push((method.clone(), route.path.clone(), openapi));
}
}
}
}
result
}
}
#[inline]
fn method_slot(method: &Method) -> Option<usize> {
Some(match *method {
Method::GET => 0,
Method::POST => 1,
Method::PUT => 2,
Method::DELETE => 3,
Method::PATCH => 4,
Method::HEAD => 5,
Method::OPTIONS => 6,
Method::CONNECT => 7,
Method::TRACE => 8,
_ => return None,
})
}
#[inline]
fn method_from_slot(idx: usize) -> Method {
match idx {
0 => Method::GET,
1 => Method::POST,
2 => Method::PUT,
3 => Method::DELETE,
4 => Method::PATCH,
5 => Method::HEAD,
6 => Method::OPTIONS,
7 => Method::CONNECT,
8 => Method::TRACE,
_ => unreachable!(),
}
}
struct MethodMap<V> {
standard: [Option<V>; 9],
custom: Vec<(Method, V)>,
}
impl<V> MethodMap<V> {
fn new() -> Self {
Self {
standard: std::array::from_fn(|_| None),
custom: Vec::new(),
}
}
#[inline]
fn get(&self, method: &Method) -> Option<&V> {
if let Some(idx) = method_slot(method) {
self.standard[idx].as_ref()
} else {
self
.custom
.iter()
.find(|(m, _)| m == method)
.map(|(_, v)| v)
}
}
fn get_or_default_mut(&mut self, method: &Method) -> &mut V
where
V: Default,
{
if let Some(idx) = method_slot(method) {
self.standard[idx].get_or_insert_with(V::default)
} else {
let pos = self.custom.iter().position(|(m, _)| m == method);
match pos {
Some(pos) => &mut self.custom[pos].1,
None => {
self.custom.push((method.clone(), V::default()));
&mut self.custom.last_mut().unwrap().1
}
}
}
}
fn iter(&self) -> impl Iterator<Item = (Method, &V)> {
self
.standard
.iter()
.enumerate()
.filter_map(|(idx, slot)| slot.as_ref().map(|v| (method_from_slot(idx), v)))
.chain(self.custom.iter().map(|(m, v)| (m.clone(), v)))
}
}