Skip to main content

api_gateway/
module.rs

1//! API Gateway Module definition
2//!
3//! Contains the `ApiGateway` module struct and its trait implementations.
4
5use async_trait::async_trait;
6use std::sync::Arc;
7
8use arc_swap::ArcSwap;
9use dashmap::DashMap;
10
11use anyhow::Result;
12use axum::http::Method;
13use axum::middleware::from_fn_with_state;
14use axum::{Router, extract::DefaultBodyLimit, middleware::from_fn, routing::get};
15use modkit::api::{OpenApiRegistry, OpenApiRegistryImpl};
16use modkit::lifecycle::ReadySignal;
17use parking_lot::Mutex;
18use std::net::SocketAddr;
19use std::time::Duration;
20use tokio_util::sync::CancellationToken;
21use tower_http::{
22    catch_panic::CatchPanicLayer,
23    limit::RequestBodyLimitLayer,
24    request_id::{PropagateRequestIdLayer, SetRequestIdLayer},
25    timeout::TimeoutLayer,
26};
27use tracing::debug;
28
29use authn_resolver_sdk::AuthNResolverClient;
30
31use crate::config::ApiGatewayConfig;
32use crate::middleware::auth;
33use modkit_security::SecurityContext;
34use modkit_security::constants::{DEFAULT_SUBJECT_ID, DEFAULT_TENANT_ID};
35
36use crate::middleware;
37use crate::router_cache::RouterCache;
38use crate::web;
39
40/// Main API Gateway module — owns the HTTP server (`rest_host`) and collects
41/// typed operation specs to emit a single `OpenAPI` document.
42#[modkit::module(
43	name = "api-gateway",
44	capabilities = [rest_host, rest, stateful],
45    deps = ["grpc-hub", "authn-resolver"],
46	lifecycle(entry = "serve", stop_timeout = "30s", await_ready)
47)]
48pub struct ApiGateway {
49    // Lock-free config using arc-swap for read-mostly access
50    pub(crate) config: ArcSwap<ApiGatewayConfig>,
51    // OpenAPI registry for operations and schemas
52    pub(crate) openapi_registry: Arc<OpenApiRegistryImpl>,
53    // Built router cache for zero-lock hot path access
54    pub(crate) router_cache: RouterCache<axum::Router>,
55    // Store the finalized router from REST phase for serving
56    pub(crate) final_router: Mutex<Option<axum::Router>>,
57    // AuthN Resolver client (resolved during init, None when auth_disabled)
58    pub(crate) authn_client: Mutex<Option<Arc<dyn AuthNResolverClient>>>,
59
60    // Duplicate detection (per (method, path) and per handler id)
61    pub(crate) registered_routes: DashMap<(Method, String), ()>,
62    pub(crate) registered_handlers: DashMap<String, ()>,
63}
64
65impl Default for ApiGateway {
66    fn default() -> Self {
67        let default_router = Router::new();
68        Self {
69            config: ArcSwap::from_pointee(ApiGatewayConfig::default()),
70            openapi_registry: Arc::new(OpenApiRegistryImpl::new()),
71            router_cache: RouterCache::new(default_router),
72            final_router: Mutex::new(None),
73            authn_client: Mutex::new(None),
74            registered_routes: DashMap::new(),
75            registered_handlers: DashMap::new(),
76        }
77    }
78}
79
80impl ApiGateway {
81    fn apply_prefix_nesting(mut router: Router, prefix: &str) -> Router {
82        if prefix.is_empty() {
83            return router;
84        }
85
86        let top = Router::new()
87            .route("/health", get(web::health_check))
88            .route("/healthz", get(|| async { "ok" }));
89
90        router = Router::new().nest(prefix, router);
91        top.merge(router)
92    }
93
94    /// Create a new `ApiGateway` instance with the given configuration
95    #[must_use]
96    pub fn new(config: ApiGatewayConfig) -> Self {
97        let default_router = Router::new();
98        Self {
99            config: ArcSwap::from_pointee(config),
100            openapi_registry: Arc::new(OpenApiRegistryImpl::new()),
101            router_cache: RouterCache::new(default_router),
102            final_router: Mutex::new(None),
103            authn_client: Mutex::new(None),
104            registered_routes: DashMap::new(),
105            registered_handlers: DashMap::new(),
106        }
107    }
108
109    /// Get the current configuration (cheap clone from `ArcSwap`)
110    pub fn get_config(&self) -> ApiGatewayConfig {
111        (**self.config.load()).clone()
112    }
113
114    /// Get cached configuration (lock-free with `ArcSwap`)
115    pub fn get_cached_config(&self) -> ApiGatewayConfig {
116        (**self.config.load()).clone()
117    }
118
119    /// Get the cached router without rebuilding (useful for performance-critical paths)
120    pub fn get_cached_router(&self) -> Arc<Router> {
121        self.router_cache.load()
122    }
123
124    /// Force rebuild and cache of the router.
125    ///
126    /// # Errors
127    /// Returns an error if router building fails.
128    pub fn rebuild_and_cache_router(&self) -> Result<()> {
129        let new_router = self.build_router()?;
130        self.router_cache.store(new_router);
131        Ok(())
132    }
133
134    /// Build route policy from operation specs.
135    fn build_route_policy_from_specs(&self) -> Result<auth::GatewayRoutePolicy> {
136        let mut authenticated_routes = std::collections::HashSet::new();
137        let mut public_routes = std::collections::HashSet::new();
138
139        // Always mark built-in health check routes as public
140        public_routes.insert((Method::GET, "/health".to_owned()));
141        public_routes.insert((Method::GET, "/healthz".to_owned()));
142
143        public_routes.insert((Method::GET, "/docs".to_owned()));
144        public_routes.insert((Method::GET, "/openapi.json".to_owned()));
145
146        for spec in &self.openapi_registry.operation_specs {
147            let spec = spec.value();
148
149            let route_key = (spec.method.clone(), spec.path.clone());
150
151            if spec.authenticated {
152                authenticated_routes.insert(route_key.clone());
153            }
154
155            if spec.is_public {
156                public_routes.insert(route_key);
157            }
158        }
159
160        let config = self.get_cached_config();
161        let requirements_count = authenticated_routes.len();
162        let public_routes_count = public_routes.len();
163
164        let route_policy = auth::build_route_policy(&config, authenticated_routes, public_routes)?;
165
166        tracing::info!(
167            auth_disabled = config.auth_disabled,
168            require_auth_by_default = config.require_auth_by_default,
169            requirements_count = requirements_count,
170            public_routes_count = public_routes_count,
171            "Route policy built from operation specs"
172        );
173
174        Ok(route_policy)
175    }
176
177    fn normalize_prefix_path(raw: &str) -> Result<String> {
178        let trimmed = raw.trim();
179        // Collapse consecutive slashes then strip trailing slash(es).
180        let collapsed: String =
181            trimmed
182                .chars()
183                .fold(String::with_capacity(trimmed.len()), |mut acc, c| {
184                    if c == '/' && acc.ends_with('/') {
185                        // skip duplicate slash
186                    } else {
187                        acc.push(c);
188                    }
189                    acc
190                });
191        let prefix = collapsed.trim_end_matches('/');
192        let result = if prefix.is_empty() {
193            String::new()
194        } else if prefix.starts_with('/') {
195            prefix.to_owned()
196        } else {
197            format!("/{prefix}")
198        };
199        // Reject characters that are unsafe in URL paths or HTML attributes.
200        if !result
201            .bytes()
202            .all(|b| b.is_ascii_alphanumeric() || b == b'/' || b == b'_' || b == b'-' || b == b'.')
203        {
204            anyhow::bail!(
205                "prefix_path contains invalid characters (must match [a-zA-Z0-9/_\\-.]): {raw:?}"
206            );
207        }
208
209        if result.split('/').any(|seg| seg == "." || seg == "..") {
210            anyhow::bail!("prefix_path must not contain '.' or '..' segments: {raw:?}");
211        }
212
213        Ok(result)
214    }
215
216    /// Apply all middleware layers to a router (request ID, tracing, timeout, body limit, CORS, rate limiting, error mapping, auth)
217    pub(crate) fn apply_middleware_stack(
218        &self,
219        mut router: Router,
220        authn_client: Option<Arc<dyn AuthNResolverClient>>,
221    ) -> Result<Router> {
222        // Build route policy once
223        let route_policy = self.build_route_policy_from_specs()?;
224
225        // IMPORTANT: `axum::Router::layer(...)` behaves like Tower layers: the **last** added layer
226        // becomes the **outermost** layer and therefore runs **first** on the request path.
227        //
228        // Desired request execution order (outermost -> innermost):
229        // SetRequestId -> PropagateRequestId -> Trace -> push_req_id_to_extensions
230        // -> Timeout -> BodyLimit -> CORS -> MIME validation -> RateLimit -> ErrorMapping -> Auth -> ScopeEnforcement -> License -> Router
231        //
232        // Therefore we must add layers in the reverse order (innermost -> outermost) below.
233        // Due future refactoring, this order must be maintained.
234
235        // 14) Propagate MatchedPath to response extensions (route_layer — innermost).
236        // This copies MatchedPath from the request (populated by Axum route matching)
237        // into the response so outer layer() middleware (metrics) can read it.
238        router = router.route_layer(from_fn(middleware::http_metrics::propagate_matched_path));
239
240        let config = self.get_cached_config();
241
242        // Collect specs once; used by MIME validation + rate limiting maps.
243        let specs: Vec<_> = self
244            .openapi_registry
245            .operation_specs
246            .iter()
247            .map(|e| e.value().clone())
248            .collect();
249
250        // 12) License validation
251        let license_map = middleware::license_validation::LicenseRequirementMap::from_specs(&specs);
252
253        router = router.layer(from_fn(
254            move |req: axum::extract::Request, next: axum::middleware::Next| {
255                let map = license_map.clone();
256                middleware::license_validation::license_validation_middleware(map, req, next)
257            },
258        ));
259
260        // 11) Route Policy Enforcement (runs after auth, checks token_scopes against route requirements)
261        if config.route_policies.enabled {
262            // Reject invalid combination: route_policies requires authentication to work
263            if config.auth_disabled {
264                return Err(anyhow::anyhow!(
265                    "Invalid configuration: route_policies.enabled=true requires authentication. \
266                     Set auth_disabled=false or disable route_policies."
267                ));
268            }
269
270            let scope_rules = middleware::scope_enforcement::ScopeEnforcementRules::from_config(
271                &config.route_policies,
272            )?;
273            let scope_state =
274                middleware::scope_enforcement::ScopeEnforcementState { rules: scope_rules };
275            router = router.layer(from_fn_with_state(
276                scope_state,
277                middleware::scope_enforcement::scope_enforcement_middleware,
278            ));
279        }
280
281        // 10) Auth
282        if config.auth_disabled {
283            // Build security contexts for compatibility during migration
284            let default_security_context = SecurityContext::builder()
285                .subject_id(DEFAULT_SUBJECT_ID)
286                .subject_tenant_id(DEFAULT_TENANT_ID)
287                .build()?;
288
289            tracing::warn!(
290                "API Gateway auth is DISABLED: all requests will run with default tenant SecurityContext. \
291                 This mode bypasses authentication and is intended ONLY for single-user on-premises deployments without an IdP. \
292                 Permission checks and secure ORM still apply. DO NOT use this mode in multi-tenant or production environments."
293            );
294            router = router.layer(from_fn(
295                move |mut req: axum::extract::Request, next: axum::middleware::Next| {
296                    let sec_context = default_security_context.clone();
297                    async move {
298                        req.extensions_mut().insert(sec_context);
299                        next.run(req).await
300                    }
301                },
302            ));
303        } else if let Some(client) = authn_client {
304            let auth_state = auth::AuthState {
305                authn_client: client,
306                route_policy,
307            };
308            router = router.layer(from_fn_with_state(auth_state, auth::authn_middleware));
309        } else {
310            return Err(anyhow::anyhow!(
311                "auth is enabled but no AuthN Resolver client is available; \
312                 ensure `authn_resolver` module is loaded or set `auth_disabled: true`"
313            ));
314        }
315
316        // 11) Error mapping (outer to auth so it can translate auth/handler errors)
317        router = router.layer(from_fn(modkit::api::error_layer::error_mapping_middleware));
318
319        // 10) Per-route rate limiting & in-flight limits
320        let rate_map = middleware::rate_limit::RateLimiterMap::from_specs(&specs, &config)?;
321
322        router = router.layer(from_fn(
323            move |req: axum::extract::Request, next: axum::middleware::Next| {
324                let map = rate_map.clone();
325                middleware::rate_limit::rate_limit_middleware(map, req, next)
326            },
327        ));
328
329        // 9) MIME type validation
330        let mime_map = middleware::mime_validation::build_mime_validation_map(&specs);
331        router = router.layer(from_fn(
332            move |req: axum::extract::Request, next: axum::middleware::Next| {
333                let map = mime_map.clone();
334                middleware::mime_validation::mime_validation_middleware(map, req, next)
335            },
336        ));
337
338        // 8) CORS (must be outer to auth/limits so OPTIONS preflight short-circuits)
339        if config.cors_enabled {
340            router = router.layer(crate::cors::build_cors_layer(&config));
341        }
342
343        // 7) Body limit
344        router = router.layer(RequestBodyLimitLayer::new(config.defaults.body_limit_bytes));
345        router = router.layer(DefaultBodyLimit::max(config.defaults.body_limit_bytes));
346
347        // 6) Timeout
348        router = router.layer(TimeoutLayer::with_status_code(
349            axum::http::StatusCode::GATEWAY_TIMEOUT,
350            Duration::from_secs(30),
351        ));
352
353        // 5) CatchPanic (converts panics to 500 before metrics sees them)
354        router = router.layer(CatchPanicLayer::new());
355
356        // 4) HTTP metrics (layer — captures all middleware responses including auth/rate-limit/timeout)
357        let http_metrics = Arc::new(middleware::http_metrics::HttpMetrics::new(
358            Self::MODULE_NAME,
359            &config.metrics.prefix,
360        ));
361        router = router.layer(from_fn_with_state(
362            http_metrics,
363            middleware::http_metrics::http_metrics_middleware,
364        ));
365
366        // 3.5) Structured access log (runs after push_req_id populates XRequestId extension)
367        router = router.layer(from_fn(middleware::access_log::access_log_middleware));
368
369        // 3) Record request_id into span + extensions (requires span to exist first => must be inner to Trace)
370        router = router.layer(from_fn(middleware::request_id::push_req_id_to_extensions));
371
372        // 2) Trace (outer to push_req_id_to_extensions)
373        router = router.layer({
374            use modkit_http::otel;
375            use tower_http::trace::TraceLayer;
376            use tracing::field::Empty;
377
378            TraceLayer::new_for_http()
379                .make_span_with(move |req: &axum::http::Request<axum::body::Body>| {
380                    let hdr = middleware::request_id::header();
381                    let rid = req
382                        .headers()
383                        .get(&hdr)
384                        .and_then(|v| v.to_str().ok())
385                        .unwrap_or("n/a");
386
387                    let span = tracing::info_span!(
388                        "http_request",
389                        method = %req.method(),
390                        uri = %req.uri().path(),
391                        version = ?req.version(),
392                        module = "api_gateway",
393                        endpoint = %req.uri().path(),
394                        request_id = %rid,
395                        status = Empty,
396                        latency_ms = Empty,
397                        // OpenTelemetry semantic conventions
398                        "http.method" = %req.method(),
399                        "http.target" = %req.uri().path(),
400                        "http.scheme" = req.uri().scheme_str().unwrap_or("http"),
401                        "http.host" = req.headers().get("host")
402                            .and_then(|h| h.to_str().ok())
403                            .unwrap_or("unknown"),
404                        "user_agent.original" = req.headers().get("user-agent")
405                            .and_then(|h| h.to_str().ok())
406                            .unwrap_or("unknown"),
407                        // Trace context placeholders (for log correlation)
408                        trace_id = Empty,
409                        parent.trace_id = Empty
410                    );
411
412                    // Set parent OTel trace context (W3C traceparent), if any
413                    // This also populates trace_id and parent.trace_id from headers
414                    otel::set_parent_from_headers(&span, req.headers());
415
416                    span
417                })
418                .on_response(
419                    |res: &axum::http::Response<axum::body::Body>,
420                     latency: std::time::Duration,
421                     span: &tracing::Span| {
422                        let ms = latency.as_millis();
423                        span.record("status", res.status().as_u16());
424                        span.record("latency_ms", ms);
425                    },
426                )
427        });
428
429        // 1) Request ID handling (outermost)
430        let x_request_id = crate::middleware::request_id::header();
431        // If missing, generate x-request-id first; then propagate it to the response.
432        router = router.layer(PropagateRequestIdLayer::new(x_request_id.clone()));
433        router = router.layer(SetRequestIdLayer::new(
434            x_request_id,
435            crate::middleware::request_id::MakeReqId,
436        ));
437
438        Ok(router)
439    }
440
441    /// Build the HTTP router from registered routes and operations.
442    ///
443    /// # Errors
444    /// Returns an error if router building or middleware setup fails.
445    pub fn build_router(&self) -> Result<Router> {
446        // If the cached router is currently held elsewhere (e.g., by the running server),
447        // return it without rebuilding to avoid unnecessary allocations.
448        let cached_router = self.router_cache.load();
449        if Arc::strong_count(&cached_router) > 1 {
450            tracing::debug!("Using cached router");
451            return Ok((*cached_router).clone());
452        }
453
454        tracing::debug!("Building new router (standalone/fallback mode)");
455        // In standalone mode (no REST pipeline), register both health endpoints here.
456        // In normal operation, rest_prepare() registers these instead.
457        let mut router = Router::new()
458            .route("/health", get(web::health_check))
459            .route("/healthz", get(|| async { "ok" }));
460
461        // Apply all middleware layers including auth, above the router
462        let authn_client = self.authn_client.lock().clone();
463        router = self.apply_middleware_stack(router, authn_client)?;
464
465        let config = self.get_cached_config();
466        let prefix = Self::normalize_prefix_path(&config.prefix_path)?;
467        router = Self::apply_prefix_nesting(router, &prefix);
468
469        // Cache the built router for future use
470        self.router_cache.store(router.clone());
471
472        Ok(router)
473    }
474
475    /// Build `OpenAPI` specification from registered routes and components.
476    ///
477    /// # Errors
478    /// Returns an error if `OpenAPI` specification building fails.
479    pub fn build_openapi(&self) -> Result<utoipa::openapi::OpenApi> {
480        let config = self.get_cached_config();
481        let info = modkit::api::OpenApiInfo {
482            title: config.openapi.title.clone(),
483            version: config.openapi.version.clone(),
484            description: config.openapi.description,
485        };
486        self.openapi_registry.build_openapi(&info)
487    }
488
489    /// Parse bind address from configuration string.
490    fn parse_bind_address(bind_addr: &str) -> anyhow::Result<SocketAddr> {
491        bind_addr
492            .parse()
493            .map_err(|e| anyhow::anyhow!("Invalid bind address '{bind_addr}': {e}"))
494    }
495
496    /// Get the finalized router or build a default one.
497    fn get_or_build_router(self: &Arc<Self>) -> anyhow::Result<Router> {
498        let stored = { self.final_router.lock().take() };
499
500        if let Some(router) = stored {
501            tracing::debug!("Using router from REST phase");
502            Ok(router)
503        } else {
504            tracing::debug!("No router from REST phase, building default router");
505            self.build_router()
506        }
507    }
508
509    /// Background HTTP server: bind, notify ready, serve until cancelled.
510    ///
511    /// This method is the lifecycle entry-point generated by the macro
512    /// (`#[modkit::module(..., lifecycle(...))]`).
513    pub(crate) async fn serve(
514        self: Arc<Self>,
515        cancel: CancellationToken,
516        ready: ReadySignal,
517    ) -> anyhow::Result<()> {
518        let cfg = self.get_cached_config();
519        let addr = Self::parse_bind_address(&cfg.bind_addr)?;
520        let router = self.get_or_build_router()?;
521
522        // Bind the socket, only now consider the service "ready"
523        let listener = tokio::net::TcpListener::bind(addr).await?;
524        tracing::info!("HTTP server bound on {}", addr);
525        ready.notify(); // Starting -> Running
526
527        // Graceful shutdown on cancel
528        let shutdown = {
529            let cancel = cancel.clone();
530            async move {
531                cancel.cancelled().await;
532                tracing::info!("HTTP server shutting down gracefully (cancellation)");
533            }
534        };
535
536        axum::serve(
537            listener,
538            router.into_make_service_with_connect_info::<SocketAddr>(),
539        )
540        .with_graceful_shutdown(shutdown)
541        .await
542        .map_err(|e| anyhow::anyhow!(e))
543    }
544
545    /// Check if `handler_id` is already registered (returns true if duplicate)
546    fn check_duplicate_handler(&self, spec: &modkit::api::OperationSpec) -> bool {
547        if self
548            .registered_handlers
549            .insert(spec.handler_id.clone(), ())
550            .is_some()
551        {
552            tracing::error!(
553                handler_id = %spec.handler_id,
554                method = %spec.method.as_str(),
555                path = %spec.path,
556                "Duplicate handler_id detected; ignoring subsequent registration"
557            );
558            return true;
559        }
560        false
561    }
562
563    /// Check if route (method, path) is already registered (returns true if duplicate)
564    fn check_duplicate_route(&self, spec: &modkit::api::OperationSpec) -> bool {
565        let route_key = (spec.method.clone(), spec.path.clone());
566        if self.registered_routes.insert(route_key, ()).is_some() {
567            tracing::error!(
568                method = %spec.method.as_str(),
569                path = %spec.path,
570                "Duplicate (method, path) detected; ignoring subsequent registration"
571            );
572            return true;
573        }
574        false
575    }
576
577    /// Log successful operation registration
578    fn log_operation_registration(&self, spec: &modkit::api::OperationSpec) {
579        let current_count = self.openapi_registry.operation_specs.len();
580        tracing::debug!(
581            handler_id = %spec.handler_id,
582            method = %spec.method.as_str(),
583            path = %spec.path,
584            summary = %spec.summary.as_deref().unwrap_or("No summary"),
585            total_operations = current_count,
586            "Registered API operation"
587        );
588    }
589
590    /// Add `OpenAPI` documentation routes to the router
591    fn add_openapi_routes(&self, mut router: axum::Router) -> anyhow::Result<axum::Router> {
592        // Build once, serve as static JSON (no per-request parsing)
593        let op_count = self.openapi_registry.operation_specs.len();
594        tracing::info!(
595            "rest_finalize: emitting OpenAPI with {} operations",
596            op_count
597        );
598
599        let openapi_doc = Arc::new(self.build_openapi()?);
600        let config = self.get_cached_config();
601        let prefix = Self::normalize_prefix_path(&config.prefix_path)?;
602        let html_doc = web::serve_docs(&prefix);
603
604        router = router
605            .route(
606                "/openapi.json",
607                get({
608                    use axum::{http::header, response::IntoResponse};
609                    let doc = openapi_doc;
610                    move || async move {
611                        let json_string = match serde_json::to_string_pretty(doc.as_ref()) {
612                            Ok(json) => json,
613                            Err(e) => {
614                                tracing::error!("Failed to serialize OpenAPI doc: {}", e);
615                                return (http::StatusCode::INTERNAL_SERVER_ERROR).into_response();
616                            }
617                        };
618                        (
619                            [
620                                (header::CONTENT_TYPE, "application/json"),
621                                (header::CACHE_CONTROL, "no-store"),
622                            ],
623                            json_string,
624                        )
625                            .into_response()
626                    }
627                }),
628            )
629            .route("/docs", get(move || async move { html_doc }));
630
631        #[cfg(feature = "embed_elements")]
632        {
633            router = router.route(
634                "/docs/assets/{*file}",
635                get(crate::assets::serve_elements_asset),
636            );
637        }
638
639        Ok(router)
640    }
641}
642
643// Manual implementation of Module trait with config loading
644#[async_trait]
645impl modkit::Module for ApiGateway {
646    async fn init(&self, ctx: &modkit::context::ModuleCtx) -> anyhow::Result<()> {
647        let cfg = ctx.config_or_default::<crate::config::ApiGatewayConfig>()?;
648        self.config.store(Arc::new(cfg.clone()));
649
650        debug!(
651            "Effective api_gateway configuration:\n{:#?}",
652            self.config.load()
653        );
654
655        if cfg.auth_disabled {
656            tracing::info!(
657                tenant_id = %DEFAULT_TENANT_ID,
658                "Auth-disabled mode enabled with default tenant"
659            );
660        } else {
661            // Resolve AuthN Resolver client from ClientHub
662            let authn_client = ctx.client_hub().get::<dyn AuthNResolverClient>()?;
663            *self.authn_client.lock() = Some(authn_client);
664            tracing::info!("AuthN Resolver client resolved from ClientHub");
665        }
666
667        Ok(())
668    }
669}
670
671// REST host role: prepare/finalize the router, but do not start the server here.
672impl modkit::contracts::ApiGatewayCapability for ApiGateway {
673    fn rest_prepare(
674        &self,
675        _ctx: &modkit::context::ModuleCtx,
676        router: axum::Router,
677    ) -> anyhow::Result<axum::Router> {
678        // Add health check endpoints:
679        // - /health: detailed JSON response with status and timestamp
680        // - /healthz: simple "ok" liveness probe (Kubernetes-style)
681        let router = router
682            .route("/health", get(web::health_check))
683            .route("/healthz", get(|| async { "ok" }));
684
685        // You may attach global middlewares here (trace, compression, cors), but do not start server.
686        tracing::debug!("REST host prepared base router with health check endpoints");
687        Ok(router)
688    }
689
690    fn rest_finalize(
691        &self,
692        _ctx: &modkit::context::ModuleCtx,
693        mut router: axum::Router,
694    ) -> anyhow::Result<axum::Router> {
695        let config = self.get_cached_config();
696
697        if config.enable_docs {
698            router = self.add_openapi_routes(router)?;
699        }
700
701        // Apply middleware stack (including auth) to the final router
702        tracing::debug!("Applying middleware stack to finalized router");
703        let authn_client = self.authn_client.lock().clone();
704        router = self.apply_middleware_stack(router, authn_client)?;
705
706        let prefix = Self::normalize_prefix_path(&config.prefix_path)?;
707        router = Self::apply_prefix_nesting(router, &prefix);
708
709        // Keep the finalized router to be used by `serve()`
710        *self.final_router.lock() = Some(router.clone());
711
712        tracing::info!("REST host finalized router with OpenAPI endpoints and auth middleware");
713        Ok(router)
714    }
715
716    fn as_registry(&self) -> &dyn modkit::contracts::OpenApiRegistry {
717        self
718    }
719}
720
721impl modkit::contracts::RestApiCapability for ApiGateway {
722    fn register_rest(
723        &self,
724        _ctx: &modkit::context::ModuleCtx,
725        router: axum::Router,
726        _openapi: &dyn modkit::contracts::OpenApiRegistry,
727    ) -> anyhow::Result<axum::Router> {
728        // This module acts as both rest_host and rest, but actual REST endpoints
729        // are handled in the host methods above.
730        Ok(router)
731    }
732}
733
734impl OpenApiRegistry for ApiGateway {
735    fn register_operation(&self, spec: &modkit::api::OperationSpec) {
736        // Reject duplicates with "first wins" policy (second registration = programmer error).
737        if self.check_duplicate_handler(spec) {
738            return;
739        }
740
741        if self.check_duplicate_route(spec) {
742            return;
743        }
744
745        // Delegate to the internal registry
746        self.openapi_registry.register_operation(spec);
747        self.log_operation_registration(spec);
748    }
749
750    fn ensure_schema_raw(
751        &self,
752        root_name: &str,
753        schemas: Vec<(
754            String,
755            utoipa::openapi::RefOr<utoipa::openapi::schema::Schema>,
756        )>,
757    ) -> String {
758        // Delegate to the internal registry
759        self.openapi_registry.ensure_schema_raw(root_name, schemas)
760    }
761
762    fn as_any(&self) -> &dyn std::any::Any {
763        self
764    }
765}
766
767#[cfg(test)]
768#[cfg_attr(coverage_nightly, coverage(off))]
769mod tests {
770    use super::*;
771
772    #[test]
773    fn test_openapi_generation() {
774        let mut config = ApiGatewayConfig::default();
775        config.openapi.title = "Test API".to_owned();
776        config.openapi.version = "1.0.0".to_owned();
777        config.openapi.description = Some("Test Description".to_owned());
778        let api = ApiGateway::new(config);
779
780        // Test that we can build OpenAPI without any operations
781        let doc = api.build_openapi().unwrap();
782        let json = serde_json::to_value(&doc).unwrap();
783
784        // Verify it's valid OpenAPI document structure
785        assert!(json.get("openapi").is_some());
786        assert!(json.get("info").is_some());
787        assert!(json.get("paths").is_some());
788
789        // Verify info section
790        let info = json.get("info").unwrap();
791        assert_eq!(info.get("title").unwrap(), "Test API");
792        assert_eq!(info.get("version").unwrap(), "1.0.0");
793        assert_eq!(info.get("description").unwrap(), "Test Description");
794    }
795}
796
797#[cfg(test)]
798#[cfg_attr(coverage_nightly, coverage(off))]
799mod normalize_prefix_path_tests {
800    use super::*;
801
802    #[test]
803    fn empty_string_returns_empty() {
804        assert_eq!(ApiGateway::normalize_prefix_path("").unwrap(), "");
805    }
806
807    #[test]
808    fn sole_slash_returns_empty() {
809        assert_eq!(ApiGateway::normalize_prefix_path("/").unwrap(), "");
810    }
811
812    #[test]
813    fn multiple_slashes_return_empty() {
814        assert_eq!(ApiGateway::normalize_prefix_path("///").unwrap(), "");
815    }
816
817    #[test]
818    fn whitespace_only_returns_empty() {
819        assert_eq!(ApiGateway::normalize_prefix_path("   ").unwrap(), "");
820    }
821
822    #[test]
823    fn simple_prefix_preserved() {
824        assert_eq!(ApiGateway::normalize_prefix_path("/cf").unwrap(), "/cf");
825    }
826
827    #[test]
828    fn trailing_slash_stripped() {
829        assert_eq!(ApiGateway::normalize_prefix_path("/cf/").unwrap(), "/cf");
830    }
831
832    #[test]
833    fn leading_slash_prepended_when_missing() {
834        assert_eq!(ApiGateway::normalize_prefix_path("cf").unwrap(), "/cf");
835    }
836
837    #[test]
838    fn consecutive_leading_slashes_collapsed() {
839        assert_eq!(ApiGateway::normalize_prefix_path("//cf").unwrap(), "/cf");
840    }
841
842    #[test]
843    fn consecutive_slashes_mid_path_collapsed() {
844        assert_eq!(
845            ApiGateway::normalize_prefix_path("/api//v1").unwrap(),
846            "/api/v1"
847        );
848    }
849
850    #[test]
851    fn many_consecutive_slashes_collapsed() {
852        assert_eq!(
853            ApiGateway::normalize_prefix_path("///api///v1///").unwrap(),
854            "/api/v1"
855        );
856    }
857
858    #[test]
859    fn surrounding_whitespace_trimmed() {
860        assert_eq!(ApiGateway::normalize_prefix_path("  /cf  ").unwrap(), "/cf");
861    }
862
863    #[test]
864    fn nested_path_preserved() {
865        assert_eq!(
866            ApiGateway::normalize_prefix_path("/api/v1").unwrap(),
867            "/api/v1"
868        );
869    }
870
871    #[test]
872    fn dot_in_path_allowed() {
873        assert_eq!(
874            ApiGateway::normalize_prefix_path("/api/v1.0").unwrap(),
875            "/api/v1.0"
876        );
877    }
878
879    #[test]
880    fn rejects_html_injection() {
881        let result = ApiGateway::normalize_prefix_path(r#""><script>alert(1)</script>"#);
882        assert!(result.is_err());
883    }
884
885    #[test]
886    fn rejects_spaces_in_path() {
887        let result = ApiGateway::normalize_prefix_path("/my path");
888        assert!(result.is_err());
889    }
890
891    #[test]
892    fn rejects_query_string_chars() {
893        let result = ApiGateway::normalize_prefix_path("/api?foo=bar");
894        assert!(result.is_err());
895    }
896}
897
898#[cfg(test)]
899#[cfg_attr(coverage_nightly, coverage(off))]
900mod problem_openapi_tests {
901    use super::*;
902    use axum::Json;
903    use modkit::api::{Missing, OperationBuilder};
904    use serde_json::Value;
905
906    async fn dummy_handler() -> Json<Value> {
907        Json(serde_json::json!({"ok": true}))
908    }
909
910    #[tokio::test]
911    async fn openapi_includes_problem_schema_and_response() {
912        let api = ApiGateway::default();
913        let router = axum::Router::new();
914
915        // Build a route with a problem+json response
916        let _router = OperationBuilder::<Missing, Missing, ()>::get("/tests/v1/problem-demo")
917            .public()
918            .summary("Problem demo")
919            .problem_response(&api, http::StatusCode::BAD_REQUEST, "Bad Request") // <-- registers Problem + sets content type
920            .handler(dummy_handler)
921            .register(router, &api);
922
923        let doc = api.build_openapi().expect("openapi");
924        let v = serde_json::to_value(&doc).expect("json");
925
926        // 1) Problem exists in components.schemas
927        let problem = v
928            .pointer("/components/schemas/Problem")
929            .expect("Problem schema missing");
930        assert!(
931            problem.get("$ref").is_none(),
932            "Problem must be a real object, not a self-ref"
933        );
934
935        // 2) Response under /paths/... references Problem and has correct media type
936        let path_obj = v
937            .pointer("/paths/~1tests~1v1~1problem-demo/get/responses/400")
938            .expect("400 response missing");
939
940        // Check what content types exist
941        let content_obj = path_obj.get("content").expect("content object missing");
942        assert!(
943            content_obj.get("application/problem+json").is_some(),
944            "application/problem+json content missing. Available content: {}",
945            serde_json::to_string_pretty(content_obj).unwrap()
946        );
947
948        let content = path_obj
949            .pointer("/content/application~1problem+json")
950            .expect("application/problem+json content missing");
951        // $ref to Problem
952        let schema_ref = content
953            .pointer("/schema/$ref")
954            .and_then(|r| r.as_str())
955            .unwrap_or("");
956        assert_eq!(schema_ref, "#/components/schemas/Problem");
957    }
958}
959
960#[cfg(test)]
961#[cfg_attr(coverage_nightly, coverage(off))]
962mod sse_openapi_tests {
963    use super::*;
964    use axum::Json;
965    use modkit::api::{Missing, OperationBuilder};
966    use serde_json::Value;
967
968    #[derive(Clone)]
969    #[modkit_macros::api_dto(request, response)]
970    struct UserEvent {
971        id: u32,
972        message: String,
973    }
974
975    async fn sse_handler() -> axum::response::sse::Sse<
976        impl futures_core::Stream<Item = Result<axum::response::sse::Event, std::convert::Infallible>>,
977    > {
978        let b = modkit::SseBroadcaster::<UserEvent>::new(4);
979        b.sse_response()
980    }
981
982    #[tokio::test]
983    async fn openapi_has_sse_content() {
984        let api = ApiGateway::default();
985        let router = axum::Router::new();
986
987        let _router = OperationBuilder::<Missing, Missing, ()>::get("/tests/v1/demo/sse")
988            .summary("Demo SSE")
989            .handler(sse_handler)
990            .public()
991            .sse_json::<UserEvent>(&api, "SSE of UserEvent")
992            .register(router, &api);
993
994        let doc = api.build_openapi().expect("openapi");
995        let v = serde_json::to_value(&doc).expect("json");
996
997        // schema is materialized
998        let schema = v
999            .pointer("/components/schemas/UserEvent")
1000            .expect("UserEvent missing");
1001        assert!(schema.get("$ref").is_none());
1002
1003        // content is text/event-stream with $ref to our schema
1004        let refp = v
1005            .pointer("/paths/~1tests~1v1~1demo~1sse/get/responses/200/content/text~1event-stream/schema/$ref")
1006            .and_then(|x| x.as_str())
1007            .unwrap_or_default();
1008        assert_eq!(refp, "#/components/schemas/UserEvent");
1009    }
1010
1011    #[tokio::test]
1012    async fn openapi_sse_additional_response() {
1013        async fn mixed_handler() -> Json<Value> {
1014            Json(serde_json::json!({"ok": true}))
1015        }
1016
1017        let api = ApiGateway::default();
1018        let router = axum::Router::new();
1019
1020        let _router = OperationBuilder::<Missing, Missing, ()>::get("/tests/v1/demo/mixed")
1021            .summary("Mixed responses")
1022            .public()
1023            .handler(mixed_handler)
1024            .json_response(http::StatusCode::OK, "Success response")
1025            .sse_json::<UserEvent>(&api, "Additional SSE stream")
1026            .register(router, &api);
1027
1028        let doc = api.build_openapi().expect("openapi");
1029        let v = serde_json::to_value(&doc).expect("json");
1030
1031        // Check that both response types are present
1032        let responses = v
1033            .pointer("/paths/~1tests~1v1~1demo~1mixed/get/responses")
1034            .expect("responses");
1035
1036        // JSON response exists
1037        assert!(responses.get("200").is_some());
1038
1039        // SSE response exists (could be another 200 or different status)
1040        let response_content = responses.get("200").and_then(|r| r.get("content"));
1041        assert!(response_content.is_some());
1042
1043        // UserEvent schema is registered
1044        let schema = v
1045            .pointer("/components/schemas/UserEvent")
1046            .expect("UserEvent missing");
1047        assert!(schema.get("$ref").is_none());
1048    }
1049
1050    #[tokio::test]
1051    async fn test_axum_to_openapi_path_conversion() {
1052        // Define a route with path parameters using Axum 0.8+ style {id}
1053        async fn user_handler() -> Json<Value> {
1054            Json(serde_json::json!({"user_id": "123"}))
1055        }
1056
1057        let api = ApiGateway::default();
1058        let router = axum::Router::new();
1059
1060        let _router = OperationBuilder::<Missing, Missing, ()>::get("/tests/v1/users/{id}")
1061            .summary("Get user by ID")
1062            .public()
1063            .path_param("id", "User ID")
1064            .handler(user_handler)
1065            .json_response(http::StatusCode::OK, "User details")
1066            .register(router, &api);
1067
1068        // Verify the operation was stored with {id} path (same for Axum 0.8 and OpenAPI)
1069        let ops: Vec<_> = api
1070            .openapi_registry
1071            .operation_specs
1072            .iter()
1073            .map(|e| e.value().clone())
1074            .collect();
1075        assert_eq!(ops.len(), 1);
1076        assert_eq!(ops[0].path, "/tests/v1/users/{id}");
1077
1078        // Verify OpenAPI doc also has {id} (no conversion needed for regular params)
1079        let doc = api.build_openapi().expect("openapi");
1080        let v = serde_json::to_value(&doc).expect("json");
1081
1082        let paths = v.get("paths").expect("paths");
1083        assert!(
1084            paths.get("/tests/v1/users/{id}").is_some(),
1085            "OpenAPI should use {{id}} placeholder"
1086        );
1087    }
1088
1089    #[tokio::test]
1090    async fn test_multiple_path_params_conversion() {
1091        async fn item_handler() -> Json<Value> {
1092            Json(serde_json::json!({"ok": true}))
1093        }
1094
1095        let api = ApiGateway::default();
1096        let router = axum::Router::new();
1097
1098        let _router = OperationBuilder::<Missing, Missing, ()>::get(
1099            "/tests/v1/projects/{project_id}/items/{item_id}",
1100        )
1101        .summary("Get project item")
1102        .public()
1103        .path_param("project_id", "Project ID")
1104        .path_param("item_id", "Item ID")
1105        .handler(item_handler)
1106        .json_response(http::StatusCode::OK, "Item details")
1107        .register(router, &api);
1108
1109        // Verify storage and OpenAPI both use {param} syntax
1110        let ops: Vec<_> = api
1111            .openapi_registry
1112            .operation_specs
1113            .iter()
1114            .map(|e| e.value().clone())
1115            .collect();
1116        assert_eq!(
1117            ops[0].path,
1118            "/tests/v1/projects/{project_id}/items/{item_id}"
1119        );
1120
1121        let doc = api.build_openapi().expect("openapi");
1122        let v = serde_json::to_value(&doc).expect("json");
1123        let paths = v.get("paths").expect("paths");
1124        assert!(
1125            paths
1126                .get("/tests/v1/projects/{project_id}/items/{item_id}")
1127                .is_some()
1128        );
1129    }
1130
1131    #[tokio::test]
1132    async fn test_wildcard_path_conversion() {
1133        async fn static_handler() -> Json<Value> {
1134            Json(serde_json::json!({"ok": true}))
1135        }
1136
1137        let api = ApiGateway::default();
1138        let router = axum::Router::new();
1139
1140        // Axum 0.8 uses {*path} for wildcards
1141        let _router = OperationBuilder::<Missing, Missing, ()>::get("/tests/v1/static/{*path}")
1142            .summary("Serve static files")
1143            .public()
1144            .handler(static_handler)
1145            .json_response(http::StatusCode::OK, "File content")
1146            .register(router, &api);
1147
1148        // Verify internal storage keeps Axum wildcard syntax {*path}
1149        let ops: Vec<_> = api
1150            .openapi_registry
1151            .operation_specs
1152            .iter()
1153            .map(|e| e.value().clone())
1154            .collect();
1155        assert_eq!(ops[0].path, "/tests/v1/static/{*path}");
1156
1157        // Verify OpenAPI converts wildcard to {path} (without asterisk)
1158        let doc = api.build_openapi().expect("openapi");
1159        let v = serde_json::to_value(&doc).expect("json");
1160        let paths = v.get("paths").expect("paths");
1161        assert!(
1162            paths.get("/tests/v1/static/{path}").is_some(),
1163            "Wildcard {{*path}} should be converted to {{path}} in OpenAPI"
1164        );
1165        assert!(
1166            paths.get("/static/{*path}").is_none(),
1167            "OpenAPI should not have Axum-style {{*path}}"
1168        );
1169    }
1170
1171    #[tokio::test]
1172    async fn test_multipart_file_upload_openapi() {
1173        async fn upload_handler() -> Json<Value> {
1174            Json(serde_json::json!({"uploaded": true}))
1175        }
1176
1177        let api = ApiGateway::default();
1178        let router = axum::Router::new();
1179
1180        let _router = OperationBuilder::<Missing, Missing, ()>::post("/tests/v1/files/upload")
1181            .operation_id("upload_file")
1182            .public()
1183            .summary("Upload a file")
1184            .multipart_file_request("file", Some("File to upload"))
1185            .handler(upload_handler)
1186            .json_response(http::StatusCode::OK, "Upload successful")
1187            .register(router, &api);
1188
1189        // Build OpenAPI and verify multipart schema
1190        let doc = api.build_openapi().expect("openapi");
1191        let v = serde_json::to_value(&doc).expect("json");
1192
1193        let paths = v.get("paths").expect("paths");
1194        let upload_path = paths
1195            .get("/tests/v1/files/upload")
1196            .expect("/tests/v1/files/upload path");
1197        let post_op = upload_path.get("post").expect("POST operation");
1198
1199        // Verify request body exists
1200        let request_body = post_op.get("requestBody").expect("requestBody");
1201        let content = request_body.get("content").expect("content");
1202        let multipart = content
1203            .get("multipart/form-data")
1204            .expect("multipart/form-data content type");
1205
1206        // Verify schema structure
1207        let schema = multipart.get("schema").expect("schema");
1208        assert_eq!(
1209            schema.get("type").and_then(|v| v.as_str()),
1210            Some("object"),
1211            "Schema should be of type object"
1212        );
1213
1214        // Verify properties
1215        let properties = schema.get("properties").expect("properties");
1216        let file_prop = properties.get("file").expect("file property");
1217        assert_eq!(
1218            file_prop.get("type").and_then(|v| v.as_str()),
1219            Some("string"),
1220            "File field should be of type string"
1221        );
1222        assert_eq!(
1223            file_prop.get("format").and_then(|v| v.as_str()),
1224            Some("binary"),
1225            "File field should have format binary"
1226        );
1227
1228        // Verify required fields
1229        let required = schema.get("required").expect("required");
1230        let required_arr = required.as_array().expect("required should be array");
1231        assert_eq!(required_arr.len(), 1);
1232        assert_eq!(required_arr[0].as_str(), Some("file"));
1233    }
1234}