1use std::collections::{HashMap, HashSet};
11
12use crate::expressions::*;
13use crate::traversal::ExpressionWalk;
14
15fn xform<F: Fn(Expression) -> Expression>(expr: Expression, fun: F) -> Expression {
18 crate::traversal::transform(expr, &|node| Ok(Some(fun(node))))
19 .unwrap_or_else(|_| Expression::Null(Null))
20}
21
22pub fn add_select_columns(expr: Expression, columns: Vec<Expression>) -> Expression {
31 if let Expression::Select(mut sel) = expr {
32 sel.expressions.extend(columns);
33 Expression::Select(sel)
34 } else {
35 expr
36 }
37}
38
39pub fn remove_select_columns<F: Fn(&Expression) -> bool>(
41 expr: Expression,
42 predicate: F,
43) -> Expression {
44 if let Expression::Select(mut sel) = expr {
45 sel.expressions.retain(|e| !predicate(e));
46 Expression::Select(sel)
47 } else {
48 expr
49 }
50}
51
52pub fn set_distinct(expr: Expression, distinct: bool) -> Expression {
54 if let Expression::Select(mut sel) = expr {
55 sel.distinct = distinct;
56 Expression::Select(sel)
57 } else {
58 expr
59 }
60}
61
62pub fn add_where(expr: Expression, condition: Expression, use_or: bool) -> Expression {
72 if let Expression::Select(mut sel) = expr {
73 sel.where_clause = Some(match sel.where_clause.take() {
74 Some(existing) => {
75 let combined = if use_or {
76 Expression::Or(Box::new(BinaryOp::new(existing.this, condition)))
77 } else {
78 Expression::And(Box::new(BinaryOp::new(existing.this, condition)))
79 };
80 Where { this: combined }
81 }
82 None => Where { this: condition },
83 });
84 Expression::Select(sel)
85 } else {
86 expr
87 }
88}
89
90pub fn remove_where(expr: Expression) -> Expression {
92 if let Expression::Select(mut sel) = expr {
93 sel.where_clause = None;
94 Expression::Select(sel)
95 } else {
96 expr
97 }
98}
99
100pub fn set_limit(expr: Expression, limit: usize) -> Expression {
106 if let Expression::Select(mut sel) = expr {
107 sel.limit = Some(Limit {
108 this: Expression::number(limit as i64),
109 percent: false,
110 comments: Vec::new(),
111 });
112 Expression::Select(sel)
113 } else {
114 expr
115 }
116}
117
118pub fn set_offset(expr: Expression, offset: usize) -> Expression {
120 if let Expression::Select(mut sel) = expr {
121 sel.offset = Some(Offset {
122 this: Expression::number(offset as i64),
123 rows: None,
124 });
125 Expression::Select(sel)
126 } else {
127 expr
128 }
129}
130
131pub fn remove_limit_offset(expr: Expression) -> Expression {
133 if let Expression::Select(mut sel) = expr {
134 sel.limit = None;
135 sel.offset = None;
136 Expression::Select(sel)
137 } else {
138 expr
139 }
140}
141
142pub fn rename_columns(expr: Expression, mapping: &HashMap<String, String>) -> Expression {
151 xform(expr, |node| match node {
152 Expression::Column(mut col) => {
153 if let Some(new_name) = mapping.get(&col.name.name) {
154 col.name.name = new_name.clone();
155 }
156 Expression::Column(col)
157 }
158 other => other,
159 })
160}
161
162pub fn rename_tables(expr: Expression, mapping: &HashMap<String, String>) -> Expression {
164 xform(expr, |node| match node {
165 Expression::Table(mut tbl) => {
166 if let Some(new_name) = mapping.get(&tbl.name.name) {
167 tbl.name.name = new_name.clone();
168 }
169 Expression::Table(tbl)
170 }
171 Expression::Column(mut col) => {
172 if let Some(ref mut table_id) = col.table {
173 if let Some(new_name) = mapping.get(&table_id.name) {
174 table_id.name = new_name.clone();
175 }
176 }
177 Expression::Column(col)
178 }
179 other => other,
180 })
181}
182
183pub fn qualify_columns(expr: Expression, table_name: &str) -> Expression {
187 let table = table_name.to_string();
188 xform(expr, move |node| match node {
189 Expression::Column(mut col) => {
190 if col.table.is_none() {
191 col.table = Some(Identifier::new(&table));
192 }
193 Expression::Column(col)
194 }
195 other => other,
196 })
197}
198
199pub fn replace_nodes<F: Fn(&Expression) -> bool>(
205 expr: Expression,
206 predicate: F,
207 replacement: Expression,
208) -> Expression {
209 xform(expr, |node| {
210 if predicate(&node) {
211 replacement.clone()
212 } else {
213 node
214 }
215 })
216}
217
218pub fn replace_by_type<F, R>(expr: Expression, predicate: F, replacer: R) -> Expression
220where
221 F: Fn(&Expression) -> bool,
222 R: Fn(Expression) -> Expression,
223{
224 xform(expr, |node| {
225 if predicate(&node) {
226 replacer(node)
227 } else {
228 node
229 }
230 })
231}
232
233pub fn remove_nodes<F: Fn(&Expression) -> bool>(expr: Expression, predicate: F) -> Expression {
239 xform(expr, |node| {
240 if predicate(&node) {
241 Expression::Null(Null)
242 } else {
243 node
244 }
245 })
246}
247
248pub fn get_column_names(expr: &Expression) -> Vec<String> {
254 expr.find_all(|e| matches!(e, Expression::Column(_)))
255 .into_iter()
256 .filter_map(|e| {
257 if let Expression::Column(col) = e {
258 Some(col.name.name.clone())
259 } else {
260 None
261 }
262 })
263 .collect()
264}
265
266pub fn get_table_names(expr: &Expression) -> Vec<String> {
268 fn collect_cte_aliases(with_clause: &With, aliases: &mut HashSet<String>) {
269 for cte in &with_clause.ctes {
270 aliases.insert(cte.alias.name.clone());
271 }
272 }
273
274 fn push_table_ref_name(
275 table: &TableRef,
276 cte_aliases: &HashSet<String>,
277 names: &mut Vec<String>,
278 ) {
279 let name = table.name.name.clone();
280 if !name.is_empty() && !cte_aliases.contains(&name) {
281 names.push(name);
282 }
283 }
284
285 let mut cte_aliases: HashSet<String> = HashSet::new();
286 for node in expr.dfs() {
287 match node {
288 Expression::Select(select) => {
289 if let Some(with) = &select.with {
290 collect_cte_aliases(with, &mut cte_aliases);
291 }
292 }
293 Expression::Insert(insert) => {
294 if let Some(with) = &insert.with {
295 collect_cte_aliases(with, &mut cte_aliases);
296 }
297 }
298 Expression::Update(update) => {
299 if let Some(with) = &update.with {
300 collect_cte_aliases(with, &mut cte_aliases);
301 }
302 }
303 Expression::Delete(delete) => {
304 if let Some(with) = &delete.with {
305 collect_cte_aliases(with, &mut cte_aliases);
306 }
307 }
308 Expression::Union(union) => {
309 if let Some(with) = &union.with {
310 collect_cte_aliases(with, &mut cte_aliases);
311 }
312 }
313 Expression::Intersect(intersect) => {
314 if let Some(with) = &intersect.with {
315 collect_cte_aliases(with, &mut cte_aliases);
316 }
317 }
318 Expression::Except(except) => {
319 if let Some(with) = &except.with {
320 collect_cte_aliases(with, &mut cte_aliases);
321 }
322 }
323 Expression::Merge(merge) => {
324 if let Some(with_) = &merge.with_ {
325 if let Expression::With(with_clause) = with_.as_ref() {
326 collect_cte_aliases(with_clause, &mut cte_aliases);
327 }
328 }
329 }
330 _ => {}
331 }
332 }
333
334 let mut names = Vec::new();
335 for node in expr.dfs() {
336 match node {
337 Expression::Table(tbl) => {
338 let name = tbl.name.name.clone();
339 if !name.is_empty() && !cte_aliases.contains(&name) {
340 names.push(name);
341 }
342 }
343 Expression::Insert(insert) => {
344 push_table_ref_name(&insert.table, &cte_aliases, &mut names);
345 }
346 Expression::Update(update) => {
347 push_table_ref_name(&update.table, &cte_aliases, &mut names);
348 for table in &update.extra_tables {
349 push_table_ref_name(table, &cte_aliases, &mut names);
350 }
351 }
352 Expression::Delete(delete) => {
353 push_table_ref_name(&delete.table, &cte_aliases, &mut names);
354 for table in &delete.using {
355 push_table_ref_name(table, &cte_aliases, &mut names);
356 }
357 for table in &delete.tables {
358 push_table_ref_name(table, &cte_aliases, &mut names);
359 }
360 }
361 _ => {}
362 }
363 }
364
365 names
366}
367
368pub fn get_identifiers(expr: &Expression) -> Vec<&Expression> {
370 expr.find_all(|e| matches!(e, Expression::Identifier(_)))
371}
372
373pub fn get_functions(expr: &Expression) -> Vec<&Expression> {
375 expr.find_all(|e| {
376 matches!(
377 e,
378 Expression::Function(_) | Expression::AggregateFunction(_)
379 )
380 })
381}
382
383pub fn get_literals(expr: &Expression) -> Vec<&Expression> {
385 expr.find_all(|e| {
386 matches!(
387 e,
388 Expression::Literal(_) | Expression::Boolean(_) | Expression::Null(_)
389 )
390 })
391}
392
393pub fn get_subqueries(expr: &Expression) -> Vec<&Expression> {
395 expr.find_all(|e| matches!(e, Expression::Subquery(_)))
396}
397
398pub fn get_aggregate_functions(expr: &Expression) -> Vec<&Expression> {
403 expr.find_all(|e| {
404 matches!(
405 e,
406 Expression::AggregateFunction(_)
407 | Expression::Count(_)
408 | Expression::Sum(_)
409 | Expression::Avg(_)
410 | Expression::Min(_)
411 | Expression::Max(_)
412 | Expression::ApproxDistinct(_)
413 | Expression::ArrayAgg(_)
414 | Expression::GroupConcat(_)
415 | Expression::StringAgg(_)
416 | Expression::ListAgg(_)
417 )
418 })
419}
420
421pub fn get_window_functions(expr: &Expression) -> Vec<&Expression> {
423 expr.find_all(|e| matches!(e, Expression::WindowFunction(_)))
424}
425
426pub fn node_count(expr: &Expression) -> usize {
428 expr.dfs().count()
429}
430
431#[cfg(test)]
432mod tests {
433 use super::*;
434 use crate::parser::Parser;
435
436 fn parse_one(sql: &str) -> Expression {
437 let mut exprs = Parser::parse_sql(sql).unwrap();
438 exprs.remove(0)
439 }
440
441 #[test]
442 fn test_add_where() {
443 let expr = parse_one("SELECT a FROM t");
444 let cond = Expression::Eq(Box::new(BinaryOp::new(
445 Expression::column("b"),
446 Expression::number(1),
447 )));
448 let result = add_where(expr, cond, false);
449 let sql = result.sql();
450 assert!(sql.contains("WHERE"), "Expected WHERE in: {}", sql);
451 assert!(sql.contains("b = 1"), "Expected condition in: {}", sql);
452 }
453
454 #[test]
455 fn test_add_where_combines_with_and() {
456 let expr = parse_one("SELECT a FROM t WHERE x = 1");
457 let cond = Expression::Eq(Box::new(BinaryOp::new(
458 Expression::column("y"),
459 Expression::number(2),
460 )));
461 let result = add_where(expr, cond, false);
462 let sql = result.sql();
463 assert!(sql.contains("AND"), "Expected AND in: {}", sql);
464 }
465
466 #[test]
467 fn test_remove_where() {
468 let expr = parse_one("SELECT a FROM t WHERE x = 1");
469 let result = remove_where(expr);
470 let sql = result.sql();
471 assert!(!sql.contains("WHERE"), "Should not contain WHERE: {}", sql);
472 }
473
474 #[test]
475 fn test_set_limit() {
476 let expr = parse_one("SELECT a FROM t");
477 let result = set_limit(expr, 10);
478 let sql = result.sql();
479 assert!(sql.contains("LIMIT 10"), "Expected LIMIT in: {}", sql);
480 }
481
482 #[test]
483 fn test_set_offset() {
484 let expr = parse_one("SELECT a FROM t");
485 let result = set_offset(expr, 5);
486 let sql = result.sql();
487 assert!(sql.contains("OFFSET 5"), "Expected OFFSET in: {}", sql);
488 }
489
490 #[test]
491 fn test_remove_limit_offset() {
492 let expr = parse_one("SELECT a FROM t LIMIT 10 OFFSET 5");
493 let result = remove_limit_offset(expr);
494 let sql = result.sql();
495 assert!(!sql.contains("LIMIT"), "Should not contain LIMIT: {}", sql);
496 assert!(
497 !sql.contains("OFFSET"),
498 "Should not contain OFFSET: {}",
499 sql
500 );
501 }
502
503 #[test]
504 fn test_get_column_names() {
505 let expr = parse_one("SELECT a, b, c FROM t");
506 let names = get_column_names(&expr);
507 assert!(names.contains(&"a".to_string()));
508 assert!(names.contains(&"b".to_string()));
509 assert!(names.contains(&"c".to_string()));
510 }
511
512 #[test]
513 fn test_get_table_names() {
514 let expr = parse_one("SELECT a FROM users");
515 let names = get_table_names(&expr);
516 assert_eq!(names, vec!["users".to_string()]);
517 }
518
519 #[test]
520 fn test_get_table_names_excludes_cte_aliases() {
521 let expr = parse_one(
522 "WITH cte AS (SELECT * FROM users) SELECT * FROM cte JOIN orders o ON cte.id = o.id",
523 );
524 let names = get_table_names(&expr);
525 assert!(names.iter().any(|n| n == "users"));
526 assert!(names.iter().any(|n| n == "orders"));
527 assert!(!names.iter().any(|n| n == "cte"));
528 }
529
530 #[test]
531 fn test_get_table_names_includes_dml_targets() {
532 let insert_expr = parse_one("INSERT INTO users (id) VALUES (1)");
533 let insert_names = get_table_names(&insert_expr);
534 assert!(insert_names.iter().any(|n| n == "users"));
535
536 let update_expr =
537 parse_one("UPDATE users SET name = 'x' FROM accounts WHERE users.id = accounts.id");
538 let update_names = get_table_names(&update_expr);
539 assert!(update_names.iter().any(|n| n == "users"));
540 assert!(update_names.iter().any(|n| n == "accounts"));
541
542 let delete_expr =
543 parse_one("DELETE FROM users USING accounts WHERE users.id = accounts.id");
544 let delete_names = get_table_names(&delete_expr);
545 assert!(delete_names.iter().any(|n| n == "users"));
546 assert!(delete_names.iter().any(|n| n == "accounts"));
547 }
548
549 #[test]
550 fn test_node_count() {
551 let expr = parse_one("SELECT a FROM t");
552 let count = node_count(&expr);
553 assert!(count > 0, "Expected non-zero node count");
554 }
555
556 #[test]
557 fn test_rename_columns() {
558 let expr = parse_one("SELECT old_name FROM t");
559 let mut mapping = HashMap::new();
560 mapping.insert("old_name".to_string(), "new_name".to_string());
561 let result = rename_columns(expr, &mapping);
562 let sql = result.sql();
563 assert!(sql.contains("new_name"), "Expected new_name in: {}", sql);
564 assert!(
565 !sql.contains("old_name"),
566 "Should not contain old_name: {}",
567 sql
568 );
569 }
570
571 #[test]
572 fn test_rename_tables() {
573 let expr = parse_one("SELECT a FROM old_table");
574 let mut mapping = HashMap::new();
575 mapping.insert("old_table".to_string(), "new_table".to_string());
576 let result = rename_tables(expr, &mapping);
577 let sql = result.sql();
578 assert!(sql.contains("new_table"), "Expected new_table in: {}", sql);
579 }
580
581 #[test]
582 fn test_set_distinct() {
583 let expr = parse_one("SELECT a FROM t");
584 let result = set_distinct(expr, true);
585 let sql = result.sql();
586 assert!(sql.contains("DISTINCT"), "Expected DISTINCT in: {}", sql);
587 }
588
589 #[test]
590 fn test_add_select_columns() {
591 let expr = parse_one("SELECT a FROM t");
592 let result = add_select_columns(expr, vec![Expression::column("b")]);
593 let sql = result.sql();
594 assert!(
595 sql.contains("a, b") || sql.contains("a,b"),
596 "Expected a, b in: {}",
597 sql
598 );
599 }
600
601 #[test]
602 fn test_qualify_columns() {
603 let expr = parse_one("SELECT a, b FROM t");
604 let result = qualify_columns(expr, "t");
605 let sql = result.sql();
606 assert!(sql.contains("t.a"), "Expected t.a in: {}", sql);
607 assert!(sql.contains("t.b"), "Expected t.b in: {}", sql);
608 }
609
610 #[test]
611 fn test_get_functions() {
612 let expr = parse_one("SELECT COUNT(*), UPPER(name) FROM t");
613 let funcs = get_functions(&expr);
614 let _ = funcs.len();
619 }
620
621 #[test]
622 fn test_get_aggregate_functions() {
623 let expr = parse_one("SELECT COUNT(*), SUM(x) FROM t");
624 let aggs = get_aggregate_functions(&expr);
625 assert!(
626 aggs.len() >= 2,
627 "Expected at least 2 aggregates, got {}",
628 aggs.len()
629 );
630 }
631}