#[cfg(feature = "websocket")]
use crate::ws::WebSocket;
use crate::{
handler::Handler,
middleware::{Middleware, Next},
regex::Rule,
request::Request,
response::Response,
Result,
};
use async_trait::async_trait;
use path_tree::PathTree;
use std::{collections::HashMap, sync::Arc};
macro_rules! method {
($($(#[$m:meta])* $v:vis fn $n:ident = $method:expr;)+) => {
$(
$(#[$m])* $v fn $n<H>(&mut self,path: &str,handler: H) -> &mut Self
where
H: Handler,
{
self.route(path,$method.to_owned(),Arc::new(handler))
}
)+
};
}
#[derive(Default)]
pub struct Router {
pub(crate) prefix: String,
pub(crate) middlewares: Vec<(Rule, Arc<dyn Middleware>)>,
pub(crate) routes: Vec<(String, String, Arc<dyn Handler>)>,
pub(crate) tree: PathTree<Arc<dyn Handler>>,
#[cfg(feature = "swagger_ui")]
pub(crate) openapis: Vec<String>,
pub(crate) is_parent: bool,
}
impl std::fmt::Debug for Router {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Router")
.field("prefix", &self.prefix)
.field(
"routes",
&self
.routes
.iter()
.map(|(p, m, _)| (p.to_owned(), m.to_owned()))
.collect::<Vec<(String, String)>>(),
)
.field("tree", &self.tree.node)
.field("is_parent", &self.is_parent)
.finish()
}
}
impl Router {
#[inline]
pub fn new(prefix: &str) -> Router {
let path = prefix.trim_end_matches('/').to_owned() + "/";
Self {
prefix: path,
middlewares: Vec::new(),
routes: Vec::new(),
tree: PathTree::new(),
#[cfg(feature = "swagger_ui")]
openapis: Vec::new(),
is_parent: false,
}
}
#[inline]
pub fn hook<'a, M, E>(
&mut self,
middleware: M,
include_path: Vec<&str>,
exclude_path: E,
) -> &mut Self
where
M: Middleware,
E: Into<Option<Vec<&'a str>>>,
{
let rule = Rule::new(include_path, exclude_path.into());
self.middlewares.push((rule, Arc::new(middleware)));
self
}
#[inline]
pub fn route(&mut self, path: &str, method: String, handler: Arc<dyn Handler>) -> &mut Self {
let path = path.trim_start_matches('/');
let mut new_path = path.to_owned();
if self.is_parent == false {
new_path = self.prefix.trim_start_matches('/').to_owned() + path;
}
let tree_path = "/".to_owned() + method.as_str() + "/" + new_path.as_str();
let _ = self.tree.insert(tree_path.as_str(), handler.clone());
self.routes.push((new_path, method, handler));
self
}
#[inline]
pub fn push(&mut self, mut child: Router) -> &mut Self {
if child.routes.len() > 0 {
if child.middlewares.len() > 0 {
self.middlewares.append(&mut child.middlewares);
}
#[cfg(feature = "swagger_ui")]
if child.openapis.len() > 0 {
for path in child.openapis {
self.openapis
.push(self.prefix.clone() + path.trim_start_matches('/'));
}
} else {
self.openapis.push(self.prefix.clone());
self.openapis
.push(self.prefix.clone() + child.prefix.trim_start_matches('/'));
}
child
.routes
.into_iter()
.fold(self, |self_router, (mut path, method, handler)| {
path = self_router.prefix.clone() + path.as_str();
self_router.is_parent = true;
self_router.route(path.as_str(), method, handler)
})
} else {
self
}
}
method![
#[inline]
pub fn get = "GET";
#[inline]
pub fn post = "POST";
#[inline]
pub fn options = "OPTIONS";
#[inline]
pub fn put = "PUT";
#[inline]
pub fn delete = "DELETE";
#[inline]
pub fn head = "HEAD";
#[inline]
pub fn trace = "TRACE";
#[inline]
pub fn connect = "CONNECT";
#[inline]
pub fn patch = "PATCH";
];
#[cfg(feature = "static_file")]
#[inline]
pub fn static_dir(
&mut self,
uri_path: &str,
dir_path: impl Into<std::path::PathBuf>,
) -> &mut Self {
self.route(
uri_path,
"GET".to_owned(),
Arc::new(crate::static_dir::StaticDir::new(dir_path)),
)
}
#[cfg(feature = "static_file")]
#[inline]
pub fn static_file(
&mut self,
uri_path: &str,
dir_path: impl Into<std::path::PathBuf>,
) -> &mut Self {
self.route(
uri_path,
"GET".to_owned(),
Arc::new(crate::fs::NamedFileBuilder::new(dir_path)),
)
}
#[cfg(feature = "websocket")]
#[inline]
pub fn ws<H, F>(&mut self, path: &str, handler: H) -> &mut Self
where
H: Send + Sync + 'static + Fn(Request, WebSocket) -> F,
F: std::future::Future<Output = Result<()>> + Send + 'static,
{
self.route(path, "GET".to_owned(), Arc::new(crate::ws::new_ws(handler)))
}
#[cfg(feature = "swagger_ui")]
#[inline]
pub fn openapi(&mut self, openapi: impl Handler) -> &mut Self {
self.route("", "GET".to_owned(), Arc::new(openapi))
}
#[cfg(feature = "swagger_ui")]
#[inline]
pub fn swagger(&mut self, swagger_ui: &str) -> &mut Self {
let length = self.openapis.len();
if length == 0 {
self.openapis.push(self.prefix.clone());
self.is_parent = true;
}
if self.prefix.ne("/") && length > 0 {
self.openapis.remove(0);
}
let prefix = swagger_ui
.trim_start_matches('/')
.trim_end_matches('/')
.to_owned();
let path = swagger_ui.trim_end_matches('/').to_owned() + "/*";
let config = Arc::new(utoipa_swagger_ui::Config::new(self.openapis.clone()));
self.route(
&path,
"GET".to_owned(),
Arc::new(swagger::ServeSwagger { prefix, config }),
);
self
}
}
#[async_trait]
impl Handler for Router {
async fn handle(&self, mut req: Request) -> Result<Response> {
let uri_path = req.uri().path().to_owned();
let path = "/".to_owned() + req.method().as_str() + uri_path.as_str();
let responded = match self.tree.find(path.as_str()) {
Some((handler, route)) => {
*req.params_mut() = route
.params()
.iter()
.map(|p| (p.0.to_owned(), p.1.to_owned()))
.collect::<HashMap<String, String>>();
let next = Next {
path: uri_path,
handler: handler.clone(),
middleware: self.middlewares.clone(),
};
next.next(req).await?
}
None => Response::default().status(404).html(
r#"<!DOCTYPE html>
<html>
<head>
<title>404 Not Found</title>
</head>
<body>
<h1>404 Not Found</h1>
<h3>The requested resource could not be found.</h1>
</body>
</html>"#,
),
};
Ok(responded)
}
}
#[cfg(feature = "swagger_ui")]
mod swagger {
use super::{async_trait, Arc, Handler, Request, Response, Result};
use crate::error::Error;
use serde_json::json;
use utoipa::openapi::OpenApi;
use utoipa_swagger_ui::Config;
pub struct ServeSwagger {
pub prefix: String,
pub config: Arc<Config<'static>>,
}
#[async_trait]
impl Handler for ServeSwagger {
async fn handle(&self, req: Request) -> Result<Response> {
let path = req.uri().path().to_string();
let prefix = "/".to_owned() + &self.prefix + "/";
let tail = path.strip_prefix(prefix.as_str()).unwrap_or("/");
let res = Response::default();
match utoipa_swagger_ui::serve(tail, self.config.clone()) {
Ok(swagger_file) => match swagger_file {
Some(file) => Ok(res
.content_type(&file.content_type)
.write(file.bytes.to_vec())),
None => Err(Error::Status(500)),
},
Err(error) => Err(Error::Response(500, json!(error.to_string()))),
}
}
}
#[async_trait]
impl Handler for OpenApi {
async fn handle(&self, _: Request) -> Result<Response> {
Ok(Response::default().json(&self))
}
}
}