use selene_core::DbString;
use crate::{
LabelExpr,
plan::{
BindingDef, ExecutionPlan, FilterPredicate, IndexCatalog, IndexTarget, JoinTree,
NodeOrEdgeScan, ScanAccess, ScanKind,
optimize::{OptimizeContext, Rule, Transformed, binding_refs, cost, walk},
},
};
use super::index_helpers::{equality_candidates, flat_disjunction_singles};
pub struct DisjunctiveLabelExpansion;
impl Rule for DisjunctiveLabelExpansion {
fn name(&self) -> &'static str {
"disjunctive_label_expansion"
}
fn rewrite(
&self,
mut plan: ExecutionPlan,
ctx: &OptimizeContext<'_>,
) -> Transformed<ExecutionPlan> {
let Some(catalog) = ctx.index_catalog else {
return Transformed::unchanged(plan);
};
let mut changed = false;
if let Some(pattern) = &mut plan.pattern_plan {
changed |= rewrite_tree(&mut pattern.join_tree, &pattern.bindings, catalog);
}
let nested = walk::recurse_rule_subplans(plan, self, ctx);
changed |= nested.changed;
Transformed {
plan: nested.plan,
changed,
}
}
}
fn rewrite_tree(tree: &mut JoinTree, bindings: &[BindingDef], catalog: &dyn IndexCatalog) -> bool {
match tree {
JoinTree::Unit => false,
JoinTree::Scan(_) => maybe_expand_scan(tree, bindings, catalog),
JoinTree::Expand { child, .. }
| JoinTree::Questioned { child, .. }
| JoinTree::Repeat { child, .. }
| JoinTree::PathSearch { child, .. }
| JoinTree::PathModeFilter { child, .. }
| JoinTree::MatchModeFilter { child, .. } => rewrite_tree(child, bindings, catalog),
JoinTree::HashJoin { left, right, .. } | JoinTree::Outer { left, right, .. } => {
rewrite_tree(left, bindings, catalog) | rewrite_tree(right, bindings, catalog)
}
JoinTree::WorstCaseOptimal { .. } => false,
JoinTree::Subplan(_) => false,
JoinTree::DisjunctiveScan { .. } => false,
}
}
fn maybe_expand_scan(
tree: &mut JoinTree,
bindings: &[BindingDef],
catalog: &dyn IndexCatalog,
) -> bool {
let JoinTree::Scan(scan) = &*tree else {
return false;
};
if scan.kind != ScanKind::Node {
return false;
}
if !matches!(scan.access, ScanAccess::Linear) {
return false;
}
let Some(labels) = flat_disjunction_singles(&scan.label_predicate) else {
return false;
};
if !any_branch_has_applicable_index(&labels, &scan.property_predicates, bindings, catalog) {
return false;
}
if let (Some(expand_cost), Some(baseline)) = (
cost::disjunctive_cost(catalog, IndexTarget::Node, &labels),
catalog.total_rows(IndexTarget::Node),
) && cost::should_decline_index(expand_cost, baseline)
{
return false;
}
let JoinTree::Scan(scan) = tree else {
unreachable!("just matched as Scan(_) above");
};
let original = scan.clone();
let branches: Vec<NodeOrEdgeScan> = labels
.iter()
.map(|label| {
let mut clone = original.clone();
clone.label_predicate = Some(LabelExpr::Single(label.clone()));
clone
})
.collect();
let scan_anchor = original;
*tree = JoinTree::DisjunctiveScan {
branches,
scan_anchor,
};
true
}
fn any_branch_has_applicable_index(
labels: &[DbString],
predicates: &[FilterPredicate],
bindings: &[BindingDef],
catalog: &dyn IndexCatalog,
) -> bool {
let eq_candidates = equality_candidates(predicates, bindings);
let eq_keys: Vec<DbString> = eq_candidates
.iter()
.map(|candidate| candidate.key.clone())
.collect();
labels.iter().any(|label| {
for pred in predicates {
let Some(matched) = binding_refs::match_property_predicate(pred, bindings) else {
continue;
};
if catalog
.typed_index(IndexTarget::Node, label.clone(), matched.key)
.is_some()
{
return true;
}
}
if eq_keys.len() >= 2
&& catalog
.composite_index(IndexTarget::Node, label.clone(), &eq_keys)
.is_some()
{
return true;
}
false
})
}