use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use hyper::Method;
use crate::error::{Error, Result};
use crate::http::{response_from_error, Request, Response};
pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
pub type HandlerFn =
Arc<dyn Fn(Request) -> BoxFuture<'static, Result<Response>> + Send + Sync + 'static>;
pub type MiddlewareFn =
Arc<dyn Fn(Request, Next) -> BoxFuture<'static, Result<Response>> + Send + Sync + 'static>;
pub struct Next {
chain: Vec<MiddlewareFn>,
handler: HandlerFn,
index: usize,
}
impl Next {
pub fn run(mut self, req: Request) -> BoxFuture<'static, Result<Response>> {
Box::pin(async move {
if self.index < self.chain.len() {
let mw = self.chain[self.index].clone();
self.index += 1;
mw(req, self).await
} else {
(self.handler)(req).await
}
})
}
}
struct Route {
method: Method,
segments: Vec<Segment>,
handler: HandlerFn,
}
enum Segment {
Static(String),
Param(String),
}
pub struct Router {
routes: Vec<Route>,
middleware: Vec<MiddlewareFn>,
}
impl Default for Router {
fn default() -> Self {
Self::new()
}
}
impl Router {
pub fn new() -> Self {
Self {
routes: Vec::new(),
middleware: Vec::new(),
}
}
pub fn middleware<F, Fut>(mut self, mw: F) -> Self
where
F: Fn(Request, Next) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<Response>> + Send + 'static,
{
let wrapped: MiddlewareFn = Arc::new(move |req, next| Box::pin(mw(req, next)));
self.middleware.push(wrapped);
self
}
pub fn get<F, Fut>(self, path: &str, handler: F) -> Self
where
F: Fn(Request) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<Response>> + Send + 'static,
{
self.route(Method::GET, path, handler)
}
pub fn post<F, Fut>(self, path: &str, handler: F) -> Self
where
F: Fn(Request) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<Response>> + Send + 'static,
{
self.route(Method::POST, path, handler)
}
pub fn route<F, Fut>(mut self, method: Method, path: &str, handler: F) -> Self
where
F: Fn(Request) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<Response>> + Send + 'static,
{
let segments = parse_path(path);
let handler: HandlerFn = Arc::new(move |req| Box::pin(handler(req)));
self.routes.push(Route {
method,
segments,
handler,
});
self
}
fn find(&self, method: &Method, path: &str) -> MatchResult {
let mut path_segs: Vec<&str> = path.trim_start_matches('/').split('/').collect();
if path_segs.len() > 1 && path_segs.last() == Some(&"") {
path_segs.pop();
}
let mut path_matched = false;
for route in &self.routes {
if !segments_match(&route.segments, &path_segs) {
continue;
}
path_matched = true;
if route.method == *method {
let params = extract_params(&route.segments, &path_segs);
return MatchResult::Ok {
handler: route.handler.clone(),
params,
};
}
}
if path_matched {
MatchResult::MethodNotAllowed
} else {
MatchResult::NotFound
}
}
pub async fn dispatch(&self, mut req: Request) -> Response {
let matched = self.find(req.method(), req.path());
let outcome = match matched {
MatchResult::Ok { handler, params } => {
req.set_params(params);
let next = Next {
chain: self.middleware.clone(),
handler,
index: 0,
};
next.run(req).await
}
MatchResult::NotFound => Err(Error::NotFound(format!("no route for {}", req.path()))),
MatchResult::MethodNotAllowed => Err(Error::MethodNotAllowed(format!(
"{} not allowed",
req.method()
))),
};
match outcome {
Ok(resp) => resp,
Err(err) => response_from_error(&err),
}
}
}
enum MatchResult {
Ok {
handler: HandlerFn,
params: std::collections::HashMap<String, String>,
},
NotFound,
MethodNotAllowed,
}
fn parse_path(path: &str) -> Vec<Segment> {
path.trim_start_matches('/')
.split('/')
.map(|seg| {
if let Some(name) = seg.strip_prefix(':') {
Segment::Param(name.to_string())
} else {
Segment::Static(seg.to_string())
}
})
.collect()
}
fn segments_match(route: &[Segment], path: &[&str]) -> bool {
if route.len() != path.len() {
return false;
}
for (r, p) in route.iter().zip(path.iter()) {
match r {
Segment::Static(s) if s != p => return false,
_ => {}
}
}
true
}
fn extract_params(route: &[Segment], path: &[&str]) -> std::collections::HashMap<String, String> {
let mut out = std::collections::HashMap::new();
for (r, p) in route.iter().zip(path.iter()) {
if let Segment::Param(name) = r {
out.insert(name.clone(), (*p).to_string());
}
}
out
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn matches_static_path() {
let router = Router::new().get("/hello", |_req| async { Ok(Response::text("hi")) });
let req = Request::new(
Method::GET,
"/hello".into(),
String::new(),
Default::default(),
bytes::Bytes::new(),
);
let resp = router.dispatch(req).await;
assert_eq!(resp.status.as_u16(), 200);
}
#[tokio::test]
async fn captures_param() {
let router = Router::new().get("/users/:id", |req| async move {
let id = req.param("id").unwrap_or("").to_string();
Ok(Response::text(id))
});
let req = Request::new(
Method::GET,
"/users/42".into(),
String::new(),
Default::default(),
bytes::Bytes::new(),
);
let resp = router.dispatch(req).await;
assert_eq!(resp.status.as_u16(), 200);
assert_eq!(&resp.body[..], b"42");
}
#[tokio::test]
async fn distinguishes_404_from_405() {
let router = Router::new().get("/things", |_| async { Ok(Response::text("ok")) });
let post = Request::new(
Method::POST,
"/things".into(),
String::new(),
Default::default(),
bytes::Bytes::new(),
);
assert_eq!(router.dispatch(post).await.status.as_u16(), 405);
let missing = Request::new(
Method::GET,
"/nope".into(),
String::new(),
Default::default(),
bytes::Bytes::new(),
);
assert_eq!(router.dispatch(missing).await.status.as_u16(), 404);
}
#[tokio::test]
async fn trailing_slash_is_normalised_for_static_and_param_routes() {
let router = Router::new()
.get("/admin/:name", |req| async move {
let name = req.param("name").unwrap_or("").to_string();
Ok(Response::text(name))
})
.get("/admin/:name/:id/edit", |req| async move {
Ok(Response::text(format!(
"{}/{}",
req.param("name").unwrap_or(""),
req.param("id").unwrap_or(""),
)))
});
for path in ["/admin/posts", "/admin/posts/"] {
let req = Request::new(
Method::GET,
path.into(),
String::new(),
Default::default(),
bytes::Bytes::new(),
);
let resp = router.dispatch(req).await;
assert_eq!(resp.status.as_u16(), 200, "GET {path} should be 200");
assert_eq!(&resp.body[..], b"posts", "GET {path} body");
}
for path in ["/admin/posts/1/edit", "/admin/posts/1/edit/"] {
let req = Request::new(
Method::GET,
path.into(),
String::new(),
Default::default(),
bytes::Bytes::new(),
);
let resp = router.dispatch(req).await;
assert_eq!(resp.status.as_u16(), 200, "GET {path} should be 200");
assert_eq!(&resp.body[..], b"posts/1", "GET {path} body");
}
}
#[tokio::test]
async fn root_path_still_matches_after_trailing_slash_normalisation() {
let router = Router::new().get("/", |_| async { Ok(Response::text("home")) });
let req = Request::new(
Method::GET,
"/".into(),
String::new(),
Default::default(),
bytes::Bytes::new(),
);
let resp = router.dispatch(req).await;
assert_eq!(resp.status.as_u16(), 200);
assert_eq!(&resp.body[..], b"home");
}
}