use std::collections::HashMap;
use crate::http::request::Request;
use super::route::Route;
const DYNAMIC_CHARS: [char; 2] = [':', '*'];
#[derive(Debug)]
pub struct RouterBuilder {
pub routes: Vec<Route>,
}
impl RouterBuilder {
pub fn new() -> Self {
Self { routes: vec![] }
}
pub fn add_route(mut self, route: Route) -> Self {
self.routes.push(route);
self
}
pub fn add_routes(mut self, routes: &[Route]) -> Self {
self.routes.extend_from_slice(routes);
self
}
pub fn build(mut self) -> Router {
self.sort_routes();
Router {
routes: self.routes,
}
}
fn sort_routes(&mut self) {
self.routes.sort_by(|a, b| {
(!b.path.contains(DYNAMIC_CHARS)).cmp(&!a.path.contains(DYNAMIC_CHARS))
});
}
}
impl Default for RouterBuilder {
fn default() -> Self {
Self::new()
}
}
pub struct MatchedRoute {
pub route: Route,
pub params: Option<HashMap<String, String>>,
}
#[derive(Debug)]
pub struct Router {
routes: Vec<Route>,
}
impl Router {
pub fn new(routes: Vec<Route>) -> Self {
Self { routes }
}
pub fn builder() -> RouterBuilder {
RouterBuilder::new()
}
pub fn route(self, request: Request) -> Option<MatchedRoute> {
let mut matched_route = None;
let mut params = HashMap::new();
'outer: for route in self.routes {
let route_segments: Vec<_> = route.path.split('/').filter(|s| !s.is_empty()).collect();
let request_segments: Vec<_> =
request.path.split('/').filter(|s| !s.is_empty()).collect();
for (route_seg, request_seg) in route_segments.iter().zip(request_segments.iter()) {
if *route_seg == "*" {
todo!("handle wildcard in router path");
} else if route_seg.starts_with(':') {
let name = route_seg.strip_prefix(':').unwrap();
params.insert(name.to_string(), (*request_seg).to_string());
} else if route_seg != request_seg {
continue 'outer;
}
}
if route_segments.len() > request_segments.len() {
if route_segments[request_segments.len()] == "*" {
todo!("handle wildcard at end of route");
}
continue;
}
if request_segments.len() > route_segments.len() {
continue;
}
matched_route = Some(MatchedRoute {
route,
params: Some(params),
});
break;
}
matched_route
}
}
#[cfg(test)]
mod tests {
use std::collections::HashMap;
use super::*;
#[test]
fn it_correctly_sorts_static_routes_before_dynamic_routes() {
let router = Router::builder()
.add_routes(&[Route::get("/posts/:id"), Route::get("/posts/1")])
.build();
let expected = vec![Route::get("/posts/1"), Route::get("/posts/:id")];
assert_eq!(router.routes, expected);
}
#[test]
fn it_correctly_matches_a_request_to_a_static_route() {
let router = Router::builder()
.add_routes(&[
Route::get("/"),
Route::get("/posts"),
Route::get("/posts/:id"),
Route::get("/posts/not-found"),
Route::get("/posts/:id/comments"),
])
.build();
let matched_route = router.route(Request {
path: "/posts/3".to_string(),
headers: HashMap::new(),
});
assert!(matched_route.is_some());
assert_eq!(matched_route.unwrap().route.path, "/posts/:id".to_string());
}
}