Skip to main content

api_gateway/
module.rs

1//! API Gateway Module definition
2//!
3//! Contains the `ApiGateway` module struct and its trait implementations.
4
5use async_trait::async_trait;
6use std::sync::Arc;
7
8use arc_swap::ArcSwap;
9use dashmap::DashMap;
10
11use anyhow::Result;
12use axum::http::Method;
13use axum::middleware::from_fn_with_state;
14use axum::{Router, extract::DefaultBodyLimit, middleware::from_fn, routing::get};
15use modkit::api::{OpenApiRegistry, OpenApiRegistryImpl};
16use modkit::lifecycle::ReadySignal;
17use parking_lot::Mutex;
18use std::net::SocketAddr;
19use std::time::Duration;
20use tokio_util::sync::CancellationToken;
21use tower_http::{
22    limit::RequestBodyLimitLayer,
23    request_id::{PropagateRequestIdLayer, SetRequestIdLayer},
24    timeout::TimeoutLayer,
25};
26use tracing::debug;
27
28use authn_resolver_sdk::AuthNResolverClient;
29
30use crate::config::ApiGatewayConfig;
31use crate::middleware::auth;
32use modkit_security::SecurityContext;
33use modkit_security::constants::{DEFAULT_SUBJECT_ID, DEFAULT_TENANT_ID};
34
35use crate::middleware;
36use crate::router_cache::RouterCache;
37use crate::web;
38
39/// Main API Gateway module — owns the HTTP server (`rest_host`) and collects
40/// typed operation specs to emit a single `OpenAPI` document.
41#[modkit::module(
42	name = "api-gateway",
43	capabilities = [rest_host, rest, stateful],
44    deps = ["grpc-hub", "authn-resolver"],
45	lifecycle(entry = "serve", stop_timeout = "30s", await_ready)
46)]
47pub struct ApiGateway {
48    // Lock-free config using arc-swap for read-mostly access
49    pub(crate) config: ArcSwap<ApiGatewayConfig>,
50    // OpenAPI registry for operations and schemas
51    pub(crate) openapi_registry: Arc<OpenApiRegistryImpl>,
52    // Built router cache for zero-lock hot path access
53    pub(crate) router_cache: RouterCache<axum::Router>,
54    // Store the finalized router from REST phase for serving
55    pub(crate) final_router: Mutex<Option<axum::Router>>,
56    // AuthN Resolver client (resolved during init, None when auth_disabled)
57    pub(crate) authn_client: Mutex<Option<Arc<dyn AuthNResolverClient>>>,
58
59    // Duplicate detection (per (method, path) and per handler id)
60    pub(crate) registered_routes: DashMap<(Method, String), ()>,
61    pub(crate) registered_handlers: DashMap<String, ()>,
62}
63
64impl Default for ApiGateway {
65    fn default() -> Self {
66        let default_router = Router::new();
67        Self {
68            config: ArcSwap::from_pointee(ApiGatewayConfig::default()),
69            openapi_registry: Arc::new(OpenApiRegistryImpl::new()),
70            router_cache: RouterCache::new(default_router),
71            final_router: Mutex::new(None),
72            authn_client: Mutex::new(None),
73            registered_routes: DashMap::new(),
74            registered_handlers: DashMap::new(),
75        }
76    }
77}
78
79impl ApiGateway {
80    fn apply_prefix_nesting(mut router: Router, prefix: &str) -> Router {
81        if prefix.is_empty() {
82            return router;
83        }
84
85        let top = Router::new()
86            .route("/health", get(web::health_check))
87            .route("/healthz", get(|| async { "ok" }));
88
89        router = Router::new().nest(prefix, router);
90        top.merge(router)
91    }
92
93    /// Create a new `ApiGateway` instance with the given configuration
94    #[must_use]
95    pub fn new(config: ApiGatewayConfig) -> Self {
96        let default_router = Router::new();
97        Self {
98            config: ArcSwap::from_pointee(config),
99            openapi_registry: Arc::new(OpenApiRegistryImpl::new()),
100            router_cache: RouterCache::new(default_router),
101            final_router: Mutex::new(None),
102            authn_client: Mutex::new(None),
103            registered_routes: DashMap::new(),
104            registered_handlers: DashMap::new(),
105        }
106    }
107
108    /// Get the current configuration (cheap clone from `ArcSwap`)
109    pub fn get_config(&self) -> ApiGatewayConfig {
110        (**self.config.load()).clone()
111    }
112
113    /// Get cached configuration (lock-free with `ArcSwap`)
114    pub fn get_cached_config(&self) -> ApiGatewayConfig {
115        (**self.config.load()).clone()
116    }
117
118    /// Get the cached router without rebuilding (useful for performance-critical paths)
119    pub fn get_cached_router(&self) -> Arc<Router> {
120        self.router_cache.load()
121    }
122
123    /// Force rebuild and cache of the router.
124    ///
125    /// # Errors
126    /// Returns an error if router building fails.
127    pub fn rebuild_and_cache_router(&self) -> Result<()> {
128        let new_router = self.build_router()?;
129        self.router_cache.store(new_router);
130        Ok(())
131    }
132
133    /// Build route policy from operation specs.
134    fn build_route_policy_from_specs(&self) -> Result<auth::GatewayRoutePolicy> {
135        let mut authenticated_routes = std::collections::HashSet::new();
136        let mut public_routes = std::collections::HashSet::new();
137
138        // Always mark built-in health check routes as public
139        public_routes.insert((Method::GET, "/health".to_owned()));
140        public_routes.insert((Method::GET, "/healthz".to_owned()));
141
142        public_routes.insert((Method::GET, "/docs".to_owned()));
143        public_routes.insert((Method::GET, "/openapi.json".to_owned()));
144
145        for spec in &self.openapi_registry.operation_specs {
146            let spec = spec.value();
147
148            let route_key = (spec.method.clone(), spec.path.clone());
149
150            if spec.authenticated {
151                authenticated_routes.insert(route_key.clone());
152            }
153
154            if spec.is_public {
155                public_routes.insert(route_key);
156            }
157        }
158
159        let config = self.get_cached_config();
160        let requirements_count = authenticated_routes.len();
161        let public_routes_count = public_routes.len();
162
163        let route_policy = auth::build_route_policy(&config, authenticated_routes, public_routes)?;
164
165        tracing::info!(
166            auth_disabled = config.auth_disabled,
167            require_auth_by_default = config.require_auth_by_default,
168            requirements_count = requirements_count,
169            public_routes_count = public_routes_count,
170            "Route policy built from operation specs"
171        );
172
173        Ok(route_policy)
174    }
175
176    fn normalize_prefix_path(raw: &str) -> Result<String> {
177        let trimmed = raw.trim();
178        // Collapse consecutive slashes then strip trailing slash(es).
179        let collapsed: String =
180            trimmed
181                .chars()
182                .fold(String::with_capacity(trimmed.len()), |mut acc, c| {
183                    if c == '/' && acc.ends_with('/') {
184                        // skip duplicate slash
185                    } else {
186                        acc.push(c);
187                    }
188                    acc
189                });
190        let prefix = collapsed.trim_end_matches('/');
191        let result = if prefix.is_empty() {
192            String::new()
193        } else if prefix.starts_with('/') {
194            prefix.to_owned()
195        } else {
196            format!("/{prefix}")
197        };
198        // Reject characters that are unsafe in URL paths or HTML attributes.
199        if !result
200            .bytes()
201            .all(|b| b.is_ascii_alphanumeric() || b == b'/' || b == b'_' || b == b'-' || b == b'.')
202        {
203            anyhow::bail!(
204                "prefix_path contains invalid characters (must match [a-zA-Z0-9/_\\-.]): {raw:?}"
205            );
206        }
207
208        if result.split('/').any(|seg| seg == "." || seg == "..") {
209            anyhow::bail!("prefix_path must not contain '.' or '..' segments: {raw:?}");
210        }
211
212        Ok(result)
213    }
214
215    /// Apply all middleware layers to a router (request ID, tracing, timeout, body limit, CORS, rate limiting, error mapping, auth)
216    pub(crate) fn apply_middleware_stack(
217        &self,
218        mut router: Router,
219        authn_client: Option<Arc<dyn AuthNResolverClient>>,
220    ) -> Result<Router> {
221        // Build route policy once
222        let route_policy = self.build_route_policy_from_specs()?;
223
224        // IMPORTANT: `axum::Router::layer(...)` behaves like Tower layers: the **last** added layer
225        // becomes the **outermost** layer and therefore runs **first** on the request path.
226        //
227        // Desired request execution order (outermost -> innermost):
228        // SetRequestId -> PropagateRequestId -> Trace -> push_req_id_to_extensions
229        // -> Timeout -> BodyLimit -> CORS -> MIME validation -> RateLimit -> ErrorMapping -> Auth -> Router
230        //
231        // Therefore we must add layers in the reverse order (innermost -> outermost) below.
232        // Due future refactoring, this order must be maintained.
233
234        let config = self.get_cached_config();
235
236        // Collect specs once; used by MIME validation + rate limiting maps.
237        let specs: Vec<_> = self
238            .openapi_registry
239            .operation_specs
240            .iter()
241            .map(|e| e.value().clone())
242            .collect();
243
244        // 11) License validation
245        let license_map = middleware::license_validation::LicenseRequirementMap::from_specs(&specs);
246
247        router = router.layer(from_fn(
248            move |req: axum::extract::Request, next: axum::middleware::Next| {
249                let map = license_map.clone();
250                middleware::license_validation::license_validation_middleware(map, req, next)
251            },
252        ));
253
254        // 10) Auth
255        if config.auth_disabled {
256            // Build security contexts for compatibility during migration
257            let default_security_context = SecurityContext::builder()
258                .subject_id(DEFAULT_SUBJECT_ID)
259                .subject_tenant_id(DEFAULT_TENANT_ID)
260                .build()?;
261
262            tracing::warn!(
263                "API Gateway auth is DISABLED: all requests will run with default tenant SecurityContext. \
264                 This mode bypasses authentication and is intended ONLY for single-user on-premises deployments without an IdP. \
265                 Permission checks and secure ORM still apply. DO NOT use this mode in multi-tenant or production environments."
266            );
267            router = router.layer(from_fn(
268                move |mut req: axum::extract::Request, next: axum::middleware::Next| {
269                    let sec_context = default_security_context.clone();
270                    async move {
271                        req.extensions_mut().insert(sec_context);
272                        next.run(req).await
273                    }
274                },
275            ));
276        } else if let Some(client) = authn_client {
277            let auth_state = auth::AuthState {
278                authn_client: client,
279                route_policy,
280            };
281            router = router.layer(from_fn_with_state(auth_state, auth::authn_middleware));
282        } else {
283            return Err(anyhow::anyhow!(
284                "auth is enabled but no AuthN Resolver client is available; \
285                 ensure `authn_resolver` module is loaded or set `auth_disabled: true`"
286            ));
287        }
288
289        // 9) Error mapping (outer to auth so it can translate auth/handler errors)
290        router = router.layer(from_fn(modkit::api::error_layer::error_mapping_middleware));
291
292        // 8) Per-route rate limiting & in-flight limits
293        let rate_map = middleware::rate_limit::RateLimiterMap::from_specs(&specs, &config)?;
294
295        router = router.layer(from_fn(
296            move |req: axum::extract::Request, next: axum::middleware::Next| {
297                let map = rate_map.clone();
298                middleware::rate_limit::rate_limit_middleware(map, req, next)
299            },
300        ));
301
302        // 7) MIME type validation
303        let mime_map = middleware::mime_validation::build_mime_validation_map(&specs);
304        router = router.layer(from_fn(
305            move |req: axum::extract::Request, next: axum::middleware::Next| {
306                let map = mime_map.clone();
307                middleware::mime_validation::mime_validation_middleware(map, req, next)
308            },
309        ));
310
311        // 6) CORS (must be outer to auth/limits so OPTIONS preflight short-circuits)
312        if config.cors_enabled {
313            router = router.layer(crate::cors::build_cors_layer(&config));
314        }
315
316        // 5) Body limit
317        router = router.layer(RequestBodyLimitLayer::new(config.defaults.body_limit_bytes));
318        router = router.layer(DefaultBodyLimit::max(config.defaults.body_limit_bytes));
319
320        // 4) Timeout
321        router = router.layer(TimeoutLayer::with_status_code(
322            axum::http::StatusCode::GATEWAY_TIMEOUT,
323            Duration::from_secs(30),
324        ));
325
326        // 3) Record request_id into span + extensions (requires span to exist first => must be inner to Trace)
327        router = router.layer(from_fn(middleware::request_id::push_req_id_to_extensions));
328
329        // 2) Trace (outer to push_req_id_to_extensions)
330        router = router.layer({
331            use modkit_http::otel;
332            use tower_http::trace::TraceLayer;
333            use tracing::field::Empty;
334
335            TraceLayer::new_for_http()
336                .make_span_with(move |req: &axum::http::Request<axum::body::Body>| {
337                    let hdr = middleware::request_id::header();
338                    let rid = req
339                        .headers()
340                        .get(&hdr)
341                        .and_then(|v| v.to_str().ok())
342                        .unwrap_or("n/a");
343
344                    let span = tracing::info_span!(
345                        "http_request",
346                        method = %req.method(),
347                        uri = %req.uri().path(),
348                        version = ?req.version(),
349                        module = "api_gateway",
350                        endpoint = %req.uri().path(),
351                        request_id = %rid,
352                        status = Empty,
353                        latency_ms = Empty,
354                        // OpenTelemetry semantic conventions
355                        "http.method" = %req.method(),
356                        "http.target" = %req.uri().path(),
357                        "http.scheme" = req.uri().scheme_str().unwrap_or("http"),
358                        "http.host" = req.headers().get("host")
359                            .and_then(|h| h.to_str().ok())
360                            .unwrap_or("unknown"),
361                        "user_agent.original" = req.headers().get("user-agent")
362                            .and_then(|h| h.to_str().ok())
363                            .unwrap_or("unknown"),
364                        // Trace context placeholders (for log correlation)
365                        trace_id = Empty,
366                        parent.trace_id = Empty
367                    );
368
369                    // Set parent OTel trace context (W3C traceparent), if any
370                    // This also populates trace_id and parent.trace_id from headers
371                    otel::set_parent_from_headers(&span, req.headers());
372
373                    span
374                })
375                .on_response(
376                    |res: &axum::http::Response<axum::body::Body>,
377                     latency: std::time::Duration,
378                     span: &tracing::Span| {
379                        let ms = latency.as_millis();
380                        span.record("status", res.status().as_u16());
381                        span.record("latency_ms", ms);
382                    },
383                )
384        });
385
386        // 1) Request ID handling
387        let x_request_id = crate::middleware::request_id::header();
388        // If missing, generate x-request-id first; then propagate it to the response.
389        router = router.layer(PropagateRequestIdLayer::new(x_request_id.clone()));
390        router = router.layer(SetRequestIdLayer::new(
391            x_request_id,
392            crate::middleware::request_id::MakeReqId,
393        ));
394
395        Ok(router)
396    }
397
398    /// Build the HTTP router from registered routes and operations.
399    ///
400    /// # Errors
401    /// Returns an error if router building or middleware setup fails.
402    pub fn build_router(&self) -> Result<Router> {
403        // If the cached router is currently held elsewhere (e.g., by the running server),
404        // return it without rebuilding to avoid unnecessary allocations.
405        let cached_router = self.router_cache.load();
406        if Arc::strong_count(&cached_router) > 1 {
407            tracing::debug!("Using cached router");
408            return Ok((*cached_router).clone());
409        }
410
411        tracing::debug!("Building new router (standalone/fallback mode)");
412        // In standalone mode (no REST pipeline), register both health endpoints here.
413        // In normal operation, rest_prepare() registers these instead.
414        let mut router = Router::new()
415            .route("/health", get(web::health_check))
416            .route("/healthz", get(|| async { "ok" }));
417
418        // Apply all middleware layers including auth, above the router
419        let authn_client = self.authn_client.lock().clone();
420        router = self.apply_middleware_stack(router, authn_client)?;
421
422        let config = self.get_cached_config();
423        let prefix = Self::normalize_prefix_path(&config.prefix_path)?;
424        router = Self::apply_prefix_nesting(router, &prefix);
425
426        // Cache the built router for future use
427        self.router_cache.store(router.clone());
428
429        Ok(router)
430    }
431
432    /// Build `OpenAPI` specification from registered routes and components.
433    ///
434    /// # Errors
435    /// Returns an error if `OpenAPI` specification building fails.
436    pub fn build_openapi(&self) -> Result<utoipa::openapi::OpenApi> {
437        let config = self.get_cached_config();
438        let info = modkit::api::OpenApiInfo {
439            title: config.openapi.title.clone(),
440            version: config.openapi.version.clone(),
441            description: config.openapi.description,
442        };
443        self.openapi_registry.build_openapi(&info)
444    }
445
446    /// Parse bind address from configuration string.
447    fn parse_bind_address(bind_addr: &str) -> anyhow::Result<SocketAddr> {
448        bind_addr
449            .parse()
450            .map_err(|e| anyhow::anyhow!("Invalid bind address '{bind_addr}': {e}"))
451    }
452
453    /// Get the finalized router or build a default one.
454    fn get_or_build_router(self: &Arc<Self>) -> anyhow::Result<Router> {
455        let stored = { self.final_router.lock().take() };
456
457        if let Some(router) = stored {
458            tracing::debug!("Using router from REST phase");
459            Ok(router)
460        } else {
461            tracing::debug!("No router from REST phase, building default router");
462            self.build_router()
463        }
464    }
465
466    /// Background HTTP server: bind, notify ready, serve until cancelled.
467    ///
468    /// This method is the lifecycle entry-point generated by the macro
469    /// (`#[modkit::module(..., lifecycle(...))]`).
470    pub(crate) async fn serve(
471        self: Arc<Self>,
472        cancel: CancellationToken,
473        ready: ReadySignal,
474    ) -> anyhow::Result<()> {
475        let cfg = self.get_cached_config();
476        let addr = Self::parse_bind_address(&cfg.bind_addr)?;
477        let router = self.get_or_build_router()?;
478
479        // Bind the socket, only now consider the service "ready"
480        let listener = tokio::net::TcpListener::bind(addr).await?;
481        tracing::info!("HTTP server bound on {}", addr);
482        ready.notify(); // Starting -> Running
483
484        // Graceful shutdown on cancel
485        let shutdown = {
486            let cancel = cancel.clone();
487            async move {
488                cancel.cancelled().await;
489                tracing::info!("HTTP server shutting down gracefully (cancellation)");
490            }
491        };
492
493        axum::serve(listener, router)
494            .with_graceful_shutdown(shutdown)
495            .await
496            .map_err(|e| anyhow::anyhow!(e))
497    }
498
499    /// Check if `handler_id` is already registered (returns true if duplicate)
500    fn check_duplicate_handler(&self, spec: &modkit::api::OperationSpec) -> bool {
501        if self
502            .registered_handlers
503            .insert(spec.handler_id.clone(), ())
504            .is_some()
505        {
506            tracing::error!(
507                handler_id = %spec.handler_id,
508                method = %spec.method.as_str(),
509                path = %spec.path,
510                "Duplicate handler_id detected; ignoring subsequent registration"
511            );
512            return true;
513        }
514        false
515    }
516
517    /// Check if route (method, path) is already registered (returns true if duplicate)
518    fn check_duplicate_route(&self, spec: &modkit::api::OperationSpec) -> bool {
519        let route_key = (spec.method.clone(), spec.path.clone());
520        if self.registered_routes.insert(route_key, ()).is_some() {
521            tracing::error!(
522                method = %spec.method.as_str(),
523                path = %spec.path,
524                "Duplicate (method, path) detected; ignoring subsequent registration"
525            );
526            return true;
527        }
528        false
529    }
530
531    /// Log successful operation registration
532    fn log_operation_registration(&self, spec: &modkit::api::OperationSpec) {
533        let current_count = self.openapi_registry.operation_specs.len();
534        tracing::debug!(
535            handler_id = %spec.handler_id,
536            method = %spec.method.as_str(),
537            path = %spec.path,
538            summary = %spec.summary.as_deref().unwrap_or("No summary"),
539            total_operations = current_count,
540            "Registered API operation"
541        );
542    }
543
544    /// Add `OpenAPI` documentation routes to the router
545    fn add_openapi_routes(&self, mut router: axum::Router) -> anyhow::Result<axum::Router> {
546        // Build once, serve as static JSON (no per-request parsing)
547        let op_count = self.openapi_registry.operation_specs.len();
548        tracing::info!(
549            "rest_finalize: emitting OpenAPI with {} operations",
550            op_count
551        );
552
553        let openapi_doc = Arc::new(self.build_openapi()?);
554        let config = self.get_cached_config();
555        let prefix = Self::normalize_prefix_path(&config.prefix_path)?;
556        let html_doc = web::serve_docs(&prefix);
557
558        router = router
559            .route(
560                "/openapi.json",
561                get({
562                    use axum::{http::header, response::IntoResponse};
563                    let doc = openapi_doc;
564                    move || async move {
565                        let json_string = match serde_json::to_string_pretty(doc.as_ref()) {
566                            Ok(json) => json,
567                            Err(e) => {
568                                tracing::error!("Failed to serialize OpenAPI doc: {}", e);
569                                return (http::StatusCode::INTERNAL_SERVER_ERROR).into_response();
570                            }
571                        };
572                        (
573                            [
574                                (header::CONTENT_TYPE, "application/json"),
575                                (header::CACHE_CONTROL, "no-store"),
576                            ],
577                            json_string,
578                        )
579                            .into_response()
580                    }
581                }),
582            )
583            .route("/docs", get(move || async move { html_doc }));
584
585        #[cfg(feature = "embed_elements")]
586        {
587            router = router.route(
588                "/docs/assets/{*file}",
589                get(crate::assets::serve_elements_asset),
590            );
591        }
592
593        Ok(router)
594    }
595}
596
597// Manual implementation of Module trait with config loading
598#[async_trait]
599impl modkit::Module for ApiGateway {
600    async fn init(&self, ctx: &modkit::context::ModuleCtx) -> anyhow::Result<()> {
601        let cfg = ctx.config::<crate::config::ApiGatewayConfig>()?;
602        self.config.store(Arc::new(cfg.clone()));
603
604        debug!(
605            "Effective api_gateway configuration:\n{:#?}",
606            self.config.load()
607        );
608
609        if cfg.auth_disabled {
610            tracing::info!(
611                tenant_id = %DEFAULT_TENANT_ID,
612                "Auth-disabled mode enabled with default tenant"
613            );
614        } else {
615            // Resolve AuthN Resolver client from ClientHub
616            let authn_client = ctx.client_hub().get::<dyn AuthNResolverClient>()?;
617            *self.authn_client.lock() = Some(authn_client);
618            tracing::info!("AuthN Resolver client resolved from ClientHub");
619        }
620
621        Ok(())
622    }
623}
624
625// REST host role: prepare/finalize the router, but do not start the server here.
626impl modkit::contracts::ApiGatewayCapability for ApiGateway {
627    fn rest_prepare(
628        &self,
629        _ctx: &modkit::context::ModuleCtx,
630        router: axum::Router,
631    ) -> anyhow::Result<axum::Router> {
632        // Add health check endpoints:
633        // - /health: detailed JSON response with status and timestamp
634        // - /healthz: simple "ok" liveness probe (Kubernetes-style)
635        let router = router
636            .route("/health", get(web::health_check))
637            .route("/healthz", get(|| async { "ok" }));
638
639        // You may attach global middlewares here (trace, compression, cors), but do not start server.
640        tracing::debug!("REST host prepared base router with health check endpoints");
641        Ok(router)
642    }
643
644    fn rest_finalize(
645        &self,
646        _ctx: &modkit::context::ModuleCtx,
647        mut router: axum::Router,
648    ) -> anyhow::Result<axum::Router> {
649        let config = self.get_cached_config();
650
651        if config.enable_docs {
652            router = self.add_openapi_routes(router)?;
653        }
654
655        // Apply middleware stack (including auth) to the final router
656        tracing::debug!("Applying middleware stack to finalized router");
657        let authn_client = self.authn_client.lock().clone();
658        router = self.apply_middleware_stack(router, authn_client)?;
659
660        let prefix = Self::normalize_prefix_path(&config.prefix_path)?;
661        router = Self::apply_prefix_nesting(router, &prefix);
662
663        // Keep the finalized router to be used by `serve()`
664        *self.final_router.lock() = Some(router.clone());
665
666        tracing::info!("REST host finalized router with OpenAPI endpoints and auth middleware");
667        Ok(router)
668    }
669
670    fn as_registry(&self) -> &dyn modkit::contracts::OpenApiRegistry {
671        self
672    }
673}
674
675impl modkit::contracts::RestApiCapability for ApiGateway {
676    fn register_rest(
677        &self,
678        _ctx: &modkit::context::ModuleCtx,
679        router: axum::Router,
680        _openapi: &dyn modkit::contracts::OpenApiRegistry,
681    ) -> anyhow::Result<axum::Router> {
682        // This module acts as both rest_host and rest, but actual REST endpoints
683        // are handled in the host methods above.
684        Ok(router)
685    }
686}
687
688impl OpenApiRegistry for ApiGateway {
689    fn register_operation(&self, spec: &modkit::api::OperationSpec) {
690        // Reject duplicates with "first wins" policy (second registration = programmer error).
691        if self.check_duplicate_handler(spec) {
692            return;
693        }
694
695        if self.check_duplicate_route(spec) {
696            return;
697        }
698
699        // Delegate to the internal registry
700        self.openapi_registry.register_operation(spec);
701        self.log_operation_registration(spec);
702    }
703
704    fn ensure_schema_raw(
705        &self,
706        root_name: &str,
707        schemas: Vec<(
708            String,
709            utoipa::openapi::RefOr<utoipa::openapi::schema::Schema>,
710        )>,
711    ) -> String {
712        // Delegate to the internal registry
713        self.openapi_registry.ensure_schema_raw(root_name, schemas)
714    }
715
716    fn as_any(&self) -> &dyn std::any::Any {
717        self
718    }
719}
720
721#[cfg(test)]
722#[cfg_attr(coverage_nightly, coverage(off))]
723mod tests {
724    use super::*;
725
726    #[test]
727    fn test_openapi_generation() {
728        let mut config = ApiGatewayConfig::default();
729        config.openapi.title = "Test API".to_owned();
730        config.openapi.version = "1.0.0".to_owned();
731        config.openapi.description = Some("Test Description".to_owned());
732        let api = ApiGateway::new(config);
733
734        // Test that we can build OpenAPI without any operations
735        let doc = api.build_openapi().unwrap();
736        let json = serde_json::to_value(&doc).unwrap();
737
738        // Verify it's valid OpenAPI document structure
739        assert!(json.get("openapi").is_some());
740        assert!(json.get("info").is_some());
741        assert!(json.get("paths").is_some());
742
743        // Verify info section
744        let info = json.get("info").unwrap();
745        assert_eq!(info.get("title").unwrap(), "Test API");
746        assert_eq!(info.get("version").unwrap(), "1.0.0");
747        assert_eq!(info.get("description").unwrap(), "Test Description");
748    }
749}
750
751#[cfg(test)]
752#[cfg_attr(coverage_nightly, coverage(off))]
753mod normalize_prefix_path_tests {
754    use super::*;
755
756    #[test]
757    fn empty_string_returns_empty() {
758        assert_eq!(ApiGateway::normalize_prefix_path("").unwrap(), "");
759    }
760
761    #[test]
762    fn sole_slash_returns_empty() {
763        assert_eq!(ApiGateway::normalize_prefix_path("/").unwrap(), "");
764    }
765
766    #[test]
767    fn multiple_slashes_return_empty() {
768        assert_eq!(ApiGateway::normalize_prefix_path("///").unwrap(), "");
769    }
770
771    #[test]
772    fn whitespace_only_returns_empty() {
773        assert_eq!(ApiGateway::normalize_prefix_path("   ").unwrap(), "");
774    }
775
776    #[test]
777    fn simple_prefix_preserved() {
778        assert_eq!(ApiGateway::normalize_prefix_path("/cf").unwrap(), "/cf");
779    }
780
781    #[test]
782    fn trailing_slash_stripped() {
783        assert_eq!(ApiGateway::normalize_prefix_path("/cf/").unwrap(), "/cf");
784    }
785
786    #[test]
787    fn leading_slash_prepended_when_missing() {
788        assert_eq!(ApiGateway::normalize_prefix_path("cf").unwrap(), "/cf");
789    }
790
791    #[test]
792    fn consecutive_leading_slashes_collapsed() {
793        assert_eq!(ApiGateway::normalize_prefix_path("//cf").unwrap(), "/cf");
794    }
795
796    #[test]
797    fn consecutive_slashes_mid_path_collapsed() {
798        assert_eq!(
799            ApiGateway::normalize_prefix_path("/api//v1").unwrap(),
800            "/api/v1"
801        );
802    }
803
804    #[test]
805    fn many_consecutive_slashes_collapsed() {
806        assert_eq!(
807            ApiGateway::normalize_prefix_path("///api///v1///").unwrap(),
808            "/api/v1"
809        );
810    }
811
812    #[test]
813    fn surrounding_whitespace_trimmed() {
814        assert_eq!(ApiGateway::normalize_prefix_path("  /cf  ").unwrap(), "/cf");
815    }
816
817    #[test]
818    fn nested_path_preserved() {
819        assert_eq!(
820            ApiGateway::normalize_prefix_path("/api/v1").unwrap(),
821            "/api/v1"
822        );
823    }
824
825    #[test]
826    fn dot_in_path_allowed() {
827        assert_eq!(
828            ApiGateway::normalize_prefix_path("/api/v1.0").unwrap(),
829            "/api/v1.0"
830        );
831    }
832
833    #[test]
834    fn rejects_html_injection() {
835        let result = ApiGateway::normalize_prefix_path(r#""><script>alert(1)</script>"#);
836        assert!(result.is_err());
837    }
838
839    #[test]
840    fn rejects_spaces_in_path() {
841        let result = ApiGateway::normalize_prefix_path("/my path");
842        assert!(result.is_err());
843    }
844
845    #[test]
846    fn rejects_query_string_chars() {
847        let result = ApiGateway::normalize_prefix_path("/api?foo=bar");
848        assert!(result.is_err());
849    }
850}
851
852#[cfg(test)]
853#[cfg_attr(coverage_nightly, coverage(off))]
854mod problem_openapi_tests {
855    use super::*;
856    use axum::Json;
857    use modkit::api::{Missing, OperationBuilder};
858    use serde_json::Value;
859
860    async fn dummy_handler() -> Json<Value> {
861        Json(serde_json::json!({"ok": true}))
862    }
863
864    #[tokio::test]
865    async fn openapi_includes_problem_schema_and_response() {
866        let api = ApiGateway::default();
867        let router = axum::Router::new();
868
869        // Build a route with a problem+json response
870        let _router = OperationBuilder::<Missing, Missing, ()>::get("/tests/v1/problem-demo")
871            .public()
872            .summary("Problem demo")
873            .problem_response(&api, http::StatusCode::BAD_REQUEST, "Bad Request") // <-- registers Problem + sets content type
874            .handler(dummy_handler)
875            .register(router, &api);
876
877        let doc = api.build_openapi().expect("openapi");
878        let v = serde_json::to_value(&doc).expect("json");
879
880        // 1) Problem exists in components.schemas
881        let problem = v
882            .pointer("/components/schemas/Problem")
883            .expect("Problem schema missing");
884        assert!(
885            problem.get("$ref").is_none(),
886            "Problem must be a real object, not a self-ref"
887        );
888
889        // 2) Response under /paths/... references Problem and has correct media type
890        let path_obj = v
891            .pointer("/paths/~1tests~1v1~1problem-demo/get/responses/400")
892            .expect("400 response missing");
893
894        // Check what content types exist
895        let content_obj = path_obj.get("content").expect("content object missing");
896        assert!(
897            content_obj.get("application/problem+json").is_some(),
898            "application/problem+json content missing. Available content: {}",
899            serde_json::to_string_pretty(content_obj).unwrap()
900        );
901
902        let content = path_obj
903            .pointer("/content/application~1problem+json")
904            .expect("application/problem+json content missing");
905        // $ref to Problem
906        let schema_ref = content
907            .pointer("/schema/$ref")
908            .and_then(|r| r.as_str())
909            .unwrap_or("");
910        assert_eq!(schema_ref, "#/components/schemas/Problem");
911    }
912}
913
914#[cfg(test)]
915#[cfg_attr(coverage_nightly, coverage(off))]
916mod sse_openapi_tests {
917    use super::*;
918    use axum::Json;
919    use modkit::api::{Missing, OperationBuilder};
920    use serde_json::Value;
921
922    #[derive(Clone)]
923    #[modkit_macros::api_dto(request, response)]
924    struct UserEvent {
925        id: u32,
926        message: String,
927    }
928
929    async fn sse_handler() -> axum::response::sse::Sse<
930        impl futures_core::Stream<Item = Result<axum::response::sse::Event, std::convert::Infallible>>,
931    > {
932        let b = modkit::SseBroadcaster::<UserEvent>::new(4);
933        b.sse_response()
934    }
935
936    #[tokio::test]
937    async fn openapi_has_sse_content() {
938        let api = ApiGateway::default();
939        let router = axum::Router::new();
940
941        let _router = OperationBuilder::<Missing, Missing, ()>::get("/tests/v1/demo/sse")
942            .summary("Demo SSE")
943            .handler(sse_handler)
944            .public()
945            .sse_json::<UserEvent>(&api, "SSE of UserEvent")
946            .register(router, &api);
947
948        let doc = api.build_openapi().expect("openapi");
949        let v = serde_json::to_value(&doc).expect("json");
950
951        // schema is materialized
952        let schema = v
953            .pointer("/components/schemas/UserEvent")
954            .expect("UserEvent missing");
955        assert!(schema.get("$ref").is_none());
956
957        // content is text/event-stream with $ref to our schema
958        let refp = v
959            .pointer("/paths/~1tests~1v1~1demo~1sse/get/responses/200/content/text~1event-stream/schema/$ref")
960            .and_then(|x| x.as_str())
961            .unwrap_or_default();
962        assert_eq!(refp, "#/components/schemas/UserEvent");
963    }
964
965    #[tokio::test]
966    async fn openapi_sse_additional_response() {
967        async fn mixed_handler() -> Json<Value> {
968            Json(serde_json::json!({"ok": true}))
969        }
970
971        let api = ApiGateway::default();
972        let router = axum::Router::new();
973
974        let _router = OperationBuilder::<Missing, Missing, ()>::get("/tests/v1/demo/mixed")
975            .summary("Mixed responses")
976            .public()
977            .handler(mixed_handler)
978            .json_response(http::StatusCode::OK, "Success response")
979            .sse_json::<UserEvent>(&api, "Additional SSE stream")
980            .register(router, &api);
981
982        let doc = api.build_openapi().expect("openapi");
983        let v = serde_json::to_value(&doc).expect("json");
984
985        // Check that both response types are present
986        let responses = v
987            .pointer("/paths/~1tests~1v1~1demo~1mixed/get/responses")
988            .expect("responses");
989
990        // JSON response exists
991        assert!(responses.get("200").is_some());
992
993        // SSE response exists (could be another 200 or different status)
994        let response_content = responses.get("200").and_then(|r| r.get("content"));
995        assert!(response_content.is_some());
996
997        // UserEvent schema is registered
998        let schema = v
999            .pointer("/components/schemas/UserEvent")
1000            .expect("UserEvent missing");
1001        assert!(schema.get("$ref").is_none());
1002    }
1003
1004    #[tokio::test]
1005    async fn test_axum_to_openapi_path_conversion() {
1006        // Define a route with path parameters using Axum 0.8+ style {id}
1007        async fn user_handler() -> Json<Value> {
1008            Json(serde_json::json!({"user_id": "123"}))
1009        }
1010
1011        let api = ApiGateway::default();
1012        let router = axum::Router::new();
1013
1014        let _router = OperationBuilder::<Missing, Missing, ()>::get("/tests/v1/users/{id}")
1015            .summary("Get user by ID")
1016            .public()
1017            .path_param("id", "User ID")
1018            .handler(user_handler)
1019            .json_response(http::StatusCode::OK, "User details")
1020            .register(router, &api);
1021
1022        // Verify the operation was stored with {id} path (same for Axum 0.8 and OpenAPI)
1023        let ops: Vec<_> = api
1024            .openapi_registry
1025            .operation_specs
1026            .iter()
1027            .map(|e| e.value().clone())
1028            .collect();
1029        assert_eq!(ops.len(), 1);
1030        assert_eq!(ops[0].path, "/tests/v1/users/{id}");
1031
1032        // Verify OpenAPI doc also has {id} (no conversion needed for regular params)
1033        let doc = api.build_openapi().expect("openapi");
1034        let v = serde_json::to_value(&doc).expect("json");
1035
1036        let paths = v.get("paths").expect("paths");
1037        assert!(
1038            paths.get("/tests/v1/users/{id}").is_some(),
1039            "OpenAPI should use {{id}} placeholder"
1040        );
1041    }
1042
1043    #[tokio::test]
1044    async fn test_multiple_path_params_conversion() {
1045        async fn item_handler() -> Json<Value> {
1046            Json(serde_json::json!({"ok": true}))
1047        }
1048
1049        let api = ApiGateway::default();
1050        let router = axum::Router::new();
1051
1052        let _router = OperationBuilder::<Missing, Missing, ()>::get(
1053            "/tests/v1/projects/{project_id}/items/{item_id}",
1054        )
1055        .summary("Get project item")
1056        .public()
1057        .path_param("project_id", "Project ID")
1058        .path_param("item_id", "Item ID")
1059        .handler(item_handler)
1060        .json_response(http::StatusCode::OK, "Item details")
1061        .register(router, &api);
1062
1063        // Verify storage and OpenAPI both use {param} syntax
1064        let ops: Vec<_> = api
1065            .openapi_registry
1066            .operation_specs
1067            .iter()
1068            .map(|e| e.value().clone())
1069            .collect();
1070        assert_eq!(
1071            ops[0].path,
1072            "/tests/v1/projects/{project_id}/items/{item_id}"
1073        );
1074
1075        let doc = api.build_openapi().expect("openapi");
1076        let v = serde_json::to_value(&doc).expect("json");
1077        let paths = v.get("paths").expect("paths");
1078        assert!(
1079            paths
1080                .get("/tests/v1/projects/{project_id}/items/{item_id}")
1081                .is_some()
1082        );
1083    }
1084
1085    #[tokio::test]
1086    async fn test_wildcard_path_conversion() {
1087        async fn static_handler() -> Json<Value> {
1088            Json(serde_json::json!({"ok": true}))
1089        }
1090
1091        let api = ApiGateway::default();
1092        let router = axum::Router::new();
1093
1094        // Axum 0.8 uses {*path} for wildcards
1095        let _router = OperationBuilder::<Missing, Missing, ()>::get("/tests/v1/static/{*path}")
1096            .summary("Serve static files")
1097            .public()
1098            .handler(static_handler)
1099            .json_response(http::StatusCode::OK, "File content")
1100            .register(router, &api);
1101
1102        // Verify internal storage keeps Axum wildcard syntax {*path}
1103        let ops: Vec<_> = api
1104            .openapi_registry
1105            .operation_specs
1106            .iter()
1107            .map(|e| e.value().clone())
1108            .collect();
1109        assert_eq!(ops[0].path, "/tests/v1/static/{*path}");
1110
1111        // Verify OpenAPI converts wildcard to {path} (without asterisk)
1112        let doc = api.build_openapi().expect("openapi");
1113        let v = serde_json::to_value(&doc).expect("json");
1114        let paths = v.get("paths").expect("paths");
1115        assert!(
1116            paths.get("/tests/v1/static/{path}").is_some(),
1117            "Wildcard {{*path}} should be converted to {{path}} in OpenAPI"
1118        );
1119        assert!(
1120            paths.get("/static/{*path}").is_none(),
1121            "OpenAPI should not have Axum-style {{*path}}"
1122        );
1123    }
1124
1125    #[tokio::test]
1126    async fn test_multipart_file_upload_openapi() {
1127        async fn upload_handler() -> Json<Value> {
1128            Json(serde_json::json!({"uploaded": true}))
1129        }
1130
1131        let api = ApiGateway::default();
1132        let router = axum::Router::new();
1133
1134        let _router = OperationBuilder::<Missing, Missing, ()>::post("/tests/v1/files/upload")
1135            .operation_id("upload_file")
1136            .public()
1137            .summary("Upload a file")
1138            .multipart_file_request("file", Some("File to upload"))
1139            .handler(upload_handler)
1140            .json_response(http::StatusCode::OK, "Upload successful")
1141            .register(router, &api);
1142
1143        // Build OpenAPI and verify multipart schema
1144        let doc = api.build_openapi().expect("openapi");
1145        let v = serde_json::to_value(&doc).expect("json");
1146
1147        let paths = v.get("paths").expect("paths");
1148        let upload_path = paths
1149            .get("/tests/v1/files/upload")
1150            .expect("/tests/v1/files/upload path");
1151        let post_op = upload_path.get("post").expect("POST operation");
1152
1153        // Verify request body exists
1154        let request_body = post_op.get("requestBody").expect("requestBody");
1155        let content = request_body.get("content").expect("content");
1156        let multipart = content
1157            .get("multipart/form-data")
1158            .expect("multipart/form-data content type");
1159
1160        // Verify schema structure
1161        let schema = multipart.get("schema").expect("schema");
1162        assert_eq!(
1163            schema.get("type").and_then(|v| v.as_str()),
1164            Some("object"),
1165            "Schema should be of type object"
1166        );
1167
1168        // Verify properties
1169        let properties = schema.get("properties").expect("properties");
1170        let file_prop = properties.get("file").expect("file property");
1171        assert_eq!(
1172            file_prop.get("type").and_then(|v| v.as_str()),
1173            Some("string"),
1174            "File field should be of type string"
1175        );
1176        assert_eq!(
1177            file_prop.get("format").and_then(|v| v.as_str()),
1178            Some("binary"),
1179            "File field should have format binary"
1180        );
1181
1182        // Verify required fields
1183        let required = schema.get("required").expect("required");
1184        let required_arr = required.as_array().expect("required should be array");
1185        assert_eq!(required_arr.len(), 1);
1186        assert_eq!(required_arr[0].as_str(), Some("file"));
1187    }
1188}