1use crate::expressions::{Expression, TableRef};
38use std::collections::{HashMap, VecDeque};
39
40pub type NodeId = usize;
42
43#[derive(Debug, Clone)]
45pub struct ParentInfo {
46 pub parent_id: Option<NodeId>,
48 pub arg_key: String,
50 pub index: Option<usize>,
52}
53
54#[derive(Debug, Default)]
68pub struct TreeContext {
69 nodes: HashMap<NodeId, ParentInfo>,
71 next_id: NodeId,
73 path: Vec<(NodeId, String, Option<usize>)>,
75}
76
77impl TreeContext {
78 pub fn new() -> Self {
80 Self::default()
81 }
82
83 pub fn build(root: &Expression) -> Self {
85 let mut ctx = Self::new();
86 ctx.visit_expr(root);
87 ctx
88 }
89
90 fn visit_expr(&mut self, expr: &Expression) -> NodeId {
92 let id = self.next_id;
93 self.next_id += 1;
94
95 let parent_info = if let Some((parent_id, arg_key, index)) = self.path.last() {
97 ParentInfo {
98 parent_id: Some(*parent_id),
99 arg_key: arg_key.clone(),
100 index: *index,
101 }
102 } else {
103 ParentInfo {
104 parent_id: None,
105 arg_key: String::new(),
106 index: None,
107 }
108 };
109 self.nodes.insert(id, parent_info);
110
111 for (key, child) in iter_children(expr) {
113 self.path.push((id, key.to_string(), None));
114 self.visit_expr(child);
115 self.path.pop();
116 }
117
118 for (key, children) in iter_children_lists(expr) {
120 for (idx, child) in children.iter().enumerate() {
121 self.path.push((id, key.to_string(), Some(idx)));
122 self.visit_expr(child);
123 self.path.pop();
124 }
125 }
126
127 id
128 }
129
130 pub fn get(&self, id: NodeId) -> Option<&ParentInfo> {
132 self.nodes.get(&id)
133 }
134
135 pub fn depth_of(&self, id: NodeId) -> usize {
137 let mut depth = 0;
138 let mut current = id;
139 while let Some(info) = self.nodes.get(¤t) {
140 if let Some(parent_id) = info.parent_id {
141 depth += 1;
142 current = parent_id;
143 } else {
144 break;
145 }
146 }
147 depth
148 }
149
150 pub fn ancestors_of(&self, id: NodeId) -> Vec<NodeId> {
152 let mut ancestors = Vec::new();
153 let mut current = id;
154 while let Some(info) = self.nodes.get(¤t) {
155 if let Some(parent_id) = info.parent_id {
156 ancestors.push(parent_id);
157 current = parent_id;
158 } else {
159 break;
160 }
161 }
162 ancestors
163 }
164}
165
166fn iter_children(expr: &Expression) -> Vec<(&'static str, &Expression)> {
170 let mut children = Vec::new();
171
172 match expr {
173 Expression::Select(s) => {
174 if let Some(from) = &s.from {
175 for source in &from.expressions {
176 children.push(("from", source));
177 }
178 }
179 for join in &s.joins {
180 children.push(("join_this", &join.this));
181 if let Some(on) = &join.on {
182 children.push(("join_on", on));
183 }
184 if let Some(match_condition) = &join.match_condition {
185 children.push(("join_match_condition", match_condition));
186 }
187 for pivot in &join.pivots {
188 children.push(("join_pivot", pivot));
189 }
190 }
191 for lateral_view in &s.lateral_views {
192 children.push(("lateral_view", &lateral_view.this));
193 }
194 if let Some(prewhere) = &s.prewhere {
195 children.push(("prewhere", prewhere));
196 }
197 if let Some(where_clause) = &s.where_clause {
198 children.push(("where", &where_clause.this));
199 }
200 if let Some(group_by) = &s.group_by {
201 for e in &group_by.expressions {
202 children.push(("group_by", e));
203 }
204 }
205 if let Some(having) = &s.having {
206 children.push(("having", &having.this));
207 }
208 if let Some(qualify) = &s.qualify {
209 children.push(("qualify", &qualify.this));
210 }
211 if let Some(order_by) = &s.order_by {
212 for ordered in &order_by.expressions {
213 children.push(("order_by", &ordered.this));
214 }
215 }
216 if let Some(distribute_by) = &s.distribute_by {
217 for e in &distribute_by.expressions {
218 children.push(("distribute_by", e));
219 }
220 }
221 if let Some(cluster_by) = &s.cluster_by {
222 for ordered in &cluster_by.expressions {
223 children.push(("cluster_by", &ordered.this));
224 }
225 }
226 if let Some(sort_by) = &s.sort_by {
227 for ordered in &sort_by.expressions {
228 children.push(("sort_by", &ordered.this));
229 }
230 }
231 if let Some(limit) = &s.limit {
232 children.push(("limit", &limit.this));
233 }
234 if let Some(offset) = &s.offset {
235 children.push(("offset", &offset.this));
236 }
237 if let Some(limit_by) = &s.limit_by {
238 for e in limit_by {
239 children.push(("limit_by", e));
240 }
241 }
242 if let Some(fetch) = &s.fetch {
243 if let Some(count) = &fetch.count {
244 children.push(("fetch", count));
245 }
246 }
247 if let Some(top) = &s.top {
248 children.push(("top", &top.this));
249 }
250 if let Some(with) = &s.with {
251 for cte in &with.ctes {
252 children.push(("with_cte", &cte.this));
253 }
254 if let Some(search) = &with.search {
255 children.push(("with_search", search));
256 }
257 }
258 if let Some(sample) = &s.sample {
259 children.push(("sample_size", &sample.size));
260 if let Some(seed) = &sample.seed {
261 children.push(("sample_seed", seed));
262 }
263 if let Some(offset) = &sample.offset {
264 children.push(("sample_offset", offset));
265 }
266 if let Some(bucket_numerator) = &sample.bucket_numerator {
267 children.push(("sample_bucket_numerator", bucket_numerator));
268 }
269 if let Some(bucket_denominator) = &sample.bucket_denominator {
270 children.push(("sample_bucket_denominator", bucket_denominator));
271 }
272 if let Some(bucket_field) = &sample.bucket_field {
273 children.push(("sample_bucket_field", bucket_field));
274 }
275 }
276 if let Some(connect) = &s.connect {
277 if let Some(start) = &connect.start {
278 children.push(("connect_start", start));
279 }
280 children.push(("connect", &connect.connect));
281 }
282 if let Some(into) = &s.into {
283 children.push(("into", &into.this));
284 }
285 for lock in &s.locks {
286 for e in &lock.expressions {
287 children.push(("lock_expression", e));
288 }
289 if let Some(wait) = &lock.wait {
290 children.push(("lock_wait", wait));
291 }
292 if let Some(key) = &lock.key {
293 children.push(("lock_key", key));
294 }
295 if let Some(update) = &lock.update {
296 children.push(("lock_update", update));
297 }
298 }
299 for e in &s.for_xml {
300 children.push(("for_xml", e));
301 }
302 }
303 Expression::With(with) => {
304 for cte in &with.ctes {
305 children.push(("cte", &cte.this));
306 }
307 if let Some(search) = &with.search {
308 children.push(("search", search));
309 }
310 }
311 Expression::Cte(cte) => {
312 children.push(("this", &cte.this));
313 }
314 Expression::Insert(insert) => {
315 if let Some(query) = &insert.query {
316 children.push(("query", query));
317 }
318 if let Some(with) = &insert.with {
319 for cte in &with.ctes {
320 children.push(("with_cte", &cte.this));
321 }
322 if let Some(search) = &with.search {
323 children.push(("with_search", search));
324 }
325 }
326 if let Some(on_conflict) = &insert.on_conflict {
327 children.push(("on_conflict", on_conflict));
328 }
329 if let Some(replace_where) = &insert.replace_where {
330 children.push(("replace_where", replace_where));
331 }
332 if let Some(source) = &insert.source {
333 children.push(("source", source));
334 }
335 if let Some(function_target) = &insert.function_target {
336 children.push(("function_target", function_target));
337 }
338 if let Some(partition_by) = &insert.partition_by {
339 children.push(("partition_by", partition_by));
340 }
341 if let Some(output) = &insert.output {
342 for column in &output.columns {
343 children.push(("output_column", column));
344 }
345 if let Some(into_table) = &output.into_table {
346 children.push(("output_into_table", into_table));
347 }
348 }
349 for row in &insert.values {
350 for value in row {
351 children.push(("value", value));
352 }
353 }
354 for (_, value) in &insert.partition {
355 if let Some(value) = value {
356 children.push(("partition_value", value));
357 }
358 }
359 for returning in &insert.returning {
360 children.push(("returning", returning));
361 }
362 for setting in &insert.settings {
363 children.push(("setting", setting));
364 }
365 }
366 Expression::Update(update) => {
367 if let Some(from_clause) = &update.from_clause {
368 for source in &from_clause.expressions {
369 children.push(("from", source));
370 }
371 }
372 for join in &update.table_joins {
373 children.push(("table_join_this", &join.this));
374 if let Some(on) = &join.on {
375 children.push(("table_join_on", on));
376 }
377 }
378 for join in &update.from_joins {
379 children.push(("from_join_this", &join.this));
380 if let Some(on) = &join.on {
381 children.push(("from_join_on", on));
382 }
383 }
384 for (_, value) in &update.set {
385 children.push(("set_value", value));
386 }
387 if let Some(where_clause) = &update.where_clause {
388 children.push(("where", &where_clause.this));
389 }
390 if let Some(output) = &update.output {
391 for column in &output.columns {
392 children.push(("output_column", column));
393 }
394 if let Some(into_table) = &output.into_table {
395 children.push(("output_into_table", into_table));
396 }
397 }
398 if let Some(with) = &update.with {
399 for cte in &with.ctes {
400 children.push(("with_cte", &cte.this));
401 }
402 if let Some(search) = &with.search {
403 children.push(("with_search", search));
404 }
405 }
406 if let Some(limit) = &update.limit {
407 children.push(("limit", limit));
408 }
409 if let Some(order_by) = &update.order_by {
410 for ordered in &order_by.expressions {
411 children.push(("order_by", &ordered.this));
412 }
413 }
414 for returning in &update.returning {
415 children.push(("returning", returning));
416 }
417 }
418 Expression::Delete(delete) => {
419 if let Some(with) = &delete.with {
420 for cte in &with.ctes {
421 children.push(("with_cte", &cte.this));
422 }
423 if let Some(search) = &with.search {
424 children.push(("with_search", search));
425 }
426 }
427 if let Some(where_clause) = &delete.where_clause {
428 children.push(("where", &where_clause.this));
429 }
430 if let Some(output) = &delete.output {
431 for column in &output.columns {
432 children.push(("output_column", column));
433 }
434 if let Some(into_table) = &output.into_table {
435 children.push(("output_into_table", into_table));
436 }
437 }
438 if let Some(limit) = &delete.limit {
439 children.push(("limit", limit));
440 }
441 if let Some(order_by) = &delete.order_by {
442 for ordered in &order_by.expressions {
443 children.push(("order_by", &ordered.this));
444 }
445 }
446 for returning in &delete.returning {
447 children.push(("returning", returning));
448 }
449 for join in &delete.joins {
450 children.push(("join_this", &join.this));
451 if let Some(on) = &join.on {
452 children.push(("join_on", on));
453 }
454 }
455 }
456 Expression::Join(join) => {
457 children.push(("this", &join.this));
458 if let Some(on) = &join.on {
459 children.push(("on", on));
460 }
461 if let Some(match_condition) = &join.match_condition {
462 children.push(("match_condition", match_condition));
463 }
464 for pivot in &join.pivots {
465 children.push(("pivot", pivot));
466 }
467 }
468 Expression::Alias(a) => {
469 children.push(("this", &a.this));
470 }
471 Expression::Cast(c) => {
472 children.push(("this", &c.this));
473 }
474 Expression::Not(u) | Expression::Neg(u) | Expression::BitwiseNot(u) => {
475 children.push(("this", &u.this));
476 }
477 Expression::Paren(p) => {
478 children.push(("this", &p.this));
479 }
480 Expression::IsNull(i) => {
481 children.push(("this", &i.this));
482 }
483 Expression::Exists(e) => {
484 children.push(("this", &e.this));
485 }
486 Expression::Subquery(s) => {
487 children.push(("this", &s.this));
488 }
489 Expression::Where(w) => {
490 children.push(("this", &w.this));
491 }
492 Expression::Having(h) => {
493 children.push(("this", &h.this));
494 }
495 Expression::Qualify(q) => {
496 children.push(("this", &q.this));
497 }
498 Expression::And(op)
499 | Expression::Or(op)
500 | Expression::Add(op)
501 | Expression::Sub(op)
502 | Expression::Mul(op)
503 | Expression::Div(op)
504 | Expression::Mod(op)
505 | Expression::Eq(op)
506 | Expression::Neq(op)
507 | Expression::Lt(op)
508 | Expression::Lte(op)
509 | Expression::Gt(op)
510 | Expression::Gte(op)
511 | Expression::BitwiseAnd(op)
512 | Expression::BitwiseOr(op)
513 | Expression::BitwiseXor(op)
514 | Expression::Concat(op) => {
515 children.push(("left", &op.left));
516 children.push(("right", &op.right));
517 }
518 Expression::Like(op) | Expression::ILike(op) => {
519 children.push(("left", &op.left));
520 children.push(("right", &op.right));
521 }
522 Expression::Between(b) => {
523 children.push(("this", &b.this));
524 children.push(("low", &b.low));
525 children.push(("high", &b.high));
526 }
527 Expression::In(i) => {
528 children.push(("this", &i.this));
529 if let Some(ref query) = i.query {
530 children.push(("query", query));
531 }
532 if let Some(ref unnest) = i.unnest {
533 children.push(("unnest", unnest));
534 }
535 }
536 Expression::Case(c) => {
537 if let Some(ref operand) = &c.operand {
538 children.push(("operand", operand));
539 }
540 }
541 Expression::WindowFunction(wf) => {
542 children.push(("this", &wf.this));
543 }
544 Expression::Union(u) => {
545 children.push(("left", &u.left));
546 children.push(("right", &u.right));
547 if let Some(with) = &u.with {
548 for cte in &with.ctes {
549 children.push(("with_cte", &cte.this));
550 }
551 if let Some(search) = &with.search {
552 children.push(("with_search", search));
553 }
554 }
555 if let Some(order_by) = &u.order_by {
556 for ordered in &order_by.expressions {
557 children.push(("order_by", &ordered.this));
558 }
559 }
560 if let Some(limit) = &u.limit {
561 children.push(("limit", limit));
562 }
563 if let Some(offset) = &u.offset {
564 children.push(("offset", offset));
565 }
566 if let Some(distribute_by) = &u.distribute_by {
567 for e in &distribute_by.expressions {
568 children.push(("distribute_by", e));
569 }
570 }
571 if let Some(sort_by) = &u.sort_by {
572 for ordered in &sort_by.expressions {
573 children.push(("sort_by", &ordered.this));
574 }
575 }
576 if let Some(cluster_by) = &u.cluster_by {
577 for ordered in &cluster_by.expressions {
578 children.push(("cluster_by", &ordered.this));
579 }
580 }
581 for e in &u.on_columns {
582 children.push(("on_column", e));
583 }
584 }
585 Expression::Intersect(i) => {
586 children.push(("left", &i.left));
587 children.push(("right", &i.right));
588 if let Some(with) = &i.with {
589 for cte in &with.ctes {
590 children.push(("with_cte", &cte.this));
591 }
592 if let Some(search) = &with.search {
593 children.push(("with_search", search));
594 }
595 }
596 if let Some(order_by) = &i.order_by {
597 for ordered in &order_by.expressions {
598 children.push(("order_by", &ordered.this));
599 }
600 }
601 if let Some(limit) = &i.limit {
602 children.push(("limit", limit));
603 }
604 if let Some(offset) = &i.offset {
605 children.push(("offset", offset));
606 }
607 if let Some(distribute_by) = &i.distribute_by {
608 for e in &distribute_by.expressions {
609 children.push(("distribute_by", e));
610 }
611 }
612 if let Some(sort_by) = &i.sort_by {
613 for ordered in &sort_by.expressions {
614 children.push(("sort_by", &ordered.this));
615 }
616 }
617 if let Some(cluster_by) = &i.cluster_by {
618 for ordered in &cluster_by.expressions {
619 children.push(("cluster_by", &ordered.this));
620 }
621 }
622 for e in &i.on_columns {
623 children.push(("on_column", e));
624 }
625 }
626 Expression::Except(e) => {
627 children.push(("left", &e.left));
628 children.push(("right", &e.right));
629 if let Some(with) = &e.with {
630 for cte in &with.ctes {
631 children.push(("with_cte", &cte.this));
632 }
633 if let Some(search) = &with.search {
634 children.push(("with_search", search));
635 }
636 }
637 if let Some(order_by) = &e.order_by {
638 for ordered in &order_by.expressions {
639 children.push(("order_by", &ordered.this));
640 }
641 }
642 if let Some(limit) = &e.limit {
643 children.push(("limit", limit));
644 }
645 if let Some(offset) = &e.offset {
646 children.push(("offset", offset));
647 }
648 if let Some(distribute_by) = &e.distribute_by {
649 for expr in &distribute_by.expressions {
650 children.push(("distribute_by", expr));
651 }
652 }
653 if let Some(sort_by) = &e.sort_by {
654 for ordered in &sort_by.expressions {
655 children.push(("sort_by", &ordered.this));
656 }
657 }
658 if let Some(cluster_by) = &e.cluster_by {
659 for ordered in &cluster_by.expressions {
660 children.push(("cluster_by", &ordered.this));
661 }
662 }
663 for expr in &e.on_columns {
664 children.push(("on_column", expr));
665 }
666 }
667 Expression::Merge(merge) => {
668 children.push(("this", &merge.this));
669 children.push(("using", &merge.using));
670 if let Some(on) = &merge.on {
671 children.push(("on", on));
672 }
673 if let Some(using_cond) = &merge.using_cond {
674 children.push(("using_cond", using_cond));
675 }
676 if let Some(whens) = &merge.whens {
677 children.push(("whens", whens));
678 }
679 if let Some(with_) = &merge.with_ {
680 children.push(("with_", with_));
681 }
682 if let Some(returning) = &merge.returning {
683 children.push(("returning", returning));
684 }
685 }
686 Expression::Any(q) | Expression::All(q) => {
687 children.push(("this", &q.this));
688 children.push(("subquery", &q.subquery));
689 }
690 Expression::Ordered(o) => {
691 children.push(("this", &o.this));
692 }
693 Expression::Interval(i) => {
694 if let Some(ref this) = i.this {
695 children.push(("this", this));
696 }
697 }
698 Expression::Describe(d) => {
699 children.push(("target", &d.target));
700 }
701 Expression::CreateTask(ct) => {
702 children.push(("body", &ct.body));
703 }
704 Expression::Prepare(prepare) => {
705 children.push(("statement", &prepare.statement));
706 }
707 Expression::Execute(exec) => {
708 children.push(("this", &exec.this));
709 for argument in &exec.arguments {
710 children.push(("argument", argument));
711 }
712 for parameter in &exec.parameters {
713 children.push(("parameter", ¶meter.value));
714 }
715 }
716 Expression::Analyze(a) => {
717 if let Some(this) = &a.this {
718 children.push(("this", this));
719 }
720 if let Some(expr) = &a.expression {
721 children.push(("expression", expr));
722 }
723 }
724 _ => {}
725 }
726
727 children
728}
729
730fn iter_children_lists(expr: &Expression) -> Vec<(&'static str, &[Expression])> {
734 let mut lists = Vec::new();
735
736 match expr {
737 Expression::Select(s) => lists.push(("expressions", s.expressions.as_slice())),
738 Expression::Function(f) => {
739 lists.push(("args", f.args.as_slice()));
740 }
741 Expression::AggregateFunction(f) => {
742 lists.push(("args", f.args.as_slice()));
743 }
744 Expression::From(f) => {
745 lists.push(("expressions", f.expressions.as_slice()));
746 }
747 Expression::GroupBy(g) => {
748 lists.push(("expressions", g.expressions.as_slice()));
749 }
750 Expression::In(i) => {
753 lists.push(("expressions", i.expressions.as_slice()));
754 }
755 Expression::Array(a) => {
756 lists.push(("expressions", a.expressions.as_slice()));
757 }
758 Expression::Tuple(t) => {
759 lists.push(("expressions", t.expressions.as_slice()));
760 }
761 Expression::TryCatch(try_catch) => {
762 lists.push(("try_body", try_catch.try_body.as_slice()));
763 if let Some(catch_body) = &try_catch.catch_body {
764 lists.push(("catch_body", catch_body.as_slice()));
765 }
766 }
767 Expression::Coalesce(c) => {
769 lists.push(("expressions", c.expressions.as_slice()));
770 }
771 Expression::Greatest(g) | Expression::Least(g) => {
772 lists.push(("expressions", g.expressions.as_slice()));
773 }
774 _ => {}
775 }
776
777 lists
778}
779
780pub struct DfsIter<'a> {
789 stack: Vec<&'a Expression>,
790}
791
792impl<'a> DfsIter<'a> {
793 pub fn new(root: &'a Expression) -> Self {
795 Self { stack: vec![root] }
796 }
797}
798
799impl<'a> Iterator for DfsIter<'a> {
800 type Item = &'a Expression;
801
802 fn next(&mut self) -> Option<Self::Item> {
803 let expr = self.stack.pop()?;
804
805 let children: Vec<_> = iter_children(expr).into_iter().map(|(_, e)| e).collect();
807 for child in children.into_iter().rev() {
808 self.stack.push(child);
809 }
810
811 let lists: Vec<_> = iter_children_lists(expr)
812 .into_iter()
813 .flat_map(|(_, es)| es.iter())
814 .collect();
815 for child in lists.into_iter().rev() {
816 self.stack.push(child);
817 }
818
819 Some(expr)
820 }
821}
822
823pub struct BfsIter<'a> {
831 queue: VecDeque<&'a Expression>,
832}
833
834impl<'a> BfsIter<'a> {
835 pub fn new(root: &'a Expression) -> Self {
837 let mut queue = VecDeque::new();
838 queue.push_back(root);
839 Self { queue }
840 }
841}
842
843impl<'a> Iterator for BfsIter<'a> {
844 type Item = &'a Expression;
845
846 fn next(&mut self) -> Option<Self::Item> {
847 let expr = self.queue.pop_front()?;
848
849 for (_, child) in iter_children(expr) {
851 self.queue.push_back(child);
852 }
853
854 for (_, children) in iter_children_lists(expr) {
855 for child in children {
856 self.queue.push_back(child);
857 }
858 }
859
860 Some(expr)
861 }
862}
863
864pub trait ExpressionWalk {
870 fn dfs(&self) -> DfsIter<'_>;
875
876 fn bfs(&self) -> BfsIter<'_>;
880
881 fn find<F>(&self, predicate: F) -> Option<&Expression>
885 where
886 F: Fn(&Expression) -> bool;
887
888 fn find_all<F>(&self, predicate: F) -> Vec<&Expression>
892 where
893 F: Fn(&Expression) -> bool;
894
895 fn contains<F>(&self, predicate: F) -> bool
897 where
898 F: Fn(&Expression) -> bool;
899
900 fn count<F>(&self, predicate: F) -> usize
902 where
903 F: Fn(&Expression) -> bool;
904
905 fn children(&self) -> Vec<&Expression>;
910
911 fn tree_depth(&self) -> usize;
915
916 fn transform_owned<F>(self, fun: F) -> crate::Result<Expression>
922 where
923 F: Fn(Expression) -> crate::Result<Option<Expression>>,
924 Self: Sized;
925}
926
927impl ExpressionWalk for Expression {
928 fn dfs(&self) -> DfsIter<'_> {
929 DfsIter::new(self)
930 }
931
932 fn bfs(&self) -> BfsIter<'_> {
933 BfsIter::new(self)
934 }
935
936 fn find<F>(&self, predicate: F) -> Option<&Expression>
937 where
938 F: Fn(&Expression) -> bool,
939 {
940 self.dfs().find(|e| predicate(e))
941 }
942
943 fn find_all<F>(&self, predicate: F) -> Vec<&Expression>
944 where
945 F: Fn(&Expression) -> bool,
946 {
947 self.dfs().filter(|e| predicate(e)).collect()
948 }
949
950 fn contains<F>(&self, predicate: F) -> bool
951 where
952 F: Fn(&Expression) -> bool,
953 {
954 self.dfs().any(|e| predicate(e))
955 }
956
957 fn count<F>(&self, predicate: F) -> usize
958 where
959 F: Fn(&Expression) -> bool,
960 {
961 self.dfs().filter(|e| predicate(e)).count()
962 }
963
964 fn children(&self) -> Vec<&Expression> {
965 let mut result: Vec<&Expression> = Vec::new();
966 for (_, child) in iter_children(self) {
967 result.push(child);
968 }
969 for (_, children_list) in iter_children_lists(self) {
970 for child in children_list {
971 result.push(child);
972 }
973 }
974 result
975 }
976
977 fn tree_depth(&self) -> usize {
978 let mut max_depth = 0;
979
980 for (_, child) in iter_children(self) {
981 let child_depth = child.tree_depth();
982 if child_depth + 1 > max_depth {
983 max_depth = child_depth + 1;
984 }
985 }
986
987 for (_, children) in iter_children_lists(self) {
988 for child in children {
989 let child_depth = child.tree_depth();
990 if child_depth + 1 > max_depth {
991 max_depth = child_depth + 1;
992 }
993 }
994 }
995
996 max_depth
997 }
998
999 fn transform_owned<F>(self, fun: F) -> crate::Result<Expression>
1000 where
1001 F: Fn(Expression) -> crate::Result<Option<Expression>>,
1002 {
1003 transform(self, &fun)
1004 }
1005}
1006
1007pub fn transform<F>(expr: Expression, fun: &F) -> crate::Result<Expression>
1028where
1029 F: Fn(Expression) -> crate::Result<Option<Expression>>,
1030{
1031 crate::dialects::transform_recursive(expr, &|e| match fun(e)? {
1032 Some(transformed) => Ok(transformed),
1033 None => Ok(Expression::Null(crate::expressions::Null)),
1034 })
1035}
1036
1037pub fn transform_map<F>(expr: Expression, fun: &F) -> crate::Result<Expression>
1058where
1059 F: Fn(Expression) -> crate::Result<Expression>,
1060{
1061 crate::dialects::transform_recursive(expr, fun)
1062}
1063
1064pub fn is_column(expr: &Expression) -> bool {
1072 matches!(expr, Expression::Column(_))
1073}
1074
1075pub fn is_literal(expr: &Expression) -> bool {
1077 matches!(
1078 expr,
1079 Expression::Literal(_) | Expression::Boolean(_) | Expression::Null(_)
1080 )
1081}
1082
1083pub fn is_function(expr: &Expression) -> bool {
1085 matches!(
1086 expr,
1087 Expression::Function(_) | Expression::AggregateFunction(_)
1088 )
1089}
1090
1091pub fn is_subquery(expr: &Expression) -> bool {
1093 matches!(expr, Expression::Subquery(_))
1094}
1095
1096pub fn is_select(expr: &Expression) -> bool {
1098 matches!(expr, Expression::Select(_))
1099}
1100
1101pub fn is_aggregate(expr: &Expression) -> bool {
1103 matches!(
1104 expr,
1105 Expression::AggregateFunction(_)
1106 | Expression::Count(_)
1107 | Expression::Sum(_)
1108 | Expression::Avg(_)
1109 | Expression::Min(_)
1110 | Expression::Max(_)
1111 | Expression::GroupConcat(_)
1112 | Expression::StringAgg(_)
1113 | Expression::ListAgg(_)
1114 | Expression::CountIf(_)
1115 | Expression::SumIf(_)
1116 )
1117}
1118
1119pub fn is_window_function(expr: &Expression) -> bool {
1121 matches!(expr, Expression::WindowFunction(_))
1122}
1123
1124pub fn get_columns(expr: &Expression) -> Vec<&Expression> {
1128 expr.find_all(is_column)
1129}
1130
1131pub fn get_tables(expr: &Expression) -> Vec<&Expression> {
1139 expr.find_all(|e| matches!(e, Expression::Table(_)))
1140}
1141
1142pub fn get_all_tables(expr: &Expression) -> Vec<Expression> {
1150 use std::collections::HashSet;
1151
1152 let mut seen = HashSet::new();
1153 let mut result = Vec::new();
1154
1155 for node in expr.dfs() {
1157 if let Expression::Table(t) = node {
1158 let qname = table_ref_qualified_name(t);
1159 if seen.insert(qname) {
1160 result.push(node.clone());
1161 }
1162 }
1163
1164 let refs: Vec<&TableRef> = match node {
1166 Expression::Insert(ins) => vec![&ins.table],
1167 Expression::Update(upd) => {
1168 let mut v = vec![&upd.table];
1169 v.extend(upd.extra_tables.iter());
1170 v
1171 }
1172 Expression::Delete(del) => {
1173 let mut v = vec![&del.table];
1174 v.extend(del.using.iter());
1175 v
1176 }
1177 _ => continue,
1178 };
1179 for tref in refs {
1180 if tref.name.name.is_empty() {
1181 continue;
1182 }
1183 let qname = table_ref_qualified_name(tref);
1184 if seen.insert(qname) {
1185 result.push(Expression::Table(Box::new(tref.clone())));
1186 }
1187 }
1188 }
1189
1190 result
1191}
1192
1193fn table_ref_qualified_name(t: &TableRef) -> String {
1195 let mut name = String::new();
1196 if let Some(ref cat) = t.catalog {
1197 name.push_str(&cat.name);
1198 name.push('.');
1199 }
1200 if let Some(ref schema) = t.schema {
1201 name.push_str(&schema.name);
1202 name.push('.');
1203 }
1204 name.push_str(&t.name.name);
1205 name
1206}
1207
1208fn unwrap_merge_table(expr: &Expression) -> Option<&Expression> {
1212 match expr {
1213 Expression::Table(_) => Some(expr),
1214 Expression::Alias(alias) => match &alias.this {
1215 Expression::Table(_) => Some(&alias.this),
1216 _ => None,
1217 },
1218 _ => None,
1219 }
1220}
1221
1222pub fn get_merge_target(expr: &Expression) -> Option<&Expression> {
1227 match expr {
1228 Expression::Merge(m) => unwrap_merge_table(&m.this),
1229 _ => None,
1230 }
1231}
1232
1233pub fn get_merge_source(expr: &Expression) -> Option<&Expression> {
1239 match expr {
1240 Expression::Merge(m) => unwrap_merge_table(&m.using),
1241 _ => None,
1242 }
1243}
1244
1245pub fn contains_aggregate(expr: &Expression) -> bool {
1247 expr.contains(is_aggregate)
1248}
1249
1250pub fn contains_window_function(expr: &Expression) -> bool {
1252 expr.contains(is_window_function)
1253}
1254
1255pub fn contains_subquery(expr: &Expression) -> bool {
1257 expr.contains(is_subquery)
1258}
1259
1260macro_rules! is_type {
1266 ($name:ident, $($variant:pat),+ $(,)?) => {
1267 pub fn $name(expr: &Expression) -> bool {
1269 matches!(expr, $($variant)|+)
1270 }
1271 };
1272}
1273
1274is_type!(is_insert, Expression::Insert(_));
1276is_type!(is_update, Expression::Update(_));
1277is_type!(is_delete, Expression::Delete(_));
1278is_type!(is_merge, Expression::Merge(_));
1279is_type!(is_union, Expression::Union(_));
1280is_type!(is_intersect, Expression::Intersect(_));
1281is_type!(is_except, Expression::Except(_));
1282
1283is_type!(is_boolean, Expression::Boolean(_));
1285is_type!(is_null_literal, Expression::Null(_));
1286is_type!(is_star, Expression::Star(_));
1287is_type!(is_identifier, Expression::Identifier(_));
1288is_type!(is_table, Expression::Table(_));
1289
1290is_type!(is_eq, Expression::Eq(_));
1292is_type!(is_neq, Expression::Neq(_));
1293is_type!(is_lt, Expression::Lt(_));
1294is_type!(is_lte, Expression::Lte(_));
1295is_type!(is_gt, Expression::Gt(_));
1296is_type!(is_gte, Expression::Gte(_));
1297is_type!(is_like, Expression::Like(_));
1298is_type!(is_ilike, Expression::ILike(_));
1299
1300is_type!(is_add, Expression::Add(_));
1302is_type!(is_sub, Expression::Sub(_));
1303is_type!(is_mul, Expression::Mul(_));
1304is_type!(is_div, Expression::Div(_));
1305is_type!(is_mod, Expression::Mod(_));
1306is_type!(is_concat, Expression::Concat(_));
1307
1308is_type!(is_and, Expression::And(_));
1310is_type!(is_or, Expression::Or(_));
1311is_type!(is_not, Expression::Not(_));
1312
1313is_type!(is_in, Expression::In(_));
1315is_type!(is_between, Expression::Between(_));
1316is_type!(is_is_null, Expression::IsNull(_));
1317is_type!(is_exists, Expression::Exists(_));
1318
1319is_type!(is_count, Expression::Count(_));
1321is_type!(is_sum, Expression::Sum(_));
1322is_type!(is_avg, Expression::Avg(_));
1323is_type!(is_min_func, Expression::Min(_));
1324is_type!(is_max_func, Expression::Max(_));
1325is_type!(is_coalesce, Expression::Coalesce(_));
1326is_type!(is_null_if, Expression::NullIf(_));
1327is_type!(is_cast, Expression::Cast(_));
1328is_type!(is_try_cast, Expression::TryCast(_));
1329is_type!(is_safe_cast, Expression::SafeCast(_));
1330is_type!(is_case, Expression::Case(_));
1331
1332is_type!(is_from, Expression::From(_));
1334is_type!(is_join, Expression::Join(_));
1335is_type!(is_where, Expression::Where(_));
1336is_type!(is_group_by, Expression::GroupBy(_));
1337is_type!(is_having, Expression::Having(_));
1338is_type!(is_order_by, Expression::OrderBy(_));
1339is_type!(is_limit, Expression::Limit(_));
1340is_type!(is_offset, Expression::Offset(_));
1341is_type!(is_with, Expression::With(_));
1342is_type!(is_cte, Expression::Cte(_));
1343is_type!(is_alias, Expression::Alias(_));
1344is_type!(is_paren, Expression::Paren(_));
1345is_type!(is_ordered, Expression::Ordered(_));
1346
1347is_type!(is_create_table, Expression::CreateTable(_));
1349is_type!(is_drop_table, Expression::DropTable(_));
1350is_type!(is_alter_table, Expression::AlterTable(_));
1351is_type!(is_create_index, Expression::CreateIndex(_));
1352is_type!(is_drop_index, Expression::DropIndex(_));
1353is_type!(is_create_view, Expression::CreateView(_));
1354is_type!(is_drop_view, Expression::DropView(_));
1355
1356pub fn is_query(expr: &Expression) -> bool {
1362 matches!(
1363 expr,
1364 Expression::Select(_)
1365 | Expression::Insert(_)
1366 | Expression::Update(_)
1367 | Expression::Delete(_)
1368 | Expression::Merge(_)
1369 )
1370}
1371
1372pub fn is_set_operation(expr: &Expression) -> bool {
1374 matches!(
1375 expr,
1376 Expression::Union(_) | Expression::Intersect(_) | Expression::Except(_)
1377 )
1378}
1379
1380pub fn is_comparison(expr: &Expression) -> bool {
1382 matches!(
1383 expr,
1384 Expression::Eq(_)
1385 | Expression::Neq(_)
1386 | Expression::Lt(_)
1387 | Expression::Lte(_)
1388 | Expression::Gt(_)
1389 | Expression::Gte(_)
1390 | Expression::Like(_)
1391 | Expression::ILike(_)
1392 )
1393}
1394
1395pub fn is_arithmetic(expr: &Expression) -> bool {
1397 matches!(
1398 expr,
1399 Expression::Add(_)
1400 | Expression::Sub(_)
1401 | Expression::Mul(_)
1402 | Expression::Div(_)
1403 | Expression::Mod(_)
1404 )
1405}
1406
1407pub fn is_logical(expr: &Expression) -> bool {
1409 matches!(
1410 expr,
1411 Expression::And(_) | Expression::Or(_) | Expression::Not(_)
1412 )
1413}
1414
1415pub fn is_ddl(expr: &Expression) -> bool {
1417 matches!(
1418 expr,
1419 Expression::CreateTable(_)
1420 | Expression::DropTable(_)
1421 | Expression::Undrop(_)
1422 | Expression::AlterTable(_)
1423 | Expression::CreateIndex(_)
1424 | Expression::DropIndex(_)
1425 | Expression::CreateView(_)
1426 | Expression::DropView(_)
1427 | Expression::AlterView(_)
1428 | Expression::CreateSchema(_)
1429 | Expression::DropSchema(_)
1430 | Expression::CreateDatabase(_)
1431 | Expression::DropDatabase(_)
1432 | Expression::CreateFunction(_)
1433 | Expression::DropFunction(_)
1434 | Expression::CreateProcedure(_)
1435 | Expression::DropProcedure(_)
1436 | Expression::CreateSequence(_)
1437 | Expression::CreateSynonym(_)
1438 | Expression::DropSequence(_)
1439 | Expression::AlterSequence(_)
1440 | Expression::CreateTrigger(_)
1441 | Expression::DropTrigger(_)
1442 | Expression::CreateType(_)
1443 | Expression::DropType(_)
1444 )
1445}
1446
1447pub fn find_parent<'a>(root: &'a Expression, target: &Expression) -> Option<&'a Expression> {
1454 fn search<'a>(node: &'a Expression, target: *const Expression) -> Option<&'a Expression> {
1455 for (_, child) in iter_children(node) {
1456 if std::ptr::eq(child, target) {
1457 return Some(node);
1458 }
1459 if let Some(found) = search(child, target) {
1460 return Some(found);
1461 }
1462 }
1463 for (_, children_list) in iter_children_lists(node) {
1464 for child in children_list {
1465 if std::ptr::eq(child, target) {
1466 return Some(node);
1467 }
1468 if let Some(found) = search(child, target) {
1469 return Some(found);
1470 }
1471 }
1472 }
1473 None
1474 }
1475
1476 search(root, target as *const Expression)
1477}
1478
1479pub fn find_ancestor<'a, F>(
1485 root: &'a Expression,
1486 target: &Expression,
1487 predicate: F,
1488) -> Option<&'a Expression>
1489where
1490 F: Fn(&Expression) -> bool,
1491{
1492 fn build_path<'a>(
1494 node: &'a Expression,
1495 target: *const Expression,
1496 path: &mut Vec<&'a Expression>,
1497 ) -> bool {
1498 if std::ptr::eq(node, target) {
1499 return true;
1500 }
1501 path.push(node);
1502 for (_, child) in iter_children(node) {
1503 if build_path(child, target, path) {
1504 return true;
1505 }
1506 }
1507 for (_, children_list) in iter_children_lists(node) {
1508 for child in children_list {
1509 if build_path(child, target, path) {
1510 return true;
1511 }
1512 }
1513 }
1514 path.pop();
1515 false
1516 }
1517
1518 let mut path = Vec::new();
1519 if !build_path(root, target as *const Expression, &mut path) {
1520 return None;
1521 }
1522
1523 for ancestor in path.iter().rev() {
1525 if predicate(ancestor) {
1526 return Some(ancestor);
1527 }
1528 }
1529 None
1530}
1531
1532#[cfg(test)]
1533mod tests {
1534 use super::*;
1535 use crate::expressions::{BinaryOp, Column, Identifier, Literal};
1536
1537 fn make_column(name: &str) -> Expression {
1538 Expression::boxed_column(Column {
1539 name: Identifier {
1540 name: name.to_string(),
1541 quoted: false,
1542 trailing_comments: vec![],
1543 span: None,
1544 },
1545 table: None,
1546 join_mark: false,
1547 trailing_comments: vec![],
1548 span: None,
1549 inferred_type: None,
1550 })
1551 }
1552
1553 fn make_literal(value: i64) -> Expression {
1554 Expression::Literal(Box::new(Literal::Number(value.to_string())))
1555 }
1556
1557 #[test]
1558 fn test_dfs_simple() {
1559 let left = make_column("a");
1560 let right = make_literal(1);
1561 let expr = Expression::Eq(Box::new(BinaryOp {
1562 left,
1563 right,
1564 left_comments: vec![],
1565 operator_comments: vec![],
1566 trailing_comments: vec![],
1567 inferred_type: None,
1568 }));
1569
1570 let nodes: Vec<_> = expr.dfs().collect();
1571 assert_eq!(nodes.len(), 3); assert!(matches!(nodes[0], Expression::Eq(_)));
1573 assert!(matches!(nodes[1], Expression::Column(_)));
1574 assert!(matches!(nodes[2], Expression::Literal(_)));
1575 }
1576
1577 #[test]
1578 fn test_find() {
1579 let left = make_column("a");
1580 let right = make_literal(1);
1581 let expr = Expression::Eq(Box::new(BinaryOp {
1582 left,
1583 right,
1584 left_comments: vec![],
1585 operator_comments: vec![],
1586 trailing_comments: vec![],
1587 inferred_type: None,
1588 }));
1589
1590 let column = expr.find(is_column);
1591 assert!(column.is_some());
1592 assert!(matches!(column.unwrap(), Expression::Column(_)));
1593
1594 let literal = expr.find(is_literal);
1595 assert!(literal.is_some());
1596 assert!(matches!(literal.unwrap(), Expression::Literal(_)));
1597 }
1598
1599 #[test]
1600 fn test_find_all() {
1601 let col1 = make_column("a");
1602 let col2 = make_column("b");
1603 let expr = Expression::And(Box::new(BinaryOp {
1604 left: col1,
1605 right: col2,
1606 left_comments: vec![],
1607 operator_comments: vec![],
1608 trailing_comments: vec![],
1609 inferred_type: None,
1610 }));
1611
1612 let columns = expr.find_all(is_column);
1613 assert_eq!(columns.len(), 2);
1614 }
1615
1616 #[test]
1617 fn test_contains() {
1618 let col = make_column("a");
1619 let lit = make_literal(1);
1620 let expr = Expression::Eq(Box::new(BinaryOp {
1621 left: col,
1622 right: lit,
1623 left_comments: vec![],
1624 operator_comments: vec![],
1625 trailing_comments: vec![],
1626 inferred_type: None,
1627 }));
1628
1629 assert!(expr.contains(is_column));
1630 assert!(expr.contains(is_literal));
1631 assert!(!expr.contains(is_subquery));
1632 }
1633
1634 #[test]
1635 fn test_count() {
1636 let col1 = make_column("a");
1637 let col2 = make_column("b");
1638 let lit = make_literal(1);
1639
1640 let inner = Expression::Add(Box::new(BinaryOp {
1641 left: col2,
1642 right: lit,
1643 left_comments: vec![],
1644 operator_comments: vec![],
1645 trailing_comments: vec![],
1646 inferred_type: None,
1647 }));
1648
1649 let expr = Expression::Eq(Box::new(BinaryOp {
1650 left: col1,
1651 right: inner,
1652 left_comments: vec![],
1653 operator_comments: vec![],
1654 trailing_comments: vec![],
1655 inferred_type: None,
1656 }));
1657
1658 assert_eq!(expr.count(is_column), 2);
1659 assert_eq!(expr.count(is_literal), 1);
1660 }
1661
1662 #[test]
1663 fn test_tree_depth() {
1664 let lit = make_literal(1);
1666 assert_eq!(lit.tree_depth(), 0);
1667
1668 let col = make_column("a");
1670 let expr = Expression::Eq(Box::new(BinaryOp {
1671 left: col,
1672 right: lit.clone(),
1673 left_comments: vec![],
1674 operator_comments: vec![],
1675 trailing_comments: vec![],
1676 inferred_type: None,
1677 }));
1678 assert_eq!(expr.tree_depth(), 1);
1679
1680 let inner = Expression::Add(Box::new(BinaryOp {
1682 left: make_column("b"),
1683 right: lit,
1684 left_comments: vec![],
1685 operator_comments: vec![],
1686 trailing_comments: vec![],
1687 inferred_type: None,
1688 }));
1689 let outer = Expression::Eq(Box::new(BinaryOp {
1690 left: make_column("a"),
1691 right: inner,
1692 left_comments: vec![],
1693 operator_comments: vec![],
1694 trailing_comments: vec![],
1695 inferred_type: None,
1696 }));
1697 assert_eq!(outer.tree_depth(), 2);
1698 }
1699
1700 #[test]
1701 fn test_tree_context() {
1702 let col = make_column("a");
1703 let lit = make_literal(1);
1704 let expr = Expression::Eq(Box::new(BinaryOp {
1705 left: col,
1706 right: lit,
1707 left_comments: vec![],
1708 operator_comments: vec![],
1709 trailing_comments: vec![],
1710 inferred_type: None,
1711 }));
1712
1713 let ctx = TreeContext::build(&expr);
1714
1715 let root_info = ctx.get(0).unwrap();
1717 assert!(root_info.parent_id.is_none());
1718
1719 let left_info = ctx.get(1).unwrap();
1721 assert_eq!(left_info.parent_id, Some(0));
1722 assert_eq!(left_info.arg_key, "left");
1723
1724 let right_info = ctx.get(2).unwrap();
1725 assert_eq!(right_info.parent_id, Some(0));
1726 assert_eq!(right_info.arg_key, "right");
1727 }
1728
1729 #[test]
1732 fn test_transform_rename_columns() {
1733 let ast = crate::parser::Parser::parse_sql("SELECT a, b FROM t").unwrap();
1734 let expr = ast[0].clone();
1735 let result = super::transform_map(expr, &|e| {
1736 if let Expression::Column(ref c) = e {
1737 if c.name.name == "a" {
1738 return Ok(Expression::boxed_column(Column {
1739 name: Identifier::new("alpha"),
1740 table: c.table.clone(),
1741 join_mark: false,
1742 trailing_comments: vec![],
1743 span: None,
1744 inferred_type: None,
1745 }));
1746 }
1747 }
1748 Ok(e)
1749 })
1750 .unwrap();
1751 let sql = crate::generator::Generator::sql(&result).unwrap();
1752 assert!(sql.contains("alpha"), "Expected 'alpha' in: {}", sql);
1753 assert!(sql.contains("b"), "Expected 'b' in: {}", sql);
1754 }
1755
1756 #[test]
1757 fn test_transform_noop() {
1758 let ast = crate::parser::Parser::parse_sql("SELECT 1 + 2").unwrap();
1759 let expr = ast[0].clone();
1760 let result = super::transform_map(expr.clone(), &|e| Ok(e)).unwrap();
1761 let sql1 = crate::generator::Generator::sql(&expr).unwrap();
1762 let sql2 = crate::generator::Generator::sql(&result).unwrap();
1763 assert_eq!(sql1, sql2);
1764 }
1765
1766 #[test]
1767 fn test_transform_nested() {
1768 let ast = crate::parser::Parser::parse_sql("SELECT a + b FROM t").unwrap();
1769 let expr = ast[0].clone();
1770 let result = super::transform_map(expr, &|e| {
1771 if let Expression::Column(ref c) = e {
1772 return Ok(Expression::Literal(Box::new(Literal::Number(
1773 if c.name.name == "a" { "1" } else { "2" }.to_string(),
1774 ))));
1775 }
1776 Ok(e)
1777 })
1778 .unwrap();
1779 let sql = crate::generator::Generator::sql(&result).unwrap();
1780 assert_eq!(sql, "SELECT 1 + 2 FROM t");
1781 }
1782
1783 #[test]
1784 fn test_transform_error() {
1785 let ast = crate::parser::Parser::parse_sql("SELECT a FROM t").unwrap();
1786 let expr = ast[0].clone();
1787 let result = super::transform_map(expr, &|e| {
1788 if let Expression::Column(ref c) = e {
1789 if c.name.name == "a" {
1790 return Err(crate::error::Error::parse("test error", 0, 0, 0, 0));
1791 }
1792 }
1793 Ok(e)
1794 });
1795 assert!(result.is_err());
1796 }
1797
1798 #[test]
1799 fn test_transform_owned_trait() {
1800 let ast = crate::parser::Parser::parse_sql("SELECT x FROM t").unwrap();
1801 let expr = ast[0].clone();
1802 let result = expr.transform_owned(|e| Ok(Some(e))).unwrap();
1803 let sql = crate::generator::Generator::sql(&result).unwrap();
1804 assert_eq!(sql, "SELECT x FROM t");
1805 }
1806
1807 #[test]
1810 fn test_children_leaf() {
1811 let lit = make_literal(1);
1812 assert_eq!(lit.children().len(), 0);
1813 }
1814
1815 #[test]
1816 fn test_children_binary_op() {
1817 let left = make_column("a");
1818 let right = make_literal(1);
1819 let expr = Expression::Eq(Box::new(BinaryOp {
1820 left,
1821 right,
1822 left_comments: vec![],
1823 operator_comments: vec![],
1824 trailing_comments: vec![],
1825 inferred_type: None,
1826 }));
1827 let children = expr.children();
1828 assert_eq!(children.len(), 2);
1829 assert!(matches!(children[0], Expression::Column(_)));
1830 assert!(matches!(children[1], Expression::Literal(_)));
1831 }
1832
1833 #[test]
1834 fn test_children_select() {
1835 let ast = crate::parser::Parser::parse_sql("SELECT a, b FROM t").unwrap();
1836 let expr = &ast[0];
1837 let children = expr.children();
1838 assert!(children.len() >= 2);
1840 }
1841
1842 #[test]
1843 fn test_children_select_includes_from_and_join_sources() {
1844 let ast = crate::parser::Parser::parse_sql(
1845 "SELECT u.id FROM users u JOIN orders o ON u.id = o.user_id",
1846 )
1847 .unwrap();
1848 let expr = &ast[0];
1849 let children = expr.children();
1850
1851 let table_names: Vec<&str> = children
1852 .iter()
1853 .filter_map(|e| match e {
1854 Expression::Table(t) => Some(t.name.name.as_str()),
1855 _ => None,
1856 })
1857 .collect();
1858
1859 assert!(table_names.contains(&"users"));
1860 assert!(table_names.contains(&"orders"));
1861 }
1862
1863 #[test]
1864 fn test_get_tables_includes_insert_query_sources() {
1865 let ast = crate::parser::Parser::parse_sql(
1866 "INSERT INTO dst (id) SELECT s.id FROM src s JOIN dim d ON s.id = d.id",
1867 )
1868 .unwrap();
1869 let expr = &ast[0];
1870 let tables = get_tables(expr);
1871 let names: Vec<&str> = tables
1872 .iter()
1873 .filter_map(|e| match e {
1874 Expression::Table(t) => Some(t.name.name.as_str()),
1875 _ => None,
1876 })
1877 .collect();
1878
1879 assert!(names.contains(&"src"));
1880 assert!(names.contains(&"dim"));
1881 }
1882
1883 #[test]
1886 fn test_find_parent_binary() {
1887 let left = make_column("a");
1888 let right = make_literal(1);
1889 let expr = Expression::Eq(Box::new(BinaryOp {
1890 left,
1891 right,
1892 left_comments: vec![],
1893 operator_comments: vec![],
1894 trailing_comments: vec![],
1895 inferred_type: None,
1896 }));
1897
1898 let col = expr.find(is_column).unwrap();
1900 let parent = super::find_parent(&expr, col);
1901 assert!(parent.is_some());
1902 assert!(matches!(parent.unwrap(), Expression::Eq(_)));
1903 }
1904
1905 #[test]
1906 fn test_find_parent_root_has_none() {
1907 let lit = make_literal(1);
1908 let parent = super::find_parent(&lit, &lit);
1909 assert!(parent.is_none());
1910 }
1911
1912 #[test]
1915 fn test_find_ancestor_select() {
1916 let ast = crate::parser::Parser::parse_sql("SELECT a FROM t WHERE a > 1").unwrap();
1917 let expr = &ast[0];
1918
1919 let where_col = expr.dfs().find(|e| {
1921 if let Expression::Column(c) = e {
1922 c.name.name == "a"
1923 } else {
1924 false
1925 }
1926 });
1927 assert!(where_col.is_some());
1928
1929 let ancestor = super::find_ancestor(expr, where_col.unwrap(), is_select);
1931 assert!(ancestor.is_some());
1932 assert!(matches!(ancestor.unwrap(), Expression::Select(_)));
1933 }
1934
1935 #[test]
1936 fn test_find_ancestor_no_match() {
1937 let left = make_column("a");
1938 let right = make_literal(1);
1939 let expr = Expression::Eq(Box::new(BinaryOp {
1940 left,
1941 right,
1942 left_comments: vec![],
1943 operator_comments: vec![],
1944 trailing_comments: vec![],
1945 inferred_type: None,
1946 }));
1947
1948 let col = expr.find(is_column).unwrap();
1949 let ancestor = super::find_ancestor(&expr, col, is_select);
1950 assert!(ancestor.is_none());
1951 }
1952
1953 #[test]
1954 fn test_ancestors() {
1955 let col = make_column("a");
1956 let lit = make_literal(1);
1957 let inner = Expression::Add(Box::new(BinaryOp {
1958 left: col,
1959 right: lit,
1960 left_comments: vec![],
1961 operator_comments: vec![],
1962 trailing_comments: vec![],
1963 inferred_type: None,
1964 }));
1965 let outer = Expression::Eq(Box::new(BinaryOp {
1966 left: make_column("b"),
1967 right: inner,
1968 left_comments: vec![],
1969 operator_comments: vec![],
1970 trailing_comments: vec![],
1971 inferred_type: None,
1972 }));
1973
1974 let ctx = TreeContext::build(&outer);
1975
1976 let ancestors = ctx.ancestors_of(3);
1984 assert_eq!(ancestors, vec![2, 0]); }
1986
1987 #[test]
1988 fn test_get_merge_target_and_source() {
1989 let dialect = crate::Dialect::get(crate::dialects::DialectType::Generic);
1990
1991 let sql = "MERGE INTO orders o USING customers c ON o.customer_id = c.id WHEN MATCHED THEN UPDATE SET amount = amount + 100";
1993 let exprs = dialect.parse(sql).unwrap();
1994 let expr = &exprs[0];
1995
1996 assert!(is_merge(expr));
1997 assert!(is_query(expr));
1998
1999 let target = get_merge_target(expr).expect("should find target table");
2000 assert!(matches!(target, Expression::Table(_)));
2001 if let Expression::Table(t) = target {
2002 assert_eq!(t.name.name, "orders");
2003 }
2004
2005 let source = get_merge_source(expr).expect("should find source table");
2006 assert!(matches!(source, Expression::Table(_)));
2007 if let Expression::Table(t) = source {
2008 assert_eq!(t.name.name, "customers");
2009 }
2010 }
2011
2012 #[test]
2013 fn test_get_merge_source_subquery_returns_none() {
2014 let dialect = crate::Dialect::get(crate::dialects::DialectType::Generic);
2015
2016 let sql = "MERGE INTO orders o USING (SELECT * FROM customers) c ON o.customer_id = c.id WHEN MATCHED THEN DELETE";
2018 let exprs = dialect.parse(sql).unwrap();
2019 let expr = &exprs[0];
2020
2021 assert!(get_merge_target(expr).is_some());
2022 assert!(get_merge_source(expr).is_none());
2023 }
2024
2025 #[test]
2026 fn test_get_merge_on_non_merge_returns_none() {
2027 let dialect = crate::Dialect::get(crate::dialects::DialectType::Generic);
2028 let exprs = dialect.parse("SELECT 1").unwrap();
2029 assert!(get_merge_target(&exprs[0]).is_none());
2030 assert!(get_merge_source(&exprs[0]).is_none());
2031 }
2032
2033 #[test]
2034 fn test_get_tables_finds_tables_inside_in_subquery() {
2035 let dialect = crate::Dialect::get(crate::dialects::DialectType::Generic);
2036 let sql = "SELECT id, name FROM customers WHERE id IN (SELECT customer_id FROM orders WHERE amount > 1000)";
2037 let exprs = dialect.parse(sql).unwrap();
2038 let tables = get_tables(&exprs[0]);
2039 let names: Vec<&str> = tables
2040 .iter()
2041 .filter_map(|e| {
2042 if let Expression::Table(t) = e {
2043 Some(t.name.name.as_str())
2044 } else {
2045 None
2046 }
2047 })
2048 .collect();
2049 assert!(names.contains(&"customers"), "should find outer table");
2050 assert!(names.contains(&"orders"), "should find subquery table");
2051 }
2052
2053 #[test]
2054 fn test_get_tables_finds_tables_inside_exists_subquery() {
2055 let dialect = crate::Dialect::get(crate::dialects::DialectType::Generic);
2056 let sql = "SELECT * FROM customers c WHERE EXISTS (SELECT 1 FROM orders o WHERE o.customer_id = c.id)";
2057 let exprs = dialect.parse(sql).unwrap();
2058 let tables = get_tables(&exprs[0]);
2059 let names: Vec<&str> = tables
2060 .iter()
2061 .filter_map(|e| {
2062 if let Expression::Table(t) = e {
2063 Some(t.name.name.as_str())
2064 } else {
2065 None
2066 }
2067 })
2068 .collect();
2069 assert!(names.contains(&"customers"), "should find outer table");
2070 assert!(
2071 names.contains(&"orders"),
2072 "should find EXISTS subquery table"
2073 );
2074 }
2075
2076 #[test]
2077 fn test_get_tables_finds_tables_in_correlated_subquery() {
2078 let dialect = crate::Dialect::get(crate::dialects::DialectType::TSQL);
2079 let sql = "SELECT id, name FROM customers WHERE id IN (SELECT customer_id FROM orders WHERE amount > 1000)";
2080 let exprs = dialect.parse(sql).unwrap();
2081 let tables = get_tables(&exprs[0]);
2082 let names: Vec<&str> = tables
2083 .iter()
2084 .filter_map(|e| {
2085 if let Expression::Table(t) = e {
2086 Some(t.name.name.as_str())
2087 } else {
2088 None
2089 }
2090 })
2091 .collect();
2092 assert!(
2093 names.contains(&"customers"),
2094 "TSQL: should find outer table"
2095 );
2096 assert!(
2097 names.contains(&"orders"),
2098 "TSQL: should find subquery table"
2099 );
2100 }
2101}