use anyhow::{anyhow, Result};
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet, VecDeque};
use std::sync::Arc;
use std::time::{Duration, Instant, SystemTime};
use tokio::sync::{Mutex as AsyncMutex, RwLock as AsyncRwLock, Semaphore};
use tracing::{debug, error, info, instrument, warn};
use crate::ast::OperationType;
use crate::performance::{OperationMetrics, PerformanceTracker};
#[derive(Debug, Clone)]
pub struct FederationGatewayConfig {
pub max_concurrent_requests: usize,
pub service_timeout: Duration,
pub enable_query_splitting: bool,
pub enable_cross_service_cache: bool,
pub enable_adaptive_load_balancing: bool,
pub enable_complexity_analysis: bool,
pub circuit_breaker_enabled: bool,
pub circuit_breaker_threshold: usize,
pub circuit_breaker_timeout: Duration,
pub max_retries: usize,
pub retry_delay: Duration,
pub health_check_interval: Duration,
pub enable_performance_monitoring: bool,
}
impl Default for FederationGatewayConfig {
fn default() -> Self {
Self {
max_concurrent_requests: 100,
service_timeout: Duration::from_secs(30),
enable_query_splitting: true,
enable_cross_service_cache: true,
enable_adaptive_load_balancing: true,
enable_complexity_analysis: true,
circuit_breaker_enabled: true,
circuit_breaker_threshold: 5,
circuit_breaker_timeout: Duration::from_secs(60),
max_retries: 3,
retry_delay: Duration::from_millis(100),
health_check_interval: Duration::from_secs(30),
enable_performance_monitoring: true,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ServiceEndpoint {
pub id: String,
pub name: String,
pub url: String,
pub schema_url: Option<String>,
pub headers: HashMap<String, String>,
pub weight: f64,
pub priority: u8,
pub capabilities: HashSet<String>,
pub supported_operations: HashSet<String>,
}
#[derive(Debug, Clone)]
pub struct ServiceHealth {
pub is_healthy: bool,
pub last_check: SystemTime,
pub response_time: Duration,
pub error_count: usize,
pub success_rate: f64,
pub circuit_breaker_open: bool,
}
#[derive(Debug, Clone)]
pub struct QueryExecutionPlan {
pub query_id: String,
pub total_complexity: usize,
pub estimated_duration: Duration,
pub service_assignments: HashMap<String, ServiceQueryFragment>,
pub dependency_graph: HashMap<String, Vec<String>>,
pub execution_order: Vec<String>,
pub optimization_strategy: OptimizationStrategy,
}
#[derive(Debug, Clone)]
pub struct ServiceQueryFragment {
pub service_id: String,
pub fragment_query: String,
pub expected_fields: HashSet<String>,
pub complexity: usize,
pub dependencies: Vec<String>,
pub cache_key: Option<String>,
pub timeout: Duration,
}
#[derive(Debug, Clone)]
pub enum OptimizationStrategy {
Parallel,
Sequential,
Adaptive,
Pipelined,
Batch,
}
pub struct IntelligentFederationGateway {
config: FederationGatewayConfig,
services: Arc<AsyncRwLock<HashMap<String, ServiceEndpoint>>>,
service_health: Arc<AsyncRwLock<HashMap<String, ServiceHealth>>>,
query_cache: Arc<AsyncRwLock<HashMap<String, CachedQueryResult>>>,
schema_registry: Arc<AsyncRwLock<FederatedSchemaRegistry>>,
performance_tracker: Arc<PerformanceTracker>,
request_semaphore: Arc<Semaphore>,
load_balancer: Arc<AsyncMutex<AdaptiveLoadBalancer>>,
circuit_breakers: Arc<AsyncRwLock<HashMap<String, CircuitBreaker>>>,
query_planner: Arc<AsyncMutex<IntelligentQueryPlanner>>,
}
#[derive(Debug, Clone)]
pub struct CachedQueryResult {
pub result: serde_json::Value,
pub cached_at: SystemTime,
pub ttl: Duration,
pub cache_tags: HashSet<String>,
pub access_count: usize,
}
#[derive(Debug)]
pub struct FederatedSchemaRegistry {
pub schemas: HashMap<String, GraphQLServiceSchema>,
pub unified_schema: Option<String>,
pub schema_version: u64,
pub last_updated: SystemTime,
}
#[derive(Debug, Clone)]
pub struct GraphQLServiceSchema {
pub service_id: String,
pub schema_sdl: String,
pub types: HashSet<String>,
pub queries: HashSet<String>,
pub mutations: HashSet<String>,
pub subscriptions: HashSet<String>,
pub directives: HashSet<String>,
pub last_fetched: SystemTime,
}
#[derive(Debug)]
pub struct AdaptiveLoadBalancer {
pub services: Vec<String>,
pub weights: HashMap<String, f64>,
pub response_times: HashMap<String, VecDeque<Duration>>,
pub error_rates: HashMap<String, f64>,
pub current_loads: HashMap<String, usize>,
pub last_adjustment: SystemTime,
}
#[derive(Debug)]
pub struct CircuitBreaker {
pub service_id: String,
pub state: CircuitBreakerState,
pub failure_count: usize,
pub last_failure: Option<SystemTime>,
pub failure_threshold: usize,
pub timeout: Duration,
}
#[derive(Debug, Clone, PartialEq)]
pub enum CircuitBreakerState {
Closed,
Open,
HalfOpen,
}
#[derive(Debug)]
pub struct IntelligentQueryPlanner {
pub optimization_history: HashMap<String, QueryOptimizationResult>,
pub service_capabilities: HashMap<String, ServiceCapabilities>,
pub query_patterns: HashMap<String, QueryPattern>,
pub performance_baselines: HashMap<String, PerformanceBaseline>,
}
#[derive(Debug, Clone)]
pub struct ServiceCapabilities {
pub max_complexity: usize,
pub supported_features: HashSet<String>,
pub performance_characteristics: PerformanceCharacteristics,
pub rate_limits: Option<RateLimit>,
}
#[derive(Debug, Clone)]
pub struct PerformanceCharacteristics {
pub avg_response_time: Duration,
pub p95_response_time: Duration,
pub throughput: f64,
pub cpu_efficiency: f64,
pub memory_efficiency: f64,
}
#[derive(Debug, Clone)]
pub struct RateLimit {
pub requests_per_second: f64,
pub burst_capacity: usize,
pub current_tokens: f64,
pub last_refill: SystemTime,
}
#[derive(Debug, Clone)]
pub struct QueryOptimizationResult {
pub original_complexity: usize,
pub optimized_complexity: usize,
pub optimization_time: Duration,
pub strategy_used: OptimizationStrategy,
pub performance_improvement: f64,
}
#[derive(Debug, Clone)]
pub struct QueryPattern {
pub pattern_signature: String,
pub typical_complexity: usize,
pub optimal_strategy: OptimizationStrategy,
pub success_rate: f64,
pub avg_execution_time: Duration,
}
#[derive(Debug, Clone)]
pub struct PerformanceBaseline {
pub baseline_time: Duration,
pub baseline_complexity: usize,
pub measured_at: SystemTime,
pub confidence_level: f64,
}
impl IntelligentFederationGateway {
pub fn new(config: FederationGatewayConfig) -> Self {
let request_semaphore = Arc::new(Semaphore::new(config.max_concurrent_requests));
Self {
config,
services: Arc::new(AsyncRwLock::new(HashMap::new())),
service_health: Arc::new(AsyncRwLock::new(HashMap::new())),
query_cache: Arc::new(AsyncRwLock::new(HashMap::new())),
schema_registry: Arc::new(AsyncRwLock::new(FederatedSchemaRegistry {
schemas: HashMap::new(),
unified_schema: None,
schema_version: 0,
last_updated: SystemTime::now(),
})),
performance_tracker: Arc::new(PerformanceTracker::new()),
request_semaphore,
load_balancer: Arc::new(AsyncMutex::new(AdaptiveLoadBalancer {
services: Vec::new(),
weights: HashMap::new(),
response_times: HashMap::new(),
error_rates: HashMap::new(),
current_loads: HashMap::new(),
last_adjustment: SystemTime::now(),
})),
circuit_breakers: Arc::new(AsyncRwLock::new(HashMap::new())),
query_planner: Arc::new(AsyncMutex::new(IntelligentQueryPlanner {
optimization_history: HashMap::new(),
service_capabilities: HashMap::new(),
query_patterns: HashMap::new(),
performance_baselines: HashMap::new(),
})),
}
}
#[instrument(skip(self))]
pub async fn register_service(&self, endpoint: ServiceEndpoint) -> Result<()> {
info!("Registering service: {} at {}", endpoint.name, endpoint.url);
let health = ServiceHealth {
is_healthy: true,
last_check: SystemTime::now(),
response_time: Duration::from_millis(100),
error_count: 0,
success_rate: 1.0,
circuit_breaker_open: false,
};
let circuit_breaker = CircuitBreaker {
service_id: endpoint.id.clone(),
state: CircuitBreakerState::Closed,
failure_count: 0,
last_failure: None,
failure_threshold: self.config.circuit_breaker_threshold,
timeout: self.config.circuit_breaker_timeout,
};
{
let mut services = self.services.write().await;
services.insert(endpoint.id.clone(), endpoint.clone());
}
{
let mut health_map = self.service_health.write().await;
health_map.insert(endpoint.id.clone(), health);
}
{
let mut breakers = self.circuit_breakers.write().await;
breakers.insert(endpoint.id.clone(), circuit_breaker);
}
if let Err(e) = self.fetch_service_schema(&endpoint.id).await {
warn!("Failed to fetch schema for service {}: {}", endpoint.id, e);
}
{
let mut lb = self.load_balancer.lock().await;
lb.services.push(endpoint.id.clone());
lb.weights.insert(endpoint.id.clone(), endpoint.weight);
lb.response_times
.insert(endpoint.id.clone(), VecDeque::new());
lb.error_rates.insert(endpoint.id.clone(), 0.0);
lb.current_loads.insert(endpoint.id.clone(), 0);
}
info!("Successfully registered service: {}", endpoint.id);
Ok(())
}
#[instrument(skip(self, query))]
pub async fn execute_federated_query(
&self,
query: &str,
variables: Option<serde_json::Value>,
operation_name: Option<String>,
) -> Result<serde_json::Value> {
let start_time = Instant::now();
let query_id = format!("query_{}", uuid::Uuid::new_v4());
info!("Executing federated query: {}", query_id);
if self.config.enable_cross_service_cache {
let cache_key = self.generate_cache_key(query, &variables);
if let Some(cached_result) = self.get_cached_result(&cache_key).await? {
debug!("Cache hit for query: {}", query_id);
return Ok(cached_result.result);
}
}
let _permit = self
.request_semaphore
.acquire()
.await
.map_err(|e| anyhow!("Failed to acquire request permit: {}", e))?;
let execution_plan = self.plan_query_execution(query, &query_id).await?;
info!(
"Query execution plan created with {} service fragments",
execution_plan.service_assignments.len()
);
let result = match execution_plan.optimization_strategy {
OptimizationStrategy::Parallel => self.execute_parallel_query(&execution_plan).await?,
OptimizationStrategy::Sequential => {
self.execute_sequential_query(&execution_plan).await?
}
OptimizationStrategy::Adaptive => self.execute_adaptive_query(&execution_plan).await?,
OptimizationStrategy::Pipelined => {
self.execute_pipelined_query(&execution_plan).await?
}
OptimizationStrategy::Batch => self.execute_batch_query(&execution_plan).await?,
};
if self.config.enable_cross_service_cache {
let cache_key = self.generate_cache_key(query, &variables);
self.cache_query_result(&cache_key, &result, &execution_plan)
.await?;
}
let execution_time = start_time.elapsed();
let query_hash = {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
query.hash(&mut hasher);
hasher.finish()
};
self.performance_tracker.record_operation(OperationMetrics {
operation_name: Some("federated_query".to_string()),
operation_type: OperationType::Query,
query_hash,
execution_time,
parsing_time: Duration::from_millis(0),
validation_time: Duration::from_millis(0),
planning_time: Duration::from_millis(0),
field_count: execution_plan.service_assignments.len(),
depth: 1,
complexity_score: execution_plan.total_complexity,
cache_hit: false,
error_count: 0,
timestamp: SystemTime::now(),
client_info: Default::default(),
});
info!(
"Federated query completed in {:?}: {}",
execution_time, query_id
);
Ok(result)
}
async fn plan_query_execution(
&self,
query: &str,
query_id: &str,
) -> Result<QueryExecutionPlan> {
let _planner = self.query_planner.lock().await;
let complexity = self.calculate_query_complexity(query).await?;
let service_assignments = self.assign_query_fragments(query).await?;
let dependency_graph = self.build_dependency_graph(&service_assignments).await?;
let execution_order = self.topological_sort(&dependency_graph)?;
let optimization_strategy = self
.choose_optimization_strategy(&service_assignments, complexity)
.await?;
Ok(QueryExecutionPlan {
query_id: query_id.to_string(),
total_complexity: complexity,
estimated_duration: Duration::from_millis(100 * complexity as u64),
service_assignments,
dependency_graph,
execution_order,
optimization_strategy,
})
}
async fn execute_parallel_query(&self, plan: &QueryExecutionPlan) -> Result<serde_json::Value> {
debug!(
"Executing parallel query with {} fragments",
plan.service_assignments.len()
);
let mut handles = Vec::new();
for (fragment_id, fragment) in &plan.service_assignments {
let fragment_clone = fragment.clone();
let fragment_id_clone = fragment_id.clone();
let self_clone = self.clone_for_async().await;
let handle = tokio::spawn(async move {
self_clone
.execute_service_fragment(&fragment_id_clone, &fragment_clone)
.await
});
handles.push((fragment_id.clone(), handle));
}
let mut results = HashMap::new();
for (fragment_id, handle) in handles {
match handle.await {
Ok(Ok(result)) => {
results.insert(fragment_id, result);
}
Ok(Err(e)) => {
error!("Fragment execution failed for {}: {}", fragment_id, e);
return Err(e);
}
Err(e) => {
error!("Fragment join failed for {}: {}", fragment_id, e);
return Err(anyhow!("Fragment execution error: {}", e));
}
}
}
self.merge_fragment_results(results).await
}
async fn execute_sequential_query(
&self,
plan: &QueryExecutionPlan,
) -> Result<serde_json::Value> {
debug!(
"Executing sequential query with {} fragments",
plan.execution_order.len()
);
let mut results = HashMap::new();
for fragment_id in &plan.execution_order {
if let Some(fragment) = plan.service_assignments.get(fragment_id) {
let result = self.execute_service_fragment(fragment_id, fragment).await?;
results.insert(fragment_id.clone(), result);
}
}
self.merge_fragment_results(results).await
}
async fn execute_adaptive_query(&self, plan: &QueryExecutionPlan) -> Result<serde_json::Value> {
let current_load = self.get_current_system_load().await?;
if current_load < 0.5 {
self.execute_parallel_query(plan).await
} else {
self.execute_sequential_query(plan).await
}
}
async fn execute_pipelined_query(
&self,
plan: &QueryExecutionPlan,
) -> Result<serde_json::Value> {
debug!("Executing pipelined query");
self.execute_parallel_query(plan).await
}
async fn execute_batch_query(&self, plan: &QueryExecutionPlan) -> Result<serde_json::Value> {
debug!(
"Executing batch query with {} fragments",
plan.service_assignments.len()
);
let mut service_groups: HashMap<String, Vec<(String, &ServiceQueryFragment)>> =
HashMap::new();
for (fragment_id, fragment) in &plan.service_assignments {
service_groups
.entry(fragment.service_id.clone())
.or_default()
.push((fragment_id.clone(), fragment));
}
let mut all_results = HashMap::new();
for (service_id, fragments) in service_groups {
let batch_results = self.execute_service_batch(&service_id, fragments).await?;
all_results.extend(batch_results);
}
self.merge_fragment_results(all_results).await
}
async fn execute_service_fragment(
&self,
_fragment_id: &str,
fragment: &ServiceQueryFragment,
) -> Result<serde_json::Value> {
if self.is_circuit_breaker_open(&fragment.service_id).await? {
return Err(anyhow!(
"Circuit breaker open for service: {}",
fragment.service_id
));
}
if let Some(cache_key) = &fragment.cache_key {
if let Some(cached) = self.get_cached_result(cache_key).await? {
return Ok(cached.result);
}
}
let timeout_duration = fragment.timeout;
tokio::time::timeout(timeout_duration, async {
self.execute_service_request(&fragment.service_id, &fragment.fragment_query)
.await
})
.await
.map_err(|_| anyhow!("Service request timeout for {}", fragment.service_id))?
}
async fn execute_service_request(
&self,
service_id: &str,
query: &str,
) -> Result<serde_json::Value> {
let service = {
let services = self.services.read().await;
services
.get(service_id)
.cloned()
.ok_or_else(|| anyhow!("Service not found: {}", service_id))?
};
let client = reqwest::Client::builder()
.timeout(self.config.service_timeout)
.build()?;
let request_body = serde_json::json!({
"query": query
});
let response = client
.post(&service.url)
.headers(self.build_request_headers(&service.headers)?)
.json(&request_body)
.send()
.await?;
if response.status().is_success() {
let result: serde_json::Value = response.json().await?;
self.record_service_success(service_id).await?;
Ok(result)
} else {
self.record_service_failure(service_id).await?;
Err(anyhow!(
"Service request failed with status: {}",
response.status()
))
}
}
async fn execute_service_batch(
&self,
service_id: &str,
fragments: Vec<(String, &ServiceQueryFragment)>,
) -> Result<HashMap<String, serde_json::Value>> {
let service = {
let services = self.services.read().await;
services
.get(service_id)
.cloned()
.ok_or_else(|| anyhow!("Service not found: {}", service_id))?
};
let client = reqwest::Client::builder()
.timeout(self.config.service_timeout)
.build()?;
let mut batch_results = HashMap::new();
for (fragment_id, fragment) in fragments {
let request_body = serde_json::json!({
"query": fragment.fragment_query
});
let response = client
.post(&service.url)
.headers(self.build_request_headers(&service.headers)?)
.json(&request_body)
.send()
.await?;
if response.status().is_success() {
let result: serde_json::Value = response.json().await?;
batch_results.insert(fragment_id, result);
self.record_service_success(service_id).await?;
} else {
self.record_service_failure(service_id).await?;
return Err(anyhow!(
"Batch service request failed for fragment {} with status: {}",
fragment_id,
response.status()
));
}
}
Ok(batch_results)
}
async fn clone_for_async(&self) -> Self {
Self {
config: self.config.clone(),
services: Arc::clone(&self.services),
service_health: Arc::clone(&self.service_health),
query_cache: Arc::clone(&self.query_cache),
schema_registry: Arc::clone(&self.schema_registry),
performance_tracker: Arc::clone(&self.performance_tracker),
request_semaphore: Arc::clone(&self.request_semaphore),
load_balancer: Arc::clone(&self.load_balancer),
circuit_breakers: Arc::clone(&self.circuit_breakers),
query_planner: Arc::clone(&self.query_planner),
}
}
async fn fetch_service_schema(&self, _service_id: &str) -> Result<()> {
Ok(())
}
fn generate_cache_key(&self, query: &str, variables: &Option<serde_json::Value>) -> String {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
let normalized_query = query
.chars()
.filter(|c| !c.is_whitespace())
.collect::<String>()
.to_lowercase();
normalized_query.hash(&mut hasher);
if let Some(vars) = variables {
vars.to_string().hash(&mut hasher);
}
let hash = hasher.finish();
format!("gql_cache_{hash:x}")
}
async fn get_cached_result(&self, _cache_key: &str) -> Result<Option<CachedQueryResult>> {
Ok(None)
}
async fn cache_query_result(
&self,
_cache_key: &str,
_result: &serde_json::Value,
_plan: &QueryExecutionPlan,
) -> Result<()> {
Ok(())
}
async fn calculate_query_complexity(&self, query: &str) -> Result<usize> {
let mut complexity = 0;
complexity += query.matches('{').count() * 2;
complexity += query.matches("query").count() * 5;
complexity += query.matches("mutation").count() * 10;
complexity += query.matches("subscription").count() * 15;
let max_nesting = query
.chars()
.fold((0u32, 0u32), |(max_depth, current_depth), c| match c {
'{' => (max_depth.max(current_depth + 1), current_depth + 1),
'}' => (max_depth, current_depth.saturating_sub(1)),
_ => (max_depth, current_depth),
})
.0;
complexity += (max_nesting as usize) * (max_nesting as usize);
complexity += query.matches('(').count();
complexity += query.matches('$').count() * 2;
complexity += query.matches("fragment").count() * 3;
complexity += query.matches("...").count() * 2;
complexity += query.matches("join").count() * 8;
complexity += query.matches("@").count();
Ok(complexity.max(1))
}
async fn assign_query_fragments(
&self,
_query: &str,
) -> Result<HashMap<String, ServiceQueryFragment>> {
Ok(HashMap::new())
}
async fn build_dependency_graph(
&self,
_assignments: &HashMap<String, ServiceQueryFragment>,
) -> Result<HashMap<String, Vec<String>>> {
Ok(HashMap::new())
}
fn topological_sort(&self, graph: &HashMap<String, Vec<String>>) -> Result<Vec<String>> {
let mut in_degree: HashMap<String, usize> = HashMap::new();
let mut adj_list: HashMap<String, Vec<String>> = HashMap::new();
for (node, neighbors) in graph {
adj_list.entry(node.clone()).or_default();
in_degree.entry(node.clone()).or_insert(0);
for neighbor in neighbors {
adj_list.entry(neighbor.clone()).or_default();
adj_list
.get_mut(node)
.expect("node should exist in adj_list after entry insert")
.push(neighbor.clone());
*in_degree.entry(neighbor.clone()).or_insert(0) += 1;
}
}
let mut queue: VecDeque<String> = in_degree
.iter()
.filter(|&(_, °ree)| degree == 0)
.map(|(node, _)| node.clone())
.collect();
let mut result = Vec::new();
while let Some(current) = queue.pop_front() {
result.push(current.clone());
if let Some(neighbors) = adj_list.get(¤t) {
for neighbor in neighbors {
if let Some(degree) = in_degree.get_mut(neighbor) {
*degree -= 1;
if *degree == 0 {
queue.push_back(neighbor.clone());
}
}
}
}
}
if result.len() != in_degree.len() {
return Err(anyhow!("Cycle detected in dependency graph"));
}
Ok(result)
}
async fn choose_optimization_strategy(
&self,
assignments: &HashMap<String, ServiceQueryFragment>,
complexity: usize,
) -> Result<OptimizationStrategy> {
let service_count = assignments.len();
let current_load = self.get_current_system_load().await?;
let has_dependencies = assignments
.values()
.any(|fragment| !fragment.dependencies.is_empty());
let cross_service_joins = assignments
.values()
.map(|fragment| fragment.dependencies.len())
.sum::<usize>();
match (complexity, service_count, current_load, has_dependencies) {
(c, s, l, _) if c > 500 && s > 5 && l > 0.8 => {
info!("Choosing sequential strategy due to high complexity and load");
Ok(OptimizationStrategy::Sequential)
}
(_, _, _, true) if cross_service_joins > 3 => {
info!("Choosing sequential strategy due to complex dependencies");
Ok(OptimizationStrategy::Sequential)
}
(c, s, l, _) if c > 200 && s > 2 && l > 0.5 => {
info!("Choosing batch strategy for medium complexity");
Ok(OptimizationStrategy::Batch)
}
(c, s, l, false) if c <= 200 || s <= 2 || l <= 0.3 => {
info!("Choosing parallel strategy for simple queries or low load");
Ok(OptimizationStrategy::Parallel)
}
_ => {
debug!("Using default parallel strategy");
Ok(OptimizationStrategy::Parallel)
}
}
}
async fn merge_fragment_results(
&self,
results: HashMap<String, serde_json::Value>,
) -> Result<serde_json::Value> {
if results.is_empty() {
return Ok(serde_json::json!({"data": null}));
}
if results.len() == 1 {
return Ok(results
.into_values()
.next()
.expect("results should not be empty when len == 1"));
}
let mut merged = serde_json::json!({
"data": {},
"errors": []
});
let mut data_map = serde_json::Map::new();
let mut all_errors = Vec::new();
for (service_id, result) in results {
if let Some(data) = result.get("data") {
if let Some(data_obj) = data.as_object() {
for (key, value) in data_obj {
if let Some(existing) = data_map.get(key) {
match (existing, value) {
(
serde_json::Value::Object(existing_obj),
serde_json::Value::Object(new_obj),
) => {
let mut merged_obj = existing_obj.clone();
for (k, v) in new_obj {
merged_obj.insert(k.clone(), v.clone());
}
data_map
.insert(key.clone(), serde_json::Value::Object(merged_obj));
}
(_, serde_json::Value::Null) => {
}
(serde_json::Value::Null, _) => {
data_map.insert(key.clone(), value.clone());
}
_ => {
data_map.insert(key.clone(), value.clone());
}
}
} else {
data_map.insert(key.clone(), value.clone());
}
}
}
}
if let Some(errors) = result.get("errors") {
if let Some(error_array) = errors.as_array() {
for error in error_array {
let mut enriched_error = error.clone();
if let Some(error_obj) = enriched_error.as_object_mut() {
error_obj.insert(
"service".to_string(),
serde_json::Value::String(service_id.clone()),
);
}
all_errors.push(enriched_error);
}
}
}
}
merged["data"] = serde_json::Value::Object(data_map);
if !all_errors.is_empty() {
merged["errors"] = serde_json::Value::Array(all_errors);
}
Ok(merged)
}
async fn get_current_system_load(&self) -> Result<f64> {
let performance_tracker = Arc::clone(&self.performance_tracker);
let mut load_factors = Vec::new();
let active_requests = self.request_semaphore.available_permits();
let max_requests = self.config.max_concurrent_requests;
let request_load = 1.0 - (active_requests as f64 / max_requests as f64);
load_factors.push(request_load * 0.4);
let services = self.services.read().await;
let total_services = services.len() as f64;
if total_services > 0.0 {
let service_health = self.service_health.read().await;
let healthy_services = service_health
.values()
.filter(|health| health.success_rate > 0.8)
.count() as f64;
let health_load = 1.0 - (healthy_services / total_services);
load_factors.push(health_load * 0.3); }
if let Ok(stats) = performance_tracker.get_stats() {
let avg_response_time = stats.avg_execution_time.as_millis() as f64;
let response_time_load = (avg_response_time / 1000.0).min(1.0);
load_factors.push(response_time_load * 0.2); }
let circuit_breakers = self.circuit_breakers.read().await;
if !circuit_breakers.is_empty() {
let open_breakers = circuit_breakers
.values()
.filter(|breaker| breaker.state == CircuitBreakerState::Open)
.count() as f64;
let total_breakers = circuit_breakers.len() as f64;
let breaker_load = open_breakers / total_breakers;
load_factors.push(breaker_load * 0.1); }
let total_load = load_factors.iter().sum::<f64>();
Ok(total_load.clamp(0.0, 1.0))
}
async fn is_circuit_breaker_open(&self, _service_id: &str) -> Result<bool> {
Ok(false)
}
fn build_request_headers(
&self,
_headers: &HashMap<String, String>,
) -> Result<reqwest::header::HeaderMap> {
Ok(reqwest::header::HeaderMap::new())
}
async fn record_service_success(&self, _service_id: &str) -> Result<()> {
Ok(())
}
async fn record_service_failure(&self, _service_id: &str) -> Result<()> {
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_gateway_creation() {
let config = FederationGatewayConfig::default();
let gateway = IntelligentFederationGateway::new(config);
assert_eq!(gateway.config.max_concurrent_requests, 100);
}
#[tokio::test]
async fn test_service_registration() {
let config = FederationGatewayConfig::default();
let gateway = IntelligentFederationGateway::new(config);
let endpoint = ServiceEndpoint {
id: "test-service".to_string(),
name: "Test Service".to_string(),
url: "http://localhost:4000/graphql".to_string(),
schema_url: None,
headers: HashMap::new(),
weight: 1.0,
priority: 1,
capabilities: HashSet::new(),
supported_operations: HashSet::new(),
};
let result = gateway.register_service(endpoint).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_query_complexity_calculation() {
let config = FederationGatewayConfig::default();
let gateway = IntelligentFederationGateway::new(config);
let simple_query = "query { hello }";
let complexity = gateway
.calculate_query_complexity(simple_query)
.await
.expect("should succeed");
assert!(complexity > 0);
let complex_query = r#"
query GetUserPosts($userId: ID!) {
user(id: $userId) {
id
name
posts {
id
title
comments {
id
author {
name
}
}
}
}
}
"#;
let complex_complexity = gateway
.calculate_query_complexity(complex_query)
.await
.expect("should succeed");
assert!(complex_complexity > complexity);
}
#[tokio::test]
async fn test_topological_sort() {
let config = FederationGatewayConfig::default();
let gateway = IntelligentFederationGateway::new(config);
let mut graph = HashMap::new();
graph.insert("A".to_string(), vec!["B".to_string(), "C".to_string()]);
graph.insert("B".to_string(), vec!["D".to_string()]);
graph.insert("C".to_string(), vec!["D".to_string()]);
graph.insert("D".to_string(), vec![]);
let sorted = gateway.topological_sort(&graph).expect("should succeed");
assert_eq!(sorted.len(), 4);
let a_pos = sorted
.iter()
.position(|x| x == "A")
.expect("should succeed");
let d_pos = sorted
.iter()
.position(|x| x == "D")
.expect("should succeed");
assert!(a_pos < d_pos); }
#[tokio::test]
async fn test_cache_key_generation() {
let config = FederationGatewayConfig::default();
let gateway = IntelligentFederationGateway::new(config);
let query = "query { hello }";
let variables = Some(serde_json::json!({"test": "value"}));
let key1 = gateway.generate_cache_key(query, &variables);
let key2 = gateway.generate_cache_key(query, &variables);
assert_eq!(key1, key2);
assert!(key1.starts_with("gql_cache_"));
let different_variables = Some(serde_json::json!({"test": "different"}));
let key3 = gateway.generate_cache_key(query, &different_variables);
assert_ne!(key1, key3);
}
#[tokio::test]
async fn test_optimization_strategy_selection() {
let config = FederationGatewayConfig::default();
let gateway = IntelligentFederationGateway::new(config);
let assignments = HashMap::new();
let strategy = gateway
.choose_optimization_strategy(&assignments, 50)
.await
.expect("should succeed");
assert!(matches!(strategy, OptimizationStrategy::Parallel));
let mut high_complexity_assignments = HashMap::new();
for i in 0..7 {
high_complexity_assignments.insert(
format!("service{i}"),
ServiceQueryFragment {
service_id: format!("service{i}"),
fragment_query: "test query".to_string(),
expected_fields: HashSet::new(),
complexity: 100,
dependencies: Vec::new(),
cache_key: None,
timeout: Duration::from_secs(30),
},
);
}
let strategy = gateway
.choose_optimization_strategy(&high_complexity_assignments, 1000)
.await
.expect("should succeed");
assert!(matches!(
strategy,
OptimizationStrategy::Sequential
| OptimizationStrategy::Batch
| OptimizationStrategy::Parallel
));
}
#[tokio::test]
async fn test_result_merging() {
let config = FederationGatewayConfig::default();
let gateway = IntelligentFederationGateway::new(config);
let mut results = HashMap::new();
results.insert(
"service1".to_string(),
serde_json::json!({
"data": {"field1": "value1"},
"errors": []
}),
);
results.insert(
"service2".to_string(),
serde_json::json!({
"data": {"field2": "value2"},
"errors": []
}),
);
let merged = gateway
.merge_fragment_results(results)
.await
.expect("should succeed");
assert!(merged.get("data").is_some());
let data = merged["data"].as_object().expect("should succeed");
assert_eq!(data["field1"], "value1");
assert_eq!(data["field2"], "value2");
}
#[tokio::test]
async fn test_system_load_calculation() {
let config = FederationGatewayConfig::default();
let gateway = IntelligentFederationGateway::new(config);
let load = gateway
.get_current_system_load()
.await
.expect("should succeed");
assert!(load >= 0.0);
assert!(load <= 1.0);
}
#[tokio::test]
async fn test_topological_sort_cycle_detection() {
let config = FederationGatewayConfig::default();
let gateway = IntelligentFederationGateway::new(config);
let mut graph = HashMap::new();
graph.insert("A".to_string(), vec!["B".to_string()]);
graph.insert("B".to_string(), vec!["C".to_string()]);
graph.insert("C".to_string(), vec!["A".to_string()]);
let result = gateway.topological_sort(&graph);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("Cycle detected"));
}
}