1use async_trait::async_trait;
6use std::sync::Arc;
7
8use arc_swap::ArcSwap;
9use dashmap::DashMap;
10
11use anyhow::Result;
12use axum::http::Method;
13use axum::middleware::from_fn_with_state;
14use axum::{Router, extract::DefaultBodyLimit, middleware::from_fn, routing::get};
15use modkit::api::{OpenApiRegistry, OpenApiRegistryImpl};
16use modkit::lifecycle::ReadySignal;
17use parking_lot::Mutex;
18use std::net::SocketAddr;
19use std::time::Duration;
20use tokio_util::sync::CancellationToken;
21use tower_http::{
22 catch_panic::CatchPanicLayer,
23 limit::RequestBodyLimitLayer,
24 request_id::{PropagateRequestIdLayer, SetRequestIdLayer},
25 timeout::TimeoutLayer,
26};
27use tracing::debug;
28
29use authn_resolver_sdk::AuthNResolverClient;
30
31use crate::config::ApiGatewayConfig;
32use crate::middleware::auth;
33use modkit_security::SecurityContext;
34use modkit_security::constants::{DEFAULT_SUBJECT_ID, DEFAULT_TENANT_ID};
35
36use crate::middleware;
37use crate::router_cache::RouterCache;
38use crate::web;
39
40#[modkit::module(
43 name = "api-gateway",
44 capabilities = [rest_host, rest, stateful],
45 deps = ["grpc-hub", "authn-resolver"],
46 lifecycle(entry = "serve", stop_timeout = "30s", await_ready)
47)]
48pub struct ApiGateway {
49 pub(crate) config: ArcSwap<ApiGatewayConfig>,
51 pub(crate) openapi_registry: Arc<OpenApiRegistryImpl>,
53 pub(crate) router_cache: RouterCache<axum::Router>,
55 pub(crate) final_router: Mutex<Option<axum::Router>>,
57 pub(crate) authn_client: Mutex<Option<Arc<dyn AuthNResolverClient>>>,
59
60 pub(crate) registered_routes: DashMap<(Method, String), ()>,
62 pub(crate) registered_handlers: DashMap<String, ()>,
63}
64
65impl Default for ApiGateway {
66 fn default() -> Self {
67 let default_router = Router::new();
68 Self {
69 config: ArcSwap::from_pointee(ApiGatewayConfig::default()),
70 openapi_registry: Arc::new(OpenApiRegistryImpl::new()),
71 router_cache: RouterCache::new(default_router),
72 final_router: Mutex::new(None),
73 authn_client: Mutex::new(None),
74 registered_routes: DashMap::new(),
75 registered_handlers: DashMap::new(),
76 }
77 }
78}
79
80impl ApiGateway {
81 fn apply_prefix_nesting(mut router: Router, prefix: &str) -> Router {
82 if prefix.is_empty() {
83 return router;
84 }
85
86 let top = Router::new()
87 .route("/health", get(web::health_check))
88 .route("/healthz", get(|| async { "ok" }));
89
90 router = Router::new().nest(prefix, router);
91 top.merge(router)
92 }
93
94 #[must_use]
96 pub fn new(config: ApiGatewayConfig) -> Self {
97 let default_router = Router::new();
98 Self {
99 config: ArcSwap::from_pointee(config),
100 openapi_registry: Arc::new(OpenApiRegistryImpl::new()),
101 router_cache: RouterCache::new(default_router),
102 final_router: Mutex::new(None),
103 authn_client: Mutex::new(None),
104 registered_routes: DashMap::new(),
105 registered_handlers: DashMap::new(),
106 }
107 }
108
109 pub fn get_config(&self) -> ApiGatewayConfig {
111 (**self.config.load()).clone()
112 }
113
114 pub fn get_cached_config(&self) -> ApiGatewayConfig {
116 (**self.config.load()).clone()
117 }
118
119 pub fn get_cached_router(&self) -> Arc<Router> {
121 self.router_cache.load()
122 }
123
124 pub fn rebuild_and_cache_router(&self) -> Result<()> {
129 let new_router = self.build_router()?;
130 self.router_cache.store(new_router);
131 Ok(())
132 }
133
134 fn build_route_policy_from_specs(&self) -> Result<auth::GatewayRoutePolicy> {
136 let mut authenticated_routes = std::collections::HashSet::new();
137 let mut public_routes = std::collections::HashSet::new();
138
139 public_routes.insert((Method::GET, "/health".to_owned()));
141 public_routes.insert((Method::GET, "/healthz".to_owned()));
142
143 public_routes.insert((Method::GET, "/docs".to_owned()));
144 public_routes.insert((Method::GET, "/openapi.json".to_owned()));
145
146 for spec in &self.openapi_registry.operation_specs {
147 let spec = spec.value();
148
149 let route_key = (spec.method.clone(), spec.path.clone());
150
151 if spec.authenticated {
152 authenticated_routes.insert(route_key.clone());
153 }
154
155 if spec.is_public {
156 public_routes.insert(route_key);
157 }
158 }
159
160 let config = self.get_cached_config();
161 let requirements_count = authenticated_routes.len();
162 let public_routes_count = public_routes.len();
163
164 let route_policy = auth::build_route_policy(&config, authenticated_routes, public_routes)?;
165
166 tracing::info!(
167 auth_disabled = config.auth_disabled,
168 require_auth_by_default = config.require_auth_by_default,
169 requirements_count = requirements_count,
170 public_routes_count = public_routes_count,
171 "Route policy built from operation specs"
172 );
173
174 Ok(route_policy)
175 }
176
177 fn normalize_prefix_path(raw: &str) -> Result<String> {
178 let trimmed = raw.trim();
179 let collapsed: String =
181 trimmed
182 .chars()
183 .fold(String::with_capacity(trimmed.len()), |mut acc, c| {
184 if c == '/' && acc.ends_with('/') {
185 } else {
187 acc.push(c);
188 }
189 acc
190 });
191 let prefix = collapsed.trim_end_matches('/');
192 let result = if prefix.is_empty() {
193 String::new()
194 } else if prefix.starts_with('/') {
195 prefix.to_owned()
196 } else {
197 format!("/{prefix}")
198 };
199 if !result
201 .bytes()
202 .all(|b| b.is_ascii_alphanumeric() || b == b'/' || b == b'_' || b == b'-' || b == b'.')
203 {
204 anyhow::bail!(
205 "prefix_path contains invalid characters (must match [a-zA-Z0-9/_\\-.]): {raw:?}"
206 );
207 }
208
209 if result.split('/').any(|seg| seg == "." || seg == "..") {
210 anyhow::bail!("prefix_path must not contain '.' or '..' segments: {raw:?}");
211 }
212
213 Ok(result)
214 }
215
216 pub(crate) fn apply_middleware_stack(
218 &self,
219 mut router: Router,
220 authn_client: Option<Arc<dyn AuthNResolverClient>>,
221 ) -> Result<Router> {
222 let route_policy = self.build_route_policy_from_specs()?;
224
225 router = router.route_layer(from_fn(middleware::http_metrics::propagate_matched_path));
241
242 let config = self.get_cached_config();
243
244 let specs: Vec<_> = self
246 .openapi_registry
247 .operation_specs
248 .iter()
249 .map(|e| e.value().clone())
250 .collect();
251
252 let license_map = middleware::license_validation::LicenseRequirementMap::from_specs(&specs);
254
255 router = router.layer(from_fn(
256 move |req: axum::extract::Request, next: axum::middleware::Next| {
257 let map = license_map.clone();
258 middleware::license_validation::license_validation_middleware(map, req, next)
259 },
260 ));
261
262 if config.auth_disabled {
264 let default_security_context = SecurityContext::builder()
266 .subject_id(DEFAULT_SUBJECT_ID)
267 .subject_tenant_id(DEFAULT_TENANT_ID)
268 .build()?;
269
270 tracing::warn!(
271 "API Gateway auth is DISABLED: all requests will run with default tenant SecurityContext. \
272 This mode bypasses authentication and is intended ONLY for single-user on-premises deployments without an IdP. \
273 Permission checks and secure ORM still apply. DO NOT use this mode in multi-tenant or production environments."
274 );
275 router = router.layer(from_fn(
276 move |mut req: axum::extract::Request, next: axum::middleware::Next| {
277 let sec_context = default_security_context.clone();
278 async move {
279 req.extensions_mut().insert(sec_context);
280 next.run(req).await
281 }
282 },
283 ));
284 } else if let Some(client) = authn_client {
285 let auth_state = auth::AuthState {
286 authn_client: client,
287 route_policy,
288 };
289 router = router.layer(from_fn_with_state(auth_state, auth::authn_middleware));
290 } else {
291 return Err(anyhow::anyhow!(
292 "auth is enabled but no AuthN Resolver client is available; \
293 ensure `authn_resolver` module is loaded or set `auth_disabled: true`"
294 ));
295 }
296
297 router = router.layer(from_fn(modkit::api::error_layer::error_mapping_middleware));
299
300 let rate_map = middleware::rate_limit::RateLimiterMap::from_specs(&specs, &config)?;
302
303 router = router.layer(from_fn(
304 move |req: axum::extract::Request, next: axum::middleware::Next| {
305 let map = rate_map.clone();
306 middleware::rate_limit::rate_limit_middleware(map, req, next)
307 },
308 ));
309
310 let mime_map = middleware::mime_validation::build_mime_validation_map(&specs);
312 router = router.layer(from_fn(
313 move |req: axum::extract::Request, next: axum::middleware::Next| {
314 let map = mime_map.clone();
315 middleware::mime_validation::mime_validation_middleware(map, req, next)
316 },
317 ));
318
319 if config.cors_enabled {
321 router = router.layer(crate::cors::build_cors_layer(&config));
322 }
323
324 router = router.layer(RequestBodyLimitLayer::new(config.defaults.body_limit_bytes));
326 router = router.layer(DefaultBodyLimit::max(config.defaults.body_limit_bytes));
327
328 router = router.layer(TimeoutLayer::with_status_code(
330 axum::http::StatusCode::GATEWAY_TIMEOUT,
331 Duration::from_secs(30),
332 ));
333
334 router = router.layer(CatchPanicLayer::new());
336
337 let http_metrics = Arc::new(middleware::http_metrics::HttpMetrics::new(
339 Self::MODULE_NAME,
340 &config.metrics.prefix,
341 ));
342 router = router.layer(from_fn_with_state(
343 http_metrics,
344 middleware::http_metrics::http_metrics_middleware,
345 ));
346
347 router = router.layer(from_fn(middleware::request_id::push_req_id_to_extensions));
349
350 router = router.layer({
352 use modkit_http::otel;
353 use tower_http::trace::TraceLayer;
354 use tracing::field::Empty;
355
356 TraceLayer::new_for_http()
357 .make_span_with(move |req: &axum::http::Request<axum::body::Body>| {
358 let hdr = middleware::request_id::header();
359 let rid = req
360 .headers()
361 .get(&hdr)
362 .and_then(|v| v.to_str().ok())
363 .unwrap_or("n/a");
364
365 let span = tracing::info_span!(
366 "http_request",
367 method = %req.method(),
368 uri = %req.uri().path(),
369 version = ?req.version(),
370 module = "api_gateway",
371 endpoint = %req.uri().path(),
372 request_id = %rid,
373 status = Empty,
374 latency_ms = Empty,
375 "http.method" = %req.method(),
377 "http.target" = %req.uri().path(),
378 "http.scheme" = req.uri().scheme_str().unwrap_or("http"),
379 "http.host" = req.headers().get("host")
380 .and_then(|h| h.to_str().ok())
381 .unwrap_or("unknown"),
382 "user_agent.original" = req.headers().get("user-agent")
383 .and_then(|h| h.to_str().ok())
384 .unwrap_or("unknown"),
385 trace_id = Empty,
387 parent.trace_id = Empty
388 );
389
390 otel::set_parent_from_headers(&span, req.headers());
393
394 span
395 })
396 .on_response(
397 |res: &axum::http::Response<axum::body::Body>,
398 latency: std::time::Duration,
399 span: &tracing::Span| {
400 let ms = latency.as_millis();
401 span.record("status", res.status().as_u16());
402 span.record("latency_ms", ms);
403 },
404 )
405 });
406
407 let x_request_id = crate::middleware::request_id::header();
409 router = router.layer(PropagateRequestIdLayer::new(x_request_id.clone()));
411 router = router.layer(SetRequestIdLayer::new(
412 x_request_id,
413 crate::middleware::request_id::MakeReqId,
414 ));
415
416 Ok(router)
417 }
418
419 pub fn build_router(&self) -> Result<Router> {
424 let cached_router = self.router_cache.load();
427 if Arc::strong_count(&cached_router) > 1 {
428 tracing::debug!("Using cached router");
429 return Ok((*cached_router).clone());
430 }
431
432 tracing::debug!("Building new router (standalone/fallback mode)");
433 let mut router = Router::new()
436 .route("/health", get(web::health_check))
437 .route("/healthz", get(|| async { "ok" }));
438
439 let authn_client = self.authn_client.lock().clone();
441 router = self.apply_middleware_stack(router, authn_client)?;
442
443 let config = self.get_cached_config();
444 let prefix = Self::normalize_prefix_path(&config.prefix_path)?;
445 router = Self::apply_prefix_nesting(router, &prefix);
446
447 self.router_cache.store(router.clone());
449
450 Ok(router)
451 }
452
453 pub fn build_openapi(&self) -> Result<utoipa::openapi::OpenApi> {
458 let config = self.get_cached_config();
459 let info = modkit::api::OpenApiInfo {
460 title: config.openapi.title.clone(),
461 version: config.openapi.version.clone(),
462 description: config.openapi.description,
463 };
464 self.openapi_registry.build_openapi(&info)
465 }
466
467 fn parse_bind_address(bind_addr: &str) -> anyhow::Result<SocketAddr> {
469 bind_addr
470 .parse()
471 .map_err(|e| anyhow::anyhow!("Invalid bind address '{bind_addr}': {e}"))
472 }
473
474 fn get_or_build_router(self: &Arc<Self>) -> anyhow::Result<Router> {
476 let stored = { self.final_router.lock().take() };
477
478 if let Some(router) = stored {
479 tracing::debug!("Using router from REST phase");
480 Ok(router)
481 } else {
482 tracing::debug!("No router from REST phase, building default router");
483 self.build_router()
484 }
485 }
486
487 pub(crate) async fn serve(
492 self: Arc<Self>,
493 cancel: CancellationToken,
494 ready: ReadySignal,
495 ) -> anyhow::Result<()> {
496 let cfg = self.get_cached_config();
497 let addr = Self::parse_bind_address(&cfg.bind_addr)?;
498 let router = self.get_or_build_router()?;
499
500 let listener = tokio::net::TcpListener::bind(addr).await?;
502 tracing::info!("HTTP server bound on {}", addr);
503 ready.notify(); let shutdown = {
507 let cancel = cancel.clone();
508 async move {
509 cancel.cancelled().await;
510 tracing::info!("HTTP server shutting down gracefully (cancellation)");
511 }
512 };
513
514 axum::serve(listener, router)
515 .with_graceful_shutdown(shutdown)
516 .await
517 .map_err(|e| anyhow::anyhow!(e))
518 }
519
520 fn check_duplicate_handler(&self, spec: &modkit::api::OperationSpec) -> bool {
522 if self
523 .registered_handlers
524 .insert(spec.handler_id.clone(), ())
525 .is_some()
526 {
527 tracing::error!(
528 handler_id = %spec.handler_id,
529 method = %spec.method.as_str(),
530 path = %spec.path,
531 "Duplicate handler_id detected; ignoring subsequent registration"
532 );
533 return true;
534 }
535 false
536 }
537
538 fn check_duplicate_route(&self, spec: &modkit::api::OperationSpec) -> bool {
540 let route_key = (spec.method.clone(), spec.path.clone());
541 if self.registered_routes.insert(route_key, ()).is_some() {
542 tracing::error!(
543 method = %spec.method.as_str(),
544 path = %spec.path,
545 "Duplicate (method, path) detected; ignoring subsequent registration"
546 );
547 return true;
548 }
549 false
550 }
551
552 fn log_operation_registration(&self, spec: &modkit::api::OperationSpec) {
554 let current_count = self.openapi_registry.operation_specs.len();
555 tracing::debug!(
556 handler_id = %spec.handler_id,
557 method = %spec.method.as_str(),
558 path = %spec.path,
559 summary = %spec.summary.as_deref().unwrap_or("No summary"),
560 total_operations = current_count,
561 "Registered API operation"
562 );
563 }
564
565 fn add_openapi_routes(&self, mut router: axum::Router) -> anyhow::Result<axum::Router> {
567 let op_count = self.openapi_registry.operation_specs.len();
569 tracing::info!(
570 "rest_finalize: emitting OpenAPI with {} operations",
571 op_count
572 );
573
574 let openapi_doc = Arc::new(self.build_openapi()?);
575 let config = self.get_cached_config();
576 let prefix = Self::normalize_prefix_path(&config.prefix_path)?;
577 let html_doc = web::serve_docs(&prefix);
578
579 router = router
580 .route(
581 "/openapi.json",
582 get({
583 use axum::{http::header, response::IntoResponse};
584 let doc = openapi_doc;
585 move || async move {
586 let json_string = match serde_json::to_string_pretty(doc.as_ref()) {
587 Ok(json) => json,
588 Err(e) => {
589 tracing::error!("Failed to serialize OpenAPI doc: {}", e);
590 return (http::StatusCode::INTERNAL_SERVER_ERROR).into_response();
591 }
592 };
593 (
594 [
595 (header::CONTENT_TYPE, "application/json"),
596 (header::CACHE_CONTROL, "no-store"),
597 ],
598 json_string,
599 )
600 .into_response()
601 }
602 }),
603 )
604 .route("/docs", get(move || async move { html_doc }));
605
606 #[cfg(feature = "embed_elements")]
607 {
608 router = router.route(
609 "/docs/assets/{*file}",
610 get(crate::assets::serve_elements_asset),
611 );
612 }
613
614 Ok(router)
615 }
616}
617
618#[async_trait]
620impl modkit::Module for ApiGateway {
621 async fn init(&self, ctx: &modkit::context::ModuleCtx) -> anyhow::Result<()> {
622 let cfg = ctx.config::<crate::config::ApiGatewayConfig>()?;
623 self.config.store(Arc::new(cfg.clone()));
624
625 debug!(
626 "Effective api_gateway configuration:\n{:#?}",
627 self.config.load()
628 );
629
630 if cfg.auth_disabled {
631 tracing::info!(
632 tenant_id = %DEFAULT_TENANT_ID,
633 "Auth-disabled mode enabled with default tenant"
634 );
635 } else {
636 let authn_client = ctx.client_hub().get::<dyn AuthNResolverClient>()?;
638 *self.authn_client.lock() = Some(authn_client);
639 tracing::info!("AuthN Resolver client resolved from ClientHub");
640 }
641
642 Ok(())
643 }
644}
645
646impl modkit::contracts::ApiGatewayCapability for ApiGateway {
648 fn rest_prepare(
649 &self,
650 _ctx: &modkit::context::ModuleCtx,
651 router: axum::Router,
652 ) -> anyhow::Result<axum::Router> {
653 let router = router
657 .route("/health", get(web::health_check))
658 .route("/healthz", get(|| async { "ok" }));
659
660 tracing::debug!("REST host prepared base router with health check endpoints");
662 Ok(router)
663 }
664
665 fn rest_finalize(
666 &self,
667 _ctx: &modkit::context::ModuleCtx,
668 mut router: axum::Router,
669 ) -> anyhow::Result<axum::Router> {
670 let config = self.get_cached_config();
671
672 if config.enable_docs {
673 router = self.add_openapi_routes(router)?;
674 }
675
676 tracing::debug!("Applying middleware stack to finalized router");
678 let authn_client = self.authn_client.lock().clone();
679 router = self.apply_middleware_stack(router, authn_client)?;
680
681 let prefix = Self::normalize_prefix_path(&config.prefix_path)?;
682 router = Self::apply_prefix_nesting(router, &prefix);
683
684 *self.final_router.lock() = Some(router.clone());
686
687 tracing::info!("REST host finalized router with OpenAPI endpoints and auth middleware");
688 Ok(router)
689 }
690
691 fn as_registry(&self) -> &dyn modkit::contracts::OpenApiRegistry {
692 self
693 }
694}
695
696impl modkit::contracts::RestApiCapability for ApiGateway {
697 fn register_rest(
698 &self,
699 _ctx: &modkit::context::ModuleCtx,
700 router: axum::Router,
701 _openapi: &dyn modkit::contracts::OpenApiRegistry,
702 ) -> anyhow::Result<axum::Router> {
703 Ok(router)
706 }
707}
708
709impl OpenApiRegistry for ApiGateway {
710 fn register_operation(&self, spec: &modkit::api::OperationSpec) {
711 if self.check_duplicate_handler(spec) {
713 return;
714 }
715
716 if self.check_duplicate_route(spec) {
717 return;
718 }
719
720 self.openapi_registry.register_operation(spec);
722 self.log_operation_registration(spec);
723 }
724
725 fn ensure_schema_raw(
726 &self,
727 root_name: &str,
728 schemas: Vec<(
729 String,
730 utoipa::openapi::RefOr<utoipa::openapi::schema::Schema>,
731 )>,
732 ) -> String {
733 self.openapi_registry.ensure_schema_raw(root_name, schemas)
735 }
736
737 fn as_any(&self) -> &dyn std::any::Any {
738 self
739 }
740}
741
742#[cfg(test)]
743#[cfg_attr(coverage_nightly, coverage(off))]
744mod tests {
745 use super::*;
746
747 #[test]
748 fn test_openapi_generation() {
749 let mut config = ApiGatewayConfig::default();
750 config.openapi.title = "Test API".to_owned();
751 config.openapi.version = "1.0.0".to_owned();
752 config.openapi.description = Some("Test Description".to_owned());
753 let api = ApiGateway::new(config);
754
755 let doc = api.build_openapi().unwrap();
757 let json = serde_json::to_value(&doc).unwrap();
758
759 assert!(json.get("openapi").is_some());
761 assert!(json.get("info").is_some());
762 assert!(json.get("paths").is_some());
763
764 let info = json.get("info").unwrap();
766 assert_eq!(info.get("title").unwrap(), "Test API");
767 assert_eq!(info.get("version").unwrap(), "1.0.0");
768 assert_eq!(info.get("description").unwrap(), "Test Description");
769 }
770}
771
772#[cfg(test)]
773#[cfg_attr(coverage_nightly, coverage(off))]
774mod normalize_prefix_path_tests {
775 use super::*;
776
777 #[test]
778 fn empty_string_returns_empty() {
779 assert_eq!(ApiGateway::normalize_prefix_path("").unwrap(), "");
780 }
781
782 #[test]
783 fn sole_slash_returns_empty() {
784 assert_eq!(ApiGateway::normalize_prefix_path("/").unwrap(), "");
785 }
786
787 #[test]
788 fn multiple_slashes_return_empty() {
789 assert_eq!(ApiGateway::normalize_prefix_path("///").unwrap(), "");
790 }
791
792 #[test]
793 fn whitespace_only_returns_empty() {
794 assert_eq!(ApiGateway::normalize_prefix_path(" ").unwrap(), "");
795 }
796
797 #[test]
798 fn simple_prefix_preserved() {
799 assert_eq!(ApiGateway::normalize_prefix_path("/cf").unwrap(), "/cf");
800 }
801
802 #[test]
803 fn trailing_slash_stripped() {
804 assert_eq!(ApiGateway::normalize_prefix_path("/cf/").unwrap(), "/cf");
805 }
806
807 #[test]
808 fn leading_slash_prepended_when_missing() {
809 assert_eq!(ApiGateway::normalize_prefix_path("cf").unwrap(), "/cf");
810 }
811
812 #[test]
813 fn consecutive_leading_slashes_collapsed() {
814 assert_eq!(ApiGateway::normalize_prefix_path("//cf").unwrap(), "/cf");
815 }
816
817 #[test]
818 fn consecutive_slashes_mid_path_collapsed() {
819 assert_eq!(
820 ApiGateway::normalize_prefix_path("/api//v1").unwrap(),
821 "/api/v1"
822 );
823 }
824
825 #[test]
826 fn many_consecutive_slashes_collapsed() {
827 assert_eq!(
828 ApiGateway::normalize_prefix_path("///api///v1///").unwrap(),
829 "/api/v1"
830 );
831 }
832
833 #[test]
834 fn surrounding_whitespace_trimmed() {
835 assert_eq!(ApiGateway::normalize_prefix_path(" /cf ").unwrap(), "/cf");
836 }
837
838 #[test]
839 fn nested_path_preserved() {
840 assert_eq!(
841 ApiGateway::normalize_prefix_path("/api/v1").unwrap(),
842 "/api/v1"
843 );
844 }
845
846 #[test]
847 fn dot_in_path_allowed() {
848 assert_eq!(
849 ApiGateway::normalize_prefix_path("/api/v1.0").unwrap(),
850 "/api/v1.0"
851 );
852 }
853
854 #[test]
855 fn rejects_html_injection() {
856 let result = ApiGateway::normalize_prefix_path(r#""><script>alert(1)</script>"#);
857 assert!(result.is_err());
858 }
859
860 #[test]
861 fn rejects_spaces_in_path() {
862 let result = ApiGateway::normalize_prefix_path("/my path");
863 assert!(result.is_err());
864 }
865
866 #[test]
867 fn rejects_query_string_chars() {
868 let result = ApiGateway::normalize_prefix_path("/api?foo=bar");
869 assert!(result.is_err());
870 }
871}
872
873#[cfg(test)]
874#[cfg_attr(coverage_nightly, coverage(off))]
875mod problem_openapi_tests {
876 use super::*;
877 use axum::Json;
878 use modkit::api::{Missing, OperationBuilder};
879 use serde_json::Value;
880
881 async fn dummy_handler() -> Json<Value> {
882 Json(serde_json::json!({"ok": true}))
883 }
884
885 #[tokio::test]
886 async fn openapi_includes_problem_schema_and_response() {
887 let api = ApiGateway::default();
888 let router = axum::Router::new();
889
890 let _router = OperationBuilder::<Missing, Missing, ()>::get("/tests/v1/problem-demo")
892 .public()
893 .summary("Problem demo")
894 .problem_response(&api, http::StatusCode::BAD_REQUEST, "Bad Request") .handler(dummy_handler)
896 .register(router, &api);
897
898 let doc = api.build_openapi().expect("openapi");
899 let v = serde_json::to_value(&doc).expect("json");
900
901 let problem = v
903 .pointer("/components/schemas/Problem")
904 .expect("Problem schema missing");
905 assert!(
906 problem.get("$ref").is_none(),
907 "Problem must be a real object, not a self-ref"
908 );
909
910 let path_obj = v
912 .pointer("/paths/~1tests~1v1~1problem-demo/get/responses/400")
913 .expect("400 response missing");
914
915 let content_obj = path_obj.get("content").expect("content object missing");
917 assert!(
918 content_obj.get("application/problem+json").is_some(),
919 "application/problem+json content missing. Available content: {}",
920 serde_json::to_string_pretty(content_obj).unwrap()
921 );
922
923 let content = path_obj
924 .pointer("/content/application~1problem+json")
925 .expect("application/problem+json content missing");
926 let schema_ref = content
928 .pointer("/schema/$ref")
929 .and_then(|r| r.as_str())
930 .unwrap_or("");
931 assert_eq!(schema_ref, "#/components/schemas/Problem");
932 }
933}
934
935#[cfg(test)]
936#[cfg_attr(coverage_nightly, coverage(off))]
937mod sse_openapi_tests {
938 use super::*;
939 use axum::Json;
940 use modkit::api::{Missing, OperationBuilder};
941 use serde_json::Value;
942
943 #[derive(Clone)]
944 #[modkit_macros::api_dto(request, response)]
945 struct UserEvent {
946 id: u32,
947 message: String,
948 }
949
950 async fn sse_handler() -> axum::response::sse::Sse<
951 impl futures_core::Stream<Item = Result<axum::response::sse::Event, std::convert::Infallible>>,
952 > {
953 let b = modkit::SseBroadcaster::<UserEvent>::new(4);
954 b.sse_response()
955 }
956
957 #[tokio::test]
958 async fn openapi_has_sse_content() {
959 let api = ApiGateway::default();
960 let router = axum::Router::new();
961
962 let _router = OperationBuilder::<Missing, Missing, ()>::get("/tests/v1/demo/sse")
963 .summary("Demo SSE")
964 .handler(sse_handler)
965 .public()
966 .sse_json::<UserEvent>(&api, "SSE of UserEvent")
967 .register(router, &api);
968
969 let doc = api.build_openapi().expect("openapi");
970 let v = serde_json::to_value(&doc).expect("json");
971
972 let schema = v
974 .pointer("/components/schemas/UserEvent")
975 .expect("UserEvent missing");
976 assert!(schema.get("$ref").is_none());
977
978 let refp = v
980 .pointer("/paths/~1tests~1v1~1demo~1sse/get/responses/200/content/text~1event-stream/schema/$ref")
981 .and_then(|x| x.as_str())
982 .unwrap_or_default();
983 assert_eq!(refp, "#/components/schemas/UserEvent");
984 }
985
986 #[tokio::test]
987 async fn openapi_sse_additional_response() {
988 async fn mixed_handler() -> Json<Value> {
989 Json(serde_json::json!({"ok": true}))
990 }
991
992 let api = ApiGateway::default();
993 let router = axum::Router::new();
994
995 let _router = OperationBuilder::<Missing, Missing, ()>::get("/tests/v1/demo/mixed")
996 .summary("Mixed responses")
997 .public()
998 .handler(mixed_handler)
999 .json_response(http::StatusCode::OK, "Success response")
1000 .sse_json::<UserEvent>(&api, "Additional SSE stream")
1001 .register(router, &api);
1002
1003 let doc = api.build_openapi().expect("openapi");
1004 let v = serde_json::to_value(&doc).expect("json");
1005
1006 let responses = v
1008 .pointer("/paths/~1tests~1v1~1demo~1mixed/get/responses")
1009 .expect("responses");
1010
1011 assert!(responses.get("200").is_some());
1013
1014 let response_content = responses.get("200").and_then(|r| r.get("content"));
1016 assert!(response_content.is_some());
1017
1018 let schema = v
1020 .pointer("/components/schemas/UserEvent")
1021 .expect("UserEvent missing");
1022 assert!(schema.get("$ref").is_none());
1023 }
1024
1025 #[tokio::test]
1026 async fn test_axum_to_openapi_path_conversion() {
1027 async fn user_handler() -> Json<Value> {
1029 Json(serde_json::json!({"user_id": "123"}))
1030 }
1031
1032 let api = ApiGateway::default();
1033 let router = axum::Router::new();
1034
1035 let _router = OperationBuilder::<Missing, Missing, ()>::get("/tests/v1/users/{id}")
1036 .summary("Get user by ID")
1037 .public()
1038 .path_param("id", "User ID")
1039 .handler(user_handler)
1040 .json_response(http::StatusCode::OK, "User details")
1041 .register(router, &api);
1042
1043 let ops: Vec<_> = api
1045 .openapi_registry
1046 .operation_specs
1047 .iter()
1048 .map(|e| e.value().clone())
1049 .collect();
1050 assert_eq!(ops.len(), 1);
1051 assert_eq!(ops[0].path, "/tests/v1/users/{id}");
1052
1053 let doc = api.build_openapi().expect("openapi");
1055 let v = serde_json::to_value(&doc).expect("json");
1056
1057 let paths = v.get("paths").expect("paths");
1058 assert!(
1059 paths.get("/tests/v1/users/{id}").is_some(),
1060 "OpenAPI should use {{id}} placeholder"
1061 );
1062 }
1063
1064 #[tokio::test]
1065 async fn test_multiple_path_params_conversion() {
1066 async fn item_handler() -> Json<Value> {
1067 Json(serde_json::json!({"ok": true}))
1068 }
1069
1070 let api = ApiGateway::default();
1071 let router = axum::Router::new();
1072
1073 let _router = OperationBuilder::<Missing, Missing, ()>::get(
1074 "/tests/v1/projects/{project_id}/items/{item_id}",
1075 )
1076 .summary("Get project item")
1077 .public()
1078 .path_param("project_id", "Project ID")
1079 .path_param("item_id", "Item ID")
1080 .handler(item_handler)
1081 .json_response(http::StatusCode::OK, "Item details")
1082 .register(router, &api);
1083
1084 let ops: Vec<_> = api
1086 .openapi_registry
1087 .operation_specs
1088 .iter()
1089 .map(|e| e.value().clone())
1090 .collect();
1091 assert_eq!(
1092 ops[0].path,
1093 "/tests/v1/projects/{project_id}/items/{item_id}"
1094 );
1095
1096 let doc = api.build_openapi().expect("openapi");
1097 let v = serde_json::to_value(&doc).expect("json");
1098 let paths = v.get("paths").expect("paths");
1099 assert!(
1100 paths
1101 .get("/tests/v1/projects/{project_id}/items/{item_id}")
1102 .is_some()
1103 );
1104 }
1105
1106 #[tokio::test]
1107 async fn test_wildcard_path_conversion() {
1108 async fn static_handler() -> Json<Value> {
1109 Json(serde_json::json!({"ok": true}))
1110 }
1111
1112 let api = ApiGateway::default();
1113 let router = axum::Router::new();
1114
1115 let _router = OperationBuilder::<Missing, Missing, ()>::get("/tests/v1/static/{*path}")
1117 .summary("Serve static files")
1118 .public()
1119 .handler(static_handler)
1120 .json_response(http::StatusCode::OK, "File content")
1121 .register(router, &api);
1122
1123 let ops: Vec<_> = api
1125 .openapi_registry
1126 .operation_specs
1127 .iter()
1128 .map(|e| e.value().clone())
1129 .collect();
1130 assert_eq!(ops[0].path, "/tests/v1/static/{*path}");
1131
1132 let doc = api.build_openapi().expect("openapi");
1134 let v = serde_json::to_value(&doc).expect("json");
1135 let paths = v.get("paths").expect("paths");
1136 assert!(
1137 paths.get("/tests/v1/static/{path}").is_some(),
1138 "Wildcard {{*path}} should be converted to {{path}} in OpenAPI"
1139 );
1140 assert!(
1141 paths.get("/static/{*path}").is_none(),
1142 "OpenAPI should not have Axum-style {{*path}}"
1143 );
1144 }
1145
1146 #[tokio::test]
1147 async fn test_multipart_file_upload_openapi() {
1148 async fn upload_handler() -> Json<Value> {
1149 Json(serde_json::json!({"uploaded": true}))
1150 }
1151
1152 let api = ApiGateway::default();
1153 let router = axum::Router::new();
1154
1155 let _router = OperationBuilder::<Missing, Missing, ()>::post("/tests/v1/files/upload")
1156 .operation_id("upload_file")
1157 .public()
1158 .summary("Upload a file")
1159 .multipart_file_request("file", Some("File to upload"))
1160 .handler(upload_handler)
1161 .json_response(http::StatusCode::OK, "Upload successful")
1162 .register(router, &api);
1163
1164 let doc = api.build_openapi().expect("openapi");
1166 let v = serde_json::to_value(&doc).expect("json");
1167
1168 let paths = v.get("paths").expect("paths");
1169 let upload_path = paths
1170 .get("/tests/v1/files/upload")
1171 .expect("/tests/v1/files/upload path");
1172 let post_op = upload_path.get("post").expect("POST operation");
1173
1174 let request_body = post_op.get("requestBody").expect("requestBody");
1176 let content = request_body.get("content").expect("content");
1177 let multipart = content
1178 .get("multipart/form-data")
1179 .expect("multipart/form-data content type");
1180
1181 let schema = multipart.get("schema").expect("schema");
1183 assert_eq!(
1184 schema.get("type").and_then(|v| v.as_str()),
1185 Some("object"),
1186 "Schema should be of type object"
1187 );
1188
1189 let properties = schema.get("properties").expect("properties");
1191 let file_prop = properties.get("file").expect("file property");
1192 assert_eq!(
1193 file_prop.get("type").and_then(|v| v.as_str()),
1194 Some("string"),
1195 "File field should be of type string"
1196 );
1197 assert_eq!(
1198 file_prop.get("format").and_then(|v| v.as_str()),
1199 Some("binary"),
1200 "File field should have format binary"
1201 );
1202
1203 let required = schema.get("required").expect("required");
1205 let required_arr = required.as_array().expect("required should be array");
1206 assert_eq!(required_arr.len(), 1);
1207 assert_eq!(required_arr[0].as_str(), Some("file"));
1208 }
1209}