oxigdal_query/optimizer/
mod.rs1pub mod cost_model;
4pub mod rules;
5
6use crate::error::{QueryError, Result};
7use crate::parser::ast::*;
8use cost_model::CostModel;
9use oxigdal_core::error::OxiGdalError;
10use rules::optimize_with_rules;
11use serde::{Deserialize, Serialize};
12
13pub struct Optimizer {
15 cost_model: CostModel,
17 _config: OptimizerConfig,
19}
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct OptimizerConfig {
24 pub max_passes: usize,
26 pub enable_predicate_pushdown: bool,
28 pub enable_join_reordering: bool,
30 pub enable_constant_folding: bool,
32 pub enable_cse: bool,
34 pub enable_filter_fusion: bool,
36 pub enable_projection_pushdown: bool,
38}
39
40impl Default for OptimizerConfig {
41 fn default() -> Self {
42 Self {
43 max_passes: 10,
44 enable_predicate_pushdown: true,
45 enable_join_reordering: true,
46 enable_constant_folding: true,
47 enable_cse: true,
48 enable_filter_fusion: true,
49 enable_projection_pushdown: true,
50 }
51 }
52}
53
54impl Optimizer {
55 pub fn new() -> Self {
57 Self::with_config(OptimizerConfig::default())
58 }
59
60 pub fn with_config(_config: OptimizerConfig) -> Self {
62 Self {
63 cost_model: CostModel::new(),
64 _config,
65 }
66 }
67
68 pub fn cost_model(&self) -> &CostModel {
70 &self.cost_model
71 }
72
73 pub fn optimize(&self, stmt: Statement) -> Result<OptimizedQuery> {
75 match stmt {
76 Statement::Select(select) => {
77 let original_cost = self.estimate_cost(&select);
78
79 if !original_cost.total().is_finite() || original_cost.total() < 0.0 {
81 return Err(QueryError::optimization(
82 OxiGdalError::invalid_state_builder(
83 "Invalid cost estimation for original query",
84 )
85 .with_operation("cost_estimation")
86 .with_parameter("estimated_cost", original_cost.total().to_string())
87 .with_suggestion("Query may be too complex or contain invalid operations")
88 .build()
89 .to_string(),
90 ));
91 }
92
93 let optimized = optimize_with_rules(select)?;
95
96 let optimized_cost = self.estimate_cost(&optimized);
97
98 if !optimized_cost.total().is_finite() || optimized_cost.total() < 0.0 {
100 return Err(QueryError::optimization(
101 OxiGdalError::invalid_state_builder(
102 "Invalid cost estimation after optimization",
103 )
104 .with_operation("optimization")
105 .with_parameter("original_cost", original_cost.total().to_string())
106 .with_parameter("optimized_cost", optimized_cost.total().to_string())
107 .with_suggestion("Optimization may have introduced invalid transformations")
108 .build()
109 .to_string(),
110 ));
111 }
112
113 Ok(OptimizedQuery {
114 statement: Statement::Select(optimized),
115 original_cost,
116 optimized_cost,
117 })
118 }
119 }
120 }
121
122 fn estimate_cost(&self, stmt: &SelectStatement) -> cost_model::Cost {
124 let mut total_cost = cost_model::Cost::zero();
125
126 if let Some(ref table_ref) = stmt.from {
128 total_cost = total_cost.add(&self.estimate_table_cost(table_ref));
129 }
130
131 if stmt.selection.is_some() {
133 total_cost = total_cost.add(&cost_model::Cost::new(1000.0, 0.0, 0.0, 0.0));
135 }
136
137 if !stmt.group_by.is_empty() {
139 total_cost = total_cost.add(&self.cost_model.aggregate_cost(1000, 100));
140 }
141
142 if !stmt.order_by.is_empty() {
144 total_cost = total_cost.add(&self.cost_model.sort_cost(1000));
145 }
146
147 total_cost
148 }
149
150 fn estimate_table_cost(&self, table_ref: &TableReference) -> cost_model::Cost {
152 match table_ref {
153 TableReference::Table { name, .. } => self.cost_model.scan_cost(name),
154 TableReference::Join {
155 left,
156 right,
157 join_type,
158 ..
159 } => {
160 let left_cost = self.estimate_table_cost(left);
161 let right_cost = self.estimate_table_cost(right);
162 let join_cost = self.cost_model.join_cost(1000, 1000, *join_type);
163 left_cost.add(&right_cost).add(&join_cost)
164 }
165 TableReference::Subquery { query, .. } => self.estimate_cost(query),
166 }
167 }
168}
169
170impl Default for Optimizer {
171 fn default() -> Self {
172 Self::new()
173 }
174}
175
176#[derive(Debug, Clone, Serialize, Deserialize)]
178pub struct OptimizedQuery {
179 pub statement: Statement,
181 pub original_cost: cost_model::Cost,
183 pub optimized_cost: cost_model::Cost,
185}
186
187impl OptimizedQuery {
188 pub fn improvement_ratio(&self) -> f64 {
190 let original = self.original_cost.total();
191 let optimized = self.optimized_cost.total();
192 if original > 0.0 {
193 (original - optimized) / original
194 } else {
195 0.0
196 }
197 }
198
199 pub fn speedup_factor(&self) -> f64 {
201 let original = self.original_cost.total();
202 let optimized = self.optimized_cost.total();
203 if optimized > 0.0 {
204 original / optimized
205 } else {
206 1.0
207 }
208 }
209}
210
211#[cfg(test)]
212mod tests {
213 use super::*;
214 use crate::parser::sql::parse_sql;
215
216 #[test]
217 fn test_optimizer_creation() {
218 let optimizer = Optimizer::new();
219 assert!(optimizer._config.enable_constant_folding);
220 }
221
222 #[test]
223 fn test_optimize_simple_query() -> Result<()> {
224 let sql = "SELECT id, name FROM users WHERE 1 = 1";
225 let stmt = parse_sql(sql)?;
226
227 let optimizer = Optimizer::new();
228 let optimized = optimizer.optimize(stmt)?;
229
230 assert!(optimized.original_cost.total() >= 0.0);
231 assert!(optimized.optimized_cost.total() >= 0.0);
232
233 Ok(())
234 }
235
236 #[test]
237 fn test_cost_estimation() {
238 let optimizer = Optimizer::new();
239 let stmt = SelectStatement {
240 projection: vec![SelectItem::Wildcard],
241 from: Some(TableReference::Table {
242 name: "users".to_string(),
243 alias: None,
244 }),
245 selection: None,
246 group_by: Vec::new(),
247 having: None,
248 order_by: Vec::new(),
249 limit: None,
250 offset: None,
251 };
252
253 let cost = optimizer.estimate_cost(&stmt);
254 assert!(cost.total() > 0.0);
255 }
256}