use std::sync::Arc;
use std::time::Duration;
use axum::{
Extension, Json, Router,
error_handling::HandleErrorLayer,
extract::DefaultBodyLimit,
middleware,
response::IntoResponse,
routing::{get, post},
};
use serde::Serialize;
use tower::BoxError;
use tower::ServiceBuilder;
use tower::limit::ConcurrencyLimitLayer;
use tower::timeout::TimeoutLayer;
use tower_http::cors::{Any, CorsLayer};
use forge_core::cluster::NodeId;
use forge_core::config::McpConfig;
use forge_core::function::{JobDispatch, KvHandle, WorkflowDispatch};
#[cfg(feature = "otel")]
use opentelemetry::global;
#[cfg(feature = "otel")]
use opentelemetry::propagation::Extractor;
use tracing::Instrument;
#[cfg(feature = "otel")]
use tracing_opentelemetry::OpenTelemetrySpanExt;
use super::admin::{AdminState, admin_router};
use super::auth::{AuthConfig, AuthMiddleware, HmacTokenIssuer, auth_middleware};
use super::mcp::{McpState, mcp_get_handler, mcp_post_handler};
use super::multipart::{MultipartConfig, rpc_multipart_handler};
use super::response::{RpcError, RpcResponse};
use super::rpc::{RpcHandler, rpc_function_handler, rpc_handler};
use super::sse::{
SseState, sse_handler, sse_job_subscribe_handler, sse_subscribe_handler,
sse_unsubscribe_handler, sse_workflow_subscribe_handler,
};
use super::tls::{TlsListenConfig, bind_listener};
use super::tracing::{REQUEST_ID_HEADER, SPAN_ID_HEADER, TRACE_ID_HEADER, TracingState};
use crate::function::FunctionRegistry;
use crate::mcp::McpToolRegistry;
use crate::pg::{Database, PgNotifyBus};
use crate::realtime::{Reactor, ReactorConfig};
const DEFAULT_MAX_JSON_BODY_SIZE: usize = 1024 * 1024;
const DEFAULT_MAX_MULTIPART_BODY_SIZE: usize = 20 * 1024 * 1024;
const DEFAULT_MAX_FILE_SIZE: usize = 10 * 1024 * 1024;
const MAX_MULTIPART_CONCURRENCY: usize = 32;
const DEFAULT_SIGNAL_SECRET: &str = "forge-default-signal-secret";
fn signal_visitor_secret(jwt_secret: &Option<String>) -> String {
jwt_secret.clone().unwrap_or_else(|| {
tracing::warn!(
"No jwt_secret configured; using default signal secret for visitor ID hashing. \
Visitor IDs will be predictable. Set [auth] jwt_secret in forge.toml."
);
DEFAULT_SIGNAL_SECRET.to_string()
})
}
#[derive(Debug, Clone)]
pub struct GatewayConfig {
pub port: u16,
pub max_connections: usize,
pub sse_max_sessions: usize,
pub request_timeout_secs: u64,
pub cors_enabled: bool,
pub cors_origins: Vec<String>,
pub auth: AuthConfig,
pub mcp: McpConfig,
pub quiet_paths: Vec<String>,
pub token_ttl: forge_core::AuthTokenTtl,
pub project_name: String,
pub max_body_size_bytes: usize,
pub max_json_body_bytes: usize,
pub max_file_size_bytes: usize,
pub tls: Option<TlsListenConfig>,
pub max_multipart_fields: usize,
pub max_sessions_per_user: usize,
pub max_sessions_per_ip: usize,
pub max_subscriptions_per_user: usize,
pub reactor_config: ReactorConfig,
pub security_headers: bool,
pub hsts: bool,
pub trusted_proxies: Vec<ipnet::IpNet>,
pub max_jobs_per_request: usize,
pub max_result_size_bytes: usize,
pub max_json_depth: usize,
}
impl Default for GatewayConfig {
fn default() -> Self {
Self {
port: 9081,
max_connections: 512,
sse_max_sessions: 10_000,
request_timeout_secs: 30,
cors_enabled: false,
cors_origins: Vec::new(),
auth: AuthConfig::default(),
mcp: McpConfig::default(),
quiet_paths: Vec::new(),
token_ttl: forge_core::AuthTokenTtl::default(),
project_name: "forge-app".to_string(),
max_body_size_bytes: DEFAULT_MAX_MULTIPART_BODY_SIZE,
max_json_body_bytes: DEFAULT_MAX_JSON_BODY_SIZE,
max_file_size_bytes: DEFAULT_MAX_FILE_SIZE,
tls: None,
max_multipart_fields: 20,
max_sessions_per_user: 8,
max_sessions_per_ip: 32,
max_subscriptions_per_user: 500,
reactor_config: ReactorConfig::default(),
security_headers: true,
hsts: false,
trusted_proxies: Vec::new(),
max_jobs_per_request: 10,
max_result_size_bytes: 10 * 1024 * 1024,
max_json_depth: 64,
}
}
}
#[derive(Debug, Clone)]
pub struct TrustedProxies(pub Arc<Vec<ipnet::IpNet>>);
#[derive(Debug, Serialize)]
pub struct HealthResponse {
pub status: String,
pub version: String,
}
#[derive(Debug, Serialize)]
#[non_exhaustive]
pub struct ReadinessResponse {
pub ready: bool,
pub database: bool,
pub reactor: bool,
pub notify_queue_ok: bool,
pub migrations_ok: bool,
pub cluster_registered: Option<bool>,
pub version: String,
}
#[derive(Clone)]
pub struct ReadinessState {
db_pool: sqlx::PgPool,
reactor: Arc<Reactor>,
node_id: Option<uuid::Uuid>,
expected_system_migrations: i64,
}
const NOTIFY_QUEUE_FAIL_THRESHOLD: f64 = 0.75;
pub struct GatewayServer {
config: GatewayConfig,
registry: FunctionRegistry,
db: Database,
reactor: Arc<Reactor>,
job_dispatcher: Option<Arc<dyn JobDispatch>>,
workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
kv: Option<Arc<dyn KvHandle>>,
mcp_registry: Option<McpToolRegistry>,
token_ttl: forge_core::AuthTokenTtl,
signals_collector: Option<crate::signals::SignalsCollector>,
signals_anonymize_ip: bool,
signals_geoip: Option<crate::signals::geoip::GeoIpResolver>,
custom_routes: Option<Router>,
rate_limiter: Option<Arc<dyn forge_core::rate_limit::RateLimiterBackend>>,
role_resolver: Option<forge_core::SharedRoleResolver>,
cluster_node_id: Option<uuid::Uuid>,
}
impl GatewayServer {
pub fn new(
config: GatewayConfig,
registry: FunctionRegistry,
db: Database,
notify_bus: Arc<PgNotifyBus>,
) -> Self {
let node_id = NodeId::new();
let reactor = Arc::new(Reactor::new(
node_id,
Arc::new(db.clone()),
registry.clone(),
config.reactor_config.clone(),
notify_bus,
));
let token_ttl = config.token_ttl.clone();
Self {
config,
registry,
db,
reactor,
job_dispatcher: None,
workflow_dispatcher: None,
kv: None,
mcp_registry: None,
token_ttl,
signals_collector: None,
signals_anonymize_ip: false,
signals_geoip: None,
custom_routes: None,
rate_limiter: None,
role_resolver: None,
cluster_node_id: None,
}
}
pub fn with_node_id(mut self, id: NodeId) -> Self {
self.cluster_node_id = Some(id.as_uuid());
self
}
pub fn with_rate_limiter(
mut self,
rate_limiter: Arc<dyn forge_core::rate_limit::RateLimiterBackend>,
) -> Self {
self.rate_limiter = Some(rate_limiter);
self
}
pub fn with_role_resolver(mut self, resolver: forge_core::SharedRoleResolver) -> Self {
self.role_resolver = Some(resolver);
self
}
pub fn with_job_dispatcher(mut self, dispatcher: Arc<dyn JobDispatch>) -> Self {
self.job_dispatcher = Some(dispatcher);
self
}
pub fn with_workflow_dispatcher(mut self, dispatcher: Arc<dyn WorkflowDispatch>) -> Self {
self.workflow_dispatcher = Some(dispatcher);
self
}
pub fn with_kv(mut self, kv: Arc<dyn KvHandle>) -> Self {
self.kv = Some(kv);
self
}
pub fn with_mcp_registry(mut self, registry: McpToolRegistry) -> Self {
self.mcp_registry = Some(registry);
self
}
pub fn with_signals_collector(mut self, collector: crate::signals::SignalsCollector) -> Self {
crate::signals::install_global(Some(collector.clone()));
self.signals_collector = Some(collector);
self
}
pub fn with_signals_anonymize_ip(mut self, anonymize: bool) -> Self {
self.signals_anonymize_ip = anonymize;
self
}
pub fn with_signals_geoip(mut self, resolver: crate::signals::geoip::GeoIpResolver) -> Self {
self.signals_geoip = Some(resolver);
self
}
pub fn with_custom_routes(mut self, router: Router) -> Self {
self.custom_routes = Some(router);
self
}
pub fn reactor(&self) -> Arc<Reactor> {
self.reactor.clone()
}
pub fn tls(&self) -> Option<&TlsListenConfig> {
self.config.tls.as_ref()
}
#[cfg(feature = "mcp-oauth")]
pub fn oauth_router(&self) -> Option<(Router, Arc<super::oauth::OAuthState>)> {
if !self.config.mcp.oauth {
return None;
}
let token_issuer = HmacTokenIssuer::from_config(&self.config.auth)
.map(|issuer| Arc::new(issuer) as Arc<dyn forge_core::TokenIssuer>)?;
let auth_middleware_state = Arc::new(AuthMiddleware::new(self.config.auth.clone()));
let jwt_secret = self.config.auth.jwt_secret.clone().unwrap_or_default();
let oauth_state = Arc::new(super::oauth::OAuthState::new(
self.db.primary().clone(),
auth_middleware_state,
token_issuer,
self.token_ttl.access_token_secs,
self.token_ttl.refresh_token_days,
self.config.auth.is_hmac(),
self.config.project_name.clone(),
jwt_secret,
self.config.auth.session_cookie_ttl_secs,
self.config.mcp.allow_unauthenticated_dcr,
));
let router = Router::new()
.route(
"/oauth/authorize",
get(super::oauth::oauth_authorize_get).post(super::oauth::oauth_authorize_post),
)
.route("/oauth/token", post(super::oauth::oauth_token))
.route("/oauth/register", post(super::oauth::oauth_register))
.with_state(oauth_state.clone());
Some((router, oauth_state))
}
#[cfg(not(feature = "mcp-oauth"))]
pub fn oauth_router(&self) -> Option<(Router, ())> {
None
}
pub fn router(&self) -> Router {
let token_issuer = HmacTokenIssuer::from_config(&self.config.auth)
.map(|issuer| Arc::new(issuer) as Arc<dyn forge_core::TokenIssuer>);
let mut rpc = RpcHandler::with_dispatch_and_issuer(
self.registry.clone(),
self.db.clone(),
self.job_dispatcher.clone(),
self.workflow_dispatcher.clone(),
token_issuer,
);
rpc.set_token_ttl(self.token_ttl.clone());
rpc.set_max_jobs_per_request(self.config.max_jobs_per_request);
rpc.set_max_result_size_bytes(self.config.max_result_size_bytes);
if let Some(kv) = &self.kv {
rpc.set_kv(Arc::clone(kv));
}
if let Some(rate_limiter) = &self.rate_limiter {
rpc.set_rate_limiter(rate_limiter.clone());
}
if let Some(resolver) = &self.role_resolver {
rpc.set_role_resolver(resolver.clone());
}
if let Some(collector) = &self.signals_collector {
let secret = signal_visitor_secret(&self.config.auth.jwt_secret);
rpc.set_signals_collector(collector.clone(), secret);
}
let rpc_handler_state = Arc::new(rpc);
let cluster_cache = rpc_handler_state.router().cache();
drop(cluster_cache.spawn_cluster_invalidator(self.reactor.change_subscriber()));
let auth_middleware_state = Arc::new(AuthMiddleware::new(self.config.auth.clone()));
let cors = if self.config.cors_enabled {
if self.config.cors_origins.iter().any(|o| o == "*") {
tracing::warn!(
"CORS wildcard (`cors_origins = [\"*\"]`) is enabled. \
Credentialed requests will fail and any origin can \
reach the gateway. Set explicit origins for \
production deployments."
);
CorsLayer::new()
.allow_origin(Any)
.allow_methods(Any)
.allow_headers(Any)
} else {
use axum::http::Method;
let origins: Vec<_> = self
.config
.cors_origins
.iter()
.filter_map(|o| o.parse().ok())
.collect();
CorsLayer::new()
.allow_origin(origins)
.allow_methods([
Method::GET,
Method::POST,
Method::PUT,
Method::DELETE,
Method::PATCH,
Method::OPTIONS,
])
.allow_headers([
axum::http::header::CONTENT_TYPE,
axum::http::header::AUTHORIZATION,
axum::http::header::ACCEPT,
axum::http::HeaderName::from_static("x-webhook-signature"),
axum::http::HeaderName::from_static("x-idempotency-key"),
axum::http::HeaderName::from_static("x-correlation-id"),
axum::http::HeaderName::from_static("x-session-id"),
axum::http::HeaderName::from_static("x-forge-platform"),
])
.allow_credentials(true)
.max_age(Duration::from_secs(86400))
}
} else {
CorsLayer::new()
};
let sse_state = Arc::new(SseState::with_config(
self.reactor.clone(),
auth_middleware_state.clone(),
super::sse::SseConfig {
max_sessions: self.config.sse_max_sessions,
max_subscriptions_per_session: self
.config
.reactor_config
.realtime
.max_subscriptions_per_session,
max_sessions_per_user: self.config.max_sessions_per_user,
max_sessions_per_ip: self.config.max_sessions_per_ip,
max_subscriptions_per_user: self.config.max_subscriptions_per_user,
..Default::default()
},
));
let expected_system_migrations = crate::pg::migration::get_system_migrations().len() as i64;
let readiness_state = Arc::new(ReadinessState {
db_pool: self.db.primary().clone(),
reactor: self.reactor.clone(),
node_id: self.cluster_node_id,
expected_system_migrations,
});
let json_depth_config = JsonDepthConfig {
max_depth: self.config.max_json_depth,
max_body_bytes: self.config.max_json_body_bytes,
};
let mut main_router = Router::new()
.route("/health", get(health_handler))
.route("/ready", get(readiness_handler).with_state(readiness_state))
.route("/rpc", post(rpc_handler))
.route("/rpc/{function}", post(rpc_function_handler))
.layer(DefaultBodyLimit::max(self.config.max_json_body_bytes))
.layer(middleware::from_fn_with_state(
json_depth_config,
json_depth_check_middleware,
))
.with_state(rpc_handler_state.clone());
let max_per_mutation = self
.registry
.functions()
.filter_map(|(_, entry)| entry.info().max_upload_size_bytes)
.max()
.unwrap_or(0);
let layer_limit = self.config.max_body_size_bytes.max(max_per_mutation);
let mp_config = MultipartConfig {
max_body_size_bytes: self.config.max_body_size_bytes,
max_file_size_bytes: self.config.max_file_size_bytes,
max_upload_fields: self.config.max_multipart_fields,
};
let multipart_router = Router::new()
.route("/rpc/{function}/upload", post(rpc_multipart_handler))
.layer(DefaultBodyLimit::max(layer_limit))
.layer(Extension(mp_config))
.layer(ConcurrencyLimitLayer::new(MAX_MULTIPART_CONCURRENCY))
.with_state(rpc_handler_state.clone());
let sse_router = Router::new()
.route("/events", get(sse_handler))
.route("/subscribe", post(sse_subscribe_handler))
.route("/unsubscribe", post(sse_unsubscribe_handler))
.route("/subscribe-job", post(sse_job_subscribe_handler))
.route("/subscribe-workflow", post(sse_workflow_subscribe_handler))
.with_state(sse_state);
let mut mcp_router = Router::new();
if self.config.mcp.enabled {
let path = self.config.mcp.path.clone();
let mut mcp_state = McpState::new(
self.config.mcp.clone(),
self.mcp_registry.clone().unwrap_or_default(),
self.db.primary().clone(),
self.job_dispatcher.clone(),
self.workflow_dispatcher.clone(),
Some(rpc_handler_state.router()),
);
if let Some(ref kv) = self.kv {
mcp_state = mcp_state.with_kv(Arc::clone(kv));
}
let mcp_state = Arc::new(mcp_state);
mcp_router = mcp_router.route(
&path,
post(mcp_post_handler)
.get(mcp_get_handler)
.with_state(mcp_state),
);
}
let mut signals_router = Router::new();
if let Some(collector) = &self.signals_collector {
let signals_state = Arc::new(crate::signals::endpoints::SignalsState {
collector: collector.clone(),
pool: self.db.primary().clone(),
server_secret: signal_visitor_secret(&self.config.auth.jwt_secret),
anonymize_ip: self.signals_anonymize_ip,
geoip: self.signals_geoip.clone(),
rate_limiter: Arc::new(crate::signals::rate_limit::SignalRateLimiter::new()),
});
signals_router = Router::new()
.route("/signal", post(crate::signals::endpoints::signal_handler))
.with_state(signals_state);
}
let admin_router = admin_router(AdminState {
db_pool: self.db.primary().clone(),
});
main_router = main_router
.merge(multipart_router)
.merge(mcp_router)
.merge(signals_router)
.merge(admin_router);
if let Some(custom) = &self.custom_routes {
main_router = main_router.merge(custom.clone());
}
let bounded_router = main_router.layer(
ServiceBuilder::new()
.layer(HandleErrorLayer::new(handle_middleware_error))
.layer(ConcurrencyLimitLayer::new(self.config.max_connections))
.layer(TimeoutLayer::new(Duration::from_secs(
self.config.request_timeout_secs,
))),
);
let full_router = bounded_router.merge(sse_router);
let security_config = Arc::new(SecurityHeadersConfig {
enabled: self.config.security_headers,
hsts: self.config.hsts,
});
let trusted_proxies = TrustedProxies(Arc::new(self.config.trusted_proxies.clone()));
let service_builder = ServiceBuilder::new()
.layer(cors.clone())
.layer(middleware::from_fn_with_state(
security_config,
security_headers_middleware,
))
.layer(middleware::from_fn(api_version_middleware))
.layer(middleware::from_fn_with_state(
trusted_proxies,
resolve_client_ip_middleware,
))
.layer(middleware::from_fn_with_state(
auth_middleware_state,
auth_middleware,
))
.layer(middleware::from_fn_with_state(
Arc::new(normalize_quiet_paths(&self.config.quiet_paths)),
tracing_middleware,
));
full_router.layer(service_builder)
}
pub fn addr(&self) -> std::net::SocketAddr {
std::net::SocketAddr::from(([0, 0, 0, 0], self.config.port))
}
pub async fn run(self) -> Result<(), std::io::Error> {
let addr = self.addr();
let tls = self.config.tls.clone();
let service = self
.router()
.into_make_service_with_connect_info::<super::PeerAddr>();
self.reactor
.start()
.await
.map_err(|e| std::io::Error::other(format!("Failed to start reactor: {}", e)))?;
tracing::info!("Reactor started for real-time updates");
tracing::info!("Gateway server listening on {}", addr);
let listener = bind_listener(addr, tls.as_ref()).await?;
axum::serve(listener, service).await
}
}
async fn health_handler() -> Json<HealthResponse> {
Json(HealthResponse {
status: "healthy".to_string(),
version: env!("CARGO_PKG_VERSION").to_string(),
})
}
async fn readiness_handler(
axum::extract::State(state): axum::extract::State<Arc<ReadinessState>>,
) -> (axum::http::StatusCode, Json<ReadinessResponse>) {
let db_ok = sqlx::query_scalar!("SELECT 1 as \"v!\"")
.fetch_one(&state.db_pool)
.await
.is_ok();
let reactor_stats = state.reactor.stats().await;
let reactor_ok = reactor_stats.listener_running;
let notify_queue_ok = if db_ok {
match sqlx::query_scalar!("SELECT pg_notification_queue_usage() AS \"usage!\"")
.fetch_one(&state.db_pool)
.await
{
Ok(usage) => usage < NOTIFY_QUEUE_FAIL_THRESHOLD,
Err(err) => {
tracing::warn!(error = %err, "pg_notification_queue_usage() failed; failing readiness probe");
false
}
}
} else {
false
};
let migrations_ok = if db_ok {
match sqlx::query_scalar!(
"SELECT COUNT(*) AS \"count!\" FROM forge_system_migrations WHERE version LIKE '__forge_v%'"
)
.fetch_one(&state.db_pool)
.await
{
Ok(applied) => applied >= state.expected_system_migrations,
Err(err) => {
tracing::warn!(error = %err, "forge_system_migrations count failed; failing readiness probe");
false
}
}
} else {
false
};
let cluster_registered = match state.node_id {
Some(node_id) if db_ok => match sqlx::query_scalar!(
r#"SELECT EXISTS(
SELECT 1 FROM forge_nodes
WHERE id = $1 AND status = 'active'
) AS "found!""#,
node_id
)
.fetch_one(&state.db_pool)
.await
{
Ok(found) => Some(found),
Err(err) => {
tracing::warn!(error = %err, "forge_nodes lookup failed; failing readiness probe");
Some(false)
}
},
Some(_) => Some(false),
None => None,
};
let ready = db_ok
&& reactor_ok
&& notify_queue_ok
&& migrations_ok
&& cluster_registered.unwrap_or(true);
let status_code = if ready {
axum::http::StatusCode::OK
} else {
axum::http::StatusCode::SERVICE_UNAVAILABLE
};
(
status_code,
Json(ReadinessResponse {
ready,
database: db_ok,
reactor: reactor_ok,
notify_queue_ok,
migrations_ok,
cluster_registered,
version: env!("CARGO_PKG_VERSION").to_string(),
}),
)
}
async fn handle_middleware_error(err: BoxError) -> axum::response::Response {
let rpc_err = if err.is::<tower::timeout::error::Elapsed>() {
RpcError::new("REQUEST_TIMEOUT", "Request timed out")
} else {
RpcError::new("SERVICE_UNAVAILABLE", "Server overloaded")
};
RpcResponse::error(rpc_err).into_response()
}
fn set_tracing_headers(response: &mut axum::response::Response, trace_id: &str, request_id: &str) {
if let Ok(val) = trace_id.parse() {
response.headers_mut().insert(TRACE_ID_HEADER, val);
}
if let Ok(val) = request_id.parse() {
response.headers_mut().insert(REQUEST_ID_HEADER, val);
}
}
#[cfg(feature = "otel")]
struct HeaderExtractor<'a>(&'a axum::http::HeaderMap);
#[cfg(feature = "otel")]
impl<'a> Extractor for HeaderExtractor<'a> {
fn get(&self, key: &str) -> Option<&str> {
self.0.get(key).and_then(|v| v.to_str().ok())
}
fn keys(&self) -> Vec<&str> {
self.0.keys().map(|k| k.as_str()).collect()
}
}
async fn resolve_client_ip_middleware(
axum::extract::State(trusted): axum::extract::State<TrustedProxies>,
mut req: axum::extract::Request,
next: axum::middleware::Next,
) -> axum::response::Response {
let peer_ip = req
.extensions()
.get::<axum::extract::connect_info::ConnectInfo<super::PeerAddr>>()
.map(|ci| ci.0.ip());
let ip = super::resolve_client_ip(req.headers(), peer_ip, &trusted.0);
req.extensions_mut().insert(super::ResolvedClientIp(ip));
next.run(req).await
}
#[derive(Debug, Clone)]
struct SecurityHeadersConfig {
enabled: bool,
hsts: bool,
}
async fn security_headers_middleware(
axum::extract::State(config): axum::extract::State<Arc<SecurityHeadersConfig>>,
req: axum::extract::Request,
next: axum::middleware::Next,
) -> axum::response::Response {
let mut response = next.run(req).await;
if config.enabled {
let headers = response.headers_mut();
headers.insert(
axum::http::header::X_CONTENT_TYPE_OPTIONS,
axum::http::HeaderValue::from_static("nosniff"),
);
headers.insert(
axum::http::header::X_FRAME_OPTIONS,
axum::http::HeaderValue::from_static("DENY"),
);
headers.insert(
axum::http::header::REFERRER_POLICY,
axum::http::HeaderValue::from_static("strict-origin-when-cross-origin"),
);
headers.insert(
axum::http::HeaderName::from_static("permissions-policy"),
axum::http::HeaderValue::from_static("camera=(), microphone=(), geolocation=()"),
);
headers.insert(
axum::http::header::CONTENT_SECURITY_POLICY,
axum::http::HeaderValue::from_static("default-src 'none'; frame-ancestors 'none'"),
);
if config.hsts {
headers.insert(
axum::http::header::STRICT_TRANSPORT_SECURITY,
axum::http::HeaderValue::from_static("max-age=63072000; includeSubDomains"),
);
}
}
response
}
const FORGE_API_V1: &str = "application/vnd.forge.v1+json";
async fn api_version_middleware(
req: axum::extract::Request,
next: axum::middleware::Next,
) -> axum::response::Response {
let is_rpc = req.uri().path().starts_with("/rpc");
if is_rpc && let Some(accept) = req.headers().get(axum::http::header::ACCEPT) {
let accept_str = accept.to_str().unwrap_or("");
if accept_str != "*/*" && !accept_str.is_empty() && !accept_str.contains(FORGE_API_V1) {
return RpcResponse::error(RpcError::new(
"UNSUPPORTED_API_VERSION",
format!(
"Unsupported Accept header '{}'. Use '{}' or omit the header.",
accept_str, FORGE_API_V1
),
))
.into_response();
}
}
next.run(req).await
}
async fn tracing_middleware(
axum::extract::State(quiet_paths): axum::extract::State<Arc<std::collections::HashSet<String>>>,
req: axum::extract::Request,
next: axum::middleware::Next,
) -> axum::response::Response {
let headers = req.headers();
#[cfg(feature = "otel")]
let parent_cx =
global::get_text_map_propagator(|propagator| propagator.extract(&HeaderExtractor(headers)));
let trace_id = headers
.get(TRACE_ID_HEADER)
.and_then(|v| v.to_str().ok())
.map(String::from)
.unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
let parent_span_id = headers
.get(SPAN_ID_HEADER)
.and_then(|v| v.to_str().ok())
.map(String::from);
let method = req.method().to_string();
let path = req.uri().path().to_string();
let route_pattern = req
.extensions()
.get::<axum::extract::MatchedPath>()
.map(|m| m.as_str().to_string())
.unwrap_or_else(|| normalize_metric_path(&path));
let mut tracing_state = TracingState::with_trace_id(trace_id.clone());
if let Some(span_id) = parent_span_id {
tracing_state = tracing_state.with_parent_span(span_id);
}
let mut req = req;
req.extensions_mut().insert(tracing_state.clone());
if req
.extensions()
.get::<forge_core::function::AuthContext>()
.is_none()
{
req.extensions_mut()
.insert(forge_core::function::AuthContext::unauthenticated());
}
let is_quiet = quiet_paths.contains(path.as_str());
if is_quiet {
let mut response = next.run(req).await;
set_tracing_headers(&mut response, &trace_id, &tracing_state.request_id);
return response;
}
let span = tracing::info_span!(
"http.request",
http.method = %method,
http.route = %path,
http.status_code = tracing::field::Empty,
trace_id = %trace_id,
request_id = %tracing_state.request_id,
);
#[cfg(feature = "otel")]
span.set_parent(parent_cx);
let mut response = next.run(req).instrument(span.clone()).await;
let status = response.status().as_u16();
let elapsed = tracing_state.elapsed();
span.record("http.status_code", status);
let duration_ms = elapsed.as_millis() as u64;
match status {
500..=599 => tracing::error!(parent: &span, duration_ms, "Request failed"),
400..=499 => tracing::warn!(parent: &span, duration_ms, "Request rejected"),
200..=299 => tracing::info!(parent: &span, duration_ms, "Request completed"),
_ => tracing::trace!(parent: &span, duration_ms, "Request completed"),
}
crate::observability::record_http_request(
&method,
&route_pattern,
status,
elapsed.as_secs_f64(),
);
set_tracing_headers(&mut response, &trace_id, &tracing_state.request_id);
response
}
fn normalize_metric_path(path: &str) -> String {
let segments: Vec<&str> = path.split('/').collect();
let mut out = String::with_capacity(path.len());
for (i, seg) in segments.iter().enumerate() {
if i > 0 {
out.push('/');
}
if seg.is_empty() {
continue;
}
if uuid::Uuid::try_parse(seg).is_ok() || seg.chars().all(|c| c.is_ascii_digit()) {
out.push_str("{id}");
} else {
out.push_str(seg);
}
}
if out.is_empty() { "/".to_string() } else { out }
}
fn normalize_quiet_paths(paths: &[String]) -> std::collections::HashSet<String> {
let mut set = std::collections::HashSet::with_capacity(paths.len() * 2);
for p in paths {
let stripped = p.strip_prefix("/_api").unwrap_or(p);
set.insert(stripped.to_string());
set.insert(p.clone());
}
set
}
fn json_max_depth(bytes: &[u8]) -> usize {
let mut depth: usize = 0;
let mut max_depth: usize = 0;
let mut in_string = false;
let mut escape = false;
for &b in bytes {
if escape {
escape = false;
continue;
}
if in_string {
if b == b'\\' {
escape = true;
} else if b == b'"' {
in_string = false;
}
continue;
}
match b {
b'"' => in_string = true,
b'{' | b'[' => {
depth += 1;
if depth > max_depth {
max_depth = depth;
}
}
b'}' | b']' => {
depth = depth.saturating_sub(1);
}
_ => {}
}
}
max_depth
}
#[derive(Debug, Clone, Copy)]
struct JsonDepthConfig {
max_depth: usize,
max_body_bytes: usize,
}
async fn json_depth_check_middleware(
axum::extract::State(config): axum::extract::State<JsonDepthConfig>,
req: axum::extract::Request,
next: axum::middleware::Next,
) -> axum::response::Response {
use axum::body::Body;
if req.method() != axum::http::Method::POST || config.max_depth == 0 {
return next.run(req).await;
}
let (parts, body) = req.into_parts();
let bytes = match axum::body::to_bytes(body, config.max_body_bytes).await {
Ok(b) => b,
Err(_) => {
return super::response::RpcResponse::error(super::response::RpcError::new(
"BAD_REQUEST",
"Failed to read request body",
))
.into_response();
}
};
let depth = json_max_depth(&bytes);
if depth > config.max_depth {
return super::response::RpcResponse::error(super::response::RpcError::new(
"BAD_REQUEST",
format!(
"JSON nesting depth {} exceeds the maximum of {}",
depth, config.max_depth
),
))
.into_response();
}
let req = axum::extract::Request::from_parts(parts, Body::from(bytes));
next.run(req).await
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::indexing_slicing, clippy::panic)]
mod tests {
use super::*;
#[test]
fn test_gateway_config_default() {
let config = GatewayConfig::default();
assert_eq!(config.port, 9081);
assert_eq!(config.max_connections, 512);
assert!(!config.cors_enabled);
}
#[test]
fn test_health_response_serialization() {
let resp = HealthResponse {
status: "healthy".to_string(),
version: "0.1.0".to_string(),
};
let json = serde_json::to_string(&resp).unwrap();
assert!(json.contains("healthy"));
}
#[test]
fn json_max_depth_flat_object_is_one() {
assert_eq!(json_max_depth(b"{\"a\":1}"), 1);
}
#[test]
fn json_max_depth_flat_array_is_one() {
assert_eq!(json_max_depth(b"[1,2,3]"), 1);
}
#[test]
fn json_max_depth_nested_object_counts_levels() {
assert_eq!(json_max_depth(b"{\"a\":{\"b\":{\"c\":1}}}"), 3);
}
#[test]
fn json_max_depth_nested_array_counts_levels() {
assert_eq!(json_max_depth(b"[[[[1]]]]"), 4);
}
#[test]
fn json_max_depth_mixed_nesting_tracks_peak() {
assert_eq!(json_max_depth(b"{\"a\":[{\"b\":[1]}]}"), 4);
}
#[test]
fn json_max_depth_ignores_braces_inside_strings() {
assert_eq!(json_max_depth(b"{\"k\":\"{{{[[[\"}"), 1);
}
#[test]
fn json_max_depth_respects_escaped_quote_in_string() {
assert_eq!(json_max_depth(b"{\"k\":\"a\\\"{b\"}"), 1);
}
#[test]
fn json_max_depth_empty_input_is_zero() {
assert_eq!(json_max_depth(b""), 0);
}
#[test]
fn json_max_depth_unbalanced_close_does_not_underflow() {
assert_eq!(json_max_depth(b"}}}}"), 0);
assert_eq!(json_max_depth(b"[1]]]]"), 1);
}
#[test]
fn signal_visitor_secret_uses_jwt_secret_when_present() {
let secret = Some("my-jwt-secret".to_string());
assert_eq!(signal_visitor_secret(&secret), "my-jwt-secret");
}
#[test]
fn signal_visitor_secret_falls_back_to_default_when_absent() {
assert_eq!(signal_visitor_secret(&None), DEFAULT_SIGNAL_SECRET);
}
#[test]
fn set_tracing_headers_inserts_both_headers() {
let mut response = axum::response::Response::new(axum::body::Body::empty());
set_tracing_headers(&mut response, "trace-abc", "req-xyz");
assert_eq!(
response.headers().get(TRACE_ID_HEADER).unwrap(),
"trace-abc"
);
assert_eq!(
response.headers().get(REQUEST_ID_HEADER).unwrap(),
"req-xyz"
);
}
#[test]
fn set_tracing_headers_skips_invalid_header_values() {
let mut response = axum::response::Response::new(axum::body::Body::empty());
set_tracing_headers(&mut response, "bad\nvalue", "req-ok");
assert!(response.headers().get(TRACE_ID_HEADER).is_none());
assert_eq!(response.headers().get(REQUEST_ID_HEADER).unwrap(), "req-ok");
}
}