use std::sync::Arc;
use apollo_compiler::Name;
use apollo_compiler::Node;
use apollo_compiler::collections::IndexSet;
use apollo_compiler::executable;
use apollo_compiler::executable::VariableDefinition;
use super::QueryPathElement;
use super::conditions::ConditionKind;
use super::query_planner::SubgraphOperationCompression;
use crate::error::FederationError;
use crate::operation::DirectiveList;
use crate::operation::SelectionSet;
use crate::query_graph::QueryGraph;
use crate::query_graph::graph_path::operation::OpPathElement;
use crate::query_plan::ConditionNode;
use crate::query_plan::DeferNode;
use crate::query_plan::DeferredDeferBlock;
use crate::query_plan::DeferredDependency;
use crate::query_plan::ParallelNode;
use crate::query_plan::PlanNode;
use crate::query_plan::PrimaryDeferBlock;
use crate::query_plan::QueryPlanCost;
use crate::query_plan::SequenceNode;
use crate::query_plan::conditions::Conditions;
use crate::query_plan::fetch_dependency_graph::DeferredInfo;
use crate::query_plan::fetch_dependency_graph::FetchDependencyGraphNode;
const FETCH_COST: QueryPlanCost = 1000.0;
const PIPELINING_COST: QueryPlanCost = 100.0;
pub(crate) struct FetchDependencyGraphToQueryPlanProcessor {
variable_definitions: Arc<Vec<Node<VariableDefinition>>>,
operation_directives: DirectiveList,
operation_compression: SubgraphOperationCompression,
operation_name: Option<Name>,
assigned_defer_labels: IndexSet<String>,
counter: u32,
}
#[derive(Clone, Copy)]
pub(crate) struct FetchDependencyGraphToCostProcessor;
pub(crate) trait FetchDependencyGraphProcessor<TProcessed, TDeferred> {
fn on_node(
&mut self,
query_graph: &QueryGraph,
node: &mut FetchDependencyGraphNode,
handled_conditions: &Conditions,
) -> Result<TProcessed, FederationError>;
fn on_conditions(&mut self, conditions: &Conditions, value: TProcessed) -> TProcessed;
fn reduce_parallel(&mut self, values: impl IntoIterator<Item = TProcessed>) -> TProcessed;
fn reduce_sequence(&mut self, values: impl IntoIterator<Item = TProcessed>) -> TProcessed;
fn reduce_deferred(
&mut self,
defer_info: &DeferredInfo,
value: TProcessed,
) -> Result<TDeferred, FederationError>;
fn reduce_defer(
&mut self,
main: TProcessed,
sub_selection: &SelectionSet,
deferred_blocks: Vec<TDeferred>,
) -> Result<TProcessed, FederationError>;
}
impl<TProcessed, TDeferred, T> FetchDependencyGraphProcessor<TProcessed, TDeferred> for &mut T
where
T: FetchDependencyGraphProcessor<TProcessed, TDeferred>,
{
fn on_node(
&mut self,
query_graph: &QueryGraph,
node: &mut FetchDependencyGraphNode,
handled_conditions: &Conditions,
) -> Result<TProcessed, FederationError> {
(*self).on_node(query_graph, node, handled_conditions)
}
fn on_conditions(&mut self, conditions: &Conditions, value: TProcessed) -> TProcessed {
(*self).on_conditions(conditions, value)
}
fn reduce_parallel(&mut self, values: impl IntoIterator<Item = TProcessed>) -> TProcessed {
(*self).reduce_parallel(values)
}
fn reduce_sequence(&mut self, values: impl IntoIterator<Item = TProcessed>) -> TProcessed {
(*self).reduce_sequence(values)
}
fn reduce_deferred(
&mut self,
defer_info: &DeferredInfo,
value: TProcessed,
) -> Result<TDeferred, FederationError> {
(*self).reduce_deferred(defer_info, value)
}
fn reduce_defer(
&mut self,
main: TProcessed,
sub_selection: &SelectionSet,
deferred_blocks: Vec<TDeferred>,
) -> Result<TProcessed, FederationError> {
(*self).reduce_defer(main, sub_selection, deferred_blocks)
}
}
impl FetchDependencyGraphProcessor<QueryPlanCost, QueryPlanCost>
for FetchDependencyGraphToCostProcessor
{
fn on_node(
&mut self,
_query_graph: &QueryGraph,
node: &mut FetchDependencyGraphNode,
_handled_conditions: &Conditions,
) -> Result<QueryPlanCost, FederationError> {
Ok(FETCH_COST + node.cost())
}
fn on_conditions(&mut self, _conditions: &Conditions, value: QueryPlanCost) -> QueryPlanCost {
value
}
fn reduce_parallel(
&mut self,
values: impl IntoIterator<Item = QueryPlanCost>,
) -> QueryPlanCost {
parallel_cost(values)
}
fn reduce_sequence(
&mut self,
values: impl IntoIterator<Item = QueryPlanCost>,
) -> QueryPlanCost {
sequence_cost(values)
}
fn reduce_deferred(
&mut self,
_defer_info: &DeferredInfo,
value: QueryPlanCost,
) -> Result<QueryPlanCost, FederationError> {
Ok(value)
}
fn reduce_defer(
&mut self,
main: QueryPlanCost,
_sub_selection: &SelectionSet,
deferred_blocks: Vec<QueryPlanCost>,
) -> Result<QueryPlanCost, FederationError> {
Ok(sequence_cost([main, parallel_cost(deferred_blocks)]))
}
}
fn parallel_cost(values: impl IntoIterator<Item = QueryPlanCost>) -> QueryPlanCost {
values.into_iter().sum()
}
fn sequence_cost(values: impl IntoIterator<Item = QueryPlanCost>) -> QueryPlanCost {
values
.into_iter()
.enumerate()
.map(|(i, stage)| stage * (1.0f64).max(i as QueryPlanCost * PIPELINING_COST))
.sum()
}
impl FetchDependencyGraphToQueryPlanProcessor {
pub(crate) fn new(
variable_definitions: Arc<Vec<Node<VariableDefinition>>>,
operation_directives: DirectiveList,
operation_compression: SubgraphOperationCompression,
operation_name: Option<Name>,
assigned_defer_labels: IndexSet<String>,
) -> Self {
Self {
variable_definitions,
operation_directives,
operation_compression,
operation_name,
assigned_defer_labels,
counter: 0,
}
}
}
impl FetchDependencyGraphProcessor<Option<PlanNode>, DeferredDeferBlock>
for FetchDependencyGraphToQueryPlanProcessor
{
fn on_node(
&mut self,
query_graph: &QueryGraph,
node: &mut FetchDependencyGraphNode,
handled_conditions: &Conditions,
) -> Result<Option<PlanNode>, FederationError> {
let op_name = self.operation_name.as_ref().map(|name| {
let counter = self.counter;
self.counter += 1;
let subgraph = to_valid_graphql_name(&node.subgraph_name).unwrap_or("".into());
Name::new(&format!("{name}__{subgraph}__{counter}")).unwrap()
});
node.to_plan_node(
query_graph,
handled_conditions,
&self.variable_definitions,
&self.operation_directives,
&mut self.operation_compression,
op_name,
)
}
fn on_conditions(
&mut self,
conditions: &Conditions,
value: Option<PlanNode>,
) -> Option<PlanNode> {
let mut value = value?;
match conditions {
Conditions::Boolean(condition) => {
condition.then_some(value)
}
Conditions::Variables(variables) => {
for (name, kind) in variables.iter() {
let (if_clause, else_clause) = match kind {
ConditionKind::Skip => (None, Some(Box::new(value))),
ConditionKind::Include => (Some(Box::new(value)), None),
};
value = PlanNode::from(ConditionNode {
condition_variable: name.clone(),
if_clause,
else_clause,
});
}
Some(value)
}
}
}
fn reduce_parallel(
&mut self,
values: impl IntoIterator<Item = Option<PlanNode>>,
) -> Option<PlanNode> {
flat_wrap_nodes(NodeKind::Parallel, values)
}
fn reduce_sequence(
&mut self,
values: impl IntoIterator<Item = Option<PlanNode>>,
) -> Option<PlanNode> {
flat_wrap_nodes(NodeKind::Sequence, values)
}
fn reduce_deferred(
&mut self,
defer_info: &DeferredInfo,
node: Option<PlanNode>,
) -> Result<DeferredDeferBlock, FederationError> {
fn op_path_to_query_path(path: &[Arc<OpPathElement>]) -> Vec<QueryPathElement> {
path.iter()
.filter_map(|element| match &**element {
OpPathElement::Field(field) => Some(QueryPathElement::Field {
response_key: field.response_name().clone(),
}),
OpPathElement::InlineFragment(inline) => inline
.type_condition_position
.as_ref()
.map(|cond| QueryPathElement::InlineFragment {
type_condition: cond.type_name().clone(),
}),
})
.collect()
}
Ok(DeferredDeferBlock {
depends: defer_info
.dependencies
.iter()
.cloned()
.map(|id| DeferredDependency { id })
.collect(),
label: if self.assigned_defer_labels.contains(&defer_info.label) {
None
} else {
Some(defer_info.label.clone())
},
query_path: op_path_to_query_path(&defer_info.path.full_path),
sub_selection: if defer_info.deferred.is_empty() {
defer_info
.sub_selection
.without_empty_branches()
.map(|filtered| executable::SelectionSet::try_from(filtered.as_ref()))
.transpose()?
.map(|selection_set| selection_set.serialize().no_indent().to_string())
} else {
None
},
node: node.map(Box::new),
})
}
fn reduce_defer(
&mut self,
main: Option<PlanNode>,
sub_selection: &SelectionSet,
deferred: Vec<DeferredDeferBlock>,
) -> Result<Option<PlanNode>, FederationError> {
Ok(Some(PlanNode::Defer(DeferNode {
primary: PrimaryDeferBlock {
sub_selection: sub_selection
.without_empty_branches()
.map(|filtered| executable::SelectionSet::try_from(filtered.as_ref()))
.transpose()?
.map(|selection_set| selection_set.serialize().no_indent().to_string()),
node: main.map(Box::new),
},
deferred,
})))
}
}
pub(crate) fn to_valid_graphql_name(subgraph_name: &str) -> Option<String> {
let mut chars = subgraph_name.chars().filter_map(|c| {
if let '-' | '_' = c {
Some('_')
} else {
c.is_ascii_alphanumeric().then_some(c)
}
});
let first = chars.next()?;
let mut sanitized = String::with_capacity(subgraph_name.len() + 1);
if first.is_ascii_digit() {
sanitized.push('_')
}
sanitized.push(first);
sanitized.extend(chars);
Some(sanitized)
}
#[derive(Clone, Copy)]
enum NodeKind {
Parallel,
Sequence,
}
fn flat_wrap_nodes(
kind: NodeKind,
nodes: impl IntoIterator<Item = Option<PlanNode>>,
) -> Option<PlanNode> {
let mut iter = nodes.into_iter().flatten();
let first = iter.next()?;
let Some(second) = iter.next() else {
return Some(first);
};
let mut nodes = Vec::new();
for node in [first, second].into_iter().chain(iter) {
match (kind, node) {
(NodeKind::Parallel, PlanNode::Parallel(inner)) => {
nodes.extend(inner.nodes.iter().cloned())
}
(NodeKind::Sequence, PlanNode::Sequence(inner)) => {
nodes.extend(inner.nodes.iter().cloned())
}
(_, node) => nodes.push(node),
}
}
Some(match kind {
NodeKind::Parallel => PlanNode::Parallel(ParallelNode { nodes }),
NodeKind::Sequence => PlanNode::Sequence(SequenceNode { nodes }),
})
}