Skip to main content

polyglot_sql/optimizer/
optimizer.rs

1//! Optimizer Orchestration Module
2//!
3//! This module provides the main entry point for SQL optimization,
4//! coordinating multiple optimization passes in the correct order.
5//!
6//! Ported from sqlglot's optimizer/optimizer.py
7
8use crate::dialects::DialectType;
9use crate::expressions::Expression;
10use crate::schema::Schema;
11use crate::traversal::ExpressionWalk;
12
13use super::annotate_types::annotate_types;
14use super::canonicalize::canonicalize;
15use super::eliminate_ctes::eliminate_ctes;
16use super::eliminate_joins::eliminate_joins;
17use super::normalize::normalize;
18use super::optimize_joins::optimize_joins;
19use super::pushdown_predicates::pushdown_predicates;
20use super::pushdown_projections::pushdown_projections;
21use super::qualify_columns::{qualify_columns, quote_identifiers};
22use super::simplify::simplify;
23use super::subquery::{merge_subqueries, unnest_subqueries};
24
25/// Optimizer configuration
26pub struct OptimizerConfig<'a> {
27    /// Database schema for type inference and column resolution
28    pub schema: Option<&'a dyn Schema>,
29    /// Default database name
30    pub db: Option<String>,
31    /// Default catalog name
32    pub catalog: Option<String>,
33    /// Dialect for dialect-specific optimizations
34    pub dialect: Option<DialectType>,
35    /// Whether to keep tables isolated (don't merge from multiple tables)
36    pub isolate_tables: bool,
37    /// Whether to quote identifiers
38    pub quote_identifiers: bool,
39}
40
41impl<'a> Default for OptimizerConfig<'a> {
42    fn default() -> Self {
43        Self {
44            schema: None,
45            db: None,
46            catalog: None,
47            dialect: None,
48            isolate_tables: true,
49            quote_identifiers: false,
50        }
51    }
52}
53
54/// Optimization rule type
55#[derive(Debug, Clone, Copy, PartialEq, Eq)]
56pub enum OptimizationRule {
57    /// Qualify columns and tables with their full names
58    Qualify,
59    /// Push projections down to eliminate unused columns early
60    PushdownProjections,
61    /// Normalize boolean expressions
62    Normalize,
63    /// Unnest correlated subqueries into joins
64    UnnestSubqueries,
65    /// Push predicates down to filter data early
66    PushdownPredicates,
67    /// Optimize join order and remove cross joins
68    OptimizeJoins,
69    /// Eliminate derived tables by converting to CTEs
70    EliminateSubqueries,
71    /// Merge subqueries into outer queries
72    MergeSubqueries,
73    /// Eliminate unused joins after join optimization and subquery merges
74    EliminateJoins,
75    /// Remove unused CTEs
76    EliminateCtes,
77    /// Quote identifiers that require quoting for the target dialect
78    QuoteIdentifiers,
79    /// Annotate expressions with type information
80    AnnotateTypes,
81    /// Convert expressions to canonical form
82    Canonicalize,
83    /// Simplify expressions
84    Simplify,
85}
86
87/// Default optimization rules in order of execution
88pub const DEFAULT_RULES: &[OptimizationRule] = &[
89    OptimizationRule::Qualify,
90    OptimizationRule::PushdownProjections,
91    OptimizationRule::Normalize,
92    OptimizationRule::UnnestSubqueries,
93    OptimizationRule::PushdownPredicates,
94    OptimizationRule::OptimizeJoins,
95    OptimizationRule::EliminateSubqueries,
96    OptimizationRule::MergeSubqueries,
97    OptimizationRule::EliminateJoins,
98    OptimizationRule::EliminateCtes,
99    OptimizationRule::QuoteIdentifiers,
100    OptimizationRule::AnnotateTypes,
101    OptimizationRule::Canonicalize,
102    OptimizationRule::Simplify,
103];
104
105const QUICK_RULES: &[OptimizationRule] =
106    &[OptimizationRule::Simplify, OptimizationRule::Canonicalize];
107const FAST_PATH_MAX_DEPTH: usize = 768;
108const FAST_PATH_MAX_CONNECTORS: usize = 10_000;
109const FAST_PATH_MAX_CONNECTOR_DEPTH: usize = 1024;
110const FAST_PATH_MAX_NODES: usize = 50_000;
111const CLONE_HEAVY_RULE_SKIP_NODES: usize = 20_000;
112
113#[derive(Debug, Clone, Copy)]
114struct ExpressionComplexity {
115    node_count: usize,
116    max_depth: usize,
117    connector_count: usize,
118    max_connector_depth: usize,
119}
120
121/// Optimize a SQL expression using the default set of rules.
122///
123/// This function coordinates multiple optimization passes in the correct order
124/// to produce an optimized query plan.
125///
126/// # Arguments
127/// * `expression` - The expression to optimize
128/// * `config` - Optimizer configuration
129///
130/// # Returns
131/// The optimized expression
132pub fn optimize(expression: Expression, config: &OptimizerConfig<'_>) -> Expression {
133    optimize_with_rules(expression, config, DEFAULT_RULES)
134}
135
136/// Optimize a SQL expression using a custom set of rules.
137///
138/// # Arguments
139/// * `expression` - The expression to optimize
140/// * `config` - Optimizer configuration
141/// * `rules` - The optimization rules to apply
142///
143/// # Returns
144/// The optimized expression
145pub fn optimize_with_rules(
146    mut expression: Expression,
147    config: &OptimizerConfig<'_>,
148    rules: &[OptimizationRule],
149) -> Expression {
150    let complexity = analyze_expression_complexity(&expression);
151    if rules == DEFAULT_RULES && should_skip_all_optimization(&complexity) {
152        return expression;
153    }
154
155    let active_rules = if rules == DEFAULT_RULES && should_use_quick_path(&complexity) {
156        QUICK_RULES
157    } else {
158        rules
159    };
160
161    for rule in active_rules {
162        if complexity.node_count >= CLONE_HEAVY_RULE_SKIP_NODES
163            && matches!(
164                rule,
165                OptimizationRule::Qualify | OptimizationRule::Normalize
166            )
167        {
168            continue;
169        }
170        expression = apply_rule(expression, *rule, config);
171    }
172    expression
173}
174
175fn should_skip_all_optimization(complexity: &ExpressionComplexity) -> bool {
176    complexity.max_depth >= FAST_PATH_MAX_DEPTH
177        || complexity.max_connector_depth >= FAST_PATH_MAX_CONNECTOR_DEPTH
178}
179
180fn should_use_quick_path(complexity: &ExpressionComplexity) -> bool {
181    complexity.connector_count >= FAST_PATH_MAX_CONNECTORS
182        || complexity.max_connector_depth >= FAST_PATH_MAX_CONNECTOR_DEPTH
183        || complexity.node_count >= FAST_PATH_MAX_NODES
184}
185
186fn analyze_expression_complexity(expression: &Expression) -> ExpressionComplexity {
187    let mut node_count = 0usize;
188    let mut max_depth = 0usize;
189    let mut connector_count = 0usize;
190    let mut max_connector_depth = 0usize;
191    let mut stack: Vec<(&Expression, usize, usize)> = vec![(expression, 0, 0)];
192
193    while let Some((node, depth, connector_depth)) = stack.pop() {
194        node_count += 1;
195        max_depth = max_depth.max(depth);
196
197        match node {
198            Expression::And(op) | Expression::Or(op) => {
199                connector_count += 1;
200                let next_connector_depth = connector_depth + 1;
201                max_connector_depth = max_connector_depth.max(next_connector_depth);
202                stack.push((&op.right, depth + 1, next_connector_depth));
203                stack.push((&op.left, depth + 1, next_connector_depth));
204            }
205            Expression::Paren(paren) => {
206                stack.push((&paren.this, depth + 1, connector_depth));
207            }
208            _ => {
209                for child in node.children().into_iter().rev() {
210                    stack.push((child, depth + 1, 0));
211                }
212            }
213        }
214    }
215
216    ExpressionComplexity {
217        node_count,
218        max_depth,
219        connector_count,
220        max_connector_depth,
221    }
222}
223
224/// Apply a single optimization rule
225fn apply_rule(
226    expression: Expression,
227    rule: OptimizationRule,
228    config: &OptimizerConfig<'_>,
229) -> Expression {
230    match rule {
231        OptimizationRule::Qualify => {
232            // Qualify columns with table references
233            if let Some(schema) = config.schema {
234                let options = super::qualify_columns::QualifyColumnsOptions {
235                    dialect: config.dialect,
236                    ..Default::default()
237                };
238                let original = expression.clone();
239                qualify_columns(expression, schema, &options).unwrap_or(original)
240            } else {
241                // Without schema, skip qualification
242                expression
243            }
244        }
245        OptimizationRule::PushdownProjections => {
246            pushdown_projections(expression, config.dialect, true)
247        }
248        OptimizationRule::Normalize => {
249            // Use CNF (dnf=false) with default max distance
250            let original = expression.clone();
251            normalize(expression, false, super::normalize::DEFAULT_MAX_DISTANCE).unwrap_or(original)
252        }
253        OptimizationRule::UnnestSubqueries => unnest_subqueries(expression),
254        OptimizationRule::PushdownPredicates => pushdown_predicates(expression, config.dialect),
255        OptimizationRule::OptimizeJoins => optimize_joins(expression),
256        OptimizationRule::EliminateSubqueries => eliminate_subqueries_opt(expression),
257        OptimizationRule::MergeSubqueries => merge_subqueries(expression, config.isolate_tables),
258        OptimizationRule::EliminateJoins => eliminate_joins(expression),
259        OptimizationRule::EliminateCtes => eliminate_ctes(expression),
260        OptimizationRule::QuoteIdentifiers => {
261            if config.quote_identifiers {
262                quote_identifiers(expression, config.dialect)
263            } else {
264                expression
265            }
266        }
267        OptimizationRule::AnnotateTypes => {
268            let mut expr = expression;
269            annotate_types(&mut expr, config.schema, config.dialect);
270            expr
271        }
272        OptimizationRule::Canonicalize => canonicalize(expression, config.dialect),
273        OptimizationRule::Simplify => simplify(expression, config.dialect),
274    }
275}
276
277// Re-import from subquery module with different name to avoid conflict
278use super::subquery::eliminate_subqueries as eliminate_subqueries_opt;
279
280/// Quick optimization that only applies essential passes.
281///
282/// This is faster than full optimization but may miss some opportunities.
283pub fn quick_optimize(expression: Expression, dialect: Option<DialectType>) -> Expression {
284    let config = OptimizerConfig {
285        dialect,
286        ..Default::default()
287    };
288
289    optimize_with_rules(expression, &config, QUICK_RULES)
290}
291
292#[cfg(test)]
293mod tests {
294    use super::*;
295    use crate::generator::Generator;
296    use crate::parser::Parser;
297
298    fn gen(expr: &Expression) -> String {
299        Generator::new().generate(expr).unwrap()
300    }
301
302    fn parse(sql: &str) -> Expression {
303        Parser::parse_sql(sql).expect("Failed to parse")[0].clone()
304    }
305
306    #[test]
307    fn test_optimize_simple() {
308        let expr = parse("SELECT a FROM t");
309        let config = OptimizerConfig::default();
310        let result = optimize(expr, &config);
311        let sql = gen(&result);
312        assert!(sql.contains("SELECT"));
313    }
314
315    #[test]
316    fn test_optimize_with_where() {
317        let expr = parse("SELECT a FROM t WHERE b = 1");
318        let config = OptimizerConfig::default();
319        let result = optimize(expr, &config);
320        let sql = gen(&result);
321        assert!(sql.contains("WHERE"));
322    }
323
324    #[test]
325    fn test_optimize_with_join() {
326        let expr = parse("SELECT t.a FROM t JOIN s ON t.id = s.id");
327        let config = OptimizerConfig::default();
328        let result = optimize(expr, &config);
329        let sql = gen(&result);
330        assert!(sql.contains("JOIN"));
331    }
332
333    #[test]
334    fn test_quick_optimize() {
335        let expr = parse("SELECT 1 + 0 FROM t");
336        let result = quick_optimize(expr, None);
337        let sql = gen(&result);
338        assert!(sql.contains("SELECT"));
339    }
340
341    #[test]
342    fn test_optimize_with_custom_rules() {
343        let expr = parse("SELECT a FROM t WHERE NOT NOT b = 1");
344        let config = OptimizerConfig::default();
345        let rules = &[OptimizationRule::Simplify];
346        let result = optimize_with_rules(expr, &config, rules);
347        let sql = gen(&result);
348        assert!(sql.contains("SELECT"));
349    }
350
351    #[test]
352    fn test_optimizer_config_default() {
353        let config = OptimizerConfig::default();
354        assert!(config.schema.is_none());
355        assert!(config.dialect.is_none());
356        assert!(config.isolate_tables);
357        assert!(!config.quote_identifiers);
358    }
359
360    #[test]
361    fn test_default_rules() {
362        assert_eq!(
363            DEFAULT_RULES,
364            &[
365                OptimizationRule::Qualify,
366                OptimizationRule::PushdownProjections,
367                OptimizationRule::Normalize,
368                OptimizationRule::UnnestSubqueries,
369                OptimizationRule::PushdownPredicates,
370                OptimizationRule::OptimizeJoins,
371                OptimizationRule::EliminateSubqueries,
372                OptimizationRule::MergeSubqueries,
373                OptimizationRule::EliminateJoins,
374                OptimizationRule::EliminateCtes,
375                OptimizationRule::QuoteIdentifiers,
376                OptimizationRule::AnnotateTypes,
377                OptimizationRule::Canonicalize,
378                OptimizationRule::Simplify,
379            ]
380        );
381    }
382
383    #[test]
384    fn test_quote_identifiers_rule_respects_config_flag() {
385        let mut expr = parse("SELECT a FROM t");
386        if let Expression::Select(ref mut select) = expr {
387            if let Expression::Column(ref mut col) = select.expressions[0] {
388                col.name.name = "select".to_string();
389            } else {
390                panic!("expected column projection");
391            }
392            if let Some(ref mut from) = select.from {
393                if let Expression::Table(ref mut table) = from.expressions[0] {
394                    table.name.name = "from".to_string();
395                } else {
396                    panic!("expected table reference");
397                }
398            } else {
399                panic!("expected FROM clause");
400            }
401        } else {
402            panic!("expected select expression");
403        }
404        let config = OptimizerConfig {
405            quote_identifiers: true,
406            dialect: Some(DialectType::PostgreSQL),
407            ..Default::default()
408        };
409        let result = optimize_with_rules(expr, &config, &[OptimizationRule::QuoteIdentifiers]);
410        let sql = gen(&result);
411        assert!(sql.contains("\"select\""), "{sql}");
412        assert!(sql.contains("\"from\""), "{sql}");
413    }
414
415    #[test]
416    fn test_quote_identifiers_rule_noop_by_default() {
417        let expr = parse("SELECT a FROM t");
418        let config = OptimizerConfig::default();
419        let result =
420            optimize_with_rules(expr.clone(), &config, &[OptimizationRule::QuoteIdentifiers]);
421        assert_eq!(gen(&result), gen(&expr));
422    }
423
424    #[test]
425    fn test_optimize_subquery() {
426        let expr = parse("SELECT * FROM (SELECT a FROM t) AS sub");
427        let config = OptimizerConfig::default();
428        let result = optimize(expr, &config);
429        let sql = gen(&result);
430        assert!(sql.contains("SELECT"));
431    }
432
433    #[test]
434    fn test_optimize_cte() {
435        let expr = parse("WITH cte AS (SELECT a FROM t) SELECT * FROM cte");
436        let config = OptimizerConfig::default();
437        let result = optimize(expr, &config);
438        let sql = gen(&result);
439        assert!(sql.contains("WITH"));
440    }
441
442    #[test]
443    fn test_optimize_preserves_semantics() {
444        let expr = parse("SELECT a, b FROM t WHERE c > 1 ORDER BY a");
445        let config = OptimizerConfig::default();
446        let result = optimize(expr, &config);
447        let sql = gen(&result);
448        assert!(sql.contains("ORDER BY"));
449    }
450
451    #[test]
452    fn test_analyze_expression_complexity_deep_connector_chain() {
453        let mut expr = Expression::Eq(Box::new(crate::expressions::BinaryOp::new(
454            Expression::column("c0"),
455            Expression::number(0),
456        )));
457
458        for i in 1..1500 {
459            let predicate = Expression::Eq(Box::new(crate::expressions::BinaryOp::new(
460                Expression::column(format!("c{i}")),
461                Expression::number(i as i64),
462            )));
463            expr = Expression::And(Box::new(crate::expressions::BinaryOp::new(expr, predicate)));
464        }
465
466        let complexity = analyze_expression_complexity(&expr);
467        assert!(complexity.max_connector_depth >= 1499);
468        assert!(complexity.connector_count >= 1499);
469    }
470
471    #[test]
472    fn test_optimize_handles_deep_connector_chain() {
473        let mut expr = Expression::Eq(Box::new(crate::expressions::BinaryOp::new(
474            Expression::column("c0"),
475            Expression::number(0),
476        )));
477
478        for i in 1..2200 {
479            let predicate = Expression::Eq(Box::new(crate::expressions::BinaryOp::new(
480                Expression::column(format!("c{i}")),
481                Expression::number(i as i64),
482            )));
483            expr = Expression::And(Box::new(crate::expressions::BinaryOp::new(expr, predicate)));
484        }
485
486        let config = OptimizerConfig::default();
487        let optimized = optimize(expr, &config);
488        let sql = gen(&optimized);
489        assert!(sql.contains("c2199 = 2199"), "{sql}");
490    }
491}