use std::{convert::Infallible, sync::Arc};
use axum::{
Router as AxumRouter,
extract::Request,
handler::Handler,
response::IntoResponse,
routing::{self, MethodRouter, Route},
};
use tower_layer::Layer;
use tower_service::Service;
type BoxedLayer = Arc<dyn Fn(MethodRouter) -> MethodRouter + Send + Sync>;
#[derive(Clone, Default)]
struct StackFrame {
prefix: String,
layers: Vec<BoxedLayer>,
}
#[derive(Default)]
pub struct Router {
inner: AxumRouter,
stack: StackFrame,
}
macro_rules! impl_http_methods {
([$($method:ident),*]) => {
$(
#[doc = concat!("Registers a `", stringify!($method), "` route.")]
pub fn $method<H, T>(&mut self, path: &str, handler: H) -> &mut Self
where
H: Handler<T, ()>,
T: 'static,
{
self.add_route(path, routing::$method(handler))
}
)*
};
}
impl Router {
pub fn new() -> Self {
Self::default()
}
impl_http_methods!([get, post, put, delete, patch, head, options, trace, connect]);
pub fn prefix(&mut self, prefix: &str) -> &mut Self {
self.stack.prefix = join_paths(&self.stack.prefix, prefix);
self
}
pub fn group<F>(&mut self, f: F) -> &mut Self
where
F: FnOnce(&mut Self),
{
let previous_stack = self.stack.clone();
f(self);
self.stack = previous_stack;
self
}
pub fn middleware<L>(&mut self, layer: L) -> &mut Self
where
L: Layer<Route> + Clone + Send + Sync + 'static,
L::Service: Service<Request> + Clone + Send + Sync + 'static,
<L::Service as Service<Request>>::Response: IntoResponse + 'static,
<L::Service as Service<Request>>::Error: Into<Infallible> + 'static,
<L::Service as Service<Request>>::Future: Send + 'static,
{
self.layer(layer)
}
pub fn layer<L>(&mut self, layer: L) -> &mut Self
where
L: Layer<Route> + Clone + Send + Sync + 'static,
L::Service: Service<Request> + Clone + Send + Sync + 'static,
<L::Service as Service<Request>>::Response: IntoResponse + 'static,
<L::Service as Service<Request>>::Error: Into<Infallible> + 'static,
<L::Service as Service<Request>>::Future: Send + 'static,
{
let layer_fn = move |route: MethodRouter| route.layer(layer.clone());
self.stack.layers.push(Arc::new(layer_fn));
self
}
pub(crate) fn build(self) -> AxumRouter {
self.inner
}
fn add_route(&mut self, path: &str, method_router: MethodRouter) -> &mut Self {
let full_path = join_paths(&self.stack.prefix, path);
let method_router_with_layers = self
.stack
.layers
.iter()
.rev()
.fold(method_router, |router, layer| layer(router));
self.inner = self
.inner
.clone()
.route(&full_path, method_router_with_layers);
self
}
}
fn join_paths(prefix: &str, path: &str) -> String {
match (prefix.is_empty(), path) {
(true, _) => path.to_string(),
(false, "/") => prefix.to_string(),
(false, _) => {
let prefix = prefix.trim_end_matches('/');
let path = path.trim_start_matches('/');
format!("{}/{}", prefix, path)
}
}
}
#[cfg(test)]
impl Router {
pub(crate) fn compose<F: FnMut(&mut Router)>(mut composer: F) -> Router {
let mut router = Self::new();
composer(&mut router);
router
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::{
pin::Pin,
sync::{Arc, Mutex},
};
use tower_layer::layer_fn;
#[test]
fn test_join_paths() {
assert_eq!(join_paths("", "/users"), "/users");
assert_eq!(join_paths("/api", "/users"), "/api/users");
assert_eq!(join_paths("/api/", "/users"), "/api/users");
assert_eq!(join_paths("/api", "users"), "/api/users");
assert_eq!(join_paths("/api/", "users"), "/api/users");
assert_eq!(join_paths("/api", "/"), "/api");
assert_eq!(join_paths("/api/v1", "/users"), "/api/v1/users");
}
#[derive(Clone)]
struct LogLayer {
name: &'static str,
logs: Arc<Mutex<Vec<String>>>,
}
impl<S> Layer<S> for LogLayer {
type Service = LogService<S>;
fn layer(&self, inner: S) -> Self::Service {
LogService {
inner,
name: self.name,
logs: self.logs.clone(),
}
}
}
#[derive(Clone)]
struct LogService<S> {
inner: S,
name: &'static str,
logs: Arc<Mutex<Vec<String>>>,
}
impl<S, Request> Service<Request> for LogService<S>
where
S: Service<Request> + Clone + Send + 'static,
S::Future: Send,
Request: Send + 'static,
{
type Response = S::Response;
type Error = S::Error;
type Future =
Pin<Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(
&mut self,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: Request) -> Self::Future {
let log = self.logs.clone();
let name = self.name;
let inner = self.inner.clone();
let mut inner = std::mem::replace(&mut self.inner, inner);
log.lock().unwrap().push(format!("{}_before", name));
Box::pin(async move {
let result = inner.call(req).await;
log.lock().unwrap().push(format!("{}_after", name));
result
})
}
}
#[tokio::test]
async fn test_middleware_execution_order() {
use http::Request as HttpRequest;
let log = Arc::new(Mutex::new(Vec::new()));
let mut router = Router::new();
router
.middleware(LogLayer {
name: "A",
logs: log.clone(),
})
.middleware(LogLayer {
name: "B",
logs: log.clone(),
})
.middleware(LogLayer {
name: "C",
logs: log.clone(),
})
.get("/test", || async { "OK" });
let mut service = router.build();
let req = HttpRequest::builder()
.uri("/test")
.body(axum::body::Body::empty())
.unwrap();
let _response = service.call(req).await.unwrap();
let recorded = log.lock().unwrap().clone();
assert_eq!(
recorded,
vec![
"A_before", "B_before", "C_before", "C_after", "B_after", "A_after",
]
);
}
#[tokio::test]
async fn test_group_middleware_isolation() {
use http::Request as HttpRequest;
let log = Arc::new(Mutex::new(Vec::new()));
let mut router = Router::new();
router.middleware(LogLayer {
name: "Global",
logs: log.clone(),
});
router.group(|r| {
r.middleware(LogLayer {
name: "GroupOnly",
logs: log.clone(),
})
.get("/grouped", || async { "grouped" });
});
router.get("/outside", || async { "outside" });
let mut service = router.build();
log.lock().unwrap().clear();
let req = HttpRequest::builder()
.uri("/grouped")
.body(axum::body::Body::empty())
.unwrap();
let _response = service.call(req).await.unwrap();
let recorded = log.lock().unwrap().clone();
assert_eq!(
recorded,
vec![
"Global_before",
"GroupOnly_before",
"GroupOnly_after",
"Global_after",
]
);
log.lock().unwrap().clear();
let req = HttpRequest::builder()
.uri("/outside")
.body(axum::body::Body::empty())
.unwrap();
let _response = service.call(req).await.unwrap();
let recorded = log.lock().unwrap().clone();
assert_eq!(recorded, vec!["Global_before", "Global_after",]);
}
#[tokio::test]
async fn test_prefix_accumulation() {
use http::StatusCode;
let mut router = Router::new();
router.middleware(layer_fn(|svc| {
println!("Request received");
svc
}));
router.group(|r| {
r.prefix("/api")
.prefix("/v1")
.get("/users", || async { "users" });
});
let mut service = router.build();
let req = http::Request::builder()
.uri("/api/v1/users")
.body(axum::body::Body::empty())
.unwrap();
let response = service.call(req).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let req = http::Request::builder()
.uri("/api/users")
.body(axum::body::Body::empty())
.unwrap();
let response = service.call(req).await.unwrap();
assert_eq!(response.status(), StatusCode::NOT_FOUND);
}
#[tokio::test]
async fn test_path_parameters() {
use axum::extract::Path;
use http::{Request as HttpRequest, StatusCode};
let mut router = Router::new();
router.get("/users/{id}", |Path(id): Path<u32>| async move {
format!("User ID: {}", id)
});
let mut service = router.build();
let req = HttpRequest::builder()
.uri("/users/123")
.body(axum::body::Body::empty())
.unwrap();
let response = service.call(req).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
assert_eq!(body, "User ID: 123");
}
#[tokio::test]
async fn test_query_parameters() {
use axum::extract::Query;
use http::{Request as HttpRequest, StatusCode};
use serde::Deserialize;
#[derive(Deserialize)]
struct Pagination {
page: Option<u32>,
limit: Option<u32>,
}
let mut router = Router::new();
router.get(
"/search",
|Query(pagination): Query<Pagination>| async move {
format!(
"Page: {}, Limit: {}",
pagination.page.unwrap_or(1),
pagination.limit.unwrap_or(10)
)
},
);
let mut service = router.build();
let req = HttpRequest::builder()
.uri("/search?page=2&limit=20")
.body(axum::body::Body::empty())
.unwrap();
let response = service.call(req).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
assert_eq!(body, "Page: 2, Limit: 20");
}
}