use anyhow::{anyhow, Result};
use serde::{Deserialize, Serialize};
use std::collections::{BTreeMap, HashMap, HashSet};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SourceStats {
pub avg_latency_ms: u64,
pub cost_per_field: f64,
pub supports_batching: bool,
pub max_batch_size: usize,
}
impl Default for SourceStats {
fn default() -> Self {
Self {
avg_latency_ms: 50,
cost_per_field: 1.0,
supports_batching: true,
max_batch_size: 0,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FederationSource {
pub id: String,
pub url: String,
pub owned_types: Vec<String>,
pub owned_fields: Vec<String>,
pub stats: SourceStats,
}
impl FederationSource {
pub fn new(id: impl Into<String>, url: impl Into<String>) -> Self {
Self {
id: id.into(),
url: url.into(),
owned_types: Vec::new(),
owned_fields: Vec::new(),
stats: SourceStats::default(),
}
}
pub fn with_type(mut self, type_name: impl Into<String>) -> Self {
self.owned_types.push(type_name.into());
self
}
pub fn with_field(mut self, field: impl Into<String>) -> Self {
self.owned_fields.push(field.into());
self
}
pub fn with_stats(mut self, stats: SourceStats) -> Self {
self.stats = stats;
self
}
pub fn owns_type(&self, type_name: &str) -> bool {
self.owned_types.iter().any(|t| t == type_name)
}
pub fn owns_field(&self, field: &str) -> bool {
self.owned_fields.iter().any(|f| f == field) || {
let type_part = field.split('.').next().unwrap_or("");
self.owns_type(type_part)
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct FieldRequest {
pub type_name: String,
pub field_name: String,
pub source_id: String,
}
impl FieldRequest {
pub fn new(
type_name: impl Into<String>,
field_name: impl Into<String>,
source_id: impl Into<String>,
) -> Self {
Self {
type_name: type_name.into(),
field_name: field_name.into(),
source_id: source_id.into(),
}
}
pub fn qualified_name(&self) -> String {
format!("{}.{}", self.type_name, self.field_name)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BatchedSubPlan {
pub source_id: String,
pub fields: Vec<FieldRequest>,
pub estimated_cost: f64,
pub estimated_latency_ms: u64,
pub depends_on: Vec<usize>,
}
impl BatchedSubPlan {
pub fn field_count(&self) -> usize {
self.fields.len()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EnhancedFederationPlan {
pub sub_plans: Vec<BatchedSubPlan>,
pub total_cost: f64,
pub critical_path_latency_ms: u64,
pub is_fully_parallel: bool,
pub contributing_sources: Vec<String>,
}
impl EnhancedFederationPlan {
pub fn is_empty(&self) -> bool {
self.sub_plans.is_empty()
}
pub fn total_field_count(&self) -> usize {
self.sub_plans.iter().map(|sp| sp.field_count()).sum()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EnhancedPlannerConfig {
pub cross_source_penalty: f64,
pub prefer_low_latency: bool,
pub max_plan_cost: Option<f64>,
pub enable_batching: bool,
}
impl Default for EnhancedPlannerConfig {
fn default() -> Self {
Self {
cross_source_penalty: 1.5,
prefer_low_latency: true,
max_plan_cost: None,
enable_batching: true,
}
}
}
#[derive(Debug, Default)]
pub struct EnhancedFederationPlanner {
config: EnhancedPlannerConfig,
sources: HashMap<String, FederationSource>,
}
impl EnhancedFederationPlanner {
pub fn new() -> Self {
Self::default()
}
pub fn with_config(config: EnhancedPlannerConfig) -> Self {
Self {
config,
sources: HashMap::new(),
}
}
pub fn register_source(&mut self, source: FederationSource) {
self.sources.insert(source.id.clone(), source);
}
pub fn route_field(&self, type_name: &str, field_name: &str) -> Option<&FederationSource> {
let qualified = format!("{type_name}.{field_name}");
let exact = self
.sources
.values()
.filter(|s| s.owned_fields.iter().any(|f| f == &qualified))
.min_by_key(|s| {
if self.config.prefer_low_latency {
s.stats.avg_latency_ms
} else {
0
}
});
if exact.is_some() {
return exact;
}
self.sources
.values()
.filter(|s| s.owns_type(type_name))
.min_by_key(|s| {
if self.config.prefer_low_latency {
s.stats.avg_latency_ms
} else {
0
}
})
}
pub fn plan_fields(&self, requests: &[FieldRequest]) -> Result<EnhancedFederationPlan> {
if requests.is_empty() {
return Ok(EnhancedFederationPlan {
sub_plans: Vec::new(),
total_cost: 0.0,
critical_path_latency_ms: 0,
is_fully_parallel: true,
contributing_sources: Vec::new(),
});
}
let mut source_to_fields: BTreeMap<String, Vec<FieldRequest>> = BTreeMap::new();
for req in requests {
let source = self
.route_field(&req.type_name, &req.field_name)
.ok_or_else(|| {
anyhow!("No source owns field {}.{}", req.type_name, req.field_name)
})?;
source_to_fields
.entry(source.id.clone())
.or_default()
.push(req.clone());
}
let mut sub_plans: Vec<BatchedSubPlan> = Vec::new();
for (source_id, fields) in &source_to_fields {
let source = self
.sources
.get(source_id)
.ok_or_else(|| anyhow!("Source '{}' registered but missing from map", source_id))?;
let chunks = if self.config.enable_batching && source.stats.supports_batching {
let max = if source.stats.max_batch_size == 0 {
fields.len()
} else {
source.stats.max_batch_size
};
fields.chunks(max.max(1)).collect::<Vec<_>>()
} else {
fields.iter().map(std::slice::from_ref).collect::<Vec<_>>()
};
for chunk in chunks {
let n = chunk.len() as f64;
let cost = n * source.stats.cost_per_field
+ if sub_plans.is_empty() {
0.0
} else {
self.config.cross_source_penalty
};
sub_plans.push(BatchedSubPlan {
source_id: source_id.clone(),
fields: chunk.to_vec(),
estimated_cost: cost,
estimated_latency_ms: source.stats.avg_latency_ms,
depends_on: Vec::new(),
});
}
}
let source_order: Vec<String> = sub_plans.iter().map(|sp| sp.source_id.clone()).collect();
let mut seen_sources: HashSet<String> = HashSet::new();
for (plan_idx, plan) in sub_plans.iter_mut().enumerate() {
let plan_source = plan.source_id.clone();
let plan_fields: Vec<String> =
plan.fields.iter().map(|f| f.type_name.clone()).collect();
let mut new_deps: Vec<usize> = Vec::new();
for type_name in &plan_fields {
for (idx, src) in source_order.iter().enumerate() {
if idx < plan_idx
&& seen_sources.contains(src)
&& src != &plan_source
&& self
.sources
.get(src.as_str())
.is_some_and(|s| s.owns_type(type_name))
&& !new_deps.contains(&idx)
{
new_deps.push(idx);
}
}
}
for dep in new_deps {
if !plan.depends_on.contains(&dep) {
plan.depends_on.push(dep);
}
}
seen_sources.insert(plan_source);
}
for plan in &mut sub_plans {
plan.depends_on.sort_unstable();
plan.depends_on.dedup();
}
let total_cost: f64 = sub_plans.iter().map(|sp| sp.estimated_cost).sum();
if let Some(max_cost) = self.config.max_plan_cost {
if total_cost > max_cost {
return Err(anyhow!(
"Plan cost {total_cost:.2} exceeds configured maximum {max_cost:.2}"
));
}
}
let critical_path_latency_ms = self.compute_critical_path(&sub_plans);
let is_fully_parallel = sub_plans.iter().all(|sp| sp.depends_on.is_empty());
let contributing_sources: Vec<String> = sub_plans
.iter()
.map(|sp| sp.source_id.clone())
.collect::<HashSet<_>>()
.into_iter()
.collect();
Ok(EnhancedFederationPlan {
sub_plans,
total_cost,
critical_path_latency_ms,
is_fully_parallel,
contributing_sources,
})
}
fn compute_critical_path(&self, sub_plans: &[BatchedSubPlan]) -> u64 {
if sub_plans.is_empty() {
return 0;
}
let mut finish_time = vec![0u64; sub_plans.len()];
for (i, sp) in sub_plans.iter().enumerate() {
let earliest_start = sp
.depends_on
.iter()
.map(|&dep| finish_time[dep])
.max()
.unwrap_or(0);
finish_time[i] = earliest_start + sp.estimated_latency_ms;
}
finish_time.into_iter().max().unwrap_or(0)
}
pub fn sources(&self) -> impl Iterator<Item = &FederationSource> {
self.sources.values()
}
pub fn estimate_cost(&self, requests: &[FieldRequest]) -> f64 {
let mut total = 0.0;
let mut seen_sources: HashSet<&str> = HashSet::new();
for req in requests {
if let Some(source) = self.route_field(&req.type_name, &req.field_name) {
total += source.stats.cost_per_field;
if !seen_sources.contains(source.id.as_str()) {
if !seen_sources.is_empty() {
total += self.config.cross_source_penalty;
}
seen_sources.insert(source.id.as_str());
}
}
}
total
}
}
#[cfg(test)]
mod tests {
use super::*;
fn user_source() -> FederationSource {
FederationSource::new("users", "https://users.example.com/graphql")
.with_type("User")
.with_field("User.id")
.with_field("User.name")
.with_field("User.email")
.with_stats(SourceStats {
avg_latency_ms: 20,
cost_per_field: 1.0,
supports_batching: true,
max_batch_size: 0,
})
}
fn product_source() -> FederationSource {
FederationSource::new("products", "https://products.example.com/graphql")
.with_type("Product")
.with_field("Product.sku")
.with_field("Product.price")
.with_stats(SourceStats {
avg_latency_ms: 30,
cost_per_field: 2.0,
supports_batching: true,
max_batch_size: 10,
})
}
fn review_source() -> FederationSource {
FederationSource::new("reviews", "https://reviews.example.com/graphql")
.with_type("Review")
.with_field("Review.rating")
.with_field("Review.body")
.with_stats(SourceStats {
avg_latency_ms: 40,
cost_per_field: 1.5,
supports_batching: false,
max_batch_size: 0,
})
}
fn make_planner() -> EnhancedFederationPlanner {
let mut planner = EnhancedFederationPlanner::new();
planner.register_source(user_source());
planner.register_source(product_source());
planner.register_source(review_source());
planner
}
#[test]
fn test_source_owns_type() {
let src = user_source();
assert!(src.owns_type("User"));
assert!(!src.owns_type("Product"));
}
#[test]
fn test_source_owns_field_exact() {
let src = user_source();
assert!(src.owns_field("User.name"));
assert!(!src.owns_field("Product.unknown"));
}
#[test]
fn test_source_owns_field_by_type() {
let src = user_source();
assert!(src.owns_field("User.anything"));
}
#[test]
fn test_source_builder_chaining() {
let src = FederationSource::new("test", "http://test")
.with_type("Foo")
.with_field("Foo.bar");
assert_eq!(src.owned_types.len(), 1);
assert_eq!(src.owned_fields.len(), 1);
}
#[test]
fn test_field_request_qualified_name() {
let req = FieldRequest::new("User", "name", "users");
assert_eq!(req.qualified_name(), "User.name");
}
#[test]
fn test_field_request_equality() {
let a = FieldRequest::new("User", "id", "users");
let b = FieldRequest::new("User", "id", "users");
assert_eq!(a, b);
}
#[test]
fn test_route_field_exact() {
let planner = make_planner();
let src = planner.route_field("User", "name");
assert!(src.is_some());
assert_eq!(src.expect("should succeed").id, "users");
}
#[test]
fn test_route_field_type_fallback() {
let planner = make_planner();
let src = planner.route_field("User", "bio");
assert!(src.is_some());
assert_eq!(src.expect("should succeed").id, "users");
}
#[test]
fn test_route_field_unknown_returns_none() {
let planner = make_planner();
let src = planner.route_field("Ghost", "field");
assert!(src.is_none());
}
#[test]
fn test_route_field_product() {
let planner = make_planner();
let src = planner.route_field("Product", "price");
assert!(src.is_some());
assert_eq!(src.expect("should succeed").id, "products");
}
#[test]
fn test_plan_empty_requests() {
let planner = make_planner();
let plan = planner.plan_fields(&[]).expect("should succeed");
assert!(plan.is_empty());
assert_eq!(plan.total_cost, 0.0);
}
#[test]
fn test_plan_single_source() {
let planner = make_planner();
let requests = vec![
FieldRequest::new("User", "id", "users"),
FieldRequest::new("User", "name", "users"),
];
let plan = planner.plan_fields(&requests).expect("should succeed");
assert_eq!(plan.sub_plans.len(), 1);
assert_eq!(plan.sub_plans[0].source_id, "users");
assert_eq!(plan.sub_plans[0].field_count(), 2);
}
#[test]
fn test_plan_multiple_sources() {
let planner = make_planner();
let requests = vec![
FieldRequest::new("User", "name", "users"),
FieldRequest::new("Product", "price", "products"),
];
let plan = planner.plan_fields(&requests).expect("should succeed");
assert_eq!(plan.sub_plans.len(), 2);
assert_eq!(plan.contributing_sources.len(), 2);
}
#[test]
fn test_plan_total_field_count() {
let planner = make_planner();
let requests = vec![
FieldRequest::new("User", "id", "users"),
FieldRequest::new("User", "name", "users"),
FieldRequest::new("Product", "sku", "products"),
];
let plan = planner.plan_fields(&requests).expect("should succeed");
assert_eq!(plan.total_field_count(), 3);
}
#[test]
fn test_plan_cost_nonzero() {
let planner = make_planner();
let requests = vec![FieldRequest::new("User", "id", "users")];
let plan = planner.plan_fields(&requests).expect("should succeed");
assert!(plan.total_cost > 0.0);
}
#[test]
fn test_plan_latency_nonzero() {
let planner = make_planner();
let requests = vec![FieldRequest::new("User", "id", "users")];
let plan = planner.plan_fields(&requests).expect("should succeed");
assert!(plan.critical_path_latency_ms > 0);
}
#[test]
fn test_plan_single_source_fully_parallel() {
let planner = make_planner();
let requests = vec![
FieldRequest::new("User", "id", "users"),
FieldRequest::new("Product", "sku", "products"),
];
let plan = planner.plan_fields(&requests).expect("should succeed");
assert!(plan.is_fully_parallel);
}
#[test]
fn test_plan_max_cost_enforced() {
let config = EnhancedPlannerConfig {
max_plan_cost: Some(0.5), ..Default::default()
};
let mut planner = EnhancedFederationPlanner::with_config(config);
planner.register_source(user_source());
let requests = vec![
FieldRequest::new("User", "id", "users"),
FieldRequest::new("User", "name", "users"),
];
let result = planner.plan_fields(&requests);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("exceeds"));
}
#[test]
fn test_plan_unknown_field_returns_error() {
let planner = make_planner();
let requests = vec![FieldRequest::new("Phantom", "field", "???")];
let result = planner.plan_fields(&requests);
assert!(result.is_err());
}
#[test]
fn test_batching_respects_max_batch_size() {
let mut planner = EnhancedFederationPlanner::new();
let src = FederationSource::new("src", "http://src")
.with_type("Item")
.with_stats(SourceStats {
avg_latency_ms: 10,
cost_per_field: 1.0,
supports_batching: true,
max_batch_size: 2,
});
planner.register_source(src);
let requests: Vec<FieldRequest> = (0..5)
.map(|i| FieldRequest::new("Item", format!("field{i}").as_str(), "src"))
.collect();
let plan = planner.plan_fields(&requests).expect("should succeed");
assert_eq!(plan.sub_plans.len(), 3);
}
#[test]
fn test_batching_disabled_when_source_does_not_support_it() {
let mut planner = EnhancedFederationPlanner::new();
let src = FederationSource::new("no_batch", "http://nb")
.with_type("Foo")
.with_stats(SourceStats {
avg_latency_ms: 10,
cost_per_field: 1.0,
supports_batching: false,
max_batch_size: 0,
});
planner.register_source(src);
let requests = vec![
FieldRequest::new("Foo", "a", "no_batch"),
FieldRequest::new("Foo", "b", "no_batch"),
];
let plan = planner.plan_fields(&requests).expect("should succeed");
assert_eq!(plan.sub_plans.len(), 2);
}
#[test]
fn test_estimate_cost_single_source() {
let planner = make_planner();
let requests = vec![
FieldRequest::new("User", "id", "users"),
FieldRequest::new("User", "name", "users"),
];
let cost = planner.estimate_cost(&requests);
assert!((cost - 2.0).abs() < 1e-9);
}
#[test]
fn test_estimate_cost_cross_source_adds_penalty() {
let planner = make_planner();
let requests = vec![
FieldRequest::new("User", "id", "users"),
FieldRequest::new("Product", "sku", "products"),
];
let cost = planner.estimate_cost(&requests);
assert!(cost > 3.0); }
#[test]
fn test_sources_iterator() {
let planner = make_planner();
let ids: Vec<&str> = planner.sources().map(|s| s.id.as_str()).collect();
assert!(ids.contains(&"users"));
assert!(ids.contains(&"products"));
assert!(ids.contains(&"reviews"));
}
#[test]
fn test_config_prefer_low_latency() {
let config = EnhancedPlannerConfig {
prefer_low_latency: true,
..Default::default()
};
let mut planner = EnhancedFederationPlanner::with_config(config);
let slow = FederationSource::new("slow", "http://slow")
.with_type("Shared")
.with_stats(SourceStats {
avg_latency_ms: 100,
..Default::default()
});
let fast = FederationSource::new("fast", "http://fast")
.with_type("Shared")
.with_stats(SourceStats {
avg_latency_ms: 10,
..Default::default()
});
planner.register_source(slow);
planner.register_source(fast);
let src = planner.route_field("Shared", "field");
assert!(src.is_some());
assert_eq!(src.expect("should succeed").id, "fast");
}
#[test]
fn test_plan_is_not_empty() {
let planner = make_planner();
let requests = vec![FieldRequest::new("User", "id", "users")];
let plan = planner.plan_fields(&requests).expect("should succeed");
assert!(!plan.is_empty());
}
}