1use std::{
2 future::Future,
3 net::SocketAddr,
4 pin::Pin,
5 sync::{
6 atomic::{AtomicU64, Ordering},
7 Arc, Mutex,
8 },
9 time::{Instant, SystemTime, UNIX_EPOCH},
10};
11
12use anyhow::Result;
13use axum::{
14 body::Body,
15 http::{header::HeaderName, HeaderValue},
16 middleware::from_fn,
17 response::{IntoResponse, Response},
18 Router,
19};
20use nestforge_core::{
21 apply_exception_filters, execute_pipeline, framework_log_event, initialize_module_runtime,
22 AuthIdentity, Container, ExceptionFilter, Guard, HttpException, InitializedModule, Interceptor,
23 ModuleDefinition, NextFn, RequestContext, RequestId,
24};
25
26use crate::middleware::{
27 run_middleware_chain, MiddlewareBinding, MiddlewareConsumer, NestMiddleware,
28};
29
30pub struct NestForgeFactory<M: ModuleDefinition> {
51 _marker: std::marker::PhantomData<M>,
52 container: Container,
53 runtime: Arc<InitializedModule>,
54 controllers: Vec<Router<Container>>,
55 extra_routers: Vec<Router<Container>>,
56 global_prefix: Option<String>,
57 version: Option<String>,
58 auth_resolver: Option<Arc<AuthResolver>>,
59 global_guards: Vec<Arc<dyn Guard>>,
60 global_interceptors: Vec<Arc<dyn Interceptor>>,
61 global_exception_filters: Vec<Arc<dyn ExceptionFilter>>,
62 middleware_bindings: Vec<MiddlewareBinding>,
63}
64
65type AuthFuture = Pin<Box<dyn Future<Output = Result<Option<AuthIdentity>, HttpException>> + Send>>;
66type AuthResolver = dyn Fn(Option<String>, Container) -> AuthFuture + Send + Sync;
67
68impl<M: ModuleDefinition> NestForgeFactory<M> {
69 pub fn create() -> Result<Self> {
73 let container = Container::new();
74 let runtime = Arc::new(initialize_module_runtime::<M>(&container)?);
75 runtime.run_module_init(&container)?;
76 runtime.run_application_bootstrap(&container)?;
77 let controllers = runtime.controllers.clone();
78
79 Ok(Self {
80 _marker: std::marker::PhantomData,
81 container,
82 runtime,
83 controllers,
84 extra_routers: Vec::new(),
85 global_prefix: None,
86 version: None,
87 auth_resolver: None,
88 global_guards: Vec::new(),
89 global_interceptors: Vec::new(),
90 global_exception_filters: Vec::new(),
91 middleware_bindings: Vec::new(),
92 })
93 }
94
95 pub fn with_global_prefix(mut self, prefix: impl Into<String>) -> Self {
97 let prefix = prefix.into().trim().trim_matches('/').to_string();
98 if !prefix.is_empty() {
99 framework_log_event("global_prefix_configured", &[("prefix", prefix.clone())]);
100 self.global_prefix = Some(prefix);
101 }
102 self
103 }
104
105 pub fn with_version(mut self, version: impl Into<String>) -> Self {
107 let version = version.into().trim().trim_matches('/').to_string();
108 if !version.is_empty() {
109 framework_log_event("api_version_configured", &[("version", version.clone())]);
110 self.version = Some(version);
111 }
112 self
113 }
114
115 pub fn use_guard<G>(mut self) -> Self
119 where
120 G: Guard + Default,
121 {
122 framework_log_event(
123 "global_guard_register",
124 &[("guard", std::any::type_name::<G>().to_string())],
125 );
126 self.global_guards.push(Arc::new(G::default()));
127 self
128 }
129
130 pub fn use_interceptor<I>(mut self) -> Self
134 where
135 I: Interceptor + Default,
136 {
137 framework_log_event(
138 "global_interceptor_register",
139 &[("interceptor", std::any::type_name::<I>().to_string())],
140 );
141 self.global_interceptors.push(Arc::new(I::default()));
142 self
143 }
144
145 pub fn use_exception_filter<F>(mut self) -> Self
149 where
150 F: ExceptionFilter + Default,
151 {
152 framework_log_event(
153 "global_exception_filter_register",
154 &[("filter", std::any::type_name::<F>().to_string())],
155 );
156 self.global_exception_filters.push(Arc::new(F::default()));
157 self
158 }
159
160 pub fn use_middleware<T>(mut self) -> Self
164 where
165 T: NestMiddleware + Default,
166 {
167 let mut consumer = MiddlewareConsumer::new();
168 consumer.apply::<T>().for_all_routes();
169 self.middleware_bindings.extend(consumer.into_bindings());
170 self
171 }
172
173 pub fn configure_middleware<F>(mut self, configure: F) -> Self
175 where
176 F: FnOnce(&mut MiddlewareConsumer),
177 {
178 let mut consumer = MiddlewareConsumer::new();
179 configure(&mut consumer);
180 self.middleware_bindings.extend(consumer.into_bindings());
181 self
182 }
183
184 pub fn with_auth_resolver<F, Fut>(mut self, resolver: F) -> Self
189 where
190 F: Fn(Option<String>, Container) -> Fut + Send + Sync + 'static,
191 Fut: Future<Output = Result<Option<AuthIdentity>, HttpException>> + Send + 'static,
192 {
193 self.auth_resolver = Some(Arc::new(move |token, container| {
194 Box::pin(resolver(token, container))
195 }));
196 self
197 }
198
199 pub fn merge_router(mut self, router: Router<Container>) -> Self {
203 self.extra_routers.push(router);
204 self
205 }
206
207 pub fn container(&self) -> &Container {
209 &self.container
210 }
211
212 pub fn into_router(self) -> Router {
216 let mut app: Router<Container> = Router::new();
221
222 for controller_router in self.controllers {
226 app = app.merge(controller_router);
227 }
228 for extra_router in self.extra_routers {
229 app = app.merge(extra_router);
230 }
231
232 if let Some(version) = &self.version {
233 app = Router::new().nest(&format!("/{}", version), app);
234 }
235
236 if let Some(prefix) = &self.global_prefix {
237 app = Router::new().nest(&format!("/{}", prefix), app);
238 }
239
240 let global_guards = Arc::new(self.global_guards);
241 let global_interceptors = Arc::new(self.global_interceptors);
242 let global_exception_filters = Arc::new(self.global_exception_filters);
243 let middleware_bindings = Arc::new(self.middleware_bindings);
244 let auth_resolver = self.auth_resolver.clone();
245 let request_container = self.container.clone();
246
247 let route_exception_filters = Arc::clone(&global_exception_filters);
248 let app = app.route_layer(from_fn(move |req, next| {
249 let guards = Arc::clone(&global_guards);
250 let interceptors = Arc::clone(&global_interceptors);
251 let filters = Arc::clone(&route_exception_filters);
252 async move { execute_pipeline(req, next, guards, interceptors, filters).await }
253 }));
254
255 let app = app.layer(from_fn(
256 move |req: axum::extract::Request, next: axum::middleware::Next| {
257 let middlewares = Arc::clone(&middleware_bindings);
258 async move {
259 if middlewares.is_empty() {
260 return next.run(req).await;
261 }
262
263 let terminal = next_to_fn(next);
264 run_middleware_chain(middlewares, 0, req, terminal).await
265 }
266 },
267 ));
268
269 Router::new()
270 .merge(app)
271 .layer(from_fn(move |req, next| {
272 let auth_resolver = auth_resolver.clone();
273 let request_container = request_container.clone();
274 let exception_filters = Arc::clone(&global_exception_filters);
275 async move {
276 request_context_middleware(
277 req,
278 next,
279 request_container,
280 auth_resolver,
281 exception_filters,
282 )
283 .await
284 }
285 }))
286 .with_state(self.container)
287 }
288
289 pub async fn listen(self, port: u16) -> Result<()> {
294 let runtime = Arc::clone(&self.runtime);
295 let container = self.container.clone();
296 let app = self.into_router();
297
298 let addr = SocketAddr::from(([127, 0, 0, 1], port));
299 let listener = tokio::net::TcpListener::bind(addr).await?;
300
301 framework_log_event("server_listening", &[("addr", addr.to_string())]);
302
303 axum::serve(listener, app).await?;
304 runtime.run_module_destroy(&container)?;
305 runtime.run_application_shutdown(&container)?;
306 Ok(())
307 }
308}
309
310static NEXT_REQUEST_SEQUENCE: AtomicU64 = AtomicU64::new(1);
311const REQUEST_ID_HEADER: &str = "x-request-id";
312
313fn next_to_fn(next: axum::middleware::Next) -> NextFn {
314 let next = Arc::new(Mutex::new(Some(next)));
315
316 Arc::new(move |req: axum::extract::Request<Body>| {
317 let next = Arc::clone(&next);
318 Box::pin(async move {
319 let next = {
320 let mut guard = match next.lock() {
321 Ok(guard) => guard,
322 Err(_) => {
323 return HttpException::internal_server_error("Middleware lock poisoned")
324 .into_response();
325 }
326 };
327 guard.take()
328 };
329
330 match next {
331 Some(next) => next.run(req).await,
332 None => {
333 HttpException::internal_server_error("Middleware next called multiple times")
334 .into_response()
335 }
336 }
337 })
338 })
339}
340
341async fn request_context_middleware(
342 mut req: axum::extract::Request,
343 next: axum::middleware::Next,
344 container: Container,
345 auth_resolver: Option<Arc<AuthResolver>>,
346 exception_filters: Arc<Vec<Arc<dyn ExceptionFilter>>>,
347) -> Response {
348 let scoped_container = container.scoped();
349 let request_id = RequestId::new(generate_request_id());
350 let request_id_value = request_id.value().to_string();
351 let method = req.method().to_string();
352 let path = req.uri().path().to_string();
353 let started = Instant::now();
354 let bearer_token = req
355 .headers()
356 .get(axum::http::header::AUTHORIZATION)
357 .and_then(|value| value.to_str().ok())
358 .and_then(|value| value.strip_prefix("Bearer "))
359 .map(str::trim)
360 .filter(|value| !value.is_empty())
361 .map(str::to_string);
362
363 req.extensions_mut().insert(scoped_container.clone());
364 req.extensions_mut().insert(request_id.clone());
365 let _ = scoped_container.override_value(request_id.clone());
366 framework_log_event(
367 "request_start",
368 &[
369 ("request_id", request_id_value.clone()),
370 ("method", method.clone()),
371 ("path", path.clone()),
372 ],
373 );
374
375 if let Some(resolver) = auth_resolver {
376 match resolver(bearer_token, container).await {
377 Ok(Some(identity)) => {
378 framework_log_event(
379 "auth_identity_resolved",
380 &[
381 ("request_id", request_id_value.clone()),
382 ("subject", identity.subject.clone()),
383 ],
384 );
385 let _ = scoped_container.override_value(identity.clone());
386 req.extensions_mut().insert(Arc::new(identity));
387 }
388 Ok(None) => {}
389 Err(err) => {
390 let ctx = RequestContext::from_request(&req);
391 let _ = scoped_container.override_value(ctx.clone());
392 let mut response = apply_exception_filters(
393 err.with_request_id(request_id_value.clone()),
394 &ctx,
395 exception_filters.as_slice(),
396 )
397 .into_response();
398 attach_request_id_header(&mut response, &request_id_value);
399 framework_log_event(
400 "request_complete",
401 &[
402 ("request_id", request_id_value),
403 ("method", method),
404 ("path", path),
405 ("status", response.status().as_u16().to_string()),
406 ("duration_ms", started.elapsed().as_millis().to_string()),
407 ],
408 );
409 return response;
410 }
411 }
412 }
413
414 let ctx = RequestContext::from_request(&req);
415 let _ = scoped_container.override_value(ctx);
416
417 let mut response = next.run(req).await;
418 attach_request_id_header(&mut response, &request_id_value);
419
420 framework_log_event(
421 "request_complete",
422 &[
423 ("request_id", request_id_value),
424 ("method", method),
425 ("path", path),
426 ("status", response.status().as_u16().to_string()),
427 ("duration_ms", started.elapsed().as_millis().to_string()),
428 ],
429 );
430
431 response
432}
433
434fn generate_request_id() -> String {
435 let sequence = NEXT_REQUEST_SEQUENCE.fetch_add(1, Ordering::Relaxed);
436 let millis = SystemTime::now()
437 .duration_since(UNIX_EPOCH)
438 .map(|duration| duration.as_millis())
439 .unwrap_or_default();
440 format!("req-{millis}-{sequence}")
441}
442
443fn attach_request_id_header(response: &mut Response, request_id: &str) {
444 if let Ok(value) = HeaderValue::from_str(request_id) {
445 response
446 .headers_mut()
447 .insert(HeaderName::from_static(REQUEST_ID_HEADER), value);
448 }
449}
450
451#[cfg(test)]
452mod tests {
453 use std::sync::Arc;
454
455 use anyhow::Result;
456 use axum::Json;
457 use nestforge_core::{
458 register_provider, ApiResult, AuthUser, Container, ControllerBasePath,
459 ControllerDefinition, ExceptionFilter, HttpException, Inject, ModuleDefinition, Provider,
460 RequestContext as FrameworkRequestContext, RouteBuilder,
461 };
462 use tower::ServiceExt;
463
464 use super::*;
465
466 struct HealthController;
467 #[derive(Default)]
468 struct RewriteBadRequestFilter;
469 struct RequestScopedService {
470 path: String,
471 }
472
473 impl ControllerBasePath for HealthController {
474 fn base_path() -> &'static str {
475 "/health"
476 }
477 }
478
479 impl HealthController {
480 async fn ok(request_id: RequestId) -> ApiResult<String> {
481 Ok(Json(request_id.value().to_string()))
482 }
483
484 async fn fail(request_id: RequestId) -> ApiResult<String> {
485 Err(HttpException::bad_request("broken request")
486 .with_request_id(request_id.value().to_string()))
487 }
488
489 async fn fail_locally(request_id: RequestId) -> ApiResult<String> {
490 Err(HttpException::bad_request("local broken request")
491 .with_request_id(request_id.value().to_string()))
492 }
493
494 async fn me(user: AuthUser) -> ApiResult<String> {
495 Ok(Json(user.subject.clone()))
496 }
497
498 async fn scoped(service: Inject<RequestScopedService>) -> ApiResult<String> {
499 Ok(Json(service.path.clone()))
500 }
501 }
502
503 impl ControllerDefinition for HealthController {
504 fn router() -> Router<Container> {
505 RouteBuilder::<Self>::new()
506 .get("/", Self::ok)
507 .get("/fail", Self::fail)
508 .get_with_pipeline(
509 "/fail-local",
510 Self::fail_locally,
511 Vec::new(),
512 Vec::new(),
513 vec![Arc::new(RewriteBadRequestFilter) as Arc<dyn ExceptionFilter>],
514 None,
515 )
516 .get("/me", Self::me)
517 .get("/scoped", Self::scoped)
518 .build()
519 }
520 }
521
522 impl ExceptionFilter for RewriteBadRequestFilter {
523 fn catch(&self, exception: HttpException, _ctx: &RequestContext) -> HttpException {
524 if exception.status == axum::http::StatusCode::BAD_REQUEST {
525 HttpException::bad_request("filtered bad request")
526 .with_optional_request_id(exception.request_id)
527 } else {
528 exception
529 }
530 }
531 }
532
533 struct TestModule;
534
535 impl ModuleDefinition for TestModule {
536 fn register(container: &Container) -> Result<()> {
537 register_provider(
538 container,
539 Provider::request_factory(|container| {
540 let ctx = container.resolve::<FrameworkRequestContext>()?;
541 Ok(RequestScopedService {
542 path: ctx.uri.path().to_string(),
543 })
544 }),
545 )?;
546 Ok(())
547 }
548
549 fn controllers() -> Vec<Router<Container>> {
550 vec![HealthController::router()]
551 }
552 }
553
554 #[tokio::test]
555 async fn request_middleware_sets_request_id_header_and_extension() {
556 let app = NestForgeFactory::<TestModule>::create()
557 .expect("factory should build")
558 .into_router();
559
560 let response = app
561 .oneshot(
562 axum::http::Request::builder()
563 .uri("/health/")
564 .body(axum::body::Body::empty())
565 .expect("request should build"),
566 )
567 .await
568 .expect("request should succeed");
569
570 assert!(response.headers().contains_key(REQUEST_ID_HEADER));
571 }
572
573 #[tokio::test]
574 async fn error_responses_keep_request_id_header() {
575 let app = NestForgeFactory::<TestModule>::create()
576 .expect("factory should build")
577 .use_exception_filter::<RewriteBadRequestFilter>()
578 .into_router();
579
580 let response = app
581 .oneshot(
582 axum::http::Request::builder()
583 .uri("/health/fail")
584 .body(axum::body::Body::empty())
585 .expect("request should build"),
586 )
587 .await
588 .expect("request should succeed");
589
590 assert_eq!(response.status(), axum::http::StatusCode::BAD_REQUEST);
591 assert!(response.headers().contains_key(REQUEST_ID_HEADER));
592 }
593
594 #[tokio::test]
595 async fn route_specific_exception_filters_rewrite_route_failures() {
596 let app = NestForgeFactory::<TestModule>::create()
597 .expect("factory should build")
598 .into_router();
599
600 let response = app
601 .oneshot(
602 axum::http::Request::builder()
603 .uri("/health/fail-local")
604 .body(axum::body::Body::empty())
605 .expect("request should build"),
606 )
607 .await
608 .expect("request should succeed");
609
610 assert_eq!(response.status(), axum::http::StatusCode::BAD_REQUEST);
611 assert!(response.headers().contains_key(REQUEST_ID_HEADER));
612 }
613
614 #[tokio::test]
615 async fn auth_resolver_inserts_identity_for_auth_user_extractor() {
616 let app = NestForgeFactory::<TestModule>::create()
617 .expect("factory should build")
618 .with_auth_resolver(|token, _container| async move {
619 Ok(token.map(|_| AuthIdentity::new("demo-user").with_roles(["admin"])))
620 })
621 .into_router();
622
623 let response = app
624 .oneshot(
625 axum::http::Request::builder()
626 .uri("/health/me")
627 .header(axum::http::header::AUTHORIZATION, "Bearer demo-token")
628 .body(axum::body::Body::empty())
629 .expect("request should build"),
630 )
631 .await
632 .expect("request should succeed");
633
634 assert_eq!(response.status(), axum::http::StatusCode::OK);
635 }
636
637 #[tokio::test]
638 async fn request_scoped_provider_resolves_from_per_request_container() {
639 let app = NestForgeFactory::<TestModule>::create()
640 .expect("factory should build")
641 .into_router();
642
643 let response = app
644 .oneshot(
645 axum::http::Request::builder()
646 .uri("/health/scoped")
647 .body(axum::body::Body::empty())
648 .expect("request should build"),
649 )
650 .await
651 .expect("request should succeed");
652
653 assert_eq!(response.status(), axum::http::StatusCode::OK);
654 }
655}