use crate::{
BinaryOp, EdgeMatch, IndexTarget, Literal,
plan::{
BindingDef, BindingElement, ExecutionPlan, FilterPredicate, IndexKey, JoinTree, ScanAccess,
ScanKind, TypedIndexBounds,
optimize::{OptimizeContext, Rule, Transformed, binding_refs, cost, walk},
},
};
use super::index_helpers::{compatible_value, single_label};
pub struct RangeIndexScan;
impl Rule for RangeIndexScan {
fn name(&self) -> &'static str {
"range_index_scan"
}
fn rewrite(
&self,
mut plan: ExecutionPlan,
ctx: &OptimizeContext<'_>,
) -> Transformed<ExecutionPlan> {
let Some(catalog) = ctx.index_catalog else {
return Transformed::unchanged(plan);
};
let mut changed = false;
if let Some(pattern) = &mut plan.pattern_plan {
changed |= rewrite_tree(&mut pattern.join_tree, &pattern.bindings, catalog);
}
let nested = walk::recurse_rule_subplans(plan, self, ctx);
changed |= nested.changed;
Transformed {
plan: nested.plan,
changed,
}
}
}
fn rewrite_tree(
tree: &mut JoinTree,
bindings: &[BindingDef],
catalog: &dyn crate::IndexCatalog,
) -> bool {
match tree {
JoinTree::Unit => false,
JoinTree::Scan(scan) => rewrite_scan(scan, bindings, catalog),
JoinTree::Expand { child, edge, .. } | JoinTree::Questioned { child, edge, .. } => {
rewrite_tree(child, bindings, catalog) | rewrite_edge(edge, bindings, catalog)
}
JoinTree::Repeat { child, .. } => rewrite_tree(child, bindings, catalog),
JoinTree::HashJoin { left, right, .. } | JoinTree::Outer { left, right, .. } => {
rewrite_tree(left, bindings, catalog) | rewrite_tree(right, bindings, catalog)
}
JoinTree::PathSearch { child, .. }
| JoinTree::PathModeFilter { child, .. }
| JoinTree::MatchModeFilter { child, .. } => rewrite_tree(child, bindings, catalog),
JoinTree::WorstCaseOptimal { .. } | JoinTree::Subplan(_) => false,
JoinTree::DisjunctiveScan { branches, .. } => {
branches.iter_mut().fold(false, |changed, branch| {
rewrite_scan(branch, bindings, catalog) | changed
})
}
}
}
fn rewrite_scan(
scan: &mut crate::NodeOrEdgeScan,
bindings: &[BindingDef],
catalog: &dyn crate::IndexCatalog,
) -> bool {
if !matches!(scan.access, ScanAccess::Linear) {
return false;
}
let Some(label) = single_label(&scan.label_predicate) else {
return false;
};
let target = target_for_scan_kind(scan.kind);
let Some(candidate) = best_candidate(
&scan.property_predicates,
bindings,
catalog,
target,
label.clone(),
) else {
return false;
};
if let (Some(index_cost), Some(baseline)) = (
cost::typed_index_cost(
catalog,
target,
label.clone(),
candidate.property.clone(),
&candidate.bounds,
),
cost::linear_baseline(catalog, target, label),
) && cost::should_decline_index(index_cost, baseline)
{
return false;
}
remove_indices(&mut scan.property_predicates, &candidate.consumed_indices);
scan.access = ScanAccess::TypedIndexRange {
handle: candidate.handle,
property: candidate.property,
kind: candidate.kind,
bounds: candidate.bounds,
};
true
}
fn rewrite_edge(
edge: &mut EdgeMatch,
bindings: &[BindingDef],
catalog: &dyn crate::IndexCatalog,
) -> bool {
if !matches!(edge.access, ScanAccess::Linear) {
return false;
}
let Some(label) = single_label(&edge.label_predicate) else {
return false;
};
let Some(candidate) = best_candidate(
&edge.property_predicates,
bindings,
catalog,
IndexTarget::Edge,
label.clone(),
) else {
return false;
};
if let (Some(index_cost), Some(baseline)) = (
cost::typed_index_cost(
catalog,
IndexTarget::Edge,
label.clone(),
candidate.property.clone(),
&candidate.bounds,
),
cost::linear_baseline(catalog, IndexTarget::Edge, label),
) && cost::should_decline_index(index_cost, baseline)
{
return false;
}
remove_indices(&mut edge.property_predicates, &candidate.consumed_indices);
edge.access = ScanAccess::TypedIndexRange {
handle: candidate.handle,
property: candidate.property,
kind: candidate.kind,
bounds: candidate.bounds,
};
true
}
struct Candidate {
handle: crate::IndexHandle,
property: selene_core::DbString,
kind: crate::IndexKind,
bounds: TypedIndexBounds,
consumed_indices: Vec<usize>,
}
fn best_candidate(
predicates: &[FilterPredicate],
bindings: &[BindingDef],
catalog: &dyn crate::IndexCatalog,
target: IndexTarget,
label: selene_core::DbString,
) -> Option<Candidate> {
for (index, pred) in predicates.iter().enumerate() {
let Some(matched) = binding_refs::match_property_predicate(pred, bindings) else {
continue;
};
if !binding_is_target(bindings, matched.binding, target) {
continue;
}
let Some(lookup) = catalog.typed_index(target, label.clone(), matched.key.clone()) else {
continue;
};
let Some((bounds, mut consumed_indices)) = bounds_for_property(
matched.key.clone(),
predicates,
bindings,
index,
lookup.kind,
) else {
continue;
};
consumed_indices.sort_unstable();
consumed_indices.dedup();
return Some(Candidate {
handle: lookup.handle,
property: matched.key,
kind: lookup.kind,
bounds,
consumed_indices,
});
}
None
}
fn bounds_for_property(
key: selene_core::DbString,
predicates: &[FilterPredicate],
bindings: &[BindingDef],
first_index: usize,
kind: crate::IndexKind,
) -> Option<(TypedIndexBounds, Vec<usize>)> {
let mut equality: Option<IndexKey> = None;
let mut lower: Option<(IndexKey, bool)> = None;
let mut upper: Option<(IndexKey, bool)> = None;
let mut consumed = Vec::new();
for (index, pred) in predicates.iter().enumerate().skip(first_index) {
let Some(matched) = binding_refs::match_property_predicate(pred, bindings) else {
continue;
};
if matched.key != key {
continue;
}
match matched.shape {
binding_refs::PropertyPredicateShape::Equality(value) => {
let key = compatible_value(value, kind)?;
equality = Some(key);
consumed = vec![index];
lower = None;
upper = None;
break;
}
binding_refs::PropertyPredicateShape::Comparison { op, value } => {
let key = compatible_value(value, kind)?;
let candidate = (key, matches!(op, BinaryOp::Ge | BinaryOp::Le));
let outcome = match op {
BinaryOp::Gt | BinaryOp::Ge => tighten_lower(lower.take(), candidate),
BinaryOp::Lt | BinaryOp::Le => tighten_upper(upper.take(), candidate),
_ => continue,
};
match outcome {
TightenOutcome::Bound(bound) => {
match op {
BinaryOp::Gt | BinaryOp::Ge => lower = Some(bound),
BinaryOp::Lt | BinaryOp::Le => upper = Some(bound),
_ => unreachable!("guarded above"),
}
consumed.push(index);
}
TightenOutcome::KeepExisting(existing) => {
match op {
BinaryOp::Gt | BinaryOp::Ge => lower = Some(existing),
BinaryOp::Lt | BinaryOp::Le => upper = Some(existing),
_ => unreachable!("guarded above"),
}
}
TightenOutcome::Reject => return None,
}
}
binding_refs::PropertyPredicateShape::InList(_)
| binding_refs::PropertyPredicateShape::InListExpression(_) => {}
}
}
if let Some(key) = equality {
return Some((TypedIndexBounds::Equality(key), consumed));
}
match (lower, upper) {
(Some((lo, lo_inclusive)), Some((hi, hi_inclusive))) => {
if let (IndexKey::Literal(lo_lit), IndexKey::Literal(hi_lit)) = (&lo, &hi)
&& !range_satisfiable(lo_lit, lo_inclusive, hi_lit, hi_inclusive)
{
return None;
}
Some((
TypedIndexBounds::Range {
lo,
lo_inclusive,
hi,
hi_inclusive,
},
consumed,
))
}
(Some((key, false)), None) => Some((TypedIndexBounds::GreaterThan(key), consumed)),
(Some((key, true)), None) => Some((TypedIndexBounds::GreaterEqual(key), consumed)),
(None, Some((key, false))) => Some((TypedIndexBounds::LessThan(key), consumed)),
(None, Some((key, true))) => Some((TypedIndexBounds::LessEqual(key), consumed)),
(None, None) => None,
}
}
enum TightenOutcome {
Bound((IndexKey, bool)),
KeepExisting((IndexKey, bool)),
Reject,
}
fn tighten_lower(
existing: Option<(IndexKey, bool)>,
candidate: (IndexKey, bool),
) -> TightenOutcome {
let Some(existing) = existing else {
return TightenOutcome::Bound(candidate);
};
match (&existing.0, &candidate.0) {
(IndexKey::Literal(existing_lit), IndexKey::Literal(candidate_lit)) => {
let Some(ordering) = compare_literals(existing_lit, candidate_lit) else {
return TightenOutcome::Reject;
};
TightenOutcome::Bound(match ordering {
std::cmp::Ordering::Less => candidate,
std::cmp::Ordering::Greater => existing,
std::cmp::Ordering::Equal => {
(existing.0, existing.1 && candidate.1)
}
})
}
_ => TightenOutcome::KeepExisting(existing),
}
}
fn tighten_upper(
existing: Option<(IndexKey, bool)>,
candidate: (IndexKey, bool),
) -> TightenOutcome {
let Some(existing) = existing else {
return TightenOutcome::Bound(candidate);
};
match (&existing.0, &candidate.0) {
(IndexKey::Literal(existing_lit), IndexKey::Literal(candidate_lit)) => {
let Some(ordering) = compare_literals(existing_lit, candidate_lit) else {
return TightenOutcome::Reject;
};
TightenOutcome::Bound(match ordering {
std::cmp::Ordering::Less => existing,
std::cmp::Ordering::Greater => candidate,
std::cmp::Ordering::Equal => (existing.0, existing.1 && candidate.1),
})
}
_ => TightenOutcome::KeepExisting(existing),
}
}
fn compare_literals(a: &Literal, b: &Literal) -> Option<std::cmp::Ordering> {
match (a, b) {
(Literal::Integer(lhs, _), Literal::Integer(rhs, _)) => Some(lhs.cmp(rhs)),
(Literal::Integer(lhs, _), Literal::RadixInteger(rhs, _, _))
| (Literal::RadixInteger(lhs, _, _), Literal::Integer(rhs, _))
| (Literal::RadixInteger(lhs, _, _), Literal::RadixInteger(rhs, _, _)) => {
Some(lhs.cmp(rhs))
}
(Literal::Float(lhs, _, _), Literal::Float(rhs, _, _)) => lhs.partial_cmp(rhs),
(Literal::Decimal(lhs, _, _), Literal::Decimal(rhs, _, _)) => Some(lhs.cmp(rhs)),
(Literal::String(lhs, _, _), Literal::String(rhs, _, _)) => {
Some(lhs.as_str().cmp(rhs.as_str()))
}
(Literal::Bytes(lhs, _), Literal::Bytes(rhs, _)) => Some(lhs.as_ref().cmp(rhs.as_ref())),
(Literal::Date(lhs, _, _), Literal::Date(rhs, _, _)) => Some(lhs.cmp(rhs)),
(Literal::LocalDateTime(lhs, _, _), Literal::LocalDateTime(rhs, _, _)) => {
Some(lhs.cmp(rhs))
}
(Literal::ZonedDateTime(lhs, _, _), Literal::ZonedDateTime(rhs, _, _)) => {
Some(lhs.cmp(rhs))
}
(Literal::LocalTime(lhs, _, _), Literal::LocalTime(rhs, _, _)) => Some(lhs.cmp(rhs)),
(Literal::ZonedTime(lhs, _, _), Literal::ZonedTime(rhs, _, _)) => Some(lhs.cmp(rhs)),
(Literal::Duration(lhs, _, _), Literal::Duration(rhs, _, _)) => {
Some(selene_core::duration_order_key(lhs).cmp(&selene_core::duration_order_key(rhs)))
}
(Literal::Bool(lhs, _), Literal::Bool(rhs, _)) => Some(lhs.cmp(rhs)),
_ => None,
}
}
fn range_satisfiable(lo: &Literal, lo_inclusive: bool, hi: &Literal, hi_inclusive: bool) -> bool {
let Some(ordering) = compare_literals(lo, hi) else {
return false;
};
match ordering {
std::cmp::Ordering::Less => true,
std::cmp::Ordering::Greater => false,
std::cmp::Ordering::Equal => lo_inclusive && hi_inclusive,
}
}
fn target_for_scan_kind(kind: ScanKind) -> IndexTarget {
match kind {
ScanKind::Node => IndexTarget::Node,
ScanKind::Edge => IndexTarget::Edge,
}
}
fn binding_is_target(
bindings: &[BindingDef],
binding_id: crate::BindingId,
target: IndexTarget,
) -> bool {
let element = match target {
IndexTarget::Node => BindingElement::Node,
IndexTarget::Edge => BindingElement::Edge,
};
bindings
.iter()
.any(|binding| binding.binding == binding_id && binding.element == element)
}
fn remove_indices(predicates: &mut Vec<FilterPredicate>, indices: &[usize]) {
let mut cursor = 0usize;
predicates.retain(|_| {
let remove = indices.binary_search(&cursor).is_ok();
cursor += 1;
!remove
});
}