use crate::error::{Error, Result};
use std::collections::{HashMap, HashSet};
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub enum OptimizationLevel {
None,
Basic,
Standard,
Aggressive,
}
impl Default for OptimizationLevel {
fn default() -> Self {
OptimizationLevel::Standard
}
}
#[derive(Debug, Clone)]
pub struct ColumnStats {
pub name: String,
pub distinct_count: usize,
pub min_value: Option<f64>,
pub max_value: Option<f64>,
pub null_count: usize,
pub row_count: usize,
pub avg_length: Option<f64>,
}
impl ColumnStats {
pub fn new(name: String, row_count: usize) -> Self {
ColumnStats {
name,
distinct_count: row_count,
min_value: None,
max_value: None,
null_count: 0,
row_count,
avg_length: None,
}
}
pub fn estimate_selectivity(&self, op: &FilterOp) -> f64 {
match op {
FilterOp::Equals(_) => {
if self.distinct_count > 0 {
1.0 / self.distinct_count as f64
} else {
0.1
}
}
FilterOp::NotEquals(_) => {
if self.distinct_count > 0 {
1.0 - (1.0 / self.distinct_count as f64)
} else {
0.9
}
}
FilterOp::LessThan(val) | FilterOp::LessOrEqual(val) => {
if let (Some(min), Some(max)) = (self.min_value, self.max_value) {
if max > min {
((val - min) / (max - min)).clamp(0.0, 1.0)
} else {
0.5
}
} else {
0.33
}
}
FilterOp::GreaterThan(val) | FilterOp::GreaterOrEqual(val) => {
if let (Some(min), Some(max)) = (self.min_value, self.max_value) {
if max > min {
((max - val) / (max - min)).clamp(0.0, 1.0)
} else {
0.5
}
} else {
0.33
}
}
FilterOp::Between(low, high) => {
if let (Some(min), Some(max)) = (self.min_value, self.max_value) {
if max > min {
((high - low) / (max - min)).clamp(0.0, 1.0)
} else {
0.5
}
} else {
0.25
}
}
FilterOp::IsNull => {
if self.row_count > 0 {
self.null_count as f64 / self.row_count as f64
} else {
0.01
}
}
FilterOp::IsNotNull => {
if self.row_count > 0 {
1.0 - (self.null_count as f64 / self.row_count as f64)
} else {
0.99
}
}
FilterOp::In(values) => {
let n_values = values.len() as f64;
if self.distinct_count > 0 {
(n_values / self.distinct_count as f64).min(1.0)
} else {
(n_values * 0.1).min(1.0)
}
}
FilterOp::Like(_) => 0.1, FilterOp::Custom(_) => 0.5, }
}
}
#[derive(Debug, Clone)]
pub enum FilterOp {
Equals(f64),
NotEquals(f64),
LessThan(f64),
LessOrEqual(f64),
GreaterThan(f64),
GreaterOrEqual(f64),
Between(f64, f64),
IsNull,
IsNotNull,
In(Vec<f64>),
Like(String),
Custom(String),
}
#[derive(Debug, Clone)]
pub enum OptimizableOp {
Select(Vec<String>),
Filter {
column: String,
op: FilterOp,
selectivity: f64,
},
Aggregate {
group_by: Vec<String>,
aggregates: Vec<(String, AggregateFunc)>,
},
Sort {
columns: Vec<String>,
ascending: Vec<bool>,
},
Join {
right_columns: Vec<String>,
left_key: String,
right_key: String,
join_type: JoinType,
},
Map {
input_columns: Vec<String>,
output_column: String,
},
Limit(usize),
Offset(usize),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AggregateFunc {
Sum,
Mean,
Min,
Max,
Count,
Std,
Var,
Median,
First,
Last,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum JoinType {
Inner,
Left,
Right,
Outer,
}
#[derive(Debug, Clone)]
pub struct QueryPlan {
pub operations: Vec<OptimizableOp>,
pub column_stats: HashMap<String, ColumnStats>,
pub estimated_rows: usize,
pub required_columns: HashSet<String>,
}
impl QueryPlan {
pub fn new(estimated_rows: usize) -> Self {
QueryPlan {
operations: Vec::new(),
column_stats: HashMap::new(),
estimated_rows,
required_columns: HashSet::new(),
}
}
pub fn add_operation(&mut self, op: OptimizableOp) {
self.operations.push(op);
}
pub fn set_column_stats(&mut self, stats: ColumnStats) {
self.column_stats.insert(stats.name.clone(), stats);
}
fn get_required_columns(&self, op: &OptimizableOp) -> HashSet<String> {
let mut cols = HashSet::new();
match op {
OptimizableOp::Select(columns) => {
cols.extend(columns.iter().cloned());
}
OptimizableOp::Filter { column, .. } => {
cols.insert(column.clone());
}
OptimizableOp::Aggregate {
group_by,
aggregates,
} => {
cols.extend(group_by.iter().cloned());
for (col, _) in aggregates {
cols.insert(col.clone());
}
}
OptimizableOp::Sort { columns, .. } => {
cols.extend(columns.iter().cloned());
}
OptimizableOp::Join {
left_key,
right_key,
right_columns,
..
} => {
cols.insert(left_key.clone());
cols.insert(right_key.clone());
cols.extend(right_columns.iter().cloned());
}
OptimizableOp::Map { input_columns, .. } => {
cols.extend(input_columns.iter().cloned());
}
OptimizableOp::Limit(_) | OptimizableOp::Offset(_) => {}
}
cols
}
}
#[derive(Debug)]
pub struct QueryOptimizer {
level: OptimizationLevel,
stats: OptimizerStats,
}
#[derive(Debug, Default, Clone)]
pub struct OptimizerStats {
pub optimizations_applied: usize,
pub operations_before: usize,
pub operations_after: usize,
pub estimated_cost_reduction: f64,
pub predicates_pushed: usize,
pub projections_pushed: usize,
pub operations_fused: usize,
}
impl QueryOptimizer {
pub fn new(level: OptimizationLevel) -> Self {
QueryOptimizer {
level,
stats: OptimizerStats::default(),
}
}
pub fn optimize(&mut self, mut plan: QueryPlan) -> Result<QueryPlan> {
self.stats = OptimizerStats::default();
self.stats.operations_before = plan.operations.len();
if self.level == OptimizationLevel::None {
self.stats.operations_after = plan.operations.len();
return Ok(plan);
}
self.compute_required_columns(&mut plan);
if self.level >= OptimizationLevel::Basic {
plan = self.predicate_pushdown(plan)?;
}
if self.level >= OptimizationLevel::Standard {
plan = self.projection_pushdown(plan)?;
}
if self.level >= OptimizationLevel::Standard {
plan = self.fuse_operations(plan)?;
}
if self.level >= OptimizationLevel::Aggressive {
plan = self.cost_based_reorder(plan)?;
}
self.stats.operations_after = plan.operations.len();
Ok(plan)
}
pub fn stats(&self) -> &OptimizerStats {
&self.stats
}
fn compute_required_columns(&mut self, plan: &mut QueryPlan) {
let mut required = HashSet::new();
for op in plan.operations.iter().rev() {
let op_required = plan.get_required_columns(op);
required.extend(op_required);
}
plan.required_columns = required;
}
fn predicate_pushdown(&mut self, mut plan: QueryPlan) -> Result<QueryPlan> {
let mut filters = Vec::new();
let mut other_ops = Vec::new();
for op in plan.operations {
match op {
OptimizableOp::Filter { .. } => filters.push(op),
_ => other_ops.push(op),
}
}
filters.sort_by(|a, b| {
let sel_a = if let OptimizableOp::Filter { selectivity, .. } = a {
*selectivity
} else {
1.0
};
let sel_b = if let OptimizableOp::Filter { selectivity, .. } = b {
*selectivity
} else {
1.0
};
sel_a
.partial_cmp(&sel_b)
.unwrap_or(std::cmp::Ordering::Equal)
});
self.stats.predicates_pushed = filters.len();
let mut new_ops = Vec::new();
for filter in filters {
if self.can_push_filter_before(&filter, &other_ops) {
new_ops.push(filter);
self.stats.optimizations_applied += 1;
} else {
let insert_pos = self.find_filter_position(&filter, &other_ops);
other_ops.insert(insert_pos, filter);
}
}
new_ops.extend(other_ops);
plan.operations = new_ops;
Ok(plan)
}
fn can_push_filter_before(&self, filter: &OptimizableOp, ops: &[OptimizableOp]) -> bool {
let filter_col = if let OptimizableOp::Filter { column, .. } = filter {
column
} else {
return false;
};
for op in ops {
match op {
OptimizableOp::Aggregate { aggregates, .. } => {
for (col, _) in aggregates {
if col == filter_col {
return false;
}
}
}
OptimizableOp::Map { output_column, .. } => {
if output_column == filter_col {
return false;
}
}
_ => {}
}
}
true
}
fn find_filter_position(&self, filter: &OptimizableOp, ops: &[OptimizableOp]) -> usize {
let filter_col = if let OptimizableOp::Filter { column, .. } = filter {
column
} else {
return ops.len();
};
for (i, op) in ops.iter().enumerate() {
match op {
OptimizableOp::Map { output_column, .. } => {
if output_column == filter_col {
return i + 1;
}
}
OptimizableOp::Aggregate { aggregates, .. } => {
for (col, _) in aggregates {
if col == filter_col {
return i + 1;
}
}
}
_ => {}
}
}
0
}
fn projection_pushdown(&mut self, mut plan: QueryPlan) -> Result<QueryPlan> {
let mut final_required = plan.required_columns.clone();
let mut required_at_stage: Vec<HashSet<String>> = Vec::new();
for op in plan.operations.iter().rev() {
let stage_required = final_required.clone();
match op {
OptimizableOp::Select(cols) => {
final_required = cols
.iter()
.filter(|c| stage_required.contains(*c))
.cloned()
.collect();
}
OptimizableOp::Filter { column, .. } => {
final_required.insert(column.clone());
}
OptimizableOp::Aggregate {
group_by,
aggregates,
} => {
final_required.extend(group_by.iter().cloned());
for (col, _) in aggregates {
final_required.insert(col.clone());
}
}
OptimizableOp::Map {
input_columns,
output_column,
} => {
if stage_required.contains(output_column) {
final_required.extend(input_columns.iter().cloned());
}
}
OptimizableOp::Sort { columns, .. } => {
final_required.extend(columns.iter().cloned());
}
OptimizableOp::Join {
left_key,
right_key,
right_columns,
..
} => {
final_required.insert(left_key.clone());
final_required.insert(right_key.clone());
final_required.extend(right_columns.iter().cloned());
}
_ => {}
}
required_at_stage.push(stage_required);
}
if !final_required.is_empty() {
let has_wide_source = plan.column_stats.len() > final_required.len();
if has_wide_source {
let early_select = OptimizableOp::Select(final_required.into_iter().collect());
plan.operations.insert(0, early_select);
self.stats.projections_pushed += 1;
self.stats.optimizations_applied += 1;
}
}
Ok(plan)
}
fn fuse_operations(&mut self, mut plan: QueryPlan) -> Result<QueryPlan> {
let mut i = 0;
while i < plan.operations.len().saturating_sub(1) {
let can_fuse = match (&plan.operations[i], &plan.operations[i + 1]) {
(
OptimizableOp::Filter { column: col1, .. },
OptimizableOp::Filter { column: col2, .. },
) => col1 == col2,
(
OptimizableOp::Sort { columns: cols1, .. },
OptimizableOp::Sort { columns: cols2, .. },
) => cols1 == cols2,
(OptimizableOp::Select(_), OptimizableOp::Select(_)) => true,
(OptimizableOp::Offset(_), OptimizableOp::Limit(_)) => true,
_ => false,
};
if can_fuse {
plan.operations.remove(i);
self.stats.operations_fused += 1;
self.stats.optimizations_applied += 1;
} else {
i += 1;
}
}
Ok(plan)
}
fn cost_based_reorder(&mut self, mut plan: QueryPlan) -> Result<QueryPlan> {
let mut filter_costs: Vec<(usize, f64)> = Vec::new();
for (i, op) in plan.operations.iter().enumerate() {
if let OptimizableOp::Filter {
column,
selectivity,
..
} = op
{
let execution_cost = plan
.column_stats
.get(column)
.map(|s| s.row_count as f64)
.unwrap_or(1000.0);
let cost = selectivity * execution_cost;
filter_costs.push((i, cost));
}
}
filter_costs.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
if !filter_costs.is_empty() {
let original_order: Vec<usize> = filter_costs.iter().map(|(i, _)| *i).collect();
let optimal_order: Vec<usize> = {
let mut sorted = filter_costs.clone();
sorted.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
sorted.iter().map(|(i, _)| *i).collect()
};
if original_order != optimal_order {
let mut new_operations = plan.operations.clone();
for (new_pos, &old_pos) in optimal_order.iter().enumerate() {
if let Some((orig_idx, _)) = filter_costs.get(new_pos) {
new_operations[*orig_idx] = plan.operations[old_pos].clone();
}
}
plan.operations = new_operations;
self.stats.optimizations_applied += 1;
}
}
self.stats.estimated_cost_reduction = filter_costs
.iter()
.enumerate()
.map(|(i, (_, cost))| {
let reduction_factor = 0.5f64.powi(i as i32);
cost * (1.0 - reduction_factor)
})
.sum();
Ok(plan)
}
}
#[derive(Debug)]
pub struct QueryPlanBuilder {
plan: QueryPlan,
}
impl QueryPlanBuilder {
pub fn new(estimated_rows: usize) -> Self {
QueryPlanBuilder {
plan: QueryPlan::new(estimated_rows),
}
}
pub fn select(mut self, columns: Vec<String>) -> Self {
self.plan.add_operation(OptimizableOp::Select(columns));
self
}
pub fn filter(mut self, column: String, op: FilterOp) -> Self {
let selectivity = self
.plan
.column_stats
.get(&column)
.map(|s| s.estimate_selectivity(&op))
.unwrap_or(0.5);
self.plan.add_operation(OptimizableOp::Filter {
column,
op,
selectivity,
});
self
}
pub fn aggregate(
mut self,
group_by: Vec<String>,
aggregates: Vec<(String, AggregateFunc)>,
) -> Self {
self.plan.add_operation(OptimizableOp::Aggregate {
group_by,
aggregates,
});
self
}
pub fn sort(mut self, columns: Vec<String>, ascending: Vec<bool>) -> Self {
self.plan
.add_operation(OptimizableOp::Sort { columns, ascending });
self
}
pub fn limit(mut self, n: usize) -> Self {
self.plan.add_operation(OptimizableOp::Limit(n));
self
}
pub fn with_stats(mut self, stats: ColumnStats) -> Self {
self.plan.set_column_stats(stats);
self
}
pub fn build(self) -> QueryPlan {
self.plan
}
}
pub trait Explainable {
fn explain(&self) -> String;
fn explain_analyze(&self) -> String;
}
impl Explainable for QueryPlan {
fn explain(&self) -> String {
let mut output = String::new();
output.push_str("Query Plan:\n");
output.push_str(&format!(" Estimated rows: {}\n", self.estimated_rows));
output.push_str(" Operations:\n");
for (i, op) in self.operations.iter().enumerate() {
let op_str = match op {
OptimizableOp::Select(cols) => format!("SELECT [{}]", cols.join(", ")),
OptimizableOp::Filter {
column,
selectivity,
..
} => {
format!(
"FILTER {} (selectivity: {:.2}%)",
column,
selectivity * 100.0
)
}
OptimizableOp::Aggregate {
group_by,
aggregates,
} => {
let agg_str: Vec<String> = aggregates
.iter()
.map(|(col, func)| format!("{:?}({})", func, col))
.collect();
if group_by.is_empty() {
format!("AGGREGATE {}", agg_str.join(", "))
} else {
format!(
"AGGREGATE BY [{}] => {}",
group_by.join(", "),
agg_str.join(", ")
)
}
}
OptimizableOp::Sort { columns, ascending } => {
let sort_str: Vec<String> = columns
.iter()
.zip(ascending)
.map(|(c, asc)| format!("{} {}", c, if *asc { "ASC" } else { "DESC" }))
.collect();
format!("SORT BY {}", sort_str.join(", "))
}
OptimizableOp::Join {
join_type,
left_key,
right_key,
..
} => {
format!("{:?} JOIN ON {} = {}", join_type, left_key, right_key)
}
OptimizableOp::Map {
input_columns,
output_column,
} => {
format!("MAP [{}] -> {}", input_columns.join(", "), output_column)
}
OptimizableOp::Limit(n) => format!("LIMIT {}", n),
OptimizableOp::Offset(n) => format!("OFFSET {}", n),
};
output.push_str(&format!(" {}. {}\n", i + 1, op_str));
}
output
}
fn explain_analyze(&self) -> String {
let mut output = self.explain();
output.push_str("\nColumn Statistics:\n");
for (name, stats) in &self.column_stats {
output.push_str(&format!(" {}:\n", name));
output.push_str(&format!(" - Distinct: {}\n", stats.distinct_count));
output.push_str(&format!(" - Nulls: {}\n", stats.null_count));
if let Some(min) = stats.min_value {
output.push_str(&format!(" - Min: {}\n", min));
}
if let Some(max) = stats.max_value {
output.push_str(&format!(" - Max: {}\n", max));
}
}
output
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_query_plan_builder() {
let plan = QueryPlanBuilder::new(1000)
.select(vec!["a".to_string(), "b".to_string()])
.filter("a".to_string(), FilterOp::GreaterThan(5.0))
.aggregate(
vec!["a".to_string()],
vec![("b".to_string(), AggregateFunc::Sum)],
)
.build();
assert_eq!(plan.operations.len(), 3);
assert_eq!(plan.estimated_rows, 1000);
}
#[test]
fn test_predicate_pushdown() {
let mut stats = ColumnStats::new("price".to_string(), 10000);
stats.min_value = Some(0.0);
stats.max_value = Some(1000.0);
let plan = QueryPlanBuilder::new(10000)
.with_stats(stats)
.select(vec!["name".to_string(), "price".to_string()])
.filter("price".to_string(), FilterOp::GreaterThan(500.0))
.build();
let mut optimizer = QueryOptimizer::new(OptimizationLevel::Aggressive);
let optimized = optimizer.optimize(plan).expect("operation should succeed");
assert!(matches!(
optimized.operations[0],
OptimizableOp::Filter { .. }
));
}
#[test]
fn test_selectivity_estimation() {
let mut stats = ColumnStats::new("category".to_string(), 1000);
stats.distinct_count = 10;
stats.min_value = Some(0.0);
stats.max_value = Some(100.0);
let eq_sel = stats.estimate_selectivity(&FilterOp::Equals(5.0));
assert!(eq_sel < 0.2);
let range_sel = stats.estimate_selectivity(&FilterOp::GreaterThan(50.0));
assert!((range_sel - 0.5).abs() < 0.1);
}
#[test]
fn test_operation_fusion() {
let plan = QueryPlanBuilder::new(1000)
.select(vec!["a".to_string(), "b".to_string()])
.select(vec!["a".to_string()]) .build();
let mut optimizer = QueryOptimizer::new(OptimizationLevel::Standard);
let optimized = optimizer.optimize(plan).expect("operation should succeed");
assert!(optimized.operations.len() < 2);
}
#[test]
fn test_explain_plan() {
let plan = QueryPlanBuilder::new(10000)
.filter("price".to_string(), FilterOp::GreaterThan(100.0))
.aggregate(
vec!["category".to_string()],
vec![("price".to_string(), AggregateFunc::Sum)],
)
.sort(vec!["price".to_string()], vec![false])
.limit(10)
.build();
let explanation = plan.explain();
assert!(explanation.contains("FILTER"));
assert!(explanation.contains("AGGREGATE"));
assert!(explanation.contains("SORT"));
assert!(explanation.contains("LIMIT"));
}
#[test]
fn test_optimizer_stats() {
let plan = QueryPlanBuilder::new(1000)
.filter("a".to_string(), FilterOp::Equals(1.0))
.filter("b".to_string(), FilterOp::GreaterThan(10.0))
.select(vec!["a".to_string(), "b".to_string()])
.build();
let mut optimizer = QueryOptimizer::new(OptimizationLevel::Aggressive);
let _ = optimizer.optimize(plan).expect("operation should succeed");
let stats = optimizer.stats();
assert!(stats.predicates_pushed > 0);
}
#[test]
fn test_cost_based_reorder() {
let mut stats_a = ColumnStats::new("a".to_string(), 10000);
stats_a.distinct_count = 2;
let mut stats_b = ColumnStats::new("b".to_string(), 10000);
stats_b.distinct_count = 1000;
let plan = QueryPlanBuilder::new(10000)
.with_stats(stats_a)
.with_stats(stats_b)
.filter("a".to_string(), FilterOp::Equals(1.0)) .filter("b".to_string(), FilterOp::Equals(500.0)) .build();
let mut optimizer = QueryOptimizer::new(OptimizationLevel::Aggressive);
let optimized = optimizer.optimize(plan).expect("operation should succeed");
if let OptimizableOp::Filter { column, .. } = &optimized.operations[0] {
assert_eq!(column, "b");
}
}
}