use async_trait::async_trait;
pub use route_service::RouteService;
use std::collections::HashMap;
use std::fmt;
use std::sync::Arc;
use crate::handler::Handler;
#[cfg(feature = "static")]
use crate::handler::{StaticOptions, static_handler_with_options};
use crate::middleware::MiddleWareHandler;
#[cfg(feature = "static")]
use crate::prelude::HandlerGetter;
use crate::{Method, Request, Response};
pub(crate) mod handler_append;
mod handler_match;
mod route_service;
mod route_tree;
pub use route_tree::RouteTree;
#[cfg(all(feature = "worker", target_arch = "wasm32"))]
pub mod worker;
pub trait RouterAdapt {
fn into_router(self) -> Route;
}
#[derive(Clone)]
pub struct Route {
pub path: String,
pub handler: HashMap<Method, Arc<dyn Handler>>,
pub children: Vec<Route>,
pub middlewares: Vec<Arc<dyn MiddleWareHandler>>,
special_match: bool,
create_path: String,
state: Option<crate::State>,
#[cfg(feature = "session")]
session_set: bool,
}
impl RouterAdapt for Route {
fn into_router(self) -> Route {
self
}
}
impl Default for Route {
fn default() -> Self {
Self::new("")
}
}
impl fmt::Debug for Route {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fn get_route_str(pre_fix: String, route: &Route) -> String {
let space_pre_fix = format!(" {pre_fix}");
let mut route_strs: Vec<String> = route
.children
.iter()
.filter(|r| !r.handler.is_empty() || !r.children.is_empty())
.map(|r| get_route_str(space_pre_fix.clone(), r))
.collect();
if !route.handler.is_empty() || !route.children.is_empty() {
let methods: Vec<String> = route.handler.keys().map(|m| m.to_string()).collect();
let methods_str = if methods.is_empty() {
"".to_string()
} else {
format!("({})", methods.join(","))
};
route_strs.insert(0, format!("{}{}{}", pre_fix, route.path, methods_str));
}
route_strs.join("\n")
}
write!(f, "{}", get_route_str("".to_string(), self))
}
}
impl Route {
pub fn new_root() -> Self {
Route {
path: String::new(),
handler: HashMap::new(),
children: Vec::new(),
middlewares: Vec::new(),
special_match: false,
create_path: String::new(),
state: Some(crate::State::new()), #[cfg(feature = "session")]
session_set: false,
}
}
pub fn new(path: &str) -> Self {
let path = path.trim_start_matches('/');
let mut paths = path.splitn(2, '/');
let first_path = paths.next().unwrap_or("");
let last_path = paths.next().unwrap_or("");
let route = Route {
path: first_path.to_string(),
handler: HashMap::new(),
children: Vec::new(),
middlewares: Vec::new(),
special_match: first_path.starts_with('<') && first_path.ends_with('>'),
create_path: path.to_string(),
state: None,
#[cfg(feature = "session")]
session_set: false,
};
if last_path.is_empty() {
route
} else {
route.append_route(Route::new(last_path))
}
}
fn append_route(mut self, route: Route) -> Self {
Self::merge_child(&mut self.children, route);
self
}
fn get_append_real_route(&mut self, create_path: &str) -> &mut Self {
if !create_path.contains('/') {
self
} else {
let mut paths = create_path.splitn(2, '/');
let _first_path = paths.next().unwrap_or("");
let last_path = paths.next().unwrap_or("");
let route = self
.children
.iter_mut()
.find(|r| r.create_path == last_path);
let route = route.unwrap();
route.get_append_real_route(last_path)
}
}
pub fn append<R: RouterAdapt>(mut self, route: R) -> Self {
let route = route.into_router();
let real_route = self.get_append_real_route(&self.create_path.clone());
Self::merge_child(&mut real_route.children, route);
self
}
pub fn extend<R: RouterAdapt>(&mut self, routes: Vec<R>) {
let routes: Vec<Route> = routes.into_iter().map(|r| r.into_router()).collect();
let real_route = self.get_append_real_route(&self.create_path.clone());
for route in routes {
Self::merge_child(&mut real_route.children, route);
}
}
pub fn hook(mut self, handler: impl MiddleWareHandler + 'static) -> Self {
self.middlewares.push(Arc::new(handler));
self
}
#[cfg(feature = "quic")]
pub fn with_quic_port(self, port: u16) -> Self {
self.hook(crate::quic::AltSvcMiddleware::new(port))
}
#[cfg(feature = "static")]
pub fn with_static(self, path: &str) -> Self {
self.with_static_options(path, StaticOptions::default())
}
#[cfg(feature = "static")]
pub fn with_static_options(self, path: &str, options: StaticOptions) -> Self {
let handler = static_handler_with_options(path, options);
self.append(Route::new("<path:**>").insert_handler(Method::GET, Arc::new(handler)))
}
#[cfg(feature = "static")]
pub fn with_static_in_url(self, url: &str, path: &str) -> Self {
self.with_static_in_url_options(url, path, StaticOptions::default())
}
#[cfg(feature = "static")]
pub fn with_static_in_url_options(self, url: &str, path: &str, options: StaticOptions) -> Self {
self.append(Route::new(url).with_static_options(path, options))
}
pub fn push<R: RouterAdapt>(&mut self, route: R) {
let route = route.into_router();
let real_route = self.get_append_real_route(&self.create_path.clone());
Self::merge_child(&mut real_route.children, route);
}
pub fn hook_first(&mut self, handler: impl MiddleWareHandler + 'static) {
let handler = Arc::new(handler);
self.middlewares.insert(0, handler);
}
#[cfg(feature = "tower-compat")]
pub fn hook_layer<L>(self, layer: L) -> Self
where
L: tower::Layer<crate::middleware::tower_compat::NextServicePublic>
+ Clone
+ Send
+ Sync
+ 'static,
L::Service:
tower::Service<http::Request<crate::core::req_body::ReqBody>> + Clone + Send + 'static,
<L::Service as tower::Service<http::Request<crate::core::req_body::ReqBody>>>::Response:
crate::middleware::tower_compat::IntoSilentResponse + Send,
<L::Service as tower::Service<http::Request<crate::core::req_body::ReqBody>>>::Error:
Into<crate::error::BoxedError> + Send,
<L::Service as tower::Service<http::Request<crate::core::req_body::ReqBody>>>::Future: Send,
{
self.hook(crate::middleware::tower_compat::TowerLayerAdapter::new(
layer,
))
}
pub fn with_state<T: Send + Sync + Clone + 'static>(mut self, val: T) -> Self {
self.state.get_or_insert_with(crate::State::new).insert(val);
self
}
pub fn set_state(&mut self, state: Option<crate::State>) {
self.state = state;
}
pub(crate) fn get_state(&self) -> Option<&crate::State> {
self.state.as_ref()
}
#[deprecated(since = "2.16.0", note = "请使用 with_state 或 set_state 代替")]
pub fn set_configs(&mut self, configs: Option<crate::State>) {
self.set_state(configs);
}
#[deprecated(since = "2.16.0", note = "请使用 get_state 代替")]
pub fn get_configs(&self) -> Option<&crate::State> {
self.get_state()
}
#[cfg(feature = "session")]
pub fn set_session_store<S: async_session::SessionStore>(&mut self, session: S) -> &mut Self {
self.hook_first(crate::session::middleware::SessionMiddleware::new(session));
self.session_set = true;
self
}
#[cfg(feature = "session")]
pub fn check_session(&mut self) {
if !self.session_set {
self.hook_first(crate::session::middleware::SessionMiddleware::default())
}
}
#[cfg(feature = "cookie")]
pub fn check_cookie(&mut self) {
self.hook_first(crate::cookie::middleware::CookieMiddleware::new())
}
#[cfg(feature = "template")]
pub fn set_template_dir(&mut self, dir: impl Into<String>) -> &mut Self {
let handler = crate::templates::TemplateMiddleware::new(dir.into().as_str());
self.middlewares.push(Arc::new(handler));
self
}
}
impl Route {
fn merge_child(children: &mut Vec<Route>, route: Route) {
if let Some(existing) = children
.iter_mut()
.find(|child| child.path == route.path && child.special_match == route.special_match)
{
existing.merge_from(route);
} else {
children.push(route);
}
}
fn merge_from(&mut self, mut other: Route) {
for (method, handler) in other.handler.drain() {
self.handler.entry(method).or_insert(handler);
}
let middlewares = std::mem::take(&mut other.middlewares);
if !middlewares.is_empty() {
self.middlewares.extend(middlewares);
}
let children = std::mem::take(&mut other.children);
for child in children {
Self::merge_child(&mut self.children, child);
}
if let Some(other_state) = other.state {
if let Some(state) = self.state.as_mut() {
state.extend_from(&other_state);
} else {
self.state = Some(other_state);
}
}
debug_assert!(
self.special_match == other.special_match,
"尝试合并特殊匹配标记不一致的路由"
);
self.special_match |= other.special_match;
#[cfg(feature = "session")]
{
self.session_set |= other.session_set;
}
}
}
#[async_trait]
impl Handler for Route {
async fn call(&self, mut req: Request) -> crate::error::SilentResult<Response> {
if let Some(state) = self.get_state() {
req.state_mut().extend_from(state);
}
let tree = self.clone().convert_to_route_tree();
tree.call(req).await
}
}
impl crate::server::ConnectionService for Route {
fn call(
&self,
stream: crate::server::connection::BoxedConnection,
peer: crate::core::socket_addr::SocketAddr,
) -> crate::server::ConnectionFuture {
crate::server::RouteConnectionService::from(self.clone()).call(stream, peer)
}
}
#[cfg(test)]
mod tests {
use std::sync::Mutex;
use crate::{Next, Request, Response, SilentError};
use super::*;
#[derive(Clone, Eq, PartialEq)]
struct MiddlewareTest;
#[async_trait::async_trait]
impl MiddleWareHandler for MiddlewareTest {
async fn handle(&self, req: Request, next: &Next) -> crate::error::SilentResult<Response> {
next.call(req).await
}
}
#[test]
fn middleware_tree_test() {
let route = Route::new("api")
.hook(MiddlewareTest {})
.append(Route::new("test"));
assert_eq!(route.middlewares.len(), 1); assert_eq!(route.children[0].middlewares.len(), 0); }
#[test]
fn long_path_append_test() {
let route = Route::new("api/v1")
.hook(MiddlewareTest {})
.append(Route::new("test"));
assert_eq!(route.children.len(), 1);
assert_eq!(route.children[0].children.len(), 1);
}
#[tokio::test]
async fn test_route_onion_model() {
let log: Arc<Mutex<Vec<String>>> = Arc::new(Mutex::new(Vec::new()));
#[derive(Clone)]
struct LoggingMw {
name: &'static str,
log: Arc<Mutex<Vec<String>>>,
short_on_options: bool,
}
#[async_trait::async_trait]
impl MiddleWareHandler for LoggingMw {
async fn handle(
&self,
req: Request,
next: &Next,
) -> crate::error::SilentResult<Response> {
{
let mut v = self.log.lock().unwrap();
v.push(format!("{}:pre", self.name));
}
if self.short_on_options && *req.method() == Method::OPTIONS {
let mut v = self.log.lock().unwrap();
v.push(format!("{}:short", self.name));
let mut res = Response::empty();
res.headers_mut()
.insert("X-Short-Circuit", "true".parse().unwrap());
return Ok(res);
}
let res = next.call(req).await;
{
let mut v = self.log.lock().unwrap();
v.push(format!("{}:post", self.name));
}
res
}
}
let log1 = log.clone();
async fn ok(_: Request) -> Result<String, SilentError> {
Ok("ok".into())
}
let handler = move |req: Request| {
let l = log1.clone();
async move {
{
let mut v = l.lock().unwrap();
v.push("handler".to_string());
}
ok(req).await
}
};
let route = Route::new("")
.hook(LoggingMw {
name: "root",
log: log.clone(),
short_on_options: false,
})
.append(
Route::new("api")
.hook(LoggingMw {
name: "api",
log: log.clone(),
short_on_options: false,
})
.append(
Route::new("v1")
.hook(LoggingMw {
name: "v1",
log: log.clone(),
short_on_options: true,
})
.get(handler),
),
);
let routes = route.convert_to_route_tree();
{
let mut req = Request::empty();
*req.uri_mut() = "/api/v1".parse().unwrap();
*req.method_mut() = Method::GET;
let _ = routes.call(req).await.expect("GET should pass");
let entries = log.lock().unwrap().clone();
assert_eq!(
entries,
vec![
"root:pre",
"api:pre",
"v1:pre",
"handler",
"v1:post",
"api:post",
"root:post",
]
);
}
{
log.lock().unwrap().clear();
let mut req = Request::empty();
*req.uri_mut() = "/api/v1".parse().unwrap();
*req.method_mut() = Method::OPTIONS;
let res = routes
.call(req)
.await
.expect("OPTIONS should short-circuit");
assert_eq!(res.status, http::StatusCode::OK);
assert_eq!(
res.headers()
.get("X-Short-Circuit")
.unwrap()
.to_str()
.unwrap(),
"true"
);
let entries = log.lock().unwrap().clone();
assert_eq!(
entries,
vec![
"root:pre",
"api:pre",
"v1:pre",
"v1:short",
"api:post",
"root:post",
]
);
}
}
#[test]
fn test_route_new_empty() {
let route = Route::new("");
assert_eq!(route.path, "");
}
#[test]
fn test_route_new_with_path() {
let route = Route::new("api");
assert_eq!(route.path, "api");
}
#[test]
fn test_route_new_with_leading_slash() {
let route = Route::new("/api");
assert_eq!(route.path, "api");
}
#[test]
fn test_route_new_with_nested_path() {
let route = Route::new("api/v1");
assert_eq!(route.path, "api");
assert_eq!(route.children.len(), 1);
assert_eq!(route.children[0].path, "v1");
}
#[test]
fn test_route_default() {
let route = Route::default();
assert_eq!(route.path, "");
}
#[test]
fn test_route_new_root() {
let route = Route::new_root();
assert_eq!(route.path, "");
assert!(route.state.is_some());
}
#[test]
fn test_route_special_match_detection() {
let route1 = Route::new("<path:**>");
assert!(route1.special_match);
let route2 = Route::new("api");
assert!(!route2.special_match);
}
#[test]
fn test_route_extend_empty() {
let mut route = Route::new("api");
route.extend::<Route>(vec![]);
assert_eq!(route.children.len(), 0);
}
#[test]
fn test_route_extend_single() {
let mut route = Route::new("api");
route.extend(vec![Route::new("v1")]);
assert_eq!(route.children.len(), 1);
}
#[test]
fn test_route_extend_multiple() {
let mut route = Route::new("api");
route.extend(vec![Route::new("v1"), Route::new("v2"), Route::new("v3")]);
assert_eq!(route.children.len(), 3);
}
#[test]
fn test_route_extend_with_handlers() {
let mut route = Route::new("api");
route.extend(vec![
Route::new("v1").get(|_req: Request| async { Ok("v1") }),
Route::new("v2").post(|_req: Request| async { Ok("v2") }),
]);
assert_eq!(route.children.len(), 2);
}
#[test]
fn test_route_push_single() {
let mut route = Route::new("api");
route.push(Route::new("v1"));
assert_eq!(route.children.len(), 1);
}
#[test]
fn test_route_push_nested() {
let mut route = Route::new("api");
route.push(Route::new("v1").append(Route::new("users")));
assert_eq!(route.children.len(), 1);
assert_eq!(route.children[0].children.len(), 1);
}
#[test]
fn test_route_merge_from_same_path() {
let mut route1 = Route::new("api").get(|_req: Request| async { Ok("r1") });
let route2 = Route::new("api").post(|_req: Request| async { Ok("r2") });
route1.merge_from(route2);
assert!(route1.handler.contains_key(&Method::GET));
assert!(route1.handler.contains_key(&Method::POST));
assert_eq!(route1.children.len(), 0);
}
#[test]
fn test_route_merge_from_handlers() {
let mut route1 = Route::new("api").get(|_req: Request| async { Ok("r1") });
let route2 = Route::new("api").post(|_req: Request| async { Ok("r2") });
route1.merge_from(route2);
assert!(route1.handler.contains_key(&Method::GET));
assert!(route1.handler.contains_key(&Method::POST));
}
#[test]
fn test_route_merge_from_middlewares() {
let mut route1 = Route::new("api").hook(MiddlewareTest {});
let route2 = Route::new("api").hook(MiddlewareTest {});
route1.merge_from(route2);
assert_eq!(route1.middlewares.len(), 2);
}
#[test]
fn test_route_merge_from_with_children() {
let mut route1 = Route::new("api");
let route2 = Route::new("test").append(Route::new("users"));
route1.merge_from(route2);
assert_eq!(route1.children.len(), 1);
}
#[test]
fn test_router_adapt_into_router() {
let route = Route::new("test");
let _converted = route.into_router();
}
#[test]
fn test_router_adapt_with_append() {
let route1 = Route::new("api");
let route2 = Route::new("test");
let _result = route1.append(route2);
}
#[test]
fn test_route_debug_empty() {
let route = Route::new("test");
let _debug_str = format!("{:?}", route);
}
#[test]
fn test_route_debug_with_handler() {
let route = Route::new("test").get(|_req: Request| async { Ok("test") });
let debug_str = format!("{:?}", route);
assert!(debug_str.contains("test"));
assert!(debug_str.contains("(GET)"));
}
#[test]
fn test_route_debug_with_handlers() {
let route = Route::new("test").get(|_req: Request| async { Ok("test") });
let debug_str = format!("{:?}", route);
assert!(debug_str.contains("test"));
assert!(debug_str.contains("(GET)"));
}
#[test]
fn test_route_debug_with_children() {
let route = Route::new("api").append(Route::new("test"));
let debug_str = format!("{:?}", route);
assert!(debug_str.contains("api"));
}
}