use crate::high_performance::{FastJsonParser, ShardedCache};
use crate::persisted_queries::{
create_apq_store, process_apq_request, PersistedQueryConfig, SharedPersistedQueryStore,
};
use crate::request_collapsing::{
create_request_collapsing_registry, CollapseResult, RequestCollapsingConfig, RequestKey,
SharedRequestCollapsingRegistry,
};
use crate::rest_connector::{HttpMethod, RestConnector, RestEndpoint};
use crate::Result;
use ahash::AHashMap;
use bytes::Bytes;
use futures::stream::{FuturesUnordered, StreamExt};
use governor::{Quota, RateLimiter};
use serde::{Deserialize, Serialize};
use serde_json::Value as JsonValue;
use std::collections::HashMap;
use std::net::IpAddr;
use std::num::NonZeroU32;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::RwLock;
use crate::circuit_breaker::{CircuitBreaker, CircuitBreakerConfig};
use crate::query_cost_analyzer::{QueryCostAnalyzer, QueryCostConfig};
use crate::waf::{is_introspection, validate_raw, WafConfig};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DdosConfig {
pub global_rps: u32,
pub per_ip_rps: u32,
pub per_ip_burst: u32,
}
impl Default for DdosConfig {
fn default() -> Self {
Self {
global_rps: 10_000, per_ip_rps: 100, per_ip_burst: 200, }
}
}
impl DdosConfig {
pub fn strict() -> Self {
Self {
global_rps: 5_000,
per_ip_rps: 50,
per_ip_burst: 100,
}
}
pub fn relaxed() -> Self {
Self {
global_rps: 100_000,
per_ip_rps: 1_000,
per_ip_burst: 2_000,
}
}
}
struct TrackedLimiter {
limiter: Arc<
RateLimiter<
governor::state::NotKeyed,
governor::state::InMemoryState,
governor::clock::DefaultClock,
>,
>,
last_seen_secs: AtomicU64,
}
impl TrackedLimiter {
fn new(
limiter: Arc<
RateLimiter<
governor::state::NotKeyed,
governor::state::InMemoryState,
governor::clock::DefaultClock,
>,
>,
) -> Self {
Self {
limiter,
last_seen_secs: AtomicU64::new(Self::current_secs()),
}
}
fn touch(&self) {
self.last_seen_secs
.store(Self::current_secs(), Ordering::Relaxed);
}
fn current_secs() -> u64 {
use std::time::SystemTime;
SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.unwrap_or_default()
.as_secs()
}
}
#[derive(Clone)]
pub struct DdosProtection {
global_limiter: Arc<
RateLimiter<
governor::state::NotKeyed,
governor::state::InMemoryState,
governor::clock::DefaultClock,
>,
>,
ip_limiters: Arc<RwLock<HashMap<IpAddr, Arc<TrackedLimiter>>>>,
config: DdosConfig,
}
impl DdosProtection {
pub fn new(config: DdosConfig) -> Self {
let safe_global_rps = NonZeroU32::new(config.global_rps.max(1)).unwrap();
let global_quota = Quota::per_second(safe_global_rps);
Self {
global_limiter: Arc::new(RateLimiter::direct(global_quota)),
ip_limiters: Arc::new(RwLock::new(HashMap::new())),
config,
}
}
pub async fn check(&self, ip: IpAddr) -> bool {
if self.config.global_rps == 0 || self.config.per_ip_rps == 0 {
return false;
}
if self.global_limiter.check().is_err() {
tracing::warn!(
"🛡️ Global rate limit exceeded (>{} req/sec)",
self.config.global_rps
);
return false;
}
let tracked = {
let limiters = self.ip_limiters.read().await;
limiters.get(&ip).cloned()
};
let tracked = match tracked {
Some(t) => {
t.touch();
t
}
None => {
let safe_rps = NonZeroU32::new(self.config.per_ip_rps.max(1)).unwrap();
let safe_burst = NonZeroU32::new(self.config.per_ip_burst.max(1)).unwrap();
let quota = Quota::per_second(safe_rps).allow_burst(safe_burst);
let new_limiter = Arc::new(RateLimiter::direct(quota));
let new_tracked = Arc::new(TrackedLimiter::new(new_limiter));
let mut limiters = self.ip_limiters.write().await;
if let Some(existing) = limiters.get(&ip) {
existing.clone()
} else {
limiters.insert(ip, new_tracked.clone());
new_tracked
}
}
};
if tracked.limiter.check().is_err() {
tracing::warn!(
client_ip = %ip,
limit = self.config.per_ip_rps,
burst = self.config.per_ip_burst,
"🛡️ Per-IP rate limit exceeded"
);
return false;
}
true
}
pub async fn cleanup_stale_limiters(&self, max_age_secs: u64) {
let mut limiters = self.ip_limiters.write().await;
let before = limiters.len();
let now = TrackedLimiter::current_secs();
limiters.retain(|_, v| {
let last_seen = v.last_seen_secs.load(Ordering::Relaxed);
now.saturating_sub(last_seen) < max_age_secs
});
let after = limiters.len();
if before != after {
tracing::debug!("🧹 Cleaned up stale IP limiters: {} -> {}", before, after);
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RouterConfig {
pub port: u16,
pub subgraphs: Vec<SubgraphConfig>,
pub force_gbp: bool,
#[serde(skip)]
pub apq: Option<PersistedQueryConfig>,
#[serde(skip)]
pub request_collapsing: Option<RequestCollapsingConfig>,
#[serde(default)]
pub waf: Option<WafConfig>,
#[serde(skip)] pub query_cost: Option<QueryCostConfig>,
#[serde(default)]
pub disable_introspection: bool,
#[serde(default)]
pub circuit_breaker: Option<CircuitBreakerConfig>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SubgraphConfig {
pub name: String,
pub url: String,
#[serde(default)]
pub headers: HashMap<String, String>,
#[serde(default)]
pub mtls: Option<crate::mtls::MtlsConfig>,
}
pub struct GbpRouter {
config: RouterConfig,
clients: AHashMap<String, RestConnector>,
cache: Arc<ShardedCache<Bytes>>,
json_parser: Arc<FastJsonParser>,
request_count: AtomicU64,
cache_hits: AtomicU64,
cache_ttl: Duration,
apq_store: Option<SharedPersistedQueryStore>,
collapsing_registry: Option<SharedRequestCollapsingRegistry>,
cost_analyzer: Option<Arc<QueryCostAnalyzer>>,
circuit_breakers: HashMap<String, Arc<CircuitBreaker>>,
}
impl GbpRouter {
pub fn new(config: RouterConfig) -> Self {
Self::with_cache_ttl(config, Duration::from_secs(60))
}
pub fn with_cache_ttl(config: RouterConfig, cache_ttl: Duration) -> Self {
let mut clients = AHashMap::with_capacity(config.subgraphs.len());
let mut circuit_breakers = HashMap::with_capacity(config.subgraphs.len());
let cb_config = config.circuit_breaker.clone().unwrap_or_default();
for subgraph in &config.subgraphs {
let mut builder = RestConnector::builder()
.base_url(&subgraph.url)
.add_endpoint(
RestEndpoint::new("query", "")
.method(HttpMethod::POST)
.header("Content-Type", "application/json"),
);
if config.force_gbp {
builder = builder.default_header("Accept", "application/x-gbp");
tracing::info!(
subgraph = %subgraph.name,
url = %subgraph.url,
gbp = true,
"🚀 High-performance subgraph configured"
);
}
for (key, value) in &subgraph.headers {
builder = builder.default_header(key, value);
}
if let Some(ref mtls_config) = subgraph.mtls {
if mtls_config.enabled {
match crate::mtls::MtlsProvider::new(mtls_config.clone()) {
Ok(provider) => {
let rt = tokio::runtime::Runtime::new().ok();
let mtls_client = rt
.as_ref()
.and_then(|rt| rt.block_on(provider.build_client()).ok());
if let Some(client) = mtls_client {
builder = builder.with_client(client);
tracing::info!(
subgraph = %subgraph.name,
trust_domain = %mtls_config.trust_domain,
"🔒 mTLS enabled for subgraph (Zero-Trust mode)"
);
} else if mtls_config.allow_fallback {
tracing::warn!(
subgraph = %subgraph.name,
"⚠️ mTLS client build failed, falling back to plain HTTP"
);
} else {
tracing::error!(
subgraph = %subgraph.name,
"❌ mTLS client build failed and fallback is disabled"
);
}
}
Err(e) => {
if mtls_config.allow_fallback {
tracing::warn!(
subgraph = %subgraph.name,
error = %e,
"⚠️ mTLS initialization failed, falling back to plain HTTP"
);
} else {
tracing::error!(
subgraph = %subgraph.name,
error = %e,
"❌ mTLS initialization failed and fallback is disabled"
);
}
}
}
}
}
clients.insert(
subgraph.name.clone(),
builder.build().expect("invalid connector config"),
);
circuit_breakers.insert(
subgraph.name.clone(),
Arc::new(CircuitBreaker::new(
subgraph.name.clone(),
cb_config.clone(),
)),
);
}
let apq_store = config.apq.clone().map(create_apq_store);
let collapsing_registry = config
.request_collapsing
.clone()
.map(create_request_collapsing_registry);
let cost_analyzer = config
.query_cost
.clone()
.map(|c| Arc::new(QueryCostAnalyzer::new(c)));
Self {
config,
clients,
cache: Arc::new(ShardedCache::<Bytes>::new(128, 10_000)), json_parser: Arc::new(FastJsonParser::new(64)), request_count: AtomicU64::new(0),
cache_hits: AtomicU64::new(0),
cache_ttl,
apq_store,
collapsing_registry,
cost_analyzer,
circuit_breakers,
}
}
pub async fn execute_scatter_gather(
&self,
query: Option<&str>,
variables: Option<&JsonValue>,
extensions: Option<&JsonValue>,
) -> Result<JsonValue> {
self.request_count.fetch_add(1, Ordering::Relaxed);
let start = Instant::now();
let query_string = if let Some(store) = &self.apq_store {
match process_apq_request(store, query, extensions) {
Ok(Some(q)) => q,
Ok(None) => {
return Err(crate::Error::Internal("No query provided".into()));
}
Err(e) => {
return Err(crate::Error::Internal(e.to_string()));
}
}
} else {
query
.map(String::from)
.ok_or_else(|| crate::Error::Internal("No query provided".into()))?
};
if self.config.disable_introspection && is_introspection(&query_string) {
tracing::warn!("Introspection query blocked");
return Err(crate::Error::Validation("Introspection is disabled".into()));
}
if let Some(waf_config) = &self.config.waf {
validate_raw(&query_string, variables, waf_config)
.map_err(|e| crate::Error::Validation(e.to_string()))?;
}
if let Some(analyzer) = &self.cost_analyzer {
analyzer
.calculate_query_cost(&query_string)
.await
.map_err(crate::Error::Validation)?;
}
let cache_key = self.compute_cache_key(&query_string, variables);
if let Some(cached) = self.cache.get(&cache_key) {
self.cache_hits.fetch_add(1, Ordering::Relaxed);
tracing::debug!(
query_hash = %&cache_key[..16.min(cache_key.len())],
latency_us = start.elapsed().as_micros(),
"⚡ Cache HIT"
);
return self
.json_parser
.parse_bytes(&cached)
.map_err(|e| crate::Error::Internal(format!("Cache parse error: {}", e)));
}
let broadcaster = if let Some(registry) = &self.collapsing_registry {
let req_key = RequestKey::new("router", "graphql", cache_key.as_bytes());
match registry.try_collapse(req_key).await {
CollapseResult::Follower(receiver) => {
tracing::debug!("Request collapsed (follower)");
let gql_val = receiver.recv().await.map_err(crate::Error::Internal)?;
return serde_json::to_value(gql_val).map_err(|e| {
crate::Error::Internal(format!("Serialization error: {}", e))
});
}
CollapseResult::Leader(b) => Some(b),
CollapseResult::Passthrough => None,
}
} else {
None
};
let mut futures = FuturesUnordered::new();
for (name, client) in &self.clients {
let name = name.clone();
let client = client.clone();
let force_gbp = self.config.force_gbp;
let query_str = query_string.clone();
let variables_clone = variables.cloned();
let circuit_breaker = self.circuit_breakers.get(&name).cloned();
futures.push(async move {
if let Some(cb) = &circuit_breaker {
if let Err(e) = cb.allow_request() {
return (
name,
Err(crate::Error::Internal(e.to_string())),
Duration::from_secs(0),
force_gbp,
);
}
}
let mut args = HashMap::with_capacity(2);
args.insert("query".to_string(), serde_json::json!(query_str));
if let Some(vars) = variables_clone {
args.insert("variables".to_string(), vars);
}
let req_start = Instant::now();
let res = client.execute("query", args).await;
let duration = req_start.elapsed();
if let Some(cb) = &circuit_breaker {
match &res {
Ok(val) => {
let has_errors = val
.get("errors")
.and_then(|e| e.as_array())
.is_some_and(|a| !a.is_empty());
let has_data = val.get("data").is_some_and(|d| !d.is_null());
if has_errors && !has_data {
cb.record_failure();
} else {
cb.record_success();
}
}
Err(_) => cb.record_failure(),
}
}
(name, res, duration, force_gbp)
});
}
let mut results = HashMap::with_capacity(self.clients.len());
let mut errors = Vec::new();
while let Some((name, res, duration, force_gbp)) = futures.next().await {
match res {
Ok(val) => {
tracing::debug!(
subgraph = %name,
latency_ms = format!("{:.2}", duration.as_secs_f64() * 1000.0),
gbp = force_gbp,
"✓ Subgraph response"
);
results.insert(name, val);
}
Err(e) => {
tracing::warn!(
subgraph = %name,
error = %e,
"✗ Subgraph failed"
);
errors.push((name, e.to_string()));
}
}
}
let response = serde_json::to_value(&results).unwrap();
let result = Ok(response);
if let Some(b) = broadcaster {
let broadcast_res = match &result {
Ok(val) => {
serde_json::from_value(val.clone())
.map_err(|e| format!("Conversion error: {}", e))
}
Err(e) => Err(format!("{:?}", e)),
};
b.broadcast(broadcast_res);
}
if errors.is_empty() {
if let Ok(val) = &result {
if let Ok(bytes) = serde_json::to_vec(val) {
self.cache
.insert(&cache_key, Bytes::from(bytes), self.cache_ttl);
}
}
}
let total_duration = start.elapsed();
tracing::info!(
subgraphs = results.len(),
errors = errors.len(),
latency_ms = format!("{:.2}", total_duration.as_secs_f64() * 1000.0),
cached = false,
"Federation query complete"
);
result
}
pub async fn execute_fail_fast(
&self,
query: Option<&str>,
variables: Option<&JsonValue>,
extensions: Option<&JsonValue>,
) -> Result<JsonValue> {
self.request_count.fetch_add(1, Ordering::Relaxed);
let query_string = if let Some(store) = &self.apq_store {
match process_apq_request(store, query, extensions) {
Ok(Some(q)) => q,
Ok(None) => return Err(crate::Error::Internal("No query provided".into())),
Err(e) => return Err(crate::Error::Internal(e.to_string())),
}
} else {
query
.map(String::from)
.ok_or_else(|| crate::Error::Internal("No query provided".into()))?
};
if self.config.disable_introspection && is_introspection(&query_string) {
return Err(crate::Error::Validation("Introspection is disabled".into()));
}
if let Some(waf_config) = &self.config.waf {
validate_raw(&query_string, variables, waf_config)
.map_err(|e| crate::Error::Validation(e.to_string()))?;
}
if let Some(analyzer) = &self.cost_analyzer {
analyzer
.calculate_query_cost(&query_string)
.await
.map_err(crate::Error::Validation)?;
}
let cache_key = self.compute_cache_key(&query_string, variables);
if let Some(cached) = self.cache.get(&cache_key) {
self.cache_hits.fetch_add(1, Ordering::Relaxed);
return self
.json_parser
.parse_bytes(&cached[..])
.map_err(|e| crate::Error::Internal(format!("Cache parse error: {}", e)));
}
let broadcaster = if let Some(registry) = &self.collapsing_registry {
let req_key = RequestKey::new("router", "graphql", cache_key.as_bytes());
match registry.try_collapse(req_key).await {
CollapseResult::Follower(receiver) => {
let gql_val = receiver.recv().await.map_err(crate::Error::Internal)?;
return serde_json::to_value(gql_val).map_err(|e| {
crate::Error::Internal(format!("Serialization error: {}", e))
});
}
CollapseResult::Leader(b) => Some(b),
CollapseResult::Passthrough => None,
}
} else {
None
};
let mut futures = FuturesUnordered::new();
for (name, client) in &self.clients {
let name = name.clone();
let client = client.clone();
let query_str = query_string.clone();
let variables_clone = variables.cloned();
let circuit_breaker = self.circuit_breakers.get(&name).cloned();
futures.push(async move {
if let Some(cb) = &circuit_breaker {
if let Err(e) = cb.allow_request() {
return (name, Err(crate::Error::Internal(e.to_string())));
}
}
let mut args = HashMap::with_capacity(2);
args.insert("query".to_string(), serde_json::json!(query_str));
if let Some(vars) = variables_clone {
args.insert("variables".to_string(), vars);
}
let res = client.execute("query", args).await;
if let Some(cb) = &circuit_breaker {
match &res {
Ok(val) => {
let has_errors = val
.get("errors")
.and_then(|e| e.as_array())
.is_some_and(|a| !a.is_empty());
let has_data = val.get("data").is_some_and(|d| !d.is_null());
if has_errors && !has_data {
cb.record_failure();
} else {
cb.record_success();
}
}
Err(_) => cb.record_failure(),
}
}
(name, res)
});
}
let mut results = HashMap::with_capacity(self.clients.len());
let mut final_res = Ok(serde_json::Value::Null);
while let Some((name, res)) = futures.next().await {
match res {
Ok(val) => {
results.insert(name, val);
}
Err(e) => {
final_res = Err(crate::Error::Internal(format!(
"Subgraph {} failed: {}",
name, e
)));
break;
}
}
}
if final_res.is_err() {
if let Some(b) = broadcaster {
b.broadcast(Err(final_res.as_ref().err().unwrap().to_string()));
}
return final_res;
}
let response = serde_json::to_value(&results).unwrap();
if let Some(b) = broadcaster {
let broadcast_msg = serde_json::from_value(response.clone())
.map_err(|e| format!("Conversion error: {}", e));
b.broadcast(broadcast_msg);
}
if let Ok(bytes) = serde_json::to_vec(&response) {
self.cache
.insert(&cache_key, Bytes::from(bytes), self.cache_ttl);
}
Ok(response)
}
#[inline]
fn compute_cache_key(&self, query: &str, variables: Option<&JsonValue>) -> String {
use std::hash::{Hash, Hasher};
let mut hasher = ahash::AHasher::default();
query.hash(&mut hasher);
if let Some(v) = variables {
v.to_string().hash(&mut hasher);
}
format!("q:{:x}", hasher.finish())
}
#[inline]
pub fn subgraph_count(&self) -> usize {
self.clients.len()
}
#[inline]
pub fn is_gbp_enabled(&self) -> bool {
self.config.force_gbp
}
pub fn stats(&self) -> RouterStats {
let total = self.request_count.load(Ordering::Relaxed);
let hits = self.cache_hits.load(Ordering::Relaxed);
RouterStats {
total_requests: total,
cache_hits: hits,
cache_hit_rate: if total > 0 {
hits as f64 / total as f64
} else {
0.0
},
subgraph_count: self.clients.len(),
gbp_enabled: self.config.force_gbp,
}
}
pub fn clear_cache(&self) {
self.cache.clear();
tracing::info!("Router cache cleared");
}
}
#[derive(Debug, Clone)]
pub struct RouterStats {
pub total_requests: u64,
pub cache_hits: u64,
pub cache_hit_rate: f64,
pub subgraph_count: usize,
pub gbp_enabled: bool,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ddos_config_defaults() {
let config = DdosConfig::default();
assert_eq!(config.global_rps, 10_000);
assert_eq!(config.per_ip_rps, 100);
assert_eq!(config.per_ip_burst, 200);
}
#[test]
fn test_ddos_config_strict() {
let config = DdosConfig::strict();
assert_eq!(config.global_rps, 5_000);
assert_eq!(config.per_ip_rps, 50);
}
#[test]
fn test_router_config() {
let config = RouterConfig {
port: 4000,
subgraphs: vec![SubgraphConfig {
name: "users".into(),
url: "http://localhost:4002".into(),
headers: HashMap::new(),
mtls: None,
}],
force_gbp: true,
apq: None,
request_collapsing: None,
waf: None,
query_cost: None,
disable_introspection: false,
circuit_breaker: None,
};
assert_eq!(config.subgraphs.len(), 1);
assert!(config.force_gbp);
}
#[tokio::test]
async fn test_ddos_protection_allows_normal_traffic() {
let ddos = DdosProtection::new(DdosConfig::default());
let ip: IpAddr = "192.168.1.1".parse().unwrap();
for _ in 0..50 {
assert!(ddos.check(ip).await, "Normal traffic should be allowed");
}
}
#[tokio::test]
async fn test_ddos_protection_blocks_excessive_traffic() {
let config = DdosConfig {
global_rps: 1000,
per_ip_rps: 10,
per_ip_burst: 20,
};
let ddos = DdosProtection::new(config);
let ip: IpAddr = "10.0.0.1".parse().unwrap();
let mut allowed = 0;
for _ in 0..25 {
if ddos.check(ip).await {
allowed += 1;
}
}
assert!(
(15..=22).contains(&allowed),
"Expected ~20 allowed, got {}",
allowed
);
}
#[tokio::test]
async fn test_ddos_protection_isolates_ips() {
let config = DdosConfig {
global_rps: 1000,
per_ip_rps: 5,
per_ip_burst: 10,
};
let ddos = DdosProtection::new(config);
let ip1: IpAddr = "1.1.1.1".parse().unwrap();
let ip2: IpAddr = "2.2.2.2".parse().unwrap();
for _ in 0..15 {
let _ = ddos.check(ip1).await;
}
assert!(
ddos.check(ip2).await,
"IP2 should not be affected by IP1's usage"
);
}
#[test]
fn test_ddos_config_relaxed() {
let config = DdosConfig::relaxed();
assert_eq!(config.global_rps, 100_000);
assert_eq!(config.per_ip_rps, 1_000);
assert_eq!(config.per_ip_burst, 2_000);
}
#[tokio::test]
async fn test_ddos_cleanup_stale_limiters() {
let config = DdosConfig::default();
let ddos = DdosProtection::new(config);
let ip: IpAddr = "127.0.0.1".parse().unwrap();
ddos.check(ip).await;
ddos.cleanup_stale_limiters(60).await;
}
#[test]
fn test_router_gbp_enabled() {
let config = RouterConfig {
port: 4000,
subgraphs: vec![SubgraphConfig {
name: "test".into(),
url: "http://localhost:8080".into(),
headers: HashMap::new(),
mtls: None,
}],
force_gbp: true,
apq: None,
request_collapsing: None,
waf: None,
query_cost: None,
disable_introspection: false,
circuit_breaker: None,
};
let router = GbpRouter::new(config);
assert!(router.is_gbp_enabled());
assert_eq!(router.subgraph_count(), 1);
let stats = router.stats();
assert!(stats.gbp_enabled);
assert_eq!(stats.subgraph_count, 1);
}
#[test]
fn test_router_stats_initial() {
let config = RouterConfig {
port: 4000,
subgraphs: vec![],
force_gbp: false,
apq: None,
request_collapsing: None,
waf: None,
query_cost: None,
disable_introspection: false,
circuit_breaker: None,
};
let router = GbpRouter::new(config);
let stats = router.stats();
assert_eq!(stats.total_requests, 0);
assert_eq!(stats.cache_hits, 0);
assert_eq!(stats.cache_hit_rate, 0.0);
}
#[test]
fn test_router_clear_cache() {
let config = RouterConfig {
port: 4000,
subgraphs: vec![],
force_gbp: false,
apq: None,
request_collapsing: None,
waf: None,
query_cost: None,
disable_introspection: false,
circuit_breaker: None,
};
let router = GbpRouter::new(config);
router.clear_cache();
}
#[tokio::test]
async fn test_execute_scatter_gather_network_failure() {
let config = RouterConfig {
port: 4000,
subgraphs: vec![SubgraphConfig {
name: "failing".into(),
url: "http://localhost:9999/graphql".into(), headers: HashMap::new(),
mtls: None,
}],
force_gbp: false,
apq: None,
request_collapsing: None,
waf: None,
query_cost: None,
disable_introspection: false,
circuit_breaker: None,
};
let router = GbpRouter::new(config);
let result = router
.execute_scatter_gather(Some("{ hello }"), None, None)
.await;
assert!(result.is_ok());
let val = result.unwrap();
assert!(val.is_object());
let obj = val.as_object().unwrap();
assert!(obj.is_empty());
let stats = router.stats();
assert_eq!(stats.total_requests, 1);
assert_eq!(stats.cache_hits, 0);
}
#[tokio::test]
async fn test_execute_fail_fast_network_failure() {
let config = RouterConfig {
port: 4000,
subgraphs: vec![SubgraphConfig {
name: "failing".into(),
url: "http://localhost:9999/graphql".into(),
headers: HashMap::new(),
mtls: None,
}],
force_gbp: false,
apq: None,
request_collapsing: None,
waf: None,
query_cost: None,
disable_introspection: false,
circuit_breaker: None,
};
let router = GbpRouter::new(config);
let result = router
.execute_fail_fast(Some("{ hello }"), None, None)
.await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_integrated_apq_collapsing_cache() {
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let server_request_count = Arc::new(AtomicU64::new(0));
let server_count_clone = server_request_count.clone();
tokio::spawn(async move {
loop {
let (mut socket, _) = listener.accept().await.unwrap();
let server_count = server_count_clone.clone();
tokio::spawn(async move {
use tokio::io::{AsyncReadExt, AsyncWriteExt};
let mut buf = [0; 1024];
let _ = socket.read(&mut buf).await;
tokio::time::sleep(Duration::from_millis(100)).await;
server_count.fetch_add(1, Ordering::Relaxed);
let response_body = r#"{"data": {"hello": "world"}}"#;
let response = format!(
"HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: {}\r\n\r\n{}",
response_body.len(),
response_body
);
let _ = socket.write_all(response.as_bytes()).await;
});
}
});
let config = RouterConfig {
port: 0, subgraphs: vec![SubgraphConfig {
name: "test".into(),
url: format!("http://{}", addr),
headers: HashMap::new(),
mtls: None,
}],
force_gbp: false,
apq: Some(PersistedQueryConfig::default()),
request_collapsing: Some(RequestCollapsingConfig::default()),
waf: None,
query_cost: None,
disable_introspection: false,
circuit_breaker: None,
};
let router = GbpRouter::new(config);
let query = "{ hello }";
let hash = crate::persisted_queries::PersistedQueryStore::hash_query(query);
let ext_with_query = serde_json::json!({
"persistedQuery": {
"version": 1,
"sha256Hash": hash
}
});
let ext_without_query = ext_with_query.clone();
let router = Arc::new(router);
let r1 = router.clone();
let r2 = router.clone();
let q1 = query.to_string();
let e1 = ext_with_query.clone();
let e2 = ext_without_query.clone();
let t1 = tokio::spawn(async move {
r1.execute_scatter_gather(Some(&q1), None, Some(&e1)).await
});
let t2 = tokio::spawn(async move {
tokio::task::yield_now().await;
r2.execute_scatter_gather(None, None, Some(&e2)).await
});
let (res1, res2) = tokio::join!(t1, t2);
let val1 = res1.unwrap().expect("Req1 failed");
let val2 = res2.unwrap().expect("Req2 failed");
assert_eq!(val1["test"]["data"]["hello"], "world");
assert_eq!(val2["test"]["data"]["hello"], "world");
assert_eq!(
server_request_count.load(Ordering::Relaxed),
1,
"Expected 1 server request due to collapsing"
);
let stats = router.stats();
assert_eq!(stats.total_requests, 2);
let val3 = router
.execute_scatter_gather(None, None, Some(&ext_without_query))
.await
.unwrap();
assert_eq!(val3["test"]["data"]["hello"], "world");
assert_eq!(
server_request_count.load(Ordering::Relaxed),
1,
"Expected cache hit (no new server request)"
);
let stats_after = router.stats();
assert_eq!(stats_after.cache_hits, 1, "Expected 1 cache hit");
assert_eq!(stats_after.total_requests, 3);
}
#[tokio::test]
async fn test_router_caching_behavior() {
}
#[test]
fn test_router_with_custom_ttl() {
let config = RouterConfig {
port: 4000,
subgraphs: vec![],
force_gbp: false,
apq: None,
request_collapsing: None,
waf: None,
query_cost: None,
disable_introspection: false,
circuit_breaker: None,
};
let router = GbpRouter::with_cache_ttl(config, Duration::from_secs(123));
assert_eq!(router.subgraph_count(), 0);
}
}
#[cfg(test)]
mod security_tests;