Skip to main content

api_gateway/
module.rs

1//! API Gateway Module definition
2//!
3//! Contains the `ApiGateway` module struct and its trait implementations.
4
5use async_trait::async_trait;
6use std::sync::Arc;
7
8use arc_swap::ArcSwap;
9use dashmap::DashMap;
10
11use anyhow::Result;
12use axum::http::Method;
13use axum::middleware::from_fn_with_state;
14use axum::{Router, extract::DefaultBodyLimit, middleware::from_fn, routing::get};
15use modkit::api::{OpenApiRegistry, OpenApiRegistryImpl};
16use modkit::lifecycle::ReadySignal;
17use parking_lot::Mutex;
18use std::net::SocketAddr;
19use std::time::Duration;
20use tokio_util::sync::CancellationToken;
21use tower_http::{
22    catch_panic::CatchPanicLayer,
23    limit::RequestBodyLimitLayer,
24    request_id::{PropagateRequestIdLayer, SetRequestIdLayer},
25    timeout::TimeoutLayer,
26};
27use tracing::debug;
28
29use authn_resolver_sdk::AuthNResolverClient;
30
31use crate::config::ApiGatewayConfig;
32use crate::middleware::auth;
33use modkit_security::SecurityContext;
34use modkit_security::constants::{DEFAULT_SUBJECT_ID, DEFAULT_TENANT_ID};
35
36use crate::middleware;
37use crate::router_cache::RouterCache;
38use crate::web;
39
40/// Main API Gateway module — owns the HTTP server (`rest_host`) and collects
41/// typed operation specs to emit a single `OpenAPI` document.
42#[modkit::module(
43	name = "api-gateway",
44	capabilities = [rest_host, rest, stateful],
45    deps = ["grpc-hub", "authn-resolver"],
46	lifecycle(entry = "serve", stop_timeout = "30s", await_ready)
47)]
48pub struct ApiGateway {
49    // Lock-free config using arc-swap for read-mostly access
50    pub(crate) config: ArcSwap<ApiGatewayConfig>,
51    // OpenAPI registry for operations and schemas
52    pub(crate) openapi_registry: Arc<OpenApiRegistryImpl>,
53    // Built router cache for zero-lock hot path access
54    pub(crate) router_cache: RouterCache<axum::Router>,
55    // Store the finalized router from REST phase for serving
56    pub(crate) final_router: Mutex<Option<axum::Router>>,
57    // AuthN Resolver client (resolved during init, None when auth_disabled)
58    pub(crate) authn_client: Mutex<Option<Arc<dyn AuthNResolverClient>>>,
59
60    // Duplicate detection (per (method, path) and per handler id)
61    pub(crate) registered_routes: DashMap<(Method, String), ()>,
62    pub(crate) registered_handlers: DashMap<String, ()>,
63}
64
65impl Default for ApiGateway {
66    fn default() -> Self {
67        let default_router = Router::new();
68        Self {
69            config: ArcSwap::from_pointee(ApiGatewayConfig::default()),
70            openapi_registry: Arc::new(OpenApiRegistryImpl::new()),
71            router_cache: RouterCache::new(default_router),
72            final_router: Mutex::new(None),
73            authn_client: Mutex::new(None),
74            registered_routes: DashMap::new(),
75            registered_handlers: DashMap::new(),
76        }
77    }
78}
79
80impl ApiGateway {
81    fn apply_prefix_nesting(mut router: Router, prefix: &str) -> Router {
82        if prefix.is_empty() {
83            return router;
84        }
85
86        let top = Router::new()
87            .route("/health", get(web::health_check))
88            .route("/healthz", get(|| async { "ok" }));
89
90        router = Router::new().nest(prefix, router);
91        top.merge(router)
92    }
93
94    /// Create a new `ApiGateway` instance with the given configuration
95    #[must_use]
96    pub fn new(config: ApiGatewayConfig) -> Self {
97        let default_router = Router::new();
98        Self {
99            config: ArcSwap::from_pointee(config),
100            openapi_registry: Arc::new(OpenApiRegistryImpl::new()),
101            router_cache: RouterCache::new(default_router),
102            final_router: Mutex::new(None),
103            authn_client: Mutex::new(None),
104            registered_routes: DashMap::new(),
105            registered_handlers: DashMap::new(),
106        }
107    }
108
109    /// Get the current configuration (cheap clone from `ArcSwap`)
110    pub fn get_config(&self) -> ApiGatewayConfig {
111        (**self.config.load()).clone()
112    }
113
114    /// Get cached configuration (lock-free with `ArcSwap`)
115    pub fn get_cached_config(&self) -> ApiGatewayConfig {
116        (**self.config.load()).clone()
117    }
118
119    /// Get the cached router without rebuilding (useful for performance-critical paths)
120    pub fn get_cached_router(&self) -> Arc<Router> {
121        self.router_cache.load()
122    }
123
124    /// Force rebuild and cache of the router.
125    ///
126    /// # Errors
127    /// Returns an error if router building fails.
128    pub fn rebuild_and_cache_router(&self) -> Result<()> {
129        let new_router = self.build_router()?;
130        self.router_cache.store(new_router);
131        Ok(())
132    }
133
134    /// Build route policy from operation specs.
135    fn build_route_policy_from_specs(&self) -> Result<auth::GatewayRoutePolicy> {
136        let mut authenticated_routes = std::collections::HashSet::new();
137        let mut public_routes = std::collections::HashSet::new();
138
139        // Always mark built-in health check routes as public
140        public_routes.insert((Method::GET, "/health".to_owned()));
141        public_routes.insert((Method::GET, "/healthz".to_owned()));
142
143        public_routes.insert((Method::GET, "/docs".to_owned()));
144        public_routes.insert((Method::GET, "/openapi.json".to_owned()));
145
146        for spec in &self.openapi_registry.operation_specs {
147            let spec = spec.value();
148
149            let route_key = (spec.method.clone(), spec.path.clone());
150
151            if spec.authenticated {
152                authenticated_routes.insert(route_key.clone());
153            }
154
155            if spec.is_public {
156                public_routes.insert(route_key);
157            }
158        }
159
160        let config = self.get_cached_config();
161        let requirements_count = authenticated_routes.len();
162        let public_routes_count = public_routes.len();
163
164        let route_policy = auth::build_route_policy(&config, authenticated_routes, public_routes)?;
165
166        tracing::info!(
167            auth_disabled = config.auth_disabled,
168            require_auth_by_default = config.require_auth_by_default,
169            requirements_count = requirements_count,
170            public_routes_count = public_routes_count,
171            "Route policy built from operation specs"
172        );
173
174        Ok(route_policy)
175    }
176
177    fn normalize_prefix_path(raw: &str) -> Result<String> {
178        let trimmed = raw.trim();
179        // Collapse consecutive slashes then strip trailing slash(es).
180        let collapsed: String =
181            trimmed
182                .chars()
183                .fold(String::with_capacity(trimmed.len()), |mut acc, c| {
184                    if c == '/' && acc.ends_with('/') {
185                        // skip duplicate slash
186                    } else {
187                        acc.push(c);
188                    }
189                    acc
190                });
191        let prefix = collapsed.trim_end_matches('/');
192        let result = if prefix.is_empty() {
193            String::new()
194        } else if prefix.starts_with('/') {
195            prefix.to_owned()
196        } else {
197            format!("/{prefix}")
198        };
199        // Reject characters that are unsafe in URL paths or HTML attributes.
200        if !result
201            .bytes()
202            .all(|b| b.is_ascii_alphanumeric() || b == b'/' || b == b'_' || b == b'-' || b == b'.')
203        {
204            anyhow::bail!(
205                "prefix_path contains invalid characters (must match [a-zA-Z0-9/_\\-.]): {raw:?}"
206            );
207        }
208
209        if result.split('/').any(|seg| seg == "." || seg == "..") {
210            anyhow::bail!("prefix_path must not contain '.' or '..' segments: {raw:?}");
211        }
212
213        Ok(result)
214    }
215
216    /// Apply all middleware layers to a router (request ID, tracing, timeout, body limit, CORS, rate limiting, error mapping, auth)
217    pub(crate) fn apply_middleware_stack(
218        &self,
219        mut router: Router,
220        authn_client: Option<Arc<dyn AuthNResolverClient>>,
221    ) -> Result<Router> {
222        // Build route policy once
223        let route_policy = self.build_route_policy_from_specs()?;
224
225        // IMPORTANT: `axum::Router::layer(...)` behaves like Tower layers: the **last** added layer
226        // becomes the **outermost** layer and therefore runs **first** on the request path.
227        //
228        // Desired request execution order (outermost -> innermost):
229        // SetRequestId -> PropagateRequestId -> Trace -> push_req_id
230        // -> HttpMetrics -> CatchPanic
231        // -> Timeout -> BodyLimit -> CORS -> MIME validation -> RateLimit -> ErrorMapping -> Auth -> License
232        // -> [Route matching] -> PropagateMatchedPath -> Handler
233        //
234        // Therefore we must add layers in the reverse order (innermost -> outermost) below.
235        // Due future refactoring, this order must be maintained.
236
237        // 14) Propagate MatchedPath to response extensions (route_layer — innermost).
238        // This copies MatchedPath from the request (populated by Axum route matching)
239        // into the response so outer layer() middleware (metrics) can read it.
240        router = router.route_layer(from_fn(middleware::http_metrics::propagate_matched_path));
241
242        let config = self.get_cached_config();
243
244        // Collect specs once; used by MIME validation + rate limiting maps.
245        let specs: Vec<_> = self
246            .openapi_registry
247            .operation_specs
248            .iter()
249            .map(|e| e.value().clone())
250            .collect();
251
252        // 13) License validation
253        let license_map = middleware::license_validation::LicenseRequirementMap::from_specs(&specs);
254
255        router = router.layer(from_fn(
256            move |req: axum::extract::Request, next: axum::middleware::Next| {
257                let map = license_map.clone();
258                middleware::license_validation::license_validation_middleware(map, req, next)
259            },
260        ));
261
262        // 12) Auth
263        if config.auth_disabled {
264            // Build security contexts for compatibility during migration
265            let default_security_context = SecurityContext::builder()
266                .subject_id(DEFAULT_SUBJECT_ID)
267                .subject_tenant_id(DEFAULT_TENANT_ID)
268                .build()?;
269
270            tracing::warn!(
271                "API Gateway auth is DISABLED: all requests will run with default tenant SecurityContext. \
272                 This mode bypasses authentication and is intended ONLY for single-user on-premises deployments without an IdP. \
273                 Permission checks and secure ORM still apply. DO NOT use this mode in multi-tenant or production environments."
274            );
275            router = router.layer(from_fn(
276                move |mut req: axum::extract::Request, next: axum::middleware::Next| {
277                    let sec_context = default_security_context.clone();
278                    async move {
279                        req.extensions_mut().insert(sec_context);
280                        next.run(req).await
281                    }
282                },
283            ));
284        } else if let Some(client) = authn_client {
285            let auth_state = auth::AuthState {
286                authn_client: client,
287                route_policy,
288            };
289            router = router.layer(from_fn_with_state(auth_state, auth::authn_middleware));
290        } else {
291            return Err(anyhow::anyhow!(
292                "auth is enabled but no AuthN Resolver client is available; \
293                 ensure `authn_resolver` module is loaded or set `auth_disabled: true`"
294            ));
295        }
296
297        // 11) Error mapping (outer to auth so it can translate auth/handler errors)
298        router = router.layer(from_fn(modkit::api::error_layer::error_mapping_middleware));
299
300        // 10) Per-route rate limiting & in-flight limits
301        let rate_map = middleware::rate_limit::RateLimiterMap::from_specs(&specs, &config)?;
302
303        router = router.layer(from_fn(
304            move |req: axum::extract::Request, next: axum::middleware::Next| {
305                let map = rate_map.clone();
306                middleware::rate_limit::rate_limit_middleware(map, req, next)
307            },
308        ));
309
310        // 9) MIME type validation
311        let mime_map = middleware::mime_validation::build_mime_validation_map(&specs);
312        router = router.layer(from_fn(
313            move |req: axum::extract::Request, next: axum::middleware::Next| {
314                let map = mime_map.clone();
315                middleware::mime_validation::mime_validation_middleware(map, req, next)
316            },
317        ));
318
319        // 8) CORS (must be outer to auth/limits so OPTIONS preflight short-circuits)
320        if config.cors_enabled {
321            router = router.layer(crate::cors::build_cors_layer(&config));
322        }
323
324        // 7) Body limit
325        router = router.layer(RequestBodyLimitLayer::new(config.defaults.body_limit_bytes));
326        router = router.layer(DefaultBodyLimit::max(config.defaults.body_limit_bytes));
327
328        // 6) Timeout
329        router = router.layer(TimeoutLayer::with_status_code(
330            axum::http::StatusCode::GATEWAY_TIMEOUT,
331            Duration::from_secs(30),
332        ));
333
334        // 5) CatchPanic (converts panics to 500 before metrics sees them)
335        router = router.layer(CatchPanicLayer::new());
336
337        // 4) HTTP metrics (layer — captures all middleware responses including auth/rate-limit/timeout)
338        let http_metrics = Arc::new(middleware::http_metrics::HttpMetrics::new(
339            Self::MODULE_NAME,
340            &config.metrics.prefix,
341        ));
342        router = router.layer(from_fn_with_state(
343            http_metrics,
344            middleware::http_metrics::http_metrics_middleware,
345        ));
346
347        // 3) Record request_id into span + extensions (requires span to exist first => must be inner to Trace)
348        router = router.layer(from_fn(middleware::request_id::push_req_id_to_extensions));
349
350        // 2) Trace (outer to push_req_id_to_extensions)
351        router = router.layer({
352            use modkit_http::otel;
353            use tower_http::trace::TraceLayer;
354            use tracing::field::Empty;
355
356            TraceLayer::new_for_http()
357                .make_span_with(move |req: &axum::http::Request<axum::body::Body>| {
358                    let hdr = middleware::request_id::header();
359                    let rid = req
360                        .headers()
361                        .get(&hdr)
362                        .and_then(|v| v.to_str().ok())
363                        .unwrap_or("n/a");
364
365                    let span = tracing::info_span!(
366                        "http_request",
367                        method = %req.method(),
368                        uri = %req.uri().path(),
369                        version = ?req.version(),
370                        module = "api_gateway",
371                        endpoint = %req.uri().path(),
372                        request_id = %rid,
373                        status = Empty,
374                        latency_ms = Empty,
375                        // OpenTelemetry semantic conventions
376                        "http.method" = %req.method(),
377                        "http.target" = %req.uri().path(),
378                        "http.scheme" = req.uri().scheme_str().unwrap_or("http"),
379                        "http.host" = req.headers().get("host")
380                            .and_then(|h| h.to_str().ok())
381                            .unwrap_or("unknown"),
382                        "user_agent.original" = req.headers().get("user-agent")
383                            .and_then(|h| h.to_str().ok())
384                            .unwrap_or("unknown"),
385                        // Trace context placeholders (for log correlation)
386                        trace_id = Empty,
387                        parent.trace_id = Empty
388                    );
389
390                    // Set parent OTel trace context (W3C traceparent), if any
391                    // This also populates trace_id and parent.trace_id from headers
392                    otel::set_parent_from_headers(&span, req.headers());
393
394                    span
395                })
396                .on_response(
397                    |res: &axum::http::Response<axum::body::Body>,
398                     latency: std::time::Duration,
399                     span: &tracing::Span| {
400                        let ms = latency.as_millis();
401                        span.record("status", res.status().as_u16());
402                        span.record("latency_ms", ms);
403                    },
404                )
405        });
406
407        // 1) Request ID handling (outermost)
408        let x_request_id = crate::middleware::request_id::header();
409        // If missing, generate x-request-id first; then propagate it to the response.
410        router = router.layer(PropagateRequestIdLayer::new(x_request_id.clone()));
411        router = router.layer(SetRequestIdLayer::new(
412            x_request_id,
413            crate::middleware::request_id::MakeReqId,
414        ));
415
416        Ok(router)
417    }
418
419    /// Build the HTTP router from registered routes and operations.
420    ///
421    /// # Errors
422    /// Returns an error if router building or middleware setup fails.
423    pub fn build_router(&self) -> Result<Router> {
424        // If the cached router is currently held elsewhere (e.g., by the running server),
425        // return it without rebuilding to avoid unnecessary allocations.
426        let cached_router = self.router_cache.load();
427        if Arc::strong_count(&cached_router) > 1 {
428            tracing::debug!("Using cached router");
429            return Ok((*cached_router).clone());
430        }
431
432        tracing::debug!("Building new router (standalone/fallback mode)");
433        // In standalone mode (no REST pipeline), register both health endpoints here.
434        // In normal operation, rest_prepare() registers these instead.
435        let mut router = Router::new()
436            .route("/health", get(web::health_check))
437            .route("/healthz", get(|| async { "ok" }));
438
439        // Apply all middleware layers including auth, above the router
440        let authn_client = self.authn_client.lock().clone();
441        router = self.apply_middleware_stack(router, authn_client)?;
442
443        let config = self.get_cached_config();
444        let prefix = Self::normalize_prefix_path(&config.prefix_path)?;
445        router = Self::apply_prefix_nesting(router, &prefix);
446
447        // Cache the built router for future use
448        self.router_cache.store(router.clone());
449
450        Ok(router)
451    }
452
453    /// Build `OpenAPI` specification from registered routes and components.
454    ///
455    /// # Errors
456    /// Returns an error if `OpenAPI` specification building fails.
457    pub fn build_openapi(&self) -> Result<utoipa::openapi::OpenApi> {
458        let config = self.get_cached_config();
459        let info = modkit::api::OpenApiInfo {
460            title: config.openapi.title.clone(),
461            version: config.openapi.version.clone(),
462            description: config.openapi.description,
463        };
464        self.openapi_registry.build_openapi(&info)
465    }
466
467    /// Parse bind address from configuration string.
468    fn parse_bind_address(bind_addr: &str) -> anyhow::Result<SocketAddr> {
469        bind_addr
470            .parse()
471            .map_err(|e| anyhow::anyhow!("Invalid bind address '{bind_addr}': {e}"))
472    }
473
474    /// Get the finalized router or build a default one.
475    fn get_or_build_router(self: &Arc<Self>) -> anyhow::Result<Router> {
476        let stored = { self.final_router.lock().take() };
477
478        if let Some(router) = stored {
479            tracing::debug!("Using router from REST phase");
480            Ok(router)
481        } else {
482            tracing::debug!("No router from REST phase, building default router");
483            self.build_router()
484        }
485    }
486
487    /// Background HTTP server: bind, notify ready, serve until cancelled.
488    ///
489    /// This method is the lifecycle entry-point generated by the macro
490    /// (`#[modkit::module(..., lifecycle(...))]`).
491    pub(crate) async fn serve(
492        self: Arc<Self>,
493        cancel: CancellationToken,
494        ready: ReadySignal,
495    ) -> anyhow::Result<()> {
496        let cfg = self.get_cached_config();
497        let addr = Self::parse_bind_address(&cfg.bind_addr)?;
498        let router = self.get_or_build_router()?;
499
500        // Bind the socket, only now consider the service "ready"
501        let listener = tokio::net::TcpListener::bind(addr).await?;
502        tracing::info!("HTTP server bound on {}", addr);
503        ready.notify(); // Starting -> Running
504
505        // Graceful shutdown on cancel
506        let shutdown = {
507            let cancel = cancel.clone();
508            async move {
509                cancel.cancelled().await;
510                tracing::info!("HTTP server shutting down gracefully (cancellation)");
511            }
512        };
513
514        axum::serve(listener, router)
515            .with_graceful_shutdown(shutdown)
516            .await
517            .map_err(|e| anyhow::anyhow!(e))
518    }
519
520    /// Check if `handler_id` is already registered (returns true if duplicate)
521    fn check_duplicate_handler(&self, spec: &modkit::api::OperationSpec) -> bool {
522        if self
523            .registered_handlers
524            .insert(spec.handler_id.clone(), ())
525            .is_some()
526        {
527            tracing::error!(
528                handler_id = %spec.handler_id,
529                method = %spec.method.as_str(),
530                path = %spec.path,
531                "Duplicate handler_id detected; ignoring subsequent registration"
532            );
533            return true;
534        }
535        false
536    }
537
538    /// Check if route (method, path) is already registered (returns true if duplicate)
539    fn check_duplicate_route(&self, spec: &modkit::api::OperationSpec) -> bool {
540        let route_key = (spec.method.clone(), spec.path.clone());
541        if self.registered_routes.insert(route_key, ()).is_some() {
542            tracing::error!(
543                method = %spec.method.as_str(),
544                path = %spec.path,
545                "Duplicate (method, path) detected; ignoring subsequent registration"
546            );
547            return true;
548        }
549        false
550    }
551
552    /// Log successful operation registration
553    fn log_operation_registration(&self, spec: &modkit::api::OperationSpec) {
554        let current_count = self.openapi_registry.operation_specs.len();
555        tracing::debug!(
556            handler_id = %spec.handler_id,
557            method = %spec.method.as_str(),
558            path = %spec.path,
559            summary = %spec.summary.as_deref().unwrap_or("No summary"),
560            total_operations = current_count,
561            "Registered API operation"
562        );
563    }
564
565    /// Add `OpenAPI` documentation routes to the router
566    fn add_openapi_routes(&self, mut router: axum::Router) -> anyhow::Result<axum::Router> {
567        // Build once, serve as static JSON (no per-request parsing)
568        let op_count = self.openapi_registry.operation_specs.len();
569        tracing::info!(
570            "rest_finalize: emitting OpenAPI with {} operations",
571            op_count
572        );
573
574        let openapi_doc = Arc::new(self.build_openapi()?);
575        let config = self.get_cached_config();
576        let prefix = Self::normalize_prefix_path(&config.prefix_path)?;
577        let html_doc = web::serve_docs(&prefix);
578
579        router = router
580            .route(
581                "/openapi.json",
582                get({
583                    use axum::{http::header, response::IntoResponse};
584                    let doc = openapi_doc;
585                    move || async move {
586                        let json_string = match serde_json::to_string_pretty(doc.as_ref()) {
587                            Ok(json) => json,
588                            Err(e) => {
589                                tracing::error!("Failed to serialize OpenAPI doc: {}", e);
590                                return (http::StatusCode::INTERNAL_SERVER_ERROR).into_response();
591                            }
592                        };
593                        (
594                            [
595                                (header::CONTENT_TYPE, "application/json"),
596                                (header::CACHE_CONTROL, "no-store"),
597                            ],
598                            json_string,
599                        )
600                            .into_response()
601                    }
602                }),
603            )
604            .route("/docs", get(move || async move { html_doc }));
605
606        #[cfg(feature = "embed_elements")]
607        {
608            router = router.route(
609                "/docs/assets/{*file}",
610                get(crate::assets::serve_elements_asset),
611            );
612        }
613
614        Ok(router)
615    }
616}
617
618// Manual implementation of Module trait with config loading
619#[async_trait]
620impl modkit::Module for ApiGateway {
621    async fn init(&self, ctx: &modkit::context::ModuleCtx) -> anyhow::Result<()> {
622        let cfg = ctx.config::<crate::config::ApiGatewayConfig>()?;
623        self.config.store(Arc::new(cfg.clone()));
624
625        debug!(
626            "Effective api_gateway configuration:\n{:#?}",
627            self.config.load()
628        );
629
630        if cfg.auth_disabled {
631            tracing::info!(
632                tenant_id = %DEFAULT_TENANT_ID,
633                "Auth-disabled mode enabled with default tenant"
634            );
635        } else {
636            // Resolve AuthN Resolver client from ClientHub
637            let authn_client = ctx.client_hub().get::<dyn AuthNResolverClient>()?;
638            *self.authn_client.lock() = Some(authn_client);
639            tracing::info!("AuthN Resolver client resolved from ClientHub");
640        }
641
642        Ok(())
643    }
644}
645
646// REST host role: prepare/finalize the router, but do not start the server here.
647impl modkit::contracts::ApiGatewayCapability for ApiGateway {
648    fn rest_prepare(
649        &self,
650        _ctx: &modkit::context::ModuleCtx,
651        router: axum::Router,
652    ) -> anyhow::Result<axum::Router> {
653        // Add health check endpoints:
654        // - /health: detailed JSON response with status and timestamp
655        // - /healthz: simple "ok" liveness probe (Kubernetes-style)
656        let router = router
657            .route("/health", get(web::health_check))
658            .route("/healthz", get(|| async { "ok" }));
659
660        // You may attach global middlewares here (trace, compression, cors), but do not start server.
661        tracing::debug!("REST host prepared base router with health check endpoints");
662        Ok(router)
663    }
664
665    fn rest_finalize(
666        &self,
667        _ctx: &modkit::context::ModuleCtx,
668        mut router: axum::Router,
669    ) -> anyhow::Result<axum::Router> {
670        let config = self.get_cached_config();
671
672        if config.enable_docs {
673            router = self.add_openapi_routes(router)?;
674        }
675
676        // Apply middleware stack (including auth) to the final router
677        tracing::debug!("Applying middleware stack to finalized router");
678        let authn_client = self.authn_client.lock().clone();
679        router = self.apply_middleware_stack(router, authn_client)?;
680
681        let prefix = Self::normalize_prefix_path(&config.prefix_path)?;
682        router = Self::apply_prefix_nesting(router, &prefix);
683
684        // Keep the finalized router to be used by `serve()`
685        *self.final_router.lock() = Some(router.clone());
686
687        tracing::info!("REST host finalized router with OpenAPI endpoints and auth middleware");
688        Ok(router)
689    }
690
691    fn as_registry(&self) -> &dyn modkit::contracts::OpenApiRegistry {
692        self
693    }
694}
695
696impl modkit::contracts::RestApiCapability for ApiGateway {
697    fn register_rest(
698        &self,
699        _ctx: &modkit::context::ModuleCtx,
700        router: axum::Router,
701        _openapi: &dyn modkit::contracts::OpenApiRegistry,
702    ) -> anyhow::Result<axum::Router> {
703        // This module acts as both rest_host and rest, but actual REST endpoints
704        // are handled in the host methods above.
705        Ok(router)
706    }
707}
708
709impl OpenApiRegistry for ApiGateway {
710    fn register_operation(&self, spec: &modkit::api::OperationSpec) {
711        // Reject duplicates with "first wins" policy (second registration = programmer error).
712        if self.check_duplicate_handler(spec) {
713            return;
714        }
715
716        if self.check_duplicate_route(spec) {
717            return;
718        }
719
720        // Delegate to the internal registry
721        self.openapi_registry.register_operation(spec);
722        self.log_operation_registration(spec);
723    }
724
725    fn ensure_schema_raw(
726        &self,
727        root_name: &str,
728        schemas: Vec<(
729            String,
730            utoipa::openapi::RefOr<utoipa::openapi::schema::Schema>,
731        )>,
732    ) -> String {
733        // Delegate to the internal registry
734        self.openapi_registry.ensure_schema_raw(root_name, schemas)
735    }
736
737    fn as_any(&self) -> &dyn std::any::Any {
738        self
739    }
740}
741
742#[cfg(test)]
743#[cfg_attr(coverage_nightly, coverage(off))]
744mod tests {
745    use super::*;
746
747    #[test]
748    fn test_openapi_generation() {
749        let mut config = ApiGatewayConfig::default();
750        config.openapi.title = "Test API".to_owned();
751        config.openapi.version = "1.0.0".to_owned();
752        config.openapi.description = Some("Test Description".to_owned());
753        let api = ApiGateway::new(config);
754
755        // Test that we can build OpenAPI without any operations
756        let doc = api.build_openapi().unwrap();
757        let json = serde_json::to_value(&doc).unwrap();
758
759        // Verify it's valid OpenAPI document structure
760        assert!(json.get("openapi").is_some());
761        assert!(json.get("info").is_some());
762        assert!(json.get("paths").is_some());
763
764        // Verify info section
765        let info = json.get("info").unwrap();
766        assert_eq!(info.get("title").unwrap(), "Test API");
767        assert_eq!(info.get("version").unwrap(), "1.0.0");
768        assert_eq!(info.get("description").unwrap(), "Test Description");
769    }
770}
771
772#[cfg(test)]
773#[cfg_attr(coverage_nightly, coverage(off))]
774mod normalize_prefix_path_tests {
775    use super::*;
776
777    #[test]
778    fn empty_string_returns_empty() {
779        assert_eq!(ApiGateway::normalize_prefix_path("").unwrap(), "");
780    }
781
782    #[test]
783    fn sole_slash_returns_empty() {
784        assert_eq!(ApiGateway::normalize_prefix_path("/").unwrap(), "");
785    }
786
787    #[test]
788    fn multiple_slashes_return_empty() {
789        assert_eq!(ApiGateway::normalize_prefix_path("///").unwrap(), "");
790    }
791
792    #[test]
793    fn whitespace_only_returns_empty() {
794        assert_eq!(ApiGateway::normalize_prefix_path("   ").unwrap(), "");
795    }
796
797    #[test]
798    fn simple_prefix_preserved() {
799        assert_eq!(ApiGateway::normalize_prefix_path("/cf").unwrap(), "/cf");
800    }
801
802    #[test]
803    fn trailing_slash_stripped() {
804        assert_eq!(ApiGateway::normalize_prefix_path("/cf/").unwrap(), "/cf");
805    }
806
807    #[test]
808    fn leading_slash_prepended_when_missing() {
809        assert_eq!(ApiGateway::normalize_prefix_path("cf").unwrap(), "/cf");
810    }
811
812    #[test]
813    fn consecutive_leading_slashes_collapsed() {
814        assert_eq!(ApiGateway::normalize_prefix_path("//cf").unwrap(), "/cf");
815    }
816
817    #[test]
818    fn consecutive_slashes_mid_path_collapsed() {
819        assert_eq!(
820            ApiGateway::normalize_prefix_path("/api//v1").unwrap(),
821            "/api/v1"
822        );
823    }
824
825    #[test]
826    fn many_consecutive_slashes_collapsed() {
827        assert_eq!(
828            ApiGateway::normalize_prefix_path("///api///v1///").unwrap(),
829            "/api/v1"
830        );
831    }
832
833    #[test]
834    fn surrounding_whitespace_trimmed() {
835        assert_eq!(ApiGateway::normalize_prefix_path("  /cf  ").unwrap(), "/cf");
836    }
837
838    #[test]
839    fn nested_path_preserved() {
840        assert_eq!(
841            ApiGateway::normalize_prefix_path("/api/v1").unwrap(),
842            "/api/v1"
843        );
844    }
845
846    #[test]
847    fn dot_in_path_allowed() {
848        assert_eq!(
849            ApiGateway::normalize_prefix_path("/api/v1.0").unwrap(),
850            "/api/v1.0"
851        );
852    }
853
854    #[test]
855    fn rejects_html_injection() {
856        let result = ApiGateway::normalize_prefix_path(r#""><script>alert(1)</script>"#);
857        assert!(result.is_err());
858    }
859
860    #[test]
861    fn rejects_spaces_in_path() {
862        let result = ApiGateway::normalize_prefix_path("/my path");
863        assert!(result.is_err());
864    }
865
866    #[test]
867    fn rejects_query_string_chars() {
868        let result = ApiGateway::normalize_prefix_path("/api?foo=bar");
869        assert!(result.is_err());
870    }
871}
872
873#[cfg(test)]
874#[cfg_attr(coverage_nightly, coverage(off))]
875mod problem_openapi_tests {
876    use super::*;
877    use axum::Json;
878    use modkit::api::{Missing, OperationBuilder};
879    use serde_json::Value;
880
881    async fn dummy_handler() -> Json<Value> {
882        Json(serde_json::json!({"ok": true}))
883    }
884
885    #[tokio::test]
886    async fn openapi_includes_problem_schema_and_response() {
887        let api = ApiGateway::default();
888        let router = axum::Router::new();
889
890        // Build a route with a problem+json response
891        let _router = OperationBuilder::<Missing, Missing, ()>::get("/tests/v1/problem-demo")
892            .public()
893            .summary("Problem demo")
894            .problem_response(&api, http::StatusCode::BAD_REQUEST, "Bad Request") // <-- registers Problem + sets content type
895            .handler(dummy_handler)
896            .register(router, &api);
897
898        let doc = api.build_openapi().expect("openapi");
899        let v = serde_json::to_value(&doc).expect("json");
900
901        // 1) Problem exists in components.schemas
902        let problem = v
903            .pointer("/components/schemas/Problem")
904            .expect("Problem schema missing");
905        assert!(
906            problem.get("$ref").is_none(),
907            "Problem must be a real object, not a self-ref"
908        );
909
910        // 2) Response under /paths/... references Problem and has correct media type
911        let path_obj = v
912            .pointer("/paths/~1tests~1v1~1problem-demo/get/responses/400")
913            .expect("400 response missing");
914
915        // Check what content types exist
916        let content_obj = path_obj.get("content").expect("content object missing");
917        assert!(
918            content_obj.get("application/problem+json").is_some(),
919            "application/problem+json content missing. Available content: {}",
920            serde_json::to_string_pretty(content_obj).unwrap()
921        );
922
923        let content = path_obj
924            .pointer("/content/application~1problem+json")
925            .expect("application/problem+json content missing");
926        // $ref to Problem
927        let schema_ref = content
928            .pointer("/schema/$ref")
929            .and_then(|r| r.as_str())
930            .unwrap_or("");
931        assert_eq!(schema_ref, "#/components/schemas/Problem");
932    }
933}
934
935#[cfg(test)]
936#[cfg_attr(coverage_nightly, coverage(off))]
937mod sse_openapi_tests {
938    use super::*;
939    use axum::Json;
940    use modkit::api::{Missing, OperationBuilder};
941    use serde_json::Value;
942
943    #[derive(Clone)]
944    #[modkit_macros::api_dto(request, response)]
945    struct UserEvent {
946        id: u32,
947        message: String,
948    }
949
950    async fn sse_handler() -> axum::response::sse::Sse<
951        impl futures_core::Stream<Item = Result<axum::response::sse::Event, std::convert::Infallible>>,
952    > {
953        let b = modkit::SseBroadcaster::<UserEvent>::new(4);
954        b.sse_response()
955    }
956
957    #[tokio::test]
958    async fn openapi_has_sse_content() {
959        let api = ApiGateway::default();
960        let router = axum::Router::new();
961
962        let _router = OperationBuilder::<Missing, Missing, ()>::get("/tests/v1/demo/sse")
963            .summary("Demo SSE")
964            .handler(sse_handler)
965            .public()
966            .sse_json::<UserEvent>(&api, "SSE of UserEvent")
967            .register(router, &api);
968
969        let doc = api.build_openapi().expect("openapi");
970        let v = serde_json::to_value(&doc).expect("json");
971
972        // schema is materialized
973        let schema = v
974            .pointer("/components/schemas/UserEvent")
975            .expect("UserEvent missing");
976        assert!(schema.get("$ref").is_none());
977
978        // content is text/event-stream with $ref to our schema
979        let refp = v
980            .pointer("/paths/~1tests~1v1~1demo~1sse/get/responses/200/content/text~1event-stream/schema/$ref")
981            .and_then(|x| x.as_str())
982            .unwrap_or_default();
983        assert_eq!(refp, "#/components/schemas/UserEvent");
984    }
985
986    #[tokio::test]
987    async fn openapi_sse_additional_response() {
988        async fn mixed_handler() -> Json<Value> {
989            Json(serde_json::json!({"ok": true}))
990        }
991
992        let api = ApiGateway::default();
993        let router = axum::Router::new();
994
995        let _router = OperationBuilder::<Missing, Missing, ()>::get("/tests/v1/demo/mixed")
996            .summary("Mixed responses")
997            .public()
998            .handler(mixed_handler)
999            .json_response(http::StatusCode::OK, "Success response")
1000            .sse_json::<UserEvent>(&api, "Additional SSE stream")
1001            .register(router, &api);
1002
1003        let doc = api.build_openapi().expect("openapi");
1004        let v = serde_json::to_value(&doc).expect("json");
1005
1006        // Check that both response types are present
1007        let responses = v
1008            .pointer("/paths/~1tests~1v1~1demo~1mixed/get/responses")
1009            .expect("responses");
1010
1011        // JSON response exists
1012        assert!(responses.get("200").is_some());
1013
1014        // SSE response exists (could be another 200 or different status)
1015        let response_content = responses.get("200").and_then(|r| r.get("content"));
1016        assert!(response_content.is_some());
1017
1018        // UserEvent schema is registered
1019        let schema = v
1020            .pointer("/components/schemas/UserEvent")
1021            .expect("UserEvent missing");
1022        assert!(schema.get("$ref").is_none());
1023    }
1024
1025    #[tokio::test]
1026    async fn test_axum_to_openapi_path_conversion() {
1027        // Define a route with path parameters using Axum 0.8+ style {id}
1028        async fn user_handler() -> Json<Value> {
1029            Json(serde_json::json!({"user_id": "123"}))
1030        }
1031
1032        let api = ApiGateway::default();
1033        let router = axum::Router::new();
1034
1035        let _router = OperationBuilder::<Missing, Missing, ()>::get("/tests/v1/users/{id}")
1036            .summary("Get user by ID")
1037            .public()
1038            .path_param("id", "User ID")
1039            .handler(user_handler)
1040            .json_response(http::StatusCode::OK, "User details")
1041            .register(router, &api);
1042
1043        // Verify the operation was stored with {id} path (same for Axum 0.8 and OpenAPI)
1044        let ops: Vec<_> = api
1045            .openapi_registry
1046            .operation_specs
1047            .iter()
1048            .map(|e| e.value().clone())
1049            .collect();
1050        assert_eq!(ops.len(), 1);
1051        assert_eq!(ops[0].path, "/tests/v1/users/{id}");
1052
1053        // Verify OpenAPI doc also has {id} (no conversion needed for regular params)
1054        let doc = api.build_openapi().expect("openapi");
1055        let v = serde_json::to_value(&doc).expect("json");
1056
1057        let paths = v.get("paths").expect("paths");
1058        assert!(
1059            paths.get("/tests/v1/users/{id}").is_some(),
1060            "OpenAPI should use {{id}} placeholder"
1061        );
1062    }
1063
1064    #[tokio::test]
1065    async fn test_multiple_path_params_conversion() {
1066        async fn item_handler() -> Json<Value> {
1067            Json(serde_json::json!({"ok": true}))
1068        }
1069
1070        let api = ApiGateway::default();
1071        let router = axum::Router::new();
1072
1073        let _router = OperationBuilder::<Missing, Missing, ()>::get(
1074            "/tests/v1/projects/{project_id}/items/{item_id}",
1075        )
1076        .summary("Get project item")
1077        .public()
1078        .path_param("project_id", "Project ID")
1079        .path_param("item_id", "Item ID")
1080        .handler(item_handler)
1081        .json_response(http::StatusCode::OK, "Item details")
1082        .register(router, &api);
1083
1084        // Verify storage and OpenAPI both use {param} syntax
1085        let ops: Vec<_> = api
1086            .openapi_registry
1087            .operation_specs
1088            .iter()
1089            .map(|e| e.value().clone())
1090            .collect();
1091        assert_eq!(
1092            ops[0].path,
1093            "/tests/v1/projects/{project_id}/items/{item_id}"
1094        );
1095
1096        let doc = api.build_openapi().expect("openapi");
1097        let v = serde_json::to_value(&doc).expect("json");
1098        let paths = v.get("paths").expect("paths");
1099        assert!(
1100            paths
1101                .get("/tests/v1/projects/{project_id}/items/{item_id}")
1102                .is_some()
1103        );
1104    }
1105
1106    #[tokio::test]
1107    async fn test_wildcard_path_conversion() {
1108        async fn static_handler() -> Json<Value> {
1109            Json(serde_json::json!({"ok": true}))
1110        }
1111
1112        let api = ApiGateway::default();
1113        let router = axum::Router::new();
1114
1115        // Axum 0.8 uses {*path} for wildcards
1116        let _router = OperationBuilder::<Missing, Missing, ()>::get("/tests/v1/static/{*path}")
1117            .summary("Serve static files")
1118            .public()
1119            .handler(static_handler)
1120            .json_response(http::StatusCode::OK, "File content")
1121            .register(router, &api);
1122
1123        // Verify internal storage keeps Axum wildcard syntax {*path}
1124        let ops: Vec<_> = api
1125            .openapi_registry
1126            .operation_specs
1127            .iter()
1128            .map(|e| e.value().clone())
1129            .collect();
1130        assert_eq!(ops[0].path, "/tests/v1/static/{*path}");
1131
1132        // Verify OpenAPI converts wildcard to {path} (without asterisk)
1133        let doc = api.build_openapi().expect("openapi");
1134        let v = serde_json::to_value(&doc).expect("json");
1135        let paths = v.get("paths").expect("paths");
1136        assert!(
1137            paths.get("/tests/v1/static/{path}").is_some(),
1138            "Wildcard {{*path}} should be converted to {{path}} in OpenAPI"
1139        );
1140        assert!(
1141            paths.get("/static/{*path}").is_none(),
1142            "OpenAPI should not have Axum-style {{*path}}"
1143        );
1144    }
1145
1146    #[tokio::test]
1147    async fn test_multipart_file_upload_openapi() {
1148        async fn upload_handler() -> Json<Value> {
1149            Json(serde_json::json!({"uploaded": true}))
1150        }
1151
1152        let api = ApiGateway::default();
1153        let router = axum::Router::new();
1154
1155        let _router = OperationBuilder::<Missing, Missing, ()>::post("/tests/v1/files/upload")
1156            .operation_id("upload_file")
1157            .public()
1158            .summary("Upload a file")
1159            .multipart_file_request("file", Some("File to upload"))
1160            .handler(upload_handler)
1161            .json_response(http::StatusCode::OK, "Upload successful")
1162            .register(router, &api);
1163
1164        // Build OpenAPI and verify multipart schema
1165        let doc = api.build_openapi().expect("openapi");
1166        let v = serde_json::to_value(&doc).expect("json");
1167
1168        let paths = v.get("paths").expect("paths");
1169        let upload_path = paths
1170            .get("/tests/v1/files/upload")
1171            .expect("/tests/v1/files/upload path");
1172        let post_op = upload_path.get("post").expect("POST operation");
1173
1174        // Verify request body exists
1175        let request_body = post_op.get("requestBody").expect("requestBody");
1176        let content = request_body.get("content").expect("content");
1177        let multipart = content
1178            .get("multipart/form-data")
1179            .expect("multipart/form-data content type");
1180
1181        // Verify schema structure
1182        let schema = multipart.get("schema").expect("schema");
1183        assert_eq!(
1184            schema.get("type").and_then(|v| v.as_str()),
1185            Some("object"),
1186            "Schema should be of type object"
1187        );
1188
1189        // Verify properties
1190        let properties = schema.get("properties").expect("properties");
1191        let file_prop = properties.get("file").expect("file property");
1192        assert_eq!(
1193            file_prop.get("type").and_then(|v| v.as_str()),
1194            Some("string"),
1195            "File field should be of type string"
1196        );
1197        assert_eq!(
1198            file_prop.get("format").and_then(|v| v.as_str()),
1199            Some("binary"),
1200            "File field should have format binary"
1201        );
1202
1203        // Verify required fields
1204        let required = schema.get("required").expect("required");
1205        let required_arr = required.as_array().expect("required should be array");
1206        assert_eq!(required_arr.len(), 1);
1207        assert_eq!(required_arr[0].as_str(), Some("file"));
1208    }
1209}