use crate::{
endpoint::{BoxedEnpoint, Endpoint, Handler, HandlerOutput},
http::{Method, Response, StatusCode},
request::Request,
server::state::State,
};
use routefinder::Captures;
use std::collections::HashMap;
pub mod route;
pub use route::Route;
#[derive(Debug)]
pub(crate) struct Router<S: State> {
method_map: HashMap<http_types::Method, routefinder::Router<BoxedEnpoint<S>>>,
all_method_router: routefinder::Router<BoxedEnpoint<S>>,
}
impl<S: State> Router<S> {
pub(crate) fn new() -> Self {
Self {
method_map: HashMap::default(),
all_method_router: routefinder::Router::new(),
}
}
pub(crate) fn add(
&mut self,
path: &str,
method: http_types::Method,
ep: impl Handler<Request<S>, HandlerOutput>,
) {
let ep = Box::new(ep);
self.method_map
.entry(method)
.or_default()
.add(path, ep)
.unwrap();
}
pub(crate) fn add_all(&mut self, path: &str, ep: impl Handler<Request<S>, HandlerOutput>) {
let ep = Box::new(ep);
self.all_method_router.add(path, ep).unwrap();
}
pub(crate) fn route(&self, path: &str, method: Method) -> Selection<'_, S> {
if let Some(m) = self
.method_map
.get(&method)
.and_then(|r| r.best_match(path))
{
Selection {
handler: m.handler(),
params: m.captures().into_owned(),
}
} else if let Some(m) = self.all_method_router.best_match(path) {
Selection {
handler: m.handler(),
params: m.captures().into_owned(),
}
} else if method == http_types::Method::Head {
self.route(path, http_types::Method::Get)
} else if self
.method_map
.iter()
.filter(|(k, _)| **k != method)
.any(|(_, r)| r.best_match(path).is_some())
{
Selection {
handler: &method_not_allowed,
params: Captures::default(),
}
} else {
Selection {
handler: ¬_found_endpoint,
params: Captures::default(),
}
}
}
}
pub(crate) struct Selection<'a, S: State> {
pub(crate) handler: &'a Endpoint<S>,
pub(crate) params: Captures<'a, 'static>,
}
async fn not_found_endpoint<S: State>(_: Request<S>) -> HandlerOutput {
Ok(Response::new(StatusCode::NotFound))
}
async fn method_not_allowed<S: State>(_: Request<S>) -> HandlerOutput {
Ok(Response::new(StatusCode::MethodNotAllowed))
}