1use std::collections::HashMap;
17use std::sync::Arc;
18use tokio::sync::RwLock;
19use tracing::{debug, trace};
20
21use grapsus_common::ids::{QualifiedId, Scope};
22use grapsus_common::RouteId;
23use grapsus_config::{FlattenedConfig, RouteConfig};
24
25use crate::routing::{RequestInfo, RouteError, RouteMatch, RouteMatcher};
26
27pub struct ScopedRouteMatcher {
32 matchers: Arc<RwLock<HashMap<Scope, RouteMatcher>>>,
34
35 routes_by_qid: Arc<RwLock<HashMap<String, Arc<RouteConfig>>>>,
37
38 default_routes: Arc<RwLock<HashMap<Scope, RouteId>>>,
40
41 global_default: Arc<RwLock<Option<RouteId>>>,
43}
44
45#[derive(Debug, Clone)]
47pub struct ScopedRouteMatch {
48 pub inner: RouteMatch,
50
51 pub qualified_id: QualifiedId,
53
54 pub matched_scope: Scope,
56}
57
58impl ScopedRouteMatch {
59 pub fn route_id(&self) -> &str {
61 self.inner.route_id.as_str()
62 }
63
64 pub fn config(&self) -> &Arc<RouteConfig> {
66 &self.inner.config
67 }
68
69 pub fn namespace(&self) -> Option<&str> {
71 match &self.matched_scope {
72 Scope::Global => None,
73 Scope::Namespace(ns) => Some(ns),
74 Scope::Service { namespace, .. } => Some(namespace),
75 }
76 }
77
78 pub fn service(&self) -> Option<&str> {
80 match &self.matched_scope {
81 Scope::Service { service, .. } => Some(service),
82 _ => None,
83 }
84 }
85}
86
87impl ScopedRouteMatcher {
88 pub fn new() -> Self {
90 Self {
91 matchers: Arc::new(RwLock::new(HashMap::new())),
92 routes_by_qid: Arc::new(RwLock::new(HashMap::new())),
93 default_routes: Arc::new(RwLock::new(HashMap::new())),
94 global_default: Arc::new(RwLock::new(None)),
95 }
96 }
97
98 pub async fn from_flattened(config: &FlattenedConfig) -> Result<Self, RouteError> {
100 let matcher = Self::new();
101 matcher.load_from_flattened(config).await?;
102 Ok(matcher)
103 }
104
105 pub async fn load_from_flattened(&self, config: &FlattenedConfig) -> Result<(), RouteError> {
107 let mut routes_by_scope: HashMap<Scope, Vec<RouteConfig>> = HashMap::new();
109 let mut routes_map = HashMap::new();
110
111 for (qid, route) in &config.routes {
112 routes_by_scope
113 .entry(qid.scope.clone())
114 .or_default()
115 .push(route.clone());
116 routes_map.insert(qid.canonical(), Arc::new(route.clone()));
117 }
118
119 let mut matchers = HashMap::new();
121 for (scope, routes) in routes_by_scope {
122 debug!(
123 scope = ?scope,
124 route_count = routes.len(),
125 "Creating route matcher for scope"
126 );
127 let matcher = RouteMatcher::new(routes, None)?;
128 matchers.insert(scope, matcher);
129 }
130
131 *self.matchers.write().await = matchers;
133 *self.routes_by_qid.write().await = routes_map;
134
135 Ok(())
136 }
137
138 pub async fn set_default_route(&self, scope: Scope, route_id: impl Into<String>) {
140 self.default_routes
141 .write()
142 .await
143 .insert(scope, RouteId::new(route_id));
144 }
145
146 pub async fn set_global_default(&self, route_id: impl Into<String>) {
148 *self.global_default.write().await = Some(RouteId::new(route_id));
149 }
150
151 pub async fn match_request(
160 &self,
161 req: &RequestInfo<'_>,
162 from_scope: &Scope,
163 ) -> Option<ScopedRouteMatch> {
164 trace!(
165 method = %req.method,
166 path = %req.path,
167 host = %req.host,
168 scope = ?from_scope,
169 "Starting scoped route matching"
170 );
171
172 let matchers = self.matchers.read().await;
173 let routes_by_qid = self.routes_by_qid.read().await;
174
175 for scope in from_scope.chain() {
177 if let Some(matcher) = matchers.get(&scope) {
178 if let Some(route_match) = matcher.match_request(req) {
179 let qid = QualifiedId {
181 name: route_match.route_id.as_str().to_string(),
182 scope: scope.clone(),
183 };
184
185 debug!(
186 route_id = %route_match.route_id,
187 scope = ?scope,
188 from_scope = ?from_scope,
189 "Route matched in scope"
190 );
191
192 return Some(ScopedRouteMatch {
193 inner: route_match,
194 qualified_id: qid,
195 matched_scope: scope,
196 });
197 }
198 }
199 }
200
201 let defaults = self.default_routes.read().await;
203 for scope in from_scope.chain() {
204 if let Some(default_id) = defaults.get(&scope) {
205 let qid = QualifiedId {
206 name: default_id.as_str().to_string(),
207 scope: scope.clone(),
208 };
209 if let Some(config) = routes_by_qid.get(&qid.canonical()) {
210 debug!(
211 route_id = %default_id,
212 scope = ?scope,
213 "Using scope default route"
214 );
215 return Some(ScopedRouteMatch {
216 inner: RouteMatch {
217 route_id: default_id.clone(),
218 config: Arc::clone(config),
219 },
220 qualified_id: qid,
221 matched_scope: scope,
222 });
223 }
224 }
225 }
226
227 if let Some(ref global_default) = *self.global_default.read().await {
229 let qid = QualifiedId::global(global_default.as_str());
230 if let Some(config) = routes_by_qid.get(&qid.canonical()) {
231 debug!(
232 route_id = %global_default,
233 "Using global default route"
234 );
235 return Some(ScopedRouteMatch {
236 inner: RouteMatch {
237 route_id: global_default.clone(),
238 config: Arc::clone(config),
239 },
240 qualified_id: qid,
241 matched_scope: Scope::Global,
242 });
243 }
244 }
245
246 debug!(
247 method = %req.method,
248 path = %req.path,
249 from_scope = ?from_scope,
250 "No route matched in any visible scope"
251 );
252 None
253 }
254
255 pub async fn get_route(&self, qid: &QualifiedId) -> Option<Arc<RouteConfig>> {
257 self.routes_by_qid
258 .read()
259 .await
260 .get(&qid.canonical())
261 .cloned()
262 }
263
264 pub async fn needs_headers(&self) -> bool {
266 self.matchers
267 .read()
268 .await
269 .values()
270 .any(|m| m.needs_headers())
271 }
272
273 pub async fn needs_query_params(&self) -> bool {
275 self.matchers
276 .read()
277 .await
278 .values()
279 .any(|m| m.needs_query_params())
280 }
281
282 pub async fn clear_caches(&self) {
284 for matcher in self.matchers.read().await.values() {
285 matcher.clear_cache();
286 }
287 }
288
289 pub async fn scope_count(&self) -> usize {
291 self.matchers.read().await.len()
292 }
293
294 pub async fn total_routes(&self) -> usize {
296 self.routes_by_qid.read().await.len()
297 }
298
299 pub async fn scopes(&self) -> Vec<Scope> {
301 self.matchers.read().await.keys().cloned().collect()
302 }
303}
304
305impl Default for ScopedRouteMatcher {
306 fn default() -> Self {
307 Self::new()
308 }
309}
310
311#[cfg(test)]
316mod tests {
317 use super::*;
318 use grapsus_common::types::Priority;
319 use grapsus_config::{MatchCondition, RoutePolicies, ServiceType};
320
321 fn test_route(id: &str, path_prefix: &str) -> RouteConfig {
322 RouteConfig {
323 id: id.to_string(),
324 priority: Priority::Normal,
325 matches: vec![MatchCondition::PathPrefix(path_prefix.to_string())],
326 upstream: Some("test-upstream".to_string()),
327 service_type: ServiceType::Web,
328 policies: RoutePolicies::default(),
329 filters: vec![],
330 builtin_handler: None,
331 waf_enabled: false,
332 circuit_breaker: None,
333 retry_policy: None,
334 static_files: None,
335 api_schema: None,
336 error_pages: None,
337 websocket: false,
338 websocket_inspection: false,
339 inference: None,
340 shadow: None,
341 fallback: None,
342 }
343 }
344
345 fn mock_flattened_config() -> FlattenedConfig {
346 let mut config = FlattenedConfig::new();
347
348 config.routes.push((
350 QualifiedId::global("global-route"),
351 test_route("global-route", "/"),
352 ));
353
354 config.routes.push((
356 QualifiedId::namespaced("api", "api-route"),
357 test_route("api-route", "/api/"),
358 ));
359
360 config.routes.push((
362 QualifiedId::in_service("api", "payments", "payments-route"),
363 test_route("payments-route", "/payments/"),
364 ));
365
366 config
367 }
368
369 #[tokio::test]
370 async fn test_match_from_global_scope() {
371 let config = mock_flattened_config();
372 let matcher = ScopedRouteMatcher::from_flattened(&config).await.unwrap();
373
374 let req = RequestInfo::new("GET", "/test", "example.com");
375 let result = matcher.match_request(&req, &Scope::Global).await;
376
377 assert!(result.is_some());
378 let route_match = result.unwrap();
379 assert_eq!(route_match.route_id(), "global-route");
380 assert_eq!(route_match.matched_scope, Scope::Global);
381 }
382
383 #[tokio::test]
384 async fn test_match_from_namespace_scope() {
385 let config = mock_flattened_config();
386 let matcher = ScopedRouteMatcher::from_flattened(&config).await.unwrap();
387
388 let ns_scope = Scope::Namespace("api".to_string());
389
390 let req = RequestInfo::new("GET", "/api/users", "example.com");
392 let result = matcher.match_request(&req, &ns_scope).await;
393
394 assert!(result.is_some());
395 let route_match = result.unwrap();
396 assert_eq!(route_match.route_id(), "api-route");
397 assert_eq!(
398 route_match.matched_scope,
399 Scope::Namespace("api".to_string())
400 );
401
402 let req = RequestInfo::new("GET", "/other", "example.com");
404 let result = matcher.match_request(&req, &ns_scope).await;
405
406 assert!(result.is_some());
407 let route_match = result.unwrap();
408 assert_eq!(route_match.route_id(), "global-route");
409 assert_eq!(route_match.matched_scope, Scope::Global);
410 }
411
412 #[tokio::test]
413 async fn test_match_from_service_scope() {
414 let config = mock_flattened_config();
415 let matcher = ScopedRouteMatcher::from_flattened(&config).await.unwrap();
416
417 let svc_scope = Scope::Service {
418 namespace: "api".to_string(),
419 service: "payments".to_string(),
420 };
421
422 let req = RequestInfo::new("GET", "/payments/checkout", "example.com");
424 let result = matcher.match_request(&req, &svc_scope).await;
425
426 assert!(result.is_some());
427 let route_match = result.unwrap();
428 assert_eq!(route_match.route_id(), "payments-route");
429 assert!(matches!(route_match.matched_scope, Scope::Service { .. }));
430
431 let req = RequestInfo::new("GET", "/api/users", "example.com");
433 let result = matcher.match_request(&req, &svc_scope).await;
434
435 assert!(result.is_some());
436 let route_match = result.unwrap();
437 assert_eq!(route_match.route_id(), "api-route");
438
439 let req = RequestInfo::new("GET", "/other", "example.com");
441 let result = matcher.match_request(&req, &svc_scope).await;
442
443 assert!(result.is_some());
444 let route_match = result.unwrap();
445 assert_eq!(route_match.route_id(), "global-route");
446 }
447
448 #[tokio::test]
449 async fn test_scope_info_in_match() {
450 let config = mock_flattened_config();
451 let matcher = ScopedRouteMatcher::from_flattened(&config).await.unwrap();
452
453 let svc_scope = Scope::Service {
454 namespace: "api".to_string(),
455 service: "payments".to_string(),
456 };
457
458 let req = RequestInfo::new("GET", "/payments/checkout", "example.com");
459 let result = matcher.match_request(&req, &svc_scope).await.unwrap();
460
461 assert_eq!(result.namespace(), Some("api"));
462 assert_eq!(result.service(), Some("payments"));
463 }
464
465 #[tokio::test]
466 async fn test_default_route() {
467 let config = mock_flattened_config();
468 let matcher = ScopedRouteMatcher::from_flattened(&config).await.unwrap();
469
470 matcher.set_global_default("global-route").await;
471
472 let req = RequestInfo::new("GET", "/nonexistent", "example.com");
474
475 let result = matcher.match_request(&req, &Scope::Global).await;
477 assert!(result.is_some());
479 }
480
481 #[tokio::test]
482 async fn test_no_match() {
483 let mut config = FlattenedConfig::new();
484 config.routes.push((
486 QualifiedId::global("specific-route"),
487 test_route("specific-route", "/specific/"),
488 ));
489
490 let matcher = ScopedRouteMatcher::from_flattened(&config).await.unwrap();
491
492 let req = RequestInfo::new("GET", "/other", "example.com");
493 let result = matcher.match_request(&req, &Scope::Global).await;
494
495 assert!(result.is_none());
496 }
497}