use anyhow::{Context, Result};
use futures_util::future;
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet, VecDeque};
use std::sync::Arc;
use super::config::FederationConfig;
use super::schema_stitcher::SchemaStitcher;
use crate::ast::{Document, Field, OperationDefinition, Selection, SelectionSet};
use crate::types::Schema;
pub struct QueryPlanner {
#[allow(dead_code)]
schema_stitcher: Arc<SchemaStitcher>,
config: FederationConfig,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QueryPlan {
pub steps: Vec<QueryStep>,
pub estimated_cost: f64,
pub can_execute_parallel: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QueryStep {
pub endpoint_id: String,
pub query_fragment: String,
pub dependencies: Vec<usize>,
pub result_variables: Vec<String>,
}
impl QueryPlanner {
pub fn new(schema_stitcher: Arc<SchemaStitcher>, config: FederationConfig) -> Self {
Self {
schema_stitcher,
config,
}
}
pub async fn plan_query(&self, query: &Document, merged_schema: &Schema) -> Result<QueryPlan> {
let mut steps = Vec::new();
for definition in &query.definitions {
if let crate::ast::Definition::Operation(op) = definition {
let service_steps = self.analyze_operation(op, merged_schema).await?;
steps.extend(service_steps);
}
}
self.optimize_step_order(&mut steps);
let estimated_cost = self.calculate_execution_cost(&steps);
let can_execute_parallel = self.can_parallelize(&steps);
Ok(QueryPlan {
steps,
estimated_cost,
can_execute_parallel,
})
}
async fn analyze_operation(
&self,
operation: &OperationDefinition,
schema: &Schema,
) -> Result<Vec<QueryStep>> {
let mut steps = Vec::new();
match operation.operation_type {
crate::ast::OperationType::Query => {
steps.extend(
self.analyze_selection_set(&operation.selection_set, schema)
.await?,
);
}
crate::ast::OperationType::Mutation => {
steps.extend(
self.analyze_selection_set(&operation.selection_set, schema)
.await?,
);
}
crate::ast::OperationType::Subscription => {
steps.extend(
self.analyze_selection_set(&operation.selection_set, schema)
.await?,
);
}
}
Ok(steps)
}
async fn analyze_selection_set(
&self,
selection_set: &SelectionSet,
schema: &Schema,
) -> Result<Vec<QueryStep>> {
let mut steps = Vec::new();
let mut service_queries: HashMap<String, Vec<String>> = HashMap::new();
for selection in &selection_set.selections {
match selection {
Selection::Field(field) => {
let service_id = self.determine_field_service(field, schema).await?;
service_queries
.entry(service_id.clone())
.or_default()
.push(field.name.clone());
if let Some(nested_selection_set) = &field.selection_set {
let nested_steps =
Box::pin(self.analyze_selection_set(nested_selection_set, schema))
.await?;
steps.extend(nested_steps);
}
}
Selection::InlineFragment(fragment) => {
let nested_steps =
Box::pin(self.analyze_selection_set(&fragment.selection_set, schema))
.await?;
steps.extend(nested_steps);
}
Selection::FragmentSpread(_) => {
}
}
}
for (service_id, fields) in service_queries {
let query_fragment = self.build_query_fragment(&fields);
steps.push(QueryStep {
endpoint_id: service_id,
query_fragment,
dependencies: Vec::new(),
result_variables: fields,
});
}
Ok(steps)
}
async fn determine_field_service(&self, field: &Field, _schema: &Schema) -> Result<String> {
for endpoint in &self.config.endpoints {
let namespace = endpoint.namespace.as_deref().unwrap_or(&endpoint.id);
if field.name.starts_with(&format!("{namespace}_")) {
return Ok(endpoint.id.clone());
}
}
Ok("local".to_string())
}
fn build_query_fragment(&self, fields: &[String]) -> String {
let field_list = fields.join(" ");
format!("{{ {field_list} }}")
}
fn optimize_step_order(&self, steps: &mut Vec<QueryStep>) {
let mut sorted_steps = Vec::new();
let mut remaining_steps: VecDeque<_> = steps.iter().enumerate().collect();
let mut resolved_indices = HashSet::new();
while !remaining_steps.is_empty() {
let mut made_progress = false;
for i in 0..remaining_steps.len() {
let (idx, step) = remaining_steps[i];
let dependencies_resolved = step
.dependencies
.iter()
.all(|dep_idx| resolved_indices.contains(dep_idx));
if dependencies_resolved {
sorted_steps.push(step.clone());
resolved_indices.insert(idx);
remaining_steps.remove(i);
made_progress = true;
break;
}
}
if !made_progress {
tracing::warn!("Unable to resolve all step dependencies, using original order");
break;
}
}
if sorted_steps.len() == steps.len() {
*steps = sorted_steps;
}
}
fn calculate_execution_cost(&self, steps: &[QueryStep]) -> f64 {
let mut total_cost = 0.0;
for step in steps {
total_cost += 1.0;
total_cost += step.result_variables.len() as f64 * 0.1;
total_cost += step.dependencies.len() as f64 * 0.5;
}
total_cost
}
fn can_parallelize(&self, steps: &[QueryStep]) -> bool {
for step in steps {
if !step.dependencies.is_empty() {
return false;
}
}
true
}
pub async fn execute_plan(&self, plan: &QueryPlan) -> Result<serde_json::Value> {
let mut results: HashMap<usize, serde_json::Value> = HashMap::new();
if plan.can_execute_parallel {
let futures: Vec<_> = plan
.steps
.iter()
.enumerate()
.map(|(idx, step)| {
let step = step.clone();
async move {
let result = self.execute_step(&step).await?;
Ok::<(usize, serde_json::Value), anyhow::Error>((idx, result))
}
})
.collect();
let results_vec = future::try_join_all(futures).await?;
results.extend(results_vec);
} else {
for (idx, step) in plan.steps.iter().enumerate() {
let result = self.execute_step(step).await?;
results.insert(idx, result);
}
}
self.merge_execution_results(&results)
}
async fn execute_step(&self, step: &QueryStep) -> Result<serde_json::Value> {
let endpoint = self
.config
.endpoints
.iter()
.find(|ep| ep.id == step.endpoint_id)
.ok_or_else(|| anyhow::anyhow!("Endpoint not found: {}", step.endpoint_id))?;
let client = reqwest::Client::new();
let mut request = client.post(&endpoint.url).json(&serde_json::json!({
"query": step.query_fragment,
"variables": {}
}));
if let Some(auth) = &endpoint.auth_header {
request = request.header("Authorization", auth);
}
let response = request
.timeout(std::time::Duration::from_secs(endpoint.timeout_secs))
.send()
.await
.context("Failed to execute federated query step")?;
if !response.status().is_success() {
return Err(anyhow::anyhow!(
"Query step failed with status: {}",
response.status()
));
}
let result: serde_json::Value = response
.json()
.await
.context("Failed to parse query step response")?;
if let Some(errors) = result.get("errors") {
return Err(anyhow::anyhow!(
"GraphQL errors in federated step: {}",
serde_json::to_string_pretty(errors)?
));
}
Ok(result
.get("data")
.unwrap_or(&serde_json::Value::Null)
.clone())
}
fn merge_execution_results(
&self,
results: &HashMap<usize, serde_json::Value>,
) -> Result<serde_json::Value> {
let mut merged = serde_json::Map::new();
for result in results.values() {
if let serde_json::Value::Object(obj) = result {
for (key, value) in obj {
merged.insert(key.clone(), value.clone());
}
}
}
Ok(serde_json::Value::Object(merged))
}
}