#[cfg(feature = "websocket")]
use crate::ws::WebSocket;
use crate::{
context::{Context, HandlerMetadata, RouteId, State},
controller::{Controller, ControllerHandler},
error::Error,
handler::Handler,
request::Request,
resolver::{Resolver, ResolverResult},
responder::Responder,
response::Builder,
Result,
};
use async_trait::async_trait;
use std::{collections::HashMap, sync::Arc};
#[async_trait]
pub trait Route {
async fn dispatch(&'static self, resolver_id: u64, req: Request) -> Option<Builder>;
fn add_handler(
&mut self,
handler_id: u64,
method: &'static str,
handler: Box<dyn Handler + Send + Sync>,
);
}
pub struct RouterBuilder<T>
where
T: Route + Send + Sync + Unpin + 'static,
{
resolver: HashMap<String, Resolver>,
route: T,
}
impl Default for RouterBuilder<HandlerChain> {
#[inline]
fn default() -> Self {
Self {
resolver: HashMap::new(),
route: HandlerChain {
handlers: HashMap::new(),
},
}
}
}
macro_rules! method {
($($(#[$m:meta])* $v:vis fn $n:ident = $method:expr;)+) => {
$(
$(#[$m])* $v fn $n<H>(self,path: &str,handler: H) -> Self
where
H: 'static + Handler + Send + Sync,
{
self.add(path,$method,handler)
}
)+
};
}
impl<T> RouterBuilder<T>
where
T: 'static + Route + Send + Sync + Unpin,
{
#[inline]
fn add<H>(mut self, path: &str, method: &'static str, handler: H) -> Self
where
H: 'static + Handler + Send + Sync,
{
let handler_id = if let Some(resolver) = self.resolver.get_mut(path) {
let _ = resolver.add_method(method);
resolver.id()
} else {
let resolver = Resolver::new(path, method).expect("Unable to construct resolver");
let resolver_id = resolver.id();
self.resolver.insert(path.to_string(), resolver);
resolver_id
};
self.route
.add_handler(handler_id, method, Box::new(handler));
self
}
method![
#[inline]
pub fn get = "GET";
#[inline]
pub fn post = "POST";
#[inline]
pub fn options = "OPTIONS";
#[inline]
pub fn put = "PUT";
#[inline]
pub fn delete = "DELETE";
#[inline]
pub fn head = "HEAD";
#[inline]
pub fn trace = "TRACE";
#[inline]
pub fn connect = "CONNECT";
#[inline]
pub fn patch = "PATCH";
];
#[cfg(feature = "static_file")]
#[inline]
pub fn static_dir(self, uri_path: &str, dir_path: impl Into<std::path::PathBuf>) -> Self {
self.add(uri_path, "GET", crate::static_dir::StaticDir::new(dir_path))
}
#[cfg(feature = "static_file")]
#[inline]
pub fn static_file(self, uri_path: &str, dir_path: impl Into<std::path::PathBuf>) -> Self {
self.add(uri_path, "GET", crate::fs::NamedFileBuilder::new(dir_path))
}
#[cfg(feature = "websocket")]
#[inline]
pub fn ws<H, F>(self, path: &str, handler: H) -> Self
where
H: Send + Sync + 'static + Fn(Request,WebSocket) -> F,
F: std::future::Future<Output = Result<()>> + Send + 'static,
{
self.add(path, "GET", crate::ws::new_ws(handler))
}
#[inline]
pub fn controller<C>(mut self, controller: C) -> RouterBuilder<ControllerChain<C, T>>
where
C: Controller + Send + Sync + Unpin,
{
let mut handlers = HashMap::new();
for (name, method, sub_path, handler) in controller.method() {
let path = format!("{}{}", C::BASE_PATH, sub_path);
let meta = name.map(|name| HandlerMetadata {
route_id: Default::default(),
name: Some(name),
});
let handler_id = if let Some(resolver) = self.resolver.get_mut(&path) {
let _ = resolver.add_method_with_metadata(method, meta);
resolver.id()
} else {
let resolver = Resolver::new_with_metadata(&path, method, meta)
.expect("Unable to construct resolver");
let resolver_id = resolver.id();
self.resolver.insert(path, resolver);
resolver_id
};
handlers.insert((handler_id, method), handler);
}
RouterBuilder {
resolver: self.resolver,
route: ControllerChain {
controller,
handlers,
route: self.route,
},
}
}
#[inline]
pub(crate) fn build(self) -> Router {
let RouterBuilder {
resolver,
route: controllers,
..
} = self;
let mut resolvers = resolver.into_values().collect::<Vec<_>>();
resolvers.sort_unstable();
Router {
inner: Arc::new((resolvers, Box::new(controllers))),
}
}
}
#[doc(hidden)]
#[derive(Clone)]
pub struct Router {
inner: Arc<(Vec<Resolver>, Box<dyn Route + Send + Sync + Unpin>)>,
}
#[allow(dead_code)]
impl<'a> Router {
#[inline]
pub fn builder() -> RouterBuilder<HandlerChain> {
RouterBuilder::default()
}
#[inline]
pub fn resolve(&self, req: &mut Request) -> Result<u64, u16> {
let mut method_not_allowed = false;
for resolver in &self.inner.0 {
match resolver.resolve(req) {
ResolverResult::InvalidPath => continue,
ResolverResult::MethodNotAllowed => method_not_allowed = true,
ResolverResult::Match(_) => return Ok(resolver.id()),
}
}
if method_not_allowed {
Err(405)
} else {
Err(404)
}
}
#[inline]
pub fn resolve_metadata(&self, req: &mut Request) -> HandlerMetadata {
let mut method_not_allowed = false;
for resolver in &self.inner.0 {
match resolver.resolve(req) {
ResolverResult::InvalidPath => continue,
ResolverResult::MethodNotAllowed => method_not_allowed = true,
ResolverResult::Match(meta) => return meta.clone(),
}
}
if method_not_allowed {
HandlerMetadata::not_allowed()
} else {
HandlerMetadata::not_found()
}
}
#[inline]
pub async fn dispatch(&self, mut ctx: Context) -> Result<Context, Error> {
let req = ctx.state.take_request()?;
let static_self = unsafe { std::mem::transmute::<&'_ Self, &'static Self>(self) };
let builder = Builder::new();
let route_id = match ctx.metadata.route_id {
RouteId::Id(id) => id,
RouteId::Error(e) => {
return e.response(builder).build().map(|r| {
ctx.state = State::After(Box::new(r));
ctx
});
}
};
let res = if let Some(responder) = static_self.inner.1.dispatch(route_id, req).await {
responder.response(builder)
} else {
404.response(builder)
}.build();
res.map(|r| {
ctx.state = State::After(Box::new(r));
ctx
})
}
}
#[doc(hidden)]
pub struct HandlerChain {
handlers: HashMap<(u64, &'static str), Box<dyn Handler + Send + Sync>>,
}
#[async_trait]
impl Route for HandlerChain {
#[inline]
async fn dispatch(&'static self, resolver_id: u64, req: Request) -> Option<Builder> {
if let Some(handler) = self
.handlers
.get(&(resolver_id, req.method().clone().as_str()))
{
Some(handler.handle(req).await)
} else {
None
}
}
#[inline]
fn add_handler(
&mut self,
endpoint_id: u64,
method: &'static str,
handler: Box<dyn Handler + Send + Sync>,
) {
self.handlers.insert((endpoint_id, method), handler);
}
}
#[doc(hidden)]
pub struct ControllerChain<C, T: Route> {
controller: C,
route: T,
handlers: HashMap<(u64, &'static str), Box<dyn ControllerHandler<C> + Send + Sync>>,
}
#[async_trait]
impl<C: Sync + Send, Rest: Route + Sync + Send> Route for ControllerChain<C, Rest> {
#[inline]
async fn dispatch(&'static self, resolver_id: u64, req: Request) -> Option<Builder> {
if let Some(handler) = self
.handlers
.get(&(resolver_id, req.method().clone().as_str()))
{
Some(handler.handle(&self.controller, req).await)
} else {
self.route.dispatch(resolver_id, req).await
}
}
#[inline]
fn add_handler(
&mut self,
handler_id: u64,
method: &'static str,
handler: Box<dyn Handler + Send + Sync>,
) {
self.route.add_handler(handler_id, method, handler);
}
}