#![allow(clippy::type_complexity)]
use crate::{
body::Body,
controller::{Controller, DynControllerHandler},
error::SaphirError,
guard::{Builder as GuardBuilder, GuardChain, GuardChainEnd},
handler::DynHandler,
http_context::{HandlerMetadata, HttpContext, RouteId, State},
request::Request,
responder::{DynResponder, Responder},
utils::{EndpointResolver, EndpointResolverResult},
};
use futures::{future::BoxFuture, FutureExt};
use http::Method;
use std::{collections::HashMap, sync::Arc};
pub struct Builder<Chain: RouterChain + Send + Unpin + 'static + Sync> {
resolver: HashMap<String, EndpointResolver>,
chain: Chain,
}
impl Default for Builder<RouterChainEnd> {
fn default() -> Self {
Self {
resolver: Default::default(),
chain: RouterChainEnd { handlers: Default::default() },
}
}
}
impl<Controllers: 'static + RouterChain + Unpin + Send + Sync> Builder<Controllers> {
pub fn route<H>(mut self, route: &str, method: Method, handler: H) -> Self
where
H: 'static + DynHandler<Body> + Send + Sync,
{
let endpoint_id = if let Some(er) = self.resolver.get_mut(route) {
er.add_method(method.clone());
er.id()
} else {
let er = EndpointResolver::new(route, method.clone()).expect("Unable to construct endpoint resolver");
let er_id = er.id();
self.resolver.insert(route.to_string(), er);
er_id
};
self.chain
.add_handler(endpoint_id, method, Box::new(handler), crate::guard::Builder::default().build());
self
}
pub fn route_with_guards<H, F, Chain>(mut self, route: &str, method: Method, handler: H, guards: F) -> Self
where
H: 'static + DynHandler<Body> + Send + Sync,
F: FnOnce(GuardBuilder<GuardChainEnd>) -> GuardBuilder<Chain>,
Chain: GuardChain + 'static,
{
let endpoint_id = if let Some(er) = self.resolver.get_mut(route) {
er.add_method(method.clone());
er.id()
} else {
let er = EndpointResolver::new(route, method.clone()).expect("Unable to construct endpoint resolver");
let er_id = er.id();
self.resolver.insert(route.to_string(), er);
er_id
};
self.chain
.add_handler(endpoint_id, method, Box::new(handler), guards(GuardBuilder::default()).build());
self
}
pub fn controller<C: Controller + Send + Unpin + Sync>(mut self, controller: C) -> Builder<RouterChainLink<C, Controllers>> {
let mut handlers = HashMap::new();
for (name, method, subroute, handler, guard_chain) in controller.handlers() {
let route = format!("{}{}", C::BASE_PATH, subroute);
let meta = name.map(|name| HandlerMetadata {
route_id: Default::default(),
name: Some(name),
});
let endpoint_id = if let Some(er) = self.resolver.get_mut(&route) {
er.add_method_with_metadata(method.clone(), meta);
er.id()
} else {
let er = EndpointResolver::new_with_metadata(&route, method.clone(), meta).expect("Unable to construct endpoint resolver");
let er_id = er.id();
self.resolver.insert(route, er);
er_id
};
handlers.insert((endpoint_id, method), (handler, guard_chain));
}
Builder {
resolver: self.resolver,
chain: RouterChainLink {
controller,
handlers,
rest: self.chain,
},
}
}
pub(crate) fn build(self) -> Router {
let Builder { resolver, chain: controllers } = self;
let mut resolvers: Vec<_> = resolver.into_values().collect();
resolvers.sort_unstable();
Router {
inner: Arc::new(RouterInner {
resolvers,
chain: Box::new(controllers),
}),
}
}
}
struct RouterInner {
resolvers: Vec<EndpointResolver>,
chain: Box<dyn RouterChain + Send + Unpin + Sync>,
}
#[doc(hidden)]
#[derive(Clone)]
pub struct Router {
inner: Arc<RouterInner>,
}
impl Router {
pub fn builder() -> Builder<RouterChainEnd> {
Builder::default()
}
pub fn resolve(&self, req: &mut Request<Body>) -> Result<u64, u16> {
let mut method_not_allowed = false;
for endpoint_resolver in &self.inner.resolvers {
match endpoint_resolver.resolve(req) {
EndpointResolverResult::InvalidPath => continue,
EndpointResolverResult::MethodNotAllowed => method_not_allowed = true,
EndpointResolverResult::Match(_) => return Ok(endpoint_resolver.id()),
}
}
if method_not_allowed {
Err(405)
} else {
Err(404)
}
}
pub fn resolve_metadata(&self, req: &mut Request) -> HandlerMetadata {
let mut method_not_allowed = false;
for endpoint_resolver in &self.inner.resolvers {
match endpoint_resolver.resolve(req) {
EndpointResolverResult::InvalidPath => continue,
EndpointResolverResult::MethodNotAllowed => method_not_allowed = true,
EndpointResolverResult::Match(meta) => return meta.clone(),
}
}
if method_not_allowed {
HandlerMetadata::not_allowed()
} else {
HandlerMetadata::not_found()
}
}
pub async fn dispatch(&self, mut ctx: HttpContext) -> Result<HttpContext, SaphirError> {
let req = ctx.state.take_request().ok_or(SaphirError::RequestMovedBeforeHandler)?;
let static_self = unsafe { std::mem::transmute::<&'_ Self, &'static Self>(self) };
let b = crate::response::Builder::new();
let route_id = match ctx.metadata.route_id {
RouteId::Id(id) => id,
RouteId::Error(e) => {
return e.respond_with_builder(b, &ctx).build().map(|r| {
ctx.state = State::After(Box::new(r));
ctx
});
}
};
let res = if let Some(responder) = static_self.inner.chain.dispatch(route_id, req) {
responder.await.dyn_respond(b, &ctx)
} else {
404.respond_with_builder(b, &ctx)
}
.build();
res.map(|r| {
ctx.state = State::After(Box::new(r));
ctx
})
}
}
#[doc(hidden)]
pub trait RouterChain {
fn dispatch(&'static self, resolver_id: u64, req: Request<Body>) -> Option<BoxFuture<'static, Box<dyn DynResponder + Send>>>;
fn add_handler(&mut self, endpoint_id: u64, method: Method, handler: Box<dyn DynHandler<Body> + Send + Sync>, guards: Box<dyn GuardChain>);
}
#[doc(hidden)]
pub struct RouterChainEnd {
handlers: HashMap<(u64, Method), (Box<dyn DynHandler<Body> + Send + Sync>, Box<dyn GuardChain>)>,
}
impl RouterChain for RouterChainEnd {
#[inline]
fn dispatch(&'static self, resolver_id: u64, req: Request<Body>) -> Option<BoxFuture<'static, Box<dyn DynResponder + Send>>> {
if let Some(handler) = self.handlers.get(&(resolver_id, req.method().clone())) {
if handler.1.is_end() {
Some(handler.0.dyn_handle(req))
} else {
let fut = handler.1.validate(req).then(move |req| async move {
match req {
Ok(req) => handler.0.dyn_handle(req).await,
Err(resp) => resp,
}
});
Some(fut.boxed())
}
} else {
None
}
}
#[inline]
fn add_handler(&mut self, endpoint_id: u64, method: Method, handler: Box<dyn DynHandler<Body> + Send + Sync>, guards: Box<dyn GuardChain>) {
self.handlers.insert((endpoint_id, method), (handler, guards));
}
}
#[doc(hidden)]
pub struct RouterChainLink<C, Rest: RouterChain> {
controller: C,
handlers: HashMap<(u64, Method), (Box<dyn DynControllerHandler<C, Body> + Send + Sync>, Box<dyn GuardChain>)>,
rest: Rest,
}
impl<C: Sync + Send, Rest: RouterChain + Sync + Send> RouterChain for RouterChainLink<C, Rest> {
#[inline]
fn dispatch(&'static self, resolver_id: u64, req: Request<Body>) -> Option<BoxFuture<'static, Box<dyn DynResponder + Send>>> {
if let Some(handler) = self.handlers.get(&(resolver_id, req.method().clone())) {
if handler.1.is_end() {
Some(handler.0.dyn_handle(&self.controller, req))
} else {
let fut = handler.1.validate(req).then(move |req| async move {
match req {
Ok(req) => handler.0.dyn_handle(&self.controller, req).await,
Err(resp) => resp,
}
});
Some(fut.boxed())
}
} else {
self.rest.dispatch(resolver_id, req)
}
}
#[inline]
fn add_handler(&mut self, endpoint_id: u64, method: Method, handler: Box<dyn DynHandler<Body> + Send + Sync>, guards: Box<dyn GuardChain>) {
self.rest.add_handler(endpoint_id, method, handler, guards);
}
}