use crate::path::clean;
use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::str;
use std::sync::Arc;
use std::task::{Context, Poll};
use futures_util::{future, ready};
use hyper::service::Service;
use hyper::{header, Body, Method, Request, Response, StatusCode};
use matchit::{Match, Node};
pub struct Router<'path> {
trees: HashMap<Method, Node<'path, Box<dyn Handler>>>,
redirect_trailing_slash: bool,
redirect_fixed_path: bool,
handle_method_not_allowed: bool,
handle_options: bool,
global_options: Option<Box<dyn Handler>>,
not_found: Option<Box<dyn Handler>>,
method_not_allowed: Option<Box<dyn Handler>>,
}
impl<'path> Router<'path> {
pub fn handle(
mut self,
path: &'path str,
method: Method,
handler: impl Handler + 'static,
) -> Self {
if !path.starts_with('/') {
panic!("expect path beginning with '/', found: '{}'", path);
}
self.trees
.entry(method)
.or_insert_with(Node::default)
.insert(path, Box::new(handler));
self
}
pub fn lookup(
&self,
method: Method,
path: impl AsRef<str>,
) -> Result<Match<'_, Box<dyn Handler>>, matchit::Tsr> {
self.trees
.get(&method)
.map_or(Err(matchit::Tsr::No), |n| n.at(path))
}
pub fn serve_files() {
unimplemented!()
}
pub fn get(self, path: &'path str, handler: impl Handler + 'static) -> Self {
self.handle(path, Method::GET, handler)
}
pub fn head(self, path: &'path str, handler: impl Handler + 'static) -> Self {
self.handle(path, Method::HEAD, handler)
}
pub fn options(self, path: &'path str, handler: impl Handler + 'static) -> Self {
self.handle(path, Method::OPTIONS, handler)
}
pub fn post(self, path: &'path str, handler: impl Handler + 'static) -> Self {
self.handle(path, Method::POST, handler)
}
pub fn put(self, path: &'path str, handler: impl Handler + 'static) -> Self {
self.handle(path, Method::PUT, handler)
}
pub fn patch(self, path: &'path str, handler: impl Handler + 'static) -> Self {
self.handle(path, Method::PATCH, handler)
}
pub fn delete(self, path: &'path str, handler: impl Handler + 'static) -> Self {
self.handle(path, Method::DELETE, handler)
}
pub fn redirect_trailing_slash(mut self) -> Self {
self.redirect_trailing_slash = true;
self
}
pub fn redirect_fixed_path(mut self) -> Self {
self.redirect_fixed_path = true;
self
}
pub fn handle_method_not_allowed(mut self) -> Self {
self.handle_method_not_allowed = true;
self
}
pub fn handle_options(mut self) -> Self {
self.handle_options = true;
self
}
pub fn global_options(mut self, handler: impl Handler + 'static) -> Self {
self.global_options = Some(Box::new(handler));
self
}
pub fn not_found(mut self, handler: impl Handler + 'static) -> Self {
self.not_found = Some(Box::new(handler));
self
}
pub fn method_not_allowed(mut self, handler: impl Handler + 'static) -> Self {
self.method_not_allowed = Some(Box::new(handler));
self
}
pub fn allowed(&self, path: &'path str) -> Vec<&str> {
let mut allowed = match path {
"*" => {
let mut allowed = Vec::with_capacity(self.trees.len());
for method in self
.trees
.keys()
.filter(|&method| method != Method::OPTIONS)
{
allowed.push(method.as_ref());
}
allowed
}
_ => self
.trees
.keys()
.filter(|&method| method != Method::OPTIONS)
.filter(|&method| {
self.trees
.get(method)
.map(|node| node.at(path).is_ok())
.unwrap_or(false)
})
.map(AsRef::as_ref)
.collect::<Vec<_>>(),
};
if !allowed.is_empty() {
allowed.push(Method::OPTIONS.as_ref())
}
allowed
}
}
impl Default for Router<'_> {
fn default() -> Self {
Self {
trees: HashMap::new(),
redirect_trailing_slash: true,
redirect_fixed_path: true,
handle_method_not_allowed: true,
handle_options: true,
global_options: None,
method_not_allowed: None,
not_found: Some(Box::new(|_| async {
Ok(Response::builder()
.status(400)
.body(Body::from("404: Not Found"))
.unwrap())
})),
}
}
}
pub trait Handler: Send + Sync {
fn handle(
&self,
req: Request<Body>,
) -> Pin<Box<dyn Future<Output = hyper::Result<Response<Body>>> + Send + Sync>>;
}
impl<F, R> Handler for F
where
F: Fn(Request<Body>) -> R + Send + Sync,
R: Future<Output = Result<Response<Body>, hyper::Error>> + Send + Sync + 'static,
{
fn handle(
&self,
req: Request<Body>,
) -> Pin<Box<dyn Future<Output = hyper::Result<Response<Body>>> + Send + Sync>> {
Box::pin(self(req))
}
}
#[doc(hidden)]
pub struct MakeRouterService<'path>(RouterService<'path>);
impl<'path, T> Service<T> for MakeRouterService<'path> {
type Response = RouterService<'path>;
type Error = hyper::Error;
type Future = future::Ready<Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, _: T) -> Self::Future {
let service = self.0.clone();
future::ok(service)
}
}
#[doc(hidden)]
#[derive(Clone)]
pub struct RouterService<'path>(Arc<Router<'path>>);
impl<'path> RouterService<'path> {
fn new(router: Router<'path>) -> Self {
RouterService(Arc::new(router))
}
}
impl<'path> Service<Request<Body>> for RouterService<'path> {
type Response = Response<Body>;
type Error = hyper::Error;
type Future = ResponseFut;
fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: Request<Body>) -> Self::Future {
self.0.serve(req)
}
}
impl<'path> Router<'path> {
pub fn into_service(self) -> MakeRouterService<'path> {
MakeRouterService(RouterService::new(self))
}
pub fn serve(&self, mut req: Request<Body>) -> ResponseFut {
let root = self.trees.get(req.method());
let path = req.uri().path();
if let Some(root) = root {
match root.at(path) {
Ok(lookup) => {
req.extensions_mut().insert(lookup.params);
return ResponseFutKind::Boxed(lookup.value.handle(req)).into();
}
Err(tsr) => {
if req.method() != Method::CONNECT && path != "/" {
let code = match *req.method() {
Method::GET => StatusCode::MOVED_PERMANENTLY,
_ => StatusCode::PERMANENT_REDIRECT,
};
if tsr == matchit::Tsr::Yes && self.redirect_trailing_slash {
let path = if path.len() > 1 && path.ends_with('/') {
path[..path.len() - 1].to_owned()
} else {
[path, "/"].join("")
};
return ResponseFutKind::Redirect(path, code).into();
}
if self.redirect_fixed_path {
if let Some(fixed_path) =
root.path_ignore_case(clean(path), self.redirect_trailing_slash)
{
return ResponseFutKind::Redirect(fixed_path, code).into();
}
}
}
}
}
}
if req.method() == Method::OPTIONS && self.handle_options {
let allow = self.allowed(path);
if !allow.is_empty() {
return match self.global_options {
Some(ref handler) => ResponseFutKind::Boxed(handler.handle(req)).into(),
None => ResponseFutKind::Options(allow.join(", ")).into(),
};
}
} else if self.handle_method_not_allowed {
let allow = self.allowed(path);
if !allow.is_empty() {
return match self.method_not_allowed {
Some(ref handler) => ResponseFutKind::Boxed(handler.handle(req)).into(),
None => ResponseFutKind::MethodNotAllowed(allow.join(", ")).into(),
};
}
}
match self.not_found {
Some(ref handler) => ResponseFutKind::Boxed(handler.handle(req)).into(),
None => ResponseFutKind::NotFound.into(),
}
}
}
pub struct ResponseFut {
kind: ResponseFutKind,
}
impl From<ResponseFutKind> for ResponseFut {
fn from(kind: ResponseFutKind) -> Self {
Self { kind }
}
}
enum ResponseFutKind {
Boxed(Pin<Box<dyn Future<Output = hyper::Result<Response<Body>>> + Send + Sync>>),
Redirect(String, StatusCode),
MethodNotAllowed(String),
Options(String),
NotFound,
}
impl Future for ResponseFut {
type Output = hyper::Result<Response<Body>>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let ready = match self.kind {
ResponseFutKind::Boxed(ref mut fut) => ready!(fut.as_mut().poll(cx)),
ResponseFutKind::Redirect(ref path, code) => Ok(Response::builder()
.header(header::LOCATION, path.as_str())
.status(code)
.body(Body::empty())
.unwrap()),
ResponseFutKind::NotFound => Ok(Response::builder()
.status(StatusCode::NOT_FOUND)
.body(Body::empty())
.unwrap()),
ResponseFutKind::Options(ref allowed) => Ok(Response::builder()
.header(header::ALLOW, allowed)
.body(Body::empty())
.unwrap()),
ResponseFutKind::MethodNotAllowed(ref allowed) => Ok(Response::builder()
.header(header::ALLOW, allowed)
.status(StatusCode::METHOD_NOT_ALLOWED)
.body(Body::empty())
.unwrap()),
};
Poll::Ready(ready)
}
}