use crate::ast::{JoinQuery, JoinType, QueryExpr};
use crate::sql_lowering::{effective_table_filter, effective_vector_filter};
pub trait OptimizationPass: Send + Sync {
fn name(&self) -> &str;
fn apply(&self, query: QueryExpr) -> QueryExpr;
fn benefit(&self) -> u32;
}
pub struct QueryOptimizer {
passes: Vec<Box<dyn OptimizationPass>>,
cost_based: bool,
}
impl QueryOptimizer {
pub fn new() -> Self {
let passes: Vec<Box<dyn OptimizationPass>> = vec![
Box::new(PredicatePushdownPass),
Box::new(ProjectionPushdownPass),
Box::new(JoinReorderingPass),
Box::new(IndexSelectionPass),
Box::new(LimitPushdownPass),
];
Self {
passes,
cost_based: true,
}
}
pub fn add_pass(&mut self, pass: Box<dyn OptimizationPass>) {
self.passes.push(pass);
self.passes.sort_by_key(|b| std::cmp::Reverse(b.benefit()));
}
pub fn optimize(&self, query: QueryExpr) -> (QueryExpr, Vec<String>) {
let mut optimized = query;
let mut applied_passes = Vec::new();
for pass in &self.passes {
let before = format!("{:?}", optimized);
optimized = pass.apply(optimized);
let after = format!("{:?}", optimized);
if before != after {
applied_passes.push(pass.name().to_string());
}
}
(optimized, applied_passes)
}
pub fn optimize_with_hints(&self, query: QueryExpr, hints: &OptimizationHints) -> QueryExpr {
let mut optimized = query;
for pass in &self.passes {
if hints.disabled_passes.contains(&pass.name().to_string()) {
continue;
}
optimized = pass.apply(optimized);
}
optimized
}
}
impl Default for QueryOptimizer {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Default)]
pub struct OptimizationHints {
pub disabled_passes: Vec<String>,
pub join_order: Option<Vec<String>>,
pub force_index: Option<String>,
pub no_parallel: bool,
}
struct PredicatePushdownPass;
impl OptimizationPass for PredicatePushdownPass {
fn name(&self) -> &str {
"PredicatePushdown"
}
fn apply(&self, query: QueryExpr) -> QueryExpr {
match query {
QueryExpr::Join(jq) => self.optimize_join(jq),
other => other,
}
}
fn benefit(&self) -> u32 {
100 }
}
impl PredicatePushdownPass {
fn optimize_join(&self, query: JoinQuery) -> QueryExpr {
let left = self.apply(*query.left);
let right = self.apply(*query.right);
QueryExpr::Join(JoinQuery {
left: Box::new(left),
right: Box::new(right),
..query
})
}
}
struct ProjectionPushdownPass;
impl OptimizationPass for ProjectionPushdownPass {
fn name(&self) -> &str {
"ProjectionPushdown"
}
fn apply(&self, query: QueryExpr) -> QueryExpr {
match query {
QueryExpr::Join(jq) => {
let left = self.apply(*jq.left);
let right = self.apply(*jq.right);
QueryExpr::Join(JoinQuery {
left: Box::new(left),
right: Box::new(right),
..jq
})
}
QueryExpr::Table(tq) => {
QueryExpr::Table(tq)
}
other => other,
}
}
fn benefit(&self) -> u32 {
80 }
}
struct JoinReorderingPass;
impl OptimizationPass for JoinReorderingPass {
fn name(&self) -> &str {
"JoinReordering"
}
fn apply(&self, query: QueryExpr) -> QueryExpr {
match query {
QueryExpr::Join(jq) => {
self.optimize_join_order(jq)
}
other => other,
}
}
fn benefit(&self) -> u32 {
90 }
}
impl JoinReorderingPass {
fn optimize_join_order(&self, query: JoinQuery) -> QueryExpr {
let left_size = Self::estimate_size(&query.left);
let right_size = Self::estimate_size(&query.right);
if left_size > right_size && query.join_type == JoinType::Inner {
let JoinQuery {
left,
right,
join_type,
on,
filter,
order_by,
limit,
offset,
return_items,
return_,
} = query;
QueryExpr::Join(JoinQuery {
left: right,
right: left,
join_type,
on: swap_condition(on),
filter,
order_by,
limit,
offset,
return_items,
return_,
})
} else {
QueryExpr::Join(query)
}
}
fn estimate_size(query: &QueryExpr) -> f64 {
match query {
QueryExpr::Table(tq) => {
let base = 1000.0;
if effective_table_filter(tq).is_some() {
base * 0.1
} else if tq.limit.is_some() {
tq.limit.unwrap() as f64
} else {
base
}
}
QueryExpr::Graph(_) => 100.0,
QueryExpr::Join(jq) => {
Self::estimate_size(&jq.left) * Self::estimate_size(&jq.right) * 0.1
}
QueryExpr::Path(_) => 10.0,
QueryExpr::Vector(vq) => {
if effective_vector_filter(vq).is_some() {
(vq.k as f64).min(100.0)
} else {
vq.k as f64
}
}
QueryExpr::Hybrid(hq) => {
let structured_size = Self::estimate_size(&hq.structured);
let vector_size = hq.vector.k as f64;
let base = structured_size.min(vector_size);
hq.limit.map(|l| base.min(l as f64)).unwrap_or(base)
}
QueryExpr::Insert(_)
| QueryExpr::Update(_)
| QueryExpr::Delete(_)
| QueryExpr::CreateTable(_)
| QueryExpr::CreateCollection(_)
| QueryExpr::CreateVector(_)
| QueryExpr::DropTable(_)
| QueryExpr::DropGraph(_)
| QueryExpr::DropVector(_)
| QueryExpr::DropDocument(_)
| QueryExpr::DropKv(_)
| QueryExpr::DropCollection(_)
| QueryExpr::Truncate(_)
| QueryExpr::AlterTable(_)
| QueryExpr::GraphCommand(_)
| QueryExpr::SearchCommand(_)
| QueryExpr::CreateIndex(_)
| QueryExpr::DropIndex(_)
| QueryExpr::ProbabilisticCommand(_)
| QueryExpr::Ask(_)
| QueryExpr::SetConfig { .. }
| QueryExpr::ShowConfig { .. }
| QueryExpr::SetSecret { .. }
| QueryExpr::DeleteSecret { .. }
| QueryExpr::ShowSecrets { .. }
| QueryExpr::SetTenant(_)
| QueryExpr::ShowTenant
| QueryExpr::CreateTimeSeries(_)
| QueryExpr::CreateMetric(_)
| QueryExpr::AlterMetric(_)
| QueryExpr::CreateSlo(_)
| QueryExpr::DropTimeSeries(_)
| QueryExpr::CreateQueue(_)
| QueryExpr::AlterQueue(_)
| QueryExpr::DropQueue(_)
| QueryExpr::QueueSelect(_)
| QueryExpr::QueueCommand(_)
| QueryExpr::KvCommand(_)
| QueryExpr::ConfigCommand(_)
| QueryExpr::CreateTree(_)
| QueryExpr::DropTree(_)
| QueryExpr::TreeCommand(_)
| QueryExpr::ExplainAlter(_)
| QueryExpr::TransactionControl(_)
| QueryExpr::MaintenanceCommand(_)
| QueryExpr::CreateSchema(_)
| QueryExpr::DropSchema(_)
| QueryExpr::CreateSequence(_)
| QueryExpr::DropSequence(_)
| QueryExpr::CopyFrom(_)
| QueryExpr::CreateView(_)
| QueryExpr::DropView(_)
| QueryExpr::RefreshMaterializedView(_)
| QueryExpr::CreatePolicy(_)
| QueryExpr::DropPolicy(_)
| QueryExpr::CreateServer(_)
| QueryExpr::DropServer(_)
| QueryExpr::CreateForeignTable(_)
| QueryExpr::DropForeignTable(_)
| QueryExpr::Grant(_)
| QueryExpr::Revoke(_)
| QueryExpr::AlterUser(_)
| QueryExpr::CreateUser(_)
| QueryExpr::CreateIamPolicy { .. }
| QueryExpr::DropIamPolicy { .. }
| QueryExpr::AttachPolicy { .. }
| QueryExpr::DetachPolicy { .. }
| QueryExpr::ShowPolicies { .. }
| QueryExpr::ShowEffectivePermissions { .. }
| QueryExpr::RankOf(_)
| QueryExpr::ApproxRankOf(_)
| QueryExpr::RankRange(_)
| QueryExpr::SimulatePolicy { .. }
| QueryExpr::LintPolicy { .. }
| QueryExpr::MigratePolicyMode { .. }
| QueryExpr::CreateMigration(_)
| QueryExpr::ApplyMigration(_)
| QueryExpr::RollbackMigration(_)
| QueryExpr::ExplainMigration(_)
| QueryExpr::EventsBackfill(_)
| QueryExpr::EventsBackfillStatus { .. } => 1.0,
}
}
}
struct IndexSelectionPass;
impl OptimizationPass for IndexSelectionPass {
fn name(&self) -> &str {
"IndexSelection"
}
fn apply(&self, query: QueryExpr) -> QueryExpr {
match query {
QueryExpr::Table(mut tq) => {
if let Some(filter) = effective_table_filter(&tq).as_ref() {
if let Some(hint) = Self::analyze_filter(filter) {
let expand = tq.expand.get_or_insert_with(Default::default);
expand.index_hint = Some(hint);
}
}
QueryExpr::Table(tq)
}
other => other,
}
}
fn benefit(&self) -> u32 {
70
}
}
impl IndexSelectionPass {
fn analyze_filter(filter: &crate::ast::Filter) -> Option<IndexHint> {
match filter {
crate::ast::Filter::Compare { field, op, .. } if *op == crate::ast::CompareOp::Eq => {
let col = Self::field_name(field);
Some(IndexHint {
method: IndexHintMethod::Hash,
column: col,
})
}
crate::ast::Filter::Compare {
field,
op:
crate::ast::CompareOp::Lt
| crate::ast::CompareOp::Le
| crate::ast::CompareOp::Gt
| crate::ast::CompareOp::Ge,
..
} => {
let col = Self::field_name(field);
Some(IndexHint {
method: IndexHintMethod::BTree,
column: col,
})
}
crate::ast::Filter::Between { field, .. } => {
let col = Self::field_name(field);
Some(IndexHint {
method: IndexHintMethod::BTree,
column: col,
})
}
crate::ast::Filter::In { field, values } if values.len() <= 10 => {
let col = Self::field_name(field);
Some(IndexHint {
method: IndexHintMethod::Bitmap,
column: col,
})
}
crate::ast::Filter::And(left, right) => {
Self::analyze_filter(left).or_else(|| Self::analyze_filter(right))
}
_ => None,
}
}
fn field_name(field: &crate::ast::FieldRef) -> String {
match field {
crate::ast::FieldRef::TableColumn { column, .. } => column.clone(),
crate::ast::FieldRef::NodeProperty { property, .. } => property.clone(),
crate::ast::FieldRef::EdgeProperty { property, .. } => property.clone(),
crate::ast::FieldRef::NodeId { alias } => {
format!("{}.id", alias)
}
}
}
}
pub use reddb_types::index_hint::{IndexHint, IndexHintMethod};
struct LimitPushdownPass;
impl OptimizationPass for LimitPushdownPass {
fn name(&self) -> &str {
"LimitPushdown"
}
fn apply(&self, query: QueryExpr) -> QueryExpr {
match query {
QueryExpr::Join(jq) => {
let left = self.apply(*jq.left);
let right = self.apply(*jq.right);
QueryExpr::Join(JoinQuery {
left: Box::new(left),
right: Box::new(right),
..jq
})
}
other => other,
}
}
fn benefit(&self) -> u32 {
60
}
}
fn swap_condition(condition: crate::ast::JoinCondition) -> crate::ast::JoinCondition {
crate::ast::JoinCondition {
left_field: condition.right_field,
right_field: condition.left_field,
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ast::{
CompareOp, DistanceMetric, FieldRef, Filter, FusionStrategy, JoinCondition, Projection,
TableQuery,
};
use reddb_types::Value;
fn make_table_query(name: &str) -> QueryExpr {
QueryExpr::Table(TableQuery {
table: name.to_string(),
source: None,
alias: Some(name.to_string()),
select_items: Vec::new(),
columns: vec![Projection::All],
where_expr: None,
filter: None,
group_by_exprs: Vec::new(),
group_by: Vec::new(),
having_expr: None,
having: None,
order_by: vec![],
limit: None,
limit_param: None,
offset: None,
offset_param: None,
expand: None,
as_of: None,
sessionize: None,
distinct: false,
})
}
#[test]
fn test_optimizer_applies_passes() {
let optimizer = QueryOptimizer::new();
let query = make_table_query("hosts");
let (optimized, passes) = optimizer.optimize(query);
assert!(matches!(optimized, QueryExpr::Table(_)));
}
#[test]
fn test_join_reordering() {
let optimizer = QueryOptimizer::new();
let small = QueryExpr::Table(TableQuery {
table: "small".to_string(),
source: None,
alias: None,
select_items: Vec::new(),
columns: vec![Projection::All],
where_expr: None,
filter: None,
group_by_exprs: Vec::new(),
group_by: Vec::new(),
having_expr: None,
having: None,
order_by: vec![],
limit: Some(10), limit_param: None,
offset: None,
offset_param: None,
expand: None,
as_of: None,
sessionize: None,
distinct: false,
});
let large = QueryExpr::Table(TableQuery {
table: "large".to_string(),
source: None,
alias: None,
select_items: Vec::new(),
columns: vec![Projection::All],
where_expr: None,
filter: None,
group_by_exprs: Vec::new(),
group_by: Vec::new(),
having_expr: None,
having: None,
order_by: vec![],
limit: None, limit_param: None,
offset: None,
offset_param: None,
expand: None,
as_of: None,
sessionize: None,
distinct: false,
});
let join = QueryExpr::Join(JoinQuery {
left: Box::new(large.clone()),
right: Box::new(small.clone()),
join_type: JoinType::Inner,
on: JoinCondition {
left_field: FieldRef::TableColumn {
table: "large".to_string(),
column: "id".to_string(),
},
right_field: FieldRef::TableColumn {
table: "small".to_string(),
column: "id".to_string(),
},
},
filter: None,
order_by: Vec::new(),
limit: None,
offset: None,
return_items: Vec::new(),
return_: Vec::new(),
});
let (optimized, passes) = optimizer.optimize(join);
assert!(passes.iter().any(|pass| pass == "JoinReordering"));
if let QueryExpr::Join(jq) = optimized {
if let QueryExpr::Table(left) = jq.left.as_ref() {
assert_eq!(left.table, "small");
}
assert!(matches!(
&jq.on.left_field,
FieldRef::TableColumn { table, column } if table == "small" && column == "id"
));
}
}
#[test]
fn optimize_with_hints_can_disable_join_reordering() {
let optimizer = QueryOptimizer::new();
let large = make_table_query("large");
let mut small_table = TableQuery::new("small");
small_table.limit = Some(1);
let small = QueryExpr::Table(small_table);
let join = QueryExpr::Join(JoinQuery {
left: Box::new(large),
right: Box::new(small),
join_type: JoinType::Inner,
on: JoinCondition {
left_field: FieldRef::TableColumn {
table: "large".to_string(),
column: "id".to_string(),
},
right_field: FieldRef::TableColumn {
table: "small".to_string(),
column: "id".to_string(),
},
},
filter: None,
order_by: Vec::new(),
limit: None,
offset: None,
return_items: Vec::new(),
return_: Vec::new(),
});
let hints = OptimizationHints {
disabled_passes: vec!["JoinReordering".to_string()],
..OptimizationHints::default()
};
let optimized = optimizer.optimize_with_hints(join, &hints);
let QueryExpr::Join(join) = optimized else {
panic!("expected join query");
};
let QueryExpr::Table(left) = join.left.as_ref() else {
panic!("expected table on left side");
};
assert_eq!(left.table, "large");
assert!(matches!(
&join.on.left_field,
FieldRef::TableColumn { table, column } if table == "large" && column == "id"
));
}
#[test]
fn optimizer_sets_index_hint_on_table_filters() {
let optimizer = QueryOptimizer::new();
let mut table = TableQuery::new("hosts");
table.filter = Some(Filter::Compare {
field: FieldRef::column("", "host_id"),
op: CompareOp::Eq,
value: Value::Integer(7),
});
let (optimized, passes) = optimizer.optimize(QueryExpr::Table(table));
let QueryExpr::Table(table) = optimized else {
panic!("expected table query");
};
let hint = table
.expand
.and_then(|expand| expand.index_hint)
.expect("expected optimizer index hint");
assert!(passes.iter().any(|pass| pass == "IndexSelection"));
assert_eq!(hint.method, IndexHintMethod::Hash);
assert_eq!(hint.column, "host_id");
}
#[test]
fn index_selection_analyzes_supported_filter_shapes() {
let range = IndexSelectionPass::analyze_filter(&Filter::Between {
field: FieldRef::node_prop("n", "score"),
low: Value::Integer(1),
high: Value::Integer(9),
})
.expect("expected range hint");
assert_eq!(range.method, IndexHintMethod::BTree);
assert_eq!(range.column, "score");
let bitmap = IndexSelectionPass::analyze_filter(&Filter::In {
field: FieldRef::edge_prop("e", "kind"),
values: vec![Value::text("http"), Value::text("ssh")],
})
.expect("expected bitmap hint");
assert_eq!(bitmap.method, IndexHintMethod::Bitmap);
assert_eq!(bitmap.column, "kind");
assert!(IndexSelectionPass::analyze_filter(&Filter::In {
field: FieldRef::column("", "status"),
values: (0..11).map(Value::Integer).collect(),
})
.is_none());
let fallback_right = IndexSelectionPass::analyze_filter(&Filter::And(
Box::new(Filter::IsNull(FieldRef::column("", "deleted_at"))),
Box::new(Filter::Compare {
field: FieldRef::node_id("n"),
op: CompareOp::Eq,
value: Value::Integer(1),
}),
))
.expect("expected right-side AND hint");
assert_eq!(fallback_right.method, IndexHintMethod::Hash);
assert_eq!(fallback_right.column, "n.id");
}
}