1use std::{
2 future::Future,
3 net::SocketAddr,
4 pin::Pin,
5 sync::{
6 atomic::{AtomicU64, Ordering},
7 Arc, Mutex,
8 },
9 time::{Instant, SystemTime, UNIX_EPOCH},
10};
11
12use anyhow::Result;
13use axum::{
14 body::Body,
15 http::{header::HeaderName, HeaderValue},
16 middleware::from_fn,
17 response::{IntoResponse, Response},
18 Router,
19};
20use nestforge_core::{
21 apply_exception_filters, execute_pipeline, framework_log_event, initialize_module_runtime,
22 AuthIdentity, Container, ExceptionFilter, Guard, HttpException, InitializedModule, Interceptor,
23 ModuleDefinition, NextFn, RequestContext, RequestId,
24};
25
26use crate::middleware::{
27 run_middleware_chain, MiddlewareBinding, MiddlewareConsumer, NestMiddleware,
28};
29
30pub struct NestForgeFactory<M: ModuleDefinition> {
42 _marker: std::marker::PhantomData<M>,
43 container: Container,
44 runtime: Arc<InitializedModule>,
45 controllers: Vec<Router<Container>>,
46 extra_routers: Vec<Router<Container>>,
47 global_prefix: Option<String>,
48 version: Option<String>,
49 auth_resolver: Option<Arc<AuthResolver>>,
50 global_guards: Vec<Arc<dyn Guard>>,
51 global_interceptors: Vec<Arc<dyn Interceptor>>,
52 global_exception_filters: Vec<Arc<dyn ExceptionFilter>>,
53 middleware_bindings: Vec<MiddlewareBinding>,
54}
55
56type AuthFuture = Pin<Box<dyn Future<Output = Result<Option<AuthIdentity>, HttpException>> + Send>>;
57type AuthResolver = dyn Fn(Option<String>, Container) -> AuthFuture + Send + Sync;
58
59impl<M: ModuleDefinition> NestForgeFactory<M> {
60 pub fn create() -> Result<Self> {
61 let container = Container::new();
62 let runtime = Arc::new(initialize_module_runtime::<M>(&container)?);
63 runtime.run_module_init(&container)?;
64 runtime.run_application_bootstrap(&container)?;
65 let controllers = runtime.controllers.clone();
66
67 Ok(Self {
68 _marker: std::marker::PhantomData,
69 container,
70 runtime,
71 controllers,
72 extra_routers: Vec::new(),
73 global_prefix: None,
74 version: None,
75 auth_resolver: None,
76 global_guards: Vec::new(),
77 global_interceptors: Vec::new(),
78 global_exception_filters: Vec::new(),
79 middleware_bindings: Vec::new(),
80 })
81 }
82
83 pub fn with_global_prefix(mut self, prefix: impl Into<String>) -> Self {
84 let prefix = prefix.into().trim().trim_matches('/').to_string();
85 if !prefix.is_empty() {
86 framework_log_event("global_prefix_configured", &[("prefix", prefix.clone())]);
87 self.global_prefix = Some(prefix);
88 }
89 self
90 }
91
92 pub fn with_version(mut self, version: impl Into<String>) -> Self {
93 let version = version.into().trim().trim_matches('/').to_string();
94 if !version.is_empty() {
95 framework_log_event("api_version_configured", &[("version", version.clone())]);
96 self.version = Some(version);
97 }
98 self
99 }
100
101 pub fn use_guard<G>(mut self) -> Self
102 where
103 G: Guard + Default,
104 {
105 framework_log_event(
106 "global_guard_register",
107 &[("guard", std::any::type_name::<G>().to_string())],
108 );
109 self.global_guards.push(Arc::new(G::default()));
110 self
111 }
112
113 pub fn use_interceptor<I>(mut self) -> Self
114 where
115 I: Interceptor + Default,
116 {
117 framework_log_event(
118 "global_interceptor_register",
119 &[("interceptor", std::any::type_name::<I>().to_string())],
120 );
121 self.global_interceptors.push(Arc::new(I::default()));
122 self
123 }
124
125 pub fn use_exception_filter<F>(mut self) -> Self
126 where
127 F: ExceptionFilter + Default,
128 {
129 framework_log_event(
130 "global_exception_filter_register",
131 &[("filter", std::any::type_name::<F>().to_string())],
132 );
133 self.global_exception_filters.push(Arc::new(F::default()));
134 self
135 }
136
137 pub fn use_middleware<T>(mut self) -> Self
138 where
139 T: NestMiddleware + Default,
140 {
141 let mut consumer = MiddlewareConsumer::new();
142 consumer.apply::<T>().for_all_routes();
143 self.middleware_bindings.extend(consumer.into_bindings());
144 self
145 }
146
147 pub fn configure_middleware<F>(mut self, configure: F) -> Self
148 where
149 F: FnOnce(&mut MiddlewareConsumer),
150 {
151 let mut consumer = MiddlewareConsumer::new();
152 configure(&mut consumer);
153 self.middleware_bindings.extend(consumer.into_bindings());
154 self
155 }
156
157 pub fn with_auth_resolver<F, Fut>(mut self, resolver: F) -> Self
158 where
159 F: Fn(Option<String>, Container) -> Fut + Send + Sync + 'static,
160 Fut: Future<Output = Result<Option<AuthIdentity>, HttpException>> + Send + 'static,
161 {
162 self.auth_resolver = Some(Arc::new(move |token, container| {
163 Box::pin(resolver(token, container))
164 }));
165 self
166 }
167
168 pub fn merge_router(mut self, router: Router<Container>) -> Self {
169 self.extra_routers.push(router);
170 self
171 }
172
173 pub fn container(&self) -> &Container {
174 &self.container
175 }
176
177 pub fn into_router(self) -> Router {
178 let mut app: Router<Container> = Router::new();
183
184 for controller_router in self.controllers {
188 app = app.merge(controller_router);
189 }
190 for extra_router in self.extra_routers {
191 app = app.merge(extra_router);
192 }
193
194 if let Some(version) = &self.version {
195 app = Router::new().nest(&format!("/{}", version), app);
196 }
197
198 if let Some(prefix) = &self.global_prefix {
199 app = Router::new().nest(&format!("/{}", prefix), app);
200 }
201
202 let global_guards = Arc::new(self.global_guards);
203 let global_interceptors = Arc::new(self.global_interceptors);
204 let global_exception_filters = Arc::new(self.global_exception_filters);
205 let middleware_bindings = Arc::new(self.middleware_bindings);
206 let auth_resolver = self.auth_resolver.clone();
207 let request_container = self.container.clone();
208
209 let route_exception_filters = Arc::clone(&global_exception_filters);
210 let app = app.route_layer(from_fn(move |req, next| {
211 let guards = Arc::clone(&global_guards);
212 let interceptors = Arc::clone(&global_interceptors);
213 let filters = Arc::clone(&route_exception_filters);
214 async move { execute_pipeline(req, next, guards, interceptors, filters).await }
215 }));
216
217 let app = app.layer(from_fn(
218 move |req: axum::extract::Request, next: axum::middleware::Next| {
219 let middlewares = Arc::clone(&middleware_bindings);
220 async move {
221 if middlewares.is_empty() {
222 return next.run(req).await;
223 }
224
225 let terminal = next_to_fn(next);
226 run_middleware_chain(middlewares, 0, req, terminal).await
227 }
228 },
229 ));
230
231 Router::new()
232 .merge(app)
233 .layer(from_fn(move |req, next| {
234 let auth_resolver = auth_resolver.clone();
235 let request_container = request_container.clone();
236 let exception_filters = Arc::clone(&global_exception_filters);
237 async move {
238 request_context_middleware(
239 req,
240 next,
241 request_container,
242 auth_resolver,
243 exception_filters,
244 )
245 .await
246 }
247 }))
248 .with_state(self.container)
249 }
250
251 pub async fn listen(self, port: u16) -> Result<()> {
252 let runtime = Arc::clone(&self.runtime);
253 let container = self.container.clone();
254 let app = self.into_router();
255
256 let addr = SocketAddr::from(([127, 0, 0, 1], port));
257 let listener = tokio::net::TcpListener::bind(addr).await?;
258
259 framework_log_event("server_listening", &[("addr", addr.to_string())]);
260
261 axum::serve(listener, app).await?;
262 runtime.run_module_destroy(&container)?;
263 runtime.run_application_shutdown(&container)?;
264 Ok(())
265 }
266}
267
268static NEXT_REQUEST_SEQUENCE: AtomicU64 = AtomicU64::new(1);
269const REQUEST_ID_HEADER: &str = "x-request-id";
270
271fn next_to_fn(next: axum::middleware::Next) -> NextFn {
272 let next = Arc::new(Mutex::new(Some(next)));
273
274 Arc::new(move |req: axum::extract::Request<Body>| {
275 let next = Arc::clone(&next);
276 Box::pin(async move {
277 let next = {
278 let mut guard = match next.lock() {
279 Ok(guard) => guard,
280 Err(_) => {
281 return HttpException::internal_server_error("Middleware lock poisoned")
282 .into_response();
283 }
284 };
285 guard.take()
286 };
287
288 match next {
289 Some(next) => next.run(req).await,
290 None => {
291 HttpException::internal_server_error("Middleware next called multiple times")
292 .into_response()
293 }
294 }
295 })
296 })
297}
298
299async fn request_context_middleware(
300 mut req: axum::extract::Request,
301 next: axum::middleware::Next,
302 container: Container,
303 auth_resolver: Option<Arc<AuthResolver>>,
304 exception_filters: Arc<Vec<Arc<dyn ExceptionFilter>>>,
305) -> Response {
306 let scoped_container = container.scoped();
307 let request_id = RequestId::new(generate_request_id());
308 let request_id_value = request_id.value().to_string();
309 let method = req.method().to_string();
310 let path = req.uri().path().to_string();
311 let started = Instant::now();
312 let bearer_token = req
313 .headers()
314 .get(axum::http::header::AUTHORIZATION)
315 .and_then(|value| value.to_str().ok())
316 .and_then(|value| value.strip_prefix("Bearer "))
317 .map(str::trim)
318 .filter(|value| !value.is_empty())
319 .map(str::to_string);
320
321 req.extensions_mut().insert(scoped_container.clone());
322 req.extensions_mut().insert(request_id.clone());
323 let _ = scoped_container.override_value(request_id.clone());
324 framework_log_event(
325 "request_start",
326 &[
327 ("request_id", request_id_value.clone()),
328 ("method", method.clone()),
329 ("path", path.clone()),
330 ],
331 );
332
333 if let Some(resolver) = auth_resolver {
334 match resolver(bearer_token, container).await {
335 Ok(Some(identity)) => {
336 framework_log_event(
337 "auth_identity_resolved",
338 &[
339 ("request_id", request_id_value.clone()),
340 ("subject", identity.subject.clone()),
341 ],
342 );
343 let _ = scoped_container.override_value(identity.clone());
344 req.extensions_mut().insert(Arc::new(identity));
345 }
346 Ok(None) => {}
347 Err(err) => {
348 let ctx = RequestContext::from_request(&req);
349 let _ = scoped_container.override_value(ctx.clone());
350 let mut response = apply_exception_filters(
351 err.with_request_id(request_id_value.clone()),
352 &ctx,
353 exception_filters.as_slice(),
354 )
355 .into_response();
356 attach_request_id_header(&mut response, &request_id_value);
357 framework_log_event(
358 "request_complete",
359 &[
360 ("request_id", request_id_value),
361 ("method", method),
362 ("path", path),
363 ("status", response.status().as_u16().to_string()),
364 ("duration_ms", started.elapsed().as_millis().to_string()),
365 ],
366 );
367 return response;
368 }
369 }
370 }
371
372 let ctx = RequestContext::from_request(&req);
373 let _ = scoped_container.override_value(ctx);
374
375 let mut response = next.run(req).await;
376 attach_request_id_header(&mut response, &request_id_value);
377
378 framework_log_event(
379 "request_complete",
380 &[
381 ("request_id", request_id_value),
382 ("method", method),
383 ("path", path),
384 ("status", response.status().as_u16().to_string()),
385 ("duration_ms", started.elapsed().as_millis().to_string()),
386 ],
387 );
388
389 response
390}
391
392fn generate_request_id() -> String {
393 let sequence = NEXT_REQUEST_SEQUENCE.fetch_add(1, Ordering::Relaxed);
394 let millis = SystemTime::now()
395 .duration_since(UNIX_EPOCH)
396 .map(|duration| duration.as_millis())
397 .unwrap_or_default();
398 format!("req-{millis}-{sequence}")
399}
400
401fn attach_request_id_header(response: &mut Response, request_id: &str) {
402 if let Ok(value) = HeaderValue::from_str(request_id) {
403 response
404 .headers_mut()
405 .insert(HeaderName::from_static(REQUEST_ID_HEADER), value);
406 }
407}
408
409#[cfg(test)]
410mod tests {
411 use std::sync::Arc;
412
413 use anyhow::Result;
414 use axum::Json;
415 use nestforge_core::{
416 register_provider, ApiResult, AuthUser, Container, ControllerBasePath,
417 ControllerDefinition, ExceptionFilter, HttpException, Inject, ModuleDefinition, Provider,
418 RequestContext as FrameworkRequestContext, RouteBuilder,
419 };
420 use tower::ServiceExt;
421
422 use super::*;
423
424 struct HealthController;
425 #[derive(Default)]
426 struct RewriteBadRequestFilter;
427 struct RequestScopedService {
428 path: String,
429 }
430
431 impl ControllerBasePath for HealthController {
432 fn base_path() -> &'static str {
433 "/health"
434 }
435 }
436
437 impl HealthController {
438 async fn ok(request_id: RequestId) -> ApiResult<String> {
439 Ok(Json(request_id.value().to_string()))
440 }
441
442 async fn fail(request_id: RequestId) -> ApiResult<String> {
443 Err(HttpException::bad_request("broken request")
444 .with_request_id(request_id.value().to_string()))
445 }
446
447 async fn fail_locally(request_id: RequestId) -> ApiResult<String> {
448 Err(HttpException::bad_request("local broken request")
449 .with_request_id(request_id.value().to_string()))
450 }
451
452 async fn me(user: AuthUser) -> ApiResult<String> {
453 Ok(Json(user.subject.clone()))
454 }
455
456 async fn scoped(service: Inject<RequestScopedService>) -> ApiResult<String> {
457 Ok(Json(service.path.clone()))
458 }
459 }
460
461 impl ControllerDefinition for HealthController {
462 fn router() -> Router<Container> {
463 RouteBuilder::<Self>::new()
464 .get("/", Self::ok)
465 .get("/fail", Self::fail)
466 .get_with_pipeline(
467 "/fail-local",
468 Self::fail_locally,
469 Vec::new(),
470 Vec::new(),
471 vec![Arc::new(RewriteBadRequestFilter) as Arc<dyn ExceptionFilter>],
472 None,
473 )
474 .get("/me", Self::me)
475 .get("/scoped", Self::scoped)
476 .build()
477 }
478 }
479
480 impl ExceptionFilter for RewriteBadRequestFilter {
481 fn catch(&self, exception: HttpException, _ctx: &RequestContext) -> HttpException {
482 if exception.status == axum::http::StatusCode::BAD_REQUEST {
483 HttpException::bad_request("filtered bad request")
484 .with_optional_request_id(exception.request_id)
485 } else {
486 exception
487 }
488 }
489 }
490
491 struct TestModule;
492
493 impl ModuleDefinition for TestModule {
494 fn register(container: &Container) -> Result<()> {
495 register_provider(
496 container,
497 Provider::request_factory(|container| {
498 let ctx = container.resolve::<FrameworkRequestContext>()?;
499 Ok(RequestScopedService {
500 path: ctx.uri.path().to_string(),
501 })
502 }),
503 )?;
504 Ok(())
505 }
506
507 fn controllers() -> Vec<Router<Container>> {
508 vec![HealthController::router()]
509 }
510 }
511
512 #[tokio::test]
513 async fn request_middleware_sets_request_id_header_and_extension() {
514 let app = NestForgeFactory::<TestModule>::create()
515 .expect("factory should build")
516 .into_router();
517
518 let response = app
519 .oneshot(
520 axum::http::Request::builder()
521 .uri("/health/")
522 .body(axum::body::Body::empty())
523 .expect("request should build"),
524 )
525 .await
526 .expect("request should succeed");
527
528 assert!(response.headers().contains_key(REQUEST_ID_HEADER));
529 }
530
531 #[tokio::test]
532 async fn error_responses_keep_request_id_header() {
533 let app = NestForgeFactory::<TestModule>::create()
534 .expect("factory should build")
535 .use_exception_filter::<RewriteBadRequestFilter>()
536 .into_router();
537
538 let response = app
539 .oneshot(
540 axum::http::Request::builder()
541 .uri("/health/fail")
542 .body(axum::body::Body::empty())
543 .expect("request should build"),
544 )
545 .await
546 .expect("request should succeed");
547
548 assert_eq!(response.status(), axum::http::StatusCode::BAD_REQUEST);
549 assert!(response.headers().contains_key(REQUEST_ID_HEADER));
550 }
551
552 #[tokio::test]
553 async fn route_specific_exception_filters_rewrite_route_failures() {
554 let app = NestForgeFactory::<TestModule>::create()
555 .expect("factory should build")
556 .into_router();
557
558 let response = app
559 .oneshot(
560 axum::http::Request::builder()
561 .uri("/health/fail-local")
562 .body(axum::body::Body::empty())
563 .expect("request should build"),
564 )
565 .await
566 .expect("request should succeed");
567
568 assert_eq!(response.status(), axum::http::StatusCode::BAD_REQUEST);
569 assert!(response.headers().contains_key(REQUEST_ID_HEADER));
570 }
571
572 #[tokio::test]
573 async fn auth_resolver_inserts_identity_for_auth_user_extractor() {
574 let app = NestForgeFactory::<TestModule>::create()
575 .expect("factory should build")
576 .with_auth_resolver(|token, _container| async move {
577 Ok(token.map(|_| AuthIdentity::new("demo-user").with_roles(["admin"])))
578 })
579 .into_router();
580
581 let response = app
582 .oneshot(
583 axum::http::Request::builder()
584 .uri("/health/me")
585 .header(axum::http::header::AUTHORIZATION, "Bearer demo-token")
586 .body(axum::body::Body::empty())
587 .expect("request should build"),
588 )
589 .await
590 .expect("request should succeed");
591
592 assert_eq!(response.status(), axum::http::StatusCode::OK);
593 }
594
595 #[tokio::test]
596 async fn request_scoped_provider_resolves_from_per_request_container() {
597 let app = NestForgeFactory::<TestModule>::create()
598 .expect("factory should build")
599 .into_router();
600
601 let response = app
602 .oneshot(
603 axum::http::Request::builder()
604 .uri("/health/scoped")
605 .body(axum::body::Body::empty())
606 .expect("request should build"),
607 )
608 .await
609 .expect("request should succeed");
610
611 assert_eq!(response.status(), axum::http::StatusCode::OK);
612 }
613}