#[cfg(feature = "websocket")]
use crate::ws::{WebSocketReceiver, WebSocketSender};
use crate::{
context::{Context, HandlerMetadata, RouteId, State},
controller::{Controller, ControllerHandler},
error::Error,
handler::Handler,
request::Request,
resolver::{Resolver, ResolverResult},
responder::Responder,
response::Response,
Result,
};
use async_trait::async_trait;
use std::{collections::HashMap, sync::Arc};
#[async_trait]
pub trait Route {
async fn dispatch(&'static self, resolver_id: u64, req: Request) -> Option<Response>;
fn add_handler(
&mut self,
handler_id: u64,
method: &'static str,
handler: Box<dyn Handler + Send + Sync>,
);
}
pub struct RouterBuilder<'a, T>
where
T: Route + Send + Sync + Unpin + 'static,
{
prefix: &'a str,
resolver: HashMap<String, Resolver>,
route: T,
}
impl<'a> Default for RouterBuilder<'a, HandlerChain> {
#[inline]
fn default() -> Self {
Self {
prefix: "",
resolver: Default::default(),
route: HandlerChain {
handlers: Default::default(),
},
}
}
}
macro_rules! method {
($($(#[$m:meta])* $v:vis fn $n:ident = $method:expr;)+) => {
$(
$(#[$m])* $v fn $n<H>(self,path: &str,handler: H) -> Self
where
H: 'static + Handler + Send + Sync,
{
self.add(path,$method,handler)
}
)+
};
}
impl<'a, T> RouterBuilder<'a, T>
where
T: 'static + Route + Send + Sync + Unpin,
{
#[inline]
fn add<H>(mut self, path: &str, method: &'static str, handler: H) -> Self
where
H: 'static + Handler + Send + Sync,
{
let path = &format!("{}/{}", self.prefix, path);
let handler_id = if let Some(resolver) = self.resolver.get_mut(path) {
let _ = resolver.add_method(method);
resolver.id()
} else {
let resolver =
Resolver::new(path, method).expect("Unable to construct resolver");
let resolver_id = resolver.id();
self.resolver.insert(path.to_string(), resolver);
resolver_id
};
self.route
.add_handler(handler_id, method, Box::new(handler));
self
}
#[inline]
pub fn group(self, prefix: &str) -> RouterBuilder<HandlerChain> {
RouterBuilder {
prefix,
resolver: Default::default(),
route: HandlerChain {
handlers: Default::default(),
},
}
}
method![
#[inline]
pub fn get = "GET";
#[inline]
pub fn post = "POST";
#[inline]
pub fn options = "OPTIONS";
#[inline]
pub fn put = "PUT";
#[inline]
pub fn delete = "DELETE";
#[inline]
pub fn head = "HEAD";
#[inline]
pub fn trace = "TRACE";
#[inline]
pub fn connect = "CONNECT";
#[inline]
pub fn patch = "PATCH";
];
#[cfg(feature = "static_file")]
#[inline]
pub fn static_dir(self, uri_path: &str, dir_path: impl Into<std::path::PathBuf>) -> Self {
self.add(
uri_path,
"GET",
crate::static_dir::StaticDir::new(dir_path),
)
}
#[cfg(feature = "static_file")]
#[inline]
pub fn static_file(self, uri_path: &str, dir_path: impl Into<std::path::PathBuf>) -> Self {
self.add(
uri_path,
"GET",
crate::fs::NamedFileBuilder::new(dir_path),
)
}
#[cfg(feature = "websocket")]
#[inline]
pub fn ws<H, F>(self, path: &str, handler: H) -> Self
where
H: Send + Sync + 'static + Fn(Request, WebSocketSender, WebSocketReceiver) -> F,
F: std::future::Future<Output = Result<()>> + Send + 'static,
{
self.add(path, "GET", crate::ws::new_ws(handler))
}
#[inline]
pub fn controller<C>(mut self, controller: C) -> RouterBuilder<'a, ControllerChain<C, T>>
where
C: Controller + Send + Sync + Unpin,
{
let mut handlers = HashMap::new();
for (name, method, sub_path, handler) in controller.method() {
let path = format!("{}{}", C::BASE_PATH, sub_path);
let meta = name.map(|name| HandlerMetadata {
route_id: Default::default(),
name: Some(name),
});
let handler_id = if let Some(resolver) = self.resolver.get_mut(&path) {
let _ = resolver.add_method_with_metadata(method, meta);
resolver.id()
} else {
let resolver = Resolver::new_with_metadata(&path, method, meta)
.expect("Unable to construct resolver");
let resolver_id = resolver.id();
self.resolver.insert(path, resolver);
resolver_id
};
handlers.insert((handler_id, method), handler);
}
RouterBuilder {
prefix: "",
resolver: self.resolver,
route: ControllerChain {
controller,
handlers,
route: self.route,
},
}
}
#[inline]
pub(crate) fn build(self) -> Router {
let RouterBuilder {
resolver,
route: controllers,
..
} = self;
let mut resolvers = resolver.into_values().collect::<Vec<_>>();
resolvers.sort_unstable();
Router {
inner: Arc::new((resolvers, Box::new(controllers))),
}
}
}
#[doc(hidden)]
#[derive(Clone)]
pub struct Router {
inner: Arc<(Vec<Resolver>, Box<dyn Route + Send + Sync + Unpin>)>,
}
#[allow(dead_code)]
impl<'a> Router {
#[inline]
pub fn builder() -> RouterBuilder<'a, HandlerChain> {
RouterBuilder::default()
}
#[inline]
pub fn resolve(&self, req: &mut Request) -> Result<u64, u16> {
let mut method_not_allowed = false;
for resolver in &self.inner.0 {
match resolver.resolve(req) {
ResolverResult::InvalidPath => continue,
ResolverResult::MethodNotAllowed => method_not_allowed = true,
ResolverResult::Match(_) => return Ok(resolver.id()),
}
}
if method_not_allowed {
Err(405)
} else {
Err(404)
}
}
#[inline]
pub fn resolve_metadata(&self, req: &mut Request) -> HandlerMetadata {
let mut method_not_allowed = false;
for resolver in &self.inner.0 {
match resolver.resolve(req) {
ResolverResult::InvalidPath => continue,
ResolverResult::MethodNotAllowed => method_not_allowed = true,
ResolverResult::Match(meta) => return meta.clone(),
}
}
if method_not_allowed {
HandlerMetadata::not_allowed()
} else {
HandlerMetadata::not_found()
}
}
#[inline]
pub async fn dispatch(&self, mut ctx: Context) -> Result<Context, Error> {
let req = ctx.state.take_request()?;
let static_self = unsafe { std::mem::transmute::<&'_ Self, &'static Self>(self) };
let builder = Response::new(crate::body::OutgoingBody::Empty);
let route_id = match ctx.metadata.route_id {
RouteId::Id(id) => id,
RouteId::Error(e) => {
let r = e.response(builder);
ctx.state = State::After(Box::new(r));
return Ok(ctx);
}
};
let res = if let Some(responder) = static_self.inner.1.dispatch(route_id, req).await {
responder.response(builder)
} else {
404.response(builder)
};
ctx.state = State::After(Box::new(res));
Ok(ctx)
}
}
#[doc(hidden)]
pub struct HandlerChain {
handlers: HashMap<(u64, &'static str), Box<dyn Handler + Send + Sync>>,
}
#[async_trait]
impl Route for HandlerChain {
#[inline]
async fn dispatch(&'static self, resolver_id: u64, req: Request) -> Option<Response> {
if let Some(handler) = self.handlers.get(&(resolver_id, req.method().clone().as_str())) {
Some(handler.handle(req).await)
} else {
None
}
}
#[inline]
fn add_handler(
&mut self,
endpoint_id: u64,
method: &'static str,
handler: Box<dyn Handler + Send + Sync>,
) {
self.handlers.insert((endpoint_id, method), handler);
}
}
#[doc(hidden)]
pub struct ControllerChain<C, T: Route> {
controller: C,
route: T,
handlers: HashMap<(u64, &'static str), Box<dyn ControllerHandler<C> + Send + Sync>>,
}
#[async_trait]
impl<C: Sync + Send, Rest: Route + Sync + Send> Route for ControllerChain<C, Rest> {
#[inline]
async fn dispatch(&'static self, resolver_id: u64, req: Request) -> Option<Response> {
if let Some(handler) = self.handlers.get(&(resolver_id, req.method().clone().as_str())) {
Some(handler.handle(&self.controller, req).await)
} else {
self.route.dispatch(resolver_id, req).await
}
}
#[inline]
fn add_handler(
&mut self,
handler_id: u64,
method: &'static str,
handler: Box<dyn Handler + Send + Sync>,
) {
self.route.add_handler(handler_id, method, handler);
}
}