use std::collections::VecDeque;
use crate::error::FederationError;
use crate::query_plan::QueryPlanCost;
type Choices<T> = Vec<Option<T>>;
struct Partial<Plan, Element> {
partial_plan: Plan,
partial_cost: Option<QueryPlanCost>,
remaining: std::vec::IntoIter<Choices<Element>>,
is_root: bool,
index: Option<usize>,
}
pub trait PlanBuilder<Plan, Element> {
fn add_to_plan(&mut self, plan: &Plan, elem: Element) -> Result<Plan, FederationError>;
fn compute_plan_cost(&mut self, plan: &mut Plan) -> Result<QueryPlanCost, FederationError>;
fn on_plan_generated(&self, plan: &Plan, cost: QueryPlanCost, prev_cost: Option<QueryPlanCost>);
}
struct Extracted<Element> {
extracted: Element,
is_last: bool,
}
#[cfg_attr(
feature = "snapshot_tracing",
tracing::instrument(skip_all, level = "trace")
)]
pub fn generate_all_plans_and_find_best<Plan, Element>(
mut initial: Plan,
to_add: Vec<Choices<Element>>,
plan_builder: &mut impl PlanBuilder<Plan, Element>,
) -> Result<(Plan, QueryPlanCost), FederationError>
where
Element: Clone,
{
if to_add.is_empty() {
let cost = plan_builder.compute_plan_cost(&mut initial)?;
return Ok((initial, cost));
}
let mut stack = VecDeque::new();
stack.push_back(Partial {
partial_plan: initial,
partial_cost: None,
remaining: to_add.into_iter(),
is_root: true,
index: Some(0),
});
let mut min = None;
while let Some(Partial {
partial_plan,
partial_cost,
mut remaining,
is_root,
index,
}) = stack.pop_back()
{
if let (Some((_, min_cost)), Some(partial_cost)) = (&min, &partial_cost)
&& partial_cost >= min_cost
{
continue;
}
let next_choices = &mut remaining.as_mut_slice()[0];
let picked_index = pick_next(index, next_choices);
let Extracted { extracted, is_last } = extract(picked_index, next_choices);
let mut new_partial_plan = plan_builder.add_to_plan(&partial_plan, extracted)?;
let cost = plan_builder.compute_plan_cost(&mut new_partial_plan)?;
if !is_last {
insert_in_stack(
&mut stack,
Partial {
partial_plan,
partial_cost,
is_root,
index: index.and_then(|i| {
let next = i + 1;
(is_root && next < next_choices.len()).then_some(next)
}),
remaining: remaining.clone(),
},
)
}
remaining.next();
let previous_min_cost = min.as_ref().map(|&(_, cost)| cost);
let previous_min_is_better = previous_min_cost.is_some_and(|min| min <= cost);
if remaining.as_slice().is_empty() {
plan_builder.on_plan_generated(&new_partial_plan, cost, previous_min_cost);
if !previous_min_is_better {
min = Some((new_partial_plan, cost))
}
continue;
}
if !previous_min_is_better {
insert_in_stack(
&mut stack,
Partial {
partial_plan: new_partial_plan,
partial_cost: Some(cost),
remaining,
is_root: false,
index,
},
)
}
}
min.ok_or_else(|| FederationError::internal("A plan should have been found"))
}
fn insert_in_stack<Plan, Element>(
stack: &mut VecDeque<Partial<Plan, Element>>,
item: Partial<Plan, Element>,
) {
if item.index.is_some() {
stack.push_back(item)
} else {
stack.push_front(item)
}
}
fn pick_next<Element>(opt_index: Option<usize>, remaining: &Choices<Element>) -> usize {
if let Some(index) = opt_index
&& let Some(choice) = remaining.get(index)
{
assert!(choice.is_some(), "Invalid index {index}");
return index;
}
remaining
.iter()
.position(|choice| choice.is_some())
.expect("Passed a `remaining` with all `None`")
}
fn extract<Element>(index: usize, choices: &mut Choices<Element>) -> Extracted<Element> {
let extracted = choices[index].take().unwrap();
let is_last = choices.iter().all(|choice| choice.is_none());
Extracted { extracted, is_last }
}
#[cfg(test)]
mod tests {
use super::*;
type Element = &'static str;
type Plan = Vec<Element>;
struct TestPlanBuilder<'a> {
generated: &'a mut Vec<Vec<&'static str>>,
target_len: usize,
}
impl PlanBuilder<Plan, Element> for TestPlanBuilder<'_> {
fn add_to_plan(
&mut self,
partial_plan: &Plan,
new_element: Element,
) -> Result<Plan, FederationError> {
let new_plan: Plan = partial_plan
.iter()
.cloned()
.chain(std::iter::once(new_element))
.collect();
if new_plan.len() == self.target_len {
self.generated.push(new_plan.clone())
}
Ok(new_plan)
}
fn compute_plan_cost(&mut self, plan: &mut Plan) -> Result<QueryPlanCost, FederationError> {
Ok(plan
.iter()
.map(|element| element.len() as QueryPlanCost)
.sum())
}
fn on_plan_generated(
&self,
_plan: &Plan,
_cost: QueryPlanCost,
_prev_cost: Option<QueryPlanCost>,
) {
}
}
fn generate_test_plans(initial: Plan, choices: Vec<Vec<Option<Element>>>) -> (Plan, Vec<Plan>) {
let mut generated = Vec::new();
let target_len = initial.len() + choices.len();
let mut plan_builder = TestPlanBuilder {
generated: &mut generated,
target_len,
};
let (best, _) =
generate_all_plans_and_find_best::<Plan, Element>(initial, choices, &mut plan_builder)
.unwrap();
(best, generated)
}
#[test]
fn pick_elements_at_same_index_first() {
let (best, generated) = generate_test_plans(
vec!["I"],
vec![
vec![Some("A1"), Some("B1")],
vec![Some("A2"), Some("B2")],
vec![Some("A3"), Some("B3")],
],
);
assert_eq!(best, ["I", "A1", "A2", "A3"]);
assert_eq!(generated[0], ["I", "A1", "A2", "A3"]);
assert_eq!(generated[1], ["I", "B1", "B2", "B3"]);
}
#[test]
fn bail_early_for_more_costly_elements() {
let (best, generated) = generate_test_plans(
vec!["I"],
vec![
vec![Some("A1"), Some("B1VeryCostly")],
vec![Some("A2"), Some("B2Co")],
vec![Some("A3"), Some("B3")],
],
);
assert_eq!(best, ["I", "A1", "A2", "A3"]);
assert_eq!(generated.len(), 2);
assert_eq!(generated[0], ["I", "A1", "A2", "A3"]);
assert_eq!(generated[1], ["I", "A1", "A2", "B3"]);
}
#[test]
fn handles_branches_of_various_sizes() {
let (best, mut generated) = generate_test_plans(
vec!["I"],
vec![
vec![Some("A1x"), Some("B1")],
vec![Some("A2"), Some("B2Costly"), Some("C2")],
vec![Some("A3")],
vec![Some("A4"), Some("B4")],
],
);
assert_eq!(best, ["I", "B1", "A2", "A3", "A4"]);
generated.sort();
assert_eq!(
generated,
[
vec!["I", "A1x", "A2", "A3", "A4"],
vec!["I", "A1x", "A2", "A3", "B4"],
vec!["I", "A1x", "C2", "A3", "A4"],
vec!["I", "A1x", "C2", "A3", "B4"],
vec!["I", "B1", "A2", "A3", "A4"],
vec!["I", "B1", "A2", "A3", "B4"],
vec!["I", "B1", "C2", "A3", "A4"],
vec!["I", "B1", "C2", "A3", "B4"],
],
);
}
}