use super::*;
use std::net::SocketAddr;
use std::time::Duration;
use warp::{Filter, Reply};
fn get_free_port() -> u16 {
std::net::TcpListener::bind("127.0.0.1:0")
.unwrap()
.local_addr()
.unwrap()
.port()
}
async fn spawn_mock_subgraph(port: u16, behavior: &str) -> SocketAddr {
let behavior = behavior.to_string();
let route = warp::post().and(warp::path("graphql")).map(move || {
match behavior.as_str() {
"normal" => warp::reply::json(&serde_json::json!({
"data": { "users": [{ "id": "1", "name": "Test" }] }
}))
.into_response(),
"malformed" => {
warp::reply::with_status("Not JSON", warp::http::StatusCode::OK).into_response()
}
"error" => {
warp::reply::with_status(
"Internal Error",
warp::http::StatusCode::INTERNAL_SERVER_ERROR,
)
.into_response()
}
"slow" => {
std::thread::sleep(Duration::from_millis(100));
warp::reply::json(&serde_json::json!({ "data": "delayed" })).into_response()
}
"huge" => {
let huge_data: Vec<String> = (0..10_000).map(|i| format!("user-{}", i)).collect();
warp::reply::json(&serde_json::json!({ "data": { "users": huge_data } }))
.into_response()
}
_ => warp::reply::json(&serde_json::json!({})).into_response(),
}
});
let port = if port == 0 { get_free_port() } else { port };
let addr = SocketAddr::from(([127, 0, 0, 1], port));
let server = warp::serve(route).run(addr);
tokio::spawn(server);
addr
}
#[tokio::test]
async fn test_security_input_fuzzing_massive_query() {
let config = RouterConfig {
port: 0,
subgraphs: vec![],
force_gbp: false,
apq: None,
request_collapsing: None,
waf: None,
query_cost: None,
disable_introspection: false,
circuit_breaker: None,
};
let router = GbpRouter::new(config);
let massive_query = "query { ".to_string() + &"a ".repeat(5_000_000) + "}";
let result = router
.execute_scatter_gather(Some(&massive_query), None, None)
.await;
assert!(result.is_ok() || result.is_err());
}
#[tokio::test]
async fn test_security_ddos_concurrent_flooding() {
let ddos_config = DdosConfig {
global_rps: 100,
per_ip_rps: 10,
per_ip_burst: 10,
};
let protection = DdosProtection::new(ddos_config);
let protection = Arc::new(protection);
let ip: IpAddr = "192.168.1.100".parse().unwrap();
let mut tasks = Vec::new();
for _ in 0..100 {
let p = protection.clone();
tasks.push(tokio::spawn(async move { p.check(ip).await }));
}
let mut allowed_count = 0;
for task in tasks {
if task.await.unwrap() {
allowed_count += 1;
}
}
assert!(
allowed_count <= 20,
"DDoS protection failed to throttle concurrent flood. Allowed: {}",
allowed_count
);
}
#[tokio::test]
async fn test_security_subgraph_response_validation() {
let mock_addr = spawn_mock_subgraph(0, "malformed").await;
let config = RouterConfig {
port: 0,
subgraphs: vec![SubgraphConfig {
name: "bad_subgraph".into(),
url: format!("http://{}/graphql", mock_addr),
headers: std::collections::HashMap::new(),
mtls: None,
}],
force_gbp: false,
apq: None,
request_collapsing: None,
waf: None,
query_cost: None,
disable_introspection: false,
circuit_breaker: None,
};
let router = GbpRouter::new(config);
let result = router
.execute_scatter_gather(Some("query { test }"), None, None)
.await;
let json = result.expect("Router should not crash on malformed subgraph response");
let obj = json.as_object().unwrap();
if let Some(val) = obj.get("bad_subgraph") {
assert_eq!(
val.as_str(),
Some("Not JSON"),
"Should return raw body if parsing fails"
);
} else {
}
}
#[tokio::test]
async fn test_security_huge_response_handling() {
let mock_addr = spawn_mock_subgraph(0, "huge").await;
let config = RouterConfig {
port: 0,
subgraphs: vec![SubgraphConfig {
name: "huge_subgraph".into(),
url: format!("http://{}/graphql", mock_addr),
headers: std::collections::HashMap::new(),
mtls: None,
}],
force_gbp: false,
apq: None,
request_collapsing: None,
waf: None,
query_cost: None,
disable_introspection: false,
circuit_breaker: None,
};
let router = GbpRouter::new(config);
let result = router
.execute_scatter_gather(Some("query { users }"), None, None)
.await;
assert!(result.is_ok());
let json = result.unwrap();
let data = &json["huge_subgraph"]["data"]["users"];
assert!(data.is_array());
assert_eq!(data.as_array().unwrap().len(), 10_000);
}
#[tokio::test]
async fn test_security_deeply_nested_query() {
let depth = 500;
let mut query = "query { ".to_string();
for i in 0..depth {
query.push_str(&format!("level{} {{ ", i));
}
query.push_str("field");
for _ in 0..depth {
query.push_str(" }");
}
query.push_str(" }");
let config = RouterConfig {
port: 0,
subgraphs: vec![],
force_gbp: false,
apq: None,
request_collapsing: None,
waf: None,
query_cost: None,
disable_introspection: false,
circuit_breaker: None,
};
let router = GbpRouter::new(config);
let result = router
.execute_scatter_gather(Some(&query), None, None)
.await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_isolate_slow_loris_subgraph() {
let fast_addr = spawn_mock_subgraph(0, "normal").await;
let slow_addr = spawn_mock_subgraph(0, "slow").await;
let config = RouterConfig {
port: 0,
subgraphs: vec![
SubgraphConfig {
name: "fast".into(),
url: format!("http://{}/graphql", fast_addr),
headers: std::collections::HashMap::new(),
mtls: None,
},
SubgraphConfig {
name: "slow".into(),
url: format!("http://{}/graphql", slow_addr),
headers: std::collections::HashMap::new(),
mtls: None,
},
],
force_gbp: false,
apq: None,
request_collapsing: None,
waf: None,
query_cost: None,
disable_introspection: false,
circuit_breaker: None,
};
let router = GbpRouter::new(config);
let start = std::time::Instant::now();
let result = router
.execute_scatter_gather(Some("query { ... }"), None, None)
.await;
let duration = start.elapsed();
assert!(result.is_ok());
assert!(duration.as_millis() >= 100);
assert!(duration.as_millis() < 4000);
let json = result.unwrap();
assert!(json["fast"]["data"].is_object());
assert!(json["slow"]["data"].is_string()); }
#[tokio::test]
async fn test_circuit_breaker_integration() {
use crate::circuit_breaker::CircuitBreakerConfig;
let mock_addr = spawn_mock_subgraph(0, "error").await;
let config = RouterConfig {
port: 0,
subgraphs: vec![SubgraphConfig {
name: "failing_service".into(),
url: format!("http://{}/graphql", mock_addr),
headers: std::collections::HashMap::new(),
mtls: None,
}],
force_gbp: false,
apq: None,
request_collapsing: None,
waf: None,
query_cost: None,
disable_introspection: false,
circuit_breaker: Some(CircuitBreakerConfig {
failure_threshold: 2,
recovery_timeout: Duration::from_secs(10),
half_open_max_requests: 1,
}),
};
let router = GbpRouter::new(config);
let _ = router
.execute_scatter_gather(Some("query { test }"), None, None)
.await;
let _ = router
.execute_scatter_gather(Some("query { test }"), None, None)
.await;
let result = router
.execute_fail_fast(Some("query { test }"), None, None)
.await;
assert!(result.is_err(), "Expected error when circuit is open");
let err_str = result.err().unwrap().to_string();
assert!(
err_str.contains("Circuit breaker open"),
"Should contain circuit breaker error, got: {}",
err_str
);
}
#[tokio::test]
async fn test_security_ddos_zero_values_should_not_crash() {
let ddos_config = DdosConfig {
global_rps: 0,
per_ip_rps: 0,
per_ip_burst: 0,
};
let protection = DdosProtection::new(ddos_config);
let allowed = protection.check("1.2.3.4".parse().unwrap()).await;
assert!(!allowed, "Should block requests when RPS is 0");
}