1use std::{
2 future::Future,
3 net::SocketAddr,
4 pin::Pin,
5 sync::{
6 atomic::{AtomicU64, Ordering},
7 Arc, Mutex,
8 },
9 time::{Instant, SystemTime, UNIX_EPOCH},
10};
11
12use anyhow::Result;
13use axum::{
14 body::Body,
15 http::{header::HeaderName, HeaderValue},
16 middleware::from_fn,
17 response::{IntoResponse, Response},
18 Router,
19};
20use nestforge_core::{
21 apply_exception_filters, execute_pipeline, framework_log_event, initialize_module_runtime,
22 AuthIdentity, Container, ExceptionFilter, Guard, HttpException, InitializedModule, NextFn,
23 Interceptor, ModuleDefinition, RequestContext, RequestId,
24};
25
26use crate::middleware::{run_middleware_chain, MiddlewareBinding, MiddlewareConsumer, NestMiddleware};
27
28pub struct NestForgeFactory<M: ModuleDefinition> {
40 _marker: std::marker::PhantomData<M>,
41 container: Container,
42 runtime: Arc<InitializedModule>,
43 controllers: Vec<Router<Container>>,
44 extra_routers: Vec<Router<Container>>,
45 global_prefix: Option<String>,
46 version: Option<String>,
47 auth_resolver: Option<Arc<AuthResolver>>,
48 global_guards: Vec<Arc<dyn Guard>>,
49 global_interceptors: Vec<Arc<dyn Interceptor>>,
50 global_exception_filters: Vec<Arc<dyn ExceptionFilter>>,
51 middleware_bindings: Vec<MiddlewareBinding>,
52}
53
54type AuthFuture = Pin<Box<dyn Future<Output = Result<Option<AuthIdentity>, HttpException>> + Send>>;
55type AuthResolver = dyn Fn(Option<String>, Container) -> AuthFuture + Send + Sync;
56
57impl<M: ModuleDefinition> NestForgeFactory<M> {
58 pub fn create() -> Result<Self> {
59 let container = Container::new();
60 let runtime = Arc::new(initialize_module_runtime::<M>(&container)?);
61 runtime.run_module_init(&container)?;
62 runtime.run_application_bootstrap(&container)?;
63 let controllers = runtime.controllers.clone();
64
65 Ok(Self {
66 _marker: std::marker::PhantomData,
67 container,
68 runtime,
69 controllers,
70 extra_routers: Vec::new(),
71 global_prefix: None,
72 version: None,
73 auth_resolver: None,
74 global_guards: Vec::new(),
75 global_interceptors: Vec::new(),
76 global_exception_filters: Vec::new(),
77 middleware_bindings: Vec::new(),
78 })
79 }
80
81 pub fn with_global_prefix(mut self, prefix: impl Into<String>) -> Self {
82 let prefix = prefix.into().trim().trim_matches('/').to_string();
83 if !prefix.is_empty() {
84 framework_log_event(
85 "global_prefix_configured",
86 &[("prefix", prefix.clone())],
87 );
88 self.global_prefix = Some(prefix);
89 }
90 self
91 }
92
93 pub fn with_version(mut self, version: impl Into<String>) -> Self {
94 let version = version.into().trim().trim_matches('/').to_string();
95 if !version.is_empty() {
96 framework_log_event(
97 "api_version_configured",
98 &[("version", version.clone())],
99 );
100 self.version = Some(version);
101 }
102 self
103 }
104
105 pub fn use_guard<G>(mut self) -> Self
106 where
107 G: Guard + Default,
108 {
109 framework_log_event(
110 "global_guard_register",
111 &[("guard", std::any::type_name::<G>().to_string())],
112 );
113 self.global_guards.push(Arc::new(G::default()));
114 self
115 }
116
117 pub fn use_interceptor<I>(mut self) -> Self
118 where
119 I: Interceptor + Default,
120 {
121 framework_log_event(
122 "global_interceptor_register",
123 &[("interceptor", std::any::type_name::<I>().to_string())],
124 );
125 self.global_interceptors.push(Arc::new(I::default()));
126 self
127 }
128
129 pub fn use_exception_filter<F>(mut self) -> Self
130 where
131 F: ExceptionFilter + Default,
132 {
133 framework_log_event(
134 "global_exception_filter_register",
135 &[("filter", std::any::type_name::<F>().to_string())],
136 );
137 self.global_exception_filters.push(Arc::new(F::default()));
138 self
139 }
140
141 pub fn use_middleware<T>(mut self) -> Self
142 where
143 T: NestMiddleware + Default,
144 {
145 let mut consumer = MiddlewareConsumer::new();
146 consumer.apply::<T>().for_all_routes();
147 self.middleware_bindings.extend(consumer.into_bindings());
148 self
149 }
150
151 pub fn configure_middleware<F>(mut self, configure: F) -> Self
152 where
153 F: FnOnce(&mut MiddlewareConsumer),
154 {
155 let mut consumer = MiddlewareConsumer::new();
156 configure(&mut consumer);
157 self.middleware_bindings.extend(consumer.into_bindings());
158 self
159 }
160
161 pub fn with_auth_resolver<F, Fut>(mut self, resolver: F) -> Self
162 where
163 F: Fn(Option<String>, Container) -> Fut + Send + Sync + 'static,
164 Fut: Future<Output = Result<Option<AuthIdentity>, HttpException>> + Send + 'static,
165 {
166 self.auth_resolver = Some(Arc::new(move |token, container| {
167 Box::pin(resolver(token, container))
168 }));
169 self
170 }
171
172 pub fn merge_router(mut self, router: Router<Container>) -> Self {
173 self.extra_routers.push(router);
174 self
175 }
176
177 pub fn container(&self) -> &Container {
178 &self.container
179 }
180
181 pub fn into_router(self) -> Router {
182 let mut app: Router<Container> = Router::new();
187
188 for controller_router in self.controllers {
192 app = app.merge(controller_router);
193 }
194 for extra_router in self.extra_routers {
195 app = app.merge(extra_router);
196 }
197
198 if let Some(version) = &self.version {
199 app = Router::new().nest(&format!("/{}", version), app);
200 }
201
202 if let Some(prefix) = &self.global_prefix {
203 app = Router::new().nest(&format!("/{}", prefix), app);
204 }
205
206 let global_guards = Arc::new(self.global_guards);
207 let global_interceptors = Arc::new(self.global_interceptors);
208 let global_exception_filters = Arc::new(self.global_exception_filters);
209 let middleware_bindings = Arc::new(self.middleware_bindings);
210 let auth_resolver = self.auth_resolver.clone();
211 let request_container = self.container.clone();
212
213 let route_exception_filters = Arc::clone(&global_exception_filters);
214 let app = app.route_layer(from_fn(move |req, next| {
215 let guards = Arc::clone(&global_guards);
216 let interceptors = Arc::clone(&global_interceptors);
217 let filters = Arc::clone(&route_exception_filters);
218 async move { execute_pipeline(req, next, guards, interceptors, filters).await }
219 }));
220
221 let app = app.layer(from_fn(move |req: axum::extract::Request, next: axum::middleware::Next| {
222 let middlewares = Arc::clone(&middleware_bindings);
223 async move {
224 if middlewares.is_empty() {
225 return next.run(req).await;
226 }
227
228 let terminal = next_to_fn(next);
229 run_middleware_chain(middlewares, 0, req, terminal).await
230 }
231 }));
232
233 Router::new()
234 .merge(app)
235 .layer(from_fn(move |req, next| {
236 let auth_resolver = auth_resolver.clone();
237 let request_container = request_container.clone();
238 let exception_filters = Arc::clone(&global_exception_filters);
239 async move {
240 request_context_middleware(
241 req,
242 next,
243 request_container,
244 auth_resolver,
245 exception_filters,
246 )
247 .await
248 }
249 }))
250 .with_state(self.container)
251 }
252
253 pub async fn listen(self, port: u16) -> Result<()> {
254 let runtime = Arc::clone(&self.runtime);
255 let container = self.container.clone();
256 let app = self.into_router();
257
258 let addr = SocketAddr::from(([127, 0, 0, 1], port));
259 let listener = tokio::net::TcpListener::bind(addr).await?;
260
261 framework_log_event("server_listening", &[("addr", addr.to_string())]);
262
263 axum::serve(listener, app).await?;
264 runtime.run_module_destroy(&container)?;
265 runtime.run_application_shutdown(&container)?;
266 Ok(())
267 }
268}
269
270static NEXT_REQUEST_SEQUENCE: AtomicU64 = AtomicU64::new(1);
271const REQUEST_ID_HEADER: &str = "x-request-id";
272
273fn next_to_fn(next: axum::middleware::Next) -> NextFn {
274 let next = Arc::new(Mutex::new(Some(next)));
275
276 Arc::new(move |req: axum::extract::Request<Body>| {
277 let next = Arc::clone(&next);
278 Box::pin(async move {
279 let next = {
280 let mut guard = match next.lock() {
281 Ok(guard) => guard,
282 Err(_) => {
283 return HttpException::internal_server_error("Middleware lock poisoned")
284 .into_response();
285 }
286 };
287 guard.take()
288 };
289
290 match next {
291 Some(next) => next.run(req).await,
292 None => HttpException::internal_server_error("Middleware next called multiple times")
293 .into_response(),
294 }
295 })
296 })
297}
298
299async fn request_context_middleware(
300 mut req: axum::extract::Request,
301 next: axum::middleware::Next,
302 container: Container,
303 auth_resolver: Option<Arc<AuthResolver>>,
304 exception_filters: Arc<Vec<Arc<dyn ExceptionFilter>>>,
305) -> Response {
306 let scoped_container = container.scoped();
307 let request_id = RequestId::new(generate_request_id());
308 let request_id_value = request_id.value().to_string();
309 let method = req.method().to_string();
310 let path = req.uri().path().to_string();
311 let started = Instant::now();
312 let bearer_token = req
313 .headers()
314 .get(axum::http::header::AUTHORIZATION)
315 .and_then(|value| value.to_str().ok())
316 .and_then(|value| value.strip_prefix("Bearer "))
317 .map(str::trim)
318 .filter(|value| !value.is_empty())
319 .map(str::to_string);
320
321 req.extensions_mut().insert(scoped_container.clone());
322 req.extensions_mut().insert(request_id.clone());
323 let _ = scoped_container.override_value(request_id.clone());
324 framework_log_event(
325 "request_start",
326 &[
327 ("request_id", request_id_value.clone()),
328 ("method", method.clone()),
329 ("path", path.clone()),
330 ],
331 );
332
333 if let Some(resolver) = auth_resolver {
334 match resolver(bearer_token, container).await {
335 Ok(Some(identity)) => {
336 framework_log_event(
337 "auth_identity_resolved",
338 &[
339 ("request_id", request_id_value.clone()),
340 ("subject", identity.subject.clone()),
341 ],
342 );
343 let _ = scoped_container.override_value(identity.clone());
344 req.extensions_mut().insert(Arc::new(identity));
345 }
346 Ok(None) => {}
347 Err(err) => {
348 let ctx = RequestContext::from_request(&req);
349 let _ = scoped_container.override_value(ctx.clone());
350 let mut response = apply_exception_filters(
351 err.with_request_id(request_id_value.clone()),
352 &ctx,
353 exception_filters.as_slice(),
354 )
355 .into_response();
356 attach_request_id_header(&mut response, &request_id_value);
357 framework_log_event(
358 "request_complete",
359 &[
360 ("request_id", request_id_value),
361 ("method", method),
362 ("path", path),
363 ("status", response.status().as_u16().to_string()),
364 ("duration_ms", started.elapsed().as_millis().to_string()),
365 ],
366 );
367 return response;
368 }
369 }
370 }
371
372 let ctx = RequestContext::from_request(&req);
373 let _ = scoped_container.override_value(ctx);
374
375 let mut response = next.run(req).await;
376 attach_request_id_header(&mut response, &request_id_value);
377
378 framework_log_event(
379 "request_complete",
380 &[
381 ("request_id", request_id_value),
382 ("method", method),
383 ("path", path),
384 ("status", response.status().as_u16().to_string()),
385 ("duration_ms", started.elapsed().as_millis().to_string()),
386 ],
387 );
388
389 response
390}
391
392fn generate_request_id() -> String {
393 let sequence = NEXT_REQUEST_SEQUENCE.fetch_add(1, Ordering::Relaxed);
394 let millis = SystemTime::now()
395 .duration_since(UNIX_EPOCH)
396 .map(|duration| duration.as_millis())
397 .unwrap_or_default();
398 format!("req-{millis}-{sequence}")
399}
400
401fn attach_request_id_header(response: &mut Response, request_id: &str) {
402 if let Ok(value) = HeaderValue::from_str(request_id) {
403 response.headers_mut().insert(
404 HeaderName::from_static(REQUEST_ID_HEADER),
405 value,
406 );
407 }
408}
409
410#[cfg(test)]
411mod tests {
412 use std::sync::Arc;
413
414 use anyhow::Result;
415 use axum::Json;
416 use nestforge_core::{
417 register_provider, ApiResult, AuthUser, Container, ControllerBasePath, ControllerDefinition,
418 ExceptionFilter, HttpException, Inject, ModuleDefinition, Provider,
419 RequestContext as FrameworkRequestContext, RouteBuilder,
420 };
421 use tower::ServiceExt;
422
423 use super::*;
424
425 struct HealthController;
426 #[derive(Default)]
427 struct RewriteBadRequestFilter;
428 struct RequestScopedService {
429 path: String,
430 }
431
432 impl ControllerBasePath for HealthController {
433 fn base_path() -> &'static str {
434 "/health"
435 }
436 }
437
438 impl HealthController {
439 async fn ok(request_id: RequestId) -> ApiResult<String> {
440 Ok(Json(request_id.value().to_string()))
441 }
442
443 async fn fail(request_id: RequestId) -> ApiResult<String> {
444 Err(HttpException::bad_request("broken request")
445 .with_request_id(request_id.value().to_string()))
446 }
447
448 async fn fail_locally(request_id: RequestId) -> ApiResult<String> {
449 Err(HttpException::bad_request("local broken request")
450 .with_request_id(request_id.value().to_string()))
451 }
452
453 async fn me(user: AuthUser) -> ApiResult<String> {
454 Ok(Json(user.subject.clone()))
455 }
456
457 async fn scoped(service: Inject<RequestScopedService>) -> ApiResult<String> {
458 Ok(Json(service.path.clone()))
459 }
460 }
461
462 impl ControllerDefinition for HealthController {
463 fn router() -> Router<Container> {
464 RouteBuilder::<Self>::new()
465 .get("/", Self::ok)
466 .get("/fail", Self::fail)
467 .get_with_pipeline(
468 "/fail-local",
469 Self::fail_locally,
470 Vec::new(),
471 Vec::new(),
472 vec![Arc::new(RewriteBadRequestFilter) as Arc<dyn ExceptionFilter>],
473 None,
474 )
475 .get("/me", Self::me)
476 .get("/scoped", Self::scoped)
477 .build()
478 }
479 }
480
481 impl ExceptionFilter for RewriteBadRequestFilter {
482 fn catch(&self, exception: HttpException, _ctx: &RequestContext) -> HttpException {
483 if exception.status == axum::http::StatusCode::BAD_REQUEST {
484 HttpException::bad_request("filtered bad request")
485 .with_optional_request_id(exception.request_id)
486 } else {
487 exception
488 }
489 }
490 }
491
492 struct TestModule;
493
494 impl ModuleDefinition for TestModule {
495 fn register(container: &Container) -> Result<()> {
496 register_provider(
497 container,
498 Provider::request_factory(|container| {
499 let ctx = container.resolve::<FrameworkRequestContext>()?;
500 Ok(RequestScopedService {
501 path: ctx.uri.path().to_string(),
502 })
503 }),
504 )?;
505 Ok(())
506 }
507
508 fn controllers() -> Vec<Router<Container>> {
509 vec![HealthController::router()]
510 }
511 }
512
513 #[tokio::test]
514 async fn request_middleware_sets_request_id_header_and_extension() {
515 let app = NestForgeFactory::<TestModule>::create()
516 .expect("factory should build")
517 .into_router();
518
519 let response = app
520 .oneshot(
521 axum::http::Request::builder()
522 .uri("/health/")
523 .body(axum::body::Body::empty())
524 .expect("request should build"),
525 )
526 .await
527 .expect("request should succeed");
528
529 assert!(response.headers().contains_key(REQUEST_ID_HEADER));
530 }
531
532 #[tokio::test]
533 async fn error_responses_keep_request_id_header() {
534 let app = NestForgeFactory::<TestModule>::create()
535 .expect("factory should build")
536 .use_exception_filter::<RewriteBadRequestFilter>()
537 .into_router();
538
539 let response = app
540 .oneshot(
541 axum::http::Request::builder()
542 .uri("/health/fail")
543 .body(axum::body::Body::empty())
544 .expect("request should build"),
545 )
546 .await
547 .expect("request should succeed");
548
549 assert_eq!(response.status(), axum::http::StatusCode::BAD_REQUEST);
550 assert!(response.headers().contains_key(REQUEST_ID_HEADER));
551 }
552
553 #[tokio::test]
554 async fn route_specific_exception_filters_rewrite_route_failures() {
555 let app = NestForgeFactory::<TestModule>::create()
556 .expect("factory should build")
557 .into_router();
558
559 let response = app
560 .oneshot(
561 axum::http::Request::builder()
562 .uri("/health/fail-local")
563 .body(axum::body::Body::empty())
564 .expect("request should build"),
565 )
566 .await
567 .expect("request should succeed");
568
569 assert_eq!(response.status(), axum::http::StatusCode::BAD_REQUEST);
570 assert!(response.headers().contains_key(REQUEST_ID_HEADER));
571 }
572
573 #[tokio::test]
574 async fn auth_resolver_inserts_identity_for_auth_user_extractor() {
575 let app = NestForgeFactory::<TestModule>::create()
576 .expect("factory should build")
577 .with_auth_resolver(|token, _container| async move {
578 Ok(token.map(|_| AuthIdentity::new("demo-user").with_roles(["admin"])))
579 })
580 .into_router();
581
582 let response = app
583 .oneshot(
584 axum::http::Request::builder()
585 .uri("/health/me")
586 .header(axum::http::header::AUTHORIZATION, "Bearer demo-token")
587 .body(axum::body::Body::empty())
588 .expect("request should build"),
589 )
590 .await
591 .expect("request should succeed");
592
593 assert_eq!(response.status(), axum::http::StatusCode::OK);
594 }
595
596 #[tokio::test]
597 async fn request_scoped_provider_resolves_from_per_request_container() {
598 let app = NestForgeFactory::<TestModule>::create()
599 .expect("factory should build")
600 .into_router();
601
602 let response = app
603 .oneshot(
604 axum::http::Request::builder()
605 .uri("/health/scoped")
606 .body(axum::body::Body::empty())
607 .expect("request should build"),
608 )
609 .await
610 .expect("request should succeed");
611
612 assert_eq!(response.status(), axum::http::StatusCode::OK);
613 }
614}