Skip to main content

nestforge_http/
factory.rs

1use std::{
2    future::Future,
3    net::SocketAddr,
4    pin::Pin,
5    sync::{
6        atomic::{AtomicU64, Ordering},
7        Arc, Mutex,
8    },
9    time::{Instant, SystemTime, UNIX_EPOCH},
10};
11
12use anyhow::Result;
13use axum::{
14    body::Body,
15    http::{header::HeaderName, HeaderValue},
16    middleware::from_fn,
17    response::{IntoResponse, Response},
18    Router,
19};
20use nestforge_core::{
21    apply_exception_filters, execute_pipeline, framework_log_event, initialize_module_runtime,
22    AuthIdentity, Container, ExceptionFilter, Guard, HttpException, InitializedModule, Interceptor,
23    ModuleDefinition, NextFn, RequestContext, RequestId,
24};
25
26use crate::middleware::{
27    run_middleware_chain, MiddlewareBinding, MiddlewareConsumer, NestMiddleware,
28};
29
30/*
31NestForgeFactory = app bootstrapper.
32
33This is the NestFactory.create(AppModule) vibe.
34
35Now it:
36- builds DI container
37- asks the module to register providers
38- asks the module for controllers
39- merges controller routers into one app router
40*/
41pub struct NestForgeFactory<M: ModuleDefinition> {
42    _marker: std::marker::PhantomData<M>,
43    container: Container,
44    runtime: Arc<InitializedModule>,
45    controllers: Vec<Router<Container>>,
46    extra_routers: Vec<Router<Container>>,
47    global_prefix: Option<String>,
48    version: Option<String>,
49    auth_resolver: Option<Arc<AuthResolver>>,
50    global_guards: Vec<Arc<dyn Guard>>,
51    global_interceptors: Vec<Arc<dyn Interceptor>>,
52    global_exception_filters: Vec<Arc<dyn ExceptionFilter>>,
53    middleware_bindings: Vec<MiddlewareBinding>,
54}
55
56type AuthFuture = Pin<Box<dyn Future<Output = Result<Option<AuthIdentity>, HttpException>> + Send>>;
57type AuthResolver = dyn Fn(Option<String>, Container) -> AuthFuture + Send + Sync;
58
59impl<M: ModuleDefinition> NestForgeFactory<M> {
60    pub fn create() -> Result<Self> {
61        let container = Container::new();
62        let runtime = Arc::new(initialize_module_runtime::<M>(&container)?);
63        runtime.run_module_init(&container)?;
64        runtime.run_application_bootstrap(&container)?;
65        let controllers = runtime.controllers.clone();
66
67        Ok(Self {
68            _marker: std::marker::PhantomData,
69            container,
70            runtime,
71            controllers,
72            extra_routers: Vec::new(),
73            global_prefix: None,
74            version: None,
75            auth_resolver: None,
76            global_guards: Vec::new(),
77            global_interceptors: Vec::new(),
78            global_exception_filters: Vec::new(),
79            middleware_bindings: Vec::new(),
80        })
81    }
82
83    pub fn with_global_prefix(mut self, prefix: impl Into<String>) -> Self {
84        let prefix = prefix.into().trim().trim_matches('/').to_string();
85        if !prefix.is_empty() {
86            framework_log_event("global_prefix_configured", &[("prefix", prefix.clone())]);
87            self.global_prefix = Some(prefix);
88        }
89        self
90    }
91
92    pub fn with_version(mut self, version: impl Into<String>) -> Self {
93        let version = version.into().trim().trim_matches('/').to_string();
94        if !version.is_empty() {
95            framework_log_event("api_version_configured", &[("version", version.clone())]);
96            self.version = Some(version);
97        }
98        self
99    }
100
101    pub fn use_guard<G>(mut self) -> Self
102    where
103        G: Guard + Default,
104    {
105        framework_log_event(
106            "global_guard_register",
107            &[("guard", std::any::type_name::<G>().to_string())],
108        );
109        self.global_guards.push(Arc::new(G::default()));
110        self
111    }
112
113    pub fn use_interceptor<I>(mut self) -> Self
114    where
115        I: Interceptor + Default,
116    {
117        framework_log_event(
118            "global_interceptor_register",
119            &[("interceptor", std::any::type_name::<I>().to_string())],
120        );
121        self.global_interceptors.push(Arc::new(I::default()));
122        self
123    }
124
125    pub fn use_exception_filter<F>(mut self) -> Self
126    where
127        F: ExceptionFilter + Default,
128    {
129        framework_log_event(
130            "global_exception_filter_register",
131            &[("filter", std::any::type_name::<F>().to_string())],
132        );
133        self.global_exception_filters.push(Arc::new(F::default()));
134        self
135    }
136
137    pub fn use_middleware<T>(mut self) -> Self
138    where
139        T: NestMiddleware + Default,
140    {
141        let mut consumer = MiddlewareConsumer::new();
142        consumer.apply::<T>().for_all_routes();
143        self.middleware_bindings.extend(consumer.into_bindings());
144        self
145    }
146
147    pub fn configure_middleware<F>(mut self, configure: F) -> Self
148    where
149        F: FnOnce(&mut MiddlewareConsumer),
150    {
151        let mut consumer = MiddlewareConsumer::new();
152        configure(&mut consumer);
153        self.middleware_bindings.extend(consumer.into_bindings());
154        self
155    }
156
157    pub fn with_auth_resolver<F, Fut>(mut self, resolver: F) -> Self
158    where
159        F: Fn(Option<String>, Container) -> Fut + Send + Sync + 'static,
160        Fut: Future<Output = Result<Option<AuthIdentity>, HttpException>> + Send + 'static,
161    {
162        self.auth_resolver = Some(Arc::new(move |token, container| {
163            Box::pin(resolver(token, container))
164        }));
165        self
166    }
167
168    pub fn merge_router(mut self, router: Router<Container>) -> Self {
169        self.extra_routers.push(router);
170        self
171    }
172
173    pub fn container(&self) -> &Container {
174        &self.container
175    }
176
177    pub fn into_router(self) -> Router {
178        /*
179        Build a router that EXPECTS Container state.
180        We don't attach the actual state yet.
181        */
182        let mut app: Router<Container> = Router::new();
183
184        /*
185        Mount all controller routers (they are also Router<Container>)
186        */
187        for controller_router in self.controllers {
188            app = app.merge(controller_router);
189        }
190        for extra_router in self.extra_routers {
191            app = app.merge(extra_router);
192        }
193
194        if let Some(version) = &self.version {
195            app = Router::new().nest(&format!("/{}", version), app);
196        }
197
198        if let Some(prefix) = &self.global_prefix {
199            app = Router::new().nest(&format!("/{}", prefix), app);
200        }
201
202        let global_guards = Arc::new(self.global_guards);
203        let global_interceptors = Arc::new(self.global_interceptors);
204        let global_exception_filters = Arc::new(self.global_exception_filters);
205        let middleware_bindings = Arc::new(self.middleware_bindings);
206        let auth_resolver = self.auth_resolver.clone();
207        let request_container = self.container.clone();
208
209        let route_exception_filters = Arc::clone(&global_exception_filters);
210        let app = app.route_layer(from_fn(move |req, next| {
211            let guards = Arc::clone(&global_guards);
212            let interceptors = Arc::clone(&global_interceptors);
213            let filters = Arc::clone(&route_exception_filters);
214            async move { execute_pipeline(req, next, guards, interceptors, filters).await }
215        }));
216
217        let app = app.layer(from_fn(
218            move |req: axum::extract::Request, next: axum::middleware::Next| {
219                let middlewares = Arc::clone(&middleware_bindings);
220                async move {
221                    if middlewares.is_empty() {
222                        return next.run(req).await;
223                    }
224
225                    let terminal = next_to_fn(next);
226                    run_middleware_chain(middlewares, 0, req, terminal).await
227                }
228            },
229        ));
230
231        Router::new()
232            .merge(app)
233            .layer(from_fn(move |req, next| {
234                let auth_resolver = auth_resolver.clone();
235                let request_container = request_container.clone();
236                let exception_filters = Arc::clone(&global_exception_filters);
237                async move {
238                    request_context_middleware(
239                        req,
240                        next,
241                        request_container,
242                        auth_resolver,
243                        exception_filters,
244                    )
245                    .await
246                }
247            }))
248            .with_state(self.container)
249    }
250
251    pub async fn listen(self, port: u16) -> Result<()> {
252        let runtime = Arc::clone(&self.runtime);
253        let container = self.container.clone();
254        let app = self.into_router();
255
256        let addr = SocketAddr::from(([127, 0, 0, 1], port));
257        let listener = tokio::net::TcpListener::bind(addr).await?;
258
259        framework_log_event("server_listening", &[("addr", addr.to_string())]);
260
261        axum::serve(listener, app).await?;
262        runtime.run_module_destroy(&container)?;
263        runtime.run_application_shutdown(&container)?;
264        Ok(())
265    }
266}
267
268static NEXT_REQUEST_SEQUENCE: AtomicU64 = AtomicU64::new(1);
269const REQUEST_ID_HEADER: &str = "x-request-id";
270
271fn next_to_fn(next: axum::middleware::Next) -> NextFn {
272    let next = Arc::new(Mutex::new(Some(next)));
273
274    Arc::new(move |req: axum::extract::Request<Body>| {
275        let next = Arc::clone(&next);
276        Box::pin(async move {
277            let next = {
278                let mut guard = match next.lock() {
279                    Ok(guard) => guard,
280                    Err(_) => {
281                        return HttpException::internal_server_error("Middleware lock poisoned")
282                            .into_response();
283                    }
284                };
285                guard.take()
286            };
287
288            match next {
289                Some(next) => next.run(req).await,
290                None => {
291                    HttpException::internal_server_error("Middleware next called multiple times")
292                        .into_response()
293                }
294            }
295        })
296    })
297}
298
299async fn request_context_middleware(
300    mut req: axum::extract::Request,
301    next: axum::middleware::Next,
302    container: Container,
303    auth_resolver: Option<Arc<AuthResolver>>,
304    exception_filters: Arc<Vec<Arc<dyn ExceptionFilter>>>,
305) -> Response {
306    let scoped_container = container.scoped();
307    let request_id = RequestId::new(generate_request_id());
308    let request_id_value = request_id.value().to_string();
309    let method = req.method().to_string();
310    let path = req.uri().path().to_string();
311    let started = Instant::now();
312    let bearer_token = req
313        .headers()
314        .get(axum::http::header::AUTHORIZATION)
315        .and_then(|value| value.to_str().ok())
316        .and_then(|value| value.strip_prefix("Bearer "))
317        .map(str::trim)
318        .filter(|value| !value.is_empty())
319        .map(str::to_string);
320
321    req.extensions_mut().insert(scoped_container.clone());
322    req.extensions_mut().insert(request_id.clone());
323    let _ = scoped_container.override_value(request_id.clone());
324    framework_log_event(
325        "request_start",
326        &[
327            ("request_id", request_id_value.clone()),
328            ("method", method.clone()),
329            ("path", path.clone()),
330        ],
331    );
332
333    if let Some(resolver) = auth_resolver {
334        match resolver(bearer_token, container).await {
335            Ok(Some(identity)) => {
336                framework_log_event(
337                    "auth_identity_resolved",
338                    &[
339                        ("request_id", request_id_value.clone()),
340                        ("subject", identity.subject.clone()),
341                    ],
342                );
343                let _ = scoped_container.override_value(identity.clone());
344                req.extensions_mut().insert(Arc::new(identity));
345            }
346            Ok(None) => {}
347            Err(err) => {
348                let ctx = RequestContext::from_request(&req);
349                let _ = scoped_container.override_value(ctx.clone());
350                let mut response = apply_exception_filters(
351                    err.with_request_id(request_id_value.clone()),
352                    &ctx,
353                    exception_filters.as_slice(),
354                )
355                .into_response();
356                attach_request_id_header(&mut response, &request_id_value);
357                framework_log_event(
358                    "request_complete",
359                    &[
360                        ("request_id", request_id_value),
361                        ("method", method),
362                        ("path", path),
363                        ("status", response.status().as_u16().to_string()),
364                        ("duration_ms", started.elapsed().as_millis().to_string()),
365                    ],
366                );
367                return response;
368            }
369        }
370    }
371
372    let ctx = RequestContext::from_request(&req);
373    let _ = scoped_container.override_value(ctx);
374
375    let mut response = next.run(req).await;
376    attach_request_id_header(&mut response, &request_id_value);
377
378    framework_log_event(
379        "request_complete",
380        &[
381            ("request_id", request_id_value),
382            ("method", method),
383            ("path", path),
384            ("status", response.status().as_u16().to_string()),
385            ("duration_ms", started.elapsed().as_millis().to_string()),
386        ],
387    );
388
389    response
390}
391
392fn generate_request_id() -> String {
393    let sequence = NEXT_REQUEST_SEQUENCE.fetch_add(1, Ordering::Relaxed);
394    let millis = SystemTime::now()
395        .duration_since(UNIX_EPOCH)
396        .map(|duration| duration.as_millis())
397        .unwrap_or_default();
398    format!("req-{millis}-{sequence}")
399}
400
401fn attach_request_id_header(response: &mut Response, request_id: &str) {
402    if let Ok(value) = HeaderValue::from_str(request_id) {
403        response
404            .headers_mut()
405            .insert(HeaderName::from_static(REQUEST_ID_HEADER), value);
406    }
407}
408
409#[cfg(test)]
410mod tests {
411    use std::sync::Arc;
412
413    use anyhow::Result;
414    use axum::Json;
415    use nestforge_core::{
416        register_provider, ApiResult, AuthUser, Container, ControllerBasePath,
417        ControllerDefinition, ExceptionFilter, HttpException, Inject, ModuleDefinition, Provider,
418        RequestContext as FrameworkRequestContext, RouteBuilder,
419    };
420    use tower::ServiceExt;
421
422    use super::*;
423
424    struct HealthController;
425    #[derive(Default)]
426    struct RewriteBadRequestFilter;
427    struct RequestScopedService {
428        path: String,
429    }
430
431    impl ControllerBasePath for HealthController {
432        fn base_path() -> &'static str {
433            "/health"
434        }
435    }
436
437    impl HealthController {
438        async fn ok(request_id: RequestId) -> ApiResult<String> {
439            Ok(Json(request_id.value().to_string()))
440        }
441
442        async fn fail(request_id: RequestId) -> ApiResult<String> {
443            Err(HttpException::bad_request("broken request")
444                .with_request_id(request_id.value().to_string()))
445        }
446
447        async fn fail_locally(request_id: RequestId) -> ApiResult<String> {
448            Err(HttpException::bad_request("local broken request")
449                .with_request_id(request_id.value().to_string()))
450        }
451
452        async fn me(user: AuthUser) -> ApiResult<String> {
453            Ok(Json(user.subject.clone()))
454        }
455
456        async fn scoped(service: Inject<RequestScopedService>) -> ApiResult<String> {
457            Ok(Json(service.path.clone()))
458        }
459    }
460
461    impl ControllerDefinition for HealthController {
462        fn router() -> Router<Container> {
463            RouteBuilder::<Self>::new()
464                .get("/", Self::ok)
465                .get("/fail", Self::fail)
466                .get_with_pipeline(
467                    "/fail-local",
468                    Self::fail_locally,
469                    Vec::new(),
470                    Vec::new(),
471                    vec![Arc::new(RewriteBadRequestFilter) as Arc<dyn ExceptionFilter>],
472                    None,
473                )
474                .get("/me", Self::me)
475                .get("/scoped", Self::scoped)
476                .build()
477        }
478    }
479
480    impl ExceptionFilter for RewriteBadRequestFilter {
481        fn catch(&self, exception: HttpException, _ctx: &RequestContext) -> HttpException {
482            if exception.status == axum::http::StatusCode::BAD_REQUEST {
483                HttpException::bad_request("filtered bad request")
484                    .with_optional_request_id(exception.request_id)
485            } else {
486                exception
487            }
488        }
489    }
490
491    struct TestModule;
492
493    impl ModuleDefinition for TestModule {
494        fn register(container: &Container) -> Result<()> {
495            register_provider(
496                container,
497                Provider::request_factory(|container| {
498                    let ctx = container.resolve::<FrameworkRequestContext>()?;
499                    Ok(RequestScopedService {
500                        path: ctx.uri.path().to_string(),
501                    })
502                }),
503            )?;
504            Ok(())
505        }
506
507        fn controllers() -> Vec<Router<Container>> {
508            vec![HealthController::router()]
509        }
510    }
511
512    #[tokio::test]
513    async fn request_middleware_sets_request_id_header_and_extension() {
514        let app = NestForgeFactory::<TestModule>::create()
515            .expect("factory should build")
516            .into_router();
517
518        let response = app
519            .oneshot(
520                axum::http::Request::builder()
521                    .uri("/health/")
522                    .body(axum::body::Body::empty())
523                    .expect("request should build"),
524            )
525            .await
526            .expect("request should succeed");
527
528        assert!(response.headers().contains_key(REQUEST_ID_HEADER));
529    }
530
531    #[tokio::test]
532    async fn error_responses_keep_request_id_header() {
533        let app = NestForgeFactory::<TestModule>::create()
534            .expect("factory should build")
535            .use_exception_filter::<RewriteBadRequestFilter>()
536            .into_router();
537
538        let response = app
539            .oneshot(
540                axum::http::Request::builder()
541                    .uri("/health/fail")
542                    .body(axum::body::Body::empty())
543                    .expect("request should build"),
544            )
545            .await
546            .expect("request should succeed");
547
548        assert_eq!(response.status(), axum::http::StatusCode::BAD_REQUEST);
549        assert!(response.headers().contains_key(REQUEST_ID_HEADER));
550    }
551
552    #[tokio::test]
553    async fn route_specific_exception_filters_rewrite_route_failures() {
554        let app = NestForgeFactory::<TestModule>::create()
555            .expect("factory should build")
556            .into_router();
557
558        let response = app
559            .oneshot(
560                axum::http::Request::builder()
561                    .uri("/health/fail-local")
562                    .body(axum::body::Body::empty())
563                    .expect("request should build"),
564            )
565            .await
566            .expect("request should succeed");
567
568        assert_eq!(response.status(), axum::http::StatusCode::BAD_REQUEST);
569        assert!(response.headers().contains_key(REQUEST_ID_HEADER));
570    }
571
572    #[tokio::test]
573    async fn auth_resolver_inserts_identity_for_auth_user_extractor() {
574        let app = NestForgeFactory::<TestModule>::create()
575            .expect("factory should build")
576            .with_auth_resolver(|token, _container| async move {
577                Ok(token.map(|_| AuthIdentity::new("demo-user").with_roles(["admin"])))
578            })
579            .into_router();
580
581        let response = app
582            .oneshot(
583                axum::http::Request::builder()
584                    .uri("/health/me")
585                    .header(axum::http::header::AUTHORIZATION, "Bearer demo-token")
586                    .body(axum::body::Body::empty())
587                    .expect("request should build"),
588            )
589            .await
590            .expect("request should succeed");
591
592        assert_eq!(response.status(), axum::http::StatusCode::OK);
593    }
594
595    #[tokio::test]
596    async fn request_scoped_provider_resolves_from_per_request_container() {
597        let app = NestForgeFactory::<TestModule>::create()
598            .expect("factory should build")
599            .into_router();
600
601        let response = app
602            .oneshot(
603                axum::http::Request::builder()
604                    .uri("/health/scoped")
605                    .body(axum::body::Body::empty())
606                    .expect("request should build"),
607            )
608            .await
609            .expect("request should succeed");
610
611        assert_eq!(response.status(), axum::http::StatusCode::OK);
612    }
613}