polyglot_sql/optimizer/
optimize_joins.rs1use std::collections::{HashMap, HashSet};
11
12use crate::expressions::{BooleanLiteral, Expression, Join, JoinKind};
13use crate::helper::tsort;
14
15pub fn optimize_joins(expression: Expression) -> Expression {
32 let expression = optimize_cross_joins(expression);
33 let expression = reorder_joins(expression);
34 let expression = normalize_joins(expression);
35 expression
36}
37
38fn optimize_cross_joins(expression: Expression) -> Expression {
40 if let Expression::Select(select) = expression {
41 if select.joins.is_empty() || !is_reorderable(&select.joins) {
42 return Expression::Select(select);
43 }
44
45 let mut references: HashMap<String, Vec<usize>> = HashMap::new();
47 let mut cross_joins: Vec<(String, usize)> = Vec::new();
48
49 for (i, join) in select.joins.iter().enumerate() {
50 let tables = other_table_names(join);
51
52 if tables.is_empty() {
53 if let Some(name) = get_join_name(join) {
55 cross_joins.push((name, i));
56 }
57 } else {
58 for table in tables {
60 references.entry(table).or_insert_with(Vec::new).push(i);
61 }
62 }
63 }
64
65 for (name, cross_idx) in &cross_joins {
67 if let Some(ref_indices) = references.get(name) {
68 for &ref_idx in ref_indices {
69 let _ = (cross_idx, ref_idx);
73 }
74 }
75 }
76
77 Expression::Select(select)
78 } else {
79 expression
80 }
81}
82
83pub fn reorder_joins(expression: Expression) -> Expression {
85 if let Expression::Select(mut select) = expression {
86 if select.joins.is_empty() || !is_reorderable(&select.joins) {
87 return Expression::Select(select);
88 }
89
90 let mut joins_by_name: HashMap<String, Join> = HashMap::new();
92 let mut dag: HashMap<String, HashSet<String>> = HashMap::new();
93
94 for join in &select.joins {
95 if let Some(name) = get_join_name(join) {
96 joins_by_name.insert(name.clone(), join.clone());
97 dag.insert(name, other_table_names(join));
98 }
99 }
100
101 if let Ok(sorted) = tsort(dag) {
103 let from_name = select.from.as_ref()
105 .and_then(|f| f.expressions.first())
106 .and_then(|e| get_table_name(e));
107
108 let mut reordered: Vec<Join> = Vec::new();
110 for name in sorted {
111 if Some(&name) != from_name.as_ref() {
112 if let Some(join) = joins_by_name.remove(&name) {
113 reordered.push(join);
114 }
115 }
116 }
117
118 if !reordered.is_empty() && reordered.len() == select.joins.len() {
120 select.joins = reordered;
121 }
122 }
123
124 Expression::Select(select)
125 } else {
126 expression
127 }
128}
129
130pub fn normalize_joins(expression: Expression) -> Expression {
137 if let Expression::Select(mut select) = expression {
138 for join in &mut select.joins {
139 if join.kind == JoinKind::Cross {
141 join.on = None;
142 } else {
143 if join.kind == JoinKind::Inner {
145 join.use_inner_keyword = false;
146 }
147
148 join.use_outer_keyword = false;
150
151 if join.on.is_none() && join.using.is_empty() {
153 join.on = Some(Expression::Boolean(BooleanLiteral { value: true }));
154 }
155 }
156 }
157
158 Expression::Select(select)
159 } else {
160 expression
161 }
162}
163
164pub fn is_reorderable(joins: &[Join]) -> bool {
169 joins.iter().all(|j| {
170 matches!(j.kind, JoinKind::Inner | JoinKind::Cross | JoinKind::Natural)
171 })
172}
173
174fn other_table_names(join: &Join) -> HashSet<String> {
176 let mut tables = HashSet::new();
177
178 if let Some(ref on) = join.on {
179 collect_table_names(on, &mut tables);
180 }
181
182 if let Some(name) = get_join_name(join) {
184 tables.remove(&name);
185 }
186
187 tables
188}
189
190fn collect_table_names(expr: &Expression, tables: &mut HashSet<String>) {
192 match expr {
193 Expression::Column(col) => {
194 if let Some(ref table) = col.table {
195 tables.insert(table.name.clone());
196 }
197 }
198 Expression::And(bin) | Expression::Or(bin) => {
199 collect_table_names(&bin.left, tables);
200 collect_table_names(&bin.right, tables);
201 }
202 Expression::Eq(bin) | Expression::Neq(bin) | Expression::Lt(bin) |
203 Expression::Gt(bin) | Expression::Lte(bin) | Expression::Gte(bin) => {
204 collect_table_names(&bin.left, tables);
205 collect_table_names(&bin.right, tables);
206 }
207 Expression::Paren(p) => {
208 collect_table_names(&p.this, tables);
209 }
210 _ => {}
211 }
212}
213
214fn get_join_name(join: &Join) -> Option<String> {
216 get_table_name(&join.this)
217}
218
219fn get_table_name(expr: &Expression) -> Option<String> {
221 match expr {
222 Expression::Table(table) => {
223 if let Some(ref alias) = table.alias {
224 Some(alias.name.clone())
225 } else {
226 Some(table.name.name.clone())
227 }
228 }
229 Expression::Subquery(subquery) => {
230 subquery.alias.as_ref().map(|a| a.name.clone())
231 }
232 Expression::Alias(alias) => {
233 Some(alias.alias.name.clone())
234 }
235 _ => None,
236 }
237}
238
239#[cfg(test)]
240mod tests {
241 use super::*;
242 use crate::generator::Generator;
243 use crate::parser::Parser;
244
245 fn gen(expr: &Expression) -> String {
246 Generator::new().generate(expr).unwrap()
247 }
248
249 fn parse(sql: &str) -> Expression {
250 Parser::parse_sql(sql).expect("Failed to parse")[0].clone()
251 }
252
253 #[test]
254 fn test_optimize_joins_simple() {
255 let expr = parse("SELECT * FROM x JOIN y ON x.a = y.a");
256 let result = optimize_joins(expr);
257 let sql = gen(&result);
258 assert!(sql.contains("JOIN"));
259 }
260
261 #[test]
262 fn test_is_reorderable_true() {
263 let expr = parse("SELECT * FROM x JOIN y ON x.a = y.a JOIN z ON y.a = z.a");
264 if let Expression::Select(select) = &expr {
265 assert!(is_reorderable(&select.joins));
266 }
267 }
268
269 #[test]
270 fn test_is_reorderable_false() {
271 let expr = parse("SELECT * FROM x LEFT JOIN y ON x.a = y.a");
272 if let Expression::Select(select) = &expr {
273 assert!(!is_reorderable(&select.joins));
274 }
275 }
276
277 #[test]
278 fn test_normalize_inner_join() {
279 let expr = parse("SELECT * FROM x INNER JOIN y ON x.a = y.a");
280 let result = normalize_joins(expr);
281 let sql = gen(&result);
282 assert!(sql.contains("JOIN"));
284 }
285
286 #[test]
287 fn test_normalize_cross_join() {
288 let expr = parse("SELECT * FROM x CROSS JOIN y");
289 let result = normalize_joins(expr);
290 let sql = gen(&result);
291 assert!(sql.contains("CROSS"));
292 }
293
294 #[test]
295 fn test_reorder_joins() {
296 let expr = parse("SELECT * FROM x JOIN y ON x.a = y.a JOIN z ON y.a = z.a");
297 let result = reorder_joins(expr);
298 let sql = gen(&result);
299 assert!(sql.contains("JOIN"));
300 }
301
302 #[test]
303 fn test_other_table_names() {
304 let expr = parse("SELECT * FROM x JOIN y ON x.a = y.a AND x.b = z.b");
305 if let Expression::Select(select) = &expr {
306 if let Some(join) = select.joins.first() {
307 let tables = other_table_names(join);
308 assert!(tables.contains("x"));
309 assert!(tables.contains("z"));
310 }
311 }
312 }
313
314 #[test]
315 fn test_get_join_name_table() {
316 let expr = parse("SELECT * FROM x JOIN y ON x.a = y.a");
317 if let Expression::Select(select) = &expr {
318 if let Some(join) = select.joins.first() {
319 let name = get_join_name(join);
320 assert_eq!(name, Some("y".to_string()));
321 }
322 }
323 }
324
325 #[test]
326 fn test_get_join_name_alias() {
327 let expr = parse("SELECT * FROM x JOIN y AS t ON x.a = t.a");
328 if let Expression::Select(select) = &expr {
329 if let Some(join) = select.joins.first() {
330 let name = get_join_name(join);
331 assert_eq!(name, Some("t".to_string()));
332 }
333 }
334 }
335
336 #[test]
337 fn test_optimize_preserves_structure() {
338 let expr = parse("SELECT a, b FROM x JOIN y ON x.a = y.a WHERE x.b > 1");
339 let result = optimize_joins(expr);
340 let sql = gen(&result);
341 assert!(sql.contains("WHERE"));
342 }
343
344 #[test]
345 fn test_left_join_not_reorderable() {
346 let expr = parse("SELECT * FROM x LEFT JOIN y ON x.a = y.a JOIN z ON y.a = z.a");
347 if let Expression::Select(select) = &expr {
348 assert!(!is_reorderable(&select.joins));
349 }
350 }
351}