1use std::collections::HashMap;
11
12use crate::expressions::*;
13use crate::traversal::ExpressionWalk;
14
15fn xform<F: Fn(Expression) -> Expression>(expr: Expression, fun: F) -> Expression {
18 crate::traversal::transform(expr, &|node| Ok(Some(fun(node))))
19 .unwrap_or_else(|_| Expression::Null(Null))
20}
21
22pub fn add_select_columns(expr: Expression, columns: Vec<Expression>) -> Expression {
31 if let Expression::Select(mut sel) = expr {
32 sel.expressions.extend(columns);
33 Expression::Select(sel)
34 } else {
35 expr
36 }
37}
38
39pub fn remove_select_columns<F: Fn(&Expression) -> bool>(
41 expr: Expression,
42 predicate: F,
43) -> Expression {
44 if let Expression::Select(mut sel) = expr {
45 sel.expressions.retain(|e| !predicate(e));
46 Expression::Select(sel)
47 } else {
48 expr
49 }
50}
51
52pub fn set_distinct(expr: Expression, distinct: bool) -> Expression {
54 if let Expression::Select(mut sel) = expr {
55 sel.distinct = distinct;
56 Expression::Select(sel)
57 } else {
58 expr
59 }
60}
61
62pub fn add_where(expr: Expression, condition: Expression, use_or: bool) -> Expression {
72 if let Expression::Select(mut sel) = expr {
73 sel.where_clause = Some(match sel.where_clause.take() {
74 Some(existing) => {
75 let combined = if use_or {
76 Expression::Or(Box::new(BinaryOp::new(existing.this, condition)))
77 } else {
78 Expression::And(Box::new(BinaryOp::new(existing.this, condition)))
79 };
80 Where { this: combined }
81 }
82 None => Where { this: condition },
83 });
84 Expression::Select(sel)
85 } else {
86 expr
87 }
88}
89
90pub fn remove_where(expr: Expression) -> Expression {
92 if let Expression::Select(mut sel) = expr {
93 sel.where_clause = None;
94 Expression::Select(sel)
95 } else {
96 expr
97 }
98}
99
100pub fn set_limit(expr: Expression, limit: usize) -> Expression {
106 if let Expression::Select(mut sel) = expr {
107 sel.limit = Some(Limit {
108 this: Expression::number(limit as i64),
109 percent: false,
110 comments: Vec::new(),
111 });
112 Expression::Select(sel)
113 } else {
114 expr
115 }
116}
117
118pub fn set_offset(expr: Expression, offset: usize) -> Expression {
120 if let Expression::Select(mut sel) = expr {
121 sel.offset = Some(Offset {
122 this: Expression::number(offset as i64),
123 rows: None,
124 });
125 Expression::Select(sel)
126 } else {
127 expr
128 }
129}
130
131pub fn remove_limit_offset(expr: Expression) -> Expression {
133 if let Expression::Select(mut sel) = expr {
134 sel.limit = None;
135 sel.offset = None;
136 Expression::Select(sel)
137 } else {
138 expr
139 }
140}
141
142pub fn rename_columns(expr: Expression, mapping: &HashMap<String, String>) -> Expression {
151 xform(expr, |node| match node {
152 Expression::Column(mut col) => {
153 if let Some(new_name) = mapping.get(&col.name.name) {
154 col.name.name = new_name.clone();
155 }
156 Expression::Column(col)
157 }
158 other => other,
159 })
160}
161
162pub fn rename_tables(expr: Expression, mapping: &HashMap<String, String>) -> Expression {
164 xform(expr, |node| match node {
165 Expression::Table(mut tbl) => {
166 if let Some(new_name) = mapping.get(&tbl.name.name) {
167 tbl.name.name = new_name.clone();
168 }
169 Expression::Table(tbl)
170 }
171 Expression::Column(mut col) => {
172 if let Some(ref mut table_id) = col.table {
173 if let Some(new_name) = mapping.get(&table_id.name) {
174 table_id.name = new_name.clone();
175 }
176 }
177 Expression::Column(col)
178 }
179 other => other,
180 })
181}
182
183pub fn qualify_columns(expr: Expression, table_name: &str) -> Expression {
187 let table = table_name.to_string();
188 xform(expr, move |node| match node {
189 Expression::Column(mut col) => {
190 if col.table.is_none() {
191 col.table = Some(Identifier::new(&table));
192 }
193 Expression::Column(col)
194 }
195 other => other,
196 })
197}
198
199pub fn replace_nodes<F: Fn(&Expression) -> bool>(
205 expr: Expression,
206 predicate: F,
207 replacement: Expression,
208) -> Expression {
209 xform(expr, |node| {
210 if predicate(&node) {
211 replacement.clone()
212 } else {
213 node
214 }
215 })
216}
217
218pub fn replace_by_type<F, R>(expr: Expression, predicate: F, replacer: R) -> Expression
220where
221 F: Fn(&Expression) -> bool,
222 R: Fn(Expression) -> Expression,
223{
224 xform(expr, |node| {
225 if predicate(&node) {
226 replacer(node)
227 } else {
228 node
229 }
230 })
231}
232
233pub fn remove_nodes<F: Fn(&Expression) -> bool>(expr: Expression, predicate: F) -> Expression {
239 xform(expr, |node| {
240 if predicate(&node) {
241 Expression::Null(Null)
242 } else {
243 node
244 }
245 })
246}
247
248pub fn get_column_names(expr: &Expression) -> Vec<String> {
254 expr.find_all(|e| matches!(e, Expression::Column(_)))
255 .into_iter()
256 .filter_map(|e| {
257 if let Expression::Column(col) = e {
258 Some(col.name.name.clone())
259 } else {
260 None
261 }
262 })
263 .collect()
264}
265
266pub fn get_table_names(expr: &Expression) -> Vec<String> {
268 expr.find_all(|e| matches!(e, Expression::Table(_)))
269 .into_iter()
270 .filter_map(|e| {
271 if let Expression::Table(tbl) = e {
272 Some(tbl.name.name.clone())
273 } else {
274 None
275 }
276 })
277 .collect()
278}
279
280pub fn get_identifiers(expr: &Expression) -> Vec<&Expression> {
282 expr.find_all(|e| matches!(e, Expression::Identifier(_)))
283}
284
285pub fn get_functions(expr: &Expression) -> Vec<&Expression> {
287 expr.find_all(|e| {
288 matches!(
289 e,
290 Expression::Function(_) | Expression::AggregateFunction(_)
291 )
292 })
293}
294
295pub fn get_literals(expr: &Expression) -> Vec<&Expression> {
297 expr.find_all(|e| {
298 matches!(
299 e,
300 Expression::Literal(_) | Expression::Boolean(_) | Expression::Null(_)
301 )
302 })
303}
304
305pub fn get_subqueries(expr: &Expression) -> Vec<&Expression> {
307 expr.find_all(|e| matches!(e, Expression::Subquery(_)))
308}
309
310pub fn get_aggregate_functions(expr: &Expression) -> Vec<&Expression> {
315 expr.find_all(|e| {
316 matches!(
317 e,
318 Expression::AggregateFunction(_)
319 | Expression::Count(_)
320 | Expression::Sum(_)
321 | Expression::Avg(_)
322 | Expression::Min(_)
323 | Expression::Max(_)
324 | Expression::ApproxDistinct(_)
325 | Expression::ArrayAgg(_)
326 | Expression::GroupConcat(_)
327 | Expression::StringAgg(_)
328 | Expression::ListAgg(_)
329 )
330 })
331}
332
333pub fn get_window_functions(expr: &Expression) -> Vec<&Expression> {
335 expr.find_all(|e| matches!(e, Expression::WindowFunction(_)))
336}
337
338pub fn node_count(expr: &Expression) -> usize {
340 expr.dfs().count()
341}
342
343#[cfg(test)]
344mod tests {
345 use super::*;
346 use crate::parser::Parser;
347
348 fn parse_one(sql: &str) -> Expression {
349 let mut exprs = Parser::parse_sql(sql).unwrap();
350 exprs.remove(0)
351 }
352
353 #[test]
354 fn test_add_where() {
355 let expr = parse_one("SELECT a FROM t");
356 let cond = Expression::Eq(Box::new(BinaryOp::new(
357 Expression::column("b"),
358 Expression::number(1),
359 )));
360 let result = add_where(expr, cond, false);
361 let sql = result.sql();
362 assert!(sql.contains("WHERE"), "Expected WHERE in: {}", sql);
363 assert!(sql.contains("b = 1"), "Expected condition in: {}", sql);
364 }
365
366 #[test]
367 fn test_add_where_combines_with_and() {
368 let expr = parse_one("SELECT a FROM t WHERE x = 1");
369 let cond = Expression::Eq(Box::new(BinaryOp::new(
370 Expression::column("y"),
371 Expression::number(2),
372 )));
373 let result = add_where(expr, cond, false);
374 let sql = result.sql();
375 assert!(sql.contains("AND"), "Expected AND in: {}", sql);
376 }
377
378 #[test]
379 fn test_remove_where() {
380 let expr = parse_one("SELECT a FROM t WHERE x = 1");
381 let result = remove_where(expr);
382 let sql = result.sql();
383 assert!(!sql.contains("WHERE"), "Should not contain WHERE: {}", sql);
384 }
385
386 #[test]
387 fn test_set_limit() {
388 let expr = parse_one("SELECT a FROM t");
389 let result = set_limit(expr, 10);
390 let sql = result.sql();
391 assert!(sql.contains("LIMIT 10"), "Expected LIMIT in: {}", sql);
392 }
393
394 #[test]
395 fn test_set_offset() {
396 let expr = parse_one("SELECT a FROM t");
397 let result = set_offset(expr, 5);
398 let sql = result.sql();
399 assert!(sql.contains("OFFSET 5"), "Expected OFFSET in: {}", sql);
400 }
401
402 #[test]
403 fn test_remove_limit_offset() {
404 let expr = parse_one("SELECT a FROM t LIMIT 10 OFFSET 5");
405 let result = remove_limit_offset(expr);
406 let sql = result.sql();
407 assert!(!sql.contains("LIMIT"), "Should not contain LIMIT: {}", sql);
408 assert!(
409 !sql.contains("OFFSET"),
410 "Should not contain OFFSET: {}",
411 sql
412 );
413 }
414
415 #[test]
416 fn test_get_column_names() {
417 let expr = parse_one("SELECT a, b, c FROM t");
418 let names = get_column_names(&expr);
419 assert!(names.contains(&"a".to_string()));
420 assert!(names.contains(&"b".to_string()));
421 assert!(names.contains(&"c".to_string()));
422 }
423
424 #[test]
425 fn test_get_table_names() {
426 let expr = parse_one("SELECT a FROM users");
429 let tables = crate::traversal::get_tables(&expr);
430 let names = get_table_names(&expr);
432 assert_eq!(
433 names.len(),
434 tables.len(),
435 "get_table_names and get_tables should find same count"
436 );
437 }
438
439 #[test]
440 fn test_node_count() {
441 let expr = parse_one("SELECT a FROM t");
442 let count = node_count(&expr);
443 assert!(count > 0, "Expected non-zero node count");
444 }
445
446 #[test]
447 fn test_rename_columns() {
448 let expr = parse_one("SELECT old_name FROM t");
449 let mut mapping = HashMap::new();
450 mapping.insert("old_name".to_string(), "new_name".to_string());
451 let result = rename_columns(expr, &mapping);
452 let sql = result.sql();
453 assert!(sql.contains("new_name"), "Expected new_name in: {}", sql);
454 assert!(
455 !sql.contains("old_name"),
456 "Should not contain old_name: {}",
457 sql
458 );
459 }
460
461 #[test]
462 fn test_rename_tables() {
463 let expr = parse_one("SELECT a FROM old_table");
464 let mut mapping = HashMap::new();
465 mapping.insert("old_table".to_string(), "new_table".to_string());
466 let result = rename_tables(expr, &mapping);
467 let sql = result.sql();
468 assert!(sql.contains("new_table"), "Expected new_table in: {}", sql);
469 }
470
471 #[test]
472 fn test_set_distinct() {
473 let expr = parse_one("SELECT a FROM t");
474 let result = set_distinct(expr, true);
475 let sql = result.sql();
476 assert!(sql.contains("DISTINCT"), "Expected DISTINCT in: {}", sql);
477 }
478
479 #[test]
480 fn test_add_select_columns() {
481 let expr = parse_one("SELECT a FROM t");
482 let result = add_select_columns(expr, vec![Expression::column("b")]);
483 let sql = result.sql();
484 assert!(
485 sql.contains("a, b") || sql.contains("a,b"),
486 "Expected a, b in: {}",
487 sql
488 );
489 }
490
491 #[test]
492 fn test_qualify_columns() {
493 let expr = parse_one("SELECT a, b FROM t");
494 let result = qualify_columns(expr, "t");
495 let sql = result.sql();
496 assert!(sql.contains("t.a"), "Expected t.a in: {}", sql);
497 assert!(sql.contains("t.b"), "Expected t.b in: {}", sql);
498 }
499
500 #[test]
501 fn test_get_functions() {
502 let expr = parse_one("SELECT COUNT(*), UPPER(name) FROM t");
503 let funcs = get_functions(&expr);
504 let _ = funcs.len();
509 }
510
511 #[test]
512 fn test_get_aggregate_functions() {
513 let expr = parse_one("SELECT COUNT(*), SUM(x) FROM t");
514 let aggs = get_aggregate_functions(&expr);
515 assert!(
516 aggs.len() >= 2,
517 "Expected at least 2 aggregates, got {}",
518 aggs.len()
519 );
520 }
521}