Skip to main content

nestforge_http/
factory.rs

1use std::{
2    future::Future,
3    net::SocketAddr,
4    pin::Pin,
5    sync::{
6        atomic::{AtomicU64, Ordering},
7        Arc, Mutex,
8    },
9    time::{Instant, SystemTime, UNIX_EPOCH},
10};
11
12use anyhow::Result;
13use axum::{
14    body::Body,
15    http::{header::HeaderName, HeaderValue},
16    middleware::from_fn,
17    response::{IntoResponse, Response},
18    Router,
19};
20use nestforge_core::{
21    apply_exception_filters, execute_pipeline, framework_log_event, initialize_module_runtime,
22    AuthIdentity, Container, ExceptionFilter, Guard, HttpException, InitializedModule, NextFn,
23    Interceptor, ModuleDefinition, RequestContext, RequestId,
24};
25
26use crate::middleware::{run_middleware_chain, MiddlewareBinding, MiddlewareConsumer, NestMiddleware};
27
28/*
29NestForgeFactory = app bootstrapper.
30
31This is the NestFactory.create(AppModule) vibe.
32
33Now it:
34- builds DI container
35- asks the module to register providers
36- asks the module for controllers
37- merges controller routers into one app router
38*/
39pub struct NestForgeFactory<M: ModuleDefinition> {
40    _marker: std::marker::PhantomData<M>,
41    container: Container,
42    runtime: Arc<InitializedModule>,
43    controllers: Vec<Router<Container>>,
44    extra_routers: Vec<Router<Container>>,
45    global_prefix: Option<String>,
46    version: Option<String>,
47    auth_resolver: Option<Arc<AuthResolver>>,
48    global_guards: Vec<Arc<dyn Guard>>,
49    global_interceptors: Vec<Arc<dyn Interceptor>>,
50    global_exception_filters: Vec<Arc<dyn ExceptionFilter>>,
51    middleware_bindings: Vec<MiddlewareBinding>,
52}
53
54type AuthFuture = Pin<Box<dyn Future<Output = Result<Option<AuthIdentity>, HttpException>> + Send>>;
55type AuthResolver = dyn Fn(Option<String>, Container) -> AuthFuture + Send + Sync;
56
57impl<M: ModuleDefinition> NestForgeFactory<M> {
58    pub fn create() -> Result<Self> {
59        let container = Container::new();
60        let runtime = Arc::new(initialize_module_runtime::<M>(&container)?);
61        runtime.run_module_init(&container)?;
62        runtime.run_application_bootstrap(&container)?;
63        let controllers = runtime.controllers.clone();
64
65        Ok(Self {
66            _marker: std::marker::PhantomData,
67            container,
68            runtime,
69            controllers,
70            extra_routers: Vec::new(),
71            global_prefix: None,
72            version: None,
73            auth_resolver: None,
74            global_guards: Vec::new(),
75            global_interceptors: Vec::new(),
76            global_exception_filters: Vec::new(),
77            middleware_bindings: Vec::new(),
78        })
79    }
80
81    pub fn with_global_prefix(mut self, prefix: impl Into<String>) -> Self {
82        let prefix = prefix.into().trim().trim_matches('/').to_string();
83        if !prefix.is_empty() {
84            framework_log_event(
85                "global_prefix_configured",
86                &[("prefix", prefix.clone())],
87            );
88            self.global_prefix = Some(prefix);
89        }
90        self
91    }
92
93    pub fn with_version(mut self, version: impl Into<String>) -> Self {
94        let version = version.into().trim().trim_matches('/').to_string();
95        if !version.is_empty() {
96            framework_log_event(
97                "api_version_configured",
98                &[("version", version.clone())],
99            );
100            self.version = Some(version);
101        }
102        self
103    }
104
105    pub fn use_guard<G>(mut self) -> Self
106    where
107        G: Guard + Default,
108    {
109        framework_log_event(
110            "global_guard_register",
111            &[("guard", std::any::type_name::<G>().to_string())],
112        );
113        self.global_guards.push(Arc::new(G::default()));
114        self
115    }
116
117    pub fn use_interceptor<I>(mut self) -> Self
118    where
119        I: Interceptor + Default,
120    {
121        framework_log_event(
122            "global_interceptor_register",
123            &[("interceptor", std::any::type_name::<I>().to_string())],
124        );
125        self.global_interceptors.push(Arc::new(I::default()));
126        self
127    }
128
129    pub fn use_exception_filter<F>(mut self) -> Self
130    where
131        F: ExceptionFilter + Default,
132    {
133        framework_log_event(
134            "global_exception_filter_register",
135            &[("filter", std::any::type_name::<F>().to_string())],
136        );
137        self.global_exception_filters.push(Arc::new(F::default()));
138        self
139    }
140
141    pub fn use_middleware<T>(mut self) -> Self
142    where
143        T: NestMiddleware + Default,
144    {
145        let mut consumer = MiddlewareConsumer::new();
146        consumer.apply::<T>().for_all_routes();
147        self.middleware_bindings.extend(consumer.into_bindings());
148        self
149    }
150
151    pub fn configure_middleware<F>(mut self, configure: F) -> Self
152    where
153        F: FnOnce(&mut MiddlewareConsumer),
154    {
155        let mut consumer = MiddlewareConsumer::new();
156        configure(&mut consumer);
157        self.middleware_bindings.extend(consumer.into_bindings());
158        self
159    }
160
161    pub fn with_auth_resolver<F, Fut>(mut self, resolver: F) -> Self
162    where
163        F: Fn(Option<String>, Container) -> Fut + Send + Sync + 'static,
164        Fut: Future<Output = Result<Option<AuthIdentity>, HttpException>> + Send + 'static,
165    {
166        self.auth_resolver = Some(Arc::new(move |token, container| {
167            Box::pin(resolver(token, container))
168        }));
169        self
170    }
171
172    pub fn merge_router(mut self, router: Router<Container>) -> Self {
173        self.extra_routers.push(router);
174        self
175    }
176
177    pub fn container(&self) -> &Container {
178        &self.container
179    }
180
181    pub fn into_router(self) -> Router {
182        /*
183        Build a router that EXPECTS Container state.
184        We don't attach the actual state yet.
185        */
186        let mut app: Router<Container> = Router::new();
187
188        /*
189        Mount all controller routers (they are also Router<Container>)
190        */
191        for controller_router in self.controllers {
192            app = app.merge(controller_router);
193        }
194        for extra_router in self.extra_routers {
195            app = app.merge(extra_router);
196        }
197
198        if let Some(version) = &self.version {
199            app = Router::new().nest(&format!("/{}", version), app);
200        }
201
202        if let Some(prefix) = &self.global_prefix {
203            app = Router::new().nest(&format!("/{}", prefix), app);
204        }
205
206        let global_guards = Arc::new(self.global_guards);
207        let global_interceptors = Arc::new(self.global_interceptors);
208        let global_exception_filters = Arc::new(self.global_exception_filters);
209        let middleware_bindings = Arc::new(self.middleware_bindings);
210        let auth_resolver = self.auth_resolver.clone();
211        let request_container = self.container.clone();
212
213        let route_exception_filters = Arc::clone(&global_exception_filters);
214        let app = app.route_layer(from_fn(move |req, next| {
215            let guards = Arc::clone(&global_guards);
216            let interceptors = Arc::clone(&global_interceptors);
217            let filters = Arc::clone(&route_exception_filters);
218            async move { execute_pipeline(req, next, guards, interceptors, filters).await }
219        }));
220
221        let app = app.layer(from_fn(move |req: axum::extract::Request, next: axum::middleware::Next| {
222            let middlewares = Arc::clone(&middleware_bindings);
223            async move {
224                if middlewares.is_empty() {
225                    return next.run(req).await;
226                }
227
228                let terminal = next_to_fn(next);
229                run_middleware_chain(middlewares, 0, req, terminal).await
230            }
231        }));
232
233        Router::new()
234            .merge(app)
235            .layer(from_fn(move |req, next| {
236                let auth_resolver = auth_resolver.clone();
237                let request_container = request_container.clone();
238                let exception_filters = Arc::clone(&global_exception_filters);
239                async move {
240                    request_context_middleware(
241                        req,
242                        next,
243                        request_container,
244                        auth_resolver,
245                        exception_filters,
246                    )
247                    .await
248                }
249            }))
250            .with_state(self.container)
251    }
252
253    pub async fn listen(self, port: u16) -> Result<()> {
254        let runtime = Arc::clone(&self.runtime);
255        let container = self.container.clone();
256        let app = self.into_router();
257
258        let addr = SocketAddr::from(([127, 0, 0, 1], port));
259        let listener = tokio::net::TcpListener::bind(addr).await?;
260
261        framework_log_event("server_listening", &[("addr", addr.to_string())]);
262
263        axum::serve(listener, app).await?;
264        runtime.run_module_destroy(&container)?;
265        runtime.run_application_shutdown(&container)?;
266        Ok(())
267    }
268}
269
270static NEXT_REQUEST_SEQUENCE: AtomicU64 = AtomicU64::new(1);
271const REQUEST_ID_HEADER: &str = "x-request-id";
272
273fn next_to_fn(next: axum::middleware::Next) -> NextFn {
274    let next = Arc::new(Mutex::new(Some(next)));
275
276    Arc::new(move |req: axum::extract::Request<Body>| {
277        let next = Arc::clone(&next);
278        Box::pin(async move {
279            let next = {
280                let mut guard = match next.lock() {
281                    Ok(guard) => guard,
282                    Err(_) => {
283                        return HttpException::internal_server_error("Middleware lock poisoned")
284                            .into_response();
285                    }
286                };
287                guard.take()
288            };
289
290            match next {
291                Some(next) => next.run(req).await,
292                None => HttpException::internal_server_error("Middleware next called multiple times")
293                    .into_response(),
294            }
295        })
296    })
297}
298
299async fn request_context_middleware(
300    mut req: axum::extract::Request,
301    next: axum::middleware::Next,
302    container: Container,
303    auth_resolver: Option<Arc<AuthResolver>>,
304    exception_filters: Arc<Vec<Arc<dyn ExceptionFilter>>>,
305) -> Response {
306    let scoped_container = container.scoped();
307    let request_id = RequestId::new(generate_request_id());
308    let request_id_value = request_id.value().to_string();
309    let method = req.method().to_string();
310    let path = req.uri().path().to_string();
311    let started = Instant::now();
312    let bearer_token = req
313        .headers()
314        .get(axum::http::header::AUTHORIZATION)
315        .and_then(|value| value.to_str().ok())
316        .and_then(|value| value.strip_prefix("Bearer "))
317        .map(str::trim)
318        .filter(|value| !value.is_empty())
319        .map(str::to_string);
320
321    req.extensions_mut().insert(scoped_container.clone());
322    req.extensions_mut().insert(request_id.clone());
323    let _ = scoped_container.override_value(request_id.clone());
324    framework_log_event(
325        "request_start",
326        &[
327            ("request_id", request_id_value.clone()),
328            ("method", method.clone()),
329            ("path", path.clone()),
330        ],
331    );
332
333    if let Some(resolver) = auth_resolver {
334        match resolver(bearer_token, container).await {
335            Ok(Some(identity)) => {
336                framework_log_event(
337                    "auth_identity_resolved",
338                    &[
339                        ("request_id", request_id_value.clone()),
340                        ("subject", identity.subject.clone()),
341                    ],
342                );
343                let _ = scoped_container.override_value(identity.clone());
344                req.extensions_mut().insert(Arc::new(identity));
345            }
346            Ok(None) => {}
347            Err(err) => {
348                let ctx = RequestContext::from_request(&req);
349                let _ = scoped_container.override_value(ctx.clone());
350                let mut response = apply_exception_filters(
351                    err.with_request_id(request_id_value.clone()),
352                    &ctx,
353                    exception_filters.as_slice(),
354                )
355                .into_response();
356                attach_request_id_header(&mut response, &request_id_value);
357                framework_log_event(
358                    "request_complete",
359                    &[
360                        ("request_id", request_id_value),
361                        ("method", method),
362                        ("path", path),
363                        ("status", response.status().as_u16().to_string()),
364                        ("duration_ms", started.elapsed().as_millis().to_string()),
365                    ],
366                );
367                return response;
368            }
369        }
370    }
371
372    let ctx = RequestContext::from_request(&req);
373    let _ = scoped_container.override_value(ctx);
374
375    let mut response = next.run(req).await;
376    attach_request_id_header(&mut response, &request_id_value);
377
378    framework_log_event(
379        "request_complete",
380        &[
381            ("request_id", request_id_value),
382            ("method", method),
383            ("path", path),
384            ("status", response.status().as_u16().to_string()),
385            ("duration_ms", started.elapsed().as_millis().to_string()),
386        ],
387    );
388
389    response
390}
391
392fn generate_request_id() -> String {
393    let sequence = NEXT_REQUEST_SEQUENCE.fetch_add(1, Ordering::Relaxed);
394    let millis = SystemTime::now()
395        .duration_since(UNIX_EPOCH)
396        .map(|duration| duration.as_millis())
397        .unwrap_or_default();
398    format!("req-{millis}-{sequence}")
399}
400
401fn attach_request_id_header(response: &mut Response, request_id: &str) {
402    if let Ok(value) = HeaderValue::from_str(request_id) {
403        response.headers_mut().insert(
404            HeaderName::from_static(REQUEST_ID_HEADER),
405            value,
406        );
407    }
408}
409
410#[cfg(test)]
411mod tests {
412    use std::sync::Arc;
413
414    use anyhow::Result;
415    use axum::Json;
416    use nestforge_core::{
417        register_provider, ApiResult, AuthUser, Container, ControllerBasePath, ControllerDefinition,
418        ExceptionFilter, HttpException, Inject, ModuleDefinition, Provider,
419        RequestContext as FrameworkRequestContext, RouteBuilder,
420    };
421    use tower::ServiceExt;
422
423    use super::*;
424
425    struct HealthController;
426    #[derive(Default)]
427    struct RewriteBadRequestFilter;
428    struct RequestScopedService {
429        path: String,
430    }
431
432    impl ControllerBasePath for HealthController {
433        fn base_path() -> &'static str {
434            "/health"
435        }
436    }
437
438    impl HealthController {
439        async fn ok(request_id: RequestId) -> ApiResult<String> {
440            Ok(Json(request_id.value().to_string()))
441        }
442
443        async fn fail(request_id: RequestId) -> ApiResult<String> {
444            Err(HttpException::bad_request("broken request")
445                .with_request_id(request_id.value().to_string()))
446        }
447
448        async fn fail_locally(request_id: RequestId) -> ApiResult<String> {
449            Err(HttpException::bad_request("local broken request")
450                .with_request_id(request_id.value().to_string()))
451        }
452
453        async fn me(user: AuthUser) -> ApiResult<String> {
454            Ok(Json(user.subject.clone()))
455        }
456
457        async fn scoped(service: Inject<RequestScopedService>) -> ApiResult<String> {
458            Ok(Json(service.path.clone()))
459        }
460    }
461
462    impl ControllerDefinition for HealthController {
463        fn router() -> Router<Container> {
464            RouteBuilder::<Self>::new()
465                .get("/", Self::ok)
466                .get("/fail", Self::fail)
467                .get_with_pipeline(
468                    "/fail-local",
469                    Self::fail_locally,
470                    Vec::new(),
471                    Vec::new(),
472                    vec![Arc::new(RewriteBadRequestFilter) as Arc<dyn ExceptionFilter>],
473                    None,
474                )
475                .get("/me", Self::me)
476                .get("/scoped", Self::scoped)
477                .build()
478        }
479    }
480
481    impl ExceptionFilter for RewriteBadRequestFilter {
482        fn catch(&self, exception: HttpException, _ctx: &RequestContext) -> HttpException {
483            if exception.status == axum::http::StatusCode::BAD_REQUEST {
484                HttpException::bad_request("filtered bad request")
485                    .with_optional_request_id(exception.request_id)
486            } else {
487                exception
488            }
489        }
490    }
491
492    struct TestModule;
493
494    impl ModuleDefinition for TestModule {
495        fn register(container: &Container) -> Result<()> {
496            register_provider(
497                container,
498                Provider::request_factory(|container| {
499                    let ctx = container.resolve::<FrameworkRequestContext>()?;
500                    Ok(RequestScopedService {
501                        path: ctx.uri.path().to_string(),
502                    })
503                }),
504            )?;
505            Ok(())
506        }
507
508        fn controllers() -> Vec<Router<Container>> {
509            vec![HealthController::router()]
510        }
511    }
512
513    #[tokio::test]
514    async fn request_middleware_sets_request_id_header_and_extension() {
515        let app = NestForgeFactory::<TestModule>::create()
516            .expect("factory should build")
517            .into_router();
518
519        let response = app
520            .oneshot(
521                axum::http::Request::builder()
522                    .uri("/health/")
523                    .body(axum::body::Body::empty())
524                    .expect("request should build"),
525            )
526            .await
527            .expect("request should succeed");
528
529        assert!(response.headers().contains_key(REQUEST_ID_HEADER));
530    }
531
532    #[tokio::test]
533    async fn error_responses_keep_request_id_header() {
534        let app = NestForgeFactory::<TestModule>::create()
535            .expect("factory should build")
536            .use_exception_filter::<RewriteBadRequestFilter>()
537            .into_router();
538
539        let response = app
540            .oneshot(
541                axum::http::Request::builder()
542                    .uri("/health/fail")
543                    .body(axum::body::Body::empty())
544                    .expect("request should build"),
545            )
546            .await
547            .expect("request should succeed");
548
549        assert_eq!(response.status(), axum::http::StatusCode::BAD_REQUEST);
550        assert!(response.headers().contains_key(REQUEST_ID_HEADER));
551    }
552
553    #[tokio::test]
554    async fn route_specific_exception_filters_rewrite_route_failures() {
555        let app = NestForgeFactory::<TestModule>::create()
556            .expect("factory should build")
557            .into_router();
558
559        let response = app
560            .oneshot(
561                axum::http::Request::builder()
562                    .uri("/health/fail-local")
563                    .body(axum::body::Body::empty())
564                    .expect("request should build"),
565            )
566            .await
567            .expect("request should succeed");
568
569        assert_eq!(response.status(), axum::http::StatusCode::BAD_REQUEST);
570        assert!(response.headers().contains_key(REQUEST_ID_HEADER));
571    }
572
573    #[tokio::test]
574    async fn auth_resolver_inserts_identity_for_auth_user_extractor() {
575        let app = NestForgeFactory::<TestModule>::create()
576            .expect("factory should build")
577            .with_auth_resolver(|token, _container| async move {
578                Ok(token.map(|_| AuthIdentity::new("demo-user").with_roles(["admin"])))
579            })
580            .into_router();
581
582        let response = app
583            .oneshot(
584                axum::http::Request::builder()
585                    .uri("/health/me")
586                    .header(axum::http::header::AUTHORIZATION, "Bearer demo-token")
587                    .body(axum::body::Body::empty())
588                    .expect("request should build"),
589            )
590            .await
591            .expect("request should succeed");
592
593        assert_eq!(response.status(), axum::http::StatusCode::OK);
594    }
595
596    #[tokio::test]
597    async fn request_scoped_provider_resolves_from_per_request_container() {
598        let app = NestForgeFactory::<TestModule>::create()
599            .expect("factory should build")
600            .into_router();
601
602        let response = app
603            .oneshot(
604                axum::http::Request::builder()
605                    .uri("/health/scoped")
606                    .body(axum::body::Body::empty())
607                    .expect("request should build"),
608            )
609            .await
610            .expect("request should succeed");
611
612        assert_eq!(response.status(), axum::http::StatusCode::OK);
613    }
614}