use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
use tracing::{debug, trace};
use zentinel_common::ids::{QualifiedId, Scope};
use zentinel_common::RouteId;
use zentinel_config::{FlattenedConfig, RouteConfig};
use crate::routing::{RequestInfo, RouteError, RouteMatch, RouteMatcher};
pub struct ScopedRouteMatcher {
matchers: Arc<RwLock<HashMap<Scope, RouteMatcher>>>,
routes_by_qid: Arc<RwLock<HashMap<String, Arc<RouteConfig>>>>,
default_routes: Arc<RwLock<HashMap<Scope, RouteId>>>,
global_default: Arc<RwLock<Option<RouteId>>>,
}
#[derive(Debug, Clone)]
pub struct ScopedRouteMatch {
pub inner: RouteMatch,
pub qualified_id: QualifiedId,
pub matched_scope: Scope,
}
impl ScopedRouteMatch {
pub fn route_id(&self) -> &str {
self.inner.route_id.as_str()
}
pub fn config(&self) -> &Arc<RouteConfig> {
&self.inner.config
}
pub fn namespace(&self) -> Option<&str> {
match &self.matched_scope {
Scope::Global => None,
Scope::Namespace(ns) => Some(ns),
Scope::Service { namespace, .. } => Some(namespace),
}
}
pub fn service(&self) -> Option<&str> {
match &self.matched_scope {
Scope::Service { service, .. } => Some(service),
_ => None,
}
}
}
impl ScopedRouteMatcher {
pub fn new() -> Self {
Self {
matchers: Arc::new(RwLock::new(HashMap::new())),
routes_by_qid: Arc::new(RwLock::new(HashMap::new())),
default_routes: Arc::new(RwLock::new(HashMap::new())),
global_default: Arc::new(RwLock::new(None)),
}
}
pub async fn from_flattened(config: &FlattenedConfig) -> Result<Self, RouteError> {
let matcher = Self::new();
matcher.load_from_flattened(config).await?;
Ok(matcher)
}
pub async fn load_from_flattened(&self, config: &FlattenedConfig) -> Result<(), RouteError> {
let mut routes_by_scope: HashMap<Scope, Vec<RouteConfig>> = HashMap::new();
let mut routes_map = HashMap::new();
for (qid, route) in &config.routes {
routes_by_scope
.entry(qid.scope.clone())
.or_default()
.push(route.clone());
routes_map.insert(qid.canonical(), Arc::new(route.clone()));
}
let mut matchers = HashMap::new();
for (scope, routes) in routes_by_scope {
debug!(
scope = ?scope,
route_count = routes.len(),
"Creating route matcher for scope"
);
let matcher = RouteMatcher::new(routes, None)?;
matchers.insert(scope, matcher);
}
*self.matchers.write().await = matchers;
*self.routes_by_qid.write().await = routes_map;
Ok(())
}
pub async fn set_default_route(&self, scope: Scope, route_id: impl Into<String>) {
self.default_routes
.write()
.await
.insert(scope, RouteId::new(route_id));
}
pub async fn set_global_default(&self, route_id: impl Into<String>) {
*self.global_default.write().await = Some(RouteId::new(route_id));
}
pub async fn match_request(
&self,
req: &RequestInfo<'_>,
from_scope: &Scope,
) -> Option<ScopedRouteMatch> {
trace!(
method = %req.method,
path = %req.path,
host = %req.host,
scope = ?from_scope,
"Starting scoped route matching"
);
let matchers = self.matchers.read().await;
let routes_by_qid = self.routes_by_qid.read().await;
for scope in from_scope.chain() {
if let Some(matcher) = matchers.get(&scope) {
if let Some(route_match) = matcher.match_request(req) {
let qid = QualifiedId {
name: route_match.route_id.as_str().to_string(),
scope: scope.clone(),
};
debug!(
route_id = %route_match.route_id,
scope = ?scope,
from_scope = ?from_scope,
"Route matched in scope"
);
return Some(ScopedRouteMatch {
inner: route_match,
qualified_id: qid,
matched_scope: scope,
});
}
}
}
let defaults = self.default_routes.read().await;
for scope in from_scope.chain() {
if let Some(default_id) = defaults.get(&scope) {
let qid = QualifiedId {
name: default_id.as_str().to_string(),
scope: scope.clone(),
};
if let Some(config) = routes_by_qid.get(&qid.canonical()) {
debug!(
route_id = %default_id,
scope = ?scope,
"Using scope default route"
);
return Some(ScopedRouteMatch {
inner: RouteMatch {
route_id: default_id.clone(),
config: Arc::clone(config),
},
qualified_id: qid,
matched_scope: scope,
});
}
}
}
if let Some(ref global_default) = *self.global_default.read().await {
let qid = QualifiedId::global(global_default.as_str());
if let Some(config) = routes_by_qid.get(&qid.canonical()) {
debug!(
route_id = %global_default,
"Using global default route"
);
return Some(ScopedRouteMatch {
inner: RouteMatch {
route_id: global_default.clone(),
config: Arc::clone(config),
},
qualified_id: qid,
matched_scope: Scope::Global,
});
}
}
debug!(
method = %req.method,
path = %req.path,
from_scope = ?from_scope,
"No route matched in any visible scope"
);
None
}
pub async fn get_route(&self, qid: &QualifiedId) -> Option<Arc<RouteConfig>> {
self.routes_by_qid
.read()
.await
.get(&qid.canonical())
.cloned()
}
pub async fn needs_headers(&self) -> bool {
self.matchers
.read()
.await
.values()
.any(|m| m.needs_headers())
}
pub async fn needs_query_params(&self) -> bool {
self.matchers
.read()
.await
.values()
.any(|m| m.needs_query_params())
}
pub async fn clear_caches(&self) {
for matcher in self.matchers.read().await.values() {
matcher.clear_cache();
}
}
pub async fn scope_count(&self) -> usize {
self.matchers.read().await.len()
}
pub async fn total_routes(&self) -> usize {
self.routes_by_qid.read().await.len()
}
pub async fn scopes(&self) -> Vec<Scope> {
self.matchers.read().await.keys().cloned().collect()
}
}
impl Default for ScopedRouteMatcher {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use zentinel_common::types::Priority;
use zentinel_config::{MatchCondition, RoutePolicies, ServiceType};
fn test_route(id: &str, path_prefix: &str) -> RouteConfig {
RouteConfig {
id: id.to_string(),
priority: Priority::NORMAL,
matches: vec![MatchCondition::PathPrefix(path_prefix.to_string())],
upstream: Some("test-upstream".to_string()),
service_type: ServiceType::Web,
policies: RoutePolicies::default(),
filters: vec![],
builtin_handler: None,
waf_enabled: false,
circuit_breaker: None,
retry_policy: None,
static_files: None,
api_schema: None,
error_pages: None,
websocket: false,
websocket_inspection: false,
inference: None,
shadow: None,
fallback: None,
}
}
fn mock_flattened_config() -> FlattenedConfig {
let mut config = FlattenedConfig::new();
config.routes.push((
QualifiedId::global("global-route"),
test_route("global-route", "/"),
));
config.routes.push((
QualifiedId::namespaced("api", "api-route"),
test_route("api-route", "/api/"),
));
config.routes.push((
QualifiedId::in_service("api", "payments", "payments-route"),
test_route("payments-route", "/payments/"),
));
config
}
#[tokio::test]
async fn test_match_from_global_scope() {
let config = mock_flattened_config();
let matcher = ScopedRouteMatcher::from_flattened(&config).await.unwrap();
let req = RequestInfo::new("GET", "/test", "example.com");
let result = matcher.match_request(&req, &Scope::Global).await;
assert!(result.is_some());
let route_match = result.unwrap();
assert_eq!(route_match.route_id(), "global-route");
assert_eq!(route_match.matched_scope, Scope::Global);
}
#[tokio::test]
async fn test_match_from_namespace_scope() {
let config = mock_flattened_config();
let matcher = ScopedRouteMatcher::from_flattened(&config).await.unwrap();
let ns_scope = Scope::Namespace("api".to_string());
let req = RequestInfo::new("GET", "/api/users", "example.com");
let result = matcher.match_request(&req, &ns_scope).await;
assert!(result.is_some());
let route_match = result.unwrap();
assert_eq!(route_match.route_id(), "api-route");
assert_eq!(
route_match.matched_scope,
Scope::Namespace("api".to_string())
);
let req = RequestInfo::new("GET", "/other", "example.com");
let result = matcher.match_request(&req, &ns_scope).await;
assert!(result.is_some());
let route_match = result.unwrap();
assert_eq!(route_match.route_id(), "global-route");
assert_eq!(route_match.matched_scope, Scope::Global);
}
#[tokio::test]
async fn test_match_from_service_scope() {
let config = mock_flattened_config();
let matcher = ScopedRouteMatcher::from_flattened(&config).await.unwrap();
let svc_scope = Scope::Service {
namespace: "api".to_string(),
service: "payments".to_string(),
};
let req = RequestInfo::new("GET", "/payments/checkout", "example.com");
let result = matcher.match_request(&req, &svc_scope).await;
assert!(result.is_some());
let route_match = result.unwrap();
assert_eq!(route_match.route_id(), "payments-route");
assert!(matches!(route_match.matched_scope, Scope::Service { .. }));
let req = RequestInfo::new("GET", "/api/users", "example.com");
let result = matcher.match_request(&req, &svc_scope).await;
assert!(result.is_some());
let route_match = result.unwrap();
assert_eq!(route_match.route_id(), "api-route");
let req = RequestInfo::new("GET", "/other", "example.com");
let result = matcher.match_request(&req, &svc_scope).await;
assert!(result.is_some());
let route_match = result.unwrap();
assert_eq!(route_match.route_id(), "global-route");
}
#[tokio::test]
async fn test_scope_info_in_match() {
let config = mock_flattened_config();
let matcher = ScopedRouteMatcher::from_flattened(&config).await.unwrap();
let svc_scope = Scope::Service {
namespace: "api".to_string(),
service: "payments".to_string(),
};
let req = RequestInfo::new("GET", "/payments/checkout", "example.com");
let result = matcher.match_request(&req, &svc_scope).await.unwrap();
assert_eq!(result.namespace(), Some("api"));
assert_eq!(result.service(), Some("payments"));
}
#[tokio::test]
async fn test_default_route() {
let config = mock_flattened_config();
let matcher = ScopedRouteMatcher::from_flattened(&config).await.unwrap();
matcher.set_global_default("global-route").await;
let req = RequestInfo::new("GET", "/nonexistent", "example.com");
let result = matcher.match_request(&req, &Scope::Global).await;
assert!(result.is_some());
}
#[tokio::test]
async fn test_no_match() {
let mut config = FlattenedConfig::new();
config.routes.push((
QualifiedId::global("specific-route"),
test_route("specific-route", "/specific/"),
));
let matcher = ScopedRouteMatcher::from_flattened(&config).await.unwrap();
let req = RequestInfo::new("GET", "/other", "example.com");
let result = matcher.match_request(&req, &Scope::Global).await;
assert!(result.is_none());
}
}