pub mod cost_model;
pub mod rules;
use crate::error::{QueryError, Result};
use crate::parser::ast::*;
use cost_model::CostModel;
use oxigdal_core::error::OxiGdalError;
use rules::optimize_with_rules;
use serde::{Deserialize, Serialize};
pub struct Optimizer {
cost_model: CostModel,
_config: OptimizerConfig,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OptimizerConfig {
pub max_passes: usize,
pub enable_predicate_pushdown: bool,
pub enable_join_reordering: bool,
pub enable_constant_folding: bool,
pub enable_cse: bool,
pub enable_filter_fusion: bool,
pub enable_projection_pushdown: bool,
}
impl Default for OptimizerConfig {
fn default() -> Self {
Self {
max_passes: 10,
enable_predicate_pushdown: true,
enable_join_reordering: true,
enable_constant_folding: true,
enable_cse: true,
enable_filter_fusion: true,
enable_projection_pushdown: true,
}
}
}
impl Optimizer {
pub fn new() -> Self {
Self::with_config(OptimizerConfig::default())
}
pub fn with_config(_config: OptimizerConfig) -> Self {
Self {
cost_model: CostModel::new(),
_config,
}
}
pub fn cost_model(&self) -> &CostModel {
&self.cost_model
}
pub fn optimize(&self, stmt: Statement) -> Result<OptimizedQuery> {
match stmt {
Statement::Select(select) => {
let original_cost = self.estimate_cost(&select);
if !original_cost.total().is_finite() || original_cost.total() < 0.0 {
return Err(QueryError::optimization(
OxiGdalError::invalid_state_builder(
"Invalid cost estimation for original query",
)
.with_operation("cost_estimation")
.with_parameter("estimated_cost", original_cost.total().to_string())
.with_suggestion("Query may be too complex or contain invalid operations")
.build()
.to_string(),
));
}
let optimized = optimize_with_rules(select)?;
let optimized_cost = self.estimate_cost(&optimized);
if !optimized_cost.total().is_finite() || optimized_cost.total() < 0.0 {
return Err(QueryError::optimization(
OxiGdalError::invalid_state_builder(
"Invalid cost estimation after optimization",
)
.with_operation("optimization")
.with_parameter("original_cost", original_cost.total().to_string())
.with_parameter("optimized_cost", optimized_cost.total().to_string())
.with_suggestion("Optimization may have introduced invalid transformations")
.build()
.to_string(),
));
}
Ok(OptimizedQuery {
statement: Statement::Select(optimized),
original_cost,
optimized_cost,
})
}
}
}
fn estimate_cost(&self, stmt: &SelectStatement) -> cost_model::Cost {
let mut total_cost = cost_model::Cost::zero();
if let Some(ref table_ref) = stmt.from {
total_cost = total_cost.add(&self.estimate_table_cost(table_ref));
}
if stmt.selection.is_some() {
total_cost = total_cost.add(&cost_model::Cost::new(1000.0, 0.0, 0.0, 0.0));
}
if !stmt.group_by.is_empty() {
total_cost = total_cost.add(&self.cost_model.aggregate_cost(1000, 100));
}
if !stmt.order_by.is_empty() {
total_cost = total_cost.add(&self.cost_model.sort_cost(1000));
}
total_cost
}
fn estimate_table_cost(&self, table_ref: &TableReference) -> cost_model::Cost {
match table_ref {
TableReference::Table { name, .. } => self.cost_model.scan_cost(name),
TableReference::Join {
left,
right,
join_type,
..
} => {
let left_cost = self.estimate_table_cost(left);
let right_cost = self.estimate_table_cost(right);
let join_cost = self.cost_model.join_cost(1000, 1000, *join_type);
left_cost.add(&right_cost).add(&join_cost)
}
TableReference::Subquery { query, .. } => self.estimate_cost(query),
}
}
}
impl Default for Optimizer {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OptimizedQuery {
pub statement: Statement,
pub original_cost: cost_model::Cost,
pub optimized_cost: cost_model::Cost,
}
impl OptimizedQuery {
pub fn improvement_ratio(&self) -> f64 {
let original = self.original_cost.total();
let optimized = self.optimized_cost.total();
if original > 0.0 {
(original - optimized) / original
} else {
0.0
}
}
pub fn speedup_factor(&self) -> f64 {
let original = self.original_cost.total();
let optimized = self.optimized_cost.total();
if optimized > 0.0 {
original / optimized
} else {
1.0
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::parser::sql::parse_sql;
#[test]
fn test_optimizer_creation() {
let optimizer = Optimizer::new();
assert!(optimizer._config.enable_constant_folding);
}
#[test]
fn test_optimize_simple_query() -> Result<()> {
let sql = "SELECT id, name FROM users WHERE 1 = 1";
let stmt = parse_sql(sql)?;
let optimizer = Optimizer::new();
let optimized = optimizer.optimize(stmt)?;
assert!(optimized.original_cost.total() >= 0.0);
assert!(optimized.optimized_cost.total() >= 0.0);
Ok(())
}
#[test]
fn test_cost_estimation() {
let optimizer = Optimizer::new();
let stmt = SelectStatement {
projection: vec![SelectItem::Wildcard],
from: Some(TableReference::Table {
name: "users".to_string(),
alias: None,
}),
selection: None,
group_by: Vec::new(),
having: None,
order_by: Vec::new(),
limit: None,
offset: None,
};
let cost = optimizer.estimate_cost(&stmt);
assert!(cost.total() > 0.0);
}
}