polyglot_sql/optimizer/
optimize_joins.rs1use std::collections::{HashMap, HashSet};
11
12use crate::expressions::{BooleanLiteral, Expression, Join, JoinKind};
13use crate::helper::tsort;
14
15pub fn optimize_joins(expression: Expression) -> Expression {
32 let expression = optimize_cross_joins(expression);
33 let expression = reorder_joins(expression);
34 let expression = normalize_joins(expression);
35 expression
36}
37
38fn optimize_cross_joins(expression: Expression) -> Expression {
40 if let Expression::Select(select) = expression {
41 if select.joins.is_empty() || !is_reorderable(&select.joins) {
42 return Expression::Select(select);
43 }
44
45 let mut references: HashMap<String, Vec<usize>> = HashMap::new();
47 let mut cross_joins: Vec<(String, usize)> = Vec::new();
48
49 for (i, join) in select.joins.iter().enumerate() {
50 let tables = other_table_names(join);
51
52 if tables.is_empty() {
53 if let Some(name) = get_join_name(join) {
55 cross_joins.push((name, i));
56 }
57 } else {
58 for table in tables {
60 references.entry(table).or_insert_with(Vec::new).push(i);
61 }
62 }
63 }
64
65 for (name, cross_idx) in &cross_joins {
67 if let Some(ref_indices) = references.get(name) {
68 for &ref_idx in ref_indices {
69 let _ = (cross_idx, ref_idx);
73 }
74 }
75 }
76
77 Expression::Select(select)
78 } else {
79 expression
80 }
81}
82
83pub fn reorder_joins(expression: Expression) -> Expression {
85 if let Expression::Select(mut select) = expression {
86 if select.joins.is_empty() || !is_reorderable(&select.joins) {
87 return Expression::Select(select);
88 }
89
90 let mut joins_by_name: HashMap<String, Join> = HashMap::new();
92 let mut dag: HashMap<String, HashSet<String>> = HashMap::new();
93
94 for join in &select.joins {
95 if let Some(name) = get_join_name(join) {
96 joins_by_name.insert(name.clone(), join.clone());
97 dag.insert(name, other_table_names(join));
98 }
99 }
100
101 if let Ok(sorted) = tsort(dag) {
103 let from_name = select
105 .from
106 .as_ref()
107 .and_then(|f| f.expressions.first())
108 .and_then(|e| get_table_name(e));
109
110 let mut reordered: Vec<Join> = Vec::new();
112 for name in sorted {
113 if Some(&name) != from_name.as_ref() {
114 if let Some(join) = joins_by_name.remove(&name) {
115 reordered.push(join);
116 }
117 }
118 }
119
120 if !reordered.is_empty() && reordered.len() == select.joins.len() {
122 select.joins = reordered;
123 }
124 }
125
126 Expression::Select(select)
127 } else {
128 expression
129 }
130}
131
132pub fn normalize_joins(expression: Expression) -> Expression {
139 if let Expression::Select(mut select) = expression {
140 for join in &mut select.joins {
141 if join.kind == JoinKind::Cross {
143 join.on = None;
144 } else {
145 if join.kind == JoinKind::Inner {
147 join.use_inner_keyword = false;
148 }
149
150 join.use_outer_keyword = false;
152
153 if join.on.is_none() && join.using.is_empty() {
155 join.on = Some(Expression::Boolean(BooleanLiteral { value: true }));
156 }
157 }
158 }
159
160 Expression::Select(select)
161 } else {
162 expression
163 }
164}
165
166pub fn is_reorderable(joins: &[Join]) -> bool {
171 joins.iter().all(|j| {
172 matches!(
173 j.kind,
174 JoinKind::Inner | JoinKind::Cross | JoinKind::Natural
175 )
176 })
177}
178
179fn other_table_names(join: &Join) -> HashSet<String> {
181 let mut tables = HashSet::new();
182
183 if let Some(ref on) = join.on {
184 collect_table_names(on, &mut tables);
185 }
186
187 if let Some(name) = get_join_name(join) {
189 tables.remove(&name);
190 }
191
192 tables
193}
194
195fn collect_table_names(expr: &Expression, tables: &mut HashSet<String>) {
197 match expr {
198 Expression::Column(col) => {
199 if let Some(ref table) = col.table {
200 tables.insert(table.name.clone());
201 }
202 }
203 Expression::And(bin) | Expression::Or(bin) => {
204 collect_table_names(&bin.left, tables);
205 collect_table_names(&bin.right, tables);
206 }
207 Expression::Eq(bin)
208 | Expression::Neq(bin)
209 | Expression::Lt(bin)
210 | Expression::Gt(bin)
211 | Expression::Lte(bin)
212 | Expression::Gte(bin) => {
213 collect_table_names(&bin.left, tables);
214 collect_table_names(&bin.right, tables);
215 }
216 Expression::Paren(p) => {
217 collect_table_names(&p.this, tables);
218 }
219 _ => {}
220 }
221}
222
223fn get_join_name(join: &Join) -> Option<String> {
225 get_table_name(&join.this)
226}
227
228fn get_table_name(expr: &Expression) -> Option<String> {
230 match expr {
231 Expression::Table(table) => {
232 if let Some(ref alias) = table.alias {
233 Some(alias.name.clone())
234 } else {
235 Some(table.name.name.clone())
236 }
237 }
238 Expression::Subquery(subquery) => subquery.alias.as_ref().map(|a| a.name.clone()),
239 Expression::Alias(alias) => Some(alias.alias.name.clone()),
240 _ => None,
241 }
242}
243
244#[cfg(test)]
245mod tests {
246 use super::*;
247 use crate::generator::Generator;
248 use crate::parser::Parser;
249
250 fn gen(expr: &Expression) -> String {
251 Generator::new().generate(expr).unwrap()
252 }
253
254 fn parse(sql: &str) -> Expression {
255 Parser::parse_sql(sql).expect("Failed to parse")[0].clone()
256 }
257
258 #[test]
259 fn test_optimize_joins_simple() {
260 let expr = parse("SELECT * FROM x JOIN y ON x.a = y.a");
261 let result = optimize_joins(expr);
262 let sql = gen(&result);
263 assert!(sql.contains("JOIN"));
264 }
265
266 #[test]
267 fn test_is_reorderable_true() {
268 let expr = parse("SELECT * FROM x JOIN y ON x.a = y.a JOIN z ON y.a = z.a");
269 if let Expression::Select(select) = &expr {
270 assert!(is_reorderable(&select.joins));
271 }
272 }
273
274 #[test]
275 fn test_is_reorderable_false() {
276 let expr = parse("SELECT * FROM x LEFT JOIN y ON x.a = y.a");
277 if let Expression::Select(select) = &expr {
278 assert!(!is_reorderable(&select.joins));
279 }
280 }
281
282 #[test]
283 fn test_normalize_inner_join() {
284 let expr = parse("SELECT * FROM x INNER JOIN y ON x.a = y.a");
285 let result = normalize_joins(expr);
286 let sql = gen(&result);
287 assert!(sql.contains("JOIN"));
289 }
290
291 #[test]
292 fn test_normalize_cross_join() {
293 let expr = parse("SELECT * FROM x CROSS JOIN y");
294 let result = normalize_joins(expr);
295 let sql = gen(&result);
296 assert!(sql.contains("CROSS"));
297 }
298
299 #[test]
300 fn test_reorder_joins() {
301 let expr = parse("SELECT * FROM x JOIN y ON x.a = y.a JOIN z ON y.a = z.a");
302 let result = reorder_joins(expr);
303 let sql = gen(&result);
304 assert!(sql.contains("JOIN"));
305 }
306
307 #[test]
308 fn test_other_table_names() {
309 let expr = parse("SELECT * FROM x JOIN y ON x.a = y.a AND x.b = z.b");
310 if let Expression::Select(select) = &expr {
311 if let Some(join) = select.joins.first() {
312 let tables = other_table_names(join);
313 assert!(tables.contains("x"));
314 assert!(tables.contains("z"));
315 }
316 }
317 }
318
319 #[test]
320 fn test_get_join_name_table() {
321 let expr = parse("SELECT * FROM x JOIN y ON x.a = y.a");
322 if let Expression::Select(select) = &expr {
323 if let Some(join) = select.joins.first() {
324 let name = get_join_name(join);
325 assert_eq!(name, Some("y".to_string()));
326 }
327 }
328 }
329
330 #[test]
331 fn test_get_join_name_alias() {
332 let expr = parse("SELECT * FROM x JOIN y AS t ON x.a = t.a");
333 if let Expression::Select(select) = &expr {
334 if let Some(join) = select.joins.first() {
335 let name = get_join_name(join);
336 assert_eq!(name, Some("t".to_string()));
337 }
338 }
339 }
340
341 #[test]
342 fn test_optimize_preserves_structure() {
343 let expr = parse("SELECT a, b FROM x JOIN y ON x.a = y.a WHERE x.b > 1");
344 let result = optimize_joins(expr);
345 let sql = gen(&result);
346 assert!(sql.contains("WHERE"));
347 }
348
349 #[test]
350 fn test_left_join_not_reorderable() {
351 let expr = parse("SELECT * FROM x LEFT JOIN y ON x.a = y.a JOIN z ON y.a = z.a");
352 if let Expression::Select(select) = &expr {
353 assert!(!is_reorderable(&select.joins));
354 }
355 }
356}