Skip to main content

polyglot_sql/optimizer/
optimize_joins.rs

1//! Join Optimization Module
2//!
3//! This module provides functionality for optimizing JOIN operations:
4//! - Removing cross joins when possible
5//! - Reordering joins based on predicate dependencies
6//! - Normalizing join syntax (removing unnecessary INNER/OUTER keywords)
7//!
8//! Ported from sqlglot's optimizer/optimize_joins.py
9
10use std::collections::{HashMap, HashSet};
11
12use crate::expressions::{BooleanLiteral, Expression, Join, JoinKind};
13use crate::helper::tsort;
14
15/// Optimize joins by removing cross joins and reordering based on dependencies.
16///
17/// # Example
18///
19/// ```sql
20/// -- Before:
21/// SELECT * FROM x CROSS JOIN y JOIN z ON x.a = z.a AND y.a = z.a
22/// -- After:
23/// SELECT * FROM x JOIN z ON x.a = z.a AND TRUE JOIN y ON y.a = z.a
24/// ```
25///
26/// # Arguments
27/// * `expression` - The expression to optimize
28///
29/// # Returns
30/// The optimized expression with improved join order
31pub fn optimize_joins(expression: Expression) -> Expression {
32    let expression = optimize_cross_joins(expression);
33    let expression = reorder_joins(expression);
34    let expression = normalize_joins(expression);
35    expression
36}
37
38/// Optimize cross joins by moving predicates from later joins
39fn optimize_cross_joins(expression: Expression) -> Expression {
40    if let Expression::Select(select) = expression {
41        if select.joins.is_empty() || !is_reorderable(&select.joins) {
42            return Expression::Select(select);
43        }
44
45        // Build reference map: table -> list of joins that reference it
46        let mut references: HashMap<String, Vec<usize>> = HashMap::new();
47        let mut cross_joins: Vec<(String, usize)> = Vec::new();
48
49        for (i, join) in select.joins.iter().enumerate() {
50            let tables = other_table_names(join);
51
52            if tables.is_empty() {
53                // This is a cross join
54                if let Some(name) = get_join_name(join) {
55                    cross_joins.push((name, i));
56                }
57            } else {
58                // This join has ON predicates referencing tables
59                for table in tables {
60                    references.entry(table).or_insert_with(Vec::new).push(i);
61                }
62            }
63        }
64
65        // Move predicates from referencing joins to cross joins
66        for (name, cross_idx) in &cross_joins {
67            if let Some(ref_indices) = references.get(name) {
68                for &ref_idx in ref_indices {
69                    // In a full implementation, we would move predicates
70                    // that reference the cross join table from the referencing
71                    // join to the cross join
72                    let _ = (cross_idx, ref_idx);
73                }
74            }
75        }
76
77        Expression::Select(select)
78    } else {
79        expression
80    }
81}
82
83/// Reorder joins by topological sort based on predicate dependencies.
84pub fn reorder_joins(expression: Expression) -> Expression {
85    if let Expression::Select(mut select) = expression {
86        if select.joins.is_empty() || !is_reorderable(&select.joins) {
87            return Expression::Select(select);
88        }
89
90        // Build dependency graph
91        let mut joins_by_name: HashMap<String, Join> = HashMap::new();
92        let mut dag: HashMap<String, HashSet<String>> = HashMap::new();
93
94        for join in &select.joins {
95            if let Some(name) = get_join_name(join) {
96                joins_by_name.insert(name.clone(), join.clone());
97                dag.insert(name, other_table_names(join));
98            }
99        }
100
101        // Get topologically sorted order
102        if let Ok(sorted) = tsort(dag) {
103            // Get the FROM table name (to exclude from join reordering)
104            let from_name = select.from.as_ref()
105                .and_then(|f| f.expressions.first())
106                .and_then(|e| get_table_name(e));
107
108            // Reorder joins
109            let mut reordered: Vec<Join> = Vec::new();
110            for name in sorted {
111                if Some(&name) != from_name.as_ref() {
112                    if let Some(join) = joins_by_name.remove(&name) {
113                        reordered.push(join);
114                    }
115                }
116            }
117
118            // If reordering succeeded, use new order; otherwise keep original
119            if !reordered.is_empty() && reordered.len() == select.joins.len() {
120                select.joins = reordered;
121            }
122        }
123
124        Expression::Select(select)
125    } else {
126        expression
127    }
128}
129
130/// Normalize join syntax by removing unnecessary keywords.
131///
132/// - Remove INNER keyword (it's the default for joins with ON clause)
133/// - Remove OUTER keyword (only LEFT/RIGHT/FULL matter)
134/// - Add CROSS keyword to joins without any join type
135/// - Add TRUE to joins without ON or USING clause
136pub fn normalize_joins(expression: Expression) -> Expression {
137    if let Expression::Select(mut select) = expression {
138        for join in &mut select.joins {
139            // For CROSS joins, clear the ON clause
140            if join.kind == JoinKind::Cross {
141                join.on = None;
142            } else {
143                // Remove INNER keyword flag (INNER is the default)
144                if join.kind == JoinKind::Inner {
145                    join.use_inner_keyword = false;
146                }
147
148                // Remove OUTER keyword flag
149                join.use_outer_keyword = false;
150
151                // If no ON or USING, add ON TRUE
152                if join.on.is_none() && join.using.is_empty() {
153                    join.on = Some(Expression::Boolean(BooleanLiteral { value: true }));
154                }
155            }
156        }
157
158        Expression::Select(select)
159    } else {
160        expression
161    }
162}
163
164/// Check if joins can be reordered without changing query semantics.
165///
166/// Joins with a side (LEFT, RIGHT, FULL) cannot be reordered,
167/// as the order affects which rows are included.
168pub fn is_reorderable(joins: &[Join]) -> bool {
169    joins.iter().all(|j| {
170        matches!(j.kind, JoinKind::Inner | JoinKind::Cross | JoinKind::Natural)
171    })
172}
173
174/// Get table names referenced in a join's ON clause (excluding the join's own table).
175fn other_table_names(join: &Join) -> HashSet<String> {
176    let mut tables = HashSet::new();
177
178    if let Some(ref on) = join.on {
179        collect_table_names(on, &mut tables);
180    }
181
182    // Remove the join's own table name
183    if let Some(name) = get_join_name(join) {
184        tables.remove(&name);
185    }
186
187    tables
188}
189
190/// Collect all table names referenced in an expression.
191fn collect_table_names(expr: &Expression, tables: &mut HashSet<String>) {
192    match expr {
193        Expression::Column(col) => {
194            if let Some(ref table) = col.table {
195                tables.insert(table.name.clone());
196            }
197        }
198        Expression::And(bin) | Expression::Or(bin) => {
199            collect_table_names(&bin.left, tables);
200            collect_table_names(&bin.right, tables);
201        }
202        Expression::Eq(bin) | Expression::Neq(bin) | Expression::Lt(bin) |
203        Expression::Gt(bin) | Expression::Lte(bin) | Expression::Gte(bin) => {
204            collect_table_names(&bin.left, tables);
205            collect_table_names(&bin.right, tables);
206        }
207        Expression::Paren(p) => {
208            collect_table_names(&p.this, tables);
209        }
210        _ => {}
211    }
212}
213
214/// Get the alias or table name from a join.
215fn get_join_name(join: &Join) -> Option<String> {
216    get_table_name(&join.this)
217}
218
219/// Get the alias or name from a table expression.
220fn get_table_name(expr: &Expression) -> Option<String> {
221    match expr {
222        Expression::Table(table) => {
223            if let Some(ref alias) = table.alias {
224                Some(alias.name.clone())
225            } else {
226                Some(table.name.name.clone())
227            }
228        }
229        Expression::Subquery(subquery) => {
230            subquery.alias.as_ref().map(|a| a.name.clone())
231        }
232        Expression::Alias(alias) => {
233            Some(alias.alias.name.clone())
234        }
235        _ => None,
236    }
237}
238
239#[cfg(test)]
240mod tests {
241    use super::*;
242    use crate::generator::Generator;
243    use crate::parser::Parser;
244
245    fn gen(expr: &Expression) -> String {
246        Generator::new().generate(expr).unwrap()
247    }
248
249    fn parse(sql: &str) -> Expression {
250        Parser::parse_sql(sql).expect("Failed to parse")[0].clone()
251    }
252
253    #[test]
254    fn test_optimize_joins_simple() {
255        let expr = parse("SELECT * FROM x JOIN y ON x.a = y.a");
256        let result = optimize_joins(expr);
257        let sql = gen(&result);
258        assert!(sql.contains("JOIN"));
259    }
260
261    #[test]
262    fn test_is_reorderable_true() {
263        let expr = parse("SELECT * FROM x JOIN y ON x.a = y.a JOIN z ON y.a = z.a");
264        if let Expression::Select(select) = &expr {
265            assert!(is_reorderable(&select.joins));
266        }
267    }
268
269    #[test]
270    fn test_is_reorderable_false() {
271        let expr = parse("SELECT * FROM x LEFT JOIN y ON x.a = y.a");
272        if let Expression::Select(select) = &expr {
273            assert!(!is_reorderable(&select.joins));
274        }
275    }
276
277    #[test]
278    fn test_normalize_inner_join() {
279        let expr = parse("SELECT * FROM x INNER JOIN y ON x.a = y.a");
280        let result = normalize_joins(expr);
281        let sql = gen(&result);
282        // INNER should be normalized (removed)
283        assert!(sql.contains("JOIN"));
284    }
285
286    #[test]
287    fn test_normalize_cross_join() {
288        let expr = parse("SELECT * FROM x CROSS JOIN y");
289        let result = normalize_joins(expr);
290        let sql = gen(&result);
291        assert!(sql.contains("CROSS"));
292    }
293
294    #[test]
295    fn test_reorder_joins() {
296        let expr = parse("SELECT * FROM x JOIN y ON x.a = y.a JOIN z ON y.a = z.a");
297        let result = reorder_joins(expr);
298        let sql = gen(&result);
299        assert!(sql.contains("JOIN"));
300    }
301
302    #[test]
303    fn test_other_table_names() {
304        let expr = parse("SELECT * FROM x JOIN y ON x.a = y.a AND x.b = z.b");
305        if let Expression::Select(select) = &expr {
306            if let Some(join) = select.joins.first() {
307                let tables = other_table_names(join);
308                assert!(tables.contains("x"));
309                assert!(tables.contains("z"));
310            }
311        }
312    }
313
314    #[test]
315    fn test_get_join_name_table() {
316        let expr = parse("SELECT * FROM x JOIN y ON x.a = y.a");
317        if let Expression::Select(select) = &expr {
318            if let Some(join) = select.joins.first() {
319                let name = get_join_name(join);
320                assert_eq!(name, Some("y".to_string()));
321            }
322        }
323    }
324
325    #[test]
326    fn test_get_join_name_alias() {
327        let expr = parse("SELECT * FROM x JOIN y AS t ON x.a = t.a");
328        if let Expression::Select(select) = &expr {
329            if let Some(join) = select.joins.first() {
330                let name = get_join_name(join);
331                assert_eq!(name, Some("t".to_string()));
332            }
333        }
334    }
335
336    #[test]
337    fn test_optimize_preserves_structure() {
338        let expr = parse("SELECT a, b FROM x JOIN y ON x.a = y.a WHERE x.b > 1");
339        let result = optimize_joins(expr);
340        let sql = gen(&result);
341        assert!(sql.contains("WHERE"));
342    }
343
344    #[test]
345    fn test_left_join_not_reorderable() {
346        let expr = parse("SELECT * FROM x LEFT JOIN y ON x.a = y.a JOIN z ON y.a = z.a");
347        if let Expression::Select(select) = &expr {
348            assert!(!is_reorderable(&select.joins));
349        }
350    }
351}