1use async_trait::async_trait;
6use std::sync::Arc;
7
8use arc_swap::ArcSwap;
9use dashmap::DashMap;
10
11use anyhow::Result;
12use axum::http::Method;
13use axum::middleware::from_fn_with_state;
14use axum::{Router, extract::DefaultBodyLimit, middleware::from_fn, routing::get};
15use modkit::api::{OpenApiRegistry, OpenApiRegistryImpl};
16use modkit::lifecycle::ReadySignal;
17use parking_lot::Mutex;
18use std::net::SocketAddr;
19use std::time::Duration;
20use tokio_util::sync::CancellationToken;
21use tower_http::{
22 catch_panic::CatchPanicLayer,
23 limit::RequestBodyLimitLayer,
24 request_id::{PropagateRequestIdLayer, SetRequestIdLayer},
25 timeout::TimeoutLayer,
26};
27use tracing::debug;
28
29use authn_resolver_sdk::AuthNResolverClient;
30
31use crate::config::ApiGatewayConfig;
32use crate::middleware::auth;
33use modkit_security::SecurityContext;
34use modkit_security::constants::{DEFAULT_SUBJECT_ID, DEFAULT_TENANT_ID};
35
36use crate::middleware;
37use crate::router_cache::RouterCache;
38use crate::web;
39
40#[modkit::module(
43 name = "api-gateway",
44 capabilities = [rest_host, rest, stateful],
45 deps = ["grpc-hub", "authn-resolver"],
46 lifecycle(entry = "serve", stop_timeout = "30s", await_ready)
47)]
48pub struct ApiGateway {
49 pub(crate) config: ArcSwap<ApiGatewayConfig>,
51 pub(crate) openapi_registry: Arc<OpenApiRegistryImpl>,
53 pub(crate) router_cache: RouterCache<axum::Router>,
55 pub(crate) final_router: Mutex<Option<axum::Router>>,
57 pub(crate) authn_client: Mutex<Option<Arc<dyn AuthNResolverClient>>>,
59
60 pub(crate) registered_routes: DashMap<(Method, String), ()>,
62 pub(crate) registered_handlers: DashMap<String, ()>,
63}
64
65impl Default for ApiGateway {
66 fn default() -> Self {
67 let default_router = Router::new();
68 Self {
69 config: ArcSwap::from_pointee(ApiGatewayConfig::default()),
70 openapi_registry: Arc::new(OpenApiRegistryImpl::new()),
71 router_cache: RouterCache::new(default_router),
72 final_router: Mutex::new(None),
73 authn_client: Mutex::new(None),
74 registered_routes: DashMap::new(),
75 registered_handlers: DashMap::new(),
76 }
77 }
78}
79
80impl ApiGateway {
81 fn apply_prefix_nesting(mut router: Router, prefix: &str) -> Router {
82 if prefix.is_empty() {
83 return router;
84 }
85
86 let top = Router::new()
87 .route("/health", get(web::health_check))
88 .route("/healthz", get(|| async { "ok" }));
89
90 router = Router::new().nest(prefix, router);
91 top.merge(router)
92 }
93
94 #[must_use]
96 pub fn new(config: ApiGatewayConfig) -> Self {
97 let default_router = Router::new();
98 Self {
99 config: ArcSwap::from_pointee(config),
100 openapi_registry: Arc::new(OpenApiRegistryImpl::new()),
101 router_cache: RouterCache::new(default_router),
102 final_router: Mutex::new(None),
103 authn_client: Mutex::new(None),
104 registered_routes: DashMap::new(),
105 registered_handlers: DashMap::new(),
106 }
107 }
108
109 pub fn get_config(&self) -> ApiGatewayConfig {
111 (**self.config.load()).clone()
112 }
113
114 pub fn get_cached_config(&self) -> ApiGatewayConfig {
116 (**self.config.load()).clone()
117 }
118
119 pub fn get_cached_router(&self) -> Arc<Router> {
121 self.router_cache.load()
122 }
123
124 pub fn rebuild_and_cache_router(&self) -> Result<()> {
129 let new_router = self.build_router()?;
130 self.router_cache.store(new_router);
131 Ok(())
132 }
133
134 fn build_route_policy_from_specs(&self) -> Result<auth::GatewayRoutePolicy> {
136 let mut authenticated_routes = std::collections::HashSet::new();
137 let mut public_routes = std::collections::HashSet::new();
138
139 public_routes.insert((Method::GET, "/health".to_owned()));
141 public_routes.insert((Method::GET, "/healthz".to_owned()));
142
143 public_routes.insert((Method::GET, "/docs".to_owned()));
144 public_routes.insert((Method::GET, "/openapi.json".to_owned()));
145
146 for spec in &self.openapi_registry.operation_specs {
147 let spec = spec.value();
148
149 let route_key = (spec.method.clone(), spec.path.clone());
150
151 if spec.authenticated {
152 authenticated_routes.insert(route_key.clone());
153 }
154
155 if spec.is_public {
156 public_routes.insert(route_key);
157 }
158 }
159
160 let config = self.get_cached_config();
161 let requirements_count = authenticated_routes.len();
162 let public_routes_count = public_routes.len();
163
164 let route_policy = auth::build_route_policy(&config, authenticated_routes, public_routes)?;
165
166 tracing::info!(
167 auth_disabled = config.auth_disabled,
168 require_auth_by_default = config.require_auth_by_default,
169 requirements_count = requirements_count,
170 public_routes_count = public_routes_count,
171 "Route policy built from operation specs"
172 );
173
174 Ok(route_policy)
175 }
176
177 fn normalize_prefix_path(raw: &str) -> Result<String> {
178 let trimmed = raw.trim();
179 let collapsed: String =
181 trimmed
182 .chars()
183 .fold(String::with_capacity(trimmed.len()), |mut acc, c| {
184 if c == '/' && acc.ends_with('/') {
185 } else {
187 acc.push(c);
188 }
189 acc
190 });
191 let prefix = collapsed.trim_end_matches('/');
192 let result = if prefix.is_empty() {
193 String::new()
194 } else if prefix.starts_with('/') {
195 prefix.to_owned()
196 } else {
197 format!("/{prefix}")
198 };
199 if !result
201 .bytes()
202 .all(|b| b.is_ascii_alphanumeric() || b == b'/' || b == b'_' || b == b'-' || b == b'.')
203 {
204 anyhow::bail!(
205 "prefix_path contains invalid characters (must match [a-zA-Z0-9/_\\-.]): {raw:?}"
206 );
207 }
208
209 if result.split('/').any(|seg| seg == "." || seg == "..") {
210 anyhow::bail!("prefix_path must not contain '.' or '..' segments: {raw:?}");
211 }
212
213 Ok(result)
214 }
215
216 pub(crate) fn apply_middleware_stack(
218 &self,
219 mut router: Router,
220 authn_client: Option<Arc<dyn AuthNResolverClient>>,
221 ) -> Result<Router> {
222 let route_policy = self.build_route_policy_from_specs()?;
224
225 router = router.route_layer(from_fn(middleware::http_metrics::propagate_matched_path));
241
242 let config = self.get_cached_config();
243
244 let specs: Vec<_> = self
246 .openapi_registry
247 .operation_specs
248 .iter()
249 .map(|e| e.value().clone())
250 .collect();
251
252 let license_map = middleware::license_validation::LicenseRequirementMap::from_specs(&specs);
254
255 router = router.layer(from_fn(
256 move |req: axum::extract::Request, next: axum::middleware::Next| {
257 let map = license_map.clone();
258 middleware::license_validation::license_validation_middleware(map, req, next)
259 },
260 ));
261
262 if config.auth_disabled {
264 let default_security_context = SecurityContext::builder()
266 .subject_id(DEFAULT_SUBJECT_ID)
267 .subject_tenant_id(DEFAULT_TENANT_ID)
268 .build()?;
269
270 tracing::warn!(
271 "API Gateway auth is DISABLED: all requests will run with default tenant SecurityContext. \
272 This mode bypasses authentication and is intended ONLY for single-user on-premises deployments without an IdP. \
273 Permission checks and secure ORM still apply. DO NOT use this mode in multi-tenant or production environments."
274 );
275 router = router.layer(from_fn(
276 move |mut req: axum::extract::Request, next: axum::middleware::Next| {
277 let sec_context = default_security_context.clone();
278 async move {
279 req.extensions_mut().insert(sec_context);
280 next.run(req).await
281 }
282 },
283 ));
284 } else if let Some(client) = authn_client {
285 let auth_state = auth::AuthState {
286 authn_client: client,
287 route_policy,
288 };
289 router = router.layer(from_fn_with_state(auth_state, auth::authn_middleware));
290 } else {
291 return Err(anyhow::anyhow!(
292 "auth is enabled but no AuthN Resolver client is available; \
293 ensure `authn_resolver` module is loaded or set `auth_disabled: true`"
294 ));
295 }
296
297 router = router.layer(from_fn(modkit::api::error_layer::error_mapping_middleware));
299
300 let rate_map = middleware::rate_limit::RateLimiterMap::from_specs(&specs, &config)?;
302
303 router = router.layer(from_fn(
304 move |req: axum::extract::Request, next: axum::middleware::Next| {
305 let map = rate_map.clone();
306 middleware::rate_limit::rate_limit_middleware(map, req, next)
307 },
308 ));
309
310 let mime_map = middleware::mime_validation::build_mime_validation_map(&specs);
312 router = router.layer(from_fn(
313 move |req: axum::extract::Request, next: axum::middleware::Next| {
314 let map = mime_map.clone();
315 middleware::mime_validation::mime_validation_middleware(map, req, next)
316 },
317 ));
318
319 if config.cors_enabled {
321 router = router.layer(crate::cors::build_cors_layer(&config));
322 }
323
324 router = router.layer(RequestBodyLimitLayer::new(config.defaults.body_limit_bytes));
326 router = router.layer(DefaultBodyLimit::max(config.defaults.body_limit_bytes));
327
328 router = router.layer(TimeoutLayer::with_status_code(
330 axum::http::StatusCode::GATEWAY_TIMEOUT,
331 Duration::from_secs(30),
332 ));
333
334 router = router.layer(CatchPanicLayer::new());
336
337 let http_metrics = Arc::new(middleware::http_metrics::HttpMetrics::new(
339 Self::MODULE_NAME,
340 &config.metrics.prefix,
341 ));
342 router = router.layer(from_fn_with_state(
343 http_metrics,
344 middleware::http_metrics::http_metrics_middleware,
345 ));
346
347 router = router.layer(from_fn(middleware::access_log::access_log_middleware));
349
350 router = router.layer(from_fn(middleware::request_id::push_req_id_to_extensions));
352
353 router = router.layer({
355 use modkit_http::otel;
356 use tower_http::trace::TraceLayer;
357 use tracing::field::Empty;
358
359 TraceLayer::new_for_http()
360 .make_span_with(move |req: &axum::http::Request<axum::body::Body>| {
361 let hdr = middleware::request_id::header();
362 let rid = req
363 .headers()
364 .get(&hdr)
365 .and_then(|v| v.to_str().ok())
366 .unwrap_or("n/a");
367
368 let span = tracing::info_span!(
369 "http_request",
370 method = %req.method(),
371 uri = %req.uri().path(),
372 version = ?req.version(),
373 module = "api_gateway",
374 endpoint = %req.uri().path(),
375 request_id = %rid,
376 status = Empty,
377 latency_ms = Empty,
378 "http.method" = %req.method(),
380 "http.target" = %req.uri().path(),
381 "http.scheme" = req.uri().scheme_str().unwrap_or("http"),
382 "http.host" = req.headers().get("host")
383 .and_then(|h| h.to_str().ok())
384 .unwrap_or("unknown"),
385 "user_agent.original" = req.headers().get("user-agent")
386 .and_then(|h| h.to_str().ok())
387 .unwrap_or("unknown"),
388 trace_id = Empty,
390 parent.trace_id = Empty
391 );
392
393 otel::set_parent_from_headers(&span, req.headers());
396
397 span
398 })
399 .on_response(
400 |res: &axum::http::Response<axum::body::Body>,
401 latency: std::time::Duration,
402 span: &tracing::Span| {
403 let ms = latency.as_millis();
404 span.record("status", res.status().as_u16());
405 span.record("latency_ms", ms);
406 },
407 )
408 });
409
410 let x_request_id = crate::middleware::request_id::header();
412 router = router.layer(PropagateRequestIdLayer::new(x_request_id.clone()));
414 router = router.layer(SetRequestIdLayer::new(
415 x_request_id,
416 crate::middleware::request_id::MakeReqId,
417 ));
418
419 Ok(router)
420 }
421
422 pub fn build_router(&self) -> Result<Router> {
427 let cached_router = self.router_cache.load();
430 if Arc::strong_count(&cached_router) > 1 {
431 tracing::debug!("Using cached router");
432 return Ok((*cached_router).clone());
433 }
434
435 tracing::debug!("Building new router (standalone/fallback mode)");
436 let mut router = Router::new()
439 .route("/health", get(web::health_check))
440 .route("/healthz", get(|| async { "ok" }));
441
442 let authn_client = self.authn_client.lock().clone();
444 router = self.apply_middleware_stack(router, authn_client)?;
445
446 let config = self.get_cached_config();
447 let prefix = Self::normalize_prefix_path(&config.prefix_path)?;
448 router = Self::apply_prefix_nesting(router, &prefix);
449
450 self.router_cache.store(router.clone());
452
453 Ok(router)
454 }
455
456 pub fn build_openapi(&self) -> Result<utoipa::openapi::OpenApi> {
461 let config = self.get_cached_config();
462 let info = modkit::api::OpenApiInfo {
463 title: config.openapi.title.clone(),
464 version: config.openapi.version.clone(),
465 description: config.openapi.description,
466 };
467 self.openapi_registry.build_openapi(&info)
468 }
469
470 fn parse_bind_address(bind_addr: &str) -> anyhow::Result<SocketAddr> {
472 bind_addr
473 .parse()
474 .map_err(|e| anyhow::anyhow!("Invalid bind address '{bind_addr}': {e}"))
475 }
476
477 fn get_or_build_router(self: &Arc<Self>) -> anyhow::Result<Router> {
479 let stored = { self.final_router.lock().take() };
480
481 if let Some(router) = stored {
482 tracing::debug!("Using router from REST phase");
483 Ok(router)
484 } else {
485 tracing::debug!("No router from REST phase, building default router");
486 self.build_router()
487 }
488 }
489
490 pub(crate) async fn serve(
495 self: Arc<Self>,
496 cancel: CancellationToken,
497 ready: ReadySignal,
498 ) -> anyhow::Result<()> {
499 let cfg = self.get_cached_config();
500 let addr = Self::parse_bind_address(&cfg.bind_addr)?;
501 let router = self.get_or_build_router()?;
502
503 let listener = tokio::net::TcpListener::bind(addr).await?;
505 tracing::info!("HTTP server bound on {}", addr);
506 ready.notify(); let shutdown = {
510 let cancel = cancel.clone();
511 async move {
512 cancel.cancelled().await;
513 tracing::info!("HTTP server shutting down gracefully (cancellation)");
514 }
515 };
516
517 axum::serve(
518 listener,
519 router.into_make_service_with_connect_info::<SocketAddr>(),
520 )
521 .with_graceful_shutdown(shutdown)
522 .await
523 .map_err(|e| anyhow::anyhow!(e))
524 }
525
526 fn check_duplicate_handler(&self, spec: &modkit::api::OperationSpec) -> bool {
528 if self
529 .registered_handlers
530 .insert(spec.handler_id.clone(), ())
531 .is_some()
532 {
533 tracing::error!(
534 handler_id = %spec.handler_id,
535 method = %spec.method.as_str(),
536 path = %spec.path,
537 "Duplicate handler_id detected; ignoring subsequent registration"
538 );
539 return true;
540 }
541 false
542 }
543
544 fn check_duplicate_route(&self, spec: &modkit::api::OperationSpec) -> bool {
546 let route_key = (spec.method.clone(), spec.path.clone());
547 if self.registered_routes.insert(route_key, ()).is_some() {
548 tracing::error!(
549 method = %spec.method.as_str(),
550 path = %spec.path,
551 "Duplicate (method, path) detected; ignoring subsequent registration"
552 );
553 return true;
554 }
555 false
556 }
557
558 fn log_operation_registration(&self, spec: &modkit::api::OperationSpec) {
560 let current_count = self.openapi_registry.operation_specs.len();
561 tracing::debug!(
562 handler_id = %spec.handler_id,
563 method = %spec.method.as_str(),
564 path = %spec.path,
565 summary = %spec.summary.as_deref().unwrap_or("No summary"),
566 total_operations = current_count,
567 "Registered API operation"
568 );
569 }
570
571 fn add_openapi_routes(&self, mut router: axum::Router) -> anyhow::Result<axum::Router> {
573 let op_count = self.openapi_registry.operation_specs.len();
575 tracing::info!(
576 "rest_finalize: emitting OpenAPI with {} operations",
577 op_count
578 );
579
580 let openapi_doc = Arc::new(self.build_openapi()?);
581 let config = self.get_cached_config();
582 let prefix = Self::normalize_prefix_path(&config.prefix_path)?;
583 let html_doc = web::serve_docs(&prefix);
584
585 router = router
586 .route(
587 "/openapi.json",
588 get({
589 use axum::{http::header, response::IntoResponse};
590 let doc = openapi_doc;
591 move || async move {
592 let json_string = match serde_json::to_string_pretty(doc.as_ref()) {
593 Ok(json) => json,
594 Err(e) => {
595 tracing::error!("Failed to serialize OpenAPI doc: {}", e);
596 return (http::StatusCode::INTERNAL_SERVER_ERROR).into_response();
597 }
598 };
599 (
600 [
601 (header::CONTENT_TYPE, "application/json"),
602 (header::CACHE_CONTROL, "no-store"),
603 ],
604 json_string,
605 )
606 .into_response()
607 }
608 }),
609 )
610 .route("/docs", get(move || async move { html_doc }));
611
612 #[cfg(feature = "embed_elements")]
613 {
614 router = router.route(
615 "/docs/assets/{*file}",
616 get(crate::assets::serve_elements_asset),
617 );
618 }
619
620 Ok(router)
621 }
622}
623
624#[async_trait]
626impl modkit::Module for ApiGateway {
627 async fn init(&self, ctx: &modkit::context::ModuleCtx) -> anyhow::Result<()> {
628 let cfg = ctx.config::<crate::config::ApiGatewayConfig>()?;
629 self.config.store(Arc::new(cfg.clone()));
630
631 debug!(
632 "Effective api_gateway configuration:\n{:#?}",
633 self.config.load()
634 );
635
636 if cfg.auth_disabled {
637 tracing::info!(
638 tenant_id = %DEFAULT_TENANT_ID,
639 "Auth-disabled mode enabled with default tenant"
640 );
641 } else {
642 let authn_client = ctx.client_hub().get::<dyn AuthNResolverClient>()?;
644 *self.authn_client.lock() = Some(authn_client);
645 tracing::info!("AuthN Resolver client resolved from ClientHub");
646 }
647
648 Ok(())
649 }
650}
651
652impl modkit::contracts::ApiGatewayCapability for ApiGateway {
654 fn rest_prepare(
655 &self,
656 _ctx: &modkit::context::ModuleCtx,
657 router: axum::Router,
658 ) -> anyhow::Result<axum::Router> {
659 let router = router
663 .route("/health", get(web::health_check))
664 .route("/healthz", get(|| async { "ok" }));
665
666 tracing::debug!("REST host prepared base router with health check endpoints");
668 Ok(router)
669 }
670
671 fn rest_finalize(
672 &self,
673 _ctx: &modkit::context::ModuleCtx,
674 mut router: axum::Router,
675 ) -> anyhow::Result<axum::Router> {
676 let config = self.get_cached_config();
677
678 if config.enable_docs {
679 router = self.add_openapi_routes(router)?;
680 }
681
682 tracing::debug!("Applying middleware stack to finalized router");
684 let authn_client = self.authn_client.lock().clone();
685 router = self.apply_middleware_stack(router, authn_client)?;
686
687 let prefix = Self::normalize_prefix_path(&config.prefix_path)?;
688 router = Self::apply_prefix_nesting(router, &prefix);
689
690 *self.final_router.lock() = Some(router.clone());
692
693 tracing::info!("REST host finalized router with OpenAPI endpoints and auth middleware");
694 Ok(router)
695 }
696
697 fn as_registry(&self) -> &dyn modkit::contracts::OpenApiRegistry {
698 self
699 }
700}
701
702impl modkit::contracts::RestApiCapability for ApiGateway {
703 fn register_rest(
704 &self,
705 _ctx: &modkit::context::ModuleCtx,
706 router: axum::Router,
707 _openapi: &dyn modkit::contracts::OpenApiRegistry,
708 ) -> anyhow::Result<axum::Router> {
709 Ok(router)
712 }
713}
714
715impl OpenApiRegistry for ApiGateway {
716 fn register_operation(&self, spec: &modkit::api::OperationSpec) {
717 if self.check_duplicate_handler(spec) {
719 return;
720 }
721
722 if self.check_duplicate_route(spec) {
723 return;
724 }
725
726 self.openapi_registry.register_operation(spec);
728 self.log_operation_registration(spec);
729 }
730
731 fn ensure_schema_raw(
732 &self,
733 root_name: &str,
734 schemas: Vec<(
735 String,
736 utoipa::openapi::RefOr<utoipa::openapi::schema::Schema>,
737 )>,
738 ) -> String {
739 self.openapi_registry.ensure_schema_raw(root_name, schemas)
741 }
742
743 fn as_any(&self) -> &dyn std::any::Any {
744 self
745 }
746}
747
748#[cfg(test)]
749#[cfg_attr(coverage_nightly, coverage(off))]
750mod tests {
751 use super::*;
752
753 #[test]
754 fn test_openapi_generation() {
755 let mut config = ApiGatewayConfig::default();
756 config.openapi.title = "Test API".to_owned();
757 config.openapi.version = "1.0.0".to_owned();
758 config.openapi.description = Some("Test Description".to_owned());
759 let api = ApiGateway::new(config);
760
761 let doc = api.build_openapi().unwrap();
763 let json = serde_json::to_value(&doc).unwrap();
764
765 assert!(json.get("openapi").is_some());
767 assert!(json.get("info").is_some());
768 assert!(json.get("paths").is_some());
769
770 let info = json.get("info").unwrap();
772 assert_eq!(info.get("title").unwrap(), "Test API");
773 assert_eq!(info.get("version").unwrap(), "1.0.0");
774 assert_eq!(info.get("description").unwrap(), "Test Description");
775 }
776}
777
778#[cfg(test)]
779#[cfg_attr(coverage_nightly, coverage(off))]
780mod normalize_prefix_path_tests {
781 use super::*;
782
783 #[test]
784 fn empty_string_returns_empty() {
785 assert_eq!(ApiGateway::normalize_prefix_path("").unwrap(), "");
786 }
787
788 #[test]
789 fn sole_slash_returns_empty() {
790 assert_eq!(ApiGateway::normalize_prefix_path("/").unwrap(), "");
791 }
792
793 #[test]
794 fn multiple_slashes_return_empty() {
795 assert_eq!(ApiGateway::normalize_prefix_path("///").unwrap(), "");
796 }
797
798 #[test]
799 fn whitespace_only_returns_empty() {
800 assert_eq!(ApiGateway::normalize_prefix_path(" ").unwrap(), "");
801 }
802
803 #[test]
804 fn simple_prefix_preserved() {
805 assert_eq!(ApiGateway::normalize_prefix_path("/cf").unwrap(), "/cf");
806 }
807
808 #[test]
809 fn trailing_slash_stripped() {
810 assert_eq!(ApiGateway::normalize_prefix_path("/cf/").unwrap(), "/cf");
811 }
812
813 #[test]
814 fn leading_slash_prepended_when_missing() {
815 assert_eq!(ApiGateway::normalize_prefix_path("cf").unwrap(), "/cf");
816 }
817
818 #[test]
819 fn consecutive_leading_slashes_collapsed() {
820 assert_eq!(ApiGateway::normalize_prefix_path("//cf").unwrap(), "/cf");
821 }
822
823 #[test]
824 fn consecutive_slashes_mid_path_collapsed() {
825 assert_eq!(
826 ApiGateway::normalize_prefix_path("/api//v1").unwrap(),
827 "/api/v1"
828 );
829 }
830
831 #[test]
832 fn many_consecutive_slashes_collapsed() {
833 assert_eq!(
834 ApiGateway::normalize_prefix_path("///api///v1///").unwrap(),
835 "/api/v1"
836 );
837 }
838
839 #[test]
840 fn surrounding_whitespace_trimmed() {
841 assert_eq!(ApiGateway::normalize_prefix_path(" /cf ").unwrap(), "/cf");
842 }
843
844 #[test]
845 fn nested_path_preserved() {
846 assert_eq!(
847 ApiGateway::normalize_prefix_path("/api/v1").unwrap(),
848 "/api/v1"
849 );
850 }
851
852 #[test]
853 fn dot_in_path_allowed() {
854 assert_eq!(
855 ApiGateway::normalize_prefix_path("/api/v1.0").unwrap(),
856 "/api/v1.0"
857 );
858 }
859
860 #[test]
861 fn rejects_html_injection() {
862 let result = ApiGateway::normalize_prefix_path(r#""><script>alert(1)</script>"#);
863 assert!(result.is_err());
864 }
865
866 #[test]
867 fn rejects_spaces_in_path() {
868 let result = ApiGateway::normalize_prefix_path("/my path");
869 assert!(result.is_err());
870 }
871
872 #[test]
873 fn rejects_query_string_chars() {
874 let result = ApiGateway::normalize_prefix_path("/api?foo=bar");
875 assert!(result.is_err());
876 }
877}
878
879#[cfg(test)]
880#[cfg_attr(coverage_nightly, coverage(off))]
881mod problem_openapi_tests {
882 use super::*;
883 use axum::Json;
884 use modkit::api::{Missing, OperationBuilder};
885 use serde_json::Value;
886
887 async fn dummy_handler() -> Json<Value> {
888 Json(serde_json::json!({"ok": true}))
889 }
890
891 #[tokio::test]
892 async fn openapi_includes_problem_schema_and_response() {
893 let api = ApiGateway::default();
894 let router = axum::Router::new();
895
896 let _router = OperationBuilder::<Missing, Missing, ()>::get("/tests/v1/problem-demo")
898 .public()
899 .summary("Problem demo")
900 .problem_response(&api, http::StatusCode::BAD_REQUEST, "Bad Request") .handler(dummy_handler)
902 .register(router, &api);
903
904 let doc = api.build_openapi().expect("openapi");
905 let v = serde_json::to_value(&doc).expect("json");
906
907 let problem = v
909 .pointer("/components/schemas/Problem")
910 .expect("Problem schema missing");
911 assert!(
912 problem.get("$ref").is_none(),
913 "Problem must be a real object, not a self-ref"
914 );
915
916 let path_obj = v
918 .pointer("/paths/~1tests~1v1~1problem-demo/get/responses/400")
919 .expect("400 response missing");
920
921 let content_obj = path_obj.get("content").expect("content object missing");
923 assert!(
924 content_obj.get("application/problem+json").is_some(),
925 "application/problem+json content missing. Available content: {}",
926 serde_json::to_string_pretty(content_obj).unwrap()
927 );
928
929 let content = path_obj
930 .pointer("/content/application~1problem+json")
931 .expect("application/problem+json content missing");
932 let schema_ref = content
934 .pointer("/schema/$ref")
935 .and_then(|r| r.as_str())
936 .unwrap_or("");
937 assert_eq!(schema_ref, "#/components/schemas/Problem");
938 }
939}
940
941#[cfg(test)]
942#[cfg_attr(coverage_nightly, coverage(off))]
943mod sse_openapi_tests {
944 use super::*;
945 use axum::Json;
946 use modkit::api::{Missing, OperationBuilder};
947 use serde_json::Value;
948
949 #[derive(Clone)]
950 #[modkit_macros::api_dto(request, response)]
951 struct UserEvent {
952 id: u32,
953 message: String,
954 }
955
956 async fn sse_handler() -> axum::response::sse::Sse<
957 impl futures_core::Stream<Item = Result<axum::response::sse::Event, std::convert::Infallible>>,
958 > {
959 let b = modkit::SseBroadcaster::<UserEvent>::new(4);
960 b.sse_response()
961 }
962
963 #[tokio::test]
964 async fn openapi_has_sse_content() {
965 let api = ApiGateway::default();
966 let router = axum::Router::new();
967
968 let _router = OperationBuilder::<Missing, Missing, ()>::get("/tests/v1/demo/sse")
969 .summary("Demo SSE")
970 .handler(sse_handler)
971 .public()
972 .sse_json::<UserEvent>(&api, "SSE of UserEvent")
973 .register(router, &api);
974
975 let doc = api.build_openapi().expect("openapi");
976 let v = serde_json::to_value(&doc).expect("json");
977
978 let schema = v
980 .pointer("/components/schemas/UserEvent")
981 .expect("UserEvent missing");
982 assert!(schema.get("$ref").is_none());
983
984 let refp = v
986 .pointer("/paths/~1tests~1v1~1demo~1sse/get/responses/200/content/text~1event-stream/schema/$ref")
987 .and_then(|x| x.as_str())
988 .unwrap_or_default();
989 assert_eq!(refp, "#/components/schemas/UserEvent");
990 }
991
992 #[tokio::test]
993 async fn openapi_sse_additional_response() {
994 async fn mixed_handler() -> Json<Value> {
995 Json(serde_json::json!({"ok": true}))
996 }
997
998 let api = ApiGateway::default();
999 let router = axum::Router::new();
1000
1001 let _router = OperationBuilder::<Missing, Missing, ()>::get("/tests/v1/demo/mixed")
1002 .summary("Mixed responses")
1003 .public()
1004 .handler(mixed_handler)
1005 .json_response(http::StatusCode::OK, "Success response")
1006 .sse_json::<UserEvent>(&api, "Additional SSE stream")
1007 .register(router, &api);
1008
1009 let doc = api.build_openapi().expect("openapi");
1010 let v = serde_json::to_value(&doc).expect("json");
1011
1012 let responses = v
1014 .pointer("/paths/~1tests~1v1~1demo~1mixed/get/responses")
1015 .expect("responses");
1016
1017 assert!(responses.get("200").is_some());
1019
1020 let response_content = responses.get("200").and_then(|r| r.get("content"));
1022 assert!(response_content.is_some());
1023
1024 let schema = v
1026 .pointer("/components/schemas/UserEvent")
1027 .expect("UserEvent missing");
1028 assert!(schema.get("$ref").is_none());
1029 }
1030
1031 #[tokio::test]
1032 async fn test_axum_to_openapi_path_conversion() {
1033 async fn user_handler() -> Json<Value> {
1035 Json(serde_json::json!({"user_id": "123"}))
1036 }
1037
1038 let api = ApiGateway::default();
1039 let router = axum::Router::new();
1040
1041 let _router = OperationBuilder::<Missing, Missing, ()>::get("/tests/v1/users/{id}")
1042 .summary("Get user by ID")
1043 .public()
1044 .path_param("id", "User ID")
1045 .handler(user_handler)
1046 .json_response(http::StatusCode::OK, "User details")
1047 .register(router, &api);
1048
1049 let ops: Vec<_> = api
1051 .openapi_registry
1052 .operation_specs
1053 .iter()
1054 .map(|e| e.value().clone())
1055 .collect();
1056 assert_eq!(ops.len(), 1);
1057 assert_eq!(ops[0].path, "/tests/v1/users/{id}");
1058
1059 let doc = api.build_openapi().expect("openapi");
1061 let v = serde_json::to_value(&doc).expect("json");
1062
1063 let paths = v.get("paths").expect("paths");
1064 assert!(
1065 paths.get("/tests/v1/users/{id}").is_some(),
1066 "OpenAPI should use {{id}} placeholder"
1067 );
1068 }
1069
1070 #[tokio::test]
1071 async fn test_multiple_path_params_conversion() {
1072 async fn item_handler() -> Json<Value> {
1073 Json(serde_json::json!({"ok": true}))
1074 }
1075
1076 let api = ApiGateway::default();
1077 let router = axum::Router::new();
1078
1079 let _router = OperationBuilder::<Missing, Missing, ()>::get(
1080 "/tests/v1/projects/{project_id}/items/{item_id}",
1081 )
1082 .summary("Get project item")
1083 .public()
1084 .path_param("project_id", "Project ID")
1085 .path_param("item_id", "Item ID")
1086 .handler(item_handler)
1087 .json_response(http::StatusCode::OK, "Item details")
1088 .register(router, &api);
1089
1090 let ops: Vec<_> = api
1092 .openapi_registry
1093 .operation_specs
1094 .iter()
1095 .map(|e| e.value().clone())
1096 .collect();
1097 assert_eq!(
1098 ops[0].path,
1099 "/tests/v1/projects/{project_id}/items/{item_id}"
1100 );
1101
1102 let doc = api.build_openapi().expect("openapi");
1103 let v = serde_json::to_value(&doc).expect("json");
1104 let paths = v.get("paths").expect("paths");
1105 assert!(
1106 paths
1107 .get("/tests/v1/projects/{project_id}/items/{item_id}")
1108 .is_some()
1109 );
1110 }
1111
1112 #[tokio::test]
1113 async fn test_wildcard_path_conversion() {
1114 async fn static_handler() -> Json<Value> {
1115 Json(serde_json::json!({"ok": true}))
1116 }
1117
1118 let api = ApiGateway::default();
1119 let router = axum::Router::new();
1120
1121 let _router = OperationBuilder::<Missing, Missing, ()>::get("/tests/v1/static/{*path}")
1123 .summary("Serve static files")
1124 .public()
1125 .handler(static_handler)
1126 .json_response(http::StatusCode::OK, "File content")
1127 .register(router, &api);
1128
1129 let ops: Vec<_> = api
1131 .openapi_registry
1132 .operation_specs
1133 .iter()
1134 .map(|e| e.value().clone())
1135 .collect();
1136 assert_eq!(ops[0].path, "/tests/v1/static/{*path}");
1137
1138 let doc = api.build_openapi().expect("openapi");
1140 let v = serde_json::to_value(&doc).expect("json");
1141 let paths = v.get("paths").expect("paths");
1142 assert!(
1143 paths.get("/tests/v1/static/{path}").is_some(),
1144 "Wildcard {{*path}} should be converted to {{path}} in OpenAPI"
1145 );
1146 assert!(
1147 paths.get("/static/{*path}").is_none(),
1148 "OpenAPI should not have Axum-style {{*path}}"
1149 );
1150 }
1151
1152 #[tokio::test]
1153 async fn test_multipart_file_upload_openapi() {
1154 async fn upload_handler() -> Json<Value> {
1155 Json(serde_json::json!({"uploaded": true}))
1156 }
1157
1158 let api = ApiGateway::default();
1159 let router = axum::Router::new();
1160
1161 let _router = OperationBuilder::<Missing, Missing, ()>::post("/tests/v1/files/upload")
1162 .operation_id("upload_file")
1163 .public()
1164 .summary("Upload a file")
1165 .multipart_file_request("file", Some("File to upload"))
1166 .handler(upload_handler)
1167 .json_response(http::StatusCode::OK, "Upload successful")
1168 .register(router, &api);
1169
1170 let doc = api.build_openapi().expect("openapi");
1172 let v = serde_json::to_value(&doc).expect("json");
1173
1174 let paths = v.get("paths").expect("paths");
1175 let upload_path = paths
1176 .get("/tests/v1/files/upload")
1177 .expect("/tests/v1/files/upload path");
1178 let post_op = upload_path.get("post").expect("POST operation");
1179
1180 let request_body = post_op.get("requestBody").expect("requestBody");
1182 let content = request_body.get("content").expect("content");
1183 let multipart = content
1184 .get("multipart/form-data")
1185 .expect("multipart/form-data content type");
1186
1187 let schema = multipart.get("schema").expect("schema");
1189 assert_eq!(
1190 schema.get("type").and_then(|v| v.as_str()),
1191 Some("object"),
1192 "Schema should be of type object"
1193 );
1194
1195 let properties = schema.get("properties").expect("properties");
1197 let file_prop = properties.get("file").expect("file property");
1198 assert_eq!(
1199 file_prop.get("type").and_then(|v| v.as_str()),
1200 Some("string"),
1201 "File field should be of type string"
1202 );
1203 assert_eq!(
1204 file_prop.get("format").and_then(|v| v.as_str()),
1205 Some("binary"),
1206 "File field should have format binary"
1207 );
1208
1209 let required = schema.get("required").expect("required");
1211 let required_arr = required.as_array().expect("required should be array");
1212 assert_eq!(required_arr.len(), 1);
1213 assert_eq!(required_arr[0].as_str(), Some("file"));
1214 }
1215}