1use std::collections::HashMap;
11
12use crate::expressions::*;
13use crate::traversal::ExpressionWalk;
14
15fn xform<F: Fn(Expression) -> Expression>(expr: Expression, fun: F) -> Expression {
18 crate::traversal::transform(expr, &|node| Ok(Some(fun(node))))
19 .unwrap_or_else(|_| Expression::Null(Null))
20}
21
22pub fn add_select_columns(expr: Expression, columns: Vec<Expression>) -> Expression {
31 if let Expression::Select(mut sel) = expr {
32 sel.expressions.extend(columns);
33 Expression::Select(sel)
34 } else {
35 expr
36 }
37}
38
39pub fn remove_select_columns<F: Fn(&Expression) -> bool>(
41 expr: Expression,
42 predicate: F,
43) -> Expression {
44 if let Expression::Select(mut sel) = expr {
45 sel.expressions.retain(|e| !predicate(e));
46 Expression::Select(sel)
47 } else {
48 expr
49 }
50}
51
52pub fn set_distinct(expr: Expression, distinct: bool) -> Expression {
54 if let Expression::Select(mut sel) = expr {
55 sel.distinct = distinct;
56 Expression::Select(sel)
57 } else {
58 expr
59 }
60}
61
62pub fn add_where(expr: Expression, condition: Expression, use_or: bool) -> Expression {
72 if let Expression::Select(mut sel) = expr {
73 sel.where_clause = Some(match sel.where_clause.take() {
74 Some(existing) => {
75 let combined = if use_or {
76 Expression::Or(Box::new(BinaryOp::new(existing.this, condition)))
77 } else {
78 Expression::And(Box::new(BinaryOp::new(existing.this, condition)))
79 };
80 Where { this: combined }
81 }
82 None => Where { this: condition },
83 });
84 Expression::Select(sel)
85 } else {
86 expr
87 }
88}
89
90pub fn remove_where(expr: Expression) -> Expression {
92 if let Expression::Select(mut sel) = expr {
93 sel.where_clause = None;
94 Expression::Select(sel)
95 } else {
96 expr
97 }
98}
99
100pub fn set_limit(expr: Expression, limit: usize) -> Expression {
106 if let Expression::Select(mut sel) = expr {
107 sel.limit = Some(Limit {
108 this: Expression::number(limit as i64),
109 percent: false,
110 });
111 Expression::Select(sel)
112 } else {
113 expr
114 }
115}
116
117pub fn set_offset(expr: Expression, offset: usize) -> Expression {
119 if let Expression::Select(mut sel) = expr {
120 sel.offset = Some(Offset {
121 this: Expression::number(offset as i64),
122 rows: None,
123 });
124 Expression::Select(sel)
125 } else {
126 expr
127 }
128}
129
130pub fn remove_limit_offset(expr: Expression) -> Expression {
132 if let Expression::Select(mut sel) = expr {
133 sel.limit = None;
134 sel.offset = None;
135 Expression::Select(sel)
136 } else {
137 expr
138 }
139}
140
141pub fn rename_columns(expr: Expression, mapping: &HashMap<String, String>) -> Expression {
150 xform(expr, |node| match node {
151 Expression::Column(mut col) => {
152 if let Some(new_name) = mapping.get(&col.name.name) {
153 col.name.name = new_name.clone();
154 }
155 Expression::Column(col)
156 }
157 other => other,
158 })
159}
160
161pub fn rename_tables(expr: Expression, mapping: &HashMap<String, String>) -> Expression {
163 xform(expr, |node| match node {
164 Expression::Table(mut tbl) => {
165 if let Some(new_name) = mapping.get(&tbl.name.name) {
166 tbl.name.name = new_name.clone();
167 }
168 Expression::Table(tbl)
169 }
170 Expression::Column(mut col) => {
171 if let Some(ref mut table_id) = col.table {
172 if let Some(new_name) = mapping.get(&table_id.name) {
173 table_id.name = new_name.clone();
174 }
175 }
176 Expression::Column(col)
177 }
178 other => other,
179 })
180}
181
182pub fn qualify_columns(expr: Expression, table_name: &str) -> Expression {
186 let table = table_name.to_string();
187 xform(expr, move |node| match node {
188 Expression::Column(mut col) => {
189 if col.table.is_none() {
190 col.table = Some(Identifier::new(&table));
191 }
192 Expression::Column(col)
193 }
194 other => other,
195 })
196}
197
198pub fn replace_nodes<F: Fn(&Expression) -> bool>(
204 expr: Expression,
205 predicate: F,
206 replacement: Expression,
207) -> Expression {
208 xform(expr, |node| {
209 if predicate(&node) {
210 replacement.clone()
211 } else {
212 node
213 }
214 })
215}
216
217pub fn replace_by_type<F, R>(expr: Expression, predicate: F, replacer: R) -> Expression
219where
220 F: Fn(&Expression) -> bool,
221 R: Fn(Expression) -> Expression,
222{
223 xform(expr, |node| {
224 if predicate(&node) {
225 replacer(node)
226 } else {
227 node
228 }
229 })
230}
231
232pub fn remove_nodes<F: Fn(&Expression) -> bool>(expr: Expression, predicate: F) -> Expression {
238 xform(expr, |node| {
239 if predicate(&node) {
240 Expression::Null(Null)
241 } else {
242 node
243 }
244 })
245}
246
247pub fn get_column_names(expr: &Expression) -> Vec<String> {
253 expr.find_all(|e| matches!(e, Expression::Column(_)))
254 .into_iter()
255 .filter_map(|e| {
256 if let Expression::Column(col) = e {
257 Some(col.name.name.clone())
258 } else {
259 None
260 }
261 })
262 .collect()
263}
264
265pub fn get_table_names(expr: &Expression) -> Vec<String> {
267 expr.find_all(|e| matches!(e, Expression::Table(_)))
268 .into_iter()
269 .filter_map(|e| {
270 if let Expression::Table(tbl) = e {
271 Some(tbl.name.name.clone())
272 } else {
273 None
274 }
275 })
276 .collect()
277}
278
279pub fn get_identifiers(expr: &Expression) -> Vec<&Expression> {
281 expr.find_all(|e| matches!(e, Expression::Identifier(_)))
282}
283
284pub fn get_functions(expr: &Expression) -> Vec<&Expression> {
286 expr.find_all(|e| {
287 matches!(
288 e,
289 Expression::Function(_) | Expression::AggregateFunction(_)
290 )
291 })
292}
293
294pub fn get_literals(expr: &Expression) -> Vec<&Expression> {
296 expr.find_all(|e| {
297 matches!(
298 e,
299 Expression::Literal(_) | Expression::Boolean(_) | Expression::Null(_)
300 )
301 })
302}
303
304pub fn get_subqueries(expr: &Expression) -> Vec<&Expression> {
306 expr.find_all(|e| matches!(e, Expression::Subquery(_)))
307}
308
309pub fn get_aggregate_functions(expr: &Expression) -> Vec<&Expression> {
314 expr.find_all(|e| {
315 matches!(
316 e,
317 Expression::AggregateFunction(_)
318 | Expression::Count(_)
319 | Expression::Sum(_)
320 | Expression::Avg(_)
321 | Expression::Min(_)
322 | Expression::Max(_)
323 | Expression::ApproxDistinct(_)
324 | Expression::ArrayAgg(_)
325 | Expression::GroupConcat(_)
326 | Expression::StringAgg(_)
327 | Expression::ListAgg(_)
328 )
329 })
330}
331
332pub fn get_window_functions(expr: &Expression) -> Vec<&Expression> {
334 expr.find_all(|e| matches!(e, Expression::WindowFunction(_)))
335}
336
337pub fn node_count(expr: &Expression) -> usize {
339 expr.dfs().count()
340}
341
342#[cfg(test)]
343mod tests {
344 use super::*;
345 use crate::parser::Parser;
346
347 fn parse_one(sql: &str) -> Expression {
348 let mut exprs = Parser::parse_sql(sql).unwrap();
349 exprs.remove(0)
350 }
351
352 #[test]
353 fn test_add_where() {
354 let expr = parse_one("SELECT a FROM t");
355 let cond = Expression::Eq(Box::new(BinaryOp::new(
356 Expression::column("b"),
357 Expression::number(1),
358 )));
359 let result = add_where(expr, cond, false);
360 let sql = result.sql();
361 assert!(sql.contains("WHERE"), "Expected WHERE in: {}", sql);
362 assert!(sql.contains("b = 1"), "Expected condition in: {}", sql);
363 }
364
365 #[test]
366 fn test_add_where_combines_with_and() {
367 let expr = parse_one("SELECT a FROM t WHERE x = 1");
368 let cond = Expression::Eq(Box::new(BinaryOp::new(
369 Expression::column("y"),
370 Expression::number(2),
371 )));
372 let result = add_where(expr, cond, false);
373 let sql = result.sql();
374 assert!(sql.contains("AND"), "Expected AND in: {}", sql);
375 }
376
377 #[test]
378 fn test_remove_where() {
379 let expr = parse_one("SELECT a FROM t WHERE x = 1");
380 let result = remove_where(expr);
381 let sql = result.sql();
382 assert!(!sql.contains("WHERE"), "Should not contain WHERE: {}", sql);
383 }
384
385 #[test]
386 fn test_set_limit() {
387 let expr = parse_one("SELECT a FROM t");
388 let result = set_limit(expr, 10);
389 let sql = result.sql();
390 assert!(sql.contains("LIMIT 10"), "Expected LIMIT in: {}", sql);
391 }
392
393 #[test]
394 fn test_set_offset() {
395 let expr = parse_one("SELECT a FROM t");
396 let result = set_offset(expr, 5);
397 let sql = result.sql();
398 assert!(sql.contains("OFFSET 5"), "Expected OFFSET in: {}", sql);
399 }
400
401 #[test]
402 fn test_remove_limit_offset() {
403 let expr = parse_one("SELECT a FROM t LIMIT 10 OFFSET 5");
404 let result = remove_limit_offset(expr);
405 let sql = result.sql();
406 assert!(!sql.contains("LIMIT"), "Should not contain LIMIT: {}", sql);
407 assert!(!sql.contains("OFFSET"), "Should not contain OFFSET: {}", sql);
408 }
409
410 #[test]
411 fn test_get_column_names() {
412 let expr = parse_one("SELECT a, b, c FROM t");
413 let names = get_column_names(&expr);
414 assert!(names.contains(&"a".to_string()));
415 assert!(names.contains(&"b".to_string()));
416 assert!(names.contains(&"c".to_string()));
417 }
418
419 #[test]
420 fn test_get_table_names() {
421 let expr = parse_one("SELECT a FROM users");
424 let tables = crate::traversal::get_tables(&expr);
425 let names = get_table_names(&expr);
427 assert_eq!(names.len(), tables.len(),
428 "get_table_names and get_tables should find same count");
429 }
430
431 #[test]
432 fn test_node_count() {
433 let expr = parse_one("SELECT a FROM t");
434 let count = node_count(&expr);
435 assert!(count > 0, "Expected non-zero node count");
436 }
437
438 #[test]
439 fn test_rename_columns() {
440 let expr = parse_one("SELECT old_name FROM t");
441 let mut mapping = HashMap::new();
442 mapping.insert("old_name".to_string(), "new_name".to_string());
443 let result = rename_columns(expr, &mapping);
444 let sql = result.sql();
445 assert!(sql.contains("new_name"), "Expected new_name in: {}", sql);
446 assert!(!sql.contains("old_name"), "Should not contain old_name: {}", sql);
447 }
448
449 #[test]
450 fn test_rename_tables() {
451 let expr = parse_one("SELECT a FROM old_table");
452 let mut mapping = HashMap::new();
453 mapping.insert("old_table".to_string(), "new_table".to_string());
454 let result = rename_tables(expr, &mapping);
455 let sql = result.sql();
456 assert!(sql.contains("new_table"), "Expected new_table in: {}", sql);
457 }
458
459 #[test]
460 fn test_set_distinct() {
461 let expr = parse_one("SELECT a FROM t");
462 let result = set_distinct(expr, true);
463 let sql = result.sql();
464 assert!(sql.contains("DISTINCT"), "Expected DISTINCT in: {}", sql);
465 }
466
467 #[test]
468 fn test_add_select_columns() {
469 let expr = parse_one("SELECT a FROM t");
470 let result = add_select_columns(expr, vec![Expression::column("b")]);
471 let sql = result.sql();
472 assert!(sql.contains("a, b") || sql.contains("a,b"), "Expected a, b in: {}", sql);
473 }
474
475 #[test]
476 fn test_qualify_columns() {
477 let expr = parse_one("SELECT a, b FROM t");
478 let result = qualify_columns(expr, "t");
479 let sql = result.sql();
480 assert!(sql.contains("t.a"), "Expected t.a in: {}", sql);
481 assert!(sql.contains("t.b"), "Expected t.b in: {}", sql);
482 }
483
484 #[test]
485 fn test_get_functions() {
486 let expr = parse_one("SELECT COUNT(*), UPPER(name) FROM t");
487 let funcs = get_functions(&expr);
488 assert!(funcs.len() >= 0);
493 }
494
495 #[test]
496 fn test_get_aggregate_functions() {
497 let expr = parse_one("SELECT COUNT(*), SUM(x) FROM t");
498 let aggs = get_aggregate_functions(&expr);
499 assert!(aggs.len() >= 2, "Expected at least 2 aggregates, got {}", aggs.len());
500 }
501}