use selene_core::{DbString, Value};
use crate::Literal;
use crate::plan::{IndexKey, IndexTarget, TypedIndexBounds, optimize::IndexCatalog};
const RANGE_DIVISOR: u64 = 3;
pub fn linear_baseline(
catalog: &dyn IndexCatalog,
target: IndexTarget,
label: DbString,
) -> Option<u64> {
match catalog.label_cardinality(target, label) {
Some(rows) => Some(rows),
None => catalog.total_rows(target),
}
}
pub fn label_scan_cost(
catalog: &dyn IndexCatalog,
target: IndexTarget,
label: DbString,
) -> Option<u64> {
catalog.label_cardinality(target, label)
}
pub fn typed_index_cost(
catalog: &dyn IndexCatalog,
target: IndexTarget,
label: DbString,
property: DbString,
bounds: &TypedIndexBounds,
) -> Option<u64> {
match bounds {
TypedIndexBounds::Equality(key) => equality_cost(catalog, target, label, property, key),
TypedIndexBounds::GreaterThan(key)
| TypedIndexBounds::GreaterEqual(key)
| TypedIndexBounds::LessThan(key)
| TypedIndexBounds::LessEqual(key) => {
single_ended_range_cost(catalog, target, label, property, bounds, key)
}
TypedIndexBounds::Range {
lo,
lo_inclusive,
hi,
hi_inclusive,
} => {
if let (Some(lo_val), Some(hi_val)) = (literal_value(lo), literal_value(hi)) {
let range = (
bound(lo_val, *lo_inclusive, BoundEnd::Lower),
bound(hi_val, *hi_inclusive, BoundEnd::Upper),
);
catalog.range_cardinality(target, label, property, range)
} else {
parameter_range_cost(catalog, target, label)
}
}
}
}
fn equality_cost(
catalog: &dyn IndexCatalog,
target: IndexTarget,
label: DbString,
property: DbString,
key: &IndexKey,
) -> Option<u64> {
match literal_value(key) {
Some(value) => catalog.equality_cardinality(target, label, property, &value),
None => catalog.typed_avg_bucket(target, label, property),
}
}
fn single_ended_range_cost(
catalog: &dyn IndexCatalog,
target: IndexTarget,
label: DbString,
property: DbString,
bounds: &TypedIndexBounds,
key: &IndexKey,
) -> Option<u64> {
let Some(value) = literal_value(key) else {
return parameter_range_cost(catalog, target, label);
};
let range = match bounds {
TypedIndexBounds::GreaterThan(_) => {
(std::ops::Bound::Excluded(value), std::ops::Bound::Unbounded)
}
TypedIndexBounds::GreaterEqual(_) => {
(std::ops::Bound::Included(value), std::ops::Bound::Unbounded)
}
TypedIndexBounds::LessThan(_) => {
(std::ops::Bound::Unbounded, std::ops::Bound::Excluded(value))
}
TypedIndexBounds::LessEqual(_) => {
(std::ops::Bound::Unbounded, std::ops::Bound::Included(value))
}
TypedIndexBounds::Equality(_) | TypedIndexBounds::Range { .. } => return None,
};
catalog.range_cardinality(target, label, property, range)
}
fn parameter_range_cost(
catalog: &dyn IndexCatalog,
target: IndexTarget,
label: DbString,
) -> Option<u64> {
let population = catalog
.label_cardinality(target, label)
.or_else(|| catalog.total_rows(target))?;
Some((population / RANGE_DIVISOR).max(1).min(population))
}
pub fn in_list_cost(
catalog: &dyn IndexCatalog,
target: IndexTarget,
label: DbString,
property: DbString,
keys: &[IndexKey],
) -> Option<u64> {
let mut total: u64 = 0;
for key in keys {
let element = match literal_value(key) {
Some(value) => {
catalog.equality_cardinality(target, label.clone(), property.clone(), &value)?
}
None => catalog.typed_avg_bucket(target, label.clone(), property.clone())?,
};
total = total.saturating_add(element);
}
Some(total)
}
pub fn composite_cost(
catalog: &dyn IndexCatalog,
target: IndexTarget,
label: DbString,
properties: &[DbString],
keys: &[IndexKey],
) -> Option<u64> {
let mut literal_keys: Vec<Value> = Vec::with_capacity(keys.len());
let mut all_literal = true;
for key in keys {
match literal_value(key) {
Some(value) => literal_keys.push(value),
None => {
all_literal = false;
break;
}
}
}
if all_literal {
catalog.composite_cardinality(target, label, properties, &literal_keys)
} else {
catalog.composite_avg_bucket(target, label, properties)
}
}
pub fn disjunctive_cost(
catalog: &dyn IndexCatalog,
target: IndexTarget,
labels: &[DbString],
) -> Option<u64> {
let mut total: u64 = 0;
for label in labels {
total = total.saturating_add(catalog.label_cardinality(target, label.clone())?);
}
Some(total)
}
pub fn should_decline_index(index_cost: u64, baseline: u64) -> bool {
baseline > 0 && index_cost >= baseline
}
enum BoundEnd {
Lower,
Upper,
}
fn bound(value: Value, inclusive: bool, _end: BoundEnd) -> std::ops::Bound<Value> {
if inclusive {
std::ops::Bound::Included(value)
} else {
std::ops::Bound::Excluded(value)
}
}
fn literal_value(key: &IndexKey) -> Option<Value> {
match key {
IndexKey::Literal(literal) => literal_to_value(literal),
IndexKey::Parameter { .. } | IndexKey::ParameterList { .. } => None,
}
}
fn literal_to_value(literal: &Literal) -> Option<Value> {
Some(match literal {
Literal::Bool(value, _) => Value::Bool(*value),
Literal::Integer(value, _) | Literal::RadixInteger(value, _, _) => Value::Int(*value),
Literal::Decimal(value, _, _) => Value::Decimal(*value),
Literal::Float(value, _, _) => Value::Float(*value),
Literal::String(value, _, _) => Value::String(value.clone()),
Literal::Bytes(value, _) => Value::Bytes(value.clone()),
Literal::Uuid(value, _, _) => Value::Uuid(*value),
Literal::ZonedDateTime(value, _, _) => Value::ZonedDateTime(value.clone()),
Literal::LocalDateTime(value, _, _) => Value::LocalDateTime(*value),
Literal::Date(value, _, _) => Value::Date(*value),
Literal::ZonedTime(value, _, _) => Value::ZonedTime(value.clone()),
Literal::LocalTime(value, _, _) => Value::LocalTime(*value),
Literal::Duration(value, _, _) => Value::Duration(value.clone()),
Literal::Null(_) => return None,
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn should_decline_index_semantics() {
assert!(!should_decline_index(5, 10));
assert!(should_decline_index(10, 10), "equal cost is not better");
assert!(should_decline_index(11, 10), "more expensive");
assert!(!should_decline_index(0, 0), "empty graph keeps the index");
}
}