1use crate::expressions::Expression;
38use std::collections::{HashMap, VecDeque};
39
40pub type NodeId = usize;
42
43#[derive(Debug, Clone)]
45pub struct ParentInfo {
46 pub parent_id: Option<NodeId>,
48 pub arg_key: String,
50 pub index: Option<usize>,
52}
53
54#[derive(Debug, Default)]
68pub struct TreeContext {
69 nodes: HashMap<NodeId, ParentInfo>,
71 next_id: NodeId,
73 path: Vec<(NodeId, String, Option<usize>)>,
75}
76
77impl TreeContext {
78 pub fn new() -> Self {
80 Self::default()
81 }
82
83 pub fn build(root: &Expression) -> Self {
85 let mut ctx = Self::new();
86 ctx.visit_expr(root);
87 ctx
88 }
89
90 fn visit_expr(&mut self, expr: &Expression) -> NodeId {
92 let id = self.next_id;
93 self.next_id += 1;
94
95 let parent_info = if let Some((parent_id, arg_key, index)) = self.path.last() {
97 ParentInfo {
98 parent_id: Some(*parent_id),
99 arg_key: arg_key.clone(),
100 index: *index,
101 }
102 } else {
103 ParentInfo {
104 parent_id: None,
105 arg_key: String::new(),
106 index: None,
107 }
108 };
109 self.nodes.insert(id, parent_info);
110
111 for (key, child) in iter_children(expr) {
113 self.path.push((id, key.to_string(), None));
114 self.visit_expr(child);
115 self.path.pop();
116 }
117
118 for (key, children) in iter_children_lists(expr) {
120 for (idx, child) in children.iter().enumerate() {
121 self.path.push((id, key.to_string(), Some(idx)));
122 self.visit_expr(child);
123 self.path.pop();
124 }
125 }
126
127 id
128 }
129
130 pub fn get(&self, id: NodeId) -> Option<&ParentInfo> {
132 self.nodes.get(&id)
133 }
134
135 pub fn depth_of(&self, id: NodeId) -> usize {
137 let mut depth = 0;
138 let mut current = id;
139 while let Some(info) = self.nodes.get(¤t) {
140 if let Some(parent_id) = info.parent_id {
141 depth += 1;
142 current = parent_id;
143 } else {
144 break;
145 }
146 }
147 depth
148 }
149
150 pub fn ancestors_of(&self, id: NodeId) -> Vec<NodeId> {
152 let mut ancestors = Vec::new();
153 let mut current = id;
154 while let Some(info) = self.nodes.get(¤t) {
155 if let Some(parent_id) = info.parent_id {
156 ancestors.push(parent_id);
157 current = parent_id;
158 } else {
159 break;
160 }
161 }
162 ancestors
163 }
164}
165
166fn iter_children(expr: &Expression) -> Vec<(&'static str, &Expression)> {
170 let mut children = Vec::new();
171
172 match expr {
173 Expression::Alias(a) => {
174 children.push(("this", &a.this));
175 }
176 Expression::Cast(c) => {
177 children.push(("this", &c.this));
178 }
179 Expression::Not(u) | Expression::Neg(u) | Expression::BitwiseNot(u) => {
180 children.push(("this", &u.this));
181 }
182 Expression::Paren(p) => {
183 children.push(("this", &p.this));
184 }
185 Expression::IsNull(i) => {
186 children.push(("this", &i.this));
187 }
188 Expression::Exists(e) => {
189 children.push(("this", &e.this));
190 }
191 Expression::Subquery(s) => {
192 children.push(("this", &s.this));
193 }
194 Expression::Where(w) => {
195 children.push(("this", &w.this));
196 }
197 Expression::Having(h) => {
198 children.push(("this", &h.this));
199 }
200 Expression::Qualify(q) => {
201 children.push(("this", &q.this));
202 }
203 Expression::And(op)
204 | Expression::Or(op)
205 | Expression::Add(op)
206 | Expression::Sub(op)
207 | Expression::Mul(op)
208 | Expression::Div(op)
209 | Expression::Mod(op)
210 | Expression::Eq(op)
211 | Expression::Neq(op)
212 | Expression::Lt(op)
213 | Expression::Lte(op)
214 | Expression::Gt(op)
215 | Expression::Gte(op)
216 | Expression::BitwiseAnd(op)
217 | Expression::BitwiseOr(op)
218 | Expression::BitwiseXor(op)
219 | Expression::Concat(op) => {
220 children.push(("left", &op.left));
221 children.push(("right", &op.right));
222 }
223 Expression::Like(op) | Expression::ILike(op) => {
224 children.push(("left", &op.left));
225 children.push(("right", &op.right));
226 }
227 Expression::Between(b) => {
228 children.push(("this", &b.this));
229 children.push(("low", &b.low));
230 children.push(("high", &b.high));
231 }
232 Expression::In(i) => {
233 children.push(("this", &i.this));
234 }
235 Expression::Case(c) => {
236 if let Some(ref operand) = &c.operand {
237 children.push(("operand", operand));
238 }
239 }
240 Expression::WindowFunction(wf) => {
241 children.push(("this", &wf.this));
242 }
243 Expression::Union(u) => {
244 children.push(("left", &u.left));
245 children.push(("right", &u.right));
246 }
247 Expression::Intersect(i) => {
248 children.push(("left", &i.left));
249 children.push(("right", &i.right));
250 }
251 Expression::Except(e) => {
252 children.push(("left", &e.left));
253 children.push(("right", &e.right));
254 }
255 Expression::Ordered(o) => {
256 children.push(("this", &o.this));
257 }
258 Expression::Interval(i) => {
259 if let Some(ref this) = i.this {
260 children.push(("this", this));
261 }
262 }
263 _ => {}
264 }
265
266 children
267}
268
269fn iter_children_lists(expr: &Expression) -> Vec<(&'static str, &[Expression])> {
273 let mut lists = Vec::new();
274
275 match expr {
276 Expression::Select(s) => {
277 lists.push(("expressions", s.expressions.as_slice()));
278 }
280 Expression::Function(f) => {
281 lists.push(("args", f.args.as_slice()));
282 }
283 Expression::AggregateFunction(f) => {
284 lists.push(("args", f.args.as_slice()));
285 }
286 Expression::From(f) => {
287 lists.push(("expressions", f.expressions.as_slice()));
288 }
289 Expression::GroupBy(g) => {
290 lists.push(("expressions", g.expressions.as_slice()));
291 }
292 Expression::In(i) => {
295 lists.push(("expressions", i.expressions.as_slice()));
296 }
297 Expression::Array(a) => {
298 lists.push(("expressions", a.expressions.as_slice()));
299 }
300 Expression::Tuple(t) => {
301 lists.push(("expressions", t.expressions.as_slice()));
302 }
303 Expression::Coalesce(c) => {
305 lists.push(("expressions", c.expressions.as_slice()));
306 }
307 Expression::Greatest(g) | Expression::Least(g) => {
308 lists.push(("expressions", g.expressions.as_slice()));
309 }
310 _ => {}
311 }
312
313 lists
314}
315
316pub struct DfsIter<'a> {
325 stack: Vec<&'a Expression>,
326}
327
328impl<'a> DfsIter<'a> {
329 pub fn new(root: &'a Expression) -> Self {
331 Self { stack: vec![root] }
332 }
333}
334
335impl<'a> Iterator for DfsIter<'a> {
336 type Item = &'a Expression;
337
338 fn next(&mut self) -> Option<Self::Item> {
339 let expr = self.stack.pop()?;
340
341 let children: Vec<_> = iter_children(expr).into_iter().map(|(_, e)| e).collect();
343 for child in children.into_iter().rev() {
344 self.stack.push(child);
345 }
346
347 let lists: Vec<_> = iter_children_lists(expr)
348 .into_iter()
349 .flat_map(|(_, es)| es.iter())
350 .collect();
351 for child in lists.into_iter().rev() {
352 self.stack.push(child);
353 }
354
355 Some(expr)
356 }
357}
358
359pub struct BfsIter<'a> {
367 queue: VecDeque<&'a Expression>,
368}
369
370impl<'a> BfsIter<'a> {
371 pub fn new(root: &'a Expression) -> Self {
373 let mut queue = VecDeque::new();
374 queue.push_back(root);
375 Self { queue }
376 }
377}
378
379impl<'a> Iterator for BfsIter<'a> {
380 type Item = &'a Expression;
381
382 fn next(&mut self) -> Option<Self::Item> {
383 let expr = self.queue.pop_front()?;
384
385 for (_, child) in iter_children(expr) {
387 self.queue.push_back(child);
388 }
389
390 for (_, children) in iter_children_lists(expr) {
391 for child in children {
392 self.queue.push_back(child);
393 }
394 }
395
396 Some(expr)
397 }
398}
399
400pub trait ExpressionWalk {
406 fn dfs(&self) -> DfsIter<'_>;
411
412 fn bfs(&self) -> BfsIter<'_>;
416
417 fn find<F>(&self, predicate: F) -> Option<&Expression>
421 where
422 F: Fn(&Expression) -> bool;
423
424 fn find_all<F>(&self, predicate: F) -> Vec<&Expression>
428 where
429 F: Fn(&Expression) -> bool;
430
431 fn contains<F>(&self, predicate: F) -> bool
433 where
434 F: Fn(&Expression) -> bool;
435
436 fn count<F>(&self, predicate: F) -> usize
438 where
439 F: Fn(&Expression) -> bool;
440
441 fn children(&self) -> Vec<&Expression>;
446
447 fn tree_depth(&self) -> usize;
451
452 fn transform_owned<F>(self, fun: F) -> crate::Result<Expression>
458 where
459 F: Fn(Expression) -> crate::Result<Option<Expression>>,
460 Self: Sized;
461}
462
463impl ExpressionWalk for Expression {
464 fn dfs(&self) -> DfsIter<'_> {
465 DfsIter::new(self)
466 }
467
468 fn bfs(&self) -> BfsIter<'_> {
469 BfsIter::new(self)
470 }
471
472 fn find<F>(&self, predicate: F) -> Option<&Expression>
473 where
474 F: Fn(&Expression) -> bool,
475 {
476 self.dfs().find(|e| predicate(e))
477 }
478
479 fn find_all<F>(&self, predicate: F) -> Vec<&Expression>
480 where
481 F: Fn(&Expression) -> bool,
482 {
483 self.dfs().filter(|e| predicate(e)).collect()
484 }
485
486 fn contains<F>(&self, predicate: F) -> bool
487 where
488 F: Fn(&Expression) -> bool,
489 {
490 self.dfs().any(|e| predicate(e))
491 }
492
493 fn count<F>(&self, predicate: F) -> usize
494 where
495 F: Fn(&Expression) -> bool,
496 {
497 self.dfs().filter(|e| predicate(e)).count()
498 }
499
500 fn children(&self) -> Vec<&Expression> {
501 let mut result: Vec<&Expression> = Vec::new();
502 for (_, child) in iter_children(self) {
503 result.push(child);
504 }
505 for (_, children_list) in iter_children_lists(self) {
506 for child in children_list {
507 result.push(child);
508 }
509 }
510 result
511 }
512
513 fn tree_depth(&self) -> usize {
514 let mut max_depth = 0;
515
516 for (_, child) in iter_children(self) {
517 let child_depth = child.tree_depth();
518 if child_depth + 1 > max_depth {
519 max_depth = child_depth + 1;
520 }
521 }
522
523 for (_, children) in iter_children_lists(self) {
524 for child in children {
525 let child_depth = child.tree_depth();
526 if child_depth + 1 > max_depth {
527 max_depth = child_depth + 1;
528 }
529 }
530 }
531
532 max_depth
533 }
534
535 fn transform_owned<F>(self, fun: F) -> crate::Result<Expression>
536 where
537 F: Fn(Expression) -> crate::Result<Option<Expression>>,
538 {
539 transform(self, &fun)
540 }
541}
542
543pub fn transform<F>(expr: Expression, fun: &F) -> crate::Result<Expression>
564where
565 F: Fn(Expression) -> crate::Result<Option<Expression>>,
566{
567 crate::dialects::transform_recursive(expr, &|e| match fun(e)? {
568 Some(transformed) => Ok(transformed),
569 None => Ok(Expression::Null(crate::expressions::Null)),
570 })
571}
572
573pub fn transform_map<F>(expr: Expression, fun: &F) -> crate::Result<Expression>
594where
595 F: Fn(Expression) -> crate::Result<Expression>,
596{
597 crate::dialects::transform_recursive(expr, fun)
598}
599
600pub fn is_column(expr: &Expression) -> bool {
608 matches!(expr, Expression::Column(_))
609}
610
611pub fn is_literal(expr: &Expression) -> bool {
613 matches!(
614 expr,
615 Expression::Literal(_) | Expression::Boolean(_) | Expression::Null(_)
616 )
617}
618
619pub fn is_function(expr: &Expression) -> bool {
621 matches!(
622 expr,
623 Expression::Function(_) | Expression::AggregateFunction(_)
624 )
625}
626
627pub fn is_subquery(expr: &Expression) -> bool {
629 matches!(expr, Expression::Subquery(_))
630}
631
632pub fn is_select(expr: &Expression) -> bool {
634 matches!(expr, Expression::Select(_))
635}
636
637pub fn is_aggregate(expr: &Expression) -> bool {
639 matches!(expr, Expression::AggregateFunction(_))
640}
641
642pub fn is_window_function(expr: &Expression) -> bool {
644 matches!(expr, Expression::WindowFunction(_))
645}
646
647pub fn get_columns(expr: &Expression) -> Vec<&Expression> {
651 expr.find_all(is_column)
652}
653
654pub fn get_tables(expr: &Expression) -> Vec<&Expression> {
658 expr.find_all(|e| matches!(e, Expression::Table(_)))
659}
660
661pub fn contains_aggregate(expr: &Expression) -> bool {
663 expr.contains(is_aggregate)
664}
665
666pub fn contains_window_function(expr: &Expression) -> bool {
668 expr.contains(is_window_function)
669}
670
671pub fn contains_subquery(expr: &Expression) -> bool {
673 expr.contains(is_subquery)
674}
675
676macro_rules! is_type {
682 ($name:ident, $($variant:pat),+ $(,)?) => {
683 pub fn $name(expr: &Expression) -> bool {
685 matches!(expr, $($variant)|+)
686 }
687 };
688}
689
690is_type!(is_insert, Expression::Insert(_));
692is_type!(is_update, Expression::Update(_));
693is_type!(is_delete, Expression::Delete(_));
694is_type!(is_union, Expression::Union(_));
695is_type!(is_intersect, Expression::Intersect(_));
696is_type!(is_except, Expression::Except(_));
697
698is_type!(is_boolean, Expression::Boolean(_));
700is_type!(is_null_literal, Expression::Null(_));
701is_type!(is_star, Expression::Star(_));
702is_type!(is_identifier, Expression::Identifier(_));
703is_type!(is_table, Expression::Table(_));
704
705is_type!(is_eq, Expression::Eq(_));
707is_type!(is_neq, Expression::Neq(_));
708is_type!(is_lt, Expression::Lt(_));
709is_type!(is_lte, Expression::Lte(_));
710is_type!(is_gt, Expression::Gt(_));
711is_type!(is_gte, Expression::Gte(_));
712is_type!(is_like, Expression::Like(_));
713is_type!(is_ilike, Expression::ILike(_));
714
715is_type!(is_add, Expression::Add(_));
717is_type!(is_sub, Expression::Sub(_));
718is_type!(is_mul, Expression::Mul(_));
719is_type!(is_div, Expression::Div(_));
720is_type!(is_mod, Expression::Mod(_));
721is_type!(is_concat, Expression::Concat(_));
722
723is_type!(is_and, Expression::And(_));
725is_type!(is_or, Expression::Or(_));
726is_type!(is_not, Expression::Not(_));
727
728is_type!(is_in, Expression::In(_));
730is_type!(is_between, Expression::Between(_));
731is_type!(is_is_null, Expression::IsNull(_));
732is_type!(is_exists, Expression::Exists(_));
733
734is_type!(is_count, Expression::Count(_));
736is_type!(is_sum, Expression::Sum(_));
737is_type!(is_avg, Expression::Avg(_));
738is_type!(is_min_func, Expression::Min(_));
739is_type!(is_max_func, Expression::Max(_));
740is_type!(is_coalesce, Expression::Coalesce(_));
741is_type!(is_null_if, Expression::NullIf(_));
742is_type!(is_cast, Expression::Cast(_));
743is_type!(is_try_cast, Expression::TryCast(_));
744is_type!(is_safe_cast, Expression::SafeCast(_));
745is_type!(is_case, Expression::Case(_));
746
747is_type!(is_from, Expression::From(_));
749is_type!(is_join, Expression::Join(_));
750is_type!(is_where, Expression::Where(_));
751is_type!(is_group_by, Expression::GroupBy(_));
752is_type!(is_having, Expression::Having(_));
753is_type!(is_order_by, Expression::OrderBy(_));
754is_type!(is_limit, Expression::Limit(_));
755is_type!(is_offset, Expression::Offset(_));
756is_type!(is_with, Expression::With(_));
757is_type!(is_cte, Expression::Cte(_));
758is_type!(is_alias, Expression::Alias(_));
759is_type!(is_paren, Expression::Paren(_));
760is_type!(is_ordered, Expression::Ordered(_));
761
762is_type!(is_create_table, Expression::CreateTable(_));
764is_type!(is_drop_table, Expression::DropTable(_));
765is_type!(is_alter_table, Expression::AlterTable(_));
766is_type!(is_create_index, Expression::CreateIndex(_));
767is_type!(is_drop_index, Expression::DropIndex(_));
768is_type!(is_create_view, Expression::CreateView(_));
769is_type!(is_drop_view, Expression::DropView(_));
770
771pub fn is_query(expr: &Expression) -> bool {
777 matches!(
778 expr,
779 Expression::Select(_)
780 | Expression::Insert(_)
781 | Expression::Update(_)
782 | Expression::Delete(_)
783 )
784}
785
786pub fn is_set_operation(expr: &Expression) -> bool {
788 matches!(
789 expr,
790 Expression::Union(_) | Expression::Intersect(_) | Expression::Except(_)
791 )
792}
793
794pub fn is_comparison(expr: &Expression) -> bool {
796 matches!(
797 expr,
798 Expression::Eq(_)
799 | Expression::Neq(_)
800 | Expression::Lt(_)
801 | Expression::Lte(_)
802 | Expression::Gt(_)
803 | Expression::Gte(_)
804 | Expression::Like(_)
805 | Expression::ILike(_)
806 )
807}
808
809pub fn is_arithmetic(expr: &Expression) -> bool {
811 matches!(
812 expr,
813 Expression::Add(_)
814 | Expression::Sub(_)
815 | Expression::Mul(_)
816 | Expression::Div(_)
817 | Expression::Mod(_)
818 )
819}
820
821pub fn is_logical(expr: &Expression) -> bool {
823 matches!(
824 expr,
825 Expression::And(_) | Expression::Or(_) | Expression::Not(_)
826 )
827}
828
829pub fn is_ddl(expr: &Expression) -> bool {
831 matches!(
832 expr,
833 Expression::CreateTable(_)
834 | Expression::DropTable(_)
835 | Expression::AlterTable(_)
836 | Expression::CreateIndex(_)
837 | Expression::DropIndex(_)
838 | Expression::CreateView(_)
839 | Expression::DropView(_)
840 | Expression::AlterView(_)
841 | Expression::CreateSchema(_)
842 | Expression::DropSchema(_)
843 | Expression::CreateDatabase(_)
844 | Expression::DropDatabase(_)
845 | Expression::CreateFunction(_)
846 | Expression::DropFunction(_)
847 | Expression::CreateProcedure(_)
848 | Expression::DropProcedure(_)
849 | Expression::CreateSequence(_)
850 | Expression::DropSequence(_)
851 | Expression::AlterSequence(_)
852 | Expression::CreateTrigger(_)
853 | Expression::DropTrigger(_)
854 | Expression::CreateType(_)
855 | Expression::DropType(_)
856 )
857}
858
859pub fn find_parent<'a>(root: &'a Expression, target: &Expression) -> Option<&'a Expression> {
866 fn search<'a>(node: &'a Expression, target: *const Expression) -> Option<&'a Expression> {
867 for (_, child) in iter_children(node) {
868 if std::ptr::eq(child, target) {
869 return Some(node);
870 }
871 if let Some(found) = search(child, target) {
872 return Some(found);
873 }
874 }
875 for (_, children_list) in iter_children_lists(node) {
876 for child in children_list {
877 if std::ptr::eq(child, target) {
878 return Some(node);
879 }
880 if let Some(found) = search(child, target) {
881 return Some(found);
882 }
883 }
884 }
885 None
886 }
887
888 search(root, target as *const Expression)
889}
890
891pub fn find_ancestor<'a, F>(
897 root: &'a Expression,
898 target: &Expression,
899 predicate: F,
900) -> Option<&'a Expression>
901where
902 F: Fn(&Expression) -> bool,
903{
904 fn build_path<'a>(
906 node: &'a Expression,
907 target: *const Expression,
908 path: &mut Vec<&'a Expression>,
909 ) -> bool {
910 if std::ptr::eq(node, target) {
911 return true;
912 }
913 path.push(node);
914 for (_, child) in iter_children(node) {
915 if build_path(child, target, path) {
916 return true;
917 }
918 }
919 for (_, children_list) in iter_children_lists(node) {
920 for child in children_list {
921 if build_path(child, target, path) {
922 return true;
923 }
924 }
925 }
926 path.pop();
927 false
928 }
929
930 let mut path = Vec::new();
931 if !build_path(root, target as *const Expression, &mut path) {
932 return None;
933 }
934
935 for ancestor in path.iter().rev() {
937 if predicate(ancestor) {
938 return Some(ancestor);
939 }
940 }
941 None
942}
943
944#[cfg(test)]
945mod tests {
946 use super::*;
947 use crate::expressions::{BinaryOp, Column, Identifier, Literal};
948
949 fn make_column(name: &str) -> Expression {
950 Expression::Column(Column {
951 name: Identifier {
952 name: name.to_string(),
953 quoted: false,
954 trailing_comments: vec![],
955 },
956 table: None,
957 join_mark: false,
958 trailing_comments: vec![],
959 })
960 }
961
962 fn make_literal(value: i64) -> Expression {
963 Expression::Literal(Literal::Number(value.to_string()))
964 }
965
966 #[test]
967 fn test_dfs_simple() {
968 let left = make_column("a");
969 let right = make_literal(1);
970 let expr = Expression::Eq(Box::new(BinaryOp {
971 left,
972 right,
973 left_comments: vec![],
974 operator_comments: vec![],
975 trailing_comments: vec![],
976 }));
977
978 let nodes: Vec<_> = expr.dfs().collect();
979 assert_eq!(nodes.len(), 3); assert!(matches!(nodes[0], Expression::Eq(_)));
981 assert!(matches!(nodes[1], Expression::Column(_)));
982 assert!(matches!(nodes[2], Expression::Literal(_)));
983 }
984
985 #[test]
986 fn test_find() {
987 let left = make_column("a");
988 let right = make_literal(1);
989 let expr = Expression::Eq(Box::new(BinaryOp {
990 left,
991 right,
992 left_comments: vec![],
993 operator_comments: vec![],
994 trailing_comments: vec![],
995 }));
996
997 let column = expr.find(is_column);
998 assert!(column.is_some());
999 assert!(matches!(column.unwrap(), Expression::Column(_)));
1000
1001 let literal = expr.find(is_literal);
1002 assert!(literal.is_some());
1003 assert!(matches!(literal.unwrap(), Expression::Literal(_)));
1004 }
1005
1006 #[test]
1007 fn test_find_all() {
1008 let col1 = make_column("a");
1009 let col2 = make_column("b");
1010 let expr = Expression::And(Box::new(BinaryOp {
1011 left: col1,
1012 right: col2,
1013 left_comments: vec![],
1014 operator_comments: vec![],
1015 trailing_comments: vec![],
1016 }));
1017
1018 let columns = expr.find_all(is_column);
1019 assert_eq!(columns.len(), 2);
1020 }
1021
1022 #[test]
1023 fn test_contains() {
1024 let col = make_column("a");
1025 let lit = make_literal(1);
1026 let expr = Expression::Eq(Box::new(BinaryOp {
1027 left: col,
1028 right: lit,
1029 left_comments: vec![],
1030 operator_comments: vec![],
1031 trailing_comments: vec![],
1032 }));
1033
1034 assert!(expr.contains(is_column));
1035 assert!(expr.contains(is_literal));
1036 assert!(!expr.contains(is_subquery));
1037 }
1038
1039 #[test]
1040 fn test_count() {
1041 let col1 = make_column("a");
1042 let col2 = make_column("b");
1043 let lit = make_literal(1);
1044
1045 let inner = Expression::Add(Box::new(BinaryOp {
1046 left: col2,
1047 right: lit,
1048 left_comments: vec![],
1049 operator_comments: vec![],
1050 trailing_comments: vec![],
1051 }));
1052
1053 let expr = Expression::Eq(Box::new(BinaryOp {
1054 left: col1,
1055 right: inner,
1056 left_comments: vec![],
1057 operator_comments: vec![],
1058 trailing_comments: vec![],
1059 }));
1060
1061 assert_eq!(expr.count(is_column), 2);
1062 assert_eq!(expr.count(is_literal), 1);
1063 }
1064
1065 #[test]
1066 fn test_tree_depth() {
1067 let lit = make_literal(1);
1069 assert_eq!(lit.tree_depth(), 0);
1070
1071 let col = make_column("a");
1073 let expr = Expression::Eq(Box::new(BinaryOp {
1074 left: col,
1075 right: lit.clone(),
1076 left_comments: vec![],
1077 operator_comments: vec![],
1078 trailing_comments: vec![],
1079 }));
1080 assert_eq!(expr.tree_depth(), 1);
1081
1082 let inner = Expression::Add(Box::new(BinaryOp {
1084 left: make_column("b"),
1085 right: lit,
1086 left_comments: vec![],
1087 operator_comments: vec![],
1088 trailing_comments: vec![],
1089 }));
1090 let outer = Expression::Eq(Box::new(BinaryOp {
1091 left: make_column("a"),
1092 right: inner,
1093 left_comments: vec![],
1094 operator_comments: vec![],
1095 trailing_comments: vec![],
1096 }));
1097 assert_eq!(outer.tree_depth(), 2);
1098 }
1099
1100 #[test]
1101 fn test_tree_context() {
1102 let col = make_column("a");
1103 let lit = make_literal(1);
1104 let expr = Expression::Eq(Box::new(BinaryOp {
1105 left: col,
1106 right: lit,
1107 left_comments: vec![],
1108 operator_comments: vec![],
1109 trailing_comments: vec![],
1110 }));
1111
1112 let ctx = TreeContext::build(&expr);
1113
1114 let root_info = ctx.get(0).unwrap();
1116 assert!(root_info.parent_id.is_none());
1117
1118 let left_info = ctx.get(1).unwrap();
1120 assert_eq!(left_info.parent_id, Some(0));
1121 assert_eq!(left_info.arg_key, "left");
1122
1123 let right_info = ctx.get(2).unwrap();
1124 assert_eq!(right_info.parent_id, Some(0));
1125 assert_eq!(right_info.arg_key, "right");
1126 }
1127
1128 #[test]
1131 fn test_transform_rename_columns() {
1132 let ast = crate::parser::Parser::parse_sql("SELECT a, b FROM t").unwrap();
1133 let expr = ast[0].clone();
1134 let result = super::transform_map(expr, &|e| {
1135 if let Expression::Column(ref c) = e {
1136 if c.name.name == "a" {
1137 return Ok(Expression::Column(Column {
1138 name: Identifier::new("alpha"),
1139 table: c.table.clone(),
1140 join_mark: false,
1141 trailing_comments: vec![],
1142 }));
1143 }
1144 }
1145 Ok(e)
1146 })
1147 .unwrap();
1148 let sql = crate::generator::Generator::sql(&result).unwrap();
1149 assert!(sql.contains("alpha"), "Expected 'alpha' in: {}", sql);
1150 assert!(sql.contains("b"), "Expected 'b' in: {}", sql);
1151 }
1152
1153 #[test]
1154 fn test_transform_noop() {
1155 let ast = crate::parser::Parser::parse_sql("SELECT 1 + 2").unwrap();
1156 let expr = ast[0].clone();
1157 let result = super::transform_map(expr.clone(), &|e| Ok(e)).unwrap();
1158 let sql1 = crate::generator::Generator::sql(&expr).unwrap();
1159 let sql2 = crate::generator::Generator::sql(&result).unwrap();
1160 assert_eq!(sql1, sql2);
1161 }
1162
1163 #[test]
1164 fn test_transform_nested() {
1165 let ast = crate::parser::Parser::parse_sql("SELECT a + b FROM t").unwrap();
1166 let expr = ast[0].clone();
1167 let result = super::transform_map(expr, &|e| {
1168 if let Expression::Column(ref c) = e {
1169 return Ok(Expression::Literal(Literal::Number(
1170 if c.name.name == "a" { "1" } else { "2" }.to_string(),
1171 )));
1172 }
1173 Ok(e)
1174 })
1175 .unwrap();
1176 let sql = crate::generator::Generator::sql(&result).unwrap();
1177 assert_eq!(sql, "SELECT 1 + 2 FROM t");
1178 }
1179
1180 #[test]
1181 fn test_transform_error() {
1182 let ast = crate::parser::Parser::parse_sql("SELECT a FROM t").unwrap();
1183 let expr = ast[0].clone();
1184 let result = super::transform_map(expr, &|e| {
1185 if let Expression::Column(ref c) = e {
1186 if c.name.name == "a" {
1187 return Err(crate::error::Error::Parse("test error".to_string()));
1188 }
1189 }
1190 Ok(e)
1191 });
1192 assert!(result.is_err());
1193 }
1194
1195 #[test]
1196 fn test_transform_owned_trait() {
1197 let ast = crate::parser::Parser::parse_sql("SELECT x FROM t").unwrap();
1198 let expr = ast[0].clone();
1199 let result = expr.transform_owned(|e| Ok(Some(e))).unwrap();
1200 let sql = crate::generator::Generator::sql(&result).unwrap();
1201 assert_eq!(sql, "SELECT x FROM t");
1202 }
1203
1204 #[test]
1207 fn test_children_leaf() {
1208 let lit = make_literal(1);
1209 assert_eq!(lit.children().len(), 0);
1210 }
1211
1212 #[test]
1213 fn test_children_binary_op() {
1214 let left = make_column("a");
1215 let right = make_literal(1);
1216 let expr = Expression::Eq(Box::new(BinaryOp {
1217 left,
1218 right,
1219 left_comments: vec![],
1220 operator_comments: vec![],
1221 trailing_comments: vec![],
1222 }));
1223 let children = expr.children();
1224 assert_eq!(children.len(), 2);
1225 assert!(matches!(children[0], Expression::Column(_)));
1226 assert!(matches!(children[1], Expression::Literal(_)));
1227 }
1228
1229 #[test]
1230 fn test_children_select() {
1231 let ast = crate::parser::Parser::parse_sql("SELECT a, b FROM t").unwrap();
1232 let expr = &ast[0];
1233 let children = expr.children();
1234 assert!(children.len() >= 2);
1236 }
1237
1238 #[test]
1241 fn test_find_parent_binary() {
1242 let left = make_column("a");
1243 let right = make_literal(1);
1244 let expr = Expression::Eq(Box::new(BinaryOp {
1245 left,
1246 right,
1247 left_comments: vec![],
1248 operator_comments: vec![],
1249 trailing_comments: vec![],
1250 }));
1251
1252 let col = expr.find(is_column).unwrap();
1254 let parent = super::find_parent(&expr, col);
1255 assert!(parent.is_some());
1256 assert!(matches!(parent.unwrap(), Expression::Eq(_)));
1257 }
1258
1259 #[test]
1260 fn test_find_parent_root_has_none() {
1261 let lit = make_literal(1);
1262 let parent = super::find_parent(&lit, &lit);
1263 assert!(parent.is_none());
1264 }
1265
1266 #[test]
1269 fn test_find_ancestor_select() {
1270 let ast = crate::parser::Parser::parse_sql("SELECT a FROM t WHERE a > 1").unwrap();
1271 let expr = &ast[0];
1272
1273 let where_col = expr.dfs().find(|e| {
1275 if let Expression::Column(c) = e {
1276 c.name.name == "a"
1277 } else {
1278 false
1279 }
1280 });
1281 assert!(where_col.is_some());
1282
1283 let ancestor = super::find_ancestor(expr, where_col.unwrap(), is_select);
1285 assert!(ancestor.is_some());
1286 assert!(matches!(ancestor.unwrap(), Expression::Select(_)));
1287 }
1288
1289 #[test]
1290 fn test_find_ancestor_no_match() {
1291 let left = make_column("a");
1292 let right = make_literal(1);
1293 let expr = Expression::Eq(Box::new(BinaryOp {
1294 left,
1295 right,
1296 left_comments: vec![],
1297 operator_comments: vec![],
1298 trailing_comments: vec![],
1299 }));
1300
1301 let col = expr.find(is_column).unwrap();
1302 let ancestor = super::find_ancestor(&expr, col, is_select);
1303 assert!(ancestor.is_none());
1304 }
1305
1306 #[test]
1307 fn test_ancestors() {
1308 let col = make_column("a");
1309 let lit = make_literal(1);
1310 let inner = Expression::Add(Box::new(BinaryOp {
1311 left: col,
1312 right: lit,
1313 left_comments: vec![],
1314 operator_comments: vec![],
1315 trailing_comments: vec![],
1316 }));
1317 let outer = Expression::Eq(Box::new(BinaryOp {
1318 left: make_column("b"),
1319 right: inner,
1320 left_comments: vec![],
1321 operator_comments: vec![],
1322 trailing_comments: vec![],
1323 }));
1324
1325 let ctx = TreeContext::build(&outer);
1326
1327 let ancestors = ctx.ancestors_of(3);
1335 assert_eq!(ancestors, vec![2, 0]); }
1337}