Skip to main content

polyglot_sql/optimizer/
optimize_joins.rs

1//! Join Optimization Module
2//!
3//! This module provides functionality for optimizing JOIN operations:
4//! - Removing cross joins when possible
5//! - Reordering joins based on predicate dependencies
6//! - Normalizing join syntax (removing unnecessary INNER/OUTER keywords)
7//!
8//! Ported from sqlglot's optimizer/optimize_joins.py
9
10use std::collections::{HashMap, HashSet};
11
12use crate::expressions::{BooleanLiteral, Expression, Join, JoinKind};
13use crate::helper::tsort;
14
15/// Optimize joins by removing cross joins and reordering based on dependencies.
16///
17/// # Example
18///
19/// ```sql
20/// -- Before:
21/// SELECT * FROM x CROSS JOIN y JOIN z ON x.a = z.a AND y.a = z.a
22/// -- After:
23/// SELECT * FROM x JOIN z ON x.a = z.a AND TRUE JOIN y ON y.a = z.a
24/// ```
25///
26/// # Arguments
27/// * `expression` - The expression to optimize
28///
29/// # Returns
30/// The optimized expression with improved join order
31pub fn optimize_joins(expression: Expression) -> Expression {
32    let expression = optimize_cross_joins(expression);
33    let expression = reorder_joins(expression);
34    let expression = normalize_joins(expression);
35    expression
36}
37
38/// Optimize cross joins by moving predicates from later joins
39fn optimize_cross_joins(expression: Expression) -> Expression {
40    if let Expression::Select(select) = expression {
41        if select.joins.is_empty() || !is_reorderable(&select.joins) {
42            return Expression::Select(select);
43        }
44
45        // Build reference map: table -> list of joins that reference it
46        let mut references: HashMap<String, Vec<usize>> = HashMap::new();
47        let mut cross_joins: Vec<(String, usize)> = Vec::new();
48
49        for (i, join) in select.joins.iter().enumerate() {
50            let tables = other_table_names(join);
51
52            if tables.is_empty() {
53                // This is a cross join
54                if let Some(name) = get_join_name(join) {
55                    cross_joins.push((name, i));
56                }
57            } else {
58                // This join has ON predicates referencing tables
59                for table in tables {
60                    references.entry(table).or_insert_with(Vec::new).push(i);
61                }
62            }
63        }
64
65        // Move predicates from referencing joins to cross joins
66        for (name, cross_idx) in &cross_joins {
67            if let Some(ref_indices) = references.get(name) {
68                for &ref_idx in ref_indices {
69                    // In a full implementation, we would move predicates
70                    // that reference the cross join table from the referencing
71                    // join to the cross join
72                    let _ = (cross_idx, ref_idx);
73                }
74            }
75        }
76
77        Expression::Select(select)
78    } else {
79        expression
80    }
81}
82
83/// Reorder joins by topological sort based on predicate dependencies.
84pub fn reorder_joins(expression: Expression) -> Expression {
85    if let Expression::Select(mut select) = expression {
86        if select.joins.is_empty() || !is_reorderable(&select.joins) {
87            return Expression::Select(select);
88        }
89
90        // Build dependency graph
91        let mut joins_by_name: HashMap<String, Join> = HashMap::new();
92        let mut dag: HashMap<String, HashSet<String>> = HashMap::new();
93
94        for join in &select.joins {
95            if let Some(name) = get_join_name(join) {
96                joins_by_name.insert(name.clone(), join.clone());
97                dag.insert(name, other_table_names(join));
98            }
99        }
100
101        // Get topologically sorted order
102        if let Ok(sorted) = tsort(dag) {
103            // Get the FROM table name (to exclude from join reordering)
104            let from_name = select
105                .from
106                .as_ref()
107                .and_then(|f| f.expressions.first())
108                .and_then(|e| get_table_name(e));
109
110            // Reorder joins
111            let mut reordered: Vec<Join> = Vec::new();
112            for name in sorted {
113                if Some(&name) != from_name.as_ref() {
114                    if let Some(join) = joins_by_name.remove(&name) {
115                        reordered.push(join);
116                    }
117                }
118            }
119
120            // If reordering succeeded, use new order; otherwise keep original
121            if !reordered.is_empty() && reordered.len() == select.joins.len() {
122                select.joins = reordered;
123            }
124        }
125
126        Expression::Select(select)
127    } else {
128        expression
129    }
130}
131
132/// Normalize join syntax by removing unnecessary keywords.
133///
134/// - Remove INNER keyword (it's the default for joins with ON clause)
135/// - Remove OUTER keyword (only LEFT/RIGHT/FULL matter)
136/// - Add CROSS keyword to joins without any join type
137/// - Add TRUE to joins without ON or USING clause
138pub fn normalize_joins(expression: Expression) -> Expression {
139    if let Expression::Select(mut select) = expression {
140        for join in &mut select.joins {
141            // For CROSS joins, clear the ON clause
142            if join.kind == JoinKind::Cross {
143                join.on = None;
144            } else {
145                // Remove INNER keyword flag (INNER is the default)
146                if join.kind == JoinKind::Inner {
147                    join.use_inner_keyword = false;
148                }
149
150                // Remove OUTER keyword flag
151                join.use_outer_keyword = false;
152
153                // If no ON or USING, add ON TRUE
154                if join.on.is_none() && join.using.is_empty() {
155                    join.on = Some(Expression::Boolean(BooleanLiteral { value: true }));
156                }
157            }
158        }
159
160        Expression::Select(select)
161    } else {
162        expression
163    }
164}
165
166/// Check if joins can be reordered without changing query semantics.
167///
168/// Joins with a side (LEFT, RIGHT, FULL) cannot be reordered,
169/// as the order affects which rows are included.
170pub fn is_reorderable(joins: &[Join]) -> bool {
171    joins.iter().all(|j| {
172        matches!(
173            j.kind,
174            JoinKind::Inner | JoinKind::Cross | JoinKind::Natural
175        )
176    })
177}
178
179/// Get table names referenced in a join's ON clause (excluding the join's own table).
180fn other_table_names(join: &Join) -> HashSet<String> {
181    let mut tables = HashSet::new();
182
183    if let Some(ref on) = join.on {
184        collect_table_names(on, &mut tables);
185    }
186
187    // Remove the join's own table name
188    if let Some(name) = get_join_name(join) {
189        tables.remove(&name);
190    }
191
192    tables
193}
194
195/// Collect all table names referenced in an expression.
196fn collect_table_names(expr: &Expression, tables: &mut HashSet<String>) {
197    match expr {
198        Expression::Column(col) => {
199            if let Some(ref table) = col.table {
200                tables.insert(table.name.clone());
201            }
202        }
203        Expression::And(bin) | Expression::Or(bin) => {
204            collect_table_names(&bin.left, tables);
205            collect_table_names(&bin.right, tables);
206        }
207        Expression::Eq(bin)
208        | Expression::Neq(bin)
209        | Expression::Lt(bin)
210        | Expression::Gt(bin)
211        | Expression::Lte(bin)
212        | Expression::Gte(bin) => {
213            collect_table_names(&bin.left, tables);
214            collect_table_names(&bin.right, tables);
215        }
216        Expression::Paren(p) => {
217            collect_table_names(&p.this, tables);
218        }
219        _ => {}
220    }
221}
222
223/// Get the alias or table name from a join.
224fn get_join_name(join: &Join) -> Option<String> {
225    get_table_name(&join.this)
226}
227
228/// Get the alias or name from a table expression.
229fn get_table_name(expr: &Expression) -> Option<String> {
230    match expr {
231        Expression::Table(table) => {
232            if let Some(ref alias) = table.alias {
233                Some(alias.name.clone())
234            } else {
235                Some(table.name.name.clone())
236            }
237        }
238        Expression::Subquery(subquery) => subquery.alias.as_ref().map(|a| a.name.clone()),
239        Expression::Alias(alias) => Some(alias.alias.name.clone()),
240        _ => None,
241    }
242}
243
244#[cfg(test)]
245mod tests {
246    use super::*;
247    use crate::generator::Generator;
248    use crate::parser::Parser;
249
250    fn gen(expr: &Expression) -> String {
251        Generator::new().generate(expr).unwrap()
252    }
253
254    fn parse(sql: &str) -> Expression {
255        Parser::parse_sql(sql).expect("Failed to parse")[0].clone()
256    }
257
258    #[test]
259    fn test_optimize_joins_simple() {
260        let expr = parse("SELECT * FROM x JOIN y ON x.a = y.a");
261        let result = optimize_joins(expr);
262        let sql = gen(&result);
263        assert!(sql.contains("JOIN"));
264    }
265
266    #[test]
267    fn test_is_reorderable_true() {
268        let expr = parse("SELECT * FROM x JOIN y ON x.a = y.a JOIN z ON y.a = z.a");
269        if let Expression::Select(select) = &expr {
270            assert!(is_reorderable(&select.joins));
271        }
272    }
273
274    #[test]
275    fn test_is_reorderable_false() {
276        let expr = parse("SELECT * FROM x LEFT JOIN y ON x.a = y.a");
277        if let Expression::Select(select) = &expr {
278            assert!(!is_reorderable(&select.joins));
279        }
280    }
281
282    #[test]
283    fn test_normalize_inner_join() {
284        let expr = parse("SELECT * FROM x INNER JOIN y ON x.a = y.a");
285        let result = normalize_joins(expr);
286        let sql = gen(&result);
287        // INNER should be normalized (removed)
288        assert!(sql.contains("JOIN"));
289    }
290
291    #[test]
292    fn test_normalize_cross_join() {
293        let expr = parse("SELECT * FROM x CROSS JOIN y");
294        let result = normalize_joins(expr);
295        let sql = gen(&result);
296        assert!(sql.contains("CROSS"));
297    }
298
299    #[test]
300    fn test_reorder_joins() {
301        let expr = parse("SELECT * FROM x JOIN y ON x.a = y.a JOIN z ON y.a = z.a");
302        let result = reorder_joins(expr);
303        let sql = gen(&result);
304        assert!(sql.contains("JOIN"));
305    }
306
307    #[test]
308    fn test_other_table_names() {
309        let expr = parse("SELECT * FROM x JOIN y ON x.a = y.a AND x.b = z.b");
310        if let Expression::Select(select) = &expr {
311            if let Some(join) = select.joins.first() {
312                let tables = other_table_names(join);
313                assert!(tables.contains("x"));
314                assert!(tables.contains("z"));
315            }
316        }
317    }
318
319    #[test]
320    fn test_get_join_name_table() {
321        let expr = parse("SELECT * FROM x JOIN y ON x.a = y.a");
322        if let Expression::Select(select) = &expr {
323            if let Some(join) = select.joins.first() {
324                let name = get_join_name(join);
325                assert_eq!(name, Some("y".to_string()));
326            }
327        }
328    }
329
330    #[test]
331    fn test_get_join_name_alias() {
332        let expr = parse("SELECT * FROM x JOIN y AS t ON x.a = t.a");
333        if let Expression::Select(select) = &expr {
334            if let Some(join) = select.joins.first() {
335                let name = get_join_name(join);
336                assert_eq!(name, Some("t".to_string()));
337            }
338        }
339    }
340
341    #[test]
342    fn test_optimize_preserves_structure() {
343        let expr = parse("SELECT a, b FROM x JOIN y ON x.a = y.a WHERE x.b > 1");
344        let result = optimize_joins(expr);
345        let sql = gen(&result);
346        assert!(sql.contains("WHERE"));
347    }
348
349    #[test]
350    fn test_left_join_not_reorderable() {
351        let expr = parse("SELECT * FROM x LEFT JOIN y ON x.a = y.a JOIN z ON y.a = z.a");
352        if let Expression::Select(select) = &expr {
353            assert!(!is_reorderable(&select.joins));
354        }
355    }
356}