1use crate::dialects::DialectType;
9use crate::expressions::Expression;
10use crate::schema::Schema;
11use crate::traversal::ExpressionWalk;
12
13use super::annotate_types::annotate_types;
14use super::canonicalize::canonicalize;
15use super::eliminate_ctes::eliminate_ctes;
16use super::eliminate_joins::eliminate_joins;
17use super::normalize::normalize;
18use super::optimize_joins::optimize_joins;
19use super::pushdown_predicates::pushdown_predicates;
20use super::pushdown_projections::pushdown_projections;
21use super::qualify_columns::{qualify_columns, quote_identifiers};
22use super::simplify::simplify;
23use super::subquery::{merge_subqueries, unnest_subqueries};
24
25pub struct OptimizerConfig<'a> {
27 pub schema: Option<&'a dyn Schema>,
29 pub db: Option<String>,
31 pub catalog: Option<String>,
33 pub dialect: Option<DialectType>,
35 pub isolate_tables: bool,
37 pub quote_identifiers: bool,
39}
40
41impl<'a> Default for OptimizerConfig<'a> {
42 fn default() -> Self {
43 Self {
44 schema: None,
45 db: None,
46 catalog: None,
47 dialect: None,
48 isolate_tables: true,
49 quote_identifiers: false,
50 }
51 }
52}
53
54#[derive(Debug, Clone, Copy, PartialEq, Eq)]
56pub enum OptimizationRule {
57 Qualify,
59 PushdownProjections,
61 Normalize,
63 UnnestSubqueries,
65 PushdownPredicates,
67 OptimizeJoins,
69 EliminateSubqueries,
71 MergeSubqueries,
73 EliminateJoins,
75 EliminateCtes,
77 QuoteIdentifiers,
79 AnnotateTypes,
81 Canonicalize,
83 Simplify,
85}
86
87pub const DEFAULT_RULES: &[OptimizationRule] = &[
89 OptimizationRule::Qualify,
90 OptimizationRule::PushdownProjections,
91 OptimizationRule::Normalize,
92 OptimizationRule::UnnestSubqueries,
93 OptimizationRule::PushdownPredicates,
94 OptimizationRule::OptimizeJoins,
95 OptimizationRule::EliminateSubqueries,
96 OptimizationRule::MergeSubqueries,
97 OptimizationRule::EliminateJoins,
98 OptimizationRule::EliminateCtes,
99 OptimizationRule::QuoteIdentifiers,
100 OptimizationRule::AnnotateTypes,
101 OptimizationRule::Canonicalize,
102 OptimizationRule::Simplify,
103];
104
105const QUICK_RULES: &[OptimizationRule] =
106 &[OptimizationRule::Simplify, OptimizationRule::Canonicalize];
107const FAST_PATH_MAX_DEPTH: usize = 768;
108const FAST_PATH_MAX_CONNECTORS: usize = 10_000;
109const FAST_PATH_MAX_CONNECTOR_DEPTH: usize = 1024;
110const FAST_PATH_MAX_NODES: usize = 50_000;
111const CLONE_HEAVY_RULE_SKIP_NODES: usize = 20_000;
112
113#[derive(Debug, Clone, Copy)]
114struct ExpressionComplexity {
115 node_count: usize,
116 max_depth: usize,
117 connector_count: usize,
118 max_connector_depth: usize,
119}
120
121pub fn optimize(expression: Expression, config: &OptimizerConfig<'_>) -> Expression {
133 optimize_with_rules(expression, config, DEFAULT_RULES)
134}
135
136pub fn optimize_with_rules(
146 mut expression: Expression,
147 config: &OptimizerConfig<'_>,
148 rules: &[OptimizationRule],
149) -> Expression {
150 let complexity = analyze_expression_complexity(&expression);
151 if rules == DEFAULT_RULES && should_skip_all_optimization(&complexity) {
152 return expression;
153 }
154
155 let active_rules = if rules == DEFAULT_RULES && should_use_quick_path(&complexity) {
156 QUICK_RULES
157 } else {
158 rules
159 };
160
161 for rule in active_rules {
162 if complexity.node_count >= CLONE_HEAVY_RULE_SKIP_NODES
163 && matches!(
164 rule,
165 OptimizationRule::Qualify | OptimizationRule::Normalize
166 )
167 {
168 continue;
169 }
170 expression = apply_rule(expression, *rule, config);
171 }
172 expression
173}
174
175fn should_skip_all_optimization(complexity: &ExpressionComplexity) -> bool {
176 complexity.max_depth >= FAST_PATH_MAX_DEPTH
177 || complexity.max_connector_depth >= FAST_PATH_MAX_CONNECTOR_DEPTH
178}
179
180fn should_use_quick_path(complexity: &ExpressionComplexity) -> bool {
181 complexity.connector_count >= FAST_PATH_MAX_CONNECTORS
182 || complexity.max_connector_depth >= FAST_PATH_MAX_CONNECTOR_DEPTH
183 || complexity.node_count >= FAST_PATH_MAX_NODES
184}
185
186fn analyze_expression_complexity(expression: &Expression) -> ExpressionComplexity {
187 let mut node_count = 0usize;
188 let mut max_depth = 0usize;
189 let mut connector_count = 0usize;
190 let mut max_connector_depth = 0usize;
191 let mut stack: Vec<(&Expression, usize, usize)> = vec![(expression, 0, 0)];
192
193 while let Some((node, depth, connector_depth)) = stack.pop() {
194 node_count += 1;
195 max_depth = max_depth.max(depth);
196
197 match node {
198 Expression::And(op) | Expression::Or(op) => {
199 connector_count += 1;
200 let next_connector_depth = connector_depth + 1;
201 max_connector_depth = max_connector_depth.max(next_connector_depth);
202 stack.push((&op.right, depth + 1, next_connector_depth));
203 stack.push((&op.left, depth + 1, next_connector_depth));
204 }
205 Expression::Paren(paren) => {
206 stack.push((&paren.this, depth + 1, connector_depth));
207 }
208 _ => {
209 for child in node.children().into_iter().rev() {
210 stack.push((child, depth + 1, 0));
211 }
212 }
213 }
214 }
215
216 ExpressionComplexity {
217 node_count,
218 max_depth,
219 connector_count,
220 max_connector_depth,
221 }
222}
223
224fn apply_rule(
226 expression: Expression,
227 rule: OptimizationRule,
228 config: &OptimizerConfig<'_>,
229) -> Expression {
230 match rule {
231 OptimizationRule::Qualify => {
232 if let Some(schema) = config.schema {
234 let options = super::qualify_columns::QualifyColumnsOptions {
235 dialect: config.dialect,
236 ..Default::default()
237 };
238 let original = expression.clone();
239 qualify_columns(expression, schema, &options).unwrap_or(original)
240 } else {
241 expression
243 }
244 }
245 OptimizationRule::PushdownProjections => {
246 pushdown_projections(expression, config.dialect, true)
247 }
248 OptimizationRule::Normalize => {
249 let original = expression.clone();
251 normalize(expression, false, super::normalize::DEFAULT_MAX_DISTANCE).unwrap_or(original)
252 }
253 OptimizationRule::UnnestSubqueries => unnest_subqueries(expression),
254 OptimizationRule::PushdownPredicates => pushdown_predicates(expression, config.dialect),
255 OptimizationRule::OptimizeJoins => optimize_joins(expression),
256 OptimizationRule::EliminateSubqueries => eliminate_subqueries_opt(expression),
257 OptimizationRule::MergeSubqueries => merge_subqueries(expression, config.isolate_tables),
258 OptimizationRule::EliminateJoins => eliminate_joins(expression),
259 OptimizationRule::EliminateCtes => eliminate_ctes(expression),
260 OptimizationRule::QuoteIdentifiers => {
261 if config.quote_identifiers {
262 quote_identifiers(expression, config.dialect)
263 } else {
264 expression
265 }
266 }
267 OptimizationRule::AnnotateTypes => {
268 let mut expr = expression;
269 annotate_types(&mut expr, config.schema, config.dialect);
270 expr
271 }
272 OptimizationRule::Canonicalize => canonicalize(expression, config.dialect),
273 OptimizationRule::Simplify => simplify(expression, config.dialect),
274 }
275}
276
277use super::subquery::eliminate_subqueries as eliminate_subqueries_opt;
279
280pub fn quick_optimize(expression: Expression, dialect: Option<DialectType>) -> Expression {
284 let config = OptimizerConfig {
285 dialect,
286 ..Default::default()
287 };
288
289 optimize_with_rules(expression, &config, QUICK_RULES)
290}
291
292#[cfg(test)]
293mod tests {
294 use super::*;
295 use crate::generator::Generator;
296 use crate::parser::Parser;
297
298 fn gen(expr: &Expression) -> String {
299 Generator::new().generate(expr).unwrap()
300 }
301
302 fn parse(sql: &str) -> Expression {
303 Parser::parse_sql(sql).expect("Failed to parse")[0].clone()
304 }
305
306 #[test]
307 fn test_optimize_simple() {
308 let expr = parse("SELECT a FROM t");
309 let config = OptimizerConfig::default();
310 let result = optimize(expr, &config);
311 let sql = gen(&result);
312 assert!(sql.contains("SELECT"));
313 }
314
315 #[test]
316 fn test_optimize_with_where() {
317 let expr = parse("SELECT a FROM t WHERE b = 1");
318 let config = OptimizerConfig::default();
319 let result = optimize(expr, &config);
320 let sql = gen(&result);
321 assert!(sql.contains("WHERE"));
322 }
323
324 #[test]
325 fn test_optimize_with_join() {
326 let expr = parse("SELECT t.a FROM t JOIN s ON t.id = s.id");
327 let config = OptimizerConfig::default();
328 let result = optimize(expr, &config);
329 let sql = gen(&result);
330 assert!(sql.contains("JOIN"));
331 }
332
333 #[test]
334 fn test_quick_optimize() {
335 let expr = parse("SELECT 1 + 0 FROM t");
336 let result = quick_optimize(expr, None);
337 let sql = gen(&result);
338 assert!(sql.contains("SELECT"));
339 }
340
341 #[test]
342 fn test_optimize_with_custom_rules() {
343 let expr = parse("SELECT a FROM t WHERE NOT NOT b = 1");
344 let config = OptimizerConfig::default();
345 let rules = &[OptimizationRule::Simplify];
346 let result = optimize_with_rules(expr, &config, rules);
347 let sql = gen(&result);
348 assert!(sql.contains("SELECT"));
349 }
350
351 #[test]
352 fn test_optimizer_config_default() {
353 let config = OptimizerConfig::default();
354 assert!(config.schema.is_none());
355 assert!(config.dialect.is_none());
356 assert!(config.isolate_tables);
357 assert!(!config.quote_identifiers);
358 }
359
360 #[test]
361 fn test_default_rules() {
362 assert_eq!(
363 DEFAULT_RULES,
364 &[
365 OptimizationRule::Qualify,
366 OptimizationRule::PushdownProjections,
367 OptimizationRule::Normalize,
368 OptimizationRule::UnnestSubqueries,
369 OptimizationRule::PushdownPredicates,
370 OptimizationRule::OptimizeJoins,
371 OptimizationRule::EliminateSubqueries,
372 OptimizationRule::MergeSubqueries,
373 OptimizationRule::EliminateJoins,
374 OptimizationRule::EliminateCtes,
375 OptimizationRule::QuoteIdentifiers,
376 OptimizationRule::AnnotateTypes,
377 OptimizationRule::Canonicalize,
378 OptimizationRule::Simplify,
379 ]
380 );
381 }
382
383 #[test]
384 fn test_quote_identifiers_rule_respects_config_flag() {
385 let mut expr = parse("SELECT a FROM t");
386 if let Expression::Select(ref mut select) = expr {
387 if let Expression::Column(ref mut col) = select.expressions[0] {
388 col.name.name = "select".to_string();
389 } else {
390 panic!("expected column projection");
391 }
392 if let Some(ref mut from) = select.from {
393 if let Expression::Table(ref mut table) = from.expressions[0] {
394 table.name.name = "from".to_string();
395 } else {
396 panic!("expected table reference");
397 }
398 } else {
399 panic!("expected FROM clause");
400 }
401 } else {
402 panic!("expected select expression");
403 }
404 let config = OptimizerConfig {
405 quote_identifiers: true,
406 dialect: Some(DialectType::PostgreSQL),
407 ..Default::default()
408 };
409 let result = optimize_with_rules(expr, &config, &[OptimizationRule::QuoteIdentifiers]);
410 let sql = gen(&result);
411 assert!(sql.contains("\"select\""), "{sql}");
412 assert!(sql.contains("\"from\""), "{sql}");
413 }
414
415 #[test]
416 fn test_quote_identifiers_rule_noop_by_default() {
417 let expr = parse("SELECT a FROM t");
418 let config = OptimizerConfig::default();
419 let result =
420 optimize_with_rules(expr.clone(), &config, &[OptimizationRule::QuoteIdentifiers]);
421 assert_eq!(gen(&result), gen(&expr));
422 }
423
424 #[test]
425 fn test_optimize_subquery() {
426 let expr = parse("SELECT * FROM (SELECT a FROM t) AS sub");
427 let config = OptimizerConfig::default();
428 let result = optimize(expr, &config);
429 let sql = gen(&result);
430 assert!(sql.contains("SELECT"));
431 }
432
433 #[test]
434 fn test_optimize_cte() {
435 let expr = parse("WITH cte AS (SELECT a FROM t) SELECT * FROM cte");
436 let config = OptimizerConfig::default();
437 let result = optimize(expr, &config);
438 let sql = gen(&result);
439 assert!(sql.contains("WITH"));
440 }
441
442 #[test]
443 fn test_optimize_preserves_semantics() {
444 let expr = parse("SELECT a, b FROM t WHERE c > 1 ORDER BY a");
445 let config = OptimizerConfig::default();
446 let result = optimize(expr, &config);
447 let sql = gen(&result);
448 assert!(sql.contains("ORDER BY"));
449 }
450
451 #[test]
452 fn test_analyze_expression_complexity_deep_connector_chain() {
453 let mut expr = Expression::Eq(Box::new(crate::expressions::BinaryOp::new(
454 Expression::column("c0"),
455 Expression::number(0),
456 )));
457
458 for i in 1..1500 {
459 let predicate = Expression::Eq(Box::new(crate::expressions::BinaryOp::new(
460 Expression::column(format!("c{i}")),
461 Expression::number(i as i64),
462 )));
463 expr = Expression::And(Box::new(crate::expressions::BinaryOp::new(expr, predicate)));
464 }
465
466 let complexity = analyze_expression_complexity(&expr);
467 assert!(complexity.max_connector_depth >= 1499);
468 assert!(complexity.connector_count >= 1499);
469 }
470
471 #[test]
472 fn test_optimize_handles_deep_connector_chain() {
473 let mut expr = Expression::Eq(Box::new(crate::expressions::BinaryOp::new(
474 Expression::column("c0"),
475 Expression::number(0),
476 )));
477
478 for i in 1..2200 {
479 let predicate = Expression::Eq(Box::new(crate::expressions::BinaryOp::new(
480 Expression::column(format!("c{i}")),
481 Expression::number(i as i64),
482 )));
483 expr = Expression::And(Box::new(crate::expressions::BinaryOp::new(expr, predicate)));
484 }
485
486 let config = OptimizerConfig::default();
487 let optimized = optimize(expr, &config);
488 let sql = gen(&optimized);
489 assert!(sql.contains("c2199 = 2199"), "{sql}");
490 }
491}