pub mod segment;
use std::{collections::HashMap, fmt, sync::Arc};
pub use segment::{RouteSegment, SegmentType};
use tairitsu_vdom::VNode;
#[derive(Clone)]
pub struct Route {
pub path: String,
pub handler: RouteHandler,
pub name: Option<String>,
pub middleware: Vec<RouteMiddleware>,
pub exact: bool,
}
impl Route {
pub fn new(path: impl Into<String>, handler: impl Into<RouteHandler>) -> Self {
Self {
path: path.into(),
handler: handler.into(),
name: None,
middleware: Vec::new(),
exact: true,
}
}
pub fn prefix(path: impl Into<String>, handler: impl Into<RouteHandler>) -> Self {
Self {
path: path.into(),
handler: handler.into(),
name: None,
middleware: Vec::new(),
exact: false,
}
}
pub fn name(mut self, name: impl Into<String>) -> Self {
self.name = Some(name.into());
self
}
pub fn middleware(mut self, middleware: RouteMiddleware) -> Self {
self.middleware.push(middleware);
self
}
}
pub type RouteHandler = Arc<dyn Fn(Params) -> VNode + Send + Sync>;
pub type Params = HashMap<String, String>;
pub type RouteMiddleware = Arc<dyn Fn(&mut Params) -> Result<(), MiddlewareError> + Send + Sync>;
#[derive(Clone, Debug, PartialEq)]
pub enum MiddlewareError {
Forbidden,
Unauthorized,
Custom(String),
}
impl fmt::Display for MiddlewareError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
MiddlewareError::Forbidden => write!(f, "Forbidden"),
MiddlewareError::Unauthorized => write!(f, "Unauthorized"),
MiddlewareError::Custom(msg) => write!(f, "{}", msg),
}
}
}
#[derive(Clone)]
pub struct RouteMatch {
pub route: Route,
pub params: Params,
}
pub struct Router {
routes: Vec<Route>,
fallback: Option<RouteHandler>,
global_middleware: Vec<RouteMiddleware>,
}
impl Default for Router {
fn default() -> Self {
Self::new()
}
}
impl Router {
pub fn new() -> Self {
Self {
routes: Vec::new(),
fallback: None,
global_middleware: Vec::new(),
}
}
pub fn route(mut self, path: impl Into<String>, handler: impl Into<RouteHandler>) -> Self {
self.routes.push(Route::new(path, handler));
self
}
pub fn prefix(mut self, path: impl Into<String>, handler: impl Into<RouteHandler>) -> Self {
self.routes.push(Route::prefix(path, handler));
self
}
pub fn named_route(
mut self,
name: impl Into<String>,
path: impl Into<String>,
handler: impl Into<RouteHandler>,
) -> Self {
let mut route = Route::new(path, handler);
route.name = Some(name.into());
self.routes.push(route);
self
}
pub fn fallback(mut self, handler: impl Into<RouteHandler>) -> Self {
self.fallback = Some(handler.into());
self
}
pub fn middleware(mut self, middleware: RouteMiddleware) -> Self {
self.global_middleware.push(middleware);
self
}
pub fn match_route(&self, path: impl AsRef<str>) -> Option<RouteMatch> {
let path = path.as_ref().trim_start_matches('/');
let path = path.trim_end_matches('/');
for route in &self.routes {
let pattern = route.path.trim_start_matches('/');
let pattern = pattern.trim_end_matches('/');
if let Some(params) = self.match_pattern(pattern, path) {
return Some(RouteMatch {
route: route.clone(),
params,
});
}
}
None
}
fn match_pattern(&self, pattern: &str, path: &str) -> Option<Params> {
let pattern_segments: Vec<&str> = pattern.split('/').filter(|s| !s.is_empty()).collect();
let path_segments: Vec<&str> = path.split('/').filter(|s| !s.is_empty()).collect();
let mut params = Params::new();
if pattern_segments.len() != path_segments.len() {
return None;
}
for (pattern_seg, path_seg) in pattern_segments.iter().zip(path_segments.iter()) {
if let Some(param_name) = pattern_seg.strip_prefix(':') {
params.insert(param_name.to_string(), path_seg.to_string());
} else if **pattern_seg == *"*" {
} else if pattern_seg != path_seg {
return None;
}
}
Some(params)
}
pub fn render(&self, path: impl AsRef<str>) -> VNode {
if let Some(matched) = self.match_route(path) {
let mut params = matched.params;
for middleware in &self.global_middleware {
if let Err(e) = middleware(&mut params) {
return self.render_error(e);
}
}
for middleware in &matched.route.middleware {
if let Err(e) = middleware(&mut params) {
return self.render_error(e);
}
}
(matched.route.handler)(params)
} else if let Some(fallback) = &self.fallback {
(fallback)(Params::new())
} else {
self.default_404()
}
}
fn render_error(&self, error: MiddlewareError) -> VNode {
let message = error.to_string();
VNode::Element(
tairitsu_vdom::VElement::new("div")
.attr("class", "error-page")
.child(VNode::Element(tairitsu_vdom::VElement::new("h1").child(
VNode::Text(tairitsu_vdom::VText::new(&format!("Error: {}", message))),
)))
.child(VNode::Element(tairitsu_vdom::VElement::new("p").child(
VNode::Text(tairitsu_vdom::VText::new(
"An error occurred while processing your request.",
)),
))),
)
}
fn default_404(&self) -> VNode {
VNode::Element(
tairitsu_vdom::VElement::new("div")
.attr("class", "error-404")
.child(VNode::Element(
tairitsu_vdom::VElement::new("h1")
.child(VNode::Text(tairitsu_vdom::VText::new("404"))),
))
.child(VNode::Element(tairitsu_vdom::VElement::new("p").child(
VNode::Text(tairitsu_vdom::VText::new("Page not found")),
))),
)
}
pub fn url_for(&self, name: &str, params: &[(&str, &str)]) -> Option<String> {
for route in &self.routes {
if route.name.as_deref() == Some(name) {
let mut url = route.path.clone();
for (key, value) in params {
url = url.replace(&format!(":{}", key), value);
}
return Some(url);
}
}
None
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
fn mock_handler(_params: Params) -> VNode {
VNode::Text(tairitsu_vdom::VText::new("mock"))
}
fn wrap_handler(f: fn(Params) -> VNode) -> RouteHandler {
Arc::new(f)
}
#[test]
fn test_route_creation() {
let route = Route::new("/test", wrap_handler(mock_handler));
assert_eq!(route.path, "/test");
assert!(route.exact);
}
#[test]
fn test_router_static_route() {
let router = Router::new().route("/test", wrap_handler(mock_handler));
let matched = router.match_route("/test");
assert!(matched.is_some());
assert_eq!(matched.unwrap().params.len(), 0);
}
#[test]
fn test_router_dynamic_route() {
let router = Router::new().route("/users/:id", wrap_handler(mock_handler));
let matched = router.match_route("/users/123");
assert!(matched.is_some());
assert_eq!(matched.unwrap().params.get("id"), Some(&"123".to_string()));
}
#[test]
fn test_router_no_match() {
let router = Router::new().route("/test", wrap_handler(mock_handler));
let matched = router.match_route("/other");
assert!(matched.is_none());
}
#[test]
fn test_router_multiple_params() {
let router = Router::new().route(
"/posts/:post_id/comments/:comment_id",
wrap_handler(mock_handler),
);
let matched = router.match_route("/posts/abc/comments/xyz");
assert!(matched.is_some());
let params = matched.unwrap().params;
assert_eq!(params.get("post_id"), Some(&"abc".to_string()));
assert_eq!(params.get("comment_id"), Some(&"xyz".to_string()));
}
#[test]
fn test_router_fallback() {
let router = Router::new()
.route("/test", wrap_handler(mock_handler))
.fallback(wrap_handler(mock_handler));
let vnode = router.render("/other");
assert!(matches!(vnode, VNode::Text(_)));
}
#[test]
fn test_named_route() {
let router = Router::new().named_route("user", "/users/:id", wrap_handler(mock_handler));
let url = router.url_for("user", &[("id", "123")]);
assert_eq!(url, Some("/users/123".to_string()));
}
#[test]
fn test_segment_parsing() {
let segments = RouteSegment::parse_path("/users/:id/posts/:post_id");
assert_eq!(segments.len(), 4);
assert_eq!(segments[0].to_string(), "users");
assert!(segments[1].is_dynamic());
assert_eq!(segments[2].to_string(), "posts");
assert!(segments[3].is_dynamic());
}
}