Skip to main content

polyglot_sql/optimizer/
optimizer.rs

1//! Optimizer Orchestration Module
2//!
3//! This module provides the main entry point for SQL optimization,
4//! coordinating multiple optimization passes in the correct order.
5//!
6//! Ported from sqlglot's optimizer/optimizer.py
7
8use crate::dialects::DialectType;
9use crate::expressions::Expression;
10use crate::schema::Schema;
11use crate::traversal::ExpressionWalk;
12
13use super::annotate_types::annotate_types;
14use super::canonicalize::canonicalize;
15use super::eliminate_ctes::eliminate_ctes;
16use super::normalize::normalize;
17use super::optimize_joins::optimize_joins;
18use super::pushdown_predicates::pushdown_predicates;
19use super::pushdown_projections::pushdown_projections;
20use super::qualify_columns::qualify_columns;
21use super::simplify::simplify;
22use super::subquery::{merge_subqueries, unnest_subqueries};
23
24/// Optimizer configuration
25pub struct OptimizerConfig<'a> {
26    /// Database schema for type inference and column resolution
27    pub schema: Option<&'a dyn Schema>,
28    /// Default database name
29    pub db: Option<String>,
30    /// Default catalog name
31    pub catalog: Option<String>,
32    /// Dialect for dialect-specific optimizations
33    pub dialect: Option<DialectType>,
34    /// Whether to keep tables isolated (don't merge from multiple tables)
35    pub isolate_tables: bool,
36    /// Whether to quote identifiers
37    pub quote_identifiers: bool,
38}
39
40impl<'a> Default for OptimizerConfig<'a> {
41    fn default() -> Self {
42        Self {
43            schema: None,
44            db: None,
45            catalog: None,
46            dialect: None,
47            isolate_tables: true,
48            quote_identifiers: false,
49        }
50    }
51}
52
53/// Optimization rule type
54#[derive(Debug, Clone, Copy, PartialEq, Eq)]
55pub enum OptimizationRule {
56    /// Qualify columns and tables with their full names
57    Qualify,
58    /// Push projections down to eliminate unused columns early
59    PushdownProjections,
60    /// Normalize boolean expressions
61    Normalize,
62    /// Unnest correlated subqueries into joins
63    UnnestSubqueries,
64    /// Push predicates down to filter data early
65    PushdownPredicates,
66    /// Optimize join order and remove cross joins
67    OptimizeJoins,
68    /// Eliminate derived tables by converting to CTEs
69    EliminateSubqueries,
70    /// Merge subqueries into outer queries
71    MergeSubqueries,
72    /// Remove unused CTEs
73    EliminateCtes,
74    /// Annotate expressions with type information
75    AnnotateTypes,
76    /// Convert expressions to canonical form
77    Canonicalize,
78    /// Simplify expressions
79    Simplify,
80}
81
82/// Default optimization rules in order of execution
83pub const DEFAULT_RULES: &[OptimizationRule] = &[
84    OptimizationRule::Qualify,
85    OptimizationRule::PushdownProjections,
86    OptimizationRule::Normalize,
87    OptimizationRule::UnnestSubqueries,
88    OptimizationRule::PushdownPredicates,
89    OptimizationRule::OptimizeJoins,
90    OptimizationRule::EliminateSubqueries,
91    OptimizationRule::MergeSubqueries,
92    OptimizationRule::EliminateCtes,
93    OptimizationRule::AnnotateTypes,
94    OptimizationRule::Canonicalize,
95    OptimizationRule::Simplify,
96];
97
98const QUICK_RULES: &[OptimizationRule] =
99    &[OptimizationRule::Simplify, OptimizationRule::Canonicalize];
100const FAST_PATH_MAX_DEPTH: usize = 768;
101const FAST_PATH_MAX_CONNECTORS: usize = 10_000;
102const FAST_PATH_MAX_CONNECTOR_DEPTH: usize = 1024;
103const FAST_PATH_MAX_NODES: usize = 50_000;
104const CLONE_HEAVY_RULE_SKIP_NODES: usize = 20_000;
105
106#[derive(Debug, Clone, Copy)]
107struct ExpressionComplexity {
108    node_count: usize,
109    max_depth: usize,
110    connector_count: usize,
111    max_connector_depth: usize,
112}
113
114/// Optimize a SQL expression using the default set of rules.
115///
116/// This function coordinates multiple optimization passes in the correct order
117/// to produce an optimized query plan.
118///
119/// # Arguments
120/// * `expression` - The expression to optimize
121/// * `config` - Optimizer configuration
122///
123/// # Returns
124/// The optimized expression
125pub fn optimize(expression: Expression, config: &OptimizerConfig<'_>) -> Expression {
126    optimize_with_rules(expression, config, DEFAULT_RULES)
127}
128
129/// Optimize a SQL expression using a custom set of rules.
130///
131/// # Arguments
132/// * `expression` - The expression to optimize
133/// * `config` - Optimizer configuration
134/// * `rules` - The optimization rules to apply
135///
136/// # Returns
137/// The optimized expression
138pub fn optimize_with_rules(
139    mut expression: Expression,
140    config: &OptimizerConfig<'_>,
141    rules: &[OptimizationRule],
142) -> Expression {
143    let complexity = analyze_expression_complexity(&expression);
144    if rules == DEFAULT_RULES && should_skip_all_optimization(&complexity) {
145        return expression;
146    }
147
148    let active_rules = if rules == DEFAULT_RULES && should_use_quick_path(&complexity) {
149        QUICK_RULES
150    } else {
151        rules
152    };
153
154    for rule in active_rules {
155        if complexity.node_count >= CLONE_HEAVY_RULE_SKIP_NODES
156            && matches!(
157                rule,
158                OptimizationRule::Qualify | OptimizationRule::Normalize
159            )
160        {
161            continue;
162        }
163        expression = apply_rule(expression, *rule, config);
164    }
165    expression
166}
167
168fn should_skip_all_optimization(complexity: &ExpressionComplexity) -> bool {
169    complexity.max_depth >= FAST_PATH_MAX_DEPTH
170        || complexity.max_connector_depth >= FAST_PATH_MAX_CONNECTOR_DEPTH
171}
172
173fn should_use_quick_path(complexity: &ExpressionComplexity) -> bool {
174    complexity.connector_count >= FAST_PATH_MAX_CONNECTORS
175        || complexity.max_connector_depth >= FAST_PATH_MAX_CONNECTOR_DEPTH
176        || complexity.node_count >= FAST_PATH_MAX_NODES
177}
178
179fn analyze_expression_complexity(expression: &Expression) -> ExpressionComplexity {
180    let mut node_count = 0usize;
181    let mut max_depth = 0usize;
182    let mut connector_count = 0usize;
183    let mut max_connector_depth = 0usize;
184    let mut stack: Vec<(&Expression, usize, usize)> = vec![(expression, 0, 0)];
185
186    while let Some((node, depth, connector_depth)) = stack.pop() {
187        node_count += 1;
188        max_depth = max_depth.max(depth);
189
190        match node {
191            Expression::And(op) | Expression::Or(op) => {
192                connector_count += 1;
193                let next_connector_depth = connector_depth + 1;
194                max_connector_depth = max_connector_depth.max(next_connector_depth);
195                stack.push((&op.right, depth + 1, next_connector_depth));
196                stack.push((&op.left, depth + 1, next_connector_depth));
197            }
198            Expression::Paren(paren) => {
199                stack.push((&paren.this, depth + 1, connector_depth));
200            }
201            _ => {
202                for child in node.children().into_iter().rev() {
203                    stack.push((child, depth + 1, 0));
204                }
205            }
206        }
207    }
208
209    ExpressionComplexity {
210        node_count,
211        max_depth,
212        connector_count,
213        max_connector_depth,
214    }
215}
216
217/// Apply a single optimization rule
218fn apply_rule(
219    expression: Expression,
220    rule: OptimizationRule,
221    config: &OptimizerConfig<'_>,
222) -> Expression {
223    match rule {
224        OptimizationRule::Qualify => {
225            // Qualify columns with table references
226            if let Some(schema) = config.schema {
227                let options = super::qualify_columns::QualifyColumnsOptions {
228                    dialect: config.dialect,
229                    ..Default::default()
230                };
231                let original = expression.clone();
232                qualify_columns(expression, schema, &options).unwrap_or(original)
233            } else {
234                // Without schema, skip qualification
235                expression
236            }
237        }
238        OptimizationRule::PushdownProjections => {
239            pushdown_projections(expression, config.dialect, true)
240        }
241        OptimizationRule::Normalize => {
242            // Use CNF (dnf=false) with default max distance
243            let original = expression.clone();
244            normalize(expression, false, super::normalize::DEFAULT_MAX_DISTANCE).unwrap_or(original)
245        }
246        OptimizationRule::UnnestSubqueries => unnest_subqueries(expression),
247        OptimizationRule::PushdownPredicates => pushdown_predicates(expression, config.dialect),
248        OptimizationRule::OptimizeJoins => optimize_joins(expression),
249        OptimizationRule::EliminateSubqueries => eliminate_subqueries_opt(expression),
250        OptimizationRule::MergeSubqueries => merge_subqueries(expression, config.isolate_tables),
251        OptimizationRule::EliminateCtes => eliminate_ctes(expression),
252        OptimizationRule::AnnotateTypes => {
253            let mut expr = expression;
254            annotate_types(&mut expr, config.schema, config.dialect);
255            expr
256        }
257        OptimizationRule::Canonicalize => canonicalize(expression, config.dialect),
258        OptimizationRule::Simplify => simplify(expression, config.dialect),
259    }
260}
261
262// Re-import from subquery module with different name to avoid conflict
263use super::subquery::eliminate_subqueries as eliminate_subqueries_opt;
264
265/// Quick optimization that only applies essential passes.
266///
267/// This is faster than full optimization but may miss some opportunities.
268pub fn quick_optimize(expression: Expression, dialect: Option<DialectType>) -> Expression {
269    let config = OptimizerConfig {
270        dialect,
271        ..Default::default()
272    };
273
274    optimize_with_rules(expression, &config, QUICK_RULES)
275}
276
277#[cfg(test)]
278mod tests {
279    use super::*;
280    use crate::generator::Generator;
281    use crate::parser::Parser;
282
283    fn gen(expr: &Expression) -> String {
284        Generator::new().generate(expr).unwrap()
285    }
286
287    fn parse(sql: &str) -> Expression {
288        Parser::parse_sql(sql).expect("Failed to parse")[0].clone()
289    }
290
291    #[test]
292    fn test_optimize_simple() {
293        let expr = parse("SELECT a FROM t");
294        let config = OptimizerConfig::default();
295        let result = optimize(expr, &config);
296        let sql = gen(&result);
297        assert!(sql.contains("SELECT"));
298    }
299
300    #[test]
301    fn test_optimize_with_where() {
302        let expr = parse("SELECT a FROM t WHERE b = 1");
303        let config = OptimizerConfig::default();
304        let result = optimize(expr, &config);
305        let sql = gen(&result);
306        assert!(sql.contains("WHERE"));
307    }
308
309    #[test]
310    fn test_optimize_with_join() {
311        let expr = parse("SELECT t.a FROM t JOIN s ON t.id = s.id");
312        let config = OptimizerConfig::default();
313        let result = optimize(expr, &config);
314        let sql = gen(&result);
315        assert!(sql.contains("JOIN"));
316    }
317
318    #[test]
319    fn test_quick_optimize() {
320        let expr = parse("SELECT 1 + 0 FROM t");
321        let result = quick_optimize(expr, None);
322        let sql = gen(&result);
323        assert!(sql.contains("SELECT"));
324    }
325
326    #[test]
327    fn test_optimize_with_custom_rules() {
328        let expr = parse("SELECT a FROM t WHERE NOT NOT b = 1");
329        let config = OptimizerConfig::default();
330        let rules = &[OptimizationRule::Simplify];
331        let result = optimize_with_rules(expr, &config, rules);
332        let sql = gen(&result);
333        assert!(sql.contains("SELECT"));
334    }
335
336    #[test]
337    fn test_optimizer_config_default() {
338        let config = OptimizerConfig::default();
339        assert!(config.schema.is_none());
340        assert!(config.dialect.is_none());
341        assert!(config.isolate_tables);
342        assert!(!config.quote_identifiers);
343    }
344
345    #[test]
346    fn test_default_rules() {
347        assert!(!DEFAULT_RULES.is_empty());
348        assert!(DEFAULT_RULES.contains(&OptimizationRule::Simplify));
349        assert!(DEFAULT_RULES.contains(&OptimizationRule::Canonicalize));
350    }
351
352    #[test]
353    fn test_optimize_subquery() {
354        let expr = parse("SELECT * FROM (SELECT a FROM t) AS sub");
355        let config = OptimizerConfig::default();
356        let result = optimize(expr, &config);
357        let sql = gen(&result);
358        assert!(sql.contains("SELECT"));
359    }
360
361    #[test]
362    fn test_optimize_cte() {
363        let expr = parse("WITH cte AS (SELECT a FROM t) SELECT * FROM cte");
364        let config = OptimizerConfig::default();
365        let result = optimize(expr, &config);
366        let sql = gen(&result);
367        assert!(sql.contains("WITH"));
368    }
369
370    #[test]
371    fn test_optimize_preserves_semantics() {
372        let expr = parse("SELECT a, b FROM t WHERE c > 1 ORDER BY a");
373        let config = OptimizerConfig::default();
374        let result = optimize(expr, &config);
375        let sql = gen(&result);
376        assert!(sql.contains("ORDER BY"));
377    }
378
379    #[test]
380    fn test_analyze_expression_complexity_deep_connector_chain() {
381        let mut expr = Expression::Eq(Box::new(crate::expressions::BinaryOp::new(
382            Expression::column("c0"),
383            Expression::number(0),
384        )));
385
386        for i in 1..1500 {
387            let predicate = Expression::Eq(Box::new(crate::expressions::BinaryOp::new(
388                Expression::column(format!("c{i}")),
389                Expression::number(i as i64),
390            )));
391            expr = Expression::And(Box::new(crate::expressions::BinaryOp::new(expr, predicate)));
392        }
393
394        let complexity = analyze_expression_complexity(&expr);
395        assert!(complexity.max_connector_depth >= 1499);
396        assert!(complexity.connector_count >= 1499);
397    }
398
399    #[test]
400    fn test_optimize_handles_deep_connector_chain() {
401        let mut expr = Expression::Eq(Box::new(crate::expressions::BinaryOp::new(
402            Expression::column("c0"),
403            Expression::number(0),
404        )));
405
406        for i in 1..2200 {
407            let predicate = Expression::Eq(Box::new(crate::expressions::BinaryOp::new(
408                Expression::column(format!("c{i}")),
409                Expression::number(i as i64),
410            )));
411            expr = Expression::And(Box::new(crate::expressions::BinaryOp::new(expr, predicate)));
412        }
413
414        let config = OptimizerConfig::default();
415        let optimized = optimize(expr, &config);
416        let sql = gen(&optimized);
417        assert!(sql.contains("c2199 = 2199"), "{sql}");
418    }
419}