use crate::Result;
use std::collections::HashMap;
#[derive(Debug, Clone, PartialEq, Eq, Hash, serde::Deserialize, serde::Serialize)]
#[serde(rename_all = "lowercase")]
pub enum HttpMethod {
GET,
POST,
PUT,
DELETE,
PATCH,
HEAD,
OPTIONS,
}
#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
pub struct Route {
pub method: HttpMethod,
pub path: String,
pub priority: i32,
pub metadata: HashMap<String, serde_json::Value>,
}
impl Route {
pub fn new(method: HttpMethod, path: String) -> Self {
Self {
method,
path,
priority: 0,
metadata: HashMap::new(),
}
}
pub fn with_priority(mut self, priority: i32) -> Self {
self.priority = priority;
self
}
pub fn with_metadata(mut self, key: String, value: serde_json::Value) -> Self {
self.metadata.insert(key, value);
self
}
}
fn to_matchit_pattern(pattern: &str) -> String {
if !pattern.contains('*') {
return pattern.to_string();
}
pattern
.split('/')
.enumerate()
.map(|(i, seg)| {
if seg == "*" {
format!("{{w{i}}}")
} else {
seg.to_string()
}
})
.collect::<Vec<_>>()
.join("/")
}
#[derive(Clone)]
struct MethodIndex {
router: matchit::Router<Vec<usize>>,
routes: Vec<Route>,
}
impl std::fmt::Debug for MethodIndex {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MethodIndex").field("routes", &self.routes).finish()
}
}
impl MethodIndex {
fn new() -> Self {
Self {
router: matchit::Router::new(),
routes: Vec::new(),
}
}
fn insert(&mut self, route: Route) {
let idx = self.routes.len();
let matchit_path = to_matchit_pattern(&route.path);
self.routes.push(route);
match self.router.insert(matchit_path.clone(), vec![idx]) {
Ok(()) => {}
Err(_) => {
if let Ok(matched) = self.router.at_mut(&matchit_path) {
matched.value.push(idx);
}
}
}
}
fn find(&self, path: &str) -> Vec<&Route> {
match self.router.at(path) {
Ok(matched) => matched.value.iter().map(|&i| &self.routes[i]).collect(),
Err(_) => Vec::new(),
}
}
fn all(&self) -> Vec<&Route> {
self.routes.iter().collect()
}
}
#[derive(Debug, Clone)]
pub struct RouteRegistry {
http_routes: HashMap<HttpMethod, MethodIndex>,
ws_routes: Vec<Route>,
grpc_routes: HashMap<String, Vec<Route>>,
}
impl RouteRegistry {
pub fn new() -> Self {
Self {
http_routes: HashMap::new(),
ws_routes: Vec::new(),
grpc_routes: HashMap::new(),
}
}
pub fn add_http_route(&mut self, route: Route) -> Result<()> {
self.http_routes
.entry(route.method.clone())
.or_insert_with(MethodIndex::new)
.insert(route);
Ok(())
}
pub fn add_ws_route(&mut self, route: Route) -> Result<()> {
self.ws_routes.push(route);
Ok(())
}
pub fn clear(&mut self) {
self.http_routes.clear();
self.ws_routes.clear();
self.grpc_routes.clear();
}
pub fn add_route(&mut self, route: Route) -> Result<()> {
self.add_http_route(route)
}
pub fn add_grpc_route(&mut self, service: String, route: Route) -> Result<()> {
self.grpc_routes.entry(service).or_default().push(route);
Ok(())
}
pub fn find_http_routes(&self, method: &HttpMethod, path: &str) -> Vec<&Route> {
self.http_routes.get(method).map(|index| index.find(path)).unwrap_or_default()
}
pub fn find_ws_routes(&self, path: &str) -> Vec<&Route> {
self.ws_routes
.iter()
.filter(|route| self.matches_path(&route.path, path))
.collect()
}
pub fn find_grpc_routes(&self, service: &str, method: &str) -> Vec<&Route> {
self.grpc_routes
.get(service)
.map(|routes| {
routes.iter().filter(|route| self.matches_path(&route.path, method)).collect()
})
.unwrap_or_default()
}
fn matches_path(&self, pattern: &str, path: &str) -> bool {
if pattern == path {
return true;
}
if pattern.contains('*') {
let pattern_parts: Vec<&str> = pattern.split('/').collect();
let path_parts: Vec<&str> = path.split('/').collect();
if pattern_parts.len() != path_parts.len() {
return false;
}
for (pattern_part, path_part) in pattern_parts.iter().zip(path_parts.iter()) {
if *pattern_part != "*" && *pattern_part != *path_part {
return false;
}
}
return true;
}
false
}
pub fn get_http_routes(&self, method: &HttpMethod) -> Vec<&Route> {
self.http_routes.get(method).map(|index| index.all()).unwrap_or_default()
}
pub fn get_ws_routes(&self) -> Vec<&Route> {
self.ws_routes.iter().collect()
}
pub fn get_grpc_routes(&self, service: &str) -> Vec<&Route> {
self.grpc_routes
.get(service)
.map(|routes| routes.iter().collect())
.unwrap_or_default()
}
}
impl Default for RouteRegistry {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_route_new() {
let route = Route::new(HttpMethod::GET, "/api/users".to_string());
assert_eq!(route.method, HttpMethod::GET);
assert_eq!(route.path, "/api/users");
assert_eq!(route.priority, 0);
assert!(route.metadata.is_empty());
}
#[test]
fn test_route_with_priority() {
let route = Route::new(HttpMethod::POST, "/api/users".to_string()).with_priority(10);
assert_eq!(route.priority, 10);
}
#[test]
fn test_route_with_metadata() {
let route = Route::new(HttpMethod::GET, "/api/users".to_string())
.with_metadata("version".to_string(), serde_json::json!("v1"))
.with_metadata("auth".to_string(), serde_json::json!(true));
assert_eq!(route.metadata.get("version"), Some(&serde_json::json!("v1")));
assert_eq!(route.metadata.get("auth"), Some(&serde_json::json!(true)));
}
#[test]
fn test_route_registry_new() {
let registry = RouteRegistry::new();
assert!(registry.http_routes.is_empty());
assert!(registry.ws_routes.is_empty());
assert!(registry.grpc_routes.is_empty());
}
#[test]
fn test_route_registry_default() {
let registry = RouteRegistry::default();
assert!(registry.http_routes.is_empty());
}
#[test]
fn test_add_http_route() {
let mut registry = RouteRegistry::new();
let route = Route::new(HttpMethod::GET, "/api/users".to_string());
assert!(registry.add_http_route(route).is_ok());
assert_eq!(registry.get_http_routes(&HttpMethod::GET).len(), 1);
}
#[test]
fn test_add_multiple_http_routes() {
let mut registry = RouteRegistry::new();
registry
.add_http_route(Route::new(HttpMethod::GET, "/api/users".to_string()))
.unwrap();
registry
.add_http_route(Route::new(HttpMethod::GET, "/api/posts".to_string()))
.unwrap();
registry
.add_http_route(Route::new(HttpMethod::POST, "/api/users".to_string()))
.unwrap();
assert_eq!(registry.get_http_routes(&HttpMethod::GET).len(), 2);
assert_eq!(registry.get_http_routes(&HttpMethod::POST).len(), 1);
}
#[test]
fn test_add_ws_route() {
let mut registry = RouteRegistry::new();
let route = Route::new(HttpMethod::GET, "/ws/chat".to_string());
assert!(registry.add_ws_route(route).is_ok());
assert_eq!(registry.get_ws_routes().len(), 1);
}
#[test]
fn test_add_grpc_route() {
let mut registry = RouteRegistry::new();
let route = Route::new(HttpMethod::POST, "GetUser".to_string());
assert!(registry.add_grpc_route("UserService".to_string(), route).is_ok());
assert_eq!(registry.get_grpc_routes("UserService").len(), 1);
}
#[test]
fn test_add_route_alias() {
let mut registry = RouteRegistry::new();
let route = Route::new(HttpMethod::GET, "/api/test".to_string());
assert!(registry.add_route(route).is_ok());
assert_eq!(registry.get_http_routes(&HttpMethod::GET).len(), 1);
}
#[test]
fn test_clear() {
let mut registry = RouteRegistry::new();
registry
.add_http_route(Route::new(HttpMethod::GET, "/api/users".to_string()))
.unwrap();
registry
.add_ws_route(Route::new(HttpMethod::GET, "/ws/chat".to_string()))
.unwrap();
registry
.add_grpc_route(
"Service".to_string(),
Route::new(HttpMethod::POST, "Method".to_string()),
)
.unwrap();
assert!(!registry.get_http_routes(&HttpMethod::GET).is_empty());
assert!(!registry.get_ws_routes().is_empty());
registry.clear();
assert!(registry.get_http_routes(&HttpMethod::GET).is_empty());
assert!(registry.get_ws_routes().is_empty());
assert!(registry.get_grpc_routes("Service").is_empty());
}
#[test]
fn test_find_http_routes_exact_match() {
let mut registry = RouteRegistry::new();
registry
.add_http_route(Route::new(HttpMethod::GET, "/api/users".to_string()))
.unwrap();
let found = registry.find_http_routes(&HttpMethod::GET, "/api/users");
assert_eq!(found.len(), 1);
assert_eq!(found[0].path, "/api/users");
}
#[test]
fn test_find_http_routes_no_match() {
let mut registry = RouteRegistry::new();
registry
.add_http_route(Route::new(HttpMethod::GET, "/api/users".to_string()))
.unwrap();
let found = registry.find_http_routes(&HttpMethod::GET, "/api/posts");
assert_eq!(found.len(), 0);
}
#[test]
fn test_find_http_routes_wildcard_match() {
let mut registry = RouteRegistry::new();
registry
.add_http_route(Route::new(HttpMethod::GET, "/api/*/details".to_string()))
.unwrap();
let found = registry.find_http_routes(&HttpMethod::GET, "/api/users/details");
assert_eq!(found.len(), 1);
let found = registry.find_http_routes(&HttpMethod::GET, "/api/posts/details");
assert_eq!(found.len(), 1);
}
#[test]
fn test_find_http_routes_wildcard_no_match_different_length() {
let mut registry = RouteRegistry::new();
registry
.add_http_route(Route::new(HttpMethod::GET, "/api/*/details".to_string()))
.unwrap();
let found = registry.find_http_routes(&HttpMethod::GET, "/api/users");
assert_eq!(found.len(), 0);
}
#[test]
fn test_find_ws_routes() {
let mut registry = RouteRegistry::new();
registry
.add_ws_route(Route::new(HttpMethod::GET, "/ws/chat".to_string()))
.unwrap();
let found = registry.find_ws_routes("/ws/chat");
assert_eq!(found.len(), 1);
}
#[test]
fn test_find_ws_routes_wildcard() {
let mut registry = RouteRegistry::new();
registry.add_ws_route(Route::new(HttpMethod::GET, "/ws/*".to_string())).unwrap();
let found = registry.find_ws_routes("/ws/chat");
assert_eq!(found.len(), 1);
let found = registry.find_ws_routes("/ws/notifications");
assert_eq!(found.len(), 1);
}
#[test]
fn test_find_grpc_routes() {
let mut registry = RouteRegistry::new();
registry
.add_grpc_route(
"UserService".to_string(),
Route::new(HttpMethod::POST, "GetUser".to_string()),
)
.unwrap();
let found = registry.find_grpc_routes("UserService", "GetUser");
assert_eq!(found.len(), 1);
}
#[test]
fn test_find_grpc_routes_wildcard() {
let mut registry = RouteRegistry::new();
registry
.add_grpc_route(
"UserService".to_string(),
Route::new(HttpMethod::POST, "GetUser".to_string()),
)
.unwrap();
let found = registry.find_grpc_routes("UserService", "GetUser");
assert_eq!(found.len(), 1);
}
#[test]
fn test_matches_path_exact() {
let registry = RouteRegistry::new();
assert!(registry.matches_path("/api/users", "/api/users"));
assert!(!registry.matches_path("/api/users", "/api/posts"));
}
#[test]
fn test_matches_path_wildcard_single_segment() {
let registry = RouteRegistry::new();
assert!(registry.matches_path("/api/*", "/api/users"));
assert!(registry.matches_path("/api/*", "/api/posts"));
assert!(!registry.matches_path("/api/*", "/api"));
assert!(!registry.matches_path("/api/*", "/api/users/123"));
}
#[test]
fn test_matches_path_wildcard_multiple_segments() {
let registry = RouteRegistry::new();
assert!(registry.matches_path("/api/*/details", "/api/users/details"));
assert!(registry.matches_path("/api/*/*", "/api/users/123"));
assert!(!registry.matches_path("/api/*/*", "/api/users"));
}
#[test]
fn test_get_http_routes_empty() {
let registry = RouteRegistry::new();
assert!(registry.get_http_routes(&HttpMethod::GET).is_empty());
}
#[test]
fn test_get_ws_routes_empty() {
let registry = RouteRegistry::new();
assert!(registry.get_ws_routes().is_empty());
}
#[test]
fn test_get_grpc_routes_empty() {
let registry = RouteRegistry::new();
assert!(registry.get_grpc_routes("Service").is_empty());
}
#[test]
fn test_http_method_serialization() {
let method = HttpMethod::GET;
let json = serde_json::to_string(&method).unwrap();
assert_eq!(json, r#""get""#);
let method = HttpMethod::POST;
let json = serde_json::to_string(&method).unwrap();
assert_eq!(json, r#""post""#);
}
#[test]
fn test_http_method_deserialization() {
let method: HttpMethod = serde_json::from_str(r#""get""#).unwrap();
assert_eq!(method, HttpMethod::GET);
let method: HttpMethod = serde_json::from_str(r#""post""#).unwrap();
assert_eq!(method, HttpMethod::POST);
}
#[test]
fn test_to_matchit_pattern() {
assert_eq!(to_matchit_pattern("/api/users"), "/api/users");
assert_eq!(to_matchit_pattern("/api/*/details"), "/api/{w2}/details");
assert_eq!(to_matchit_pattern("/api/*/*"), "/api/{w2}/{w3}");
assert_eq!(to_matchit_pattern("/*"), "/{w1}");
}
#[test]
fn test_matchit_many_routes_performance() {
let mut registry = RouteRegistry::new();
for i in 0..200 {
registry
.add_http_route(Route::new(HttpMethod::GET, format!("/api/v1/resource{i}")))
.unwrap();
}
let found = registry.find_http_routes(&HttpMethod::GET, "/api/v1/resource199");
assert_eq!(found.len(), 1);
assert_eq!(found[0].path, "/api/v1/resource199");
let found = registry.find_http_routes(&HttpMethod::GET, "/api/v1/resource999");
assert_eq!(found.len(), 0);
}
}