use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use hyper::Method;
use crate::error::Error;
use crate::http::{Request, Response};
use crate::middleware::{MiddlewareFn, Next};
type BoxFuture<T> = Pin<Box<dyn Future<Output = T> + Send>>;
type HandlerFn = Arc<dyn Fn(Request, Params) -> BoxFuture<Result<Response, Error>> + Send + Sync>;
pub struct Params {
pairs: Vec<(String, String)>,
}
impl Params {
fn empty() -> Self {
Self { pairs: Vec::new() }
}
pub fn get(&self, name: &str) -> Option<&str> {
self.pairs
.iter()
.find_map(|(k, v)| (k == name).then_some(v.as_str()))
}
pub fn len(&self) -> usize {
self.pairs.len()
}
pub fn is_empty(&self) -> bool {
self.pairs.is_empty()
}
}
enum Segment {
Literal(String),
Param(String),
}
struct Route {
method: Method,
segments: Vec<Segment>,
handler: HandlerFn,
}
pub struct Router {
routes: Vec<Route>,
middlewares: Vec<MiddlewareFn>,
}
impl Router {
pub fn new() -> Self {
Self {
routes: Vec::new(),
middlewares: Vec::new(),
}
}
pub fn wrap<F, Fut>(mut self, middleware: F) -> Self
where
F: Fn(Request, Next) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<Response, Error>> + Send + 'static,
{
self.middlewares
.push(Arc::new(move |req, next| Box::pin(middleware(req, next))));
self
}
pub fn get<F, Fut>(self, path: &str, handler: F) -> Self
where
F: Fn(Request, Params) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<Response, Error>> + Send + 'static,
{
self.route(Method::GET, path, handler)
}
pub fn post<F, Fut>(self, path: &str, handler: F) -> Self
where
F: Fn(Request, Params) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<Response, Error>> + Send + 'static,
{
self.route(Method::POST, path, handler)
}
fn route<F, Fut>(mut self, method: Method, path: &str, handler: F) -> Self
where
F: Fn(Request, Params) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<Response, Error>> + Send + 'static,
{
let handler: HandlerFn = Arc::new(move |req, params| Box::pin(handler(req, params)));
self.routes.push(Route {
method,
segments: parse_path(path),
handler,
});
self
}
pub fn has_route(&self, method: &Method, path: &str) -> bool {
let target = parse_path(path);
self.routes
.iter()
.any(|r| r.method == *method && segments_equal(&r.segments, &target))
}
pub async fn dispatch(&self, req: Request) -> Response {
let path = req.uri().path().to_owned();
let actual: Vec<&str> = path.split('/').filter(|s| !s.is_empty()).collect();
let method = req.method().clone();
let mut found: Option<(HandlerFn, Params)> = None;
let mut path_matched = false;
for route in &self.routes {
if let Some(params) = match_segments(&route.segments, &actual) {
path_matched = true;
if route.method == method {
found = Some((route.handler.clone(), params));
break;
}
}
}
let (handler, params) = found.unwrap_or_else(|| {
let method_not_allowed = path_matched;
let fallback: HandlerFn = Arc::new(move |_req, _params| {
let err = if method_not_allowed {
Error::MethodNotAllowed
} else {
Error::NotFound
};
Box::pin(async move { Err(err) })
});
(fallback, Params::empty())
});
let chain = build_chain(&self.middlewares, handler, params);
match chain(req).await {
Ok(resp) => resp,
Err(err) => err.into_response(),
}
}
}
impl Default for Router {
fn default() -> Self {
Self::new()
}
}
fn build_chain(
middlewares: &[MiddlewareFn],
handler: HandlerFn,
params: Params,
) -> Box<dyn FnOnce(Request) -> BoxFuture<Result<Response, Error>> + Send> {
let mut chain: Box<dyn FnOnce(Request) -> BoxFuture<Result<Response, Error>> + Send> =
Box::new(move |req| handler(req, params));
for mw in middlewares.iter().rev() {
let mw = mw.clone();
let inner = chain;
chain = Box::new(move |req| {
let next = Next::new(inner);
mw(req, next)
});
}
chain
}
fn parse_path(path: &str) -> Vec<Segment> {
path.split('/')
.filter(|s| !s.is_empty())
.map(|s| match s.strip_prefix(':') {
Some(name) => Segment::Param(name.to_owned()),
None => Segment::Literal(s.to_owned()),
})
.collect()
}
fn segments_equal(a: &[Segment], b: &[Segment]) -> bool {
if a.len() != b.len() {
return false;
}
a.iter().zip(b.iter()).all(|(x, y)| match (x, y) {
(Segment::Literal(u), Segment::Literal(v)) => u == v,
(Segment::Param(u), Segment::Param(v)) => u == v,
_ => false,
})
}
fn match_segments(patterns: &[Segment], actual: &[&str]) -> Option<Params> {
if patterns.len() != actual.len() {
return None;
}
let mut params = Params::empty();
for (pat, seg) in patterns.iter().zip(actual.iter()) {
match pat {
Segment::Literal(lit) => {
if lit != seg {
return None;
}
}
Segment::Param(name) => {
params.pairs.push((name.clone(), (*seg).to_owned()));
}
}
}
Some(params)
}
#[cfg(test)]
mod tests {
use super::*;
fn segs(path: &str) -> Vec<Segment> {
parse_path(path)
}
fn parts(path: &str) -> Vec<&str> {
path.split('/').filter(|s| !s.is_empty()).collect()
}
#[test]
fn root_path_is_empty_segment_list() {
assert!(parse_path("/").is_empty());
}
#[test]
fn literal_match() {
assert!(match_segments(&segs("/users"), &parts("/users")).is_some());
assert!(match_segments(&segs("/users"), &parts("/posts")).is_none());
}
#[test]
fn param_captures_value() {
let params = match_segments(&segs("/users/:id"), &parts("/users/42")).unwrap();
assert_eq!(params.get("id"), Some("42"));
}
#[test]
fn length_mismatch_does_not_match() {
assert!(match_segments(&segs("/users/:id"), &parts("/users")).is_none());
assert!(match_segments(&segs("/users"), &parts("/users/42")).is_none());
}
#[test]
fn multiple_params_captured_by_name() {
let params = match_segments(&segs("/a/:x/b/:y"), &parts("/a/first/b/second")).unwrap();
assert_eq!(params.get("x"), Some("first"));
assert_eq!(params.get("y"), Some("second"));
}
#[test]
fn has_route_detects_registered_path() {
let router = Router::new()
.get("/", |_req, _p| async {
Ok::<Response, Error>(crate::http::text("home"))
})
.get("/users/:id", |_req, _p| async {
Ok::<Response, Error>(crate::http::text("user"))
});
assert!(router.has_route(&Method::GET, "/"));
assert!(router.has_route(&Method::GET, "/users/:id"));
assert!(!router.has_route(&Method::GET, "/missing"));
assert!(!router.has_route(&Method::POST, "/"));
assert!(router.has_route(&Method::GET, "/users/:id"));
assert!(!router.has_route(&Method::GET, "/users/:other"));
}
}