1use async_trait::async_trait;
6use std::sync::Arc;
7
8use arc_swap::ArcSwap;
9use dashmap::DashMap;
10
11use anyhow::Result;
12use axum::http::Method;
13use axum::middleware::from_fn_with_state;
14use axum::{Router, extract::DefaultBodyLimit, middleware::from_fn, routing::get};
15use modkit::api::{OpenApiRegistry, OpenApiRegistryImpl};
16use modkit::lifecycle::ReadySignal;
17use parking_lot::Mutex;
18use std::net::SocketAddr;
19use std::time::Duration;
20use tokio_util::sync::CancellationToken;
21use tower_http::{
22 catch_panic::CatchPanicLayer,
23 limit::RequestBodyLimitLayer,
24 request_id::{PropagateRequestIdLayer, SetRequestIdLayer},
25 timeout::TimeoutLayer,
26};
27use tracing::debug;
28
29use authn_resolver_sdk::AuthNResolverClient;
30
31use crate::config::ApiGatewayConfig;
32use crate::middleware::auth;
33use modkit_security::SecurityContext;
34use modkit_security::constants::{DEFAULT_SUBJECT_ID, DEFAULT_TENANT_ID};
35
36use crate::middleware;
37use crate::router_cache::RouterCache;
38use crate::web;
39
40#[modkit::module(
43 name = "api-gateway",
44 capabilities = [rest_host, rest, stateful],
45 deps = ["grpc-hub", "authn-resolver"],
46 lifecycle(entry = "serve", stop_timeout = "30s", await_ready)
47)]
48pub struct ApiGateway {
49 pub(crate) config: ArcSwap<ApiGatewayConfig>,
51 pub(crate) openapi_registry: Arc<OpenApiRegistryImpl>,
53 pub(crate) router_cache: RouterCache<axum::Router>,
55 pub(crate) final_router: Mutex<Option<axum::Router>>,
57 pub(crate) authn_client: Mutex<Option<Arc<dyn AuthNResolverClient>>>,
59
60 pub(crate) registered_routes: DashMap<(Method, String), ()>,
62 pub(crate) registered_handlers: DashMap<String, ()>,
63}
64
65impl Default for ApiGateway {
66 fn default() -> Self {
67 let default_router = Router::new();
68 Self {
69 config: ArcSwap::from_pointee(ApiGatewayConfig::default()),
70 openapi_registry: Arc::new(OpenApiRegistryImpl::new()),
71 router_cache: RouterCache::new(default_router),
72 final_router: Mutex::new(None),
73 authn_client: Mutex::new(None),
74 registered_routes: DashMap::new(),
75 registered_handlers: DashMap::new(),
76 }
77 }
78}
79
80impl ApiGateway {
81 fn apply_prefix_nesting(mut router: Router, prefix: &str) -> Router {
82 if prefix.is_empty() {
83 return router;
84 }
85
86 let top = Router::new()
87 .route("/health", get(web::health_check))
88 .route("/healthz", get(|| async { "ok" }));
89
90 router = Router::new().nest(prefix, router);
91 top.merge(router)
92 }
93
94 #[must_use]
96 pub fn new(config: ApiGatewayConfig) -> Self {
97 let default_router = Router::new();
98 Self {
99 config: ArcSwap::from_pointee(config),
100 openapi_registry: Arc::new(OpenApiRegistryImpl::new()),
101 router_cache: RouterCache::new(default_router),
102 final_router: Mutex::new(None),
103 authn_client: Mutex::new(None),
104 registered_routes: DashMap::new(),
105 registered_handlers: DashMap::new(),
106 }
107 }
108
109 pub fn get_config(&self) -> ApiGatewayConfig {
111 (**self.config.load()).clone()
112 }
113
114 pub fn get_cached_config(&self) -> ApiGatewayConfig {
116 (**self.config.load()).clone()
117 }
118
119 pub fn get_cached_router(&self) -> Arc<Router> {
121 self.router_cache.load()
122 }
123
124 pub fn rebuild_and_cache_router(&self) -> Result<()> {
129 let new_router = self.build_router()?;
130 self.router_cache.store(new_router);
131 Ok(())
132 }
133
134 fn build_route_policy_from_specs(&self) -> Result<auth::GatewayRoutePolicy> {
136 let mut authenticated_routes = std::collections::HashSet::new();
137 let mut public_routes = std::collections::HashSet::new();
138
139 public_routes.insert((Method::GET, "/health".to_owned()));
141 public_routes.insert((Method::GET, "/healthz".to_owned()));
142
143 public_routes.insert((Method::GET, "/docs".to_owned()));
144 public_routes.insert((Method::GET, "/openapi.json".to_owned()));
145
146 for spec in &self.openapi_registry.operation_specs {
147 let spec = spec.value();
148
149 let route_key = (spec.method.clone(), spec.path.clone());
150
151 if spec.authenticated {
152 authenticated_routes.insert(route_key.clone());
153 }
154
155 if spec.is_public {
156 public_routes.insert(route_key);
157 }
158 }
159
160 let config = self.get_cached_config();
161 let requirements_count = authenticated_routes.len();
162 let public_routes_count = public_routes.len();
163
164 let route_policy = auth::build_route_policy(&config, authenticated_routes, public_routes)?;
165
166 tracing::info!(
167 auth_disabled = config.auth_disabled,
168 require_auth_by_default = config.require_auth_by_default,
169 requirements_count = requirements_count,
170 public_routes_count = public_routes_count,
171 "Route policy built from operation specs"
172 );
173
174 Ok(route_policy)
175 }
176
177 fn normalize_prefix_path(raw: &str) -> Result<String> {
178 let trimmed = raw.trim();
179 let collapsed: String =
181 trimmed
182 .chars()
183 .fold(String::with_capacity(trimmed.len()), |mut acc, c| {
184 if c == '/' && acc.ends_with('/') {
185 } else {
187 acc.push(c);
188 }
189 acc
190 });
191 let prefix = collapsed.trim_end_matches('/');
192 let result = if prefix.is_empty() {
193 String::new()
194 } else if prefix.starts_with('/') {
195 prefix.to_owned()
196 } else {
197 format!("/{prefix}")
198 };
199 if !result
201 .bytes()
202 .all(|b| b.is_ascii_alphanumeric() || b == b'/' || b == b'_' || b == b'-' || b == b'.')
203 {
204 anyhow::bail!(
205 "prefix_path contains invalid characters (must match [a-zA-Z0-9/_\\-.]): {raw:?}"
206 );
207 }
208
209 if result.split('/').any(|seg| seg == "." || seg == "..") {
210 anyhow::bail!("prefix_path must not contain '.' or '..' segments: {raw:?}");
211 }
212
213 Ok(result)
214 }
215
216 pub(crate) fn apply_middleware_stack(
218 &self,
219 mut router: Router,
220 authn_client: Option<Arc<dyn AuthNResolverClient>>,
221 ) -> Result<Router> {
222 let route_policy = self.build_route_policy_from_specs()?;
224
225 router = router.route_layer(from_fn(middleware::http_metrics::propagate_matched_path));
239
240 let config = self.get_cached_config();
241
242 let specs: Vec<_> = self
244 .openapi_registry
245 .operation_specs
246 .iter()
247 .map(|e| e.value().clone())
248 .collect();
249
250 let license_map = middleware::license_validation::LicenseRequirementMap::from_specs(&specs);
252
253 router = router.layer(from_fn(
254 move |req: axum::extract::Request, next: axum::middleware::Next| {
255 let map = license_map.clone();
256 middleware::license_validation::license_validation_middleware(map, req, next)
257 },
258 ));
259
260 if config.route_policies.enabled {
262 if config.auth_disabled {
264 return Err(anyhow::anyhow!(
265 "Invalid configuration: route_policies.enabled=true requires authentication. \
266 Set auth_disabled=false or disable route_policies."
267 ));
268 }
269
270 let scope_rules = middleware::scope_enforcement::ScopeEnforcementRules::from_config(
271 &config.route_policies,
272 )?;
273 let scope_state =
274 middleware::scope_enforcement::ScopeEnforcementState { rules: scope_rules };
275 router = router.layer(from_fn_with_state(
276 scope_state,
277 middleware::scope_enforcement::scope_enforcement_middleware,
278 ));
279 }
280
281 if config.auth_disabled {
283 let default_security_context = SecurityContext::builder()
285 .subject_id(DEFAULT_SUBJECT_ID)
286 .subject_tenant_id(DEFAULT_TENANT_ID)
287 .build()?;
288
289 tracing::warn!(
290 "API Gateway auth is DISABLED: all requests will run with default tenant SecurityContext. \
291 This mode bypasses authentication and is intended ONLY for single-user on-premises deployments without an IdP. \
292 Permission checks and secure ORM still apply. DO NOT use this mode in multi-tenant or production environments."
293 );
294 router = router.layer(from_fn(
295 move |mut req: axum::extract::Request, next: axum::middleware::Next| {
296 let sec_context = default_security_context.clone();
297 async move {
298 req.extensions_mut().insert(sec_context);
299 next.run(req).await
300 }
301 },
302 ));
303 } else if let Some(client) = authn_client {
304 let auth_state = auth::AuthState {
305 authn_client: client,
306 route_policy,
307 };
308 router = router.layer(from_fn_with_state(auth_state, auth::authn_middleware));
309 } else {
310 return Err(anyhow::anyhow!(
311 "auth is enabled but no AuthN Resolver client is available; \
312 ensure `authn_resolver` module is loaded or set `auth_disabled: true`"
313 ));
314 }
315
316 router = router.layer(from_fn(modkit::api::error_layer::error_mapping_middleware));
318
319 let rate_map = middleware::rate_limit::RateLimiterMap::from_specs(&specs, &config)?;
321
322 router = router.layer(from_fn(
323 move |req: axum::extract::Request, next: axum::middleware::Next| {
324 let map = rate_map.clone();
325 middleware::rate_limit::rate_limit_middleware(map, req, next)
326 },
327 ));
328
329 let mime_map = middleware::mime_validation::build_mime_validation_map(&specs);
331 router = router.layer(from_fn(
332 move |req: axum::extract::Request, next: axum::middleware::Next| {
333 let map = mime_map.clone();
334 middleware::mime_validation::mime_validation_middleware(map, req, next)
335 },
336 ));
337
338 if config.cors_enabled {
340 router = router.layer(crate::cors::build_cors_layer(&config));
341 }
342
343 router = router.layer(RequestBodyLimitLayer::new(config.defaults.body_limit_bytes));
345 router = router.layer(DefaultBodyLimit::max(config.defaults.body_limit_bytes));
346
347 router = router.layer(TimeoutLayer::with_status_code(
349 axum::http::StatusCode::GATEWAY_TIMEOUT,
350 Duration::from_secs(30),
351 ));
352
353 router = router.layer(CatchPanicLayer::new());
355
356 let http_metrics = Arc::new(middleware::http_metrics::HttpMetrics::new(
358 Self::MODULE_NAME,
359 &config.metrics.prefix,
360 ));
361 router = router.layer(from_fn_with_state(
362 http_metrics,
363 middleware::http_metrics::http_metrics_middleware,
364 ));
365
366 router = router.layer(from_fn(middleware::access_log::access_log_middleware));
368
369 router = router.layer(from_fn(middleware::request_id::push_req_id_to_extensions));
371
372 router = router.layer({
374 use modkit_http::otel;
375 use tower_http::trace::TraceLayer;
376 use tracing::field::Empty;
377
378 TraceLayer::new_for_http()
379 .make_span_with(move |req: &axum::http::Request<axum::body::Body>| {
380 let hdr = middleware::request_id::header();
381 let rid = req
382 .headers()
383 .get(&hdr)
384 .and_then(|v| v.to_str().ok())
385 .unwrap_or("n/a");
386
387 let span = tracing::info_span!(
388 "http_request",
389 method = %req.method(),
390 uri = %req.uri().path(),
391 version = ?req.version(),
392 module = "api_gateway",
393 endpoint = %req.uri().path(),
394 request_id = %rid,
395 status = Empty,
396 latency_ms = Empty,
397 "http.method" = %req.method(),
399 "http.target" = %req.uri().path(),
400 "http.scheme" = req.uri().scheme_str().unwrap_or("http"),
401 "http.host" = req.headers().get("host")
402 .and_then(|h| h.to_str().ok())
403 .unwrap_or("unknown"),
404 "user_agent.original" = req.headers().get("user-agent")
405 .and_then(|h| h.to_str().ok())
406 .unwrap_or("unknown"),
407 trace_id = Empty,
409 parent.trace_id = Empty
410 );
411
412 otel::set_parent_from_headers(&span, req.headers());
415
416 span
417 })
418 .on_response(
419 |res: &axum::http::Response<axum::body::Body>,
420 latency: std::time::Duration,
421 span: &tracing::Span| {
422 let ms = latency.as_millis();
423 span.record("status", res.status().as_u16());
424 span.record("latency_ms", ms);
425 },
426 )
427 });
428
429 let x_request_id = crate::middleware::request_id::header();
431 router = router.layer(PropagateRequestIdLayer::new(x_request_id.clone()));
433 router = router.layer(SetRequestIdLayer::new(
434 x_request_id,
435 crate::middleware::request_id::MakeReqId,
436 ));
437
438 Ok(router)
439 }
440
441 pub fn build_router(&self) -> Result<Router> {
446 let cached_router = self.router_cache.load();
449 if Arc::strong_count(&cached_router) > 1 {
450 tracing::debug!("Using cached router");
451 return Ok((*cached_router).clone());
452 }
453
454 tracing::debug!("Building new router (standalone/fallback mode)");
455 let mut router = Router::new()
458 .route("/health", get(web::health_check))
459 .route("/healthz", get(|| async { "ok" }));
460
461 let authn_client = self.authn_client.lock().clone();
463 router = self.apply_middleware_stack(router, authn_client)?;
464
465 let config = self.get_cached_config();
466 let prefix = Self::normalize_prefix_path(&config.prefix_path)?;
467 router = Self::apply_prefix_nesting(router, &prefix);
468
469 self.router_cache.store(router.clone());
471
472 Ok(router)
473 }
474
475 pub fn build_openapi(&self) -> Result<utoipa::openapi::OpenApi> {
480 let config = self.get_cached_config();
481 let info = modkit::api::OpenApiInfo {
482 title: config.openapi.title.clone(),
483 version: config.openapi.version.clone(),
484 description: config.openapi.description,
485 };
486 self.openapi_registry.build_openapi(&info)
487 }
488
489 fn parse_bind_address(bind_addr: &str) -> anyhow::Result<SocketAddr> {
491 bind_addr
492 .parse()
493 .map_err(|e| anyhow::anyhow!("Invalid bind address '{bind_addr}': {e}"))
494 }
495
496 fn get_or_build_router(self: &Arc<Self>) -> anyhow::Result<Router> {
498 let stored = { self.final_router.lock().take() };
499
500 if let Some(router) = stored {
501 tracing::debug!("Using router from REST phase");
502 Ok(router)
503 } else {
504 tracing::debug!("No router from REST phase, building default router");
505 self.build_router()
506 }
507 }
508
509 pub(crate) async fn serve(
514 self: Arc<Self>,
515 cancel: CancellationToken,
516 ready: ReadySignal,
517 ) -> anyhow::Result<()> {
518 let cfg = self.get_cached_config();
519 let addr = Self::parse_bind_address(&cfg.bind_addr)?;
520 let router = self.get_or_build_router()?;
521
522 let listener = tokio::net::TcpListener::bind(addr).await?;
524 tracing::info!("HTTP server bound on {}", addr);
525 ready.notify(); let shutdown = {
529 let cancel = cancel.clone();
530 async move {
531 cancel.cancelled().await;
532 tracing::info!("HTTP server shutting down gracefully (cancellation)");
533 }
534 };
535
536 axum::serve(
537 listener,
538 router.into_make_service_with_connect_info::<SocketAddr>(),
539 )
540 .with_graceful_shutdown(shutdown)
541 .await
542 .map_err(|e| anyhow::anyhow!(e))
543 }
544
545 fn check_duplicate_handler(&self, spec: &modkit::api::OperationSpec) -> bool {
547 if self
548 .registered_handlers
549 .insert(spec.handler_id.clone(), ())
550 .is_some()
551 {
552 tracing::error!(
553 handler_id = %spec.handler_id,
554 method = %spec.method.as_str(),
555 path = %spec.path,
556 "Duplicate handler_id detected; ignoring subsequent registration"
557 );
558 return true;
559 }
560 false
561 }
562
563 fn check_duplicate_route(&self, spec: &modkit::api::OperationSpec) -> bool {
565 let route_key = (spec.method.clone(), spec.path.clone());
566 if self.registered_routes.insert(route_key, ()).is_some() {
567 tracing::error!(
568 method = %spec.method.as_str(),
569 path = %spec.path,
570 "Duplicate (method, path) detected; ignoring subsequent registration"
571 );
572 return true;
573 }
574 false
575 }
576
577 fn log_operation_registration(&self, spec: &modkit::api::OperationSpec) {
579 let current_count = self.openapi_registry.operation_specs.len();
580 tracing::debug!(
581 handler_id = %spec.handler_id,
582 method = %spec.method.as_str(),
583 path = %spec.path,
584 summary = %spec.summary.as_deref().unwrap_or("No summary"),
585 total_operations = current_count,
586 "Registered API operation"
587 );
588 }
589
590 fn add_openapi_routes(&self, mut router: axum::Router) -> anyhow::Result<axum::Router> {
592 let op_count = self.openapi_registry.operation_specs.len();
594 tracing::info!(
595 "rest_finalize: emitting OpenAPI with {} operations",
596 op_count
597 );
598
599 let openapi_doc = Arc::new(self.build_openapi()?);
600 let config = self.get_cached_config();
601 let prefix = Self::normalize_prefix_path(&config.prefix_path)?;
602 let html_doc = web::serve_docs(&prefix);
603
604 router = router
605 .route(
606 "/openapi.json",
607 get({
608 use axum::{http::header, response::IntoResponse};
609 let doc = openapi_doc;
610 move || async move {
611 let json_string = match serde_json::to_string_pretty(doc.as_ref()) {
612 Ok(json) => json,
613 Err(e) => {
614 tracing::error!("Failed to serialize OpenAPI doc: {}", e);
615 return (http::StatusCode::INTERNAL_SERVER_ERROR).into_response();
616 }
617 };
618 (
619 [
620 (header::CONTENT_TYPE, "application/json"),
621 (header::CACHE_CONTROL, "no-store"),
622 ],
623 json_string,
624 )
625 .into_response()
626 }
627 }),
628 )
629 .route("/docs", get(move || async move { html_doc }));
630
631 #[cfg(feature = "embed_elements")]
632 {
633 router = router.route(
634 "/docs/assets/{*file}",
635 get(crate::assets::serve_elements_asset),
636 );
637 }
638
639 Ok(router)
640 }
641}
642
643#[async_trait]
645impl modkit::Module for ApiGateway {
646 async fn init(&self, ctx: &modkit::context::ModuleCtx) -> anyhow::Result<()> {
647 let cfg = ctx.config_or_default::<crate::config::ApiGatewayConfig>()?;
648 self.config.store(Arc::new(cfg.clone()));
649
650 debug!(
651 "Effective api_gateway configuration:\n{:#?}",
652 self.config.load()
653 );
654
655 if cfg.auth_disabled {
656 tracing::info!(
657 tenant_id = %DEFAULT_TENANT_ID,
658 "Auth-disabled mode enabled with default tenant"
659 );
660 } else {
661 let authn_client = ctx.client_hub().get::<dyn AuthNResolverClient>()?;
663 *self.authn_client.lock() = Some(authn_client);
664 tracing::info!("AuthN Resolver client resolved from ClientHub");
665 }
666
667 Ok(())
668 }
669}
670
671impl modkit::contracts::ApiGatewayCapability for ApiGateway {
673 fn rest_prepare(
674 &self,
675 _ctx: &modkit::context::ModuleCtx,
676 router: axum::Router,
677 ) -> anyhow::Result<axum::Router> {
678 let router = router
682 .route("/health", get(web::health_check))
683 .route("/healthz", get(|| async { "ok" }));
684
685 tracing::debug!("REST host prepared base router with health check endpoints");
687 Ok(router)
688 }
689
690 fn rest_finalize(
691 &self,
692 _ctx: &modkit::context::ModuleCtx,
693 mut router: axum::Router,
694 ) -> anyhow::Result<axum::Router> {
695 let config = self.get_cached_config();
696
697 if config.enable_docs {
698 router = self.add_openapi_routes(router)?;
699 }
700
701 tracing::debug!("Applying middleware stack to finalized router");
703 let authn_client = self.authn_client.lock().clone();
704 router = self.apply_middleware_stack(router, authn_client)?;
705
706 let prefix = Self::normalize_prefix_path(&config.prefix_path)?;
707 router = Self::apply_prefix_nesting(router, &prefix);
708
709 *self.final_router.lock() = Some(router.clone());
711
712 tracing::info!("REST host finalized router with OpenAPI endpoints and auth middleware");
713 Ok(router)
714 }
715
716 fn as_registry(&self) -> &dyn modkit::contracts::OpenApiRegistry {
717 self
718 }
719}
720
721impl modkit::contracts::RestApiCapability for ApiGateway {
722 fn register_rest(
723 &self,
724 _ctx: &modkit::context::ModuleCtx,
725 router: axum::Router,
726 _openapi: &dyn modkit::contracts::OpenApiRegistry,
727 ) -> anyhow::Result<axum::Router> {
728 Ok(router)
731 }
732}
733
734impl OpenApiRegistry for ApiGateway {
735 fn register_operation(&self, spec: &modkit::api::OperationSpec) {
736 if self.check_duplicate_handler(spec) {
738 return;
739 }
740
741 if self.check_duplicate_route(spec) {
742 return;
743 }
744
745 self.openapi_registry.register_operation(spec);
747 self.log_operation_registration(spec);
748 }
749
750 fn ensure_schema_raw(
751 &self,
752 root_name: &str,
753 schemas: Vec<(
754 String,
755 utoipa::openapi::RefOr<utoipa::openapi::schema::Schema>,
756 )>,
757 ) -> String {
758 self.openapi_registry.ensure_schema_raw(root_name, schemas)
760 }
761
762 fn as_any(&self) -> &dyn std::any::Any {
763 self
764 }
765}
766
767#[cfg(test)]
768#[cfg_attr(coverage_nightly, coverage(off))]
769mod tests {
770 use super::*;
771
772 #[test]
773 fn test_openapi_generation() {
774 let mut config = ApiGatewayConfig::default();
775 config.openapi.title = "Test API".to_owned();
776 config.openapi.version = "1.0.0".to_owned();
777 config.openapi.description = Some("Test Description".to_owned());
778 let api = ApiGateway::new(config);
779
780 let doc = api.build_openapi().unwrap();
782 let json = serde_json::to_value(&doc).unwrap();
783
784 assert!(json.get("openapi").is_some());
786 assert!(json.get("info").is_some());
787 assert!(json.get("paths").is_some());
788
789 let info = json.get("info").unwrap();
791 assert_eq!(info.get("title").unwrap(), "Test API");
792 assert_eq!(info.get("version").unwrap(), "1.0.0");
793 assert_eq!(info.get("description").unwrap(), "Test Description");
794 }
795}
796
797#[cfg(test)]
798#[cfg_attr(coverage_nightly, coverage(off))]
799mod normalize_prefix_path_tests {
800 use super::*;
801
802 #[test]
803 fn empty_string_returns_empty() {
804 assert_eq!(ApiGateway::normalize_prefix_path("").unwrap(), "");
805 }
806
807 #[test]
808 fn sole_slash_returns_empty() {
809 assert_eq!(ApiGateway::normalize_prefix_path("/").unwrap(), "");
810 }
811
812 #[test]
813 fn multiple_slashes_return_empty() {
814 assert_eq!(ApiGateway::normalize_prefix_path("///").unwrap(), "");
815 }
816
817 #[test]
818 fn whitespace_only_returns_empty() {
819 assert_eq!(ApiGateway::normalize_prefix_path(" ").unwrap(), "");
820 }
821
822 #[test]
823 fn simple_prefix_preserved() {
824 assert_eq!(ApiGateway::normalize_prefix_path("/cf").unwrap(), "/cf");
825 }
826
827 #[test]
828 fn trailing_slash_stripped() {
829 assert_eq!(ApiGateway::normalize_prefix_path("/cf/").unwrap(), "/cf");
830 }
831
832 #[test]
833 fn leading_slash_prepended_when_missing() {
834 assert_eq!(ApiGateway::normalize_prefix_path("cf").unwrap(), "/cf");
835 }
836
837 #[test]
838 fn consecutive_leading_slashes_collapsed() {
839 assert_eq!(ApiGateway::normalize_prefix_path("//cf").unwrap(), "/cf");
840 }
841
842 #[test]
843 fn consecutive_slashes_mid_path_collapsed() {
844 assert_eq!(
845 ApiGateway::normalize_prefix_path("/api//v1").unwrap(),
846 "/api/v1"
847 );
848 }
849
850 #[test]
851 fn many_consecutive_slashes_collapsed() {
852 assert_eq!(
853 ApiGateway::normalize_prefix_path("///api///v1///").unwrap(),
854 "/api/v1"
855 );
856 }
857
858 #[test]
859 fn surrounding_whitespace_trimmed() {
860 assert_eq!(ApiGateway::normalize_prefix_path(" /cf ").unwrap(), "/cf");
861 }
862
863 #[test]
864 fn nested_path_preserved() {
865 assert_eq!(
866 ApiGateway::normalize_prefix_path("/api/v1").unwrap(),
867 "/api/v1"
868 );
869 }
870
871 #[test]
872 fn dot_in_path_allowed() {
873 assert_eq!(
874 ApiGateway::normalize_prefix_path("/api/v1.0").unwrap(),
875 "/api/v1.0"
876 );
877 }
878
879 #[test]
880 fn rejects_html_injection() {
881 let result = ApiGateway::normalize_prefix_path(r#""><script>alert(1)</script>"#);
882 assert!(result.is_err());
883 }
884
885 #[test]
886 fn rejects_spaces_in_path() {
887 let result = ApiGateway::normalize_prefix_path("/my path");
888 assert!(result.is_err());
889 }
890
891 #[test]
892 fn rejects_query_string_chars() {
893 let result = ApiGateway::normalize_prefix_path("/api?foo=bar");
894 assert!(result.is_err());
895 }
896}
897
898#[cfg(test)]
899#[cfg_attr(coverage_nightly, coverage(off))]
900mod problem_openapi_tests {
901 use super::*;
902 use axum::Json;
903 use modkit::api::{Missing, OperationBuilder};
904 use serde_json::Value;
905
906 async fn dummy_handler() -> Json<Value> {
907 Json(serde_json::json!({"ok": true}))
908 }
909
910 #[tokio::test]
911 async fn openapi_includes_problem_schema_and_response() {
912 let api = ApiGateway::default();
913 let router = axum::Router::new();
914
915 let _router = OperationBuilder::<Missing, Missing, ()>::get("/tests/v1/problem-demo")
917 .public()
918 .summary("Problem demo")
919 .problem_response(&api, http::StatusCode::BAD_REQUEST, "Bad Request") .handler(dummy_handler)
921 .register(router, &api);
922
923 let doc = api.build_openapi().expect("openapi");
924 let v = serde_json::to_value(&doc).expect("json");
925
926 let problem = v
928 .pointer("/components/schemas/Problem")
929 .expect("Problem schema missing");
930 assert!(
931 problem.get("$ref").is_none(),
932 "Problem must be a real object, not a self-ref"
933 );
934
935 let path_obj = v
937 .pointer("/paths/~1tests~1v1~1problem-demo/get/responses/400")
938 .expect("400 response missing");
939
940 let content_obj = path_obj.get("content").expect("content object missing");
942 assert!(
943 content_obj.get("application/problem+json").is_some(),
944 "application/problem+json content missing. Available content: {}",
945 serde_json::to_string_pretty(content_obj).unwrap()
946 );
947
948 let content = path_obj
949 .pointer("/content/application~1problem+json")
950 .expect("application/problem+json content missing");
951 let schema_ref = content
953 .pointer("/schema/$ref")
954 .and_then(|r| r.as_str())
955 .unwrap_or("");
956 assert_eq!(schema_ref, "#/components/schemas/Problem");
957 }
958}
959
960#[cfg(test)]
961#[cfg_attr(coverage_nightly, coverage(off))]
962mod sse_openapi_tests {
963 use super::*;
964 use axum::Json;
965 use modkit::api::{Missing, OperationBuilder};
966 use serde_json::Value;
967
968 #[derive(Clone)]
969 #[modkit_macros::api_dto(request, response)]
970 struct UserEvent {
971 id: u32,
972 message: String,
973 }
974
975 async fn sse_handler() -> axum::response::sse::Sse<
976 impl futures_core::Stream<Item = Result<axum::response::sse::Event, std::convert::Infallible>>,
977 > {
978 let b = modkit::SseBroadcaster::<UserEvent>::new(4);
979 b.sse_response()
980 }
981
982 #[tokio::test]
983 async fn openapi_has_sse_content() {
984 let api = ApiGateway::default();
985 let router = axum::Router::new();
986
987 let _router = OperationBuilder::<Missing, Missing, ()>::get("/tests/v1/demo/sse")
988 .summary("Demo SSE")
989 .handler(sse_handler)
990 .public()
991 .sse_json::<UserEvent>(&api, "SSE of UserEvent")
992 .register(router, &api);
993
994 let doc = api.build_openapi().expect("openapi");
995 let v = serde_json::to_value(&doc).expect("json");
996
997 let schema = v
999 .pointer("/components/schemas/UserEvent")
1000 .expect("UserEvent missing");
1001 assert!(schema.get("$ref").is_none());
1002
1003 let refp = v
1005 .pointer("/paths/~1tests~1v1~1demo~1sse/get/responses/200/content/text~1event-stream/schema/$ref")
1006 .and_then(|x| x.as_str())
1007 .unwrap_or_default();
1008 assert_eq!(refp, "#/components/schemas/UserEvent");
1009 }
1010
1011 #[tokio::test]
1012 async fn openapi_sse_additional_response() {
1013 async fn mixed_handler() -> Json<Value> {
1014 Json(serde_json::json!({"ok": true}))
1015 }
1016
1017 let api = ApiGateway::default();
1018 let router = axum::Router::new();
1019
1020 let _router = OperationBuilder::<Missing, Missing, ()>::get("/tests/v1/demo/mixed")
1021 .summary("Mixed responses")
1022 .public()
1023 .handler(mixed_handler)
1024 .json_response(http::StatusCode::OK, "Success response")
1025 .sse_json::<UserEvent>(&api, "Additional SSE stream")
1026 .register(router, &api);
1027
1028 let doc = api.build_openapi().expect("openapi");
1029 let v = serde_json::to_value(&doc).expect("json");
1030
1031 let responses = v
1033 .pointer("/paths/~1tests~1v1~1demo~1mixed/get/responses")
1034 .expect("responses");
1035
1036 assert!(responses.get("200").is_some());
1038
1039 let response_content = responses.get("200").and_then(|r| r.get("content"));
1041 assert!(response_content.is_some());
1042
1043 let schema = v
1045 .pointer("/components/schemas/UserEvent")
1046 .expect("UserEvent missing");
1047 assert!(schema.get("$ref").is_none());
1048 }
1049
1050 #[tokio::test]
1051 async fn test_axum_to_openapi_path_conversion() {
1052 async fn user_handler() -> Json<Value> {
1054 Json(serde_json::json!({"user_id": "123"}))
1055 }
1056
1057 let api = ApiGateway::default();
1058 let router = axum::Router::new();
1059
1060 let _router = OperationBuilder::<Missing, Missing, ()>::get("/tests/v1/users/{id}")
1061 .summary("Get user by ID")
1062 .public()
1063 .path_param("id", "User ID")
1064 .handler(user_handler)
1065 .json_response(http::StatusCode::OK, "User details")
1066 .register(router, &api);
1067
1068 let ops: Vec<_> = api
1070 .openapi_registry
1071 .operation_specs
1072 .iter()
1073 .map(|e| e.value().clone())
1074 .collect();
1075 assert_eq!(ops.len(), 1);
1076 assert_eq!(ops[0].path, "/tests/v1/users/{id}");
1077
1078 let doc = api.build_openapi().expect("openapi");
1080 let v = serde_json::to_value(&doc).expect("json");
1081
1082 let paths = v.get("paths").expect("paths");
1083 assert!(
1084 paths.get("/tests/v1/users/{id}").is_some(),
1085 "OpenAPI should use {{id}} placeholder"
1086 );
1087 }
1088
1089 #[tokio::test]
1090 async fn test_multiple_path_params_conversion() {
1091 async fn item_handler() -> Json<Value> {
1092 Json(serde_json::json!({"ok": true}))
1093 }
1094
1095 let api = ApiGateway::default();
1096 let router = axum::Router::new();
1097
1098 let _router = OperationBuilder::<Missing, Missing, ()>::get(
1099 "/tests/v1/projects/{project_id}/items/{item_id}",
1100 )
1101 .summary("Get project item")
1102 .public()
1103 .path_param("project_id", "Project ID")
1104 .path_param("item_id", "Item ID")
1105 .handler(item_handler)
1106 .json_response(http::StatusCode::OK, "Item details")
1107 .register(router, &api);
1108
1109 let ops: Vec<_> = api
1111 .openapi_registry
1112 .operation_specs
1113 .iter()
1114 .map(|e| e.value().clone())
1115 .collect();
1116 assert_eq!(
1117 ops[0].path,
1118 "/tests/v1/projects/{project_id}/items/{item_id}"
1119 );
1120
1121 let doc = api.build_openapi().expect("openapi");
1122 let v = serde_json::to_value(&doc).expect("json");
1123 let paths = v.get("paths").expect("paths");
1124 assert!(
1125 paths
1126 .get("/tests/v1/projects/{project_id}/items/{item_id}")
1127 .is_some()
1128 );
1129 }
1130
1131 #[tokio::test]
1132 async fn test_wildcard_path_conversion() {
1133 async fn static_handler() -> Json<Value> {
1134 Json(serde_json::json!({"ok": true}))
1135 }
1136
1137 let api = ApiGateway::default();
1138 let router = axum::Router::new();
1139
1140 let _router = OperationBuilder::<Missing, Missing, ()>::get("/tests/v1/static/{*path}")
1142 .summary("Serve static files")
1143 .public()
1144 .handler(static_handler)
1145 .json_response(http::StatusCode::OK, "File content")
1146 .register(router, &api);
1147
1148 let ops: Vec<_> = api
1150 .openapi_registry
1151 .operation_specs
1152 .iter()
1153 .map(|e| e.value().clone())
1154 .collect();
1155 assert_eq!(ops[0].path, "/tests/v1/static/{*path}");
1156
1157 let doc = api.build_openapi().expect("openapi");
1159 let v = serde_json::to_value(&doc).expect("json");
1160 let paths = v.get("paths").expect("paths");
1161 assert!(
1162 paths.get("/tests/v1/static/{path}").is_some(),
1163 "Wildcard {{*path}} should be converted to {{path}} in OpenAPI"
1164 );
1165 assert!(
1166 paths.get("/static/{*path}").is_none(),
1167 "OpenAPI should not have Axum-style {{*path}}"
1168 );
1169 }
1170
1171 #[tokio::test]
1172 async fn test_multipart_file_upload_openapi() {
1173 async fn upload_handler() -> Json<Value> {
1174 Json(serde_json::json!({"uploaded": true}))
1175 }
1176
1177 let api = ApiGateway::default();
1178 let router = axum::Router::new();
1179
1180 let _router = OperationBuilder::<Missing, Missing, ()>::post("/tests/v1/files/upload")
1181 .operation_id("upload_file")
1182 .public()
1183 .summary("Upload a file")
1184 .multipart_file_request("file", Some("File to upload"))
1185 .handler(upload_handler)
1186 .json_response(http::StatusCode::OK, "Upload successful")
1187 .register(router, &api);
1188
1189 let doc = api.build_openapi().expect("openapi");
1191 let v = serde_json::to_value(&doc).expect("json");
1192
1193 let paths = v.get("paths").expect("paths");
1194 let upload_path = paths
1195 .get("/tests/v1/files/upload")
1196 .expect("/tests/v1/files/upload path");
1197 let post_op = upload_path.get("post").expect("POST operation");
1198
1199 let request_body = post_op.get("requestBody").expect("requestBody");
1201 let content = request_body.get("content").expect("content");
1202 let multipart = content
1203 .get("multipart/form-data")
1204 .expect("multipart/form-data content type");
1205
1206 let schema = multipart.get("schema").expect("schema");
1208 assert_eq!(
1209 schema.get("type").and_then(|v| v.as_str()),
1210 Some("object"),
1211 "Schema should be of type object"
1212 );
1213
1214 let properties = schema.get("properties").expect("properties");
1216 let file_prop = properties.get("file").expect("file property");
1217 assert_eq!(
1218 file_prop.get("type").and_then(|v| v.as_str()),
1219 Some("string"),
1220 "File field should be of type string"
1221 );
1222 assert_eq!(
1223 file_prop.get("format").and_then(|v| v.as_str()),
1224 Some("binary"),
1225 "File field should have format binary"
1226 );
1227
1228 let required = schema.get("required").expect("required");
1230 let required_arr = required.as_array().expect("required should be array");
1231 assert_eq!(required_arr.len(), 1);
1232 assert_eq!(required_arr[0].as_str(), Some("file"));
1233 }
1234}