1use crate::dialects::DialectType;
9use crate::expressions::Expression;
10use crate::schema::Schema;
11use crate::traversal::ExpressionWalk;
12
13use super::annotate_types::annotate_types;
14use super::canonicalize::canonicalize;
15use super::eliminate_ctes::eliminate_ctes;
16use super::normalize::normalize;
17use super::optimize_joins::optimize_joins;
18use super::pushdown_predicates::pushdown_predicates;
19use super::pushdown_projections::pushdown_projections;
20use super::qualify_columns::qualify_columns;
21use super::simplify::simplify;
22use super::subquery::{merge_subqueries, unnest_subqueries};
23
24pub struct OptimizerConfig<'a> {
26 pub schema: Option<&'a dyn Schema>,
28 pub db: Option<String>,
30 pub catalog: Option<String>,
32 pub dialect: Option<DialectType>,
34 pub isolate_tables: bool,
36 pub quote_identifiers: bool,
38}
39
40impl<'a> Default for OptimizerConfig<'a> {
41 fn default() -> Self {
42 Self {
43 schema: None,
44 db: None,
45 catalog: None,
46 dialect: None,
47 isolate_tables: true,
48 quote_identifiers: false,
49 }
50 }
51}
52
53#[derive(Debug, Clone, Copy, PartialEq, Eq)]
55pub enum OptimizationRule {
56 Qualify,
58 PushdownProjections,
60 Normalize,
62 UnnestSubqueries,
64 PushdownPredicates,
66 OptimizeJoins,
68 EliminateSubqueries,
70 MergeSubqueries,
72 EliminateCtes,
74 AnnotateTypes,
76 Canonicalize,
78 Simplify,
80}
81
82pub const DEFAULT_RULES: &[OptimizationRule] = &[
84 OptimizationRule::Qualify,
85 OptimizationRule::PushdownProjections,
86 OptimizationRule::Normalize,
87 OptimizationRule::UnnestSubqueries,
88 OptimizationRule::PushdownPredicates,
89 OptimizationRule::OptimizeJoins,
90 OptimizationRule::EliminateSubqueries,
91 OptimizationRule::MergeSubqueries,
92 OptimizationRule::EliminateCtes,
93 OptimizationRule::AnnotateTypes,
94 OptimizationRule::Canonicalize,
95 OptimizationRule::Simplify,
96];
97
98const QUICK_RULES: &[OptimizationRule] =
99 &[OptimizationRule::Simplify, OptimizationRule::Canonicalize];
100const FAST_PATH_MAX_DEPTH: usize = 768;
101const FAST_PATH_MAX_CONNECTORS: usize = 10_000;
102const FAST_PATH_MAX_CONNECTOR_DEPTH: usize = 1024;
103const FAST_PATH_MAX_NODES: usize = 50_000;
104const CLONE_HEAVY_RULE_SKIP_NODES: usize = 20_000;
105
106#[derive(Debug, Clone, Copy)]
107struct ExpressionComplexity {
108 node_count: usize,
109 max_depth: usize,
110 connector_count: usize,
111 max_connector_depth: usize,
112}
113
114pub fn optimize(expression: Expression, config: &OptimizerConfig<'_>) -> Expression {
126 optimize_with_rules(expression, config, DEFAULT_RULES)
127}
128
129pub fn optimize_with_rules(
139 mut expression: Expression,
140 config: &OptimizerConfig<'_>,
141 rules: &[OptimizationRule],
142) -> Expression {
143 let complexity = analyze_expression_complexity(&expression);
144 if rules == DEFAULT_RULES && should_skip_all_optimization(&complexity) {
145 return expression;
146 }
147
148 let active_rules = if rules == DEFAULT_RULES && should_use_quick_path(&complexity) {
149 QUICK_RULES
150 } else {
151 rules
152 };
153
154 for rule in active_rules {
155 if complexity.node_count >= CLONE_HEAVY_RULE_SKIP_NODES
156 && matches!(
157 rule,
158 OptimizationRule::Qualify | OptimizationRule::Normalize
159 )
160 {
161 continue;
162 }
163 expression = apply_rule(expression, *rule, config);
164 }
165 expression
166}
167
168fn should_skip_all_optimization(complexity: &ExpressionComplexity) -> bool {
169 complexity.max_depth >= FAST_PATH_MAX_DEPTH
170 || complexity.max_connector_depth >= FAST_PATH_MAX_CONNECTOR_DEPTH
171}
172
173fn should_use_quick_path(complexity: &ExpressionComplexity) -> bool {
174 complexity.connector_count >= FAST_PATH_MAX_CONNECTORS
175 || complexity.max_connector_depth >= FAST_PATH_MAX_CONNECTOR_DEPTH
176 || complexity.node_count >= FAST_PATH_MAX_NODES
177}
178
179fn analyze_expression_complexity(expression: &Expression) -> ExpressionComplexity {
180 let mut node_count = 0usize;
181 let mut max_depth = 0usize;
182 let mut connector_count = 0usize;
183 let mut max_connector_depth = 0usize;
184 let mut stack: Vec<(&Expression, usize, usize)> = vec![(expression, 0, 0)];
185
186 while let Some((node, depth, connector_depth)) = stack.pop() {
187 node_count += 1;
188 max_depth = max_depth.max(depth);
189
190 match node {
191 Expression::And(op) | Expression::Or(op) => {
192 connector_count += 1;
193 let next_connector_depth = connector_depth + 1;
194 max_connector_depth = max_connector_depth.max(next_connector_depth);
195 stack.push((&op.right, depth + 1, next_connector_depth));
196 stack.push((&op.left, depth + 1, next_connector_depth));
197 }
198 Expression::Paren(paren) => {
199 stack.push((&paren.this, depth + 1, connector_depth));
200 }
201 _ => {
202 for child in node.children().into_iter().rev() {
203 stack.push((child, depth + 1, 0));
204 }
205 }
206 }
207 }
208
209 ExpressionComplexity {
210 node_count,
211 max_depth,
212 connector_count,
213 max_connector_depth,
214 }
215}
216
217fn apply_rule(
219 expression: Expression,
220 rule: OptimizationRule,
221 config: &OptimizerConfig<'_>,
222) -> Expression {
223 match rule {
224 OptimizationRule::Qualify => {
225 if let Some(schema) = config.schema {
227 let options = super::qualify_columns::QualifyColumnsOptions {
228 dialect: config.dialect,
229 ..Default::default()
230 };
231 let original = expression.clone();
232 qualify_columns(expression, schema, &options).unwrap_or(original)
233 } else {
234 expression
236 }
237 }
238 OptimizationRule::PushdownProjections => {
239 pushdown_projections(expression, config.dialect, true)
240 }
241 OptimizationRule::Normalize => {
242 let original = expression.clone();
244 normalize(expression, false, super::normalize::DEFAULT_MAX_DISTANCE).unwrap_or(original)
245 }
246 OptimizationRule::UnnestSubqueries => unnest_subqueries(expression),
247 OptimizationRule::PushdownPredicates => pushdown_predicates(expression, config.dialect),
248 OptimizationRule::OptimizeJoins => optimize_joins(expression),
249 OptimizationRule::EliminateSubqueries => eliminate_subqueries_opt(expression),
250 OptimizationRule::MergeSubqueries => merge_subqueries(expression, config.isolate_tables),
251 OptimizationRule::EliminateCtes => eliminate_ctes(expression),
252 OptimizationRule::AnnotateTypes => {
253 let _ = annotate_types(&expression, config.schema, config.dialect);
256 expression
257 }
258 OptimizationRule::Canonicalize => canonicalize(expression, config.dialect),
259 OptimizationRule::Simplify => simplify(expression, config.dialect),
260 }
261}
262
263use super::subquery::eliminate_subqueries as eliminate_subqueries_opt;
265
266pub fn quick_optimize(expression: Expression, dialect: Option<DialectType>) -> Expression {
270 let config = OptimizerConfig {
271 dialect,
272 ..Default::default()
273 };
274
275 optimize_with_rules(expression, &config, QUICK_RULES)
276}
277
278#[cfg(test)]
279mod tests {
280 use super::*;
281 use crate::generator::Generator;
282 use crate::parser::Parser;
283
284 fn gen(expr: &Expression) -> String {
285 Generator::new().generate(expr).unwrap()
286 }
287
288 fn parse(sql: &str) -> Expression {
289 Parser::parse_sql(sql).expect("Failed to parse")[0].clone()
290 }
291
292 #[test]
293 fn test_optimize_simple() {
294 let expr = parse("SELECT a FROM t");
295 let config = OptimizerConfig::default();
296 let result = optimize(expr, &config);
297 let sql = gen(&result);
298 assert!(sql.contains("SELECT"));
299 }
300
301 #[test]
302 fn test_optimize_with_where() {
303 let expr = parse("SELECT a FROM t WHERE b = 1");
304 let config = OptimizerConfig::default();
305 let result = optimize(expr, &config);
306 let sql = gen(&result);
307 assert!(sql.contains("WHERE"));
308 }
309
310 #[test]
311 fn test_optimize_with_join() {
312 let expr = parse("SELECT t.a FROM t JOIN s ON t.id = s.id");
313 let config = OptimizerConfig::default();
314 let result = optimize(expr, &config);
315 let sql = gen(&result);
316 assert!(sql.contains("JOIN"));
317 }
318
319 #[test]
320 fn test_quick_optimize() {
321 let expr = parse("SELECT 1 + 0 FROM t");
322 let result = quick_optimize(expr, None);
323 let sql = gen(&result);
324 assert!(sql.contains("SELECT"));
325 }
326
327 #[test]
328 fn test_optimize_with_custom_rules() {
329 let expr = parse("SELECT a FROM t WHERE NOT NOT b = 1");
330 let config = OptimizerConfig::default();
331 let rules = &[OptimizationRule::Simplify];
332 let result = optimize_with_rules(expr, &config, rules);
333 let sql = gen(&result);
334 assert!(sql.contains("SELECT"));
335 }
336
337 #[test]
338 fn test_optimizer_config_default() {
339 let config = OptimizerConfig::default();
340 assert!(config.schema.is_none());
341 assert!(config.dialect.is_none());
342 assert!(config.isolate_tables);
343 assert!(!config.quote_identifiers);
344 }
345
346 #[test]
347 fn test_default_rules() {
348 assert!(!DEFAULT_RULES.is_empty());
349 assert!(DEFAULT_RULES.contains(&OptimizationRule::Simplify));
350 assert!(DEFAULT_RULES.contains(&OptimizationRule::Canonicalize));
351 }
352
353 #[test]
354 fn test_optimize_subquery() {
355 let expr = parse("SELECT * FROM (SELECT a FROM t) AS sub");
356 let config = OptimizerConfig::default();
357 let result = optimize(expr, &config);
358 let sql = gen(&result);
359 assert!(sql.contains("SELECT"));
360 }
361
362 #[test]
363 fn test_optimize_cte() {
364 let expr = parse("WITH cte AS (SELECT a FROM t) SELECT * FROM cte");
365 let config = OptimizerConfig::default();
366 let result = optimize(expr, &config);
367 let sql = gen(&result);
368 assert!(sql.contains("WITH"));
369 }
370
371 #[test]
372 fn test_optimize_preserves_semantics() {
373 let expr = parse("SELECT a, b FROM t WHERE c > 1 ORDER BY a");
374 let config = OptimizerConfig::default();
375 let result = optimize(expr, &config);
376 let sql = gen(&result);
377 assert!(sql.contains("ORDER BY"));
378 }
379
380 #[test]
381 fn test_analyze_expression_complexity_deep_connector_chain() {
382 let mut expr = Expression::Eq(Box::new(crate::expressions::BinaryOp::new(
383 Expression::column("c0"),
384 Expression::number(0),
385 )));
386
387 for i in 1..1500 {
388 let predicate = Expression::Eq(Box::new(crate::expressions::BinaryOp::new(
389 Expression::column(format!("c{i}")),
390 Expression::number(i as i64),
391 )));
392 expr = Expression::And(Box::new(crate::expressions::BinaryOp::new(expr, predicate)));
393 }
394
395 let complexity = analyze_expression_complexity(&expr);
396 assert!(complexity.max_connector_depth >= 1499);
397 assert!(complexity.connector_count >= 1499);
398 }
399
400 #[test]
401 fn test_optimize_handles_deep_connector_chain() {
402 let mut expr = Expression::Eq(Box::new(crate::expressions::BinaryOp::new(
403 Expression::column("c0"),
404 Expression::number(0),
405 )));
406
407 for i in 1..2200 {
408 let predicate = Expression::Eq(Box::new(crate::expressions::BinaryOp::new(
409 Expression::column(format!("c{i}")),
410 Expression::number(i as i64),
411 )));
412 expr = Expression::And(Box::new(crate::expressions::BinaryOp::new(expr, predicate)));
413 }
414
415 let config = OptimizerConfig::default();
416 let optimized = optimize(expr, &config);
417 let sql = gen(&optimized);
418 assert!(sql.contains("c2199 = 2199"), "{sql}");
419 }
420}