Skip to main content

polyglot_sql/optimizer/
optimizer.rs

1//! Optimizer Orchestration Module
2//!
3//! This module provides the main entry point for SQL optimization,
4//! coordinating multiple optimization passes in the correct order.
5//!
6//! Ported from sqlglot's optimizer/optimizer.py
7
8use crate::dialects::DialectType;
9use crate::expressions::Expression;
10use crate::schema::Schema;
11use crate::traversal::ExpressionWalk;
12
13use super::annotate_types::annotate_types;
14use super::canonicalize::canonicalize;
15use super::eliminate_ctes::eliminate_ctes;
16use super::normalize::normalize;
17use super::optimize_joins::optimize_joins;
18use super::pushdown_predicates::pushdown_predicates;
19use super::pushdown_projections::pushdown_projections;
20use super::qualify_columns::qualify_columns;
21use super::simplify::simplify;
22use super::subquery::{merge_subqueries, unnest_subqueries};
23
24/// Optimizer configuration
25pub struct OptimizerConfig<'a> {
26    /// Database schema for type inference and column resolution
27    pub schema: Option<&'a dyn Schema>,
28    /// Default database name
29    pub db: Option<String>,
30    /// Default catalog name
31    pub catalog: Option<String>,
32    /// Dialect for dialect-specific optimizations
33    pub dialect: Option<DialectType>,
34    /// Whether to keep tables isolated (don't merge from multiple tables)
35    pub isolate_tables: bool,
36    /// Whether to quote identifiers
37    pub quote_identifiers: bool,
38}
39
40impl<'a> Default for OptimizerConfig<'a> {
41    fn default() -> Self {
42        Self {
43            schema: None,
44            db: None,
45            catalog: None,
46            dialect: None,
47            isolate_tables: true,
48            quote_identifiers: false,
49        }
50    }
51}
52
53/// Optimization rule type
54#[derive(Debug, Clone, Copy, PartialEq, Eq)]
55pub enum OptimizationRule {
56    /// Qualify columns and tables with their full names
57    Qualify,
58    /// Push projections down to eliminate unused columns early
59    PushdownProjections,
60    /// Normalize boolean expressions
61    Normalize,
62    /// Unnest correlated subqueries into joins
63    UnnestSubqueries,
64    /// Push predicates down to filter data early
65    PushdownPredicates,
66    /// Optimize join order and remove cross joins
67    OptimizeJoins,
68    /// Eliminate derived tables by converting to CTEs
69    EliminateSubqueries,
70    /// Merge subqueries into outer queries
71    MergeSubqueries,
72    /// Remove unused CTEs
73    EliminateCtes,
74    /// Annotate expressions with type information
75    AnnotateTypes,
76    /// Convert expressions to canonical form
77    Canonicalize,
78    /// Simplify expressions
79    Simplify,
80}
81
82/// Default optimization rules in order of execution
83pub const DEFAULT_RULES: &[OptimizationRule] = &[
84    OptimizationRule::Qualify,
85    OptimizationRule::PushdownProjections,
86    OptimizationRule::Normalize,
87    OptimizationRule::UnnestSubqueries,
88    OptimizationRule::PushdownPredicates,
89    OptimizationRule::OptimizeJoins,
90    OptimizationRule::EliminateSubqueries,
91    OptimizationRule::MergeSubqueries,
92    OptimizationRule::EliminateCtes,
93    OptimizationRule::AnnotateTypes,
94    OptimizationRule::Canonicalize,
95    OptimizationRule::Simplify,
96];
97
98const QUICK_RULES: &[OptimizationRule] =
99    &[OptimizationRule::Simplify, OptimizationRule::Canonicalize];
100const FAST_PATH_MAX_DEPTH: usize = 768;
101const FAST_PATH_MAX_CONNECTORS: usize = 10_000;
102const FAST_PATH_MAX_CONNECTOR_DEPTH: usize = 1024;
103const FAST_PATH_MAX_NODES: usize = 50_000;
104const CLONE_HEAVY_RULE_SKIP_NODES: usize = 20_000;
105
106#[derive(Debug, Clone, Copy)]
107struct ExpressionComplexity {
108    node_count: usize,
109    max_depth: usize,
110    connector_count: usize,
111    max_connector_depth: usize,
112}
113
114/// Optimize a SQL expression using the default set of rules.
115///
116/// This function coordinates multiple optimization passes in the correct order
117/// to produce an optimized query plan.
118///
119/// # Arguments
120/// * `expression` - The expression to optimize
121/// * `config` - Optimizer configuration
122///
123/// # Returns
124/// The optimized expression
125pub fn optimize(expression: Expression, config: &OptimizerConfig<'_>) -> Expression {
126    optimize_with_rules(expression, config, DEFAULT_RULES)
127}
128
129/// Optimize a SQL expression using a custom set of rules.
130///
131/// # Arguments
132/// * `expression` - The expression to optimize
133/// * `config` - Optimizer configuration
134/// * `rules` - The optimization rules to apply
135///
136/// # Returns
137/// The optimized expression
138pub fn optimize_with_rules(
139    mut expression: Expression,
140    config: &OptimizerConfig<'_>,
141    rules: &[OptimizationRule],
142) -> Expression {
143    let complexity = analyze_expression_complexity(&expression);
144    if rules == DEFAULT_RULES && should_skip_all_optimization(&complexity) {
145        return expression;
146    }
147
148    let active_rules = if rules == DEFAULT_RULES && should_use_quick_path(&complexity) {
149        QUICK_RULES
150    } else {
151        rules
152    };
153
154    for rule in active_rules {
155        if complexity.node_count >= CLONE_HEAVY_RULE_SKIP_NODES
156            && matches!(
157                rule,
158                OptimizationRule::Qualify | OptimizationRule::Normalize
159            )
160        {
161            continue;
162        }
163        expression = apply_rule(expression, *rule, config);
164    }
165    expression
166}
167
168fn should_skip_all_optimization(complexity: &ExpressionComplexity) -> bool {
169    complexity.max_depth >= FAST_PATH_MAX_DEPTH
170        || complexity.max_connector_depth >= FAST_PATH_MAX_CONNECTOR_DEPTH
171}
172
173fn should_use_quick_path(complexity: &ExpressionComplexity) -> bool {
174    complexity.connector_count >= FAST_PATH_MAX_CONNECTORS
175        || complexity.max_connector_depth >= FAST_PATH_MAX_CONNECTOR_DEPTH
176        || complexity.node_count >= FAST_PATH_MAX_NODES
177}
178
179fn analyze_expression_complexity(expression: &Expression) -> ExpressionComplexity {
180    let mut node_count = 0usize;
181    let mut max_depth = 0usize;
182    let mut connector_count = 0usize;
183    let mut max_connector_depth = 0usize;
184    let mut stack: Vec<(&Expression, usize, usize)> = vec![(expression, 0, 0)];
185
186    while let Some((node, depth, connector_depth)) = stack.pop() {
187        node_count += 1;
188        max_depth = max_depth.max(depth);
189
190        match node {
191            Expression::And(op) | Expression::Or(op) => {
192                connector_count += 1;
193                let next_connector_depth = connector_depth + 1;
194                max_connector_depth = max_connector_depth.max(next_connector_depth);
195                stack.push((&op.right, depth + 1, next_connector_depth));
196                stack.push((&op.left, depth + 1, next_connector_depth));
197            }
198            Expression::Paren(paren) => {
199                stack.push((&paren.this, depth + 1, connector_depth));
200            }
201            _ => {
202                for child in node.children().into_iter().rev() {
203                    stack.push((child, depth + 1, 0));
204                }
205            }
206        }
207    }
208
209    ExpressionComplexity {
210        node_count,
211        max_depth,
212        connector_count,
213        max_connector_depth,
214    }
215}
216
217/// Apply a single optimization rule
218fn apply_rule(
219    expression: Expression,
220    rule: OptimizationRule,
221    config: &OptimizerConfig<'_>,
222) -> Expression {
223    match rule {
224        OptimizationRule::Qualify => {
225            // Qualify columns with table references
226            if let Some(schema) = config.schema {
227                let options = super::qualify_columns::QualifyColumnsOptions {
228                    dialect: config.dialect,
229                    ..Default::default()
230                };
231                let original = expression.clone();
232                qualify_columns(expression, schema, &options).unwrap_or(original)
233            } else {
234                // Without schema, skip qualification
235                expression
236            }
237        }
238        OptimizationRule::PushdownProjections => {
239            pushdown_projections(expression, config.dialect, true)
240        }
241        OptimizationRule::Normalize => {
242            // Use CNF (dnf=false) with default max distance
243            let original = expression.clone();
244            normalize(expression, false, super::normalize::DEFAULT_MAX_DISTANCE).unwrap_or(original)
245        }
246        OptimizationRule::UnnestSubqueries => unnest_subqueries(expression),
247        OptimizationRule::PushdownPredicates => pushdown_predicates(expression, config.dialect),
248        OptimizationRule::OptimizeJoins => optimize_joins(expression),
249        OptimizationRule::EliminateSubqueries => eliminate_subqueries_opt(expression),
250        OptimizationRule::MergeSubqueries => merge_subqueries(expression, config.isolate_tables),
251        OptimizationRule::EliminateCtes => eliminate_ctes(expression),
252        OptimizationRule::AnnotateTypes => {
253            // annotate_types is used for type inference, not expression transformation
254            // For now, just return the expression unchanged
255            let _ = annotate_types(&expression, config.schema, config.dialect);
256            expression
257        }
258        OptimizationRule::Canonicalize => canonicalize(expression, config.dialect),
259        OptimizationRule::Simplify => simplify(expression, config.dialect),
260    }
261}
262
263// Re-import from subquery module with different name to avoid conflict
264use super::subquery::eliminate_subqueries as eliminate_subqueries_opt;
265
266/// Quick optimization that only applies essential passes.
267///
268/// This is faster than full optimization but may miss some opportunities.
269pub fn quick_optimize(expression: Expression, dialect: Option<DialectType>) -> Expression {
270    let config = OptimizerConfig {
271        dialect,
272        ..Default::default()
273    };
274
275    optimize_with_rules(expression, &config, QUICK_RULES)
276}
277
278#[cfg(test)]
279mod tests {
280    use super::*;
281    use crate::generator::Generator;
282    use crate::parser::Parser;
283
284    fn gen(expr: &Expression) -> String {
285        Generator::new().generate(expr).unwrap()
286    }
287
288    fn parse(sql: &str) -> Expression {
289        Parser::parse_sql(sql).expect("Failed to parse")[0].clone()
290    }
291
292    #[test]
293    fn test_optimize_simple() {
294        let expr = parse("SELECT a FROM t");
295        let config = OptimizerConfig::default();
296        let result = optimize(expr, &config);
297        let sql = gen(&result);
298        assert!(sql.contains("SELECT"));
299    }
300
301    #[test]
302    fn test_optimize_with_where() {
303        let expr = parse("SELECT a FROM t WHERE b = 1");
304        let config = OptimizerConfig::default();
305        let result = optimize(expr, &config);
306        let sql = gen(&result);
307        assert!(sql.contains("WHERE"));
308    }
309
310    #[test]
311    fn test_optimize_with_join() {
312        let expr = parse("SELECT t.a FROM t JOIN s ON t.id = s.id");
313        let config = OptimizerConfig::default();
314        let result = optimize(expr, &config);
315        let sql = gen(&result);
316        assert!(sql.contains("JOIN"));
317    }
318
319    #[test]
320    fn test_quick_optimize() {
321        let expr = parse("SELECT 1 + 0 FROM t");
322        let result = quick_optimize(expr, None);
323        let sql = gen(&result);
324        assert!(sql.contains("SELECT"));
325    }
326
327    #[test]
328    fn test_optimize_with_custom_rules() {
329        let expr = parse("SELECT a FROM t WHERE NOT NOT b = 1");
330        let config = OptimizerConfig::default();
331        let rules = &[OptimizationRule::Simplify];
332        let result = optimize_with_rules(expr, &config, rules);
333        let sql = gen(&result);
334        assert!(sql.contains("SELECT"));
335    }
336
337    #[test]
338    fn test_optimizer_config_default() {
339        let config = OptimizerConfig::default();
340        assert!(config.schema.is_none());
341        assert!(config.dialect.is_none());
342        assert!(config.isolate_tables);
343        assert!(!config.quote_identifiers);
344    }
345
346    #[test]
347    fn test_default_rules() {
348        assert!(!DEFAULT_RULES.is_empty());
349        assert!(DEFAULT_RULES.contains(&OptimizationRule::Simplify));
350        assert!(DEFAULT_RULES.contains(&OptimizationRule::Canonicalize));
351    }
352
353    #[test]
354    fn test_optimize_subquery() {
355        let expr = parse("SELECT * FROM (SELECT a FROM t) AS sub");
356        let config = OptimizerConfig::default();
357        let result = optimize(expr, &config);
358        let sql = gen(&result);
359        assert!(sql.contains("SELECT"));
360    }
361
362    #[test]
363    fn test_optimize_cte() {
364        let expr = parse("WITH cte AS (SELECT a FROM t) SELECT * FROM cte");
365        let config = OptimizerConfig::default();
366        let result = optimize(expr, &config);
367        let sql = gen(&result);
368        assert!(sql.contains("WITH"));
369    }
370
371    #[test]
372    fn test_optimize_preserves_semantics() {
373        let expr = parse("SELECT a, b FROM t WHERE c > 1 ORDER BY a");
374        let config = OptimizerConfig::default();
375        let result = optimize(expr, &config);
376        let sql = gen(&result);
377        assert!(sql.contains("ORDER BY"));
378    }
379
380    #[test]
381    fn test_analyze_expression_complexity_deep_connector_chain() {
382        let mut expr = Expression::Eq(Box::new(crate::expressions::BinaryOp::new(
383            Expression::column("c0"),
384            Expression::number(0),
385        )));
386
387        for i in 1..1500 {
388            let predicate = Expression::Eq(Box::new(crate::expressions::BinaryOp::new(
389                Expression::column(format!("c{i}")),
390                Expression::number(i as i64),
391            )));
392            expr = Expression::And(Box::new(crate::expressions::BinaryOp::new(expr, predicate)));
393        }
394
395        let complexity = analyze_expression_complexity(&expr);
396        assert!(complexity.max_connector_depth >= 1499);
397        assert!(complexity.connector_count >= 1499);
398    }
399
400    #[test]
401    fn test_optimize_handles_deep_connector_chain() {
402        let mut expr = Expression::Eq(Box::new(crate::expressions::BinaryOp::new(
403            Expression::column("c0"),
404            Expression::number(0),
405        )));
406
407        for i in 1..2200 {
408            let predicate = Expression::Eq(Box::new(crate::expressions::BinaryOp::new(
409                Expression::column(format!("c{i}")),
410                Expression::number(i as i64),
411            )));
412            expr = Expression::And(Box::new(crate::expressions::BinaryOp::new(expr, predicate)));
413        }
414
415        let config = OptimizerConfig::default();
416        let optimized = optimize(expr, &config);
417        let sql = gen(&optimized);
418        assert!(sql.contains("c2199 = 2199"), "{sql}");
419    }
420}