reinhardt_urls/routers/server_router/
builder.rs1use super::ServerRouter;
7use super::types::MiddlewareInfo;
8use crate::routers::UrlReverser;
9use matchit::Router as MatchitRouter;
10use reinhardt_di::InjectionContext;
11use reinhardt_http::ExcludeMiddleware;
12use reinhardt_middleware::Middleware;
13use std::collections::HashMap;
14use std::sync::{Arc, RwLock};
15
16impl ServerRouter {
17 fn validate_prefix(prefix: &str) {
43 if prefix.contains('{') || prefix.contains('}') {
48 panic!(
49 "`mount()` prefix `{prefix}` contains a path parameter placeholder (`{{...}}`); this is not supported.\nUse `route()` with the full path on the child router instead, or mount at a literal prefix."
50 );
51 }
52
53 if !prefix.ends_with('/') {
55 if prefix.is_empty() {
56 panic!(
57 "URL route prefix cannot be an empty string. \
58 Use '/' instead of ''. \
59 This follows Django URL configuration conventions."
60 );
61 } else {
62 panic!(
63 "URL route '{}' must end with a trailing slash '/'. \
64 Use '{}/' instead of '{}'. \
65 This follows Django URL configuration conventions.",
66 prefix, prefix, prefix,
67 );
68 }
69 }
70 }
71
72 pub fn new() -> Self {
82 Self {
83 prefix: String::new(),
84 namespace: None,
85 routes: Vec::new(),
86 viewsets: HashMap::new(),
87 functions: Vec::new(),
88 views: Vec::new(),
89 children: Vec::new(),
90 di_context: None,
91 pending_middleware_di: reinhardt_di::DiRegistrationList::new(),
92 middleware: Vec::new(),
93 middleware_names: Vec::new(),
94 middleware_exclusions: Vec::new(),
95 reverser: UrlReverser::new(),
96 get_router: RwLock::new(MatchitRouter::new()),
97 post_router: RwLock::new(MatchitRouter::new()),
98 put_router: RwLock::new(MatchitRouter::new()),
99 delete_router: RwLock::new(MatchitRouter::new()),
100 patch_router: RwLock::new(MatchitRouter::new()),
101 head_router: RwLock::new(MatchitRouter::new()),
102 options_router: RwLock::new(MatchitRouter::new()),
103 routes_compiled: RwLock::new(false),
104 }
105 }
106
107 pub fn with_prefix(mut self, prefix: impl Into<String>) -> Self {
118 self.prefix = prefix.into();
119 self
120 }
121
122 pub fn with_namespace(mut self, namespace: impl Into<String>) -> Self {
133 self.namespace = Some(namespace.into());
134 self
135 }
136
137 pub fn with_di_context(mut self, ctx: Arc<InjectionContext>) -> Self {
152 Self::adopt_di_context_recursive(&mut self, &ctx);
161 self.di_context = Some(ctx);
162 self
163 }
164
165 fn inherit_context_from_parent_if_any(parent: &ServerRouter, child: &mut ServerRouter) {
176 if child.di_context.is_none()
177 && let Some(parent_ctx) = parent.di_context.as_ref()
178 {
179 Self::adopt_di_context_recursive(child, parent_ctx);
180 child.di_context = Some(Arc::clone(parent_ctx));
181 }
182 }
183
184 fn adopt_di_context_recursive(router: &mut ServerRouter, ctx: &Arc<InjectionContext>) {
185 if !router.pending_middleware_di.is_empty() {
186 let pending = std::mem::take(&mut router.pending_middleware_di);
187 pending.apply_to(ctx.singleton_scope());
188 }
189 for child in router.children.iter_mut() {
190 if child.di_context.is_none() {
191 Self::adopt_di_context_recursive(child, ctx);
192 child.di_context = Some(Arc::clone(ctx));
193 }
194 }
195 }
196
197 pub(crate) fn di_context(&self) -> Option<&Arc<InjectionContext>> {
199 self.di_context.as_ref()
200 }
201
202 pub fn with_middleware<M: Middleware + 'static>(mut self, mw: M) -> Self {
214 let full_type_name = std::any::type_name::<M>().to_string();
215 let short_name = full_type_name
216 .rsplit("::")
217 .next()
218 .unwrap_or(&full_type_name)
219 .to_string();
220 let di_entries = mw.di_registrations();
232 if !di_entries.is_empty() {
233 if let Some(ctx) = self.di_context.as_ref() {
234 let scope = ctx.singleton_scope();
235 for (type_id, value) in di_entries {
236 scope.set_arc_any(type_id, value);
237 }
238 } else {
239 for (type_id, value) in di_entries {
240 self.pending_middleware_di.register_arc_any(type_id, value);
241 }
242 }
243 }
244 self.middleware_names.push(MiddlewareInfo {
245 name: short_name,
246 type_name: full_type_name,
247 });
248 self.middleware.push(Arc::new(mw));
249 self.middleware_exclusions.push(Vec::new());
250 self
251 }
252
253 pub fn exclude(mut self, pattern: &str) -> Self {
279 assert!(
280 !self.middleware_exclusions.is_empty(),
281 "exclude() called with no middleware. Call with_middleware() first."
282 );
283 self.middleware_exclusions
284 .last_mut()
285 .unwrap()
286 .push(pattern.to_string());
287 self
288 }
289
290 pub(crate) fn build_middleware_with_exclusions(&self) -> Vec<Arc<dyn Middleware>> {
292 let mut result: Vec<Arc<dyn Middleware>> = Vec::with_capacity(self.middleware.len());
293
294 for (mw, exclusions) in self
295 .middleware
296 .iter()
297 .zip(self.middleware_exclusions.iter())
298 {
299 if exclusions.is_empty() {
300 result.push(mw.clone());
301 } else {
302 let mut exclude_mw = ExcludeMiddleware::new(mw.clone());
303 for pattern in exclusions {
304 exclude_mw.add_exclusion_mut(pattern);
305 }
306 result.push(Arc::new(exclude_mw) as Arc<dyn Middleware>);
307 }
308 }
309
310 result
311 }
312
313 pub fn mount(mut self, prefix: &str, mut child: ServerRouter) -> Self {
351 Self::validate_prefix(prefix);
353
354 if child.prefix.is_empty() {
356 child.prefix = prefix.to_string();
357 }
358
359 Self::inherit_context_from_parent_if_any(&self, &mut child);
367
368 self.children.push(child);
369 self
370 }
371
372 pub fn mount_mut(&mut self, prefix: &str, mut child: ServerRouter) {
385 Self::validate_prefix(prefix);
387
388 if child.prefix.is_empty() {
389 child.prefix = prefix.to_string();
390 }
391 Self::inherit_context_from_parent_if_any(self, &mut child);
393 self.children.push(child);
394 }
395
396 pub fn group(mut self, routers: Vec<ServerRouter>) -> Self {
413 for mut router in routers {
414 Self::inherit_context_from_parent_if_any(&self, &mut router);
421 self.children.push(router);
422 }
423 self
424 }
425}
426
427#[cfg(test)]
428mod middleware_di_tests {
429 use super::*;
430 use async_trait::async_trait;
431 use reinhardt_core::exception::Result;
432 use reinhardt_di::{InjectionContext, SingletonScope};
433 use reinhardt_http::{Handler, Request, Response};
434 use rstest::rstest;
435 use std::any::TypeId;
436 use std::sync::Arc;
437
438 #[derive(Debug, PartialEq, Eq)]
439 struct DummyState(&'static str);
440
441 struct DummyMiddleware {
442 state: Arc<DummyState>,
443 }
444
445 #[async_trait]
446 impl Middleware for DummyMiddleware {
447 async fn process(&self, request: Request, handler: Arc<dyn Handler>) -> Result<Response> {
448 handler.handle(request).await
449 }
450
451 fn di_registrations(&self) -> Vec<reinhardt_http::MiddlewareDiRegistration> {
452 vec![(
453 TypeId::of::<DummyState>(),
454 Arc::clone(&self.state) as Arc<dyn std::any::Any + Send + Sync>,
455 )]
456 }
457 }
458
459 fn make_mw(tag: &'static str) -> DummyMiddleware {
460 DummyMiddleware {
461 state: Arc::new(DummyState(tag)),
462 }
463 }
464
465 #[rstest]
466 #[serial_test::serial(global_di)]
467 fn with_middleware_before_with_di_context_applies_to_context() {
468 let scope = Arc::new(SingletonScope::new());
471 let ctx = Arc::new(InjectionContext::builder(Arc::clone(&scope)).build());
472
473 let _router = ServerRouter::new()
475 .with_middleware(make_mw("before-context"))
476 .with_di_context(Arc::clone(&ctx));
477
478 let leaked = crate::routers::take_di_registrations();
480
481 let resolved = scope
484 .get::<DummyState>()
485 .expect("with_di_context must drain pending middleware DI into the new context");
486 assert_eq!(resolved.0, "before-context");
487 assert!(
488 leaked.is_none(),
489 "pending middleware DI must not leak into the global deferred list when a context is attached later"
490 );
491 }
492
493 #[rstest]
494 #[serial_test::serial(global_di)]
495 fn with_middleware_after_with_di_context_applies_to_context() {
496 let scope = Arc::new(SingletonScope::new());
498 let ctx = Arc::new(InjectionContext::builder(Arc::clone(&scope)).build());
499
500 let _router = ServerRouter::new()
502 .with_di_context(Arc::clone(&ctx))
503 .with_middleware(make_mw("after-context"));
504
505 let leaked = crate::routers::take_di_registrations();
506
507 let resolved = scope
509 .get::<DummyState>()
510 .expect("with_middleware after with_di_context must apply directly to context scope");
511 assert_eq!(resolved.0, "after-context");
512 assert!(leaked.is_none());
513 }
514
515 #[rstest]
516 #[serial_test::serial(global_di)]
517 fn with_middleware_without_context_flushes_to_global_on_register_all_routes() {
518 let _ = crate::routers::take_di_registrations(); let mut router = ServerRouter::new().with_middleware(make_mw("no-context"));
524 let _errors = router.register_all_routes();
525
526 let taken = crate::routers::take_di_registrations()
528 .expect("register_all_routes must flush pending middleware DI when no context is set");
529 let scope = SingletonScope::new();
530 taken.apply_to(&scope);
531 let resolved = scope.get::<DummyState>().expect(
532 "flushed registration must resolve from the global deferred list after apply_to",
533 );
534 assert_eq!(resolved.0, "no-context");
535 }
536
537 #[rstest]
538 #[serial_test::serial(global_di)]
539 fn group_drains_grouped_router_pending_into_parent_context() {
540 let scope = Arc::new(SingletonScope::new());
544 let ctx = Arc::new(InjectionContext::builder(Arc::clone(&scope)).build());
545 let users = ServerRouter::new()
546 .with_prefix("/users")
547 .with_middleware(make_mw("group-users"));
548 let posts_grandchild =
549 ServerRouter::new().with_middleware(make_mw("group-posts-grandchild"));
550 let posts = ServerRouter::new()
551 .with_prefix("/posts")
552 .mount("/comments/", posts_grandchild);
553
554 let _parent = ServerRouter::new()
556 .with_di_context(Arc::clone(&ctx))
557 .group(vec![users, posts]);
558
559 let leaked = crate::routers::take_di_registrations();
560
561 let resolved = scope.get::<DummyState>().expect(
565 "group must recursively drain grouped routers' pending middleware DI into the parent context",
566 );
567 assert!(matches!(
568 resolved.0,
569 "group-users" | "group-posts-grandchild"
570 ));
571 assert!(leaked.is_none());
572 }
573
574 #[rstest]
575 #[serial_test::serial(global_di)]
576 fn nested_grandchild_pending_drains_into_parent_context_on_mount() {
577 let scope = Arc::new(SingletonScope::new());
582 let ctx = Arc::new(InjectionContext::builder(Arc::clone(&scope)).build());
583 let grandchild = ServerRouter::new().with_middleware(make_mw("nested-grandchild"));
584 let child = ServerRouter::new().mount("/users/", grandchild);
585
586 let _parent = ServerRouter::new()
588 .with_di_context(Arc::clone(&ctx))
589 .mount("/api/", child);
590
591 let leaked = crate::routers::take_di_registrations();
592
593 let resolved = scope.get::<DummyState>().expect(
595 "mount must recursively drain grandchildren's pending middleware DI into the parent's context",
596 );
597 assert_eq!(resolved.0, "nested-grandchild");
598 assert!(leaked.is_none());
599 }
600
601 #[rstest]
602 #[serial_test::serial(global_di)]
603 fn child_pending_drains_when_context_attached_after_mount() {
604 let scope = Arc::new(SingletonScope::new());
608 let ctx = Arc::new(InjectionContext::builder(Arc::clone(&scope)).build());
609 let child = ServerRouter::new().with_middleware(make_mw("late-context-child"));
610
611 let _parent = ServerRouter::new()
613 .mount("/api/", child)
614 .with_di_context(Arc::clone(&ctx));
615
616 let leaked = crate::routers::take_di_registrations();
617
618 let resolved = scope.get::<DummyState>().expect(
620 "attaching a context after mounting a child with pending middleware DI must propagate into the child",
621 );
622 assert_eq!(resolved.0, "late-context-child");
623 assert!(leaked.is_none());
624 }
625
626 #[rstest]
627 #[serial_test::serial(global_di)]
628 fn child_pending_drains_into_parent_context_on_mount() {
629 let scope = Arc::new(SingletonScope::new());
632 let ctx = Arc::new(InjectionContext::builder(Arc::clone(&scope)).build());
633 let child = ServerRouter::new().with_middleware(make_mw("mounted-child"));
634
635 let _parent = ServerRouter::new()
637 .with_di_context(Arc::clone(&ctx))
638 .mount("/api/", child);
639
640 let leaked = crate::routers::take_di_registrations();
641
642 let resolved = scope.get::<DummyState>().expect(
644 "mounting a child with pending middleware DI into a context-bearing parent must drain into the parent's scope",
645 );
646 assert_eq!(resolved.0, "mounted-child");
647 assert!(leaked.is_none());
648 }
649}