Skip to main content

api_gateway/
module.rs

1//! API Gateway Module definition
2//!
3//! Contains the `ApiGateway` module struct and its trait implementations.
4
5use async_trait::async_trait;
6use std::sync::Arc;
7
8use arc_swap::ArcSwap;
9use dashmap::DashMap;
10
11use anyhow::Result;
12use axum::http::Method;
13use axum::middleware::from_fn_with_state;
14use axum::{Router, extract::DefaultBodyLimit, middleware::from_fn, routing::get};
15use modkit::api::{OpenApiRegistry, OpenApiRegistryImpl};
16use modkit::lifecycle::ReadySignal;
17use parking_lot::Mutex;
18use std::net::SocketAddr;
19use std::time::Duration;
20use tokio_util::sync::CancellationToken;
21use tower_http::{
22    catch_panic::CatchPanicLayer,
23    limit::RequestBodyLimitLayer,
24    request_id::{PropagateRequestIdLayer, SetRequestIdLayer},
25    timeout::TimeoutLayer,
26};
27use tracing::debug;
28
29use authn_resolver_sdk::AuthNResolverClient;
30
31use crate::config::ApiGatewayConfig;
32use crate::middleware::auth;
33use modkit_security::SecurityContext;
34use modkit_security::constants::{DEFAULT_SUBJECT_ID, DEFAULT_TENANT_ID};
35
36use crate::middleware;
37use crate::router_cache::RouterCache;
38use crate::web;
39
40/// Main API Gateway module — owns the HTTP server (`rest_host`) and collects
41/// typed operation specs to emit a single `OpenAPI` document.
42#[modkit::module(
43	name = "api-gateway",
44	capabilities = [rest_host, rest, stateful],
45    deps = ["grpc-hub", "authn-resolver"],
46	lifecycle(entry = "serve", stop_timeout = "30s", await_ready)
47)]
48pub struct ApiGateway {
49    // Lock-free config using arc-swap for read-mostly access
50    pub(crate) config: ArcSwap<ApiGatewayConfig>,
51    // OpenAPI registry for operations and schemas
52    pub(crate) openapi_registry: Arc<OpenApiRegistryImpl>,
53    // Built router cache for zero-lock hot path access
54    pub(crate) router_cache: RouterCache<axum::Router>,
55    // Store the finalized router from REST phase for serving
56    pub(crate) final_router: Mutex<Option<axum::Router>>,
57    // AuthN Resolver client (resolved during init, None when auth_disabled)
58    pub(crate) authn_client: Mutex<Option<Arc<dyn AuthNResolverClient>>>,
59
60    // Duplicate detection (per (method, path) and per handler id)
61    pub(crate) registered_routes: DashMap<(Method, String), ()>,
62    pub(crate) registered_handlers: DashMap<String, ()>,
63}
64
65impl Default for ApiGateway {
66    fn default() -> Self {
67        let default_router = Router::new();
68        Self {
69            config: ArcSwap::from_pointee(ApiGatewayConfig::default()),
70            openapi_registry: Arc::new(OpenApiRegistryImpl::new()),
71            router_cache: RouterCache::new(default_router),
72            final_router: Mutex::new(None),
73            authn_client: Mutex::new(None),
74            registered_routes: DashMap::new(),
75            registered_handlers: DashMap::new(),
76        }
77    }
78}
79
80impl ApiGateway {
81    fn apply_prefix_nesting(mut router: Router, prefix: &str) -> Router {
82        if prefix.is_empty() {
83            return router;
84        }
85
86        let top = Router::new()
87            .route("/health", get(web::health_check))
88            .route("/healthz", get(|| async { "ok" }));
89
90        router = Router::new().nest(prefix, router);
91        top.merge(router)
92    }
93
94    /// Create a new `ApiGateway` instance with the given configuration
95    #[must_use]
96    pub fn new(config: ApiGatewayConfig) -> Self {
97        let default_router = Router::new();
98        Self {
99            config: ArcSwap::from_pointee(config),
100            openapi_registry: Arc::new(OpenApiRegistryImpl::new()),
101            router_cache: RouterCache::new(default_router),
102            final_router: Mutex::new(None),
103            authn_client: Mutex::new(None),
104            registered_routes: DashMap::new(),
105            registered_handlers: DashMap::new(),
106        }
107    }
108
109    /// Get the current configuration (cheap clone from `ArcSwap`)
110    pub fn get_config(&self) -> ApiGatewayConfig {
111        (**self.config.load()).clone()
112    }
113
114    /// Get cached configuration (lock-free with `ArcSwap`)
115    pub fn get_cached_config(&self) -> ApiGatewayConfig {
116        (**self.config.load()).clone()
117    }
118
119    /// Get the cached router without rebuilding (useful for performance-critical paths)
120    pub fn get_cached_router(&self) -> Arc<Router> {
121        self.router_cache.load()
122    }
123
124    /// Force rebuild and cache of the router.
125    ///
126    /// # Errors
127    /// Returns an error if router building fails.
128    pub fn rebuild_and_cache_router(&self) -> Result<()> {
129        let new_router = self.build_router()?;
130        self.router_cache.store(new_router);
131        Ok(())
132    }
133
134    /// Build route policy from operation specs.
135    fn build_route_policy_from_specs(&self) -> Result<auth::GatewayRoutePolicy> {
136        let mut authenticated_routes = std::collections::HashSet::new();
137        let mut public_routes = std::collections::HashSet::new();
138
139        // Always mark built-in health check routes as public
140        public_routes.insert((Method::GET, "/health".to_owned()));
141        public_routes.insert((Method::GET, "/healthz".to_owned()));
142
143        public_routes.insert((Method::GET, "/docs".to_owned()));
144        public_routes.insert((Method::GET, "/openapi.json".to_owned()));
145
146        for spec in &self.openapi_registry.operation_specs {
147            let spec = spec.value();
148
149            let route_key = (spec.method.clone(), spec.path.clone());
150
151            if spec.authenticated {
152                authenticated_routes.insert(route_key.clone());
153            }
154
155            if spec.is_public {
156                public_routes.insert(route_key);
157            }
158        }
159
160        let config = self.get_cached_config();
161        let requirements_count = authenticated_routes.len();
162        let public_routes_count = public_routes.len();
163
164        let route_policy = auth::build_route_policy(&config, authenticated_routes, public_routes)?;
165
166        tracing::info!(
167            auth_disabled = config.auth_disabled,
168            require_auth_by_default = config.require_auth_by_default,
169            requirements_count = requirements_count,
170            public_routes_count = public_routes_count,
171            "Route policy built from operation specs"
172        );
173
174        Ok(route_policy)
175    }
176
177    fn normalize_prefix_path(raw: &str) -> Result<String> {
178        let trimmed = raw.trim();
179        // Collapse consecutive slashes then strip trailing slash(es).
180        let collapsed: String =
181            trimmed
182                .chars()
183                .fold(String::with_capacity(trimmed.len()), |mut acc, c| {
184                    if c == '/' && acc.ends_with('/') {
185                        // skip duplicate slash
186                    } else {
187                        acc.push(c);
188                    }
189                    acc
190                });
191        let prefix = collapsed.trim_end_matches('/');
192        let result = if prefix.is_empty() {
193            String::new()
194        } else if prefix.starts_with('/') {
195            prefix.to_owned()
196        } else {
197            format!("/{prefix}")
198        };
199        // Reject characters that are unsafe in URL paths or HTML attributes.
200        if !result
201            .bytes()
202            .all(|b| b.is_ascii_alphanumeric() || b == b'/' || b == b'_' || b == b'-' || b == b'.')
203        {
204            anyhow::bail!(
205                "prefix_path contains invalid characters (must match [a-zA-Z0-9/_\\-.]): {raw:?}"
206            );
207        }
208
209        if result.split('/').any(|seg| seg == "." || seg == "..") {
210            anyhow::bail!("prefix_path must not contain '.' or '..' segments: {raw:?}");
211        }
212
213        Ok(result)
214    }
215
216    /// Apply all middleware layers to a router (request ID, tracing, timeout, body limit, CORS, rate limiting, error mapping, auth)
217    pub(crate) fn apply_middleware_stack(
218        &self,
219        mut router: Router,
220        authn_client: Option<Arc<dyn AuthNResolverClient>>,
221    ) -> Result<Router> {
222        // Build route policy once
223        let route_policy = self.build_route_policy_from_specs()?;
224
225        // IMPORTANT: `axum::Router::layer(...)` behaves like Tower layers: the **last** added layer
226        // becomes the **outermost** layer and therefore runs **first** on the request path.
227        //
228        // Desired request execution order (outermost -> innermost):
229        // SetRequestId -> PropagateRequestId -> Trace -> push_req_id
230        // -> HttpMetrics -> CatchPanic
231        // -> Timeout -> BodyLimit -> CORS -> MIME validation -> RateLimit -> ErrorMapping -> Auth -> License
232        // -> [Route matching] -> PropagateMatchedPath -> Handler
233        //
234        // Therefore we must add layers in the reverse order (innermost -> outermost) below.
235        // Due future refactoring, this order must be maintained.
236
237        // 14) Propagate MatchedPath to response extensions (route_layer — innermost).
238        // This copies MatchedPath from the request (populated by Axum route matching)
239        // into the response so outer layer() middleware (metrics) can read it.
240        router = router.route_layer(from_fn(middleware::http_metrics::propagate_matched_path));
241
242        let config = self.get_cached_config();
243
244        // Collect specs once; used by MIME validation + rate limiting maps.
245        let specs: Vec<_> = self
246            .openapi_registry
247            .operation_specs
248            .iter()
249            .map(|e| e.value().clone())
250            .collect();
251
252        // 13) License validation
253        let license_map = middleware::license_validation::LicenseRequirementMap::from_specs(&specs);
254
255        router = router.layer(from_fn(
256            move |req: axum::extract::Request, next: axum::middleware::Next| {
257                let map = license_map.clone();
258                middleware::license_validation::license_validation_middleware(map, req, next)
259            },
260        ));
261
262        // 12) Auth
263        if config.auth_disabled {
264            // Build security contexts for compatibility during migration
265            let default_security_context = SecurityContext::builder()
266                .subject_id(DEFAULT_SUBJECT_ID)
267                .subject_tenant_id(DEFAULT_TENANT_ID)
268                .build()?;
269
270            tracing::warn!(
271                "API Gateway auth is DISABLED: all requests will run with default tenant SecurityContext. \
272                 This mode bypasses authentication and is intended ONLY for single-user on-premises deployments without an IdP. \
273                 Permission checks and secure ORM still apply. DO NOT use this mode in multi-tenant or production environments."
274            );
275            router = router.layer(from_fn(
276                move |mut req: axum::extract::Request, next: axum::middleware::Next| {
277                    let sec_context = default_security_context.clone();
278                    async move {
279                        req.extensions_mut().insert(sec_context);
280                        next.run(req).await
281                    }
282                },
283            ));
284        } else if let Some(client) = authn_client {
285            let auth_state = auth::AuthState {
286                authn_client: client,
287                route_policy,
288            };
289            router = router.layer(from_fn_with_state(auth_state, auth::authn_middleware));
290        } else {
291            return Err(anyhow::anyhow!(
292                "auth is enabled but no AuthN Resolver client is available; \
293                 ensure `authn_resolver` module is loaded or set `auth_disabled: true`"
294            ));
295        }
296
297        // 11) Error mapping (outer to auth so it can translate auth/handler errors)
298        router = router.layer(from_fn(modkit::api::error_layer::error_mapping_middleware));
299
300        // 10) Per-route rate limiting & in-flight limits
301        let rate_map = middleware::rate_limit::RateLimiterMap::from_specs(&specs, &config)?;
302
303        router = router.layer(from_fn(
304            move |req: axum::extract::Request, next: axum::middleware::Next| {
305                let map = rate_map.clone();
306                middleware::rate_limit::rate_limit_middleware(map, req, next)
307            },
308        ));
309
310        // 9) MIME type validation
311        let mime_map = middleware::mime_validation::build_mime_validation_map(&specs);
312        router = router.layer(from_fn(
313            move |req: axum::extract::Request, next: axum::middleware::Next| {
314                let map = mime_map.clone();
315                middleware::mime_validation::mime_validation_middleware(map, req, next)
316            },
317        ));
318
319        // 8) CORS (must be outer to auth/limits so OPTIONS preflight short-circuits)
320        if config.cors_enabled {
321            router = router.layer(crate::cors::build_cors_layer(&config));
322        }
323
324        // 7) Body limit
325        router = router.layer(RequestBodyLimitLayer::new(config.defaults.body_limit_bytes));
326        router = router.layer(DefaultBodyLimit::max(config.defaults.body_limit_bytes));
327
328        // 6) Timeout
329        router = router.layer(TimeoutLayer::with_status_code(
330            axum::http::StatusCode::GATEWAY_TIMEOUT,
331            Duration::from_secs(30),
332        ));
333
334        // 5) CatchPanic (converts panics to 500 before metrics sees them)
335        router = router.layer(CatchPanicLayer::new());
336
337        // 4) HTTP metrics (layer — captures all middleware responses including auth/rate-limit/timeout)
338        let http_metrics = Arc::new(middleware::http_metrics::HttpMetrics::new(
339            Self::MODULE_NAME,
340            &config.metrics.prefix,
341        ));
342        router = router.layer(from_fn_with_state(
343            http_metrics,
344            middleware::http_metrics::http_metrics_middleware,
345        ));
346
347        // 3.5) Structured access log (runs after push_req_id populates XRequestId extension)
348        router = router.layer(from_fn(middleware::access_log::access_log_middleware));
349
350        // 3) Record request_id into span + extensions (requires span to exist first => must be inner to Trace)
351        router = router.layer(from_fn(middleware::request_id::push_req_id_to_extensions));
352
353        // 2) Trace (outer to push_req_id_to_extensions)
354        router = router.layer({
355            use modkit_http::otel;
356            use tower_http::trace::TraceLayer;
357            use tracing::field::Empty;
358
359            TraceLayer::new_for_http()
360                .make_span_with(move |req: &axum::http::Request<axum::body::Body>| {
361                    let hdr = middleware::request_id::header();
362                    let rid = req
363                        .headers()
364                        .get(&hdr)
365                        .and_then(|v| v.to_str().ok())
366                        .unwrap_or("n/a");
367
368                    let span = tracing::info_span!(
369                        "http_request",
370                        method = %req.method(),
371                        uri = %req.uri().path(),
372                        version = ?req.version(),
373                        module = "api_gateway",
374                        endpoint = %req.uri().path(),
375                        request_id = %rid,
376                        status = Empty,
377                        latency_ms = Empty,
378                        // OpenTelemetry semantic conventions
379                        "http.method" = %req.method(),
380                        "http.target" = %req.uri().path(),
381                        "http.scheme" = req.uri().scheme_str().unwrap_or("http"),
382                        "http.host" = req.headers().get("host")
383                            .and_then(|h| h.to_str().ok())
384                            .unwrap_or("unknown"),
385                        "user_agent.original" = req.headers().get("user-agent")
386                            .and_then(|h| h.to_str().ok())
387                            .unwrap_or("unknown"),
388                        // Trace context placeholders (for log correlation)
389                        trace_id = Empty,
390                        parent.trace_id = Empty
391                    );
392
393                    // Set parent OTel trace context (W3C traceparent), if any
394                    // This also populates trace_id and parent.trace_id from headers
395                    otel::set_parent_from_headers(&span, req.headers());
396
397                    span
398                })
399                .on_response(
400                    |res: &axum::http::Response<axum::body::Body>,
401                     latency: std::time::Duration,
402                     span: &tracing::Span| {
403                        let ms = latency.as_millis();
404                        span.record("status", res.status().as_u16());
405                        span.record("latency_ms", ms);
406                    },
407                )
408        });
409
410        // 1) Request ID handling (outermost)
411        let x_request_id = crate::middleware::request_id::header();
412        // If missing, generate x-request-id first; then propagate it to the response.
413        router = router.layer(PropagateRequestIdLayer::new(x_request_id.clone()));
414        router = router.layer(SetRequestIdLayer::new(
415            x_request_id,
416            crate::middleware::request_id::MakeReqId,
417        ));
418
419        Ok(router)
420    }
421
422    /// Build the HTTP router from registered routes and operations.
423    ///
424    /// # Errors
425    /// Returns an error if router building or middleware setup fails.
426    pub fn build_router(&self) -> Result<Router> {
427        // If the cached router is currently held elsewhere (e.g., by the running server),
428        // return it without rebuilding to avoid unnecessary allocations.
429        let cached_router = self.router_cache.load();
430        if Arc::strong_count(&cached_router) > 1 {
431            tracing::debug!("Using cached router");
432            return Ok((*cached_router).clone());
433        }
434
435        tracing::debug!("Building new router (standalone/fallback mode)");
436        // In standalone mode (no REST pipeline), register both health endpoints here.
437        // In normal operation, rest_prepare() registers these instead.
438        let mut router = Router::new()
439            .route("/health", get(web::health_check))
440            .route("/healthz", get(|| async { "ok" }));
441
442        // Apply all middleware layers including auth, above the router
443        let authn_client = self.authn_client.lock().clone();
444        router = self.apply_middleware_stack(router, authn_client)?;
445
446        let config = self.get_cached_config();
447        let prefix = Self::normalize_prefix_path(&config.prefix_path)?;
448        router = Self::apply_prefix_nesting(router, &prefix);
449
450        // Cache the built router for future use
451        self.router_cache.store(router.clone());
452
453        Ok(router)
454    }
455
456    /// Build `OpenAPI` specification from registered routes and components.
457    ///
458    /// # Errors
459    /// Returns an error if `OpenAPI` specification building fails.
460    pub fn build_openapi(&self) -> Result<utoipa::openapi::OpenApi> {
461        let config = self.get_cached_config();
462        let info = modkit::api::OpenApiInfo {
463            title: config.openapi.title.clone(),
464            version: config.openapi.version.clone(),
465            description: config.openapi.description,
466        };
467        self.openapi_registry.build_openapi(&info)
468    }
469
470    /// Parse bind address from configuration string.
471    fn parse_bind_address(bind_addr: &str) -> anyhow::Result<SocketAddr> {
472        bind_addr
473            .parse()
474            .map_err(|e| anyhow::anyhow!("Invalid bind address '{bind_addr}': {e}"))
475    }
476
477    /// Get the finalized router or build a default one.
478    fn get_or_build_router(self: &Arc<Self>) -> anyhow::Result<Router> {
479        let stored = { self.final_router.lock().take() };
480
481        if let Some(router) = stored {
482            tracing::debug!("Using router from REST phase");
483            Ok(router)
484        } else {
485            tracing::debug!("No router from REST phase, building default router");
486            self.build_router()
487        }
488    }
489
490    /// Background HTTP server: bind, notify ready, serve until cancelled.
491    ///
492    /// This method is the lifecycle entry-point generated by the macro
493    /// (`#[modkit::module(..., lifecycle(...))]`).
494    pub(crate) async fn serve(
495        self: Arc<Self>,
496        cancel: CancellationToken,
497        ready: ReadySignal,
498    ) -> anyhow::Result<()> {
499        let cfg = self.get_cached_config();
500        let addr = Self::parse_bind_address(&cfg.bind_addr)?;
501        let router = self.get_or_build_router()?;
502
503        // Bind the socket, only now consider the service "ready"
504        let listener = tokio::net::TcpListener::bind(addr).await?;
505        tracing::info!("HTTP server bound on {}", addr);
506        ready.notify(); // Starting -> Running
507
508        // Graceful shutdown on cancel
509        let shutdown = {
510            let cancel = cancel.clone();
511            async move {
512                cancel.cancelled().await;
513                tracing::info!("HTTP server shutting down gracefully (cancellation)");
514            }
515        };
516
517        axum::serve(
518            listener,
519            router.into_make_service_with_connect_info::<SocketAddr>(),
520        )
521        .with_graceful_shutdown(shutdown)
522        .await
523        .map_err(|e| anyhow::anyhow!(e))
524    }
525
526    /// Check if `handler_id` is already registered (returns true if duplicate)
527    fn check_duplicate_handler(&self, spec: &modkit::api::OperationSpec) -> bool {
528        if self
529            .registered_handlers
530            .insert(spec.handler_id.clone(), ())
531            .is_some()
532        {
533            tracing::error!(
534                handler_id = %spec.handler_id,
535                method = %spec.method.as_str(),
536                path = %spec.path,
537                "Duplicate handler_id detected; ignoring subsequent registration"
538            );
539            return true;
540        }
541        false
542    }
543
544    /// Check if route (method, path) is already registered (returns true if duplicate)
545    fn check_duplicate_route(&self, spec: &modkit::api::OperationSpec) -> bool {
546        let route_key = (spec.method.clone(), spec.path.clone());
547        if self.registered_routes.insert(route_key, ()).is_some() {
548            tracing::error!(
549                method = %spec.method.as_str(),
550                path = %spec.path,
551                "Duplicate (method, path) detected; ignoring subsequent registration"
552            );
553            return true;
554        }
555        false
556    }
557
558    /// Log successful operation registration
559    fn log_operation_registration(&self, spec: &modkit::api::OperationSpec) {
560        let current_count = self.openapi_registry.operation_specs.len();
561        tracing::debug!(
562            handler_id = %spec.handler_id,
563            method = %spec.method.as_str(),
564            path = %spec.path,
565            summary = %spec.summary.as_deref().unwrap_or("No summary"),
566            total_operations = current_count,
567            "Registered API operation"
568        );
569    }
570
571    /// Add `OpenAPI` documentation routes to the router
572    fn add_openapi_routes(&self, mut router: axum::Router) -> anyhow::Result<axum::Router> {
573        // Build once, serve as static JSON (no per-request parsing)
574        let op_count = self.openapi_registry.operation_specs.len();
575        tracing::info!(
576            "rest_finalize: emitting OpenAPI with {} operations",
577            op_count
578        );
579
580        let openapi_doc = Arc::new(self.build_openapi()?);
581        let config = self.get_cached_config();
582        let prefix = Self::normalize_prefix_path(&config.prefix_path)?;
583        let html_doc = web::serve_docs(&prefix);
584
585        router = router
586            .route(
587                "/openapi.json",
588                get({
589                    use axum::{http::header, response::IntoResponse};
590                    let doc = openapi_doc;
591                    move || async move {
592                        let json_string = match serde_json::to_string_pretty(doc.as_ref()) {
593                            Ok(json) => json,
594                            Err(e) => {
595                                tracing::error!("Failed to serialize OpenAPI doc: {}", e);
596                                return (http::StatusCode::INTERNAL_SERVER_ERROR).into_response();
597                            }
598                        };
599                        (
600                            [
601                                (header::CONTENT_TYPE, "application/json"),
602                                (header::CACHE_CONTROL, "no-store"),
603                            ],
604                            json_string,
605                        )
606                            .into_response()
607                    }
608                }),
609            )
610            .route("/docs", get(move || async move { html_doc }));
611
612        #[cfg(feature = "embed_elements")]
613        {
614            router = router.route(
615                "/docs/assets/{*file}",
616                get(crate::assets::serve_elements_asset),
617            );
618        }
619
620        Ok(router)
621    }
622}
623
624// Manual implementation of Module trait with config loading
625#[async_trait]
626impl modkit::Module for ApiGateway {
627    async fn init(&self, ctx: &modkit::context::ModuleCtx) -> anyhow::Result<()> {
628        let cfg = ctx.config::<crate::config::ApiGatewayConfig>()?;
629        self.config.store(Arc::new(cfg.clone()));
630
631        debug!(
632            "Effective api_gateway configuration:\n{:#?}",
633            self.config.load()
634        );
635
636        if cfg.auth_disabled {
637            tracing::info!(
638                tenant_id = %DEFAULT_TENANT_ID,
639                "Auth-disabled mode enabled with default tenant"
640            );
641        } else {
642            // Resolve AuthN Resolver client from ClientHub
643            let authn_client = ctx.client_hub().get::<dyn AuthNResolverClient>()?;
644            *self.authn_client.lock() = Some(authn_client);
645            tracing::info!("AuthN Resolver client resolved from ClientHub");
646        }
647
648        Ok(())
649    }
650}
651
652// REST host role: prepare/finalize the router, but do not start the server here.
653impl modkit::contracts::ApiGatewayCapability for ApiGateway {
654    fn rest_prepare(
655        &self,
656        _ctx: &modkit::context::ModuleCtx,
657        router: axum::Router,
658    ) -> anyhow::Result<axum::Router> {
659        // Add health check endpoints:
660        // - /health: detailed JSON response with status and timestamp
661        // - /healthz: simple "ok" liveness probe (Kubernetes-style)
662        let router = router
663            .route("/health", get(web::health_check))
664            .route("/healthz", get(|| async { "ok" }));
665
666        // You may attach global middlewares here (trace, compression, cors), but do not start server.
667        tracing::debug!("REST host prepared base router with health check endpoints");
668        Ok(router)
669    }
670
671    fn rest_finalize(
672        &self,
673        _ctx: &modkit::context::ModuleCtx,
674        mut router: axum::Router,
675    ) -> anyhow::Result<axum::Router> {
676        let config = self.get_cached_config();
677
678        if config.enable_docs {
679            router = self.add_openapi_routes(router)?;
680        }
681
682        // Apply middleware stack (including auth) to the final router
683        tracing::debug!("Applying middleware stack to finalized router");
684        let authn_client = self.authn_client.lock().clone();
685        router = self.apply_middleware_stack(router, authn_client)?;
686
687        let prefix = Self::normalize_prefix_path(&config.prefix_path)?;
688        router = Self::apply_prefix_nesting(router, &prefix);
689
690        // Keep the finalized router to be used by `serve()`
691        *self.final_router.lock() = Some(router.clone());
692
693        tracing::info!("REST host finalized router with OpenAPI endpoints and auth middleware");
694        Ok(router)
695    }
696
697    fn as_registry(&self) -> &dyn modkit::contracts::OpenApiRegistry {
698        self
699    }
700}
701
702impl modkit::contracts::RestApiCapability for ApiGateway {
703    fn register_rest(
704        &self,
705        _ctx: &modkit::context::ModuleCtx,
706        router: axum::Router,
707        _openapi: &dyn modkit::contracts::OpenApiRegistry,
708    ) -> anyhow::Result<axum::Router> {
709        // This module acts as both rest_host and rest, but actual REST endpoints
710        // are handled in the host methods above.
711        Ok(router)
712    }
713}
714
715impl OpenApiRegistry for ApiGateway {
716    fn register_operation(&self, spec: &modkit::api::OperationSpec) {
717        // Reject duplicates with "first wins" policy (second registration = programmer error).
718        if self.check_duplicate_handler(spec) {
719            return;
720        }
721
722        if self.check_duplicate_route(spec) {
723            return;
724        }
725
726        // Delegate to the internal registry
727        self.openapi_registry.register_operation(spec);
728        self.log_operation_registration(spec);
729    }
730
731    fn ensure_schema_raw(
732        &self,
733        root_name: &str,
734        schemas: Vec<(
735            String,
736            utoipa::openapi::RefOr<utoipa::openapi::schema::Schema>,
737        )>,
738    ) -> String {
739        // Delegate to the internal registry
740        self.openapi_registry.ensure_schema_raw(root_name, schemas)
741    }
742
743    fn as_any(&self) -> &dyn std::any::Any {
744        self
745    }
746}
747
748#[cfg(test)]
749#[cfg_attr(coverage_nightly, coverage(off))]
750mod tests {
751    use super::*;
752
753    #[test]
754    fn test_openapi_generation() {
755        let mut config = ApiGatewayConfig::default();
756        config.openapi.title = "Test API".to_owned();
757        config.openapi.version = "1.0.0".to_owned();
758        config.openapi.description = Some("Test Description".to_owned());
759        let api = ApiGateway::new(config);
760
761        // Test that we can build OpenAPI without any operations
762        let doc = api.build_openapi().unwrap();
763        let json = serde_json::to_value(&doc).unwrap();
764
765        // Verify it's valid OpenAPI document structure
766        assert!(json.get("openapi").is_some());
767        assert!(json.get("info").is_some());
768        assert!(json.get("paths").is_some());
769
770        // Verify info section
771        let info = json.get("info").unwrap();
772        assert_eq!(info.get("title").unwrap(), "Test API");
773        assert_eq!(info.get("version").unwrap(), "1.0.0");
774        assert_eq!(info.get("description").unwrap(), "Test Description");
775    }
776}
777
778#[cfg(test)]
779#[cfg_attr(coverage_nightly, coverage(off))]
780mod normalize_prefix_path_tests {
781    use super::*;
782
783    #[test]
784    fn empty_string_returns_empty() {
785        assert_eq!(ApiGateway::normalize_prefix_path("").unwrap(), "");
786    }
787
788    #[test]
789    fn sole_slash_returns_empty() {
790        assert_eq!(ApiGateway::normalize_prefix_path("/").unwrap(), "");
791    }
792
793    #[test]
794    fn multiple_slashes_return_empty() {
795        assert_eq!(ApiGateway::normalize_prefix_path("///").unwrap(), "");
796    }
797
798    #[test]
799    fn whitespace_only_returns_empty() {
800        assert_eq!(ApiGateway::normalize_prefix_path("   ").unwrap(), "");
801    }
802
803    #[test]
804    fn simple_prefix_preserved() {
805        assert_eq!(ApiGateway::normalize_prefix_path("/cf").unwrap(), "/cf");
806    }
807
808    #[test]
809    fn trailing_slash_stripped() {
810        assert_eq!(ApiGateway::normalize_prefix_path("/cf/").unwrap(), "/cf");
811    }
812
813    #[test]
814    fn leading_slash_prepended_when_missing() {
815        assert_eq!(ApiGateway::normalize_prefix_path("cf").unwrap(), "/cf");
816    }
817
818    #[test]
819    fn consecutive_leading_slashes_collapsed() {
820        assert_eq!(ApiGateway::normalize_prefix_path("//cf").unwrap(), "/cf");
821    }
822
823    #[test]
824    fn consecutive_slashes_mid_path_collapsed() {
825        assert_eq!(
826            ApiGateway::normalize_prefix_path("/api//v1").unwrap(),
827            "/api/v1"
828        );
829    }
830
831    #[test]
832    fn many_consecutive_slashes_collapsed() {
833        assert_eq!(
834            ApiGateway::normalize_prefix_path("///api///v1///").unwrap(),
835            "/api/v1"
836        );
837    }
838
839    #[test]
840    fn surrounding_whitespace_trimmed() {
841        assert_eq!(ApiGateway::normalize_prefix_path("  /cf  ").unwrap(), "/cf");
842    }
843
844    #[test]
845    fn nested_path_preserved() {
846        assert_eq!(
847            ApiGateway::normalize_prefix_path("/api/v1").unwrap(),
848            "/api/v1"
849        );
850    }
851
852    #[test]
853    fn dot_in_path_allowed() {
854        assert_eq!(
855            ApiGateway::normalize_prefix_path("/api/v1.0").unwrap(),
856            "/api/v1.0"
857        );
858    }
859
860    #[test]
861    fn rejects_html_injection() {
862        let result = ApiGateway::normalize_prefix_path(r#""><script>alert(1)</script>"#);
863        assert!(result.is_err());
864    }
865
866    #[test]
867    fn rejects_spaces_in_path() {
868        let result = ApiGateway::normalize_prefix_path("/my path");
869        assert!(result.is_err());
870    }
871
872    #[test]
873    fn rejects_query_string_chars() {
874        let result = ApiGateway::normalize_prefix_path("/api?foo=bar");
875        assert!(result.is_err());
876    }
877}
878
879#[cfg(test)]
880#[cfg_attr(coverage_nightly, coverage(off))]
881mod problem_openapi_tests {
882    use super::*;
883    use axum::Json;
884    use modkit::api::{Missing, OperationBuilder};
885    use serde_json::Value;
886
887    async fn dummy_handler() -> Json<Value> {
888        Json(serde_json::json!({"ok": true}))
889    }
890
891    #[tokio::test]
892    async fn openapi_includes_problem_schema_and_response() {
893        let api = ApiGateway::default();
894        let router = axum::Router::new();
895
896        // Build a route with a problem+json response
897        let _router = OperationBuilder::<Missing, Missing, ()>::get("/tests/v1/problem-demo")
898            .public()
899            .summary("Problem demo")
900            .problem_response(&api, http::StatusCode::BAD_REQUEST, "Bad Request") // <-- registers Problem + sets content type
901            .handler(dummy_handler)
902            .register(router, &api);
903
904        let doc = api.build_openapi().expect("openapi");
905        let v = serde_json::to_value(&doc).expect("json");
906
907        // 1) Problem exists in components.schemas
908        let problem = v
909            .pointer("/components/schemas/Problem")
910            .expect("Problem schema missing");
911        assert!(
912            problem.get("$ref").is_none(),
913            "Problem must be a real object, not a self-ref"
914        );
915
916        // 2) Response under /paths/... references Problem and has correct media type
917        let path_obj = v
918            .pointer("/paths/~1tests~1v1~1problem-demo/get/responses/400")
919            .expect("400 response missing");
920
921        // Check what content types exist
922        let content_obj = path_obj.get("content").expect("content object missing");
923        assert!(
924            content_obj.get("application/problem+json").is_some(),
925            "application/problem+json content missing. Available content: {}",
926            serde_json::to_string_pretty(content_obj).unwrap()
927        );
928
929        let content = path_obj
930            .pointer("/content/application~1problem+json")
931            .expect("application/problem+json content missing");
932        // $ref to Problem
933        let schema_ref = content
934            .pointer("/schema/$ref")
935            .and_then(|r| r.as_str())
936            .unwrap_or("");
937        assert_eq!(schema_ref, "#/components/schemas/Problem");
938    }
939}
940
941#[cfg(test)]
942#[cfg_attr(coverage_nightly, coverage(off))]
943mod sse_openapi_tests {
944    use super::*;
945    use axum::Json;
946    use modkit::api::{Missing, OperationBuilder};
947    use serde_json::Value;
948
949    #[derive(Clone)]
950    #[modkit_macros::api_dto(request, response)]
951    struct UserEvent {
952        id: u32,
953        message: String,
954    }
955
956    async fn sse_handler() -> axum::response::sse::Sse<
957        impl futures_core::Stream<Item = Result<axum::response::sse::Event, std::convert::Infallible>>,
958    > {
959        let b = modkit::SseBroadcaster::<UserEvent>::new(4);
960        b.sse_response()
961    }
962
963    #[tokio::test]
964    async fn openapi_has_sse_content() {
965        let api = ApiGateway::default();
966        let router = axum::Router::new();
967
968        let _router = OperationBuilder::<Missing, Missing, ()>::get("/tests/v1/demo/sse")
969            .summary("Demo SSE")
970            .handler(sse_handler)
971            .public()
972            .sse_json::<UserEvent>(&api, "SSE of UserEvent")
973            .register(router, &api);
974
975        let doc = api.build_openapi().expect("openapi");
976        let v = serde_json::to_value(&doc).expect("json");
977
978        // schema is materialized
979        let schema = v
980            .pointer("/components/schemas/UserEvent")
981            .expect("UserEvent missing");
982        assert!(schema.get("$ref").is_none());
983
984        // content is text/event-stream with $ref to our schema
985        let refp = v
986            .pointer("/paths/~1tests~1v1~1demo~1sse/get/responses/200/content/text~1event-stream/schema/$ref")
987            .and_then(|x| x.as_str())
988            .unwrap_or_default();
989        assert_eq!(refp, "#/components/schemas/UserEvent");
990    }
991
992    #[tokio::test]
993    async fn openapi_sse_additional_response() {
994        async fn mixed_handler() -> Json<Value> {
995            Json(serde_json::json!({"ok": true}))
996        }
997
998        let api = ApiGateway::default();
999        let router = axum::Router::new();
1000
1001        let _router = OperationBuilder::<Missing, Missing, ()>::get("/tests/v1/demo/mixed")
1002            .summary("Mixed responses")
1003            .public()
1004            .handler(mixed_handler)
1005            .json_response(http::StatusCode::OK, "Success response")
1006            .sse_json::<UserEvent>(&api, "Additional SSE stream")
1007            .register(router, &api);
1008
1009        let doc = api.build_openapi().expect("openapi");
1010        let v = serde_json::to_value(&doc).expect("json");
1011
1012        // Check that both response types are present
1013        let responses = v
1014            .pointer("/paths/~1tests~1v1~1demo~1mixed/get/responses")
1015            .expect("responses");
1016
1017        // JSON response exists
1018        assert!(responses.get("200").is_some());
1019
1020        // SSE response exists (could be another 200 or different status)
1021        let response_content = responses.get("200").and_then(|r| r.get("content"));
1022        assert!(response_content.is_some());
1023
1024        // UserEvent schema is registered
1025        let schema = v
1026            .pointer("/components/schemas/UserEvent")
1027            .expect("UserEvent missing");
1028        assert!(schema.get("$ref").is_none());
1029    }
1030
1031    #[tokio::test]
1032    async fn test_axum_to_openapi_path_conversion() {
1033        // Define a route with path parameters using Axum 0.8+ style {id}
1034        async fn user_handler() -> Json<Value> {
1035            Json(serde_json::json!({"user_id": "123"}))
1036        }
1037
1038        let api = ApiGateway::default();
1039        let router = axum::Router::new();
1040
1041        let _router = OperationBuilder::<Missing, Missing, ()>::get("/tests/v1/users/{id}")
1042            .summary("Get user by ID")
1043            .public()
1044            .path_param("id", "User ID")
1045            .handler(user_handler)
1046            .json_response(http::StatusCode::OK, "User details")
1047            .register(router, &api);
1048
1049        // Verify the operation was stored with {id} path (same for Axum 0.8 and OpenAPI)
1050        let ops: Vec<_> = api
1051            .openapi_registry
1052            .operation_specs
1053            .iter()
1054            .map(|e| e.value().clone())
1055            .collect();
1056        assert_eq!(ops.len(), 1);
1057        assert_eq!(ops[0].path, "/tests/v1/users/{id}");
1058
1059        // Verify OpenAPI doc also has {id} (no conversion needed for regular params)
1060        let doc = api.build_openapi().expect("openapi");
1061        let v = serde_json::to_value(&doc).expect("json");
1062
1063        let paths = v.get("paths").expect("paths");
1064        assert!(
1065            paths.get("/tests/v1/users/{id}").is_some(),
1066            "OpenAPI should use {{id}} placeholder"
1067        );
1068    }
1069
1070    #[tokio::test]
1071    async fn test_multiple_path_params_conversion() {
1072        async fn item_handler() -> Json<Value> {
1073            Json(serde_json::json!({"ok": true}))
1074        }
1075
1076        let api = ApiGateway::default();
1077        let router = axum::Router::new();
1078
1079        let _router = OperationBuilder::<Missing, Missing, ()>::get(
1080            "/tests/v1/projects/{project_id}/items/{item_id}",
1081        )
1082        .summary("Get project item")
1083        .public()
1084        .path_param("project_id", "Project ID")
1085        .path_param("item_id", "Item ID")
1086        .handler(item_handler)
1087        .json_response(http::StatusCode::OK, "Item details")
1088        .register(router, &api);
1089
1090        // Verify storage and OpenAPI both use {param} syntax
1091        let ops: Vec<_> = api
1092            .openapi_registry
1093            .operation_specs
1094            .iter()
1095            .map(|e| e.value().clone())
1096            .collect();
1097        assert_eq!(
1098            ops[0].path,
1099            "/tests/v1/projects/{project_id}/items/{item_id}"
1100        );
1101
1102        let doc = api.build_openapi().expect("openapi");
1103        let v = serde_json::to_value(&doc).expect("json");
1104        let paths = v.get("paths").expect("paths");
1105        assert!(
1106            paths
1107                .get("/tests/v1/projects/{project_id}/items/{item_id}")
1108                .is_some()
1109        );
1110    }
1111
1112    #[tokio::test]
1113    async fn test_wildcard_path_conversion() {
1114        async fn static_handler() -> Json<Value> {
1115            Json(serde_json::json!({"ok": true}))
1116        }
1117
1118        let api = ApiGateway::default();
1119        let router = axum::Router::new();
1120
1121        // Axum 0.8 uses {*path} for wildcards
1122        let _router = OperationBuilder::<Missing, Missing, ()>::get("/tests/v1/static/{*path}")
1123            .summary("Serve static files")
1124            .public()
1125            .handler(static_handler)
1126            .json_response(http::StatusCode::OK, "File content")
1127            .register(router, &api);
1128
1129        // Verify internal storage keeps Axum wildcard syntax {*path}
1130        let ops: Vec<_> = api
1131            .openapi_registry
1132            .operation_specs
1133            .iter()
1134            .map(|e| e.value().clone())
1135            .collect();
1136        assert_eq!(ops[0].path, "/tests/v1/static/{*path}");
1137
1138        // Verify OpenAPI converts wildcard to {path} (without asterisk)
1139        let doc = api.build_openapi().expect("openapi");
1140        let v = serde_json::to_value(&doc).expect("json");
1141        let paths = v.get("paths").expect("paths");
1142        assert!(
1143            paths.get("/tests/v1/static/{path}").is_some(),
1144            "Wildcard {{*path}} should be converted to {{path}} in OpenAPI"
1145        );
1146        assert!(
1147            paths.get("/static/{*path}").is_none(),
1148            "OpenAPI should not have Axum-style {{*path}}"
1149        );
1150    }
1151
1152    #[tokio::test]
1153    async fn test_multipart_file_upload_openapi() {
1154        async fn upload_handler() -> Json<Value> {
1155            Json(serde_json::json!({"uploaded": true}))
1156        }
1157
1158        let api = ApiGateway::default();
1159        let router = axum::Router::new();
1160
1161        let _router = OperationBuilder::<Missing, Missing, ()>::post("/tests/v1/files/upload")
1162            .operation_id("upload_file")
1163            .public()
1164            .summary("Upload a file")
1165            .multipart_file_request("file", Some("File to upload"))
1166            .handler(upload_handler)
1167            .json_response(http::StatusCode::OK, "Upload successful")
1168            .register(router, &api);
1169
1170        // Build OpenAPI and verify multipart schema
1171        let doc = api.build_openapi().expect("openapi");
1172        let v = serde_json::to_value(&doc).expect("json");
1173
1174        let paths = v.get("paths").expect("paths");
1175        let upload_path = paths
1176            .get("/tests/v1/files/upload")
1177            .expect("/tests/v1/files/upload path");
1178        let post_op = upload_path.get("post").expect("POST operation");
1179
1180        // Verify request body exists
1181        let request_body = post_op.get("requestBody").expect("requestBody");
1182        let content = request_body.get("content").expect("content");
1183        let multipart = content
1184            .get("multipart/form-data")
1185            .expect("multipart/form-data content type");
1186
1187        // Verify schema structure
1188        let schema = multipart.get("schema").expect("schema");
1189        assert_eq!(
1190            schema.get("type").and_then(|v| v.as_str()),
1191            Some("object"),
1192            "Schema should be of type object"
1193        );
1194
1195        // Verify properties
1196        let properties = schema.get("properties").expect("properties");
1197        let file_prop = properties.get("file").expect("file property");
1198        assert_eq!(
1199            file_prop.get("type").and_then(|v| v.as_str()),
1200            Some("string"),
1201            "File field should be of type string"
1202        );
1203        assert_eq!(
1204            file_prop.get("format").and_then(|v| v.as_str()),
1205            Some("binary"),
1206            "File field should have format binary"
1207        );
1208
1209        // Verify required fields
1210        let required = schema.get("required").expect("required");
1211        let required_arr = required.as_array().expect("required should be array");
1212        assert_eq!(required_arr.len(), 1);
1213        assert_eq!(required_arr[0].as_str(), Some("file"));
1214    }
1215}