use crate::{DynEndpoint, Endpoint, IntoResponse, Middleware, Next, Request, Result};
use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::Arc;
pub struct Router {
pub prefix: Option<String>,
pub middlewares: Vec<Arc<dyn Middleware>>,
pub routes: HashMap<hyper::Method, route_recognizer::Router<Box<DynEndpoint>>>,
pub not_found_handler: Box<DynEndpoint>,
}
async fn default_handler(_req: Request) -> impl IntoResponse {
"handle not found"
}
impl Default for Router {
fn default() -> Self {
Self::new()
}
}
impl Router {
pub fn new() -> Self {
Router {
prefix: None,
middlewares: Vec::new(),
routes: HashMap::new(),
not_found_handler: Box::new(default_handler),
}
}
pub fn at(&mut self, method: hyper::Method, route: &str, dest: impl Endpoint) {
let path = match &self.prefix {
Some(prefix) => format!("{}{}", prefix, route),
None => route.to_string(),
};
self
.routes
.entry(method)
.or_default()
.add(&path, Box::new(dest));
}
pub fn get(&mut self, route: &str, dest: impl Endpoint) {
self.at(hyper::Method::GET, route, dest);
}
pub fn post(&mut self, route: &str, dest: impl Endpoint) {
self.at(hyper::Method::POST, route, dest);
}
pub fn delete(&mut self, route: &str, dest: impl Endpoint) {
self.at(hyper::Method::DELETE, route, dest);
}
pub fn patch(&mut self, route: &str, dest: impl Endpoint) {
self.at(hyper::Method::PATCH, route, dest);
}
pub fn put(&mut self, route: &str, dest: impl Endpoint) {
self.at(hyper::Method::PUT, route, dest);
}
pub fn options(&mut self, route: &str, dest: impl Endpoint) {
self.at(hyper::Method::OPTIONS, route, dest);
}
pub fn head(&mut self, route: &str, dest: impl Endpoint) {
self.at(hyper::Method::HEAD, route, dest);
}
pub fn trace(&mut self, route: &str, dest: impl Endpoint) {
self.at(hyper::Method::TRACE, route, dest);
}
pub fn connect(&mut self, route: &str, dest: impl Endpoint) {
self.at(hyper::Method::CONNECT, route, dest);
}
pub fn with(&mut self, middleware: impl Middleware) {
self.middlewares.push(Arc::new(middleware));
}
pub fn merge(&mut self, target: Router) {
self.routes.extend(target.routes);
}
pub async fn dispatch(&self, mut req: Request, remote_addr: Arc<SocketAddr>) -> Result {
let method = req.method();
let path = req.uri().path();
let mut params = route_recognizer::Params::new();
let endpoint = match self.routes.get(method) {
Some(route) => match route.recognize(path) {
Ok(m) => {
m.params().clone_into(&mut params);
&***m.handler()
}
Err(_e) => &*self.not_found_handler,
},
None => &*self.not_found_handler,
};
req.params = params;
req.remote_addr = Some(remote_addr);
let next = Next {
endpoint,
middlewares: &self.middlewares,
};
next.run(req).await
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_router_new() {
let router = Router::new();
assert!(router.prefix.is_none());
assert!(router.middlewares.is_empty());
assert!(router.routes.is_empty());
}
#[test]
fn test_router_get() {
let mut router = Router::new();
router.get("/", |_| async { "home" });
assert!(router.routes.contains_key(&hyper::Method::GET));
}
#[test]
fn test_router_post() {
let mut router = Router::new();
router.post("/api/data", |_| async { "created" });
assert!(router.routes.contains_key(&hyper::Method::POST));
}
#[test]
fn test_router_all_methods() {
let mut router = Router::new();
router.get("/g", |_| async { "get" });
router.post("/p", |_| async { "post" });
router.delete("/d", |_| async { "delete" });
router.patch("/pa", |_| async { "patch" });
router.put("/pu", |_| async { "put" });
router.options("/o", |_| async { "options" });
router.head("/h", |_| async { "head" });
router.trace("/t", |_| async { "trace" });
router.connect("/c", |_| async { "connect" });
assert_eq!(router.routes.len(), 9);
}
#[test]
fn test_router_with_middleware() {
use crate::{Middleware, Next, Request, Result};
struct TestMiddleware;
#[async_trait::async_trait]
impl Middleware for TestMiddleware {
async fn handle(&self, req: Request, next: Next<'_>) -> Result {
next.run(req).await
}
}
let mut router = Router::new();
router.get("/", |_| async { "test" });
router.with(TestMiddleware);
assert_eq!(router.middlewares.len(), 1);
}
#[test]
fn test_router_merge() {
let mut router1 = Router::new();
router1.get("/a", |_| async { "a" });
let mut router2 = Router::new();
router2.get("/b", |_| async { "b" });
router1.merge(router2);
assert!(router1.routes.contains_key(&hyper::Method::GET));
}
}