use std::collections::{HashMap, HashSet};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct FederationKey {
pub type_name: String,
pub fields: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SubGraph {
pub name: String,
pub url: String,
pub types: Vec<String>,
pub keys: Vec<FederationKey>,
}
impl SubGraph {
pub fn new(name: impl Into<String>, url: impl Into<String>) -> Self {
Self {
name: name.into(),
url: url.into(),
types: vec![],
keys: vec![],
}
}
pub fn with_type(mut self, type_name: impl Into<String>) -> Self {
self.types.push(type_name.into());
self
}
pub fn with_key(mut self, key: FederationKey) -> Self {
self.keys.push(key);
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FederationStep {
pub subgraph: String,
pub query: String,
pub depends_on: Vec<usize>,
pub resolves_type: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FederationPlan {
pub steps: Vec<FederationStep>,
pub estimated_latency_ms: u64,
pub is_parallelizable: bool,
}
impl FederationPlan {
pub fn is_empty(&self) -> bool {
self.steps.is_empty()
}
}
#[derive(Debug, Clone)]
pub struct FederationPlannerConfig {
pub base_latency_ms: u64,
}
impl Default for FederationPlannerConfig {
fn default() -> Self {
Self {
base_latency_ms: 50,
}
}
}
#[derive(Debug)]
pub struct FederationQueryPlanner {
config: FederationPlannerConfig,
}
impl FederationQueryPlanner {
pub fn new() -> Self {
Self {
config: FederationPlannerConfig::default(),
}
}
pub fn with_config(config: FederationPlannerConfig) -> Self {
Self { config }
}
pub fn plan_query(&self, query: &str, subgraphs: &[SubGraph]) -> FederationPlan {
let type_owner: HashMap<&str, &SubGraph> = subgraphs
.iter()
.flat_map(|sg| sg.types.iter().map(move |t| (t.as_str(), sg)))
.collect();
let referenced_types = Self::extract_type_references(query);
let mut subgraph_fields: HashMap<&str, Vec<String>> = HashMap::new();
for type_name in &referenced_types {
if let Some(sg) = type_owner.get(type_name.as_str()) {
subgraph_fields
.entry(sg.name.as_str())
.or_default()
.push(type_name.clone());
}
}
let key_owners: HashMap<&str, &str> = subgraphs
.iter()
.flat_map(|sg| {
sg.keys
.iter()
.map(move |k| (k.type_name.as_str(), sg.name.as_str()))
})
.collect();
let mut steps: Vec<FederationStep> = Vec::new();
let mut sg_names: Vec<&str> = subgraph_fields.keys().copied().collect();
sg_names.sort_unstable();
let name_to_idx: HashMap<&str, usize> = sg_names
.iter()
.enumerate()
.map(|(i, &name)| (name, i))
.collect();
for sg_name in &sg_names {
let fields = &subgraph_fields[sg_name];
let query_fragment = Self::build_query_fragment(fields);
let mut depends_on: Vec<usize> = fields
.iter()
.filter_map(|type_name| {
let key_owner_sg = key_owners.get(type_name.as_str())?;
if key_owner_sg != sg_name {
name_to_idx.get(key_owner_sg).copied()
} else {
None
}
})
.collect::<HashSet<_>>()
.into_iter()
.collect();
depends_on.sort_unstable();
steps.push(FederationStep {
subgraph: sg_name.to_string(),
query: query_fragment,
depends_on,
resolves_type: fields.join(", "),
});
}
let estimated_latency_ms = self.estimate_latency(&steps);
let is_parallelizable = steps.iter().all(|s| s.depends_on.is_empty());
FederationPlan {
steps,
estimated_latency_ms,
is_parallelizable,
}
}
fn extract_type_references(query: &str) -> Vec<String> {
let mut types = Vec::new();
for line in query.lines() {
let trimmed = line.trim();
if let Some(rest) = trimmed.strip_prefix("... on ") {
let type_name = rest.trim_end_matches('{').trim();
if type_name.starts_with(|c: char| c.is_uppercase()) {
types.push(type_name.to_string());
}
continue;
}
let candidate = trimmed
.split_whitespace()
.next()
.unwrap_or("")
.trim_end_matches('{');
if candidate.starts_with(|c: char| c.is_uppercase()) && trimmed.ends_with('{') {
types.push(candidate.to_string());
}
}
types
}
fn build_query_fragment(types: &[String]) -> String {
let inner: Vec<String> = types
.iter()
.map(|t| format!(" {t} {{ __typename id }}"))
.collect();
format!("{{\n{}\n}}", inner.join("\n"))
}
fn estimate_latency(&self, steps: &[FederationStep]) -> u64 {
if steps.is_empty() {
return 0;
}
let mut depth = vec![0usize; steps.len()];
for (i, step) in steps.iter().enumerate() {
let max_dep_depth = step.depends_on.iter().map(|&d| depth[d]).max().unwrap_or(0);
depth[i] = max_dep_depth + 1;
}
let max_depth = depth.into_iter().max().unwrap_or(1);
(max_depth as u64) * self.config.base_latency_ms
}
}
impl Default for FederationQueryPlanner {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Default)]
pub struct EntityResolver {
type_to_url: HashMap<String, String>,
}
impl EntityResolver {
pub fn new() -> Self {
Self::default()
}
pub fn register_type(&mut self, type_name: impl Into<String>, url: impl Into<String>) {
self.type_to_url.insert(type_name.into(), url.into());
}
pub fn register_subgraphs(&mut self, subgraphs: &[SubGraph]) {
for sg in subgraphs {
for type_name in &sg.types {
self.type_to_url.insert(type_name.clone(), sg.url.clone());
}
}
}
pub fn resolve_entities(
&self,
typename: &str,
representations: &[serde_json::Value],
) -> Vec<serde_json::Value> {
representations
.iter()
.map(|rep| {
let mut obj = match rep {
serde_json::Value::Object(m) => m.clone(),
_ => serde_json::Map::new(),
};
obj.insert(
"__typename".to_string(),
serde_json::Value::String(typename.to_string()),
);
obj.insert(
"_resolved".to_string(),
serde_json::Value::Bool(self.type_to_url.contains_key(typename)),
);
if let Some(url) = self.type_to_url.get(typename) {
obj.insert(
"_owning_subgraph".to_string(),
serde_json::Value::String(url.clone()),
);
}
serde_json::Value::Object(obj)
})
.collect()
}
pub fn owner_url(&self, typename: &str) -> Option<&str> {
self.type_to_url.get(typename).map(|s| s.as_str())
}
}
#[cfg(test)]
mod tests {
use super::*;
fn accounts_sg() -> SubGraph {
SubGraph::new("accounts", "https://accounts.example.com/graphql")
.with_type("User")
.with_type("Account")
.with_key(FederationKey {
type_name: "User".to_string(),
fields: vec!["id".to_string()],
})
}
fn products_sg() -> SubGraph {
SubGraph::new("products", "https://products.example.com/graphql")
.with_type("Product")
.with_type("Category")
.with_key(FederationKey {
type_name: "Product".to_string(),
fields: vec!["sku".to_string()],
})
}
fn reviews_sg() -> SubGraph {
SubGraph::new("reviews", "https://reviews.example.com/graphql")
.with_type("Review")
.with_key(FederationKey {
type_name: "Review".to_string(),
fields: vec!["id".to_string()],
})
}
#[test]
fn test_federation_key_construction() {
let key = FederationKey {
type_name: "User".to_string(),
fields: vec!["id".to_string()],
};
assert_eq!(key.type_name, "User");
assert_eq!(key.fields, vec!["id"]);
}
#[test]
fn test_subgraph_builder() {
let sg = accounts_sg();
assert_eq!(sg.name, "accounts");
assert!(sg.types.contains(&"User".to_string()));
assert!(!sg.keys.is_empty());
}
#[test]
fn test_plan_empty_query() {
let planner = FederationQueryPlanner::new();
let plan = planner.plan_query("{ }", &[accounts_sg(), products_sg()]);
assert!(plan.steps.is_empty());
}
#[test]
fn test_plan_single_subgraph() {
let planner = FederationQueryPlanner::new();
let query = "{\n User {\n id name\n }\n}";
let plan = planner.plan_query(query, &[accounts_sg(), products_sg()]);
assert_eq!(plan.steps.len(), 1);
assert_eq!(plan.steps[0].subgraph, "accounts");
}
#[test]
fn test_plan_multiple_subgraphs() {
let planner = FederationQueryPlanner::new();
let query = "{\n User {\n id\n }\n Product {\n sku\n }\n}";
let subgraphs = vec![accounts_sg(), products_sg()];
let plan = planner.plan_query(query, &subgraphs);
assert_eq!(plan.steps.len(), 2);
let sg_names: Vec<&str> = plan.steps.iter().map(|s| s.subgraph.as_str()).collect();
assert!(sg_names.contains(&"accounts"));
assert!(sg_names.contains(&"products"));
}
#[test]
fn test_plan_parallelizable_when_no_deps() {
let planner = FederationQueryPlanner::new();
let query = "{\n User {\n id\n }\n Product {\n sku\n }\n}";
let plan = planner.plan_query(query, &[accounts_sg(), products_sg()]);
assert!(plan.is_parallelizable);
}
#[test]
fn test_plan_estimated_latency_nonzero() {
let planner = FederationQueryPlanner::new();
let query = "{\n User {\n id\n }\n}";
let plan = planner.plan_query(query, &[accounts_sg()]);
assert!(plan.estimated_latency_ms > 0);
}
#[test]
fn test_plan_step_contains_query_fragment() {
let planner = FederationQueryPlanner::new();
let query = "{\n Product {\n sku\n }\n}";
let plan = planner.plan_query(query, &[products_sg()]);
assert!(!plan.steps.is_empty());
assert!(!plan.steps[0].query.is_empty());
}
#[test]
fn test_plan_is_empty() {
let plan = FederationPlan {
steps: vec![],
estimated_latency_ms: 0,
is_parallelizable: true,
};
assert!(plan.is_empty());
}
#[test]
fn test_plan_unrecognised_type_ignored() {
let planner = FederationQueryPlanner::new();
let query = "{\n Ghost {\n id\n }\n}";
let plan = planner.plan_query(query, &[accounts_sg()]);
assert!(plan.steps.is_empty());
}
#[test]
fn test_plan_three_subgraphs() {
let planner = FederationQueryPlanner::new();
let query =
"{\n User {\n id\n }\n Product {\n sku\n }\n Review {\n id\n }\n}";
let plan = planner.plan_query(query, &[accounts_sg(), products_sg(), reviews_sg()]);
assert_eq!(plan.steps.len(), 3);
}
#[test]
fn test_entity_resolver_register_type() {
let mut resolver = EntityResolver::new();
resolver.register_type("User", "https://accounts.example.com/graphql");
assert_eq!(
resolver.owner_url("User"),
Some("https://accounts.example.com/graphql")
);
}
#[test]
fn test_entity_resolver_register_subgraphs() {
let mut resolver = EntityResolver::new();
resolver.register_subgraphs(&[accounts_sg(), products_sg()]);
assert!(resolver.owner_url("User").is_some());
assert!(resolver.owner_url("Product").is_some());
}
#[test]
fn test_resolve_entities_enriches_response() {
let mut resolver = EntityResolver::new();
resolver.register_type("User", "https://accounts.example.com/graphql");
let reps = vec![serde_json::json!({"__typename": "User", "id": "1"})];
let resolved = resolver.resolve_entities("User", &reps);
assert_eq!(resolved.len(), 1);
let obj = resolved[0].as_object().expect("object");
assert_eq!(obj["__typename"], "User");
assert_eq!(obj["_resolved"], true);
}
#[test]
fn test_resolve_entities_unknown_type_not_resolved() {
let resolver = EntityResolver::new();
let reps = vec![serde_json::json!({"__typename": "Ghost", "id": "99"})];
let resolved = resolver.resolve_entities("Ghost", &reps);
let obj = resolved[0].as_object().expect("object");
assert_eq!(obj["_resolved"], false);
}
#[test]
fn test_resolve_entities_multiple_representations() {
let mut resolver = EntityResolver::new();
resolver.register_type("Product", "https://products.example.com/graphql");
let reps = vec![
serde_json::json!({"sku": "ABC"}),
serde_json::json!({"sku": "DEF"}),
serde_json::json!({"sku": "GHI"}),
];
let resolved = resolver.resolve_entities("Product", &reps);
assert_eq!(resolved.len(), 3);
for r in &resolved {
assert_eq!(r["__typename"], "Product");
assert_eq!(r["_resolved"], true);
}
}
#[test]
fn test_owner_url_unregistered_type_returns_none() {
let resolver = EntityResolver::new();
assert!(resolver.owner_url("Unknown").is_none());
}
#[test]
fn test_planner_config_latency() {
let config = FederationPlannerConfig {
base_latency_ms: 100,
};
let planner = FederationQueryPlanner::with_config(config);
let query = "{\n User {\n id\n }\n}";
let plan = planner.plan_query(query, &[accounts_sg()]);
assert_eq!(plan.estimated_latency_ms, 100);
}
#[test]
fn test_step_resolves_type_field_populated() {
let planner = FederationQueryPlanner::new();
let query = "{\n User {\n id\n }\n}";
let plan = planner.plan_query(query, &[accounts_sg()]);
assert!(!plan.steps[0].resolves_type.is_empty());
}
}