Skip to main content

nestforge_http/
factory.rs

1use std::{
2    future::Future,
3    net::SocketAddr,
4    pin::Pin,
5    sync::{
6        atomic::{AtomicU64, Ordering},
7        Arc, Mutex,
8    },
9    time::{Instant, SystemTime, UNIX_EPOCH},
10};
11
12use anyhow::Result;
13use axum::{
14    body::Body,
15    http::{header::HeaderName, HeaderValue},
16    middleware::from_fn,
17    response::{IntoResponse, Response},
18    Router,
19};
20use nestforge_core::{
21    apply_exception_filters, execute_pipeline, framework_log_event, initialize_module_runtime,
22    AuthIdentity, Container, ExceptionFilter, Guard, HttpException, InitializedModule, Interceptor,
23    ModuleDefinition, NextFn, RequestContext, RequestId,
24};
25
26use crate::middleware::{
27    run_middleware_chain, MiddlewareBinding, MiddlewareConsumer, NestMiddleware,
28};
29
30/// The main entry point for creating a NestForge application.
31///
32/// It handles the bootstrap process:
33/// 1. Creating the DI Container.
34/// 2. Initializing the Module Graph (resolving imports and providers).
35/// 3. Merging all Controller routers into a single Axum app.
36/// 4. Attaching global middleware, guards, interceptors, and exception filters.
37///
38/// # Example
39/// ```rust,no_run
40/// use nestforge_http::NestForgeFactory;
41///
42/// #[tokio::main]
43/// async fn main() {
44///     let app = NestForgeFactory::<AppModule>::create()
45///         .expect("failed to start")
46///         .listen(3000)
47///         .await;
48/// }
49/// ```
50pub struct NestForgeFactory<M: ModuleDefinition> {
51    _marker: std::marker::PhantomData<M>,
52    container: Container,
53    runtime: Arc<InitializedModule>,
54    controllers: Vec<Router<Container>>,
55    extra_routers: Vec<Router<Container>>,
56    global_prefix: Option<String>,
57    version: Option<String>,
58    auth_resolver: Option<Arc<AuthResolver>>,
59    global_guards: Vec<Arc<dyn Guard>>,
60    global_interceptors: Vec<Arc<dyn Interceptor>>,
61    global_exception_filters: Vec<Arc<dyn ExceptionFilter>>,
62    middleware_bindings: Vec<MiddlewareBinding>,
63}
64
65type AuthFuture = Pin<Box<dyn Future<Output = Result<Option<AuthIdentity>, HttpException>> + Send>>;
66type AuthResolver = dyn Fn(Option<String>, Container) -> AuthFuture + Send + Sync;
67
68impl<M: ModuleDefinition> NestForgeFactory<M> {
69    /// Creates a new application instance from the root module.
70    ///
71    /// This triggers the DI container initialization and module lifecycle hooks (e.g., `on_module_init`).
72    pub fn create() -> Result<Self> {
73        let container = Container::new();
74        let runtime = Arc::new(initialize_module_runtime::<M>(&container)?);
75        runtime.run_module_init(&container)?;
76        runtime.run_application_bootstrap(&container)?;
77        let controllers = runtime.controllers.clone();
78
79        Ok(Self {
80            _marker: std::marker::PhantomData,
81            container,
82            runtime,
83            controllers,
84            extra_routers: Vec::new(),
85            global_prefix: None,
86            version: None,
87            auth_resolver: None,
88            global_guards: Vec::new(),
89            global_interceptors: Vec::new(),
90            global_exception_filters: Vec::new(),
91            middleware_bindings: Vec::new(),
92        })
93    }
94
95    /// Sets a global prefix for all routes (e.g., "api").
96    pub fn with_global_prefix(mut self, prefix: impl Into<String>) -> Self {
97        let prefix = prefix.into().trim().trim_matches('/').to_string();
98        if !prefix.is_empty() {
99            framework_log_event("global_prefix_configured", &[("prefix", prefix.clone())]);
100            self.global_prefix = Some(prefix);
101        }
102        self
103    }
104
105    /// Sets a global API version for all routes (e.g., "v1").
106    pub fn with_version(mut self, version: impl Into<String>) -> Self {
107        let version = version.into().trim().trim_matches('/').to_string();
108        if !version.is_empty() {
109            framework_log_event("api_version_configured", &[("version", version.clone())]);
110            self.version = Some(version);
111        }
112        self
113    }
114
115    /// Registers a global guard.
116    ///
117    /// Global guards run for *every* route in the application.
118    pub fn use_guard<G>(mut self) -> Self
119    where
120        G: Guard + Default,
121    {
122        framework_log_event(
123            "global_guard_register",
124            &[("guard", std::any::type_name::<G>().to_string())],
125        );
126        self.global_guards.push(Arc::new(G::default()));
127        self
128    }
129
130    /// Registers a global interceptor.
131    ///
132    /// Global interceptors wrap *every* route handler.
133    pub fn use_interceptor<I>(mut self) -> Self
134    where
135        I: Interceptor + Default,
136    {
137        framework_log_event(
138            "global_interceptor_register",
139            &[("interceptor", std::any::type_name::<I>().to_string())],
140        );
141        self.global_interceptors.push(Arc::new(I::default()));
142        self
143    }
144
145    /// Registers a global exception filter.
146    ///
147    /// Catches unhandled exceptions from *any* route.
148    pub fn use_exception_filter<F>(mut self) -> Self
149    where
150        F: ExceptionFilter + Default,
151    {
152        framework_log_event(
153            "global_exception_filter_register",
154            &[("filter", std::any::type_name::<F>().to_string())],
155        );
156        self.global_exception_filters.push(Arc::new(F::default()));
157        self
158    }
159
160    /// Applies middleware to the application.
161    ///
162    /// Use the builder to select which routes the middleware applies to.
163    pub fn use_middleware<T>(mut self) -> Self
164    where
165        T: NestMiddleware + Default,
166    {
167        let mut consumer = MiddlewareConsumer::new();
168        consumer.apply::<T>().for_all_routes();
169        self.middleware_bindings.extend(consumer.into_bindings());
170        self
171    }
172
173    /// Advanced middleware configuration using a consumer builder.
174    pub fn configure_middleware<F>(mut self, configure: F) -> Self
175    where
176        F: FnOnce(&mut MiddlewareConsumer),
177    {
178        let mut consumer = MiddlewareConsumer::new();
179        configure(&mut consumer);
180        self.middleware_bindings.extend(consumer.into_bindings());
181        self
182    }
183
184    /// Sets the authentication resolver.
185    ///
186    /// This function is called for every request to resolve the `AuthIdentity`
187    /// from the bearer token.
188    pub fn with_auth_resolver<F, Fut>(mut self, resolver: F) -> Self
189    where
190        F: Fn(Option<String>, Container) -> Fut + Send + Sync + 'static,
191        Fut: Future<Output = Result<Option<AuthIdentity>, HttpException>> + Send + 'static,
192    {
193        self.auth_resolver = Some(Arc::new(move |token, container| {
194            Box::pin(resolver(token, container))
195        }));
196        self
197    }
198
199    /// Merges an external Axum router into the application.
200    ///
201    /// Useful for integrating other libraries or raw Axum handlers.
202    pub fn merge_router(mut self, router: Router<Container>) -> Self {
203        self.extra_routers.push(router);
204        self
205    }
206
207    /// Returns a reference to the underlying DI Container.
208    pub fn container(&self) -> &Container {
209        &self.container
210    }
211
212    /// Consumes the factory and returns the fully configured Axum Router.
213    ///
214    /// Use this if you want to run the app with your own server (e.g. Lambda, Shuttle).
215    pub fn into_router(self) -> Router {
216        /*
217        Build a router that EXPECTS Container state.
218        We don't attach the actual state yet.
219        */
220        let mut app: Router<Container> = Router::new();
221
222        /*
223        Mount all controller routers (they are also Router<Container>)
224        */
225        for controller_router in self.controllers {
226            app = app.merge(controller_router);
227        }
228        for extra_router in self.extra_routers {
229            app = app.merge(extra_router);
230        }
231
232        if let Some(version) = &self.version {
233            app = Router::new().nest(&format!("/{}", version), app);
234        }
235
236        if let Some(prefix) = &self.global_prefix {
237            app = Router::new().nest(&format!("/{}", prefix), app);
238        }
239
240        let global_guards = Arc::new(self.global_guards);
241        let global_interceptors = Arc::new(self.global_interceptors);
242        let global_exception_filters = Arc::new(self.global_exception_filters);
243        let middleware_bindings = Arc::new(self.middleware_bindings);
244        let auth_resolver = self.auth_resolver.clone();
245        let request_container = self.container.clone();
246
247        let route_exception_filters = Arc::clone(&global_exception_filters);
248        let app = app.route_layer(from_fn(move |req, next| {
249            let guards = Arc::clone(&global_guards);
250            let interceptors = Arc::clone(&global_interceptors);
251            let filters = Arc::clone(&route_exception_filters);
252            async move { execute_pipeline(req, next, guards, interceptors, filters).await }
253        }));
254
255        let app = app.layer(from_fn(
256            move |req: axum::extract::Request, next: axum::middleware::Next| {
257                let middlewares = Arc::clone(&middleware_bindings);
258                async move {
259                    if middlewares.is_empty() {
260                        return next.run(req).await;
261                    }
262
263                    let terminal = next_to_fn(next);
264                    run_middleware_chain(middlewares, 0, req, terminal).await
265                }
266            },
267        ));
268
269        Router::new()
270            .merge(app)
271            .layer(from_fn(move |req, next| {
272                let auth_resolver = auth_resolver.clone();
273                let request_container = request_container.clone();
274                let exception_filters = Arc::clone(&global_exception_filters);
275                async move {
276                    request_context_middleware(
277                        req,
278                        next,
279                        request_container,
280                        auth_resolver,
281                        exception_filters,
282                    )
283                    .await
284                }
285            }))
286            .with_state(self.container)
287    }
288
289    /// Starts the HTTP server on the specified port.
290    ///
291    /// This will block the current thread (it should be awaited).
292    /// Upon shutdown (Ctrl+C), it runs the `on_module_destroy` and `on_application_shutdown` hooks.
293    pub async fn listen(self, port: u16) -> Result<()> {
294        let runtime = Arc::clone(&self.runtime);
295        let container = self.container.clone();
296        let app = self.into_router();
297
298        let addr = SocketAddr::from(([127, 0, 0, 1], port));
299        let listener = tokio::net::TcpListener::bind(addr).await?;
300
301        framework_log_event("server_listening", &[("addr", addr.to_string())]);
302
303        axum::serve(listener, app).await?;
304        runtime.run_module_destroy(&container)?;
305        runtime.run_application_shutdown(&container)?;
306        Ok(())
307    }
308}
309
310static NEXT_REQUEST_SEQUENCE: AtomicU64 = AtomicU64::new(1);
311const REQUEST_ID_HEADER: &str = "x-request-id";
312
313fn next_to_fn(next: axum::middleware::Next) -> NextFn {
314    let next = Arc::new(Mutex::new(Some(next)));
315
316    Arc::new(move |req: axum::extract::Request<Body>| {
317        let next = Arc::clone(&next);
318        Box::pin(async move {
319            let next = {
320                let mut guard = match next.lock() {
321                    Ok(guard) => guard,
322                    Err(_) => {
323                        return HttpException::internal_server_error("Middleware lock poisoned")
324                            .into_response();
325                    }
326                };
327                guard.take()
328            };
329
330            match next {
331                Some(next) => next.run(req).await,
332                None => {
333                    HttpException::internal_server_error("Middleware next called multiple times")
334                        .into_response()
335                }
336            }
337        })
338    })
339}
340
341async fn request_context_middleware(
342    mut req: axum::extract::Request,
343    next: axum::middleware::Next,
344    container: Container,
345    auth_resolver: Option<Arc<AuthResolver>>,
346    exception_filters: Arc<Vec<Arc<dyn ExceptionFilter>>>,
347) -> Response {
348    let scoped_container = container.scoped();
349    let request_id = RequestId::new(generate_request_id());
350    let request_id_value = request_id.value().to_string();
351    let method = req.method().to_string();
352    let path = req.uri().path().to_string();
353    let started = Instant::now();
354    let bearer_token = req
355        .headers()
356        .get(axum::http::header::AUTHORIZATION)
357        .and_then(|value| value.to_str().ok())
358        .and_then(|value| value.strip_prefix("Bearer "))
359        .map(str::trim)
360        .filter(|value| !value.is_empty())
361        .map(str::to_string);
362
363    req.extensions_mut().insert(scoped_container.clone());
364    req.extensions_mut().insert(request_id.clone());
365    let _ = scoped_container.override_value(request_id.clone());
366    framework_log_event(
367        "request_start",
368        &[
369            ("request_id", request_id_value.clone()),
370            ("method", method.clone()),
371            ("path", path.clone()),
372        ],
373    );
374
375    if let Some(resolver) = auth_resolver {
376        match resolver(bearer_token, container).await {
377            Ok(Some(identity)) => {
378                framework_log_event(
379                    "auth_identity_resolved",
380                    &[
381                        ("request_id", request_id_value.clone()),
382                        ("subject", identity.subject.clone()),
383                    ],
384                );
385                let _ = scoped_container.override_value(identity.clone());
386                req.extensions_mut().insert(Arc::new(identity));
387            }
388            Ok(None) => {}
389            Err(err) => {
390                let ctx = RequestContext::from_request(&req);
391                let _ = scoped_container.override_value(ctx.clone());
392                let mut response = apply_exception_filters(
393                    err.with_request_id(request_id_value.clone()),
394                    &ctx,
395                    exception_filters.as_slice(),
396                )
397                .into_response();
398                attach_request_id_header(&mut response, &request_id_value);
399                framework_log_event(
400                    "request_complete",
401                    &[
402                        ("request_id", request_id_value),
403                        ("method", method),
404                        ("path", path),
405                        ("status", response.status().as_u16().to_string()),
406                        ("duration_ms", started.elapsed().as_millis().to_string()),
407                    ],
408                );
409                return response;
410            }
411        }
412    }
413
414    let ctx = RequestContext::from_request(&req);
415    let _ = scoped_container.override_value(ctx);
416
417    let mut response = next.run(req).await;
418    attach_request_id_header(&mut response, &request_id_value);
419
420    framework_log_event(
421        "request_complete",
422        &[
423            ("request_id", request_id_value),
424            ("method", method),
425            ("path", path),
426            ("status", response.status().as_u16().to_string()),
427            ("duration_ms", started.elapsed().as_millis().to_string()),
428        ],
429    );
430
431    response
432}
433
434fn generate_request_id() -> String {
435    let sequence = NEXT_REQUEST_SEQUENCE.fetch_add(1, Ordering::Relaxed);
436    let millis = SystemTime::now()
437        .duration_since(UNIX_EPOCH)
438        .map(|duration| duration.as_millis())
439        .unwrap_or_default();
440    format!("req-{millis}-{sequence}")
441}
442
443fn attach_request_id_header(response: &mut Response, request_id: &str) {
444    if let Ok(value) = HeaderValue::from_str(request_id) {
445        response
446            .headers_mut()
447            .insert(HeaderName::from_static(REQUEST_ID_HEADER), value);
448    }
449}
450
451#[cfg(test)]
452mod tests {
453    use std::sync::Arc;
454
455    use anyhow::Result;
456    use axum::Json;
457    use nestforge_core::{
458        register_provider, ApiResult, AuthUser, Container, ControllerBasePath,
459        ControllerDefinition, ExceptionFilter, HttpException, Inject, ModuleDefinition, Provider,
460        RequestContext as FrameworkRequestContext, RouteBuilder,
461    };
462    use tower::ServiceExt;
463
464    use super::*;
465
466    struct HealthController;
467    #[derive(Default)]
468    struct RewriteBadRequestFilter;
469    struct RequestScopedService {
470        path: String,
471    }
472
473    impl ControllerBasePath for HealthController {
474        fn base_path() -> &'static str {
475            "/health"
476        }
477    }
478
479    impl HealthController {
480        async fn ok(request_id: RequestId) -> ApiResult<String> {
481            Ok(Json(request_id.value().to_string()))
482        }
483
484        async fn fail(request_id: RequestId) -> ApiResult<String> {
485            Err(HttpException::bad_request("broken request")
486                .with_request_id(request_id.value().to_string()))
487        }
488
489        async fn fail_locally(request_id: RequestId) -> ApiResult<String> {
490            Err(HttpException::bad_request("local broken request")
491                .with_request_id(request_id.value().to_string()))
492        }
493
494        async fn me(user: AuthUser) -> ApiResult<String> {
495            Ok(Json(user.subject.clone()))
496        }
497
498        async fn scoped(service: Inject<RequestScopedService>) -> ApiResult<String> {
499            Ok(Json(service.path.clone()))
500        }
501    }
502
503    impl ControllerDefinition for HealthController {
504        fn router() -> Router<Container> {
505            RouteBuilder::<Self>::new()
506                .get("/", Self::ok)
507                .get("/fail", Self::fail)
508                .get_with_pipeline(
509                    "/fail-local",
510                    Self::fail_locally,
511                    Vec::new(),
512                    Vec::new(),
513                    vec![Arc::new(RewriteBadRequestFilter) as Arc<dyn ExceptionFilter>],
514                    None,
515                )
516                .get("/me", Self::me)
517                .get("/scoped", Self::scoped)
518                .build()
519        }
520    }
521
522    impl ExceptionFilter for RewriteBadRequestFilter {
523        fn catch(&self, exception: HttpException, _ctx: &RequestContext) -> HttpException {
524            if exception.status == axum::http::StatusCode::BAD_REQUEST {
525                HttpException::bad_request("filtered bad request")
526                    .with_optional_request_id(exception.request_id)
527            } else {
528                exception
529            }
530        }
531    }
532
533    struct TestModule;
534
535    impl ModuleDefinition for TestModule {
536        fn register(container: &Container) -> Result<()> {
537            register_provider(
538                container,
539                Provider::request_factory(|container| {
540                    let ctx = container.resolve::<FrameworkRequestContext>()?;
541                    Ok(RequestScopedService {
542                        path: ctx.uri.path().to_string(),
543                    })
544                }),
545            )?;
546            Ok(())
547        }
548
549        fn controllers() -> Vec<Router<Container>> {
550            vec![HealthController::router()]
551        }
552    }
553
554    #[tokio::test]
555    async fn request_middleware_sets_request_id_header_and_extension() {
556        let app = NestForgeFactory::<TestModule>::create()
557            .expect("factory should build")
558            .into_router();
559
560        let response = app
561            .oneshot(
562                axum::http::Request::builder()
563                    .uri("/health/")
564                    .body(axum::body::Body::empty())
565                    .expect("request should build"),
566            )
567            .await
568            .expect("request should succeed");
569
570        assert!(response.headers().contains_key(REQUEST_ID_HEADER));
571    }
572
573    #[tokio::test]
574    async fn error_responses_keep_request_id_header() {
575        let app = NestForgeFactory::<TestModule>::create()
576            .expect("factory should build")
577            .use_exception_filter::<RewriteBadRequestFilter>()
578            .into_router();
579
580        let response = app
581            .oneshot(
582                axum::http::Request::builder()
583                    .uri("/health/fail")
584                    .body(axum::body::Body::empty())
585                    .expect("request should build"),
586            )
587            .await
588            .expect("request should succeed");
589
590        assert_eq!(response.status(), axum::http::StatusCode::BAD_REQUEST);
591        assert!(response.headers().contains_key(REQUEST_ID_HEADER));
592    }
593
594    #[tokio::test]
595    async fn route_specific_exception_filters_rewrite_route_failures() {
596        let app = NestForgeFactory::<TestModule>::create()
597            .expect("factory should build")
598            .into_router();
599
600        let response = app
601            .oneshot(
602                axum::http::Request::builder()
603                    .uri("/health/fail-local")
604                    .body(axum::body::Body::empty())
605                    .expect("request should build"),
606            )
607            .await
608            .expect("request should succeed");
609
610        assert_eq!(response.status(), axum::http::StatusCode::BAD_REQUEST);
611        assert!(response.headers().contains_key(REQUEST_ID_HEADER));
612    }
613
614    #[tokio::test]
615    async fn auth_resolver_inserts_identity_for_auth_user_extractor() {
616        let app = NestForgeFactory::<TestModule>::create()
617            .expect("factory should build")
618            .with_auth_resolver(|token, _container| async move {
619                Ok(token.map(|_| AuthIdentity::new("demo-user").with_roles(["admin"])))
620            })
621            .into_router();
622
623        let response = app
624            .oneshot(
625                axum::http::Request::builder()
626                    .uri("/health/me")
627                    .header(axum::http::header::AUTHORIZATION, "Bearer demo-token")
628                    .body(axum::body::Body::empty())
629                    .expect("request should build"),
630            )
631            .await
632            .expect("request should succeed");
633
634        assert_eq!(response.status(), axum::http::StatusCode::OK);
635    }
636
637    #[tokio::test]
638    async fn request_scoped_provider_resolves_from_per_request_container() {
639        let app = NestForgeFactory::<TestModule>::create()
640            .expect("factory should build")
641            .into_router();
642
643        let response = app
644            .oneshot(
645                axum::http::Request::builder()
646                    .uri("/health/scoped")
647                    .body(axum::body::Body::empty())
648                    .expect("request should build"),
649            )
650            .await
651            .expect("request should succeed");
652
653        assert_eq!(response.status(), axum::http::StatusCode::OK);
654    }
655}