1use crate::expressions::Expression;
38use std::collections::{HashMap, VecDeque};
39
40pub type NodeId = usize;
42
43#[derive(Debug, Clone)]
45pub struct ParentInfo {
46 pub parent_id: Option<NodeId>,
48 pub arg_key: String,
50 pub index: Option<usize>,
52}
53
54#[derive(Debug, Default)]
68pub struct TreeContext {
69 nodes: HashMap<NodeId, ParentInfo>,
71 next_id: NodeId,
73 path: Vec<(NodeId, String, Option<usize>)>,
75}
76
77impl TreeContext {
78 pub fn new() -> Self {
80 Self::default()
81 }
82
83 pub fn build(root: &Expression) -> Self {
85 let mut ctx = Self::new();
86 ctx.visit_expr(root);
87 ctx
88 }
89
90 fn visit_expr(&mut self, expr: &Expression) -> NodeId {
92 let id = self.next_id;
93 self.next_id += 1;
94
95 let parent_info = if let Some((parent_id, arg_key, index)) = self.path.last() {
97 ParentInfo {
98 parent_id: Some(*parent_id),
99 arg_key: arg_key.clone(),
100 index: *index,
101 }
102 } else {
103 ParentInfo {
104 parent_id: None,
105 arg_key: String::new(),
106 index: None,
107 }
108 };
109 self.nodes.insert(id, parent_info);
110
111 for (key, child) in iter_children(expr) {
113 self.path.push((id, key.to_string(), None));
114 self.visit_expr(child);
115 self.path.pop();
116 }
117
118 for (key, children) in iter_children_lists(expr) {
120 for (idx, child) in children.iter().enumerate() {
121 self.path.push((id, key.to_string(), Some(idx)));
122 self.visit_expr(child);
123 self.path.pop();
124 }
125 }
126
127 id
128 }
129
130 pub fn get(&self, id: NodeId) -> Option<&ParentInfo> {
132 self.nodes.get(&id)
133 }
134
135 pub fn depth_of(&self, id: NodeId) -> usize {
137 let mut depth = 0;
138 let mut current = id;
139 while let Some(info) = self.nodes.get(¤t) {
140 if let Some(parent_id) = info.parent_id {
141 depth += 1;
142 current = parent_id;
143 } else {
144 break;
145 }
146 }
147 depth
148 }
149
150 pub fn ancestors_of(&self, id: NodeId) -> Vec<NodeId> {
152 let mut ancestors = Vec::new();
153 let mut current = id;
154 while let Some(info) = self.nodes.get(¤t) {
155 if let Some(parent_id) = info.parent_id {
156 ancestors.push(parent_id);
157 current = parent_id;
158 } else {
159 break;
160 }
161 }
162 ancestors
163 }
164}
165
166fn iter_children(expr: &Expression) -> Vec<(&'static str, &Expression)> {
170 let mut children = Vec::new();
171
172 match expr {
173 Expression::Select(s) => {
174 if let Some(from) = &s.from {
175 for source in &from.expressions {
176 children.push(("from", source));
177 }
178 }
179 for join in &s.joins {
180 children.push(("join_this", &join.this));
181 if let Some(on) = &join.on {
182 children.push(("join_on", on));
183 }
184 if let Some(match_condition) = &join.match_condition {
185 children.push(("join_match_condition", match_condition));
186 }
187 for pivot in &join.pivots {
188 children.push(("join_pivot", pivot));
189 }
190 }
191 for lateral_view in &s.lateral_views {
192 children.push(("lateral_view", &lateral_view.this));
193 }
194 if let Some(prewhere) = &s.prewhere {
195 children.push(("prewhere", prewhere));
196 }
197 if let Some(where_clause) = &s.where_clause {
198 children.push(("where", &where_clause.this));
199 }
200 if let Some(group_by) = &s.group_by {
201 for e in &group_by.expressions {
202 children.push(("group_by", e));
203 }
204 }
205 if let Some(having) = &s.having {
206 children.push(("having", &having.this));
207 }
208 if let Some(qualify) = &s.qualify {
209 children.push(("qualify", &qualify.this));
210 }
211 if let Some(order_by) = &s.order_by {
212 for ordered in &order_by.expressions {
213 children.push(("order_by", &ordered.this));
214 }
215 }
216 if let Some(distribute_by) = &s.distribute_by {
217 for e in &distribute_by.expressions {
218 children.push(("distribute_by", e));
219 }
220 }
221 if let Some(cluster_by) = &s.cluster_by {
222 for ordered in &cluster_by.expressions {
223 children.push(("cluster_by", &ordered.this));
224 }
225 }
226 if let Some(sort_by) = &s.sort_by {
227 for ordered in &sort_by.expressions {
228 children.push(("sort_by", &ordered.this));
229 }
230 }
231 if let Some(limit) = &s.limit {
232 children.push(("limit", &limit.this));
233 }
234 if let Some(offset) = &s.offset {
235 children.push(("offset", &offset.this));
236 }
237 if let Some(limit_by) = &s.limit_by {
238 for e in limit_by {
239 children.push(("limit_by", e));
240 }
241 }
242 if let Some(fetch) = &s.fetch {
243 if let Some(count) = &fetch.count {
244 children.push(("fetch", count));
245 }
246 }
247 if let Some(top) = &s.top {
248 children.push(("top", &top.this));
249 }
250 if let Some(with) = &s.with {
251 for cte in &with.ctes {
252 children.push(("with_cte", &cte.this));
253 }
254 if let Some(search) = &with.search {
255 children.push(("with_search", search));
256 }
257 }
258 if let Some(sample) = &s.sample {
259 children.push(("sample_size", &sample.size));
260 if let Some(seed) = &sample.seed {
261 children.push(("sample_seed", seed));
262 }
263 if let Some(offset) = &sample.offset {
264 children.push(("sample_offset", offset));
265 }
266 if let Some(bucket_numerator) = &sample.bucket_numerator {
267 children.push(("sample_bucket_numerator", bucket_numerator));
268 }
269 if let Some(bucket_denominator) = &sample.bucket_denominator {
270 children.push(("sample_bucket_denominator", bucket_denominator));
271 }
272 if let Some(bucket_field) = &sample.bucket_field {
273 children.push(("sample_bucket_field", bucket_field));
274 }
275 }
276 if let Some(connect) = &s.connect {
277 if let Some(start) = &connect.start {
278 children.push(("connect_start", start));
279 }
280 children.push(("connect", &connect.connect));
281 }
282 if let Some(into) = &s.into {
283 children.push(("into", &into.this));
284 }
285 for lock in &s.locks {
286 for e in &lock.expressions {
287 children.push(("lock_expression", e));
288 }
289 if let Some(wait) = &lock.wait {
290 children.push(("lock_wait", wait));
291 }
292 if let Some(key) = &lock.key {
293 children.push(("lock_key", key));
294 }
295 if let Some(update) = &lock.update {
296 children.push(("lock_update", update));
297 }
298 }
299 for e in &s.for_xml {
300 children.push(("for_xml", e));
301 }
302 }
303 Expression::With(with) => {
304 for cte in &with.ctes {
305 children.push(("cte", &cte.this));
306 }
307 if let Some(search) = &with.search {
308 children.push(("search", search));
309 }
310 }
311 Expression::Cte(cte) => {
312 children.push(("this", &cte.this));
313 }
314 Expression::Insert(insert) => {
315 if let Some(query) = &insert.query {
316 children.push(("query", query));
317 }
318 if let Some(with) = &insert.with {
319 for cte in &with.ctes {
320 children.push(("with_cte", &cte.this));
321 }
322 if let Some(search) = &with.search {
323 children.push(("with_search", search));
324 }
325 }
326 if let Some(on_conflict) = &insert.on_conflict {
327 children.push(("on_conflict", on_conflict));
328 }
329 if let Some(replace_where) = &insert.replace_where {
330 children.push(("replace_where", replace_where));
331 }
332 if let Some(source) = &insert.source {
333 children.push(("source", source));
334 }
335 if let Some(function_target) = &insert.function_target {
336 children.push(("function_target", function_target));
337 }
338 if let Some(partition_by) = &insert.partition_by {
339 children.push(("partition_by", partition_by));
340 }
341 if let Some(output) = &insert.output {
342 for column in &output.columns {
343 children.push(("output_column", column));
344 }
345 if let Some(into_table) = &output.into_table {
346 children.push(("output_into_table", into_table));
347 }
348 }
349 for row in &insert.values {
350 for value in row {
351 children.push(("value", value));
352 }
353 }
354 for (_, value) in &insert.partition {
355 if let Some(value) = value {
356 children.push(("partition_value", value));
357 }
358 }
359 for returning in &insert.returning {
360 children.push(("returning", returning));
361 }
362 for setting in &insert.settings {
363 children.push(("setting", setting));
364 }
365 }
366 Expression::Update(update) => {
367 if let Some(from_clause) = &update.from_clause {
368 for source in &from_clause.expressions {
369 children.push(("from", source));
370 }
371 }
372 for join in &update.table_joins {
373 children.push(("table_join_this", &join.this));
374 if let Some(on) = &join.on {
375 children.push(("table_join_on", on));
376 }
377 }
378 for join in &update.from_joins {
379 children.push(("from_join_this", &join.this));
380 if let Some(on) = &join.on {
381 children.push(("from_join_on", on));
382 }
383 }
384 for (_, value) in &update.set {
385 children.push(("set_value", value));
386 }
387 if let Some(where_clause) = &update.where_clause {
388 children.push(("where", &where_clause.this));
389 }
390 if let Some(output) = &update.output {
391 for column in &output.columns {
392 children.push(("output_column", column));
393 }
394 if let Some(into_table) = &output.into_table {
395 children.push(("output_into_table", into_table));
396 }
397 }
398 if let Some(with) = &update.with {
399 for cte in &with.ctes {
400 children.push(("with_cte", &cte.this));
401 }
402 if let Some(search) = &with.search {
403 children.push(("with_search", search));
404 }
405 }
406 if let Some(limit) = &update.limit {
407 children.push(("limit", limit));
408 }
409 if let Some(order_by) = &update.order_by {
410 for ordered in &order_by.expressions {
411 children.push(("order_by", &ordered.this));
412 }
413 }
414 for returning in &update.returning {
415 children.push(("returning", returning));
416 }
417 }
418 Expression::Delete(delete) => {
419 if let Some(with) = &delete.with {
420 for cte in &with.ctes {
421 children.push(("with_cte", &cte.this));
422 }
423 if let Some(search) = &with.search {
424 children.push(("with_search", search));
425 }
426 }
427 if let Some(where_clause) = &delete.where_clause {
428 children.push(("where", &where_clause.this));
429 }
430 if let Some(output) = &delete.output {
431 for column in &output.columns {
432 children.push(("output_column", column));
433 }
434 if let Some(into_table) = &output.into_table {
435 children.push(("output_into_table", into_table));
436 }
437 }
438 if let Some(limit) = &delete.limit {
439 children.push(("limit", limit));
440 }
441 if let Some(order_by) = &delete.order_by {
442 for ordered in &order_by.expressions {
443 children.push(("order_by", &ordered.this));
444 }
445 }
446 for returning in &delete.returning {
447 children.push(("returning", returning));
448 }
449 for join in &delete.joins {
450 children.push(("join_this", &join.this));
451 if let Some(on) = &join.on {
452 children.push(("join_on", on));
453 }
454 }
455 }
456 Expression::Join(join) => {
457 children.push(("this", &join.this));
458 if let Some(on) = &join.on {
459 children.push(("on", on));
460 }
461 if let Some(match_condition) = &join.match_condition {
462 children.push(("match_condition", match_condition));
463 }
464 for pivot in &join.pivots {
465 children.push(("pivot", pivot));
466 }
467 }
468 Expression::Alias(a) => {
469 children.push(("this", &a.this));
470 }
471 Expression::Cast(c) => {
472 children.push(("this", &c.this));
473 }
474 Expression::Not(u) | Expression::Neg(u) | Expression::BitwiseNot(u) => {
475 children.push(("this", &u.this));
476 }
477 Expression::Paren(p) => {
478 children.push(("this", &p.this));
479 }
480 Expression::IsNull(i) => {
481 children.push(("this", &i.this));
482 }
483 Expression::Exists(e) => {
484 children.push(("this", &e.this));
485 }
486 Expression::Subquery(s) => {
487 children.push(("this", &s.this));
488 }
489 Expression::Where(w) => {
490 children.push(("this", &w.this));
491 }
492 Expression::Having(h) => {
493 children.push(("this", &h.this));
494 }
495 Expression::Qualify(q) => {
496 children.push(("this", &q.this));
497 }
498 Expression::And(op)
499 | Expression::Or(op)
500 | Expression::Add(op)
501 | Expression::Sub(op)
502 | Expression::Mul(op)
503 | Expression::Div(op)
504 | Expression::Mod(op)
505 | Expression::Eq(op)
506 | Expression::Neq(op)
507 | Expression::Lt(op)
508 | Expression::Lte(op)
509 | Expression::Gt(op)
510 | Expression::Gte(op)
511 | Expression::BitwiseAnd(op)
512 | Expression::BitwiseOr(op)
513 | Expression::BitwiseXor(op)
514 | Expression::Concat(op) => {
515 children.push(("left", &op.left));
516 children.push(("right", &op.right));
517 }
518 Expression::Like(op) | Expression::ILike(op) => {
519 children.push(("left", &op.left));
520 children.push(("right", &op.right));
521 }
522 Expression::Between(b) => {
523 children.push(("this", &b.this));
524 children.push(("low", &b.low));
525 children.push(("high", &b.high));
526 }
527 Expression::In(i) => {
528 children.push(("this", &i.this));
529 }
530 Expression::Case(c) => {
531 if let Some(ref operand) = &c.operand {
532 children.push(("operand", operand));
533 }
534 }
535 Expression::WindowFunction(wf) => {
536 children.push(("this", &wf.this));
537 }
538 Expression::Union(u) => {
539 children.push(("left", &u.left));
540 children.push(("right", &u.right));
541 if let Some(with) = &u.with {
542 for cte in &with.ctes {
543 children.push(("with_cte", &cte.this));
544 }
545 if let Some(search) = &with.search {
546 children.push(("with_search", search));
547 }
548 }
549 if let Some(order_by) = &u.order_by {
550 for ordered in &order_by.expressions {
551 children.push(("order_by", &ordered.this));
552 }
553 }
554 if let Some(limit) = &u.limit {
555 children.push(("limit", limit));
556 }
557 if let Some(offset) = &u.offset {
558 children.push(("offset", offset));
559 }
560 if let Some(distribute_by) = &u.distribute_by {
561 for e in &distribute_by.expressions {
562 children.push(("distribute_by", e));
563 }
564 }
565 if let Some(sort_by) = &u.sort_by {
566 for ordered in &sort_by.expressions {
567 children.push(("sort_by", &ordered.this));
568 }
569 }
570 if let Some(cluster_by) = &u.cluster_by {
571 for ordered in &cluster_by.expressions {
572 children.push(("cluster_by", &ordered.this));
573 }
574 }
575 for e in &u.on_columns {
576 children.push(("on_column", e));
577 }
578 }
579 Expression::Intersect(i) => {
580 children.push(("left", &i.left));
581 children.push(("right", &i.right));
582 if let Some(with) = &i.with {
583 for cte in &with.ctes {
584 children.push(("with_cte", &cte.this));
585 }
586 if let Some(search) = &with.search {
587 children.push(("with_search", search));
588 }
589 }
590 if let Some(order_by) = &i.order_by {
591 for ordered in &order_by.expressions {
592 children.push(("order_by", &ordered.this));
593 }
594 }
595 if let Some(limit) = &i.limit {
596 children.push(("limit", limit));
597 }
598 if let Some(offset) = &i.offset {
599 children.push(("offset", offset));
600 }
601 if let Some(distribute_by) = &i.distribute_by {
602 for e in &distribute_by.expressions {
603 children.push(("distribute_by", e));
604 }
605 }
606 if let Some(sort_by) = &i.sort_by {
607 for ordered in &sort_by.expressions {
608 children.push(("sort_by", &ordered.this));
609 }
610 }
611 if let Some(cluster_by) = &i.cluster_by {
612 for ordered in &cluster_by.expressions {
613 children.push(("cluster_by", &ordered.this));
614 }
615 }
616 for e in &i.on_columns {
617 children.push(("on_column", e));
618 }
619 }
620 Expression::Except(e) => {
621 children.push(("left", &e.left));
622 children.push(("right", &e.right));
623 if let Some(with) = &e.with {
624 for cte in &with.ctes {
625 children.push(("with_cte", &cte.this));
626 }
627 if let Some(search) = &with.search {
628 children.push(("with_search", search));
629 }
630 }
631 if let Some(order_by) = &e.order_by {
632 for ordered in &order_by.expressions {
633 children.push(("order_by", &ordered.this));
634 }
635 }
636 if let Some(limit) = &e.limit {
637 children.push(("limit", limit));
638 }
639 if let Some(offset) = &e.offset {
640 children.push(("offset", offset));
641 }
642 if let Some(distribute_by) = &e.distribute_by {
643 for expr in &distribute_by.expressions {
644 children.push(("distribute_by", expr));
645 }
646 }
647 if let Some(sort_by) = &e.sort_by {
648 for ordered in &sort_by.expressions {
649 children.push(("sort_by", &ordered.this));
650 }
651 }
652 if let Some(cluster_by) = &e.cluster_by {
653 for ordered in &cluster_by.expressions {
654 children.push(("cluster_by", &ordered.this));
655 }
656 }
657 for expr in &e.on_columns {
658 children.push(("on_column", expr));
659 }
660 }
661 Expression::Merge(merge) => {
662 children.push(("this", &merge.this));
663 children.push(("using", &merge.using));
664 if let Some(on) = &merge.on {
665 children.push(("on", on));
666 }
667 if let Some(using_cond) = &merge.using_cond {
668 children.push(("using_cond", using_cond));
669 }
670 if let Some(whens) = &merge.whens {
671 children.push(("whens", whens));
672 }
673 if let Some(with_) = &merge.with_ {
674 children.push(("with_", with_));
675 }
676 if let Some(returning) = &merge.returning {
677 children.push(("returning", returning));
678 }
679 }
680 Expression::Ordered(o) => {
681 children.push(("this", &o.this));
682 }
683 Expression::Interval(i) => {
684 if let Some(ref this) = i.this {
685 children.push(("this", this));
686 }
687 }
688 _ => {}
689 }
690
691 children
692}
693
694fn iter_children_lists(expr: &Expression) -> Vec<(&'static str, &[Expression])> {
698 let mut lists = Vec::new();
699
700 match expr {
701 Expression::Select(s) => lists.push(("expressions", s.expressions.as_slice())),
702 Expression::Function(f) => {
703 lists.push(("args", f.args.as_slice()));
704 }
705 Expression::AggregateFunction(f) => {
706 lists.push(("args", f.args.as_slice()));
707 }
708 Expression::From(f) => {
709 lists.push(("expressions", f.expressions.as_slice()));
710 }
711 Expression::GroupBy(g) => {
712 lists.push(("expressions", g.expressions.as_slice()));
713 }
714 Expression::In(i) => {
717 lists.push(("expressions", i.expressions.as_slice()));
718 }
719 Expression::Array(a) => {
720 lists.push(("expressions", a.expressions.as_slice()));
721 }
722 Expression::Tuple(t) => {
723 lists.push(("expressions", t.expressions.as_slice()));
724 }
725 Expression::Coalesce(c) => {
727 lists.push(("expressions", c.expressions.as_slice()));
728 }
729 Expression::Greatest(g) | Expression::Least(g) => {
730 lists.push(("expressions", g.expressions.as_slice()));
731 }
732 _ => {}
733 }
734
735 lists
736}
737
738pub struct DfsIter<'a> {
747 stack: Vec<&'a Expression>,
748}
749
750impl<'a> DfsIter<'a> {
751 pub fn new(root: &'a Expression) -> Self {
753 Self { stack: vec![root] }
754 }
755}
756
757impl<'a> Iterator for DfsIter<'a> {
758 type Item = &'a Expression;
759
760 fn next(&mut self) -> Option<Self::Item> {
761 let expr = self.stack.pop()?;
762
763 let children: Vec<_> = iter_children(expr).into_iter().map(|(_, e)| e).collect();
765 for child in children.into_iter().rev() {
766 self.stack.push(child);
767 }
768
769 let lists: Vec<_> = iter_children_lists(expr)
770 .into_iter()
771 .flat_map(|(_, es)| es.iter())
772 .collect();
773 for child in lists.into_iter().rev() {
774 self.stack.push(child);
775 }
776
777 Some(expr)
778 }
779}
780
781pub struct BfsIter<'a> {
789 queue: VecDeque<&'a Expression>,
790}
791
792impl<'a> BfsIter<'a> {
793 pub fn new(root: &'a Expression) -> Self {
795 let mut queue = VecDeque::new();
796 queue.push_back(root);
797 Self { queue }
798 }
799}
800
801impl<'a> Iterator for BfsIter<'a> {
802 type Item = &'a Expression;
803
804 fn next(&mut self) -> Option<Self::Item> {
805 let expr = self.queue.pop_front()?;
806
807 for (_, child) in iter_children(expr) {
809 self.queue.push_back(child);
810 }
811
812 for (_, children) in iter_children_lists(expr) {
813 for child in children {
814 self.queue.push_back(child);
815 }
816 }
817
818 Some(expr)
819 }
820}
821
822pub trait ExpressionWalk {
828 fn dfs(&self) -> DfsIter<'_>;
833
834 fn bfs(&self) -> BfsIter<'_>;
838
839 fn find<F>(&self, predicate: F) -> Option<&Expression>
843 where
844 F: Fn(&Expression) -> bool;
845
846 fn find_all<F>(&self, predicate: F) -> Vec<&Expression>
850 where
851 F: Fn(&Expression) -> bool;
852
853 fn contains<F>(&self, predicate: F) -> bool
855 where
856 F: Fn(&Expression) -> bool;
857
858 fn count<F>(&self, predicate: F) -> usize
860 where
861 F: Fn(&Expression) -> bool;
862
863 fn children(&self) -> Vec<&Expression>;
868
869 fn tree_depth(&self) -> usize;
873
874 fn transform_owned<F>(self, fun: F) -> crate::Result<Expression>
880 where
881 F: Fn(Expression) -> crate::Result<Option<Expression>>,
882 Self: Sized;
883}
884
885impl ExpressionWalk for Expression {
886 fn dfs(&self) -> DfsIter<'_> {
887 DfsIter::new(self)
888 }
889
890 fn bfs(&self) -> BfsIter<'_> {
891 BfsIter::new(self)
892 }
893
894 fn find<F>(&self, predicate: F) -> Option<&Expression>
895 where
896 F: Fn(&Expression) -> bool,
897 {
898 self.dfs().find(|e| predicate(e))
899 }
900
901 fn find_all<F>(&self, predicate: F) -> Vec<&Expression>
902 where
903 F: Fn(&Expression) -> bool,
904 {
905 self.dfs().filter(|e| predicate(e)).collect()
906 }
907
908 fn contains<F>(&self, predicate: F) -> bool
909 where
910 F: Fn(&Expression) -> bool,
911 {
912 self.dfs().any(|e| predicate(e))
913 }
914
915 fn count<F>(&self, predicate: F) -> usize
916 where
917 F: Fn(&Expression) -> bool,
918 {
919 self.dfs().filter(|e| predicate(e)).count()
920 }
921
922 fn children(&self) -> Vec<&Expression> {
923 let mut result: Vec<&Expression> = Vec::new();
924 for (_, child) in iter_children(self) {
925 result.push(child);
926 }
927 for (_, children_list) in iter_children_lists(self) {
928 for child in children_list {
929 result.push(child);
930 }
931 }
932 result
933 }
934
935 fn tree_depth(&self) -> usize {
936 let mut max_depth = 0;
937
938 for (_, child) in iter_children(self) {
939 let child_depth = child.tree_depth();
940 if child_depth + 1 > max_depth {
941 max_depth = child_depth + 1;
942 }
943 }
944
945 for (_, children) in iter_children_lists(self) {
946 for child in children {
947 let child_depth = child.tree_depth();
948 if child_depth + 1 > max_depth {
949 max_depth = child_depth + 1;
950 }
951 }
952 }
953
954 max_depth
955 }
956
957 fn transform_owned<F>(self, fun: F) -> crate::Result<Expression>
958 where
959 F: Fn(Expression) -> crate::Result<Option<Expression>>,
960 {
961 transform(self, &fun)
962 }
963}
964
965pub fn transform<F>(expr: Expression, fun: &F) -> crate::Result<Expression>
986where
987 F: Fn(Expression) -> crate::Result<Option<Expression>>,
988{
989 crate::dialects::transform_recursive(expr, &|e| match fun(e)? {
990 Some(transformed) => Ok(transformed),
991 None => Ok(Expression::Null(crate::expressions::Null)),
992 })
993}
994
995pub fn transform_map<F>(expr: Expression, fun: &F) -> crate::Result<Expression>
1016where
1017 F: Fn(Expression) -> crate::Result<Expression>,
1018{
1019 crate::dialects::transform_recursive(expr, fun)
1020}
1021
1022pub fn is_column(expr: &Expression) -> bool {
1030 matches!(expr, Expression::Column(_))
1031}
1032
1033pub fn is_literal(expr: &Expression) -> bool {
1035 matches!(
1036 expr,
1037 Expression::Literal(_) | Expression::Boolean(_) | Expression::Null(_)
1038 )
1039}
1040
1041pub fn is_function(expr: &Expression) -> bool {
1043 matches!(
1044 expr,
1045 Expression::Function(_) | Expression::AggregateFunction(_)
1046 )
1047}
1048
1049pub fn is_subquery(expr: &Expression) -> bool {
1051 matches!(expr, Expression::Subquery(_))
1052}
1053
1054pub fn is_select(expr: &Expression) -> bool {
1056 matches!(expr, Expression::Select(_))
1057}
1058
1059pub fn is_aggregate(expr: &Expression) -> bool {
1061 matches!(expr, Expression::AggregateFunction(_))
1062}
1063
1064pub fn is_window_function(expr: &Expression) -> bool {
1066 matches!(expr, Expression::WindowFunction(_))
1067}
1068
1069pub fn get_columns(expr: &Expression) -> Vec<&Expression> {
1073 expr.find_all(is_column)
1074}
1075
1076pub fn get_tables(expr: &Expression) -> Vec<&Expression> {
1080 expr.find_all(|e| matches!(e, Expression::Table(_)))
1081}
1082
1083pub fn contains_aggregate(expr: &Expression) -> bool {
1085 expr.contains(is_aggregate)
1086}
1087
1088pub fn contains_window_function(expr: &Expression) -> bool {
1090 expr.contains(is_window_function)
1091}
1092
1093pub fn contains_subquery(expr: &Expression) -> bool {
1095 expr.contains(is_subquery)
1096}
1097
1098macro_rules! is_type {
1104 ($name:ident, $($variant:pat),+ $(,)?) => {
1105 pub fn $name(expr: &Expression) -> bool {
1107 matches!(expr, $($variant)|+)
1108 }
1109 };
1110}
1111
1112is_type!(is_insert, Expression::Insert(_));
1114is_type!(is_update, Expression::Update(_));
1115is_type!(is_delete, Expression::Delete(_));
1116is_type!(is_union, Expression::Union(_));
1117is_type!(is_intersect, Expression::Intersect(_));
1118is_type!(is_except, Expression::Except(_));
1119
1120is_type!(is_boolean, Expression::Boolean(_));
1122is_type!(is_null_literal, Expression::Null(_));
1123is_type!(is_star, Expression::Star(_));
1124is_type!(is_identifier, Expression::Identifier(_));
1125is_type!(is_table, Expression::Table(_));
1126
1127is_type!(is_eq, Expression::Eq(_));
1129is_type!(is_neq, Expression::Neq(_));
1130is_type!(is_lt, Expression::Lt(_));
1131is_type!(is_lte, Expression::Lte(_));
1132is_type!(is_gt, Expression::Gt(_));
1133is_type!(is_gte, Expression::Gte(_));
1134is_type!(is_like, Expression::Like(_));
1135is_type!(is_ilike, Expression::ILike(_));
1136
1137is_type!(is_add, Expression::Add(_));
1139is_type!(is_sub, Expression::Sub(_));
1140is_type!(is_mul, Expression::Mul(_));
1141is_type!(is_div, Expression::Div(_));
1142is_type!(is_mod, Expression::Mod(_));
1143is_type!(is_concat, Expression::Concat(_));
1144
1145is_type!(is_and, Expression::And(_));
1147is_type!(is_or, Expression::Or(_));
1148is_type!(is_not, Expression::Not(_));
1149
1150is_type!(is_in, Expression::In(_));
1152is_type!(is_between, Expression::Between(_));
1153is_type!(is_is_null, Expression::IsNull(_));
1154is_type!(is_exists, Expression::Exists(_));
1155
1156is_type!(is_count, Expression::Count(_));
1158is_type!(is_sum, Expression::Sum(_));
1159is_type!(is_avg, Expression::Avg(_));
1160is_type!(is_min_func, Expression::Min(_));
1161is_type!(is_max_func, Expression::Max(_));
1162is_type!(is_coalesce, Expression::Coalesce(_));
1163is_type!(is_null_if, Expression::NullIf(_));
1164is_type!(is_cast, Expression::Cast(_));
1165is_type!(is_try_cast, Expression::TryCast(_));
1166is_type!(is_safe_cast, Expression::SafeCast(_));
1167is_type!(is_case, Expression::Case(_));
1168
1169is_type!(is_from, Expression::From(_));
1171is_type!(is_join, Expression::Join(_));
1172is_type!(is_where, Expression::Where(_));
1173is_type!(is_group_by, Expression::GroupBy(_));
1174is_type!(is_having, Expression::Having(_));
1175is_type!(is_order_by, Expression::OrderBy(_));
1176is_type!(is_limit, Expression::Limit(_));
1177is_type!(is_offset, Expression::Offset(_));
1178is_type!(is_with, Expression::With(_));
1179is_type!(is_cte, Expression::Cte(_));
1180is_type!(is_alias, Expression::Alias(_));
1181is_type!(is_paren, Expression::Paren(_));
1182is_type!(is_ordered, Expression::Ordered(_));
1183
1184is_type!(is_create_table, Expression::CreateTable(_));
1186is_type!(is_drop_table, Expression::DropTable(_));
1187is_type!(is_alter_table, Expression::AlterTable(_));
1188is_type!(is_create_index, Expression::CreateIndex(_));
1189is_type!(is_drop_index, Expression::DropIndex(_));
1190is_type!(is_create_view, Expression::CreateView(_));
1191is_type!(is_drop_view, Expression::DropView(_));
1192
1193pub fn is_query(expr: &Expression) -> bool {
1199 matches!(
1200 expr,
1201 Expression::Select(_)
1202 | Expression::Insert(_)
1203 | Expression::Update(_)
1204 | Expression::Delete(_)
1205 )
1206}
1207
1208pub fn is_set_operation(expr: &Expression) -> bool {
1210 matches!(
1211 expr,
1212 Expression::Union(_) | Expression::Intersect(_) | Expression::Except(_)
1213 )
1214}
1215
1216pub fn is_comparison(expr: &Expression) -> bool {
1218 matches!(
1219 expr,
1220 Expression::Eq(_)
1221 | Expression::Neq(_)
1222 | Expression::Lt(_)
1223 | Expression::Lte(_)
1224 | Expression::Gt(_)
1225 | Expression::Gte(_)
1226 | Expression::Like(_)
1227 | Expression::ILike(_)
1228 )
1229}
1230
1231pub fn is_arithmetic(expr: &Expression) -> bool {
1233 matches!(
1234 expr,
1235 Expression::Add(_)
1236 | Expression::Sub(_)
1237 | Expression::Mul(_)
1238 | Expression::Div(_)
1239 | Expression::Mod(_)
1240 )
1241}
1242
1243pub fn is_logical(expr: &Expression) -> bool {
1245 matches!(
1246 expr,
1247 Expression::And(_) | Expression::Or(_) | Expression::Not(_)
1248 )
1249}
1250
1251pub fn is_ddl(expr: &Expression) -> bool {
1253 matches!(
1254 expr,
1255 Expression::CreateTable(_)
1256 | Expression::DropTable(_)
1257 | Expression::AlterTable(_)
1258 | Expression::CreateIndex(_)
1259 | Expression::DropIndex(_)
1260 | Expression::CreateView(_)
1261 | Expression::DropView(_)
1262 | Expression::AlterView(_)
1263 | Expression::CreateSchema(_)
1264 | Expression::DropSchema(_)
1265 | Expression::CreateDatabase(_)
1266 | Expression::DropDatabase(_)
1267 | Expression::CreateFunction(_)
1268 | Expression::DropFunction(_)
1269 | Expression::CreateProcedure(_)
1270 | Expression::DropProcedure(_)
1271 | Expression::CreateSequence(_)
1272 | Expression::DropSequence(_)
1273 | Expression::AlterSequence(_)
1274 | Expression::CreateTrigger(_)
1275 | Expression::DropTrigger(_)
1276 | Expression::CreateType(_)
1277 | Expression::DropType(_)
1278 )
1279}
1280
1281pub fn find_parent<'a>(root: &'a Expression, target: &Expression) -> Option<&'a Expression> {
1288 fn search<'a>(node: &'a Expression, target: *const Expression) -> Option<&'a Expression> {
1289 for (_, child) in iter_children(node) {
1290 if std::ptr::eq(child, target) {
1291 return Some(node);
1292 }
1293 if let Some(found) = search(child, target) {
1294 return Some(found);
1295 }
1296 }
1297 for (_, children_list) in iter_children_lists(node) {
1298 for child in children_list {
1299 if std::ptr::eq(child, target) {
1300 return Some(node);
1301 }
1302 if let Some(found) = search(child, target) {
1303 return Some(found);
1304 }
1305 }
1306 }
1307 None
1308 }
1309
1310 search(root, target as *const Expression)
1311}
1312
1313pub fn find_ancestor<'a, F>(
1319 root: &'a Expression,
1320 target: &Expression,
1321 predicate: F,
1322) -> Option<&'a Expression>
1323where
1324 F: Fn(&Expression) -> bool,
1325{
1326 fn build_path<'a>(
1328 node: &'a Expression,
1329 target: *const Expression,
1330 path: &mut Vec<&'a Expression>,
1331 ) -> bool {
1332 if std::ptr::eq(node, target) {
1333 return true;
1334 }
1335 path.push(node);
1336 for (_, child) in iter_children(node) {
1337 if build_path(child, target, path) {
1338 return true;
1339 }
1340 }
1341 for (_, children_list) in iter_children_lists(node) {
1342 for child in children_list {
1343 if build_path(child, target, path) {
1344 return true;
1345 }
1346 }
1347 }
1348 path.pop();
1349 false
1350 }
1351
1352 let mut path = Vec::new();
1353 if !build_path(root, target as *const Expression, &mut path) {
1354 return None;
1355 }
1356
1357 for ancestor in path.iter().rev() {
1359 if predicate(ancestor) {
1360 return Some(ancestor);
1361 }
1362 }
1363 None
1364}
1365
1366#[cfg(test)]
1367mod tests {
1368 use super::*;
1369 use crate::expressions::{BinaryOp, Column, Identifier, Literal};
1370
1371 fn make_column(name: &str) -> Expression {
1372 Expression::Column(Column {
1373 name: Identifier {
1374 name: name.to_string(),
1375 quoted: false,
1376 trailing_comments: vec![],
1377 span: None,
1378 },
1379 table: None,
1380 join_mark: false,
1381 trailing_comments: vec![],
1382 span: None,
1383 inferred_type: None,
1384 })
1385 }
1386
1387 fn make_literal(value: i64) -> Expression {
1388 Expression::Literal(Literal::Number(value.to_string()))
1389 }
1390
1391 #[test]
1392 fn test_dfs_simple() {
1393 let left = make_column("a");
1394 let right = make_literal(1);
1395 let expr = Expression::Eq(Box::new(BinaryOp {
1396 left,
1397 right,
1398 left_comments: vec![],
1399 operator_comments: vec![],
1400 trailing_comments: vec![],
1401 inferred_type: None,
1402 }));
1403
1404 let nodes: Vec<_> = expr.dfs().collect();
1405 assert_eq!(nodes.len(), 3); assert!(matches!(nodes[0], Expression::Eq(_)));
1407 assert!(matches!(nodes[1], Expression::Column(_)));
1408 assert!(matches!(nodes[2], Expression::Literal(_)));
1409 }
1410
1411 #[test]
1412 fn test_find() {
1413 let left = make_column("a");
1414 let right = make_literal(1);
1415 let expr = Expression::Eq(Box::new(BinaryOp {
1416 left,
1417 right,
1418 left_comments: vec![],
1419 operator_comments: vec![],
1420 trailing_comments: vec![],
1421 inferred_type: None,
1422 }));
1423
1424 let column = expr.find(is_column);
1425 assert!(column.is_some());
1426 assert!(matches!(column.unwrap(), Expression::Column(_)));
1427
1428 let literal = expr.find(is_literal);
1429 assert!(literal.is_some());
1430 assert!(matches!(literal.unwrap(), Expression::Literal(_)));
1431 }
1432
1433 #[test]
1434 fn test_find_all() {
1435 let col1 = make_column("a");
1436 let col2 = make_column("b");
1437 let expr = Expression::And(Box::new(BinaryOp {
1438 left: col1,
1439 right: col2,
1440 left_comments: vec![],
1441 operator_comments: vec![],
1442 trailing_comments: vec![],
1443 inferred_type: None,
1444 }));
1445
1446 let columns = expr.find_all(is_column);
1447 assert_eq!(columns.len(), 2);
1448 }
1449
1450 #[test]
1451 fn test_contains() {
1452 let col = make_column("a");
1453 let lit = make_literal(1);
1454 let expr = Expression::Eq(Box::new(BinaryOp {
1455 left: col,
1456 right: lit,
1457 left_comments: vec![],
1458 operator_comments: vec![],
1459 trailing_comments: vec![],
1460 inferred_type: None,
1461 }));
1462
1463 assert!(expr.contains(is_column));
1464 assert!(expr.contains(is_literal));
1465 assert!(!expr.contains(is_subquery));
1466 }
1467
1468 #[test]
1469 fn test_count() {
1470 let col1 = make_column("a");
1471 let col2 = make_column("b");
1472 let lit = make_literal(1);
1473
1474 let inner = Expression::Add(Box::new(BinaryOp {
1475 left: col2,
1476 right: lit,
1477 left_comments: vec![],
1478 operator_comments: vec![],
1479 trailing_comments: vec![],
1480 inferred_type: None,
1481 }));
1482
1483 let expr = Expression::Eq(Box::new(BinaryOp {
1484 left: col1,
1485 right: inner,
1486 left_comments: vec![],
1487 operator_comments: vec![],
1488 trailing_comments: vec![],
1489 inferred_type: None,
1490 }));
1491
1492 assert_eq!(expr.count(is_column), 2);
1493 assert_eq!(expr.count(is_literal), 1);
1494 }
1495
1496 #[test]
1497 fn test_tree_depth() {
1498 let lit = make_literal(1);
1500 assert_eq!(lit.tree_depth(), 0);
1501
1502 let col = make_column("a");
1504 let expr = Expression::Eq(Box::new(BinaryOp {
1505 left: col,
1506 right: lit.clone(),
1507 left_comments: vec![],
1508 operator_comments: vec![],
1509 trailing_comments: vec![],
1510 inferred_type: None,
1511 }));
1512 assert_eq!(expr.tree_depth(), 1);
1513
1514 let inner = Expression::Add(Box::new(BinaryOp {
1516 left: make_column("b"),
1517 right: lit,
1518 left_comments: vec![],
1519 operator_comments: vec![],
1520 trailing_comments: vec![],
1521 inferred_type: None,
1522 }));
1523 let outer = Expression::Eq(Box::new(BinaryOp {
1524 left: make_column("a"),
1525 right: inner,
1526 left_comments: vec![],
1527 operator_comments: vec![],
1528 trailing_comments: vec![],
1529 inferred_type: None,
1530 }));
1531 assert_eq!(outer.tree_depth(), 2);
1532 }
1533
1534 #[test]
1535 fn test_tree_context() {
1536 let col = make_column("a");
1537 let lit = make_literal(1);
1538 let expr = Expression::Eq(Box::new(BinaryOp {
1539 left: col,
1540 right: lit,
1541 left_comments: vec![],
1542 operator_comments: vec![],
1543 trailing_comments: vec![],
1544 inferred_type: None,
1545 }));
1546
1547 let ctx = TreeContext::build(&expr);
1548
1549 let root_info = ctx.get(0).unwrap();
1551 assert!(root_info.parent_id.is_none());
1552
1553 let left_info = ctx.get(1).unwrap();
1555 assert_eq!(left_info.parent_id, Some(0));
1556 assert_eq!(left_info.arg_key, "left");
1557
1558 let right_info = ctx.get(2).unwrap();
1559 assert_eq!(right_info.parent_id, Some(0));
1560 assert_eq!(right_info.arg_key, "right");
1561 }
1562
1563 #[test]
1566 fn test_transform_rename_columns() {
1567 let ast = crate::parser::Parser::parse_sql("SELECT a, b FROM t").unwrap();
1568 let expr = ast[0].clone();
1569 let result = super::transform_map(expr, &|e| {
1570 if let Expression::Column(ref c) = e {
1571 if c.name.name == "a" {
1572 return Ok(Expression::Column(Column {
1573 name: Identifier::new("alpha"),
1574 table: c.table.clone(),
1575 join_mark: false,
1576 trailing_comments: vec![],
1577 span: None,
1578 inferred_type: None,
1579 }));
1580 }
1581 }
1582 Ok(e)
1583 })
1584 .unwrap();
1585 let sql = crate::generator::Generator::sql(&result).unwrap();
1586 assert!(sql.contains("alpha"), "Expected 'alpha' in: {}", sql);
1587 assert!(sql.contains("b"), "Expected 'b' in: {}", sql);
1588 }
1589
1590 #[test]
1591 fn test_transform_noop() {
1592 let ast = crate::parser::Parser::parse_sql("SELECT 1 + 2").unwrap();
1593 let expr = ast[0].clone();
1594 let result = super::transform_map(expr.clone(), &|e| Ok(e)).unwrap();
1595 let sql1 = crate::generator::Generator::sql(&expr).unwrap();
1596 let sql2 = crate::generator::Generator::sql(&result).unwrap();
1597 assert_eq!(sql1, sql2);
1598 }
1599
1600 #[test]
1601 fn test_transform_nested() {
1602 let ast = crate::parser::Parser::parse_sql("SELECT a + b FROM t").unwrap();
1603 let expr = ast[0].clone();
1604 let result = super::transform_map(expr, &|e| {
1605 if let Expression::Column(ref c) = e {
1606 return Ok(Expression::Literal(Literal::Number(
1607 if c.name.name == "a" { "1" } else { "2" }.to_string(),
1608 )));
1609 }
1610 Ok(e)
1611 })
1612 .unwrap();
1613 let sql = crate::generator::Generator::sql(&result).unwrap();
1614 assert_eq!(sql, "SELECT 1 + 2 FROM t");
1615 }
1616
1617 #[test]
1618 fn test_transform_error() {
1619 let ast = crate::parser::Parser::parse_sql("SELECT a FROM t").unwrap();
1620 let expr = ast[0].clone();
1621 let result = super::transform_map(expr, &|e| {
1622 if let Expression::Column(ref c) = e {
1623 if c.name.name == "a" {
1624 return Err(crate::error::Error::parse("test error", 0, 0, 0, 0));
1625 }
1626 }
1627 Ok(e)
1628 });
1629 assert!(result.is_err());
1630 }
1631
1632 #[test]
1633 fn test_transform_owned_trait() {
1634 let ast = crate::parser::Parser::parse_sql("SELECT x FROM t").unwrap();
1635 let expr = ast[0].clone();
1636 let result = expr.transform_owned(|e| Ok(Some(e))).unwrap();
1637 let sql = crate::generator::Generator::sql(&result).unwrap();
1638 assert_eq!(sql, "SELECT x FROM t");
1639 }
1640
1641 #[test]
1644 fn test_children_leaf() {
1645 let lit = make_literal(1);
1646 assert_eq!(lit.children().len(), 0);
1647 }
1648
1649 #[test]
1650 fn test_children_binary_op() {
1651 let left = make_column("a");
1652 let right = make_literal(1);
1653 let expr = Expression::Eq(Box::new(BinaryOp {
1654 left,
1655 right,
1656 left_comments: vec![],
1657 operator_comments: vec![],
1658 trailing_comments: vec![],
1659 inferred_type: None,
1660 }));
1661 let children = expr.children();
1662 assert_eq!(children.len(), 2);
1663 assert!(matches!(children[0], Expression::Column(_)));
1664 assert!(matches!(children[1], Expression::Literal(_)));
1665 }
1666
1667 #[test]
1668 fn test_children_select() {
1669 let ast = crate::parser::Parser::parse_sql("SELECT a, b FROM t").unwrap();
1670 let expr = &ast[0];
1671 let children = expr.children();
1672 assert!(children.len() >= 2);
1674 }
1675
1676 #[test]
1677 fn test_children_select_includes_from_and_join_sources() {
1678 let ast = crate::parser::Parser::parse_sql(
1679 "SELECT u.id FROM users u JOIN orders o ON u.id = o.user_id",
1680 )
1681 .unwrap();
1682 let expr = &ast[0];
1683 let children = expr.children();
1684
1685 let table_names: Vec<&str> = children
1686 .iter()
1687 .filter_map(|e| match e {
1688 Expression::Table(t) => Some(t.name.name.as_str()),
1689 _ => None,
1690 })
1691 .collect();
1692
1693 assert!(table_names.contains(&"users"));
1694 assert!(table_names.contains(&"orders"));
1695 }
1696
1697 #[test]
1698 fn test_get_tables_includes_insert_query_sources() {
1699 let ast = crate::parser::Parser::parse_sql(
1700 "INSERT INTO dst (id) SELECT s.id FROM src s JOIN dim d ON s.id = d.id",
1701 )
1702 .unwrap();
1703 let expr = &ast[0];
1704 let tables = get_tables(expr);
1705 let names: Vec<&str> = tables
1706 .iter()
1707 .filter_map(|e| match e {
1708 Expression::Table(t) => Some(t.name.name.as_str()),
1709 _ => None,
1710 })
1711 .collect();
1712
1713 assert!(names.contains(&"src"));
1714 assert!(names.contains(&"dim"));
1715 }
1716
1717 #[test]
1720 fn test_find_parent_binary() {
1721 let left = make_column("a");
1722 let right = make_literal(1);
1723 let expr = Expression::Eq(Box::new(BinaryOp {
1724 left,
1725 right,
1726 left_comments: vec![],
1727 operator_comments: vec![],
1728 trailing_comments: vec![],
1729 inferred_type: None,
1730 }));
1731
1732 let col = expr.find(is_column).unwrap();
1734 let parent = super::find_parent(&expr, col);
1735 assert!(parent.is_some());
1736 assert!(matches!(parent.unwrap(), Expression::Eq(_)));
1737 }
1738
1739 #[test]
1740 fn test_find_parent_root_has_none() {
1741 let lit = make_literal(1);
1742 let parent = super::find_parent(&lit, &lit);
1743 assert!(parent.is_none());
1744 }
1745
1746 #[test]
1749 fn test_find_ancestor_select() {
1750 let ast = crate::parser::Parser::parse_sql("SELECT a FROM t WHERE a > 1").unwrap();
1751 let expr = &ast[0];
1752
1753 let where_col = expr.dfs().find(|e| {
1755 if let Expression::Column(c) = e {
1756 c.name.name == "a"
1757 } else {
1758 false
1759 }
1760 });
1761 assert!(where_col.is_some());
1762
1763 let ancestor = super::find_ancestor(expr, where_col.unwrap(), is_select);
1765 assert!(ancestor.is_some());
1766 assert!(matches!(ancestor.unwrap(), Expression::Select(_)));
1767 }
1768
1769 #[test]
1770 fn test_find_ancestor_no_match() {
1771 let left = make_column("a");
1772 let right = make_literal(1);
1773 let expr = Expression::Eq(Box::new(BinaryOp {
1774 left,
1775 right,
1776 left_comments: vec![],
1777 operator_comments: vec![],
1778 trailing_comments: vec![],
1779 inferred_type: None,
1780 }));
1781
1782 let col = expr.find(is_column).unwrap();
1783 let ancestor = super::find_ancestor(&expr, col, is_select);
1784 assert!(ancestor.is_none());
1785 }
1786
1787 #[test]
1788 fn test_ancestors() {
1789 let col = make_column("a");
1790 let lit = make_literal(1);
1791 let inner = Expression::Add(Box::new(BinaryOp {
1792 left: col,
1793 right: lit,
1794 left_comments: vec![],
1795 operator_comments: vec![],
1796 trailing_comments: vec![],
1797 inferred_type: None,
1798 }));
1799 let outer = Expression::Eq(Box::new(BinaryOp {
1800 left: make_column("b"),
1801 right: inner,
1802 left_comments: vec![],
1803 operator_comments: vec![],
1804 trailing_comments: vec![],
1805 inferred_type: None,
1806 }));
1807
1808 let ctx = TreeContext::build(&outer);
1809
1810 let ancestors = ctx.ancestors_of(3);
1818 assert_eq!(ancestors, vec![2, 0]); }
1820}