1use async_trait::async_trait;
6use std::sync::Arc;
7
8use arc_swap::ArcSwap;
9use dashmap::DashMap;
10
11use anyhow::Result;
12use axum::http::Method;
13use axum::middleware::from_fn_with_state;
14use axum::{Router, extract::DefaultBodyLimit, middleware::from_fn, routing::get};
15use modkit::api::{OpenApiRegistry, OpenApiRegistryImpl};
16use modkit::lifecycle::ReadySignal;
17use parking_lot::Mutex;
18use std::net::SocketAddr;
19use std::time::Duration;
20use tokio_util::sync::CancellationToken;
21use tower_http::{
22 limit::RequestBodyLimitLayer,
23 request_id::{PropagateRequestIdLayer, SetRequestIdLayer},
24 timeout::TimeoutLayer,
25};
26use tracing::debug;
27
28use authn_resolver_sdk::AuthNResolverClient;
29
30use crate::config::ApiGatewayConfig;
31use crate::middleware::auth;
32use modkit_security::SecurityContext;
33use modkit_security::constants::{DEFAULT_SUBJECT_ID, DEFAULT_TENANT_ID};
34
35use crate::middleware;
36use crate::router_cache::RouterCache;
37use crate::web;
38
39#[modkit::module(
42 name = "api-gateway",
43 capabilities = [rest_host, rest, stateful],
44 deps = ["grpc-hub", "authn-resolver"],
45 lifecycle(entry = "serve", stop_timeout = "30s", await_ready)
46)]
47pub struct ApiGateway {
48 pub(crate) config: ArcSwap<ApiGatewayConfig>,
50 pub(crate) openapi_registry: Arc<OpenApiRegistryImpl>,
52 pub(crate) router_cache: RouterCache<axum::Router>,
54 pub(crate) final_router: Mutex<Option<axum::Router>>,
56 pub(crate) authn_client: Mutex<Option<Arc<dyn AuthNResolverClient>>>,
58
59 pub(crate) registered_routes: DashMap<(Method, String), ()>,
61 pub(crate) registered_handlers: DashMap<String, ()>,
62}
63
64impl Default for ApiGateway {
65 fn default() -> Self {
66 let default_router = Router::new();
67 Self {
68 config: ArcSwap::from_pointee(ApiGatewayConfig::default()),
69 openapi_registry: Arc::new(OpenApiRegistryImpl::new()),
70 router_cache: RouterCache::new(default_router),
71 final_router: Mutex::new(None),
72 authn_client: Mutex::new(None),
73 registered_routes: DashMap::new(),
74 registered_handlers: DashMap::new(),
75 }
76 }
77}
78
79impl ApiGateway {
80 fn apply_prefix_nesting(mut router: Router, prefix: &str) -> Router {
81 if prefix.is_empty() {
82 return router;
83 }
84
85 let top = Router::new()
86 .route("/health", get(web::health_check))
87 .route("/healthz", get(|| async { "ok" }));
88
89 router = Router::new().nest(prefix, router);
90 top.merge(router)
91 }
92
93 #[must_use]
95 pub fn new(config: ApiGatewayConfig) -> Self {
96 let default_router = Router::new();
97 Self {
98 config: ArcSwap::from_pointee(config),
99 openapi_registry: Arc::new(OpenApiRegistryImpl::new()),
100 router_cache: RouterCache::new(default_router),
101 final_router: Mutex::new(None),
102 authn_client: Mutex::new(None),
103 registered_routes: DashMap::new(),
104 registered_handlers: DashMap::new(),
105 }
106 }
107
108 pub fn get_config(&self) -> ApiGatewayConfig {
110 (**self.config.load()).clone()
111 }
112
113 pub fn get_cached_config(&self) -> ApiGatewayConfig {
115 (**self.config.load()).clone()
116 }
117
118 pub fn get_cached_router(&self) -> Arc<Router> {
120 self.router_cache.load()
121 }
122
123 pub fn rebuild_and_cache_router(&self) -> Result<()> {
128 let new_router = self.build_router()?;
129 self.router_cache.store(new_router);
130 Ok(())
131 }
132
133 fn build_route_policy_from_specs(&self) -> Result<auth::GatewayRoutePolicy> {
135 let mut authenticated_routes = std::collections::HashSet::new();
136 let mut public_routes = std::collections::HashSet::new();
137
138 public_routes.insert((Method::GET, "/health".to_owned()));
140 public_routes.insert((Method::GET, "/healthz".to_owned()));
141
142 public_routes.insert((Method::GET, "/docs".to_owned()));
143 public_routes.insert((Method::GET, "/openapi.json".to_owned()));
144
145 for spec in &self.openapi_registry.operation_specs {
146 let spec = spec.value();
147
148 let route_key = (spec.method.clone(), spec.path.clone());
149
150 if spec.authenticated {
151 authenticated_routes.insert(route_key.clone());
152 }
153
154 if spec.is_public {
155 public_routes.insert(route_key);
156 }
157 }
158
159 let config = self.get_cached_config();
160 let requirements_count = authenticated_routes.len();
161 let public_routes_count = public_routes.len();
162
163 let route_policy = auth::build_route_policy(&config, authenticated_routes, public_routes)?;
164
165 tracing::info!(
166 auth_disabled = config.auth_disabled,
167 require_auth_by_default = config.require_auth_by_default,
168 requirements_count = requirements_count,
169 public_routes_count = public_routes_count,
170 "Route policy built from operation specs"
171 );
172
173 Ok(route_policy)
174 }
175
176 fn normalize_prefix_path(raw: &str) -> Result<String> {
177 let trimmed = raw.trim();
178 let collapsed: String =
180 trimmed
181 .chars()
182 .fold(String::with_capacity(trimmed.len()), |mut acc, c| {
183 if c == '/' && acc.ends_with('/') {
184 } else {
186 acc.push(c);
187 }
188 acc
189 });
190 let prefix = collapsed.trim_end_matches('/');
191 let result = if prefix.is_empty() {
192 String::new()
193 } else if prefix.starts_with('/') {
194 prefix.to_owned()
195 } else {
196 format!("/{prefix}")
197 };
198 if !result
200 .bytes()
201 .all(|b| b.is_ascii_alphanumeric() || b == b'/' || b == b'_' || b == b'-' || b == b'.')
202 {
203 anyhow::bail!(
204 "prefix_path contains invalid characters (must match [a-zA-Z0-9/_\\-.]): {raw:?}"
205 );
206 }
207
208 if result.split('/').any(|seg| seg == "." || seg == "..") {
209 anyhow::bail!("prefix_path must not contain '.' or '..' segments: {raw:?}");
210 }
211
212 Ok(result)
213 }
214
215 pub(crate) fn apply_middleware_stack(
217 &self,
218 mut router: Router,
219 authn_client: Option<Arc<dyn AuthNResolverClient>>,
220 ) -> Result<Router> {
221 let route_policy = self.build_route_policy_from_specs()?;
223
224 let config = self.get_cached_config();
235
236 let specs: Vec<_> = self
238 .openapi_registry
239 .operation_specs
240 .iter()
241 .map(|e| e.value().clone())
242 .collect();
243
244 let license_map = middleware::license_validation::LicenseRequirementMap::from_specs(&specs);
246
247 router = router.layer(from_fn(
248 move |req: axum::extract::Request, next: axum::middleware::Next| {
249 let map = license_map.clone();
250 middleware::license_validation::license_validation_middleware(map, req, next)
251 },
252 ));
253
254 if config.auth_disabled {
256 let default_security_context = SecurityContext::builder()
258 .subject_id(DEFAULT_SUBJECT_ID)
259 .subject_tenant_id(DEFAULT_TENANT_ID)
260 .build()?;
261
262 tracing::warn!(
263 "API Gateway auth is DISABLED: all requests will run with default tenant SecurityContext. \
264 This mode bypasses authentication and is intended ONLY for single-user on-premises deployments without an IdP. \
265 Permission checks and secure ORM still apply. DO NOT use this mode in multi-tenant or production environments."
266 );
267 router = router.layer(from_fn(
268 move |mut req: axum::extract::Request, next: axum::middleware::Next| {
269 let sec_context = default_security_context.clone();
270 async move {
271 req.extensions_mut().insert(sec_context);
272 next.run(req).await
273 }
274 },
275 ));
276 } else if let Some(client) = authn_client {
277 let auth_state = auth::AuthState {
278 authn_client: client,
279 route_policy,
280 };
281 router = router.layer(from_fn_with_state(auth_state, auth::authn_middleware));
282 } else {
283 return Err(anyhow::anyhow!(
284 "auth is enabled but no AuthN Resolver client is available; \
285 ensure `authn_resolver` module is loaded or set `auth_disabled: true`"
286 ));
287 }
288
289 router = router.layer(from_fn(modkit::api::error_layer::error_mapping_middleware));
291
292 let rate_map = middleware::rate_limit::RateLimiterMap::from_specs(&specs, &config)?;
294
295 router = router.layer(from_fn(
296 move |req: axum::extract::Request, next: axum::middleware::Next| {
297 let map = rate_map.clone();
298 middleware::rate_limit::rate_limit_middleware(map, req, next)
299 },
300 ));
301
302 let mime_map = middleware::mime_validation::build_mime_validation_map(&specs);
304 router = router.layer(from_fn(
305 move |req: axum::extract::Request, next: axum::middleware::Next| {
306 let map = mime_map.clone();
307 middleware::mime_validation::mime_validation_middleware(map, req, next)
308 },
309 ));
310
311 if config.cors_enabled {
313 router = router.layer(crate::cors::build_cors_layer(&config));
314 }
315
316 router = router.layer(RequestBodyLimitLayer::new(config.defaults.body_limit_bytes));
318 router = router.layer(DefaultBodyLimit::max(config.defaults.body_limit_bytes));
319
320 router = router.layer(TimeoutLayer::with_status_code(
322 axum::http::StatusCode::GATEWAY_TIMEOUT,
323 Duration::from_secs(30),
324 ));
325
326 router = router.layer(from_fn(middleware::request_id::push_req_id_to_extensions));
328
329 router = router.layer({
331 use modkit_http::otel;
332 use tower_http::trace::TraceLayer;
333 use tracing::field::Empty;
334
335 TraceLayer::new_for_http()
336 .make_span_with(move |req: &axum::http::Request<axum::body::Body>| {
337 let hdr = middleware::request_id::header();
338 let rid = req
339 .headers()
340 .get(&hdr)
341 .and_then(|v| v.to_str().ok())
342 .unwrap_or("n/a");
343
344 let span = tracing::info_span!(
345 "http_request",
346 method = %req.method(),
347 uri = %req.uri().path(),
348 version = ?req.version(),
349 module = "api_gateway",
350 endpoint = %req.uri().path(),
351 request_id = %rid,
352 status = Empty,
353 latency_ms = Empty,
354 "http.method" = %req.method(),
356 "http.target" = %req.uri().path(),
357 "http.scheme" = req.uri().scheme_str().unwrap_or("http"),
358 "http.host" = req.headers().get("host")
359 .and_then(|h| h.to_str().ok())
360 .unwrap_or("unknown"),
361 "user_agent.original" = req.headers().get("user-agent")
362 .and_then(|h| h.to_str().ok())
363 .unwrap_or("unknown"),
364 trace_id = Empty,
366 parent.trace_id = Empty
367 );
368
369 otel::set_parent_from_headers(&span, req.headers());
372
373 span
374 })
375 .on_response(
376 |res: &axum::http::Response<axum::body::Body>,
377 latency: std::time::Duration,
378 span: &tracing::Span| {
379 let ms = latency.as_millis();
380 span.record("status", res.status().as_u16());
381 span.record("latency_ms", ms);
382 },
383 )
384 });
385
386 let x_request_id = crate::middleware::request_id::header();
388 router = router.layer(PropagateRequestIdLayer::new(x_request_id.clone()));
390 router = router.layer(SetRequestIdLayer::new(
391 x_request_id,
392 crate::middleware::request_id::MakeReqId,
393 ));
394
395 Ok(router)
396 }
397
398 pub fn build_router(&self) -> Result<Router> {
403 let cached_router = self.router_cache.load();
406 if Arc::strong_count(&cached_router) > 1 {
407 tracing::debug!("Using cached router");
408 return Ok((*cached_router).clone());
409 }
410
411 tracing::debug!("Building new router (standalone/fallback mode)");
412 let mut router = Router::new()
415 .route("/health", get(web::health_check))
416 .route("/healthz", get(|| async { "ok" }));
417
418 let authn_client = self.authn_client.lock().clone();
420 router = self.apply_middleware_stack(router, authn_client)?;
421
422 let config = self.get_cached_config();
423 let prefix = Self::normalize_prefix_path(&config.prefix_path)?;
424 router = Self::apply_prefix_nesting(router, &prefix);
425
426 self.router_cache.store(router.clone());
428
429 Ok(router)
430 }
431
432 pub fn build_openapi(&self) -> Result<utoipa::openapi::OpenApi> {
437 let config = self.get_cached_config();
438 let info = modkit::api::OpenApiInfo {
439 title: config.openapi.title.clone(),
440 version: config.openapi.version.clone(),
441 description: config.openapi.description,
442 };
443 self.openapi_registry.build_openapi(&info)
444 }
445
446 fn parse_bind_address(bind_addr: &str) -> anyhow::Result<SocketAddr> {
448 bind_addr
449 .parse()
450 .map_err(|e| anyhow::anyhow!("Invalid bind address '{bind_addr}': {e}"))
451 }
452
453 fn get_or_build_router(self: &Arc<Self>) -> anyhow::Result<Router> {
455 let stored = { self.final_router.lock().take() };
456
457 if let Some(router) = stored {
458 tracing::debug!("Using router from REST phase");
459 Ok(router)
460 } else {
461 tracing::debug!("No router from REST phase, building default router");
462 self.build_router()
463 }
464 }
465
466 pub(crate) async fn serve(
471 self: Arc<Self>,
472 cancel: CancellationToken,
473 ready: ReadySignal,
474 ) -> anyhow::Result<()> {
475 let cfg = self.get_cached_config();
476 let addr = Self::parse_bind_address(&cfg.bind_addr)?;
477 let router = self.get_or_build_router()?;
478
479 let listener = tokio::net::TcpListener::bind(addr).await?;
481 tracing::info!("HTTP server bound on {}", addr);
482 ready.notify(); let shutdown = {
486 let cancel = cancel.clone();
487 async move {
488 cancel.cancelled().await;
489 tracing::info!("HTTP server shutting down gracefully (cancellation)");
490 }
491 };
492
493 axum::serve(listener, router)
494 .with_graceful_shutdown(shutdown)
495 .await
496 .map_err(|e| anyhow::anyhow!(e))
497 }
498
499 fn check_duplicate_handler(&self, spec: &modkit::api::OperationSpec) -> bool {
501 if self
502 .registered_handlers
503 .insert(spec.handler_id.clone(), ())
504 .is_some()
505 {
506 tracing::error!(
507 handler_id = %spec.handler_id,
508 method = %spec.method.as_str(),
509 path = %spec.path,
510 "Duplicate handler_id detected; ignoring subsequent registration"
511 );
512 return true;
513 }
514 false
515 }
516
517 fn check_duplicate_route(&self, spec: &modkit::api::OperationSpec) -> bool {
519 let route_key = (spec.method.clone(), spec.path.clone());
520 if self.registered_routes.insert(route_key, ()).is_some() {
521 tracing::error!(
522 method = %spec.method.as_str(),
523 path = %spec.path,
524 "Duplicate (method, path) detected; ignoring subsequent registration"
525 );
526 return true;
527 }
528 false
529 }
530
531 fn log_operation_registration(&self, spec: &modkit::api::OperationSpec) {
533 let current_count = self.openapi_registry.operation_specs.len();
534 tracing::debug!(
535 handler_id = %spec.handler_id,
536 method = %spec.method.as_str(),
537 path = %spec.path,
538 summary = %spec.summary.as_deref().unwrap_or("No summary"),
539 total_operations = current_count,
540 "Registered API operation"
541 );
542 }
543
544 fn add_openapi_routes(&self, mut router: axum::Router) -> anyhow::Result<axum::Router> {
546 let op_count = self.openapi_registry.operation_specs.len();
548 tracing::info!(
549 "rest_finalize: emitting OpenAPI with {} operations",
550 op_count
551 );
552
553 let openapi_doc = Arc::new(self.build_openapi()?);
554 let config = self.get_cached_config();
555 let prefix = Self::normalize_prefix_path(&config.prefix_path)?;
556 let html_doc = web::serve_docs(&prefix);
557
558 router = router
559 .route(
560 "/openapi.json",
561 get({
562 use axum::{http::header, response::IntoResponse};
563 let doc = openapi_doc;
564 move || async move {
565 let json_string = match serde_json::to_string_pretty(doc.as_ref()) {
566 Ok(json) => json,
567 Err(e) => {
568 tracing::error!("Failed to serialize OpenAPI doc: {}", e);
569 return (http::StatusCode::INTERNAL_SERVER_ERROR).into_response();
570 }
571 };
572 (
573 [
574 (header::CONTENT_TYPE, "application/json"),
575 (header::CACHE_CONTROL, "no-store"),
576 ],
577 json_string,
578 )
579 .into_response()
580 }
581 }),
582 )
583 .route("/docs", get(move || async move { html_doc }));
584
585 #[cfg(feature = "embed_elements")]
586 {
587 router = router.route(
588 "/docs/assets/{*file}",
589 get(crate::assets::serve_elements_asset),
590 );
591 }
592
593 Ok(router)
594 }
595}
596
597#[async_trait]
599impl modkit::Module for ApiGateway {
600 async fn init(&self, ctx: &modkit::context::ModuleCtx) -> anyhow::Result<()> {
601 let cfg = ctx.config::<crate::config::ApiGatewayConfig>()?;
602 self.config.store(Arc::new(cfg.clone()));
603
604 debug!(
605 "Effective api_gateway configuration:\n{:#?}",
606 self.config.load()
607 );
608
609 if cfg.auth_disabled {
610 tracing::info!(
611 tenant_id = %DEFAULT_TENANT_ID,
612 "Auth-disabled mode enabled with default tenant"
613 );
614 } else {
615 let authn_client = ctx.client_hub().get::<dyn AuthNResolverClient>()?;
617 *self.authn_client.lock() = Some(authn_client);
618 tracing::info!("AuthN Resolver client resolved from ClientHub");
619 }
620
621 Ok(())
622 }
623}
624
625impl modkit::contracts::ApiGatewayCapability for ApiGateway {
627 fn rest_prepare(
628 &self,
629 _ctx: &modkit::context::ModuleCtx,
630 router: axum::Router,
631 ) -> anyhow::Result<axum::Router> {
632 let router = router
636 .route("/health", get(web::health_check))
637 .route("/healthz", get(|| async { "ok" }));
638
639 tracing::debug!("REST host prepared base router with health check endpoints");
641 Ok(router)
642 }
643
644 fn rest_finalize(
645 &self,
646 _ctx: &modkit::context::ModuleCtx,
647 mut router: axum::Router,
648 ) -> anyhow::Result<axum::Router> {
649 let config = self.get_cached_config();
650
651 if config.enable_docs {
652 router = self.add_openapi_routes(router)?;
653 }
654
655 tracing::debug!("Applying middleware stack to finalized router");
657 let authn_client = self.authn_client.lock().clone();
658 router = self.apply_middleware_stack(router, authn_client)?;
659
660 let prefix = Self::normalize_prefix_path(&config.prefix_path)?;
661 router = Self::apply_prefix_nesting(router, &prefix);
662
663 *self.final_router.lock() = Some(router.clone());
665
666 tracing::info!("REST host finalized router with OpenAPI endpoints and auth middleware");
667 Ok(router)
668 }
669
670 fn as_registry(&self) -> &dyn modkit::contracts::OpenApiRegistry {
671 self
672 }
673}
674
675impl modkit::contracts::RestApiCapability for ApiGateway {
676 fn register_rest(
677 &self,
678 _ctx: &modkit::context::ModuleCtx,
679 router: axum::Router,
680 _openapi: &dyn modkit::contracts::OpenApiRegistry,
681 ) -> anyhow::Result<axum::Router> {
682 Ok(router)
685 }
686}
687
688impl OpenApiRegistry for ApiGateway {
689 fn register_operation(&self, spec: &modkit::api::OperationSpec) {
690 if self.check_duplicate_handler(spec) {
692 return;
693 }
694
695 if self.check_duplicate_route(spec) {
696 return;
697 }
698
699 self.openapi_registry.register_operation(spec);
701 self.log_operation_registration(spec);
702 }
703
704 fn ensure_schema_raw(
705 &self,
706 root_name: &str,
707 schemas: Vec<(
708 String,
709 utoipa::openapi::RefOr<utoipa::openapi::schema::Schema>,
710 )>,
711 ) -> String {
712 self.openapi_registry.ensure_schema_raw(root_name, schemas)
714 }
715
716 fn as_any(&self) -> &dyn std::any::Any {
717 self
718 }
719}
720
721#[cfg(test)]
722#[cfg_attr(coverage_nightly, coverage(off))]
723mod tests {
724 use super::*;
725
726 #[test]
727 fn test_openapi_generation() {
728 let mut config = ApiGatewayConfig::default();
729 config.openapi.title = "Test API".to_owned();
730 config.openapi.version = "1.0.0".to_owned();
731 config.openapi.description = Some("Test Description".to_owned());
732 let api = ApiGateway::new(config);
733
734 let doc = api.build_openapi().unwrap();
736 let json = serde_json::to_value(&doc).unwrap();
737
738 assert!(json.get("openapi").is_some());
740 assert!(json.get("info").is_some());
741 assert!(json.get("paths").is_some());
742
743 let info = json.get("info").unwrap();
745 assert_eq!(info.get("title").unwrap(), "Test API");
746 assert_eq!(info.get("version").unwrap(), "1.0.0");
747 assert_eq!(info.get("description").unwrap(), "Test Description");
748 }
749}
750
751#[cfg(test)]
752#[cfg_attr(coverage_nightly, coverage(off))]
753mod normalize_prefix_path_tests {
754 use super::*;
755
756 #[test]
757 fn empty_string_returns_empty() {
758 assert_eq!(ApiGateway::normalize_prefix_path("").unwrap(), "");
759 }
760
761 #[test]
762 fn sole_slash_returns_empty() {
763 assert_eq!(ApiGateway::normalize_prefix_path("/").unwrap(), "");
764 }
765
766 #[test]
767 fn multiple_slashes_return_empty() {
768 assert_eq!(ApiGateway::normalize_prefix_path("///").unwrap(), "");
769 }
770
771 #[test]
772 fn whitespace_only_returns_empty() {
773 assert_eq!(ApiGateway::normalize_prefix_path(" ").unwrap(), "");
774 }
775
776 #[test]
777 fn simple_prefix_preserved() {
778 assert_eq!(ApiGateway::normalize_prefix_path("/cf").unwrap(), "/cf");
779 }
780
781 #[test]
782 fn trailing_slash_stripped() {
783 assert_eq!(ApiGateway::normalize_prefix_path("/cf/").unwrap(), "/cf");
784 }
785
786 #[test]
787 fn leading_slash_prepended_when_missing() {
788 assert_eq!(ApiGateway::normalize_prefix_path("cf").unwrap(), "/cf");
789 }
790
791 #[test]
792 fn consecutive_leading_slashes_collapsed() {
793 assert_eq!(ApiGateway::normalize_prefix_path("//cf").unwrap(), "/cf");
794 }
795
796 #[test]
797 fn consecutive_slashes_mid_path_collapsed() {
798 assert_eq!(
799 ApiGateway::normalize_prefix_path("/api//v1").unwrap(),
800 "/api/v1"
801 );
802 }
803
804 #[test]
805 fn many_consecutive_slashes_collapsed() {
806 assert_eq!(
807 ApiGateway::normalize_prefix_path("///api///v1///").unwrap(),
808 "/api/v1"
809 );
810 }
811
812 #[test]
813 fn surrounding_whitespace_trimmed() {
814 assert_eq!(ApiGateway::normalize_prefix_path(" /cf ").unwrap(), "/cf");
815 }
816
817 #[test]
818 fn nested_path_preserved() {
819 assert_eq!(
820 ApiGateway::normalize_prefix_path("/api/v1").unwrap(),
821 "/api/v1"
822 );
823 }
824
825 #[test]
826 fn dot_in_path_allowed() {
827 assert_eq!(
828 ApiGateway::normalize_prefix_path("/api/v1.0").unwrap(),
829 "/api/v1.0"
830 );
831 }
832
833 #[test]
834 fn rejects_html_injection() {
835 let result = ApiGateway::normalize_prefix_path(r#""><script>alert(1)</script>"#);
836 assert!(result.is_err());
837 }
838
839 #[test]
840 fn rejects_spaces_in_path() {
841 let result = ApiGateway::normalize_prefix_path("/my path");
842 assert!(result.is_err());
843 }
844
845 #[test]
846 fn rejects_query_string_chars() {
847 let result = ApiGateway::normalize_prefix_path("/api?foo=bar");
848 assert!(result.is_err());
849 }
850}
851
852#[cfg(test)]
853#[cfg_attr(coverage_nightly, coverage(off))]
854mod problem_openapi_tests {
855 use super::*;
856 use axum::Json;
857 use modkit::api::{Missing, OperationBuilder};
858 use serde_json::Value;
859
860 async fn dummy_handler() -> Json<Value> {
861 Json(serde_json::json!({"ok": true}))
862 }
863
864 #[tokio::test]
865 async fn openapi_includes_problem_schema_and_response() {
866 let api = ApiGateway::default();
867 let router = axum::Router::new();
868
869 let _router = OperationBuilder::<Missing, Missing, ()>::get("/tests/v1/problem-demo")
871 .public()
872 .summary("Problem demo")
873 .problem_response(&api, http::StatusCode::BAD_REQUEST, "Bad Request") .handler(dummy_handler)
875 .register(router, &api);
876
877 let doc = api.build_openapi().expect("openapi");
878 let v = serde_json::to_value(&doc).expect("json");
879
880 let problem = v
882 .pointer("/components/schemas/Problem")
883 .expect("Problem schema missing");
884 assert!(
885 problem.get("$ref").is_none(),
886 "Problem must be a real object, not a self-ref"
887 );
888
889 let path_obj = v
891 .pointer("/paths/~1tests~1v1~1problem-demo/get/responses/400")
892 .expect("400 response missing");
893
894 let content_obj = path_obj.get("content").expect("content object missing");
896 assert!(
897 content_obj.get("application/problem+json").is_some(),
898 "application/problem+json content missing. Available content: {}",
899 serde_json::to_string_pretty(content_obj).unwrap()
900 );
901
902 let content = path_obj
903 .pointer("/content/application~1problem+json")
904 .expect("application/problem+json content missing");
905 let schema_ref = content
907 .pointer("/schema/$ref")
908 .and_then(|r| r.as_str())
909 .unwrap_or("");
910 assert_eq!(schema_ref, "#/components/schemas/Problem");
911 }
912}
913
914#[cfg(test)]
915#[cfg_attr(coverage_nightly, coverage(off))]
916mod sse_openapi_tests {
917 use super::*;
918 use axum::Json;
919 use modkit::api::{Missing, OperationBuilder};
920 use serde_json::Value;
921
922 #[derive(Clone)]
923 #[modkit_macros::api_dto(request, response)]
924 struct UserEvent {
925 id: u32,
926 message: String,
927 }
928
929 async fn sse_handler() -> axum::response::sse::Sse<
930 impl futures_core::Stream<Item = Result<axum::response::sse::Event, std::convert::Infallible>>,
931 > {
932 let b = modkit::SseBroadcaster::<UserEvent>::new(4);
933 b.sse_response()
934 }
935
936 #[tokio::test]
937 async fn openapi_has_sse_content() {
938 let api = ApiGateway::default();
939 let router = axum::Router::new();
940
941 let _router = OperationBuilder::<Missing, Missing, ()>::get("/tests/v1/demo/sse")
942 .summary("Demo SSE")
943 .handler(sse_handler)
944 .public()
945 .sse_json::<UserEvent>(&api, "SSE of UserEvent")
946 .register(router, &api);
947
948 let doc = api.build_openapi().expect("openapi");
949 let v = serde_json::to_value(&doc).expect("json");
950
951 let schema = v
953 .pointer("/components/schemas/UserEvent")
954 .expect("UserEvent missing");
955 assert!(schema.get("$ref").is_none());
956
957 let refp = v
959 .pointer("/paths/~1tests~1v1~1demo~1sse/get/responses/200/content/text~1event-stream/schema/$ref")
960 .and_then(|x| x.as_str())
961 .unwrap_or_default();
962 assert_eq!(refp, "#/components/schemas/UserEvent");
963 }
964
965 #[tokio::test]
966 async fn openapi_sse_additional_response() {
967 async fn mixed_handler() -> Json<Value> {
968 Json(serde_json::json!({"ok": true}))
969 }
970
971 let api = ApiGateway::default();
972 let router = axum::Router::new();
973
974 let _router = OperationBuilder::<Missing, Missing, ()>::get("/tests/v1/demo/mixed")
975 .summary("Mixed responses")
976 .public()
977 .handler(mixed_handler)
978 .json_response(http::StatusCode::OK, "Success response")
979 .sse_json::<UserEvent>(&api, "Additional SSE stream")
980 .register(router, &api);
981
982 let doc = api.build_openapi().expect("openapi");
983 let v = serde_json::to_value(&doc).expect("json");
984
985 let responses = v
987 .pointer("/paths/~1tests~1v1~1demo~1mixed/get/responses")
988 .expect("responses");
989
990 assert!(responses.get("200").is_some());
992
993 let response_content = responses.get("200").and_then(|r| r.get("content"));
995 assert!(response_content.is_some());
996
997 let schema = v
999 .pointer("/components/schemas/UserEvent")
1000 .expect("UserEvent missing");
1001 assert!(schema.get("$ref").is_none());
1002 }
1003
1004 #[tokio::test]
1005 async fn test_axum_to_openapi_path_conversion() {
1006 async fn user_handler() -> Json<Value> {
1008 Json(serde_json::json!({"user_id": "123"}))
1009 }
1010
1011 let api = ApiGateway::default();
1012 let router = axum::Router::new();
1013
1014 let _router = OperationBuilder::<Missing, Missing, ()>::get("/tests/v1/users/{id}")
1015 .summary("Get user by ID")
1016 .public()
1017 .path_param("id", "User ID")
1018 .handler(user_handler)
1019 .json_response(http::StatusCode::OK, "User details")
1020 .register(router, &api);
1021
1022 let ops: Vec<_> = api
1024 .openapi_registry
1025 .operation_specs
1026 .iter()
1027 .map(|e| e.value().clone())
1028 .collect();
1029 assert_eq!(ops.len(), 1);
1030 assert_eq!(ops[0].path, "/tests/v1/users/{id}");
1031
1032 let doc = api.build_openapi().expect("openapi");
1034 let v = serde_json::to_value(&doc).expect("json");
1035
1036 let paths = v.get("paths").expect("paths");
1037 assert!(
1038 paths.get("/tests/v1/users/{id}").is_some(),
1039 "OpenAPI should use {{id}} placeholder"
1040 );
1041 }
1042
1043 #[tokio::test]
1044 async fn test_multiple_path_params_conversion() {
1045 async fn item_handler() -> Json<Value> {
1046 Json(serde_json::json!({"ok": true}))
1047 }
1048
1049 let api = ApiGateway::default();
1050 let router = axum::Router::new();
1051
1052 let _router = OperationBuilder::<Missing, Missing, ()>::get(
1053 "/tests/v1/projects/{project_id}/items/{item_id}",
1054 )
1055 .summary("Get project item")
1056 .public()
1057 .path_param("project_id", "Project ID")
1058 .path_param("item_id", "Item ID")
1059 .handler(item_handler)
1060 .json_response(http::StatusCode::OK, "Item details")
1061 .register(router, &api);
1062
1063 let ops: Vec<_> = api
1065 .openapi_registry
1066 .operation_specs
1067 .iter()
1068 .map(|e| e.value().clone())
1069 .collect();
1070 assert_eq!(
1071 ops[0].path,
1072 "/tests/v1/projects/{project_id}/items/{item_id}"
1073 );
1074
1075 let doc = api.build_openapi().expect("openapi");
1076 let v = serde_json::to_value(&doc).expect("json");
1077 let paths = v.get("paths").expect("paths");
1078 assert!(
1079 paths
1080 .get("/tests/v1/projects/{project_id}/items/{item_id}")
1081 .is_some()
1082 );
1083 }
1084
1085 #[tokio::test]
1086 async fn test_wildcard_path_conversion() {
1087 async fn static_handler() -> Json<Value> {
1088 Json(serde_json::json!({"ok": true}))
1089 }
1090
1091 let api = ApiGateway::default();
1092 let router = axum::Router::new();
1093
1094 let _router = OperationBuilder::<Missing, Missing, ()>::get("/tests/v1/static/{*path}")
1096 .summary("Serve static files")
1097 .public()
1098 .handler(static_handler)
1099 .json_response(http::StatusCode::OK, "File content")
1100 .register(router, &api);
1101
1102 let ops: Vec<_> = api
1104 .openapi_registry
1105 .operation_specs
1106 .iter()
1107 .map(|e| e.value().clone())
1108 .collect();
1109 assert_eq!(ops[0].path, "/tests/v1/static/{*path}");
1110
1111 let doc = api.build_openapi().expect("openapi");
1113 let v = serde_json::to_value(&doc).expect("json");
1114 let paths = v.get("paths").expect("paths");
1115 assert!(
1116 paths.get("/tests/v1/static/{path}").is_some(),
1117 "Wildcard {{*path}} should be converted to {{path}} in OpenAPI"
1118 );
1119 assert!(
1120 paths.get("/static/{*path}").is_none(),
1121 "OpenAPI should not have Axum-style {{*path}}"
1122 );
1123 }
1124
1125 #[tokio::test]
1126 async fn test_multipart_file_upload_openapi() {
1127 async fn upload_handler() -> Json<Value> {
1128 Json(serde_json::json!({"uploaded": true}))
1129 }
1130
1131 let api = ApiGateway::default();
1132 let router = axum::Router::new();
1133
1134 let _router = OperationBuilder::<Missing, Missing, ()>::post("/tests/v1/files/upload")
1135 .operation_id("upload_file")
1136 .public()
1137 .summary("Upload a file")
1138 .multipart_file_request("file", Some("File to upload"))
1139 .handler(upload_handler)
1140 .json_response(http::StatusCode::OK, "Upload successful")
1141 .register(router, &api);
1142
1143 let doc = api.build_openapi().expect("openapi");
1145 let v = serde_json::to_value(&doc).expect("json");
1146
1147 let paths = v.get("paths").expect("paths");
1148 let upload_path = paths
1149 .get("/tests/v1/files/upload")
1150 .expect("/tests/v1/files/upload path");
1151 let post_op = upload_path.get("post").expect("POST operation");
1152
1153 let request_body = post_op.get("requestBody").expect("requestBody");
1155 let content = request_body.get("content").expect("content");
1156 let multipart = content
1157 .get("multipart/form-data")
1158 .expect("multipart/form-data content type");
1159
1160 let schema = multipart.get("schema").expect("schema");
1162 assert_eq!(
1163 schema.get("type").and_then(|v| v.as_str()),
1164 Some("object"),
1165 "Schema should be of type object"
1166 );
1167
1168 let properties = schema.get("properties").expect("properties");
1170 let file_prop = properties.get("file").expect("file property");
1171 assert_eq!(
1172 file_prop.get("type").and_then(|v| v.as_str()),
1173 Some("string"),
1174 "File field should be of type string"
1175 );
1176 assert_eq!(
1177 file_prop.get("format").and_then(|v| v.as_str()),
1178 Some("binary"),
1179 "File field should have format binary"
1180 );
1181
1182 let required = schema.get("required").expect("required");
1184 let required_arr = required.as_array().expect("required should be array");
1185 assert_eq!(required_arr.len(), 1);
1186 assert_eq!(required_arr[0].as_str(), Some("file"));
1187 }
1188}