Skip to main content

oxigdal_query/optimizer/
mod.rs

1//! Query optimizer.
2
3pub mod cost_model;
4pub mod rules;
5
6use crate::error::{QueryError, Result};
7use crate::parser::ast::*;
8use cost_model::CostModel;
9use oxigdal_core::error::OxiGdalError;
10use rules::optimize_with_rules;
11use serde::{Deserialize, Serialize};
12
13/// Query optimizer.
14pub struct Optimizer {
15    /// Cost model for estimating query costs.
16    cost_model: CostModel,
17    /// Configuration.
18    _config: OptimizerConfig,
19}
20
21/// Optimizer configuration.
22#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct OptimizerConfig {
24    /// Maximum optimization passes.
25    pub max_passes: usize,
26    /// Enable predicate pushdown.
27    pub enable_predicate_pushdown: bool,
28    /// Enable join reordering.
29    pub enable_join_reordering: bool,
30    /// Enable constant folding.
31    pub enable_constant_folding: bool,
32    /// Enable common subexpression elimination.
33    pub enable_cse: bool,
34    /// Enable filter fusion.
35    pub enable_filter_fusion: bool,
36    /// Enable projection pushdown.
37    pub enable_projection_pushdown: bool,
38}
39
40impl Default for OptimizerConfig {
41    fn default() -> Self {
42        Self {
43            max_passes: 10,
44            enable_predicate_pushdown: true,
45            enable_join_reordering: true,
46            enable_constant_folding: true,
47            enable_cse: true,
48            enable_filter_fusion: true,
49            enable_projection_pushdown: true,
50        }
51    }
52}
53
54impl Optimizer {
55    /// Create a new optimizer with default configuration.
56    pub fn new() -> Self {
57        Self::with_config(OptimizerConfig::default())
58    }
59
60    /// Create a new optimizer with custom configuration.
61    pub fn with_config(_config: OptimizerConfig) -> Self {
62        Self {
63            cost_model: CostModel::new(),
64            _config,
65        }
66    }
67
68    /// Get the cost model.
69    pub fn cost_model(&self) -> &CostModel {
70        &self.cost_model
71    }
72
73    /// Optimize a query.
74    pub fn optimize(&self, stmt: Statement) -> Result<OptimizedQuery> {
75        match stmt {
76            Statement::Select(select) => {
77                let original_cost = self.estimate_cost(&select);
78
79                // Validate original cost is reasonable
80                if !original_cost.total().is_finite() || original_cost.total() < 0.0 {
81                    return Err(QueryError::optimization(
82                        OxiGdalError::invalid_state_builder(
83                            "Invalid cost estimation for original query",
84                        )
85                        .with_operation("cost_estimation")
86                        .with_parameter("estimated_cost", original_cost.total().to_string())
87                        .with_suggestion("Query may be too complex or contain invalid operations")
88                        .build()
89                        .to_string(),
90                    ));
91                }
92
93                // Apply rule-based optimization
94                let optimized = optimize_with_rules(select)?;
95
96                let optimized_cost = self.estimate_cost(&optimized);
97
98                // Validate optimized cost
99                if !optimized_cost.total().is_finite() || optimized_cost.total() < 0.0 {
100                    return Err(QueryError::optimization(
101                        OxiGdalError::invalid_state_builder(
102                            "Invalid cost estimation after optimization",
103                        )
104                        .with_operation("optimization")
105                        .with_parameter("original_cost", original_cost.total().to_string())
106                        .with_parameter("optimized_cost", optimized_cost.total().to_string())
107                        .with_suggestion("Optimization may have introduced invalid transformations")
108                        .build()
109                        .to_string(),
110                    ));
111                }
112
113                Ok(OptimizedQuery {
114                    statement: Statement::Select(optimized),
115                    original_cost,
116                    optimized_cost,
117                })
118            }
119        }
120    }
121
122    /// Estimate the cost of executing a select statement.
123    fn estimate_cost(&self, stmt: &SelectStatement) -> cost_model::Cost {
124        let mut total_cost = cost_model::Cost::zero();
125
126        // Estimate FROM clause cost
127        if let Some(ref table_ref) = stmt.from {
128            total_cost = total_cost.add(&self.estimate_table_cost(table_ref));
129        }
130
131        // Estimate WHERE clause cost
132        if stmt.selection.is_some() {
133            // Add filter cost
134            total_cost = total_cost.add(&cost_model::Cost::new(1000.0, 0.0, 0.0, 0.0));
135        }
136
137        // Estimate GROUP BY cost
138        if !stmt.group_by.is_empty() {
139            total_cost = total_cost.add(&self.cost_model.aggregate_cost(1000, 100));
140        }
141
142        // Estimate ORDER BY cost
143        if !stmt.order_by.is_empty() {
144            total_cost = total_cost.add(&self.cost_model.sort_cost(1000));
145        }
146
147        total_cost
148    }
149
150    /// Estimate cost of a table reference.
151    fn estimate_table_cost(&self, table_ref: &TableReference) -> cost_model::Cost {
152        match table_ref {
153            TableReference::Table { name, .. } => self.cost_model.scan_cost(name),
154            TableReference::Join {
155                left,
156                right,
157                join_type,
158                ..
159            } => {
160                let left_cost = self.estimate_table_cost(left);
161                let right_cost = self.estimate_table_cost(right);
162                let join_cost = self.cost_model.join_cost(1000, 1000, *join_type);
163                left_cost.add(&right_cost).add(&join_cost)
164            }
165            TableReference::Subquery { query, .. } => self.estimate_cost(query),
166        }
167    }
168}
169
170impl Default for Optimizer {
171    fn default() -> Self {
172        Self::new()
173    }
174}
175
176/// An optimized query with cost information.
177#[derive(Debug, Clone, Serialize, Deserialize)]
178pub struct OptimizedQuery {
179    /// The optimized statement.
180    pub statement: Statement,
181    /// Original cost estimate.
182    pub original_cost: cost_model::Cost,
183    /// Optimized cost estimate.
184    pub optimized_cost: cost_model::Cost,
185}
186
187impl OptimizedQuery {
188    /// Get the improvement ratio.
189    pub fn improvement_ratio(&self) -> f64 {
190        let original = self.original_cost.total();
191        let optimized = self.optimized_cost.total();
192        if original > 0.0 {
193            (original - optimized) / original
194        } else {
195            0.0
196        }
197    }
198
199    /// Get the speedup factor.
200    pub fn speedup_factor(&self) -> f64 {
201        let original = self.original_cost.total();
202        let optimized = self.optimized_cost.total();
203        if optimized > 0.0 {
204            original / optimized
205        } else {
206            1.0
207        }
208    }
209}
210
211#[cfg(test)]
212mod tests {
213    use super::*;
214    use crate::parser::sql::parse_sql;
215
216    #[test]
217    fn test_optimizer_creation() {
218        let optimizer = Optimizer::new();
219        assert!(optimizer._config.enable_constant_folding);
220    }
221
222    #[test]
223    fn test_optimize_simple_query() -> Result<()> {
224        let sql = "SELECT id, name FROM users WHERE 1 = 1";
225        let stmt = parse_sql(sql)?;
226
227        let optimizer = Optimizer::new();
228        let optimized = optimizer.optimize(stmt)?;
229
230        assert!(optimized.original_cost.total() >= 0.0);
231        assert!(optimized.optimized_cost.total() >= 0.0);
232
233        Ok(())
234    }
235
236    #[test]
237    fn test_cost_estimation() {
238        let optimizer = Optimizer::new();
239        let stmt = SelectStatement {
240            projection: vec![SelectItem::Wildcard],
241            from: Some(TableReference::Table {
242                name: "users".to_string(),
243                alias: None,
244            }),
245            selection: None,
246            group_by: Vec::new(),
247            having: None,
248            order_by: Vec::new(),
249            limit: None,
250            offset: None,
251        };
252
253        let cost = optimizer.estimate_cost(&stmt);
254        assert!(cost.total() > 0.0);
255    }
256}