use crate::analytics::{AnalyticsConfig, SharedQueryAnalytics};
use crate::cache::{CacheConfig, CacheLookupResult, SharedResponseCache};
use crate::circuit_breaker::{CircuitBreakerConfig, SharedCircuitBreakerRegistry};
use crate::compression::{create_compression_layer, CompressionConfig};
use crate::defer::{
extract_deferred_fragments, format_initial_part, format_subsequent_part, has_defer_directive,
strip_defer_directives, DeferConfig, DeferredExecution, DeferredPart, MULTIPART_CONTENT_TYPE,
};
use crate::error::{GraphQLError, Result};
use crate::grpc_client::GrpcClientPool;
use crate::health::{health_handler, readiness_handler, HealthState};
use crate::high_performance::{
pin_to_core, recommended_workers, FastJsonParser, HighPerfConfig, PerfMetrics,
ResponseTemplates, ShardedCache,
};
use crate::metrics::GatewayMetrics;
use crate::middleware::{Context, Middleware};
use crate::persisted_queries::{
process_apq_request, PersistedQueryConfig, PersistedQueryError, SharedPersistedQueryStore,
};
use crate::plugin::PluginRegistry;
use crate::query_whitelist::{QueryWhitelistConfig, SharedQueryWhitelist};
use crate::request_collapsing::{RequestCollapsingConfig, SharedRequestCollapsingRegistry};
use crate::schema::{DynamicSchema, GrpcResponseCache};
use async_graphql::{futures_util::stream::BoxStream, Data, ServerError};
use async_graphql_axum::{GraphQLProtocol, GraphQLRequest, GraphQLResponse, GraphQLWebSocket};
use axum::{
extract::{
ws::{Message, WebSocket, WebSocketUpgrade},
State,
},
http::HeaderMap,
response::{Html, IntoResponse, Json},
routing::{get, post},
Extension, Router,
};
use bytes::Bytes;
use futures::{SinkExt, StreamExt};
use std::any::TypeId;
use std::collections::HashSet;
use std::sync::Arc;
use std::time::{Duration, Instant};
pub struct ServeMux {
schema: DynamicSchema,
middlewares: Vec<Arc<dyn Middleware>>,
error_handler: Option<Arc<dyn Fn(Vec<GraphQLError>) + Send + Sync>>,
client_pool: Option<GrpcClientPool>,
health_checks_enabled: bool,
metrics_enabled: bool,
playground_enabled: bool,
apq_store: Option<SharedPersistedQueryStore>,
circuit_breaker: Option<SharedCircuitBreakerRegistry>,
response_cache: Option<SharedResponseCache>,
compression_config: Option<CompressionConfig>,
query_whitelist: Option<SharedQueryWhitelist>,
analytics: Option<SharedQueryAnalytics>,
request_collapsing: Option<SharedRequestCollapsingRegistry>,
high_perf_config: Option<HighPerfConfig>,
json_parser: Arc<FastJsonParser>,
sharded_cache: Option<Arc<ShardedCache<Bytes>>>,
perf_metrics: Arc<PerfMetrics>,
response_templates: Arc<ResponseTemplates>,
defer_config: Option<DeferConfig>,
plugins: PluginRegistry,
}
const STRICT_API_CSP: &str =
"default-src 'none'; frame-ancestors 'none'; base-uri 'none'; form-action 'none'";
const PLAYGROUND_CSP: &str = "default-src 'self'; script-src 'self' 'unsafe-inline' https://cdn.jsdelivr.net https://unpkg.com; style-src 'self' 'unsafe-inline'; img-src 'self' data:; connect-src 'self' ws: wss:; object-src 'none'; base-uri 'self'; frame-ancestors 'none'; form-action 'self'";
fn parse_cors_allowed_origin(raw: Option<&str>) -> Option<axum::http::HeaderValue> {
let origin = raw?.trim();
if origin.is_empty() {
return None;
}
if origin != "*" {
let parsed = reqwest::Url::parse(origin).ok()?;
if !matches!(parsed.scheme(), "http" | "https") {
return None;
}
if parsed.host_str().is_none()
|| !parsed.username().is_empty()
|| parsed.password().is_some()
{
return None;
}
if parsed.path() != "/" || parsed.query().is_some() || parsed.fragment().is_some() {
return None;
}
}
axum::http::HeaderValue::from_str(origin).ok()
}
fn configured_cors_allow_origin() -> Option<axum::http::HeaderValue> {
parse_cors_allowed_origin(std::env::var("CORS_ALLOWED_ORIGIN").ok().as_deref())
}
fn content_security_policy(playground_enabled: bool) -> &'static str {
if playground_enabled {
PLAYGROUND_CSP
} else {
STRICT_API_CSP
}
}
#[derive(Clone, Debug)]
struct WebSocketSessionHeaders {
headers: HeaderMap,
}
#[derive(Clone, Copy, Debug)]
struct LiveQueryRequestMarker;
fn is_forbidden_ws_connection_header(name: &str) -> bool {
matches!(
name.to_ascii_lowercase().as_str(),
"connection"
| "content-length"
| "forwarded"
| "host"
| "origin"
| "sec-websocket-extensions"
| "sec-websocket-key"
| "sec-websocket-protocol"
| "sec-websocket-version"
| "upgrade"
| "x-forwarded-for"
| "x-real-ip"
)
}
fn json_value_to_header_string(value: &serde_json::Value) -> Option<String> {
match value {
serde_json::Value::String(s) => Some(s.clone()),
serde_json::Value::Number(n) => Some(n.to_string()),
serde_json::Value::Bool(b) => Some(b.to_string()),
_ => None,
}
}
fn merge_ws_connection_header_values(
headers: &mut HeaderMap,
values: &serde_json::Map<String, serde_json::Value>,
) {
for (key, value) in values {
if is_forbidden_ws_connection_header(key) {
continue;
}
let Some(value_str) = json_value_to_header_string(value) else {
continue;
};
let Ok(header_name) = axum::http::HeaderName::from_bytes(key.as_bytes()) else {
continue;
};
let Ok(header_value) = axum::http::HeaderValue::from_str(&value_str) else {
continue;
};
headers.insert(header_name, header_value);
}
}
fn merge_ws_connection_init_headers(
base_headers: &HeaderMap,
payload: &serde_json::Value,
) -> HeaderMap {
let mut headers = base_headers.clone();
let Some(payload_obj) = payload.as_object() else {
return headers;
};
if let Some(nested_headers) = payload_obj
.get("headers")
.and_then(|value| value.as_object())
{
merge_ws_connection_header_values(&mut headers, nested_headers);
}
merge_ws_connection_header_values(&mut headers, payload_obj);
headers
}
fn ws_session_headers(session_data: Option<&Arc<Data>>) -> HeaderMap {
session_data
.and_then(|data| data.get(&TypeId::of::<WebSocketSessionHeaders>()))
.and_then(|value| value.downcast_ref::<WebSocketSessionHeaders>())
.map(|session| session.headers.clone())
.unwrap_or_default()
}
#[derive(Clone)]
struct SubscriptionExecutor {
mux: Arc<ServeMux>,
}
impl SubscriptionExecutor {
fn new(mux: Arc<ServeMux>) -> Self {
Self { mux }
}
}
#[async_trait::async_trait]
impl async_graphql::Executor for SubscriptionExecutor {
async fn execute(&self, request: async_graphql::Request) -> async_graphql::Response {
self.mux.schema.execute(request).await
}
fn execute_stream(
&self,
request: async_graphql::Request,
session_data: Option<Arc<Data>>,
) -> BoxStream<'static, async_graphql::Response> {
let mux = self.mux.clone();
Box::pin(async_stream::stream! {
let headers = ws_session_headers(session_data.as_ref());
let request = match mux.prepare_graphql_request(request) {
Ok(request) => request,
Err(response) => {
yield response;
return;
}
};
let ctx = match mux.prepare_execution_context(&headers).await {
Ok(ctx) => ctx,
Err(err) => {
yield async_graphql::Response::from_errors(vec![ServerError::new(err.to_string(), None)]);
return;
}
};
if let Err(err) = mux.plugins.on_request(&ctx, &request).await {
yield async_graphql::Response::from_errors(vec![ServerError::new(err.to_string(), None)]);
return;
}
let request = request
.data(ctx.clone())
.data(mux.plugins.clone())
.data(GrpcResponseCache::default());
let schema_executor = mux.schema.executor();
let stream =
async_graphql::Executor::execute_stream(&schema_executor, request, session_data);
futures::pin_mut!(stream);
while let Some(response) = stream.next().await {
if let Err(err) = mux.plugins.on_response(&ctx, &response).await {
yield async_graphql::Response::from_errors(vec![ServerError::new(err.to_string(), None)]);
break;
}
yield response;
}
})
}
}
impl ServeMux {
pub fn new(schema: DynamicSchema) -> Self {
Self {
schema,
middlewares: Vec::new(),
error_handler: None,
client_pool: None,
health_checks_enabled: false,
metrics_enabled: false,
playground_enabled: std::env::var("ENABLE_GRAPHQL_PLAYGROUND")
.map(|v| v == "true" || v == "1")
.unwrap_or(false),
apq_store: None,
circuit_breaker: None,
response_cache: None,
compression_config: None,
query_whitelist: None,
analytics: None,
request_collapsing: None,
high_perf_config: None,
json_parser: Arc::new(FastJsonParser::default()),
sharded_cache: None,
perf_metrics: Arc::new(PerfMetrics::default()),
response_templates: Arc::new(ResponseTemplates::new()),
defer_config: None,
plugins: PluginRegistry::new(),
}
}
pub fn set_client_pool(&mut self, pool: GrpcClientPool) {
self.client_pool = Some(pool);
}
pub fn enable_health_checks(&mut self) {
self.health_checks_enabled = true;
}
pub fn enable_metrics(&mut self) {
self.metrics_enabled = true;
}
pub fn enable_playground(&mut self) {
self.playground_enabled = true;
}
pub fn enable_persisted_queries(&mut self, config: PersistedQueryConfig) {
self.apq_store = Some(crate::persisted_queries::create_apq_store(config));
}
pub fn enable_circuit_breaker(&mut self, config: CircuitBreakerConfig) {
self.circuit_breaker = Some(crate::circuit_breaker::create_circuit_breaker_registry(
config,
));
}
pub fn circuit_breaker(&self) -> Option<&SharedCircuitBreakerRegistry> {
self.circuit_breaker.as_ref()
}
pub fn enable_response_cache(&mut self, config: CacheConfig) {
self.response_cache = Some(crate::cache::create_response_cache(config));
}
pub fn response_cache(&self) -> Option<&SharedResponseCache> {
self.response_cache.as_ref()
}
pub fn enable_compression(&mut self, config: CompressionConfig) {
self.compression_config = Some(config);
}
pub fn compression_config(&self) -> Option<&CompressionConfig> {
self.compression_config.as_ref()
}
pub fn enable_query_whitelist(&mut self, config: QueryWhitelistConfig) {
self.query_whitelist = Some(Arc::new(crate::query_whitelist::QueryWhitelist::new(
config,
)));
}
pub fn query_whitelist(&self) -> Option<&SharedQueryWhitelist> {
self.query_whitelist.as_ref()
}
pub fn enable_analytics(&mut self, config: AnalyticsConfig) {
self.analytics = Some(crate::analytics::create_analytics(config));
}
pub fn analytics(&self) -> Option<&SharedQueryAnalytics> {
self.analytics.as_ref()
}
pub fn enable_request_collapsing(&mut self, config: RequestCollapsingConfig) {
self.request_collapsing =
Some(crate::request_collapsing::create_request_collapsing_registry(config));
}
pub fn request_collapsing(&self) -> Option<&SharedRequestCollapsingRegistry> {
self.request_collapsing.as_ref()
}
pub fn enable_high_performance(&mut self, config: HighPerfConfig) {
self.json_parser = Arc::new(FastJsonParser::new(config.buffer_pool_size));
self.sharded_cache = Some(Arc::new(ShardedCache::new(
config.cache_shards,
config.max_entries_per_shard,
)));
if config.cpu_affinity {
let num_cores = recommended_workers();
for i in 0..num_cores {
let _ = pin_to_core(i);
}
}
self.high_perf_config = Some(config);
}
pub fn perf_metrics(&self) -> &PerfMetrics {
&self.perf_metrics
}
pub fn enable_defer(&mut self, config: DeferConfig) {
self.defer_config = Some(config);
}
pub fn defer_config(&self) -> Option<&DeferConfig> {
self.defer_config.as_ref()
}
pub fn add_middleware(&mut self, middleware: Arc<dyn Middleware>) {
self.middlewares.push(middleware);
}
pub fn with_middleware(mut self, middleware: Arc<dyn Middleware>) -> Self {
self.add_middleware(middleware);
self
}
pub fn set_error_handler_arc(&mut self, handler: Arc<dyn Fn(Vec<GraphQLError>) + Send + Sync>) {
self.error_handler = Some(handler);
}
pub fn set_error_handler<F>(&mut self, handler: F)
where
F: Fn(Vec<GraphQLError>) + Send + Sync + 'static,
{
self.set_error_handler_arc(Arc::new(handler));
}
pub fn set_plugins(&mut self, plugins: PluginRegistry) {
self.plugins = plugins;
}
async fn prepare_execution_context(&self, headers: &HeaderMap) -> Result<Context> {
let mut ctx = Context {
headers: headers.clone(),
extensions: std::collections::HashMap::new(),
request_start: std::time::Instant::now(),
request_id: headers
.get("x-request-id")
.and_then(|v| v.to_str().ok())
.map(String::from)
.unwrap_or_else(|| uuid::Uuid::new_v4().to_string()),
client_ip: headers
.get("x-forwarded-for")
.and_then(|v| v.to_str().ok())
.and_then(|s| s.split(',').next())
.map(|s| s.trim().to_string())
.filter(|ip| ip.parse::<std::net::IpAddr>().is_ok())
.or_else(|| {
headers
.get("x-real-ip")
.and_then(|v| v.to_str().ok())
.map(String::from)
.filter(|ip| ip.parse::<std::net::IpAddr>().is_ok())
}),
encryption_key: None,
};
for middleware in &self.middlewares {
middleware.call(&mut ctx).await?;
}
Ok(ctx)
}
fn prepare_graphql_request(
&self,
request: async_graphql::Request,
) -> std::result::Result<async_graphql::Request, async_graphql::Response> {
let processed_request = if let Some(ref apq_store) = self.apq_store {
match self.process_apq_request(apq_store, request) {
Ok(req) => req,
Err(apq_err) => {
return Err(self.apq_error_response(apq_err));
}
}
} else {
request
};
if let Err(err) = crate::waf::validate_request(&processed_request) {
tracing::warn!("WAF blocked request: {}", err);
let mut server_err = ServerError::new(err.to_string(), None);
server_err.extensions = Some({
let mut ext = async_graphql::ErrorExtensionValues::default();
ext.set("code", "VALIDATION_ERROR");
ext
});
return Err(async_graphql::Response::from_errors(vec![server_err]));
}
if let Some(ref whitelist) = self.query_whitelist {
let operation_id = processed_request
.extensions
.get("operationId")
.and_then(|v| serde_json::to_value(v).ok())
.and_then(|v| v.as_str().map(String::from));
if let Err(err) =
whitelist.validate_query(&processed_request.query, operation_id.as_deref())
{
tracing::warn!("Query whitelist validation failed: {}", err);
let mut server_err = ServerError::new(err.to_string(), None);
server_err.extensions = Some({
let mut ext = async_graphql::ErrorExtensionValues::default();
ext.set("code", "QUERY_NOT_WHITELISTED");
ext
});
return Err(async_graphql::Response::from_errors(vec![server_err]));
}
}
Ok(processed_request)
}
async fn execute_with_middlewares(
&self,
headers: HeaderMap,
request: async_graphql::Request,
) -> Result<async_graphql::Response> {
let ctx = self.prepare_execution_context(&headers).await?;
self.plugins.on_request(&ctx, &request).await?;
let mut gql_request = request;
gql_request = gql_request.data(ctx.clone());
gql_request = gql_request.data(self.plugins.clone());
gql_request = gql_request.data(GrpcResponseCache::default());
let response = self.schema.execute(gql_request).await;
self.plugins.on_response(&ctx, &response).await?;
Ok(response)
}
pub async fn handle_http(
&self,
headers: HeaderMap,
request: async_graphql::Request,
) -> async_graphql::Response {
let mut processed_request = match self.prepare_graphql_request(request) {
Ok(request) => request,
Err(response) => return response,
};
let is_live_query = crate::live_query::has_live_directive(&processed_request.query)
|| processed_request
.data
.contains_key(&TypeId::of::<LiveQueryRequestMarker>());
if crate::live_query::has_live_directive(&processed_request.query) {
let stripped_query = crate::live_query::strip_live_directive(&processed_request.query);
tracing::debug!(
is_live = is_live_query,
"Live query detected, stripping @live directive"
);
let mut rebuilt_request = async_graphql::Request::new(stripped_query);
rebuilt_request.operation_name = processed_request.operation_name;
rebuilt_request.variables = processed_request.variables;
rebuilt_request.uploads = processed_request.uploads;
rebuilt_request.data = processed_request.data;
rebuilt_request.extensions = processed_request.extensions;
rebuilt_request.introspection_mode = processed_request.introspection_mode;
processed_request = rebuilt_request;
}
let bypass_response_cache = is_live_query;
let is_mutation = crate::cache::is_mutation(&processed_request.query);
let operation_type = if is_mutation { "mutation" } else { "query" };
let analytics_query = processed_request.query.clone();
let analytics_op_name = processed_request.operation_name.clone();
let request_start = Instant::now();
let vary_header_values = if let Some(ref cache) = self.response_cache {
cache
.config
.vary_headers
.iter()
.map(|h| {
let val = headers.get(h).and_then(|v| v.to_str().ok()).unwrap_or("");
format!("{}:{}", h, val)
})
.collect::<Vec<_>>()
} else {
Vec::new()
};
if !is_mutation && !bypass_response_cache {
if let Some(ref cache) = self.response_cache {
let cache_key = crate::cache::ResponseCache::generate_cache_key(
&processed_request.query,
Some(&serde_json::to_value(&processed_request.variables).unwrap_or_default()),
processed_request.operation_name.as_deref(),
&vary_header_values,
);
match cache.get(&cache_key).await {
CacheLookupResult::Hit(cached) => {
tracing::debug!("Response cache hit");
if let Some(ref analytics) = self.analytics {
analytics.record_cache_access(true);
analytics.record_query(
&analytics_query,
analytics_op_name.as_deref(),
operation_type,
request_start.elapsed(),
false,
None,
);
}
return self.cached_to_response(cached.data);
}
CacheLookupResult::Stale(cached) => {
tracing::debug!("Response cache stale hit");
if let Some(ref analytics) = self.analytics {
analytics.record_cache_access(true);
analytics.record_query(
&analytics_query,
analytics_op_name.as_deref(),
operation_type,
request_start.elapsed(),
false,
None,
);
}
return self.cached_to_response(cached.data);
}
CacheLookupResult::Miss => {
if let Some(ref analytics) = self.analytics {
analytics.record_cache_access(false);
}
}
}
}
}
let cache_query_info =
if self.response_cache.is_some() && !is_mutation && !bypass_response_cache {
Some((
processed_request.query.clone(),
serde_json::to_value(&processed_request.variables).unwrap_or_default(),
processed_request.operation_name.clone(),
))
} else {
None
};
match self
.execute_with_middlewares(headers, processed_request)
.await
{
Ok(resp) => {
let duration = request_start.elapsed();
let had_error = !resp.errors.is_empty();
if let Some(ref analytics) = self.analytics {
let error_details = if had_error {
resp.errors.first().map(|e| {
let code = e
.extensions
.as_ref()
.and_then(|ext| ext.get("code"))
.map(|c| c.to_string())
.unwrap_or_else(|| "GRAPHQL_ERROR".to_string());
(code, e.message.clone())
})
} else {
None
};
analytics.record_query(
&analytics_query,
analytics_op_name.as_deref(),
operation_type,
duration,
had_error,
error_details
.as_ref()
.map(|(c, m)| (c.as_str(), m.as_str())),
);
}
if is_mutation {
if let Some(ref cache) = self.response_cache {
if let Ok(resp_json) = serde_json::to_value(&resp) {
cache.invalidate_for_mutation(&resp_json).await;
}
}
if let Some(ref sharded) = self.sharded_cache {
sharded.clear(); }
} else if let Some((query, vars, op_name)) = cache_query_info {
if let Some(ref cache) = self.response_cache {
let cache_key = crate::cache::ResponseCache::generate_cache_key(
&query,
Some(&vars),
op_name.as_deref(),
&vary_header_values,
);
if let Ok(resp_json) = serde_json::to_value(&resp) {
let types = extract_types_from_response(&resp_json);
let entities = extract_entities_from_response(&resp_json);
cache
.put(cache_key.clone(), resp_json.clone(), types, entities)
.await;
cache.put_all_entities(&resp_json, None).await;
if let Some(ref sharded) = self.sharded_cache {
if let Ok(resp_bytes) = self.json_parser.serialize(&resp) {
sharded.insert(&cache_key, resp_bytes, Duration::from_secs(60));
}
}
}
}
}
resp
}
Err(err) => {
let duration = request_start.elapsed();
let gql_err: GraphQLError = err.into();
if let Some(ref analytics) = self.analytics {
analytics.record_query(
&analytics_query,
analytics_op_name.as_deref(),
operation_type,
duration,
true,
Some(("INTERNAL_ERROR", &gql_err.message)),
);
}
if let Some(handler) = &self.error_handler {
handler(vec![gql_err.clone()]);
}
let server_err = ServerError::new(gql_err.message.clone(), None);
async_graphql::Response::from_errors(vec![server_err])
}
}
}
pub async fn handle_fast(&self, headers: HeaderMap, body: Bytes) -> axum::response::Response {
let start = Instant::now();
let request_val = match self.json_parser.parse_bytes(&body) {
Ok(v) => v,
Err(err) => {
tracing::warn!("Failed to parse JSON with SIMD: {}", err);
return (
axum::http::StatusCode::BAD_REQUEST,
[(axum::http::header::CONTENT_TYPE, "application/json")],
self.response_templates
.errors
.get("PARSE_ERROR")
.cloned()
.unwrap_or_else(|| {
Bytes::from(r#"{"errors":[{"message":"Invalid JSON"}]}"#)
}),
)
.into_response();
}
};
let query = request_val["query"].as_str().unwrap_or("");
let variables = &request_val["variables"];
let operation_name = request_val["operationName"].as_str();
if let Some(ref sharded) = self.sharded_cache {
let vary_header_values: Vec<String> = if let Some(ref cache) = self.response_cache {
cache
.config
.vary_headers
.iter()
.map(|h| {
format!(
"{}:{}",
h,
headers.get(h).and_then(|v| v.to_str().ok()).unwrap_or("")
)
})
.collect()
} else {
Vec::new()
};
let cache_key = crate::cache::ResponseCache::generate_cache_key(
query,
Some(variables),
operation_name,
&vary_header_values,
);
if let Some(cached_bytes) = sharded.get(&cache_key) {
self.perf_metrics
.record(start.elapsed().as_nanos() as u64, true);
return (
[(axum::http::header::CONTENT_TYPE, "application/json")],
cached_bytes,
)
.into_response();
}
}
let mut gql_req = async_graphql::Request::new(query);
if !variables.is_null() {
if let Ok(vars) = serde_json::from_value(variables.clone()) {
gql_req = gql_req.variables(vars);
}
}
if let Some(op) = operation_name {
gql_req = gql_req.operation_name(op);
}
let resp = self.handle_http(headers, gql_req).await;
self.perf_metrics
.record(start.elapsed().as_nanos() as u64, false);
GraphQLResponse::from(resp).into_response()
}
fn cached_to_response(&self, data: serde_json::Value) -> async_graphql::Response {
match serde_json::from_value::<async_graphql::Response>(data.clone()) {
Ok(resp) => resp,
Err(_) => {
async_graphql::Response::new(
serde_json::from_value::<async_graphql::Value>(data)
.unwrap_or(async_graphql::Value::Null),
)
}
}
}
fn process_apq_request(
&self,
store: &SharedPersistedQueryStore,
mut request: async_graphql::Request,
) -> std::result::Result<async_graphql::Request, PersistedQueryError> {
let query = if request.query.is_empty() {
None
} else {
Some(request.query.as_str())
};
let extensions_value = if request.extensions.is_empty() {
None
} else {
serde_json::to_value(&request.extensions).ok()
};
match process_apq_request(store, query, extensions_value.as_ref())? {
Some(resolved_query) => {
request.query = resolved_query;
Ok(request)
}
None => {
Err(PersistedQueryError::NotFound)
}
}
}
fn apq_error_response(&self, err: PersistedQueryError) -> async_graphql::Response {
let error_extensions = err.to_extensions();
let code = error_extensions
.get("code")
.and_then(|v| v.as_str())
.unwrap_or("PERSISTED_QUERY_ERROR");
let mut server_err = ServerError::new(err.to_string(), None);
server_err.extensions = Some({
let mut ext = async_graphql::ErrorExtensionValues::default();
ext.set("code", code);
ext
});
async_graphql::Response::from_errors(vec![server_err])
}
pub fn into_router(self) -> Router {
use axum::middleware as axum_mw;
use axum::response::Response;
let health_checks_enabled = self.health_checks_enabled;
let metrics_enabled = self.metrics_enabled;
let analytics_enabled = self.analytics.is_some();
let client_pool = self.client_pool.clone();
let compression_config = self.compression_config.clone();
let playground_enabled = self.playground_enabled;
let cors_allow_origin = configured_cors_allow_origin();
let content_security_policy = content_security_policy(playground_enabled);
let state = Arc::new(self);
let use_fast_path = state.high_perf_config.is_some();
let router = Router::new();
let router = if use_fast_path {
router.route("/graphql", post(handle_graphql_fast_or_defer))
} else {
router.route("/graphql", post(handle_graphql_post))
};
let router = if playground_enabled {
router.route("/graphql", get(graphql_playground))
} else {
router
};
let mut router = router
.route("/graphql/ws", get(handle_graphql_ws))
.route("/graphql/live", get(handle_live_query_ws))
.route("/graphql/defer", post(handle_graphql_defer))
.layer(Extension(state.schema.executor()))
.with_state(state.clone());
router = router.layer(axum::extract::DefaultBodyLimit::max(1024 * 1024));
router = router.layer(axum_mw::from_fn(
move |req: axum::http::Request<axum::body::Body>, next: axum_mw::Next| {
let cors_allow_origin = cors_allow_origin.clone();
async move {
if req.method() == axum::http::Method::OPTIONS {
let mut response = Response::builder()
.status(axum::http::StatusCode::NO_CONTENT)
.body(axum::body::Body::empty())
.unwrap();
if let Some(origin) = &cors_allow_origin {
let headers = response.headers_mut();
headers.insert(
axum::http::header::ACCESS_CONTROL_ALLOW_ORIGIN,
origin.clone(),
);
headers.insert(
axum::http::header::ACCESS_CONTROL_ALLOW_METHODS,
axum::http::HeaderValue::from_static("GET, POST, OPTIONS"),
);
headers.insert(
axum::http::header::ACCESS_CONTROL_ALLOW_HEADERS,
axum::http::HeaderValue::from_static(
"Content-Type, Authorization, X-Request-ID",
),
);
headers.insert(
axum::http::header::ACCESS_CONTROL_MAX_AGE,
axum::http::HeaderValue::from_static("86400"),
);
}
return response;
}
let mut response = next.run(req).await;
let headers = response.headers_mut();
headers.insert(
axum::http::header::X_CONTENT_TYPE_OPTIONS,
axum::http::HeaderValue::from_static("nosniff"),
);
headers.insert(
axum::http::header::X_FRAME_OPTIONS,
axum::http::HeaderValue::from_static("DENY"),
);
headers.insert(
axum::http::header::STRICT_TRANSPORT_SECURITY,
axum::http::HeaderValue::from_static("max-age=31536000; includeSubDomains"),
);
headers.insert(
axum::http::header::CACHE_CONTROL,
axum::http::HeaderValue::from_static("no-store, no-cache, must-revalidate"),
);
headers.insert(
axum::http::header::HeaderName::from_static("x-xss-protection"),
axum::http::HeaderValue::from_static("1; mode=block"),
);
headers.insert(
axum::http::header::CONTENT_SECURITY_POLICY,
axum::http::HeaderValue::from_static(content_security_policy),
);
headers.insert(
axum::http::header::REFERRER_POLICY,
axum::http::HeaderValue::from_static("strict-origin-when-cross-origin"),
);
headers.insert(
axum::http::header::HeaderName::from_static("permissions-policy"),
axum::http::HeaderValue::from_static(
"camera=(), microphone=(), geolocation=(), browsing-topics=(), payment=()",
),
);
headers.insert(
axum::http::header::HeaderName::from_static("x-dns-prefetch-control"),
axum::http::HeaderValue::from_static("off"),
);
headers.insert(
axum::http::header::HeaderName::from_static("cross-origin-opener-policy"),
axum::http::HeaderValue::from_static("same-origin"),
);
headers.insert(
axum::http::header::HeaderName::from_static("cross-origin-embedder-policy"),
axum::http::HeaderValue::from_static("require-corp"),
);
headers.insert(
axum::http::header::HeaderName::from_static("cross-origin-resource-policy"),
axum::http::HeaderValue::from_static("same-origin"),
);
if let Some(origin) = cors_allow_origin {
headers.insert(axum::http::header::ACCESS_CONTROL_ALLOW_ORIGIN, origin);
}
response
}
},
));
if let Some(ref config) = compression_config {
if config.enabled {
router = router.layer(create_compression_layer(config));
if config.lz4_enabled() || config.gbp_lz4_enabled() {
router = router.layer(axum::middleware::from_fn(
crate::lz4_compression::lz4_compression_middleware,
));
}
}
}
if health_checks_enabled {
let health_state = Arc::new(HealthState::new(client_pool.unwrap_or_default()));
router = router
.route("/health", get(health_handler))
.route("/ready", get(readiness_handler).with_state(health_state));
}
if metrics_enabled {
router = router.route("/metrics", get(metrics_handler));
}
if analytics_enabled {
router = router
.route("/analytics", get(analytics_dashboard_handler))
.route(
"/analytics/api",
get(analytics_api_handler).with_state(state.clone()),
)
.route(
"/analytics/reset",
post(analytics_reset_handler).with_state(state),
);
}
router
}
}
impl Clone for ServeMux {
fn clone(&self) -> Self {
Self {
schema: self.schema.clone(),
middlewares: self.middlewares.clone(),
error_handler: self.error_handler.clone(),
client_pool: self.client_pool.clone(),
health_checks_enabled: self.health_checks_enabled,
metrics_enabled: self.metrics_enabled,
playground_enabled: self.playground_enabled,
apq_store: self.apq_store.clone(),
circuit_breaker: self.circuit_breaker.clone(),
response_cache: self.response_cache.clone(),
compression_config: self.compression_config.clone(),
query_whitelist: self.query_whitelist.clone(),
analytics: self.analytics.clone(),
request_collapsing: self.request_collapsing.clone(),
high_perf_config: self.high_perf_config.clone(),
json_parser: self.json_parser.clone(),
sharded_cache: self.sharded_cache.clone(),
perf_metrics: self.perf_metrics.clone(),
response_templates: self.response_templates.clone(),
defer_config: self.defer_config.clone(),
plugins: self.plugins.clone(),
}
}
}
fn extract_types_from_response(response: &serde_json::Value) -> HashSet<String> {
let mut types = HashSet::new();
extract_types_recursive(response, &mut types);
types
}
fn extract_types_recursive(value: &serde_json::Value, types: &mut HashSet<String>) {
match value {
serde_json::Value::Object(map) => {
if let Some(serde_json::Value::String(type_name)) = map.get("__typename") {
types.insert(type_name.clone());
}
for v in map.values() {
extract_types_recursive(v, types);
}
}
serde_json::Value::Array(arr) => {
for item in arr {
extract_types_recursive(item, types);
}
}
_ => {}
}
}
fn extract_entities_from_response(response: &serde_json::Value) -> HashSet<String> {
let mut entities = HashSet::new();
extract_entities_recursive(response, &mut entities);
entities
}
fn extract_entities_recursive(value: &serde_json::Value, entities: &mut HashSet<String>) {
match value {
serde_json::Value::Object(map) => {
let type_name = map.get("__typename").and_then(|t| t.as_str());
let id = map
.get("id")
.and_then(|i| i.as_str())
.or_else(|| map.get("_id").and_then(|i| i.as_str()));
if let (Some(tn), Some(id_val)) = (type_name, id) {
entities.insert(format!("{}#{}", tn, id_val));
}
for v in map.values() {
extract_entities_recursive(v, entities);
}
}
serde_json::Value::Array(arr) => {
for item in arr {
extract_entities_recursive(item, entities);
}
}
_ => {}
}
}
async fn handle_graphql_post(
State(mux): State<Arc<ServeMux>>,
headers: HeaderMap,
request: GraphQLRequest,
) -> axum::response::Response {
let gql_request = request.into_inner();
let accepts_multipart = headers
.get("accept")
.and_then(|v| v.to_str().ok())
.is_some_and(|v| v.contains("multipart/mixed"));
if accepts_multipart && has_defer_directive(&gql_request.query) {
if let Some(config) = mux.defer_config().cloned() {
if config.enabled {
let query = gql_request.query.clone();
let fragments = extract_deferred_fragments(&query);
if fragments.len() <= config.max_deferred_fragments {
let stripped_query = strip_defer_directives(&query);
let mut eager_request = async_graphql::Request::new(stripped_query)
.variables(gql_request.variables);
if let Some(op_name) = gql_request.operation_name {
eager_request = eager_request.operation_name(op_name);
}
let full_response = mux.handle_http(headers, eager_request).await;
let full_json = serde_json::to_value(&full_response).unwrap_or_else(|_| {
serde_json::json!({"data": null, "errors": [{"message": "Serialization failed"}]})
});
let boundary = config.multipart_boundary.clone();
let (exec, mut rx) = DeferredExecution::new(config, fragments);
tokio::spawn(async move {
if let Err(e) = exec.execute(full_json).await {
tracing::warn!(error = %e, "Deferred execution failed");
}
});
let stream = async_stream::stream! {
while let Some(part) = rx.recv().await {
match part {
DeferredPart::Initial(payload) => {
yield Ok::<_, std::convert::Infallible>(
format_initial_part(&payload, &boundary)
);
}
DeferredPart::Subsequent(payload) => {
let is_last = !payload.has_next;
yield Ok::<_, std::convert::Infallible>(
format_subsequent_part(&payload, &boundary)
);
if is_last {
break;
}
}
}
}
};
let body = axum::body::Body::from_stream(stream);
return axum::response::Response::builder()
.header("Content-Type", MULTIPART_CONTENT_TYPE)
.header("Transfer-Encoding", "chunked")
.header("Cache-Control", "no-cache")
.body(body)
.unwrap_or_else(|_| {
axum::response::Response::builder()
.status(500)
.body(axum::body::Body::from("Internal Server Error"))
.unwrap()
});
}
}
}
}
GraphQLResponse::from(mux.handle_http(headers, gql_request).await).into_response()
}
#[allow(dead_code)]
async fn handle_graphql_fast(
State(mux): State<Arc<ServeMux>>,
headers: HeaderMap,
body: Bytes,
) -> impl IntoResponse {
mux.handle_fast(headers, body).await
}
async fn handle_graphql_fast_or_defer(
State(mux): State<Arc<ServeMux>>,
headers: HeaderMap,
body: Bytes,
) -> axum::response::Response {
let accepts_multipart = headers
.get("accept")
.and_then(|v| v.to_str().ok())
.is_some_and(|v| v.contains("multipart/mixed"));
if accepts_multipart && mux.defer_config().is_some_and(|c| c.enabled) {
if let Ok(gql_request) = serde_json::from_slice::<async_graphql::Request>(&body) {
if has_defer_directive(&gql_request.query) {
let config = mux.defer_config().unwrap().clone();
let query = gql_request.query.clone();
let fragments = extract_deferred_fragments(&query);
if fragments.len() <= config.max_deferred_fragments {
let stripped_query = strip_defer_directives(&query);
let mut eager_request = async_graphql::Request::new(stripped_query)
.variables(gql_request.variables);
if let Some(op_name) = gql_request.operation_name {
eager_request = eager_request.operation_name(op_name);
}
let full_response = mux.handle_http(headers, eager_request).await;
let full_json = serde_json::to_value(&full_response).unwrap_or_else(|_| {
serde_json::json!({"data": null, "errors": [{"message": "Serialization failed"}]})
});
let boundary = config.multipart_boundary.clone();
let (exec, mut rx) = DeferredExecution::new(config, fragments);
tokio::spawn(async move {
if let Err(e) = exec.execute(full_json).await {
tracing::warn!(error = %e, "Deferred execution failed");
}
});
let stream = async_stream::stream! {
while let Some(part) = rx.recv().await {
match part {
DeferredPart::Initial(payload) => {
yield Ok::<_, std::convert::Infallible>(
format_initial_part(&payload, &boundary)
);
}
DeferredPart::Subsequent(payload) => {
let is_last = !payload.has_next;
yield Ok::<_, std::convert::Infallible>(
format_subsequent_part(&payload, &boundary)
);
if is_last {
break;
}
}
}
}
};
let body = axum::body::Body::from_stream(stream);
return axum::response::Response::builder()
.header("Content-Type", MULTIPART_CONTENT_TYPE)
.header("Transfer-Encoding", "chunked")
.header("Cache-Control", "no-cache")
.body(body)
.unwrap_or_else(|_| {
axum::response::Response::builder()
.status(500)
.body(axum::body::Body::from("Internal Server Error"))
.unwrap()
});
}
}
}
}
mux.handle_fast(headers, body).await.into_response()
}
async fn handle_graphql_ws(
protocol: GraphQLProtocol,
ws: WebSocketUpgrade,
State(mux): State<Arc<ServeMux>>,
headers: HeaderMap,
) -> impl IntoResponse {
let handshake_headers = headers.clone();
let executor = SubscriptionExecutor::new(mux);
ws.protocols(async_graphql::http::ALL_WEBSOCKET_PROTOCOLS)
.on_upgrade(move |stream| async move {
let connection_headers = handshake_headers;
GraphQLWebSocket::new(stream, executor, protocol)
.on_connection_init(move |payload| {
let connection_headers = connection_headers.clone();
async move {
let mut data = Data::default();
data.insert(WebSocketSessionHeaders {
headers: merge_ws_connection_init_headers(
&connection_headers,
&payload,
),
});
Ok(data)
}
})
.serve()
.await;
})
}
async fn handle_graphql_defer(
State(mux): State<Arc<ServeMux>>,
headers: HeaderMap,
request: GraphQLRequest,
) -> axum::response::Response {
let gql_request = request.into_inner();
let query = gql_request.query.clone();
let defer_config = mux.defer_config().cloned();
let is_deferred = has_defer_directive(&query);
if !is_deferred || defer_config.as_ref().is_none_or(|c| !c.enabled) {
let resp = mux.handle_http(headers, gql_request).await;
return GraphQLResponse::from(resp).into_response();
}
let config = defer_config.unwrap();
let fragments = extract_deferred_fragments(&query);
if fragments.len() > config.max_deferred_fragments {
let err = ServerError::new(
format!(
"Too many @defer fragments ({}/{})",
fragments.len(),
config.max_deferred_fragments
),
None,
);
let resp = async_graphql::Response::from_errors(vec![err]);
return GraphQLResponse::from(resp).into_response();
}
let stripped_query = strip_defer_directives(&query);
let mut eager_request =
async_graphql::Request::new(stripped_query).variables(gql_request.variables);
if let Some(op_name) = gql_request.operation_name {
eager_request = eager_request.operation_name(op_name);
}
tracing::debug!(
deferred_fragments = fragments.len(),
"Executing @defer query with eager resolution"
);
let full_response = mux.handle_http(headers, eager_request).await;
let full_json = serde_json::to_value(&full_response).unwrap_or_else(
|_| serde_json::json!({"data": null, "errors": [{"message": "Serialization failed"}]}),
);
let boundary = config.multipart_boundary.clone();
let (exec, mut rx) = DeferredExecution::new(config, fragments);
tokio::spawn(async move {
if let Err(e) = exec.execute(full_json).await {
tracing::warn!(error = %e, "Deferred execution failed");
}
});
let stream = async_stream::stream! {
while let Some(part) = rx.recv().await {
match part {
DeferredPart::Initial(payload) => {
yield Ok::<_, std::convert::Infallible>(
format_initial_part(&payload, &boundary)
);
}
DeferredPart::Subsequent(payload) => {
let is_last = !payload.has_next;
yield Ok::<_, std::convert::Infallible>(
format_subsequent_part(&payload, &boundary)
);
if is_last {
break;
}
}
}
}
};
let body = axum::body::Body::from_stream(stream);
axum::response::Response::builder()
.header("Content-Type", MULTIPART_CONTENT_TYPE)
.header("Transfer-Encoding", "chunked")
.header("Cache-Control", "no-cache")
.header("Connection", "keep-alive")
.body(body)
.unwrap_or_else(|_| {
axum::response::Response::builder()
.status(500)
.body(axum::body::Body::from("Internal Server Error"))
.unwrap()
})
}
async fn graphql_playground() -> impl IntoResponse {
Html(async_graphql::http::playground_source(
async_graphql::http::GraphQLPlaygroundConfig::new("/graphql")
.subscription_endpoint("/graphql/ws"),
))
}
async fn metrics_handler(headers: HeaderMap) -> axum::response::Response {
let required_key = match std::env::var("METRICS_API_KEY") {
Ok(k) if !k.is_empty() => k,
_ => {
return (
axum::http::StatusCode::FORBIDDEN,
"Metrics endpoint requires METRICS_API_KEY to be configured",
)
.into_response();
}
};
let provided_key = headers
.get("x-metrics-key")
.and_then(|v| v.to_str().ok())
.unwrap_or("");
if !constant_time_eq(provided_key.as_bytes(), required_key.as_bytes()) {
return (
axum::http::StatusCode::UNAUTHORIZED,
"Unauthorized: Valid x-metrics-key header required",
)
.into_response();
}
let metrics = GatewayMetrics::global();
let body = metrics.render();
(
[(
axum::http::header::CONTENT_TYPE,
"text/plain; charset=utf-8",
)],
body,
)
.into_response()
}
async fn analytics_dashboard_handler(headers: HeaderMap) -> axum::response::Response {
let required_key = match std::env::var("ANALYTICS_API_KEY") {
Ok(k) if !k.is_empty() => k,
_ => {
return (
axum::http::StatusCode::FORBIDDEN,
"Analytics endpoint requires ANALYTICS_API_KEY to be configured",
)
.into_response();
}
};
let provided_key = headers
.get("x-analytics-key")
.and_then(|v| v.to_str().ok())
.unwrap_or("");
if !constant_time_eq(provided_key.as_bytes(), required_key.as_bytes()) {
return (
axum::http::StatusCode::UNAUTHORIZED,
"Unauthorized: Valid x-analytics-key header required",
)
.into_response();
}
Html(crate::analytics::analytics_dashboard_html()).into_response()
}
async fn analytics_api_handler(
State(mux): State<Arc<ServeMux>>,
headers: HeaderMap,
) -> impl IntoResponse {
let required_key = match std::env::var("ANALYTICS_API_KEY") {
Ok(k) if !k.is_empty() => k,
_ => {
return Json(serde_json::json!({
"error": "Forbidden",
"message": "Analytics endpoint requires ANALYTICS_API_KEY to be configured"
}));
}
};
let provided_key = headers
.get("x-analytics-key")
.and_then(|v| v.to_str().ok())
.unwrap_or("");
if !constant_time_eq(provided_key.as_bytes(), required_key.as_bytes()) {
return Json(serde_json::json!({
"error": "Unauthorized",
"message": "Valid x-analytics-key header required"
}));
}
if let Some(ref analytics) = mux.analytics {
let snapshot = analytics.get_snapshot();
Json(
serde_json::to_value(snapshot)
.unwrap_or_else(|_| serde_json::json!({"error": "Failed to serialize analytics"})),
)
} else {
Json(serde_json::json!({"error": "Analytics not enabled"}))
}
}
async fn analytics_reset_handler(
State(mux): State<Arc<ServeMux>>,
headers: HeaderMap,
) -> impl IntoResponse {
let required_key = match std::env::var("ANALYTICS_API_KEY") {
Ok(k) if !k.is_empty() => k,
_ => {
return Json(serde_json::json!({
"error": "Forbidden",
"message": "Analytics endpoint requires ANALYTICS_API_KEY to be configured"
}));
}
};
let provided_key = headers
.get("x-analytics-key")
.and_then(|v| v.to_str().ok())
.unwrap_or("");
if !constant_time_eq(provided_key.as_bytes(), required_key.as_bytes()) {
return Json(serde_json::json!({
"error": "Unauthorized",
"message": "Valid x-analytics-key header required"
}));
}
if let Some(ref analytics) = mux.analytics {
analytics.reset();
Json(serde_json::json!({"status": "ok", "message": "Analytics reset successfully"}))
} else {
Json(serde_json::json!({"error": "Analytics not enabled"}))
}
}
fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
if a.len() != b.len() {
return false;
}
let mut result: u8 = 0;
for (x, y) in a.iter().zip(b.iter()) {
result |= x ^ y;
}
result == 0
}
pub fn build_tcp_listener_tuned(addr: &str) -> std::io::Result<std::net::TcpListener> {
use std::net::SocketAddr;
let addr: SocketAddr = addr.parse().map_err(|e| {
std::io::Error::new(
std::io::ErrorKind::InvalidInput,
format!("invalid addr: {e}"),
)
})?;
let socket = socket2::Socket::new(
if addr.is_ipv6() {
socket2::Domain::IPV6
} else {
socket2::Domain::IPV4
},
socket2::Type::STREAM,
None,
)?;
socket.set_reuse_address(true)?;
#[cfg(any(target_os = "linux", target_os = "macos"))]
socket.set_reuse_port(true)?;
socket.set_nodelay(true)?;
socket.set_nonblocking(true)?;
socket.bind(&addr.into())?;
socket.listen(4096)?;
Ok(socket.into())
}
pub async fn serve_reuseport(
addr: &str,
num_workers: usize,
app: axum::Router,
) -> crate::error::Result<()> {
let mut handles = Vec::with_capacity(num_workers);
let addr_owned = addr.to_string();
for i in 0..num_workers {
let std_listener = build_tcp_listener_tuned(addr).map_err(|e| {
crate::error::Error::Internal(format!(
"Failed to bind worker {i} listener on {addr}: {e}"
))
})?;
std_listener
.set_nonblocking(true)
.map_err(|e| crate::error::Error::Internal(format!("set_nonblocking failed: {e}")))?;
let listener = tokio::net::TcpListener::from_std(std_listener)
.map_err(|e| crate::error::Error::Internal(format!("from_std failed: {e}")))?;
let app = app.clone();
let addr_clone = addr_owned.clone();
let handle = tokio::spawn(async move {
tracing::info!("Worker {i} accepting on {addr_clone} (SO_REUSEPORT)");
if let Err(e) = axum::serve(listener, app).await {
tracing::error!("Worker {i} server error: {e}");
}
});
handles.push(handle);
}
for handle in handles {
let _ = handle.await;
}
Ok(())
}
async fn handle_live_query_ws(
ws: WebSocketUpgrade,
State(mux): State<Arc<ServeMux>>,
) -> impl IntoResponse {
ws.protocols(["graphql-transport-ws"])
.on_upgrade(move |socket| handle_live_socket(socket, mux))
}
async fn handle_live_socket(socket: WebSocket, mux: Arc<ServeMux>) {
use std::collections::HashMap;
use tokio::sync::mpsc;
let (sender, mut receiver) = socket.split();
#[derive(serde::Deserialize)]
struct WsMessage {
#[serde(rename = "type")]
msg_type: String,
id: Option<String>,
payload: Option<serde_json::Value>,
}
#[derive(serde::Serialize, Clone)]
struct WsResponse {
#[serde(rename = "type")]
msg_type: String,
#[serde(skip_serializing_if = "Option::is_none")]
id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
payload: Option<serde_json::Value>,
}
#[derive(Clone)]
struct LiveSubscription {
id: String,
query: String,
variables: Option<serde_json::Value>,
operation_name: Option<String>,
triggers: Vec<String>,
is_live: bool,
}
let use_gbp_compression = Arc::new(parking_lot::RwLock::new(false));
let mut connection_initialized = false;
let active_subscriptions: Arc<parking_lot::RwLock<HashMap<String, LiveSubscription>>> =
Arc::new(parking_lot::RwLock::new(HashMap::new()));
let (ws_tx, mut ws_rx) = mpsc::channel::<WsResponse>(100);
let live_query_store = crate::live_query::create_live_query_store();
let mut invalidation_rx = live_query_store.subscribe_invalidations();
let _ws_tx_clone = ws_tx.clone();
let sender = Arc::new(tokio::sync::Mutex::new(sender));
let sender_clone = sender.clone();
let use_compression_clone = use_gbp_compression.clone();
let forward_task = tokio::spawn(async move {
while let Some(msg) = ws_rx.recv().await {
let mut sender = sender_clone.lock().await;
if *use_compression_clone.read() {
if let Some(payload) = &msg.payload {
match crate::gbp::GbpEncoder::new().encode_lz4(payload) {
Ok(compressed) => {
let envelope = serde_json::json!({
"type": msg.msg_type,
"id": msg.id,
"compressed": true
});
let envelope_json = serde_json::to_string(&envelope).unwrap();
if sender
.send(Message::Text(envelope_json.into()))
.await
.is_err()
{
break;
}
if sender
.send(Message::Binary(compressed.into()))
.await
.is_err()
{
break;
}
}
Err(e) => {
tracing::warn!(
"Failed to compress with GBP, falling back to JSON: {}",
e
);
let json = serde_json::to_string(&msg).unwrap();
if sender.send(Message::Text(json.into())).await.is_err() {
break;
}
}
}
} else {
let json = serde_json::to_string(&msg).unwrap();
if sender.send(Message::Text(json.into())).await.is_err() {
break;
}
}
} else {
let json = serde_json::to_string(&msg).unwrap();
if sender.send(Message::Text(json.into())).await.is_err() {
break;
}
}
}
});
let subscriptions_clone = active_subscriptions.clone();
let mux_clone = mux.clone();
let ws_tx_for_invalidation = ws_tx.clone();
let invalidation_task = tokio::spawn(async move {
loop {
match invalidation_rx.recv().await {
Ok(event) => {
let trigger_pattern = format!("{}.{}", event.type_name, event.action);
let matching_subs: Vec<LiveSubscription> = {
let subs = subscriptions_clone.read();
subs.values()
.filter(|sub| {
sub.is_live
&& sub.triggers.iter().any(|t| {
t == &trigger_pattern
|| t == &format!("{}.*", event.type_name)
|| t == &format!("*.{}", event.action)
|| t == "*.*"
})
})
.cloned()
.collect()
};
for sub in matching_subs {
tracing::info!(
subscription_id = %sub.id,
trigger = %trigger_pattern,
"Re-executing live query due to invalidation"
);
let mut gql_request = async_graphql::Request::new(&sub.query);
if let Some(vars) = &sub.variables {
if let Ok(variables) = serde_json::from_value(vars.clone()) {
gql_request = gql_request.variables(variables);
}
}
if let Some(op_name) = &sub.operation_name {
gql_request = gql_request.operation_name(op_name);
}
if sub.is_live {
gql_request = gql_request.data(LiveQueryRequestMarker);
}
let response = mux_clone.handle_http(HeaderMap::new(), gql_request).await;
let response_json = serde_json::to_value(&response).unwrap_or_default();
let update = WsResponse {
msg_type: "next".to_string(),
id: Some(sub.id.clone()),
payload: Some(response_json),
};
if ws_tx_for_invalidation.send(update).await.is_err() {
break;
}
}
}
Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => {
continue;
}
Err(tokio::sync::broadcast::error::RecvError::Closed) => {
break;
}
}
}
});
while let Some(msg) = receiver.next().await {
let msg = match msg {
Ok(Message::Text(text)) => text,
Ok(Message::Close(_)) => break,
_ => continue,
};
let parsed: WsMessage = match serde_json::from_str(&msg) {
Ok(m) => m,
Err(e) => {
tracing::warn!("Failed to parse WebSocket message: {}", e);
continue;
}
};
match parsed.msg_type.as_str() {
"connection_init" => {
connection_initialized = true;
if let Some(payload) = &parsed.payload {
if let Some(compression) = payload.get("compression").and_then(|c| c.as_str()) {
if compression == "gbp-lz4" || compression == "gbp" {
*use_gbp_compression.write() = true;
tracing::info!("GBP compression enabled for live query connection");
}
}
}
let mut ack_payload = serde_json::json!({});
if *use_gbp_compression.read() {
ack_payload["compression"] = serde_json::json!("gbp-lz4");
ack_payload["compressionInfo"] = serde_json::json!({
"algorithm": "GBP Ultra + LZ4",
"expectedReduction": "90-99%",
"format": "binary"
});
}
let ack = WsResponse {
msg_type: "connection_ack".to_string(),
id: None,
payload: if ack_payload.as_object().unwrap().is_empty() {
None
} else {
Some(ack_payload)
},
};
if ws_tx.send(ack).await.is_err() {
break;
}
}
"ping" => {
let pong = WsResponse {
msg_type: "pong".to_string(),
id: None,
payload: None,
};
let _ = ws_tx.send(pong).await;
}
"subscribe" => {
if !connection_initialized {
tracing::warn!("Received subscribe before connection_init");
continue;
}
const MAX_SUBSCRIPTIONS_PER_CONNECTION: usize = 100;
if active_subscriptions.read().len() >= MAX_SUBSCRIPTIONS_PER_CONNECTION {
tracing::warn!(
max = MAX_SUBSCRIPTIONS_PER_CONNECTION,
"WebSocket subscription limit reached, rejecting new subscription"
);
let err_msg = WsResponse {
msg_type: "error".to_string(),
id: parsed.id.clone(),
payload: Some(serde_json::json!({
"message": format!(
"Subscription limit ({}) exceeded for this connection",
MAX_SUBSCRIPTIONS_PER_CONNECTION
)
})),
};
let _ = ws_tx.send(err_msg).await;
continue;
}
let id = parsed.id.clone().unwrap_or_default();
let query = parsed
.payload
.as_ref()
.and_then(|p| p.get("query"))
.and_then(|q| q.as_str())
.unwrap_or("");
let is_live = crate::live_query::has_live_directive(query);
let clean_query = if is_live {
let stripped = crate::live_query::strip_live_directive(query);
if stripped.trim_start().starts_with("subscription") {
stripped.replacen("subscription", "query", 1)
} else {
stripped
}
} else {
query.to_string()
};
let variables = parsed
.payload
.as_ref()
.and_then(|p| p.get("variables"))
.cloned();
let operation_name = parsed
.payload
.as_ref()
.and_then(|p| p.get("operationName"))
.and_then(|n| n.as_str())
.map(|s| s.to_string());
tracing::info!(
subscription_id = %id,
is_live = is_live,
"Live query subscription started"
);
let mut gql_request = async_graphql::Request::new(&clean_query);
if let Some(vars) = &variables {
if let Ok(v) = serde_json::from_value(vars.clone()) {
gql_request = gql_request.variables(v);
}
}
if let Some(ref op_name) = operation_name {
gql_request = gql_request.operation_name(op_name);
}
if is_live {
gql_request = gql_request.data(LiveQueryRequestMarker);
}
let response = mux.handle_http(HeaderMap::new(), gql_request).await;
let response_json = serde_json::to_value(&response).unwrap_or_default();
let next_msg = WsResponse {
msg_type: "next".to_string(),
id: Some(id.clone()),
payload: Some(response_json),
};
if ws_tx.send(next_msg).await.is_err() {
break;
}
if is_live {
let configs = mux.schema.live_query_configs();
let mut triggers = std::collections::HashSet::new();
if !configs.is_empty() {
for (op_name, config) in configs {
let matched = {
let mut found = false;
let query_bytes = clean_query.as_bytes();
let name_bytes = op_name.as_bytes();
let qlen = query_bytes.len();
let nlen = name_bytes.len();
if nlen > 0 && qlen >= nlen {
for i in 0..=(qlen - nlen) {
if &query_bytes[i..i + nlen] == name_bytes {
let left_ok = i == 0
|| !query_bytes[i - 1].is_ascii_alphanumeric()
&& query_bytes[i - 1] != b'_';
let right_ok = i + nlen == qlen
|| !query_bytes[i + nlen].is_ascii_alphanumeric()
&& query_bytes[i + nlen] != b'_';
if left_ok && right_ok {
found = true;
break;
}
}
}
}
found
};
if matched {
for trigger in &config.triggers {
triggers.insert(trigger.clone());
}
tracing::info!(
operation = %op_name,
found_triggers = ?config.triggers,
"Configured live query triggers found"
);
}
}
}
if triggers.is_empty() {
tracing::debug!("No configured triggers found, using defaults");
triggers.insert("User.create".to_string());
triggers.insert("User.update".to_string());
triggers.insert("User.delete".to_string());
triggers.insert("*.*".to_string());
}
let subscription = LiveSubscription {
id: id.clone(),
query: clean_query,
variables,
operation_name,
triggers: triggers.into_iter().collect(),
is_live: true,
};
active_subscriptions
.write()
.insert(id.clone(), subscription);
tracing::info!(subscription_id = %id, "Live subscription registered for updates");
} else {
let complete_msg = WsResponse {
msg_type: "complete".to_string(),
id: Some(id),
payload: None,
};
if ws_tx.send(complete_msg).await.is_err() {
break;
}
}
}
"complete" => {
if let Some(id) = parsed.id {
active_subscriptions.write().remove(&id);
tracing::debug!(subscription_id = %id, "Client completed subscription");
}
}
_ => {
tracing::debug!("Unknown message type: {}", parsed.msg_type);
}
}
}
forward_task.abort();
invalidation_task.abort();
tracing::debug!("Live query WebSocket connection closed");
}
#[cfg(test)]
mod tests {
use super::*;
use axum::{
body::{to_bytes, Body},
http::{Request, StatusCode},
};
use tower::ServiceExt;
const GREETER_DESCRIPTOR: &[u8] = include_bytes!("generated/greeter_descriptor.bin");
fn build_router_mux() -> ServeMux {
let schema = crate::schema::SchemaBuilder::new()
.with_descriptor_set_bytes(GREETER_DESCRIPTOR)
.build(&crate::grpc_client::GrpcClientPool::new())
.expect("schema builds");
ServeMux::new(schema)
}
#[tokio::test]
async fn playground_served_on_get() {
let mut mux = build_router_mux();
mux.enable_playground();
let app = mux.into_router();
let response = app
.oneshot(
Request::builder()
.uri("/graphql")
.body(Body::empty())
.expect("build request"),
)
.await
.expect("receive response");
assert_eq!(response.status(), StatusCode::OK);
let body = to_bytes(response.into_body(), 1024 * 1024)
.await
.expect("read body");
let body_str = String::from_utf8(body.to_vec()).expect("utf8 body");
assert!(
body_str.contains("GraphQL Playground"),
"playground HTML should be returned"
);
assert!(
body_str.contains("/graphql/ws"),
"websocket endpoint should be linked"
);
}
#[tokio::test]
async fn test_servemux_new() {
let mux = build_router_mux();
assert!(mux.circuit_breaker().is_none());
assert!(mux.response_cache().is_none());
}
#[tokio::test]
async fn test_servemux_clone() {
let mux = build_router_mux();
let cloned = mux.clone();
let _router1 = mux.into_router();
let _router2 = cloned.into_router();
}
#[tokio::test]
async fn test_enable_health_checks() {
let mut mux = build_router_mux();
mux.set_client_pool(crate::grpc_client::GrpcClientPool::new());
mux.enable_health_checks();
}
#[tokio::test]
async fn test_enable_metrics() {
let mut mux = build_router_mux();
mux.enable_metrics();
}
#[tokio::test]
async fn test_enable_circuit_breaker() {
let mut mux = build_router_mux();
let config = crate::circuit_breaker::CircuitBreakerConfig::default();
mux.enable_circuit_breaker(config);
assert!(mux.circuit_breaker().is_some());
}
#[tokio::test]
async fn test_enable_response_cache() {
let mut mux = build_router_mux();
let config = crate::cache::CacheConfig::default();
mux.enable_response_cache(config);
assert!(mux.response_cache().is_some());
}
#[tokio::test]
async fn test_enable_compression() {
let mut mux = build_router_mux();
let config = crate::compression::CompressionConfig::default();
mux.enable_compression(config);
assert!(mux.compression_config().is_some());
}
#[tokio::test]
async fn test_enable_query_whitelist() {
let mut mux = build_router_mux();
let config = crate::query_whitelist::QueryWhitelistConfig::warn();
mux.enable_query_whitelist(config);
assert!(mux.query_whitelist().is_some());
}
#[tokio::test]
async fn test_enable_analytics() {
let mut mux = build_router_mux();
let config = crate::analytics::AnalyticsConfig::default();
mux.enable_analytics(config);
assert!(mux.analytics().is_some());
}
#[tokio::test]
async fn test_enable_request_collapsing() {
let mut mux = build_router_mux();
let config = crate::request_collapsing::RequestCollapsingConfig::default();
mux.enable_request_collapsing(config);
assert!(mux.request_collapsing().is_some());
}
#[tokio::test]
async fn test_enable_high_performance() {
let mut mux = build_router_mux();
let config = crate::high_performance::HighPerfConfig::default();
mux.enable_high_performance(config);
}
#[tokio::test]
async fn test_perf_metrics() {
let mux = build_router_mux();
let _metrics = mux.perf_metrics();
}
#[tokio::test]
async fn test_health_endpoint() {
let mut mux = build_router_mux();
mux.set_client_pool(crate::grpc_client::GrpcClientPool::new());
mux.enable_health_checks();
let app = mux.into_router();
let response = app
.oneshot(
Request::builder()
.uri("/health")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
}
#[tokio::test]
async fn test_readiness_endpoint() {
let mut mux = build_router_mux();
mux.set_client_pool(crate::grpc_client::GrpcClientPool::new());
mux.enable_health_checks();
let app = mux.into_router();
let response = app
.oneshot(
Request::builder()
.uri("/ready")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert!(response.status().is_success() || response.status().is_server_error());
}
#[tokio::test]
async fn test_metrics_endpoint() {
let mut mux = build_router_mux();
mux.enable_metrics();
let app = mux.into_router();
let response = app
.clone()
.oneshot(
Request::builder()
.uri("/metrics")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::FORBIDDEN);
unsafe { std::env::set_var("METRICS_API_KEY", "test-metrics-secret") };
let response = app
.oneshot(
Request::builder()
.uri("/metrics")
.header("x-metrics-key", "test-metrics-secret")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
unsafe { std::env::remove_var("METRICS_API_KEY") };
assert_eq!(response.status(), StatusCode::OK);
}
#[tokio::test]
async fn test_analytics_endpoint() {
let mut mux = build_router_mux();
mux.enable_analytics(crate::analytics::AnalyticsConfig::default());
let app = mux.into_router();
let response = app
.clone()
.oneshot(
Request::builder()
.uri("/analytics")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::FORBIDDEN);
unsafe { std::env::set_var("ANALYTICS_API_KEY", "test-analytics-secret") };
let response = app
.oneshot(
Request::builder()
.uri("/analytics")
.header("x-analytics-key", "test-analytics-secret")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
unsafe { std::env::remove_var("ANALYTICS_API_KEY") };
assert_eq!(response.status(), StatusCode::OK);
}
#[tokio::test]
async fn test_graphql_post_query() {
let mux = build_router_mux();
let app = mux.into_router();
let query = r#"{ __schema { queryType { name } } }"#;
let request_body = serde_json::json!({
"query": query
});
let response = app
.oneshot(
Request::builder()
.method("POST")
.uri("/graphql")
.header("content-type", "application/json")
.body(Body::from(serde_json::to_string(&request_body).unwrap()))
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
}
#[tokio::test]
async fn test_graphql_post_introspection() {
let mux = build_router_mux();
let app = mux.into_router();
let query = r#"{ __schema { types { name } } }"#;
let request_body = serde_json::json!({
"query": query
});
let response = app
.oneshot(
Request::builder()
.method("POST")
.uri("/graphql")
.header("content-type", "application/json")
.body(Body::from(serde_json::to_string(&request_body).unwrap()))
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
}
#[tokio::test]
async fn test_graphql_post_empty_body() {
let mux = build_router_mux();
let app = mux.into_router();
let response = app
.oneshot(
Request::builder()
.method("POST")
.uri("/graphql")
.header("content-type", "application/json")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert!(!response.status().is_success());
}
#[tokio::test]
async fn test_graphql_post_invalid_json() {
let mux = build_router_mux();
let app = mux.into_router();
let response = app
.oneshot(
Request::builder()
.method("POST")
.uri("/graphql")
.header("content-type", "application/json")
.body(Body::from("{invalid json"))
.unwrap(),
)
.await
.unwrap();
assert!(!response.status().is_success());
}
#[tokio::test]
async fn test_graphql_with_variables() {
let mux = build_router_mux();
let app = mux.into_router();
let query = r#"query Test($name: String!) { __type(name: $name) { name } }"#;
let request_body = serde_json::json!({
"query": query,
"variables": { "name": "String" }
});
let response = app
.oneshot(
Request::builder()
.method("POST")
.uri("/graphql")
.header("content-type", "application/json")
.body(Body::from(serde_json::to_string(&request_body).unwrap()))
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
}
#[tokio::test]
async fn test_graphql_with_operation_name() {
let mux = build_router_mux();
let app = mux.into_router();
let query = r#"
query First { __schema { queryType { name } } }
query Second { __schema { mutationType { name } } }
"#;
let request_body = serde_json::json!({
"query": query,
"operationName": "First"
});
let response = app
.oneshot(
Request::builder()
.method("POST")
.uri("/graphql")
.header("content-type", "application/json")
.body(Body::from(serde_json::to_string(&request_body).unwrap()))
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
}
#[tokio::test]
async fn test_with_middleware_builder() {
let mux = build_router_mux();
let _router = mux.into_router();
}
#[tokio::test]
async fn test_multiple_configurations() {
let mut mux = build_router_mux();
mux.enable_metrics();
mux.enable_playground();
mux.enable_analytics(crate::analytics::AnalyticsConfig::default());
mux.enable_compression(crate::compression::CompressionConfig::default());
let app = mux.into_router();
let response = app
.oneshot(
Request::builder()
.uri("/metrics")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
let status = response.status();
assert!(
status == StatusCode::FORBIDDEN || status == StatusCode::UNAUTHORIZED,
"Expected 403 or 401 but got {}",
status
);
}
#[tokio::test]
async fn test_into_router_consumes_mux() {
let mux = build_router_mux();
let _router = mux.into_router();
}
#[tokio::test]
async fn test_security_headers_present() {
let mux = build_router_mux();
let app = mux.into_router();
let response = app
.oneshot(
Request::builder()
.uri("/health")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
let headers = response.headers();
assert_eq!(
headers
.get("x-content-type-options")
.unwrap()
.to_str()
.unwrap(),
"nosniff"
);
assert_eq!(
headers.get("x-frame-options").unwrap().to_str().unwrap(),
"DENY"
);
assert_eq!(
headers.get("x-xss-protection").unwrap().to_str().unwrap(),
"1; mode=block"
);
assert_eq!(
headers
.get("strict-transport-security")
.unwrap()
.to_str()
.unwrap(),
"max-age=31536000; includeSubDomains"
);
assert_eq!(
headers.get("cache-control").unwrap().to_str().unwrap(),
"no-store, no-cache, must-revalidate"
);
assert_eq!(
headers.get("referrer-policy").unwrap().to_str().unwrap(),
"strict-origin-when-cross-origin"
);
assert_eq!(
headers
.get("x-dns-prefetch-control")
.unwrap()
.to_str()
.unwrap(),
"off"
);
let p_policy = headers.get("permissions-policy").unwrap().to_str().unwrap();
assert!(p_policy.contains("camera=()"));
assert!(p_policy.contains("microphone=()"));
assert!(p_policy.contains("geolocation=()"));
let csp = headers
.get("content-security-policy")
.unwrap()
.to_str()
.unwrap();
assert!(csp.contains("default-src 'none'"));
assert!(csp.contains("base-uri 'none'"));
assert!(csp.contains("frame-ancestors 'none'"));
assert!(csp.contains("form-action 'none'"));
assert_eq!(
headers
.get("cross-origin-opener-policy")
.unwrap()
.to_str()
.unwrap(),
"same-origin"
);
assert_eq!(
headers
.get("cross-origin-embedder-policy")
.unwrap()
.to_str()
.unwrap(),
"require-corp"
);
assert_eq!(
headers
.get("cross-origin-resource-policy")
.unwrap()
.to_str()
.unwrap(),
"same-origin"
);
}
#[tokio::test]
async fn test_cors_headers_absent_by_default() {
let mux = build_router_mux();
let app = mux.into_router();
let response = app
.oneshot(
Request::builder()
.uri("/health")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
let headers = response.headers();
assert!(headers.get("access-control-allow-origin").is_none());
}
#[test]
fn test_parse_cors_allowed_origin_rejects_empty_values() {
assert!(parse_cors_allowed_origin(None).is_none());
assert!(parse_cors_allowed_origin(Some("")).is_none());
assert!(parse_cors_allowed_origin(Some(" ")).is_none());
assert!(parse_cors_allowed_origin(Some("javascript:alert(1)")).is_none());
assert!(parse_cors_allowed_origin(Some("https://app.example.com/path")).is_none());
}
#[test]
fn test_content_security_policy_switches_for_playground() {
let api_csp = content_security_policy(false);
let playground_csp = content_security_policy(true);
assert!(api_csp.contains("default-src 'none'"));
assert!(playground_csp.contains("default-src 'self'"));
assert!(playground_csp.contains("https://cdn.jsdelivr.net"));
}
#[test]
fn test_merge_ws_connection_init_headers_rejects_forwarded_ip_overrides() {
let mut base_headers = HeaderMap::new();
base_headers.insert(
axum::http::header::AUTHORIZATION,
axum::http::HeaderValue::from_static("Bearer handshake"),
);
let merged = merge_ws_connection_init_headers(
&base_headers,
&serde_json::json!({
"headers": {
"authorization": "Bearer init-token",
"x-forwarded-for": "203.0.113.10"
}
}),
);
assert_eq!(
merged
.get(axum::http::header::AUTHORIZATION)
.unwrap()
.to_str()
.unwrap(),
"Bearer init-token"
);
assert!(merged.get("x-forwarded-for").is_none());
}
#[tokio::test]
async fn test_prepare_execution_context_runs_middlewares_for_ws_requests() {
let mut mux = build_router_mux();
let auth = crate::middleware::EnhancedAuthMiddleware::with_fn(
crate::middleware::AuthConfig::required(),
|token| {
let accepted = token == "ws-secret";
Box::pin(async move {
if accepted {
Ok(crate::middleware::AuthClaims {
sub: Some("user-1".to_string()),
..Default::default()
})
} else {
Err(crate::error::Error::Unauthorized("bad token".to_string()))
}
})
},
);
mux.add_middleware(std::sync::Arc::new(auth));
let mut headers = HeaderMap::new();
headers.insert(
axum::http::header::AUTHORIZATION,
axum::http::HeaderValue::from_static("Bearer ws-secret"),
);
let ctx = mux
.prepare_execution_context(&headers)
.await
.expect("middleware-authenticated websocket context");
assert_eq!(
ctx.get("auth.authenticated"),
Some(&serde_json::json!(true))
);
assert_eq!(ctx.user_id().as_deref(), Some("user-1"));
}
#[tokio::test]
async fn test_live_queries_bypass_response_cache() {
let mut mux = build_router_mux();
mux.enable_response_cache(crate::cache::CacheConfig::default());
let stripped_query = "query { __schema { queryType { name } } }";
let request = async_graphql::Request::new(stripped_query).data(LiveQueryRequestMarker);
let response = mux.handle_http(HeaderMap::new(), request).await;
assert!(
response.errors.is_empty(),
"live query should execute successfully"
);
let cache_key = crate::cache::ResponseCache::generate_cache_key(
stripped_query,
Some(&serde_json::json!({})),
None,
&[],
);
let cache_lookup = mux
.response_cache()
.expect("response cache enabled")
.get(&cache_key)
.await;
assert!(
matches!(cache_lookup, crate::cache::CacheLookupResult::Miss),
"live query responses must not be cached"
);
}
}