use std::collections::HashMap;
use serde::Serialize;
use serde_json::Value;
pub(crate) const MAX_ENTITY_DEPTH: usize = 8;
#[derive(Debug, Clone, Serialize)]
pub struct QueryPlan {
pub fetches: Vec<SubgraphFetch>,
}
#[derive(Debug, Clone, Serialize)]
pub struct SubgraphFetch {
pub subgraph: String,
pub query: String,
pub variables: Value,
pub is_entity_fetch: bool,
pub depends_on: Option<usize>,
}
#[derive(Debug, Clone, Default)]
pub struct FieldOwnership {
entries: HashMap<String, String>,
}
impl FieldOwnership {
pub fn insert(&mut self, field: String, subgraph: String) {
self.entries.insert(field, subgraph);
}
pub fn owner(&self, field: &str) -> Option<&str> {
self.entries.get(field).map(String::as_str)
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum PlanError {
UnknownField {
field: String,
},
DepthExceeded {
depth: usize,
max: usize,
},
EmptyQuery,
}
impl std::fmt::Display for PlanError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::UnknownField { field } => {
write!(f, "No subgraph owns root field '{field}'")
},
Self::DepthExceeded { depth, max } => {
write!(f, "Entity resolution depth {depth} exceeds max {max}")
},
Self::EmptyQuery => write!(f, "Query is empty"),
}
}
}
impl std::error::Error for PlanError {}
pub fn plan_query(
root_fields: &[String],
ownership: &FieldOwnership,
) -> Result<QueryPlan, PlanError> {
if root_fields.is_empty() {
return Err(PlanError::EmptyQuery);
}
let mut groups: HashMap<String, Vec<String>> = HashMap::new();
for field in root_fields {
let owner = ownership.owner(field).ok_or_else(|| PlanError::UnknownField {
field: field.clone(),
})?;
groups.entry(owner.to_string()).or_default().push(field.clone());
}
let fetches = groups
.into_iter()
.map(|(subgraph, fields)| {
let selection = fields.join("\n ");
let query = format!("{{\n {selection}\n}}");
SubgraphFetch {
subgraph,
query,
variables: Value::Object(serde_json::Map::new()),
is_entity_fetch: false,
depends_on: None,
}
})
.collect();
Ok(QueryPlan { fetches })
}
pub fn plan_entity_fetch(
subgraph: &str,
representations: &[Value],
selection: &str,
current_depth: usize,
) -> Result<SubgraphFetch, PlanError> {
if current_depth >= MAX_ENTITY_DEPTH {
return Err(PlanError::DepthExceeded {
depth: current_depth,
max: MAX_ENTITY_DEPTH,
});
}
let query = format!(
"query($representations: [_Any!]!) {{\n _entities(representations: $representations) {{\n ... on _ {{\n {selection}\n }}\n }}\n}}"
);
let variables = serde_json::json!({
"representations": representations,
});
Ok(SubgraphFetch {
subgraph: subgraph.to_string(),
query,
variables,
is_entity_fetch: true,
depends_on: None,
})
}
pub fn extract_root_fields(query: &str) -> Vec<String> {
let trimmed = query.trim();
let body = if let Some(brace_start) = trimmed.find('{') {
&trimmed[brace_start + 1..]
} else {
return Vec::new();
};
let body = if let Some(brace_end) = body.rfind('}') {
&body[..brace_end]
} else {
return Vec::new();
};
let mut fields = Vec::new();
let mut brace_depth: i32 = 0;
let mut paren_depth: i32 = 0;
for token in body.split_whitespace() {
if brace_depth == 0 && paren_depth == 0 {
let field_name = token.split('(').next().unwrap_or(token);
if !field_name.is_empty()
&& field_name != "{"
&& field_name != "}"
&& !field_name.starts_with('#')
&& !field_name.starts_with("...")
{
fields.push(field_name.to_string());
}
}
for ch in token.chars() {
match ch {
'{' => brace_depth += 1,
'}' => brace_depth -= 1,
'(' => paren_depth += 1,
')' => paren_depth -= 1,
_ => {},
}
}
}
fields
}