use super::{HttpMethod, RouteInfo, RouteRegistry, params::{ParamExtractor, ParamType}};
use axum::{
Router as AxumRouter,
routing::{get, post, put, delete, patch},
handler::Handler,
response::IntoResponse,
};
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
#[derive(Debug)]
pub struct Router<S = ()>
where
S: Clone + Send + Sync + 'static,
{
axum_router: AxumRouter<S>,
registry: Arc<Mutex<RouteRegistry>>,
route_counter: Arc<Mutex<usize>>,
}
impl<S> Router<S>
where
S: Clone + Send + Sync + 'static,
{
pub fn new() -> Self {
Self {
axum_router: AxumRouter::new(),
registry: Arc::new(Mutex::new(RouteRegistry::new())),
route_counter: Arc::new(Mutex::new(0)),
}
}
pub fn with_state(state: S) -> Self {
Self {
axum_router: AxumRouter::new().with_state(state),
registry: Arc::new(Mutex::new(RouteRegistry::new())),
route_counter: Arc::new(Mutex::new(0)),
}
}
fn next_route_id(&self) -> String {
let mut counter = self.route_counter.lock().unwrap();
*counter += 1;
format!("route_{}", counter)
}
fn register_route(&self, method: HttpMethod, path: &str, name: Option<String>) -> String {
let route_id = self.next_route_id();
let params = self.extract_param_names(path);
let route_info = RouteInfo {
name: name.clone(),
path: path.to_string(),
method,
params,
group: None, };
self.registry.lock().unwrap().register(route_id.clone(), route_info);
route_id
}
fn extract_param_names(&self, path: &str) -> Vec<String> {
path.split('/')
.filter_map(|segment| {
if segment.starts_with('{') && segment.ends_with('}') {
Some(segment[1..segment.len()-1].to_string())
} else {
None
}
})
.collect()
}
pub fn get<H, T>(mut self, path: &str, handler: H) -> Self
where
H: Handler<T, S>,
T: 'static,
{
self.register_route(HttpMethod::GET, path, None);
self.axum_router = self.axum_router.route(path, get(handler));
self
}
pub fn post<H, T>(mut self, path: &str, handler: H) -> Self
where
H: Handler<T, S>,
T: 'static,
{
self.register_route(HttpMethod::POST, path, None);
self.axum_router = self.axum_router.route(path, post(handler));
self
}
pub fn put<H, T>(mut self, path: &str, handler: H) -> Self
where
H: Handler<T, S>,
T: 'static,
{
self.register_route(HttpMethod::PUT, path, None);
self.axum_router = self.axum_router.route(path, put(handler));
self
}
pub fn delete<H, T>(mut self, path: &str, handler: H) -> Self
where
H: Handler<T, S>,
T: 'static,
{
self.register_route(HttpMethod::DELETE, path, None);
self.axum_router = self.axum_router.route(path, delete(handler));
self
}
pub fn patch<H, T>(mut self, path: &str, handler: H) -> Self
where
H: Handler<T, S>,
T: 'static,
{
self.register_route(HttpMethod::PATCH, path, None);
self.axum_router = self.axum_router.route(path, patch(handler));
self
}
pub fn merge(mut self, other: Router<S>) -> Self {
if let (Ok(mut self_registry), Ok(other_registry)) =
(self.registry.lock(), other.registry.lock()) {
for (id, route_info) in other_registry.all_routes() {
self_registry.register(id.clone(), route_info.clone());
}
}
self.axum_router = self.axum_router.merge(other.axum_router);
self
}
pub(crate) fn merge_axum(mut self, other: AxumRouter<S>) -> Self {
self.axum_router = self.axum_router.merge(other);
self
}
pub fn nest(mut self, path: &str, router: Router<S>) -> Self {
if let (Ok(mut self_registry), Ok(router_registry)) =
(self.registry.lock(), router.registry.lock()) {
for (id, route_info) in router_registry.all_routes() {
self_registry.register(id.clone(), route_info.clone());
}
}
self.axum_router = self.axum_router.nest(path, router.axum_router);
self
}
pub(crate) fn nest_axum(mut self, path: &str, router: AxumRouter<S>) -> Self {
self.axum_router = self.axum_router.nest(path, router);
self
}
pub fn into_axum_router(self) -> AxumRouter<S> {
self.axum_router
}
pub fn registry(&self) -> Arc<Mutex<RouteRegistry>> {
Arc::clone(&self.registry)
}
pub fn url_for(&self, name: &str, params: &HashMap<String, String>) -> Option<String> {
let registry = self.registry.lock().unwrap();
if let Some(route) = registry.get_by_name(name) {
let mut url = route.path.clone();
for (key, value) in params {
url = url.replace(&format!("{{{}}}", key), value);
}
Some(url)
} else {
None
}
}
}
impl<S> Default for Router<S>
where
S: Clone + Send + Sync + 'static,
{
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Default)]
pub struct RouteBuilder {
name: Option<String>,
param_types: HashMap<String, ParamType>,
middleware: Vec<String>, }
impl RouteBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn name(mut self, name: &str) -> Self {
self.name = Some(name.to_string());
self
}
pub fn param(mut self, name: &str, param_type: ParamType) -> Self {
self.param_types.insert(name.to_string(), param_type);
self
}
pub fn build(self) -> Route {
Route {
name: self.name,
param_types: self.param_types,
middleware: self.middleware,
}
}
}
#[derive(Debug)]
pub struct Route {
pub name: Option<String>,
pub param_types: HashMap<String, ParamType>,
pub middleware: Vec<String>,
}
impl Route {
pub fn builder() -> RouteBuilder {
RouteBuilder::new()
}
pub fn param_extractor(&self) -> ParamExtractor {
let mut extractor = ParamExtractor::new();
for (name, param_type) in &self.param_types {
extractor = extractor.param(name, param_type.clone());
}
extractor
}
}
#[cfg(test)]
mod tests {
use super::*;
use axum::response::Html;
async fn handler() -> Html<&'static str> {
Html("<h1>Hello, World!</h1>")
}
#[test]
fn test_router_creation() {
let router = Router::<()>::new()
.get("/", handler)
.post("/users", handler)
.get("/users/{id}", handler);
let registry = router.registry();
let reg = registry.lock().unwrap();
assert_eq!(reg.all_routes().len(), 3);
}
#[test]
fn test_param_extraction() {
let router = Router::<()>::new();
let params = router.extract_param_names("/users/{id}/posts/{slug}");
assert_eq!(params, vec!["id", "slug"]);
}
#[test]
fn test_url_generation() {
let mut router = Router::<()>::new().get("/users/{id}/posts/{slug}", handler);
{
let mut registry = router.registry.lock().unwrap();
let route_info = RouteInfo {
name: Some("user.posts.show".to_string()),
path: "/users/{id}/posts/{slug}".to_string(),
method: HttpMethod::GET,
params: vec!["id".to_string(), "slug".to_string()],
group: None,
};
registry.register("test_route".to_string(), route_info);
}
let mut params = HashMap::new();
params.insert("id".to_string(), "123".to_string());
params.insert("slug".to_string(), "hello-world".to_string());
let url = router.url_for("user.posts.show", ¶ms);
assert_eq!(url, Some("/users/123/posts/hello-world".to_string()));
}
}