use crate::storage::query::ast::{JoinQuery, JoinType, QueryExpr};
use crate::storage::query::sql_lowering::{effective_table_filter, effective_vector_filter};
pub trait OptimizationPass: Send + Sync {
fn name(&self) -> &str;
fn apply(&self, query: QueryExpr) -> QueryExpr;
fn benefit(&self) -> u32;
}
pub struct QueryOptimizer {
passes: Vec<Box<dyn OptimizationPass>>,
cost_based: bool,
}
impl QueryOptimizer {
pub fn new() -> Self {
let passes: Vec<Box<dyn OptimizationPass>> = vec![
Box::new(PredicatePushdownPass),
Box::new(ProjectionPushdownPass),
Box::new(JoinReorderingPass),
Box::new(IndexSelectionPass),
Box::new(LimitPushdownPass),
];
Self {
passes,
cost_based: true,
}
}
pub fn add_pass(&mut self, pass: Box<dyn OptimizationPass>) {
self.passes.push(pass);
self.passes.sort_by_key(|b| std::cmp::Reverse(b.benefit()));
}
pub fn optimize(&self, query: QueryExpr) -> (QueryExpr, Vec<String>) {
let mut optimized = query;
let mut applied_passes = Vec::new();
for pass in &self.passes {
let before = format!("{:?}", optimized);
optimized = pass.apply(optimized);
let after = format!("{:?}", optimized);
if before != after {
applied_passes.push(pass.name().to_string());
}
}
(optimized, applied_passes)
}
pub fn optimize_with_hints(&self, query: QueryExpr, hints: &OptimizationHints) -> QueryExpr {
let mut optimized = query;
for pass in &self.passes {
if hints.disabled_passes.contains(&pass.name().to_string()) {
continue;
}
optimized = pass.apply(optimized);
}
optimized
}
}
impl Default for QueryOptimizer {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Default)]
pub struct OptimizationHints {
pub disabled_passes: Vec<String>,
pub join_order: Option<Vec<String>>,
pub force_index: Option<String>,
pub no_parallel: bool,
}
struct PredicatePushdownPass;
impl OptimizationPass for PredicatePushdownPass {
fn name(&self) -> &str {
"PredicatePushdown"
}
fn apply(&self, query: QueryExpr) -> QueryExpr {
match query {
QueryExpr::Join(jq) => self.optimize_join(jq),
other => other,
}
}
fn benefit(&self) -> u32 {
100 }
}
impl PredicatePushdownPass {
fn optimize_join(&self, query: JoinQuery) -> QueryExpr {
let left = self.apply(*query.left);
let right = self.apply(*query.right);
QueryExpr::Join(JoinQuery {
left: Box::new(left),
right: Box::new(right),
..query
})
}
}
struct ProjectionPushdownPass;
impl OptimizationPass for ProjectionPushdownPass {
fn name(&self) -> &str {
"ProjectionPushdown"
}
fn apply(&self, query: QueryExpr) -> QueryExpr {
match query {
QueryExpr::Join(jq) => {
let left = self.apply(*jq.left);
let right = self.apply(*jq.right);
QueryExpr::Join(JoinQuery {
left: Box::new(left),
right: Box::new(right),
..jq
})
}
QueryExpr::Table(tq) => {
QueryExpr::Table(tq)
}
other => other,
}
}
fn benefit(&self) -> u32 {
80 }
}
struct JoinReorderingPass;
impl OptimizationPass for JoinReorderingPass {
fn name(&self) -> &str {
"JoinReordering"
}
fn apply(&self, query: QueryExpr) -> QueryExpr {
match query {
QueryExpr::Join(jq) => {
self.optimize_join_order(jq)
}
other => other,
}
}
fn benefit(&self) -> u32 {
90 }
}
impl JoinReorderingPass {
fn optimize_join_order(&self, query: JoinQuery) -> QueryExpr {
let left_size = Self::estimate_size(&query.left);
let right_size = Self::estimate_size(&query.right);
if left_size > right_size && query.join_type == JoinType::Inner {
let JoinQuery {
left,
right,
join_type,
on,
filter,
order_by,
limit,
offset,
return_items,
return_,
} = query;
QueryExpr::Join(JoinQuery {
left: right,
right: left,
join_type,
on: swap_condition(on),
filter,
order_by,
limit,
offset,
return_items,
return_,
})
} else {
QueryExpr::Join(query)
}
}
fn estimate_size(query: &QueryExpr) -> f64 {
match query {
QueryExpr::Table(tq) => {
let base = 1000.0;
if effective_table_filter(tq).is_some() {
base * 0.1
} else if tq.limit.is_some() {
tq.limit.unwrap() as f64
} else {
base
}
}
QueryExpr::Graph(_) => 100.0,
QueryExpr::Join(jq) => {
Self::estimate_size(&jq.left) * Self::estimate_size(&jq.right) * 0.1
}
QueryExpr::Path(_) => 10.0,
QueryExpr::Vector(vq) => {
if effective_vector_filter(vq).is_some() {
(vq.k as f64).min(100.0)
} else {
vq.k as f64
}
}
QueryExpr::Hybrid(hq) => {
let structured_size = Self::estimate_size(&hq.structured);
let vector_size = hq.vector.k as f64;
let base = structured_size.min(vector_size);
hq.limit.map(|l| base.min(l as f64)).unwrap_or(base)
}
QueryExpr::Insert(_)
| QueryExpr::Update(_)
| QueryExpr::Delete(_)
| QueryExpr::CreateTable(_)
| QueryExpr::DropTable(_)
| QueryExpr::DropGraph(_)
| QueryExpr::DropVector(_)
| QueryExpr::DropDocument(_)
| QueryExpr::DropKv(_)
| QueryExpr::DropCollection(_)
| QueryExpr::Truncate(_)
| QueryExpr::AlterTable(_)
| QueryExpr::GraphCommand(_)
| QueryExpr::SearchCommand(_)
| QueryExpr::CreateIndex(_)
| QueryExpr::DropIndex(_)
| QueryExpr::ProbabilisticCommand(_)
| QueryExpr::Ask(_)
| QueryExpr::SetConfig { .. }
| QueryExpr::ShowConfig { .. }
| QueryExpr::SetSecret { .. }
| QueryExpr::DeleteSecret { .. }
| QueryExpr::ShowSecrets { .. }
| QueryExpr::SetTenant(_)
| QueryExpr::ShowTenant
| QueryExpr::CreateTimeSeries(_)
| QueryExpr::DropTimeSeries(_)
| QueryExpr::CreateQueue(_)
| QueryExpr::AlterQueue(_)
| QueryExpr::DropQueue(_)
| QueryExpr::QueueSelect(_)
| QueryExpr::QueueCommand(_)
| QueryExpr::KvCommand(_)
| QueryExpr::ConfigCommand(_)
| QueryExpr::CreateTree(_)
| QueryExpr::DropTree(_)
| QueryExpr::TreeCommand(_)
| QueryExpr::ExplainAlter(_)
| QueryExpr::TransactionControl(_)
| QueryExpr::MaintenanceCommand(_)
| QueryExpr::CreateSchema(_)
| QueryExpr::DropSchema(_)
| QueryExpr::CreateSequence(_)
| QueryExpr::DropSequence(_)
| QueryExpr::CopyFrom(_)
| QueryExpr::CreateView(_)
| QueryExpr::DropView(_)
| QueryExpr::RefreshMaterializedView(_)
| QueryExpr::CreatePolicy(_)
| QueryExpr::DropPolicy(_)
| QueryExpr::CreateServer(_)
| QueryExpr::DropServer(_)
| QueryExpr::CreateForeignTable(_)
| QueryExpr::DropForeignTable(_)
| QueryExpr::Grant(_)
| QueryExpr::Revoke(_)
| QueryExpr::AlterUser(_)
| QueryExpr::CreateIamPolicy { .. }
| QueryExpr::DropIamPolicy { .. }
| QueryExpr::AttachPolicy { .. }
| QueryExpr::DetachPolicy { .. }
| QueryExpr::ShowPolicies { .. }
| QueryExpr::ShowEffectivePermissions { .. }
| QueryExpr::SimulatePolicy { .. }
| QueryExpr::CreateMigration(_)
| QueryExpr::ApplyMigration(_)
| QueryExpr::RollbackMigration(_)
| QueryExpr::ExplainMigration(_)
| QueryExpr::EventsBackfill(_)
| QueryExpr::EventsBackfillStatus { .. } => 1.0,
}
}
}
struct IndexSelectionPass;
impl OptimizationPass for IndexSelectionPass {
fn name(&self) -> &str {
"IndexSelection"
}
fn apply(&self, query: QueryExpr) -> QueryExpr {
match query {
QueryExpr::Table(mut tq) => {
if let Some(filter) = effective_table_filter(&tq).as_ref() {
if let Some(hint) = Self::analyze_filter(filter) {
let expand = tq.expand.get_or_insert_with(Default::default);
expand.index_hint = Some(hint);
}
}
QueryExpr::Table(tq)
}
other => other,
}
}
fn benefit(&self) -> u32 {
70
}
}
impl IndexSelectionPass {
fn analyze_filter(filter: &crate::storage::query::ast::Filter) -> Option<IndexHint> {
match filter {
crate::storage::query::ast::Filter::Compare { field, op, .. }
if *op == crate::storage::query::ast::CompareOp::Eq =>
{
let col = Self::field_name(field);
Some(IndexHint {
method: IndexHintMethod::Hash,
column: col,
})
}
crate::storage::query::ast::Filter::Compare {
field,
op:
crate::storage::query::ast::CompareOp::Lt
| crate::storage::query::ast::CompareOp::Le
| crate::storage::query::ast::CompareOp::Gt
| crate::storage::query::ast::CompareOp::Ge,
..
} => {
let col = Self::field_name(field);
Some(IndexHint {
method: IndexHintMethod::BTree,
column: col,
})
}
crate::storage::query::ast::Filter::Between { field, .. } => {
let col = Self::field_name(field);
Some(IndexHint {
method: IndexHintMethod::BTree,
column: col,
})
}
crate::storage::query::ast::Filter::In { field, values } if values.len() <= 10 => {
let col = Self::field_name(field);
Some(IndexHint {
method: IndexHintMethod::Bitmap,
column: col,
})
}
crate::storage::query::ast::Filter::And(left, right) => {
Self::analyze_filter(left).or_else(|| Self::analyze_filter(right))
}
_ => None,
}
}
fn field_name(field: &crate::storage::query::ast::FieldRef) -> String {
match field {
crate::storage::query::ast::FieldRef::TableColumn { column, .. } => column.clone(),
crate::storage::query::ast::FieldRef::NodeProperty { property, .. } => property.clone(),
crate::storage::query::ast::FieldRef::EdgeProperty { property, .. } => property.clone(),
crate::storage::query::ast::FieldRef::NodeId { alias } => {
format!("{}.id", alias)
}
}
}
}
#[derive(Debug, Clone)]
pub struct IndexHint {
pub method: IndexHintMethod,
pub column: String,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum IndexHintMethod {
Hash,
BTree,
Bitmap,
Spatial,
}
struct LimitPushdownPass;
impl OptimizationPass for LimitPushdownPass {
fn name(&self) -> &str {
"LimitPushdown"
}
fn apply(&self, query: QueryExpr) -> QueryExpr {
match query {
QueryExpr::Join(jq) => {
let left = self.apply(*jq.left);
let right = self.apply(*jq.right);
QueryExpr::Join(JoinQuery {
left: Box::new(left),
right: Box::new(right),
..jq
})
}
other => other,
}
}
fn benefit(&self) -> u32 {
60
}
}
fn swap_condition(
condition: crate::storage::query::ast::JoinCondition,
) -> crate::storage::query::ast::JoinCondition {
crate::storage::query::ast::JoinCondition {
left_field: condition.right_field,
right_field: condition.left_field,
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::storage::query::ast::{
DistanceMetric, FieldRef, FusionStrategy, JoinCondition, Projection, TableQuery,
};
fn make_table_query(name: &str) -> QueryExpr {
QueryExpr::Table(TableQuery {
table: name.to_string(),
source: None,
alias: Some(name.to_string()),
select_items: Vec::new(),
columns: vec![Projection::All],
where_expr: None,
filter: None,
group_by_exprs: Vec::new(),
group_by: Vec::new(),
having_expr: None,
having: None,
order_by: vec![],
limit: None,
limit_param: None,
offset: None,
offset_param: None,
expand: None,
as_of: None,
})
}
#[test]
fn test_optimizer_applies_passes() {
let optimizer = QueryOptimizer::new();
let query = make_table_query("hosts");
let (optimized, passes) = optimizer.optimize(query);
assert!(matches!(optimized, QueryExpr::Table(_)));
}
#[test]
fn test_join_reordering() {
let optimizer = QueryOptimizer::new();
let small = QueryExpr::Table(TableQuery {
table: "small".to_string(),
source: None,
alias: None,
select_items: Vec::new(),
columns: vec![Projection::All],
where_expr: None,
filter: None,
group_by_exprs: Vec::new(),
group_by: Vec::new(),
having_expr: None,
having: None,
order_by: vec![],
limit: Some(10), limit_param: None,
offset: None,
offset_param: None,
expand: None,
as_of: None,
});
let large = QueryExpr::Table(TableQuery {
table: "large".to_string(),
source: None,
alias: None,
select_items: Vec::new(),
columns: vec![Projection::All],
where_expr: None,
filter: None,
group_by_exprs: Vec::new(),
group_by: Vec::new(),
having_expr: None,
having: None,
order_by: vec![],
limit: None, limit_param: None,
offset: None,
offset_param: None,
expand: None,
as_of: None,
});
let join = QueryExpr::Join(JoinQuery {
left: Box::new(large.clone()),
right: Box::new(small.clone()),
join_type: JoinType::Inner,
on: JoinCondition {
left_field: FieldRef::TableColumn {
table: "large".to_string(),
column: "id".to_string(),
},
right_field: FieldRef::TableColumn {
table: "small".to_string(),
column: "id".to_string(),
},
},
filter: None,
order_by: Vec::new(),
limit: None,
offset: None,
return_items: Vec::new(),
return_: Vec::new(),
});
let (optimized, passes) = optimizer.optimize(join);
if let QueryExpr::Join(jq) = optimized {
if let QueryExpr::Table(left) = *jq.left {
assert_eq!(left.table, "small");
}
}
}
}