use async_trait::async_trait;
use std::sync::Arc;
use arc_swap::ArcSwap;
use dashmap::DashMap;
use anyhow::Result;
use axum::http::Method;
use axum::middleware::from_fn_with_state;
use axum::{Router, extract::DefaultBodyLimit, middleware::from_fn, routing::get};
use modkit::api::{OpenApiRegistry, OpenApiRegistryImpl};
use modkit::lifecycle::ReadySignal;
use parking_lot::Mutex;
use std::net::SocketAddr;
use std::time::Duration;
use tokio_util::sync::CancellationToken;
use tower_http::{
catch_panic::CatchPanicLayer,
limit::RequestBodyLimitLayer,
request_id::{PropagateRequestIdLayer, SetRequestIdLayer},
timeout::TimeoutLayer,
};
use tracing::debug;
use authn_resolver_sdk::AuthNResolverClient;
use crate::config::ApiGatewayConfig;
use crate::middleware::auth;
use modkit_security::SecurityContext;
use modkit_security::constants::{DEFAULT_SUBJECT_ID, DEFAULT_TENANT_ID};
use crate::middleware;
use crate::router_cache::RouterCache;
use crate::web;
#[modkit::module(
name = "api-gateway",
capabilities = [rest_host, rest, stateful],
deps = ["grpc-hub", "authn-resolver"],
lifecycle(entry = "serve", stop_timeout = "30s", await_ready)
)]
pub struct ApiGateway {
pub(crate) config: ArcSwap<ApiGatewayConfig>,
pub(crate) openapi_registry: Arc<OpenApiRegistryImpl>,
pub(crate) router_cache: RouterCache<axum::Router>,
pub(crate) final_router: Mutex<Option<axum::Router>>,
pub(crate) authn_client: Mutex<Option<Arc<dyn AuthNResolverClient>>>,
pub(crate) registered_routes: DashMap<(Method, String), ()>,
pub(crate) registered_handlers: DashMap<String, ()>,
}
impl Default for ApiGateway {
fn default() -> Self {
let default_router = Router::new();
Self {
config: ArcSwap::from_pointee(ApiGatewayConfig::default()),
openapi_registry: Arc::new(OpenApiRegistryImpl::new()),
router_cache: RouterCache::new(default_router),
final_router: Mutex::new(None),
authn_client: Mutex::new(None),
registered_routes: DashMap::new(),
registered_handlers: DashMap::new(),
}
}
}
impl ApiGateway {
fn apply_prefix_nesting(mut router: Router, prefix: &str) -> Router {
if prefix.is_empty() {
return router;
}
let top = Router::new()
.route("/health", get(web::health_check))
.route("/healthz", get(|| async { "ok" }));
router = Router::new().nest(prefix, router);
top.merge(router)
}
#[must_use]
pub fn new(config: ApiGatewayConfig) -> Self {
let default_router = Router::new();
Self {
config: ArcSwap::from_pointee(config),
openapi_registry: Arc::new(OpenApiRegistryImpl::new()),
router_cache: RouterCache::new(default_router),
final_router: Mutex::new(None),
authn_client: Mutex::new(None),
registered_routes: DashMap::new(),
registered_handlers: DashMap::new(),
}
}
pub fn get_config(&self) -> ApiGatewayConfig {
(**self.config.load()).clone()
}
pub fn get_cached_config(&self) -> ApiGatewayConfig {
(**self.config.load()).clone()
}
pub fn get_cached_router(&self) -> Arc<Router> {
self.router_cache.load()
}
pub fn rebuild_and_cache_router(&self) -> Result<()> {
let new_router = self.build_router()?;
self.router_cache.store(new_router);
Ok(())
}
fn build_route_policy_from_specs(&self) -> Result<auth::GatewayRoutePolicy> {
let mut authenticated_routes = std::collections::HashSet::new();
let mut public_routes = std::collections::HashSet::new();
public_routes.insert((Method::GET, "/health".to_owned()));
public_routes.insert((Method::GET, "/healthz".to_owned()));
public_routes.insert((Method::GET, "/docs".to_owned()));
public_routes.insert((Method::GET, "/openapi.json".to_owned()));
for spec in &self.openapi_registry.operation_specs {
let spec = spec.value();
let route_key = (spec.method.clone(), spec.path.clone());
if spec.authenticated {
authenticated_routes.insert(route_key.clone());
}
if spec.is_public {
public_routes.insert(route_key);
}
}
let config = self.get_cached_config();
let requirements_count = authenticated_routes.len();
let public_routes_count = public_routes.len();
let route_policy = auth::build_route_policy(&config, authenticated_routes, public_routes)?;
tracing::info!(
auth_disabled = config.auth_disabled,
require_auth_by_default = config.require_auth_by_default,
requirements_count = requirements_count,
public_routes_count = public_routes_count,
"Route policy built from operation specs"
);
Ok(route_policy)
}
fn normalize_prefix_path(raw: &str) -> Result<String> {
let trimmed = raw.trim();
let collapsed: String =
trimmed
.chars()
.fold(String::with_capacity(trimmed.len()), |mut acc, c| {
if c == '/' && acc.ends_with('/') {
} else {
acc.push(c);
}
acc
});
let prefix = collapsed.trim_end_matches('/');
let result = if prefix.is_empty() {
String::new()
} else if prefix.starts_with('/') {
prefix.to_owned()
} else {
format!("/{prefix}")
};
if !result
.bytes()
.all(|b| b.is_ascii_alphanumeric() || b == b'/' || b == b'_' || b == b'-' || b == b'.')
{
anyhow::bail!(
"prefix_path contains invalid characters (must match [a-zA-Z0-9/_\\-.]): {raw:?}"
);
}
if result.split('/').any(|seg| seg == "." || seg == "..") {
anyhow::bail!("prefix_path must not contain '.' or '..' segments: {raw:?}");
}
Ok(result)
}
pub(crate) fn apply_middleware_stack(
&self,
mut router: Router,
authn_client: Option<Arc<dyn AuthNResolverClient>>,
) -> Result<Router> {
let route_policy = self.build_route_policy_from_specs()?;
router = router.route_layer(from_fn(middleware::http_metrics::propagate_matched_path));
let config = self.get_cached_config();
let specs: Vec<_> = self
.openapi_registry
.operation_specs
.iter()
.map(|e| e.value().clone())
.collect();
let license_map = middleware::license_validation::LicenseRequirementMap::from_specs(&specs);
router = router.layer(from_fn(
move |req: axum::extract::Request, next: axum::middleware::Next| {
let map = license_map.clone();
middleware::license_validation::license_validation_middleware(map, req, next)
},
));
if config.auth_disabled {
let default_security_context = SecurityContext::builder()
.subject_id(DEFAULT_SUBJECT_ID)
.subject_tenant_id(DEFAULT_TENANT_ID)
.build()?;
tracing::warn!(
"API Gateway auth is DISABLED: all requests will run with default tenant SecurityContext. \
This mode bypasses authentication and is intended ONLY for single-user on-premises deployments without an IdP. \
Permission checks and secure ORM still apply. DO NOT use this mode in multi-tenant or production environments."
);
router = router.layer(from_fn(
move |mut req: axum::extract::Request, next: axum::middleware::Next| {
let sec_context = default_security_context.clone();
async move {
req.extensions_mut().insert(sec_context);
next.run(req).await
}
},
));
} else if let Some(client) = authn_client {
let auth_state = auth::AuthState {
authn_client: client,
route_policy,
};
router = router.layer(from_fn_with_state(auth_state, auth::authn_middleware));
} else {
return Err(anyhow::anyhow!(
"auth is enabled but no AuthN Resolver client is available; \
ensure `authn_resolver` module is loaded or set `auth_disabled: true`"
));
}
router = router.layer(from_fn(modkit::api::error_layer::error_mapping_middleware));
let rate_map = middleware::rate_limit::RateLimiterMap::from_specs(&specs, &config)?;
router = router.layer(from_fn(
move |req: axum::extract::Request, next: axum::middleware::Next| {
let map = rate_map.clone();
middleware::rate_limit::rate_limit_middleware(map, req, next)
},
));
let mime_map = middleware::mime_validation::build_mime_validation_map(&specs);
router = router.layer(from_fn(
move |req: axum::extract::Request, next: axum::middleware::Next| {
let map = mime_map.clone();
middleware::mime_validation::mime_validation_middleware(map, req, next)
},
));
if config.cors_enabled {
router = router.layer(crate::cors::build_cors_layer(&config));
}
router = router.layer(RequestBodyLimitLayer::new(config.defaults.body_limit_bytes));
router = router.layer(DefaultBodyLimit::max(config.defaults.body_limit_bytes));
router = router.layer(TimeoutLayer::with_status_code(
axum::http::StatusCode::GATEWAY_TIMEOUT,
Duration::from_secs(30),
));
router = router.layer(CatchPanicLayer::new());
let http_metrics = Arc::new(middleware::http_metrics::HttpMetrics::new(
Self::MODULE_NAME,
&config.metrics.prefix,
));
router = router.layer(from_fn_with_state(
http_metrics,
middleware::http_metrics::http_metrics_middleware,
));
router = router.layer(from_fn(middleware::request_id::push_req_id_to_extensions));
router = router.layer({
use modkit_http::otel;
use tower_http::trace::TraceLayer;
use tracing::field::Empty;
TraceLayer::new_for_http()
.make_span_with(move |req: &axum::http::Request<axum::body::Body>| {
let hdr = middleware::request_id::header();
let rid = req
.headers()
.get(&hdr)
.and_then(|v| v.to_str().ok())
.unwrap_or("n/a");
let span = tracing::info_span!(
"http_request",
method = %req.method(),
uri = %req.uri().path(),
version = ?req.version(),
module = "api_gateway",
endpoint = %req.uri().path(),
request_id = %rid,
status = Empty,
latency_ms = Empty,
"http.method" = %req.method(),
"http.target" = %req.uri().path(),
"http.scheme" = req.uri().scheme_str().unwrap_or("http"),
"http.host" = req.headers().get("host")
.and_then(|h| h.to_str().ok())
.unwrap_or("unknown"),
"user_agent.original" = req.headers().get("user-agent")
.and_then(|h| h.to_str().ok())
.unwrap_or("unknown"),
trace_id = Empty,
parent.trace_id = Empty
);
otel::set_parent_from_headers(&span, req.headers());
span
})
.on_response(
|res: &axum::http::Response<axum::body::Body>,
latency: std::time::Duration,
span: &tracing::Span| {
let ms = latency.as_millis();
span.record("status", res.status().as_u16());
span.record("latency_ms", ms);
},
)
});
let x_request_id = crate::middleware::request_id::header();
router = router.layer(PropagateRequestIdLayer::new(x_request_id.clone()));
router = router.layer(SetRequestIdLayer::new(
x_request_id,
crate::middleware::request_id::MakeReqId,
));
Ok(router)
}
pub fn build_router(&self) -> Result<Router> {
let cached_router = self.router_cache.load();
if Arc::strong_count(&cached_router) > 1 {
tracing::debug!("Using cached router");
return Ok((*cached_router).clone());
}
tracing::debug!("Building new router (standalone/fallback mode)");
let mut router = Router::new()
.route("/health", get(web::health_check))
.route("/healthz", get(|| async { "ok" }));
let authn_client = self.authn_client.lock().clone();
router = self.apply_middleware_stack(router, authn_client)?;
let config = self.get_cached_config();
let prefix = Self::normalize_prefix_path(&config.prefix_path)?;
router = Self::apply_prefix_nesting(router, &prefix);
self.router_cache.store(router.clone());
Ok(router)
}
pub fn build_openapi(&self) -> Result<utoipa::openapi::OpenApi> {
let config = self.get_cached_config();
let info = modkit::api::OpenApiInfo {
title: config.openapi.title.clone(),
version: config.openapi.version.clone(),
description: config.openapi.description,
};
self.openapi_registry.build_openapi(&info)
}
fn parse_bind_address(bind_addr: &str) -> anyhow::Result<SocketAddr> {
bind_addr
.parse()
.map_err(|e| anyhow::anyhow!("Invalid bind address '{bind_addr}': {e}"))
}
fn get_or_build_router(self: &Arc<Self>) -> anyhow::Result<Router> {
let stored = { self.final_router.lock().take() };
if let Some(router) = stored {
tracing::debug!("Using router from REST phase");
Ok(router)
} else {
tracing::debug!("No router from REST phase, building default router");
self.build_router()
}
}
pub(crate) async fn serve(
self: Arc<Self>,
cancel: CancellationToken,
ready: ReadySignal,
) -> anyhow::Result<()> {
let cfg = self.get_cached_config();
let addr = Self::parse_bind_address(&cfg.bind_addr)?;
let router = self.get_or_build_router()?;
let listener = tokio::net::TcpListener::bind(addr).await?;
tracing::info!("HTTP server bound on {}", addr);
ready.notify();
let shutdown = {
let cancel = cancel.clone();
async move {
cancel.cancelled().await;
tracing::info!("HTTP server shutting down gracefully (cancellation)");
}
};
axum::serve(listener, router)
.with_graceful_shutdown(shutdown)
.await
.map_err(|e| anyhow::anyhow!(e))
}
fn check_duplicate_handler(&self, spec: &modkit::api::OperationSpec) -> bool {
if self
.registered_handlers
.insert(spec.handler_id.clone(), ())
.is_some()
{
tracing::error!(
handler_id = %spec.handler_id,
method = %spec.method.as_str(),
path = %spec.path,
"Duplicate handler_id detected; ignoring subsequent registration"
);
return true;
}
false
}
fn check_duplicate_route(&self, spec: &modkit::api::OperationSpec) -> bool {
let route_key = (spec.method.clone(), spec.path.clone());
if self.registered_routes.insert(route_key, ()).is_some() {
tracing::error!(
method = %spec.method.as_str(),
path = %spec.path,
"Duplicate (method, path) detected; ignoring subsequent registration"
);
return true;
}
false
}
fn log_operation_registration(&self, spec: &modkit::api::OperationSpec) {
let current_count = self.openapi_registry.operation_specs.len();
tracing::debug!(
handler_id = %spec.handler_id,
method = %spec.method.as_str(),
path = %spec.path,
summary = %spec.summary.as_deref().unwrap_or("No summary"),
total_operations = current_count,
"Registered API operation"
);
}
fn add_openapi_routes(&self, mut router: axum::Router) -> anyhow::Result<axum::Router> {
let op_count = self.openapi_registry.operation_specs.len();
tracing::info!(
"rest_finalize: emitting OpenAPI with {} operations",
op_count
);
let openapi_doc = Arc::new(self.build_openapi()?);
let config = self.get_cached_config();
let prefix = Self::normalize_prefix_path(&config.prefix_path)?;
let html_doc = web::serve_docs(&prefix);
router = router
.route(
"/openapi.json",
get({
use axum::{http::header, response::IntoResponse};
let doc = openapi_doc;
move || async move {
let json_string = match serde_json::to_string_pretty(doc.as_ref()) {
Ok(json) => json,
Err(e) => {
tracing::error!("Failed to serialize OpenAPI doc: {}", e);
return (http::StatusCode::INTERNAL_SERVER_ERROR).into_response();
}
};
(
[
(header::CONTENT_TYPE, "application/json"),
(header::CACHE_CONTROL, "no-store"),
],
json_string,
)
.into_response()
}
}),
)
.route("/docs", get(move || async move { html_doc }));
#[cfg(feature = "embed_elements")]
{
router = router.route(
"/docs/assets/{*file}",
get(crate::assets::serve_elements_asset),
);
}
Ok(router)
}
}
#[async_trait]
impl modkit::Module for ApiGateway {
async fn init(&self, ctx: &modkit::context::ModuleCtx) -> anyhow::Result<()> {
let cfg = ctx.config::<crate::config::ApiGatewayConfig>()?;
self.config.store(Arc::new(cfg.clone()));
debug!(
"Effective api_gateway configuration:\n{:#?}",
self.config.load()
);
if cfg.auth_disabled {
tracing::info!(
tenant_id = %DEFAULT_TENANT_ID,
"Auth-disabled mode enabled with default tenant"
);
} else {
let authn_client = ctx.client_hub().get::<dyn AuthNResolverClient>()?;
*self.authn_client.lock() = Some(authn_client);
tracing::info!("AuthN Resolver client resolved from ClientHub");
}
Ok(())
}
}
impl modkit::contracts::ApiGatewayCapability for ApiGateway {
fn rest_prepare(
&self,
_ctx: &modkit::context::ModuleCtx,
router: axum::Router,
) -> anyhow::Result<axum::Router> {
let router = router
.route("/health", get(web::health_check))
.route("/healthz", get(|| async { "ok" }));
tracing::debug!("REST host prepared base router with health check endpoints");
Ok(router)
}
fn rest_finalize(
&self,
_ctx: &modkit::context::ModuleCtx,
mut router: axum::Router,
) -> anyhow::Result<axum::Router> {
let config = self.get_cached_config();
if config.enable_docs {
router = self.add_openapi_routes(router)?;
}
tracing::debug!("Applying middleware stack to finalized router");
let authn_client = self.authn_client.lock().clone();
router = self.apply_middleware_stack(router, authn_client)?;
let prefix = Self::normalize_prefix_path(&config.prefix_path)?;
router = Self::apply_prefix_nesting(router, &prefix);
*self.final_router.lock() = Some(router.clone());
tracing::info!("REST host finalized router with OpenAPI endpoints and auth middleware");
Ok(router)
}
fn as_registry(&self) -> &dyn modkit::contracts::OpenApiRegistry {
self
}
}
impl modkit::contracts::RestApiCapability for ApiGateway {
fn register_rest(
&self,
_ctx: &modkit::context::ModuleCtx,
router: axum::Router,
_openapi: &dyn modkit::contracts::OpenApiRegistry,
) -> anyhow::Result<axum::Router> {
Ok(router)
}
}
impl OpenApiRegistry for ApiGateway {
fn register_operation(&self, spec: &modkit::api::OperationSpec) {
if self.check_duplicate_handler(spec) {
return;
}
if self.check_duplicate_route(spec) {
return;
}
self.openapi_registry.register_operation(spec);
self.log_operation_registration(spec);
}
fn ensure_schema_raw(
&self,
root_name: &str,
schemas: Vec<(
String,
utoipa::openapi::RefOr<utoipa::openapi::schema::Schema>,
)>,
) -> String {
self.openapi_registry.ensure_schema_raw(root_name, schemas)
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
}
#[cfg(test)]
#[cfg_attr(coverage_nightly, coverage(off))]
mod tests {
use super::*;
#[test]
fn test_openapi_generation() {
let mut config = ApiGatewayConfig::default();
config.openapi.title = "Test API".to_owned();
config.openapi.version = "1.0.0".to_owned();
config.openapi.description = Some("Test Description".to_owned());
let api = ApiGateway::new(config);
let doc = api.build_openapi().unwrap();
let json = serde_json::to_value(&doc).unwrap();
assert!(json.get("openapi").is_some());
assert!(json.get("info").is_some());
assert!(json.get("paths").is_some());
let info = json.get("info").unwrap();
assert_eq!(info.get("title").unwrap(), "Test API");
assert_eq!(info.get("version").unwrap(), "1.0.0");
assert_eq!(info.get("description").unwrap(), "Test Description");
}
}
#[cfg(test)]
#[cfg_attr(coverage_nightly, coverage(off))]
mod normalize_prefix_path_tests {
use super::*;
#[test]
fn empty_string_returns_empty() {
assert_eq!(ApiGateway::normalize_prefix_path("").unwrap(), "");
}
#[test]
fn sole_slash_returns_empty() {
assert_eq!(ApiGateway::normalize_prefix_path("/").unwrap(), "");
}
#[test]
fn multiple_slashes_return_empty() {
assert_eq!(ApiGateway::normalize_prefix_path("///").unwrap(), "");
}
#[test]
fn whitespace_only_returns_empty() {
assert_eq!(ApiGateway::normalize_prefix_path(" ").unwrap(), "");
}
#[test]
fn simple_prefix_preserved() {
assert_eq!(ApiGateway::normalize_prefix_path("/cf").unwrap(), "/cf");
}
#[test]
fn trailing_slash_stripped() {
assert_eq!(ApiGateway::normalize_prefix_path("/cf/").unwrap(), "/cf");
}
#[test]
fn leading_slash_prepended_when_missing() {
assert_eq!(ApiGateway::normalize_prefix_path("cf").unwrap(), "/cf");
}
#[test]
fn consecutive_leading_slashes_collapsed() {
assert_eq!(ApiGateway::normalize_prefix_path("//cf").unwrap(), "/cf");
}
#[test]
fn consecutive_slashes_mid_path_collapsed() {
assert_eq!(
ApiGateway::normalize_prefix_path("/api//v1").unwrap(),
"/api/v1"
);
}
#[test]
fn many_consecutive_slashes_collapsed() {
assert_eq!(
ApiGateway::normalize_prefix_path("///api///v1///").unwrap(),
"/api/v1"
);
}
#[test]
fn surrounding_whitespace_trimmed() {
assert_eq!(ApiGateway::normalize_prefix_path(" /cf ").unwrap(), "/cf");
}
#[test]
fn nested_path_preserved() {
assert_eq!(
ApiGateway::normalize_prefix_path("/api/v1").unwrap(),
"/api/v1"
);
}
#[test]
fn dot_in_path_allowed() {
assert_eq!(
ApiGateway::normalize_prefix_path("/api/v1.0").unwrap(),
"/api/v1.0"
);
}
#[test]
fn rejects_html_injection() {
let result = ApiGateway::normalize_prefix_path(r#""><script>alert(1)</script>"#);
assert!(result.is_err());
}
#[test]
fn rejects_spaces_in_path() {
let result = ApiGateway::normalize_prefix_path("/my path");
assert!(result.is_err());
}
#[test]
fn rejects_query_string_chars() {
let result = ApiGateway::normalize_prefix_path("/api?foo=bar");
assert!(result.is_err());
}
}
#[cfg(test)]
#[cfg_attr(coverage_nightly, coverage(off))]
mod problem_openapi_tests {
use super::*;
use axum::Json;
use modkit::api::{Missing, OperationBuilder};
use serde_json::Value;
async fn dummy_handler() -> Json<Value> {
Json(serde_json::json!({"ok": true}))
}
#[tokio::test]
async fn openapi_includes_problem_schema_and_response() {
let api = ApiGateway::default();
let router = axum::Router::new();
let _router = OperationBuilder::<Missing, Missing, ()>::get("/tests/v1/problem-demo")
.public()
.summary("Problem demo")
.problem_response(&api, http::StatusCode::BAD_REQUEST, "Bad Request") .handler(dummy_handler)
.register(router, &api);
let doc = api.build_openapi().expect("openapi");
let v = serde_json::to_value(&doc).expect("json");
let problem = v
.pointer("/components/schemas/Problem")
.expect("Problem schema missing");
assert!(
problem.get("$ref").is_none(),
"Problem must be a real object, not a self-ref"
);
let path_obj = v
.pointer("/paths/~1tests~1v1~1problem-demo/get/responses/400")
.expect("400 response missing");
let content_obj = path_obj.get("content").expect("content object missing");
assert!(
content_obj.get("application/problem+json").is_some(),
"application/problem+json content missing. Available content: {}",
serde_json::to_string_pretty(content_obj).unwrap()
);
let content = path_obj
.pointer("/content/application~1problem+json")
.expect("application/problem+json content missing");
let schema_ref = content
.pointer("/schema/$ref")
.and_then(|r| r.as_str())
.unwrap_or("");
assert_eq!(schema_ref, "#/components/schemas/Problem");
}
}
#[cfg(test)]
#[cfg_attr(coverage_nightly, coverage(off))]
mod sse_openapi_tests {
use super::*;
use axum::Json;
use modkit::api::{Missing, OperationBuilder};
use serde_json::Value;
#[derive(Clone)]
#[modkit_macros::api_dto(request, response)]
struct UserEvent {
id: u32,
message: String,
}
async fn sse_handler() -> axum::response::sse::Sse<
impl futures_core::Stream<Item = Result<axum::response::sse::Event, std::convert::Infallible>>,
> {
let b = modkit::SseBroadcaster::<UserEvent>::new(4);
b.sse_response()
}
#[tokio::test]
async fn openapi_has_sse_content() {
let api = ApiGateway::default();
let router = axum::Router::new();
let _router = OperationBuilder::<Missing, Missing, ()>::get("/tests/v1/demo/sse")
.summary("Demo SSE")
.handler(sse_handler)
.public()
.sse_json::<UserEvent>(&api, "SSE of UserEvent")
.register(router, &api);
let doc = api.build_openapi().expect("openapi");
let v = serde_json::to_value(&doc).expect("json");
let schema = v
.pointer("/components/schemas/UserEvent")
.expect("UserEvent missing");
assert!(schema.get("$ref").is_none());
let refp = v
.pointer("/paths/~1tests~1v1~1demo~1sse/get/responses/200/content/text~1event-stream/schema/$ref")
.and_then(|x| x.as_str())
.unwrap_or_default();
assert_eq!(refp, "#/components/schemas/UserEvent");
}
#[tokio::test]
async fn openapi_sse_additional_response() {
async fn mixed_handler() -> Json<Value> {
Json(serde_json::json!({"ok": true}))
}
let api = ApiGateway::default();
let router = axum::Router::new();
let _router = OperationBuilder::<Missing, Missing, ()>::get("/tests/v1/demo/mixed")
.summary("Mixed responses")
.public()
.handler(mixed_handler)
.json_response(http::StatusCode::OK, "Success response")
.sse_json::<UserEvent>(&api, "Additional SSE stream")
.register(router, &api);
let doc = api.build_openapi().expect("openapi");
let v = serde_json::to_value(&doc).expect("json");
let responses = v
.pointer("/paths/~1tests~1v1~1demo~1mixed/get/responses")
.expect("responses");
assert!(responses.get("200").is_some());
let response_content = responses.get("200").and_then(|r| r.get("content"));
assert!(response_content.is_some());
let schema = v
.pointer("/components/schemas/UserEvent")
.expect("UserEvent missing");
assert!(schema.get("$ref").is_none());
}
#[tokio::test]
async fn test_axum_to_openapi_path_conversion() {
async fn user_handler() -> Json<Value> {
Json(serde_json::json!({"user_id": "123"}))
}
let api = ApiGateway::default();
let router = axum::Router::new();
let _router = OperationBuilder::<Missing, Missing, ()>::get("/tests/v1/users/{id}")
.summary("Get user by ID")
.public()
.path_param("id", "User ID")
.handler(user_handler)
.json_response(http::StatusCode::OK, "User details")
.register(router, &api);
let ops: Vec<_> = api
.openapi_registry
.operation_specs
.iter()
.map(|e| e.value().clone())
.collect();
assert_eq!(ops.len(), 1);
assert_eq!(ops[0].path, "/tests/v1/users/{id}");
let doc = api.build_openapi().expect("openapi");
let v = serde_json::to_value(&doc).expect("json");
let paths = v.get("paths").expect("paths");
assert!(
paths.get("/tests/v1/users/{id}").is_some(),
"OpenAPI should use {{id}} placeholder"
);
}
#[tokio::test]
async fn test_multiple_path_params_conversion() {
async fn item_handler() -> Json<Value> {
Json(serde_json::json!({"ok": true}))
}
let api = ApiGateway::default();
let router = axum::Router::new();
let _router = OperationBuilder::<Missing, Missing, ()>::get(
"/tests/v1/projects/{project_id}/items/{item_id}",
)
.summary("Get project item")
.public()
.path_param("project_id", "Project ID")
.path_param("item_id", "Item ID")
.handler(item_handler)
.json_response(http::StatusCode::OK, "Item details")
.register(router, &api);
let ops: Vec<_> = api
.openapi_registry
.operation_specs
.iter()
.map(|e| e.value().clone())
.collect();
assert_eq!(
ops[0].path,
"/tests/v1/projects/{project_id}/items/{item_id}"
);
let doc = api.build_openapi().expect("openapi");
let v = serde_json::to_value(&doc).expect("json");
let paths = v.get("paths").expect("paths");
assert!(
paths
.get("/tests/v1/projects/{project_id}/items/{item_id}")
.is_some()
);
}
#[tokio::test]
async fn test_wildcard_path_conversion() {
async fn static_handler() -> Json<Value> {
Json(serde_json::json!({"ok": true}))
}
let api = ApiGateway::default();
let router = axum::Router::new();
let _router = OperationBuilder::<Missing, Missing, ()>::get("/tests/v1/static/{*path}")
.summary("Serve static files")
.public()
.handler(static_handler)
.json_response(http::StatusCode::OK, "File content")
.register(router, &api);
let ops: Vec<_> = api
.openapi_registry
.operation_specs
.iter()
.map(|e| e.value().clone())
.collect();
assert_eq!(ops[0].path, "/tests/v1/static/{*path}");
let doc = api.build_openapi().expect("openapi");
let v = serde_json::to_value(&doc).expect("json");
let paths = v.get("paths").expect("paths");
assert!(
paths.get("/tests/v1/static/{path}").is_some(),
"Wildcard {{*path}} should be converted to {{path}} in OpenAPI"
);
assert!(
paths.get("/static/{*path}").is_none(),
"OpenAPI should not have Axum-style {{*path}}"
);
}
#[tokio::test]
async fn test_multipart_file_upload_openapi() {
async fn upload_handler() -> Json<Value> {
Json(serde_json::json!({"uploaded": true}))
}
let api = ApiGateway::default();
let router = axum::Router::new();
let _router = OperationBuilder::<Missing, Missing, ()>::post("/tests/v1/files/upload")
.operation_id("upload_file")
.public()
.summary("Upload a file")
.multipart_file_request("file", Some("File to upload"))
.handler(upload_handler)
.json_response(http::StatusCode::OK, "Upload successful")
.register(router, &api);
let doc = api.build_openapi().expect("openapi");
let v = serde_json::to_value(&doc).expect("json");
let paths = v.get("paths").expect("paths");
let upload_path = paths
.get("/tests/v1/files/upload")
.expect("/tests/v1/files/upload path");
let post_op = upload_path.get("post").expect("POST operation");
let request_body = post_op.get("requestBody").expect("requestBody");
let content = request_body.get("content").expect("content");
let multipart = content
.get("multipart/form-data")
.expect("multipart/form-data content type");
let schema = multipart.get("schema").expect("schema");
assert_eq!(
schema.get("type").and_then(|v| v.as_str()),
Some("object"),
"Schema should be of type object"
);
let properties = schema.get("properties").expect("properties");
let file_prop = properties.get("file").expect("file property");
assert_eq!(
file_prop.get("type").and_then(|v| v.as_str()),
Some("string"),
"File field should be of type string"
);
assert_eq!(
file_prop.get("format").and_then(|v| v.as_str()),
Some("binary"),
"File field should have format binary"
);
let required = schema.get("required").expect("required");
let required_arr = required.as_array().expect("required should be array");
assert_eq!(required_arr.len(), 1);
assert_eq!(required_arr[0].as_str(), Some("file"));
}
}