Skip to main content

polyglot_sql/
traversal.rs

1//! Tree traversal utilities for SQL expression ASTs.
2//!
3//! This module provides read-only traversal, search, and transformation utilities
4//! for the [`Expression`] tree produced by the parser. Because Rust's ownership
5//! model does not allow parent pointers inside the AST, parent information is
6//! tracked externally via [`TreeContext`] (built on demand).
7//!
8//! # Traversal
9//!
10//! Two iterator types are provided:
11//! - [`DfsIter`] -- depth-first (pre-order) traversal using a stack. Visits a node
12//!   before its children. Good for top-down analysis and early termination.
13//! - [`BfsIter`] -- breadth-first (level-order) traversal using a queue. Visits all
14//!   nodes at depth N before any node at depth N+1. Good for level-aware analysis.
15//!
16//! Both are available through the [`ExpressionWalk`] trait methods [`dfs`](ExpressionWalk::dfs)
17//! and [`bfs`](ExpressionWalk::bfs).
18//!
19//! # Searching
20//!
21//! The [`ExpressionWalk`] trait also provides convenience methods for finding expressions:
22//! [`find`](ExpressionWalk::find), [`find_all`](ExpressionWalk::find_all),
23//! [`contains`](ExpressionWalk::contains), and [`count`](ExpressionWalk::count).
24//! Common predicates are available as free functions: [`is_column`], [`is_literal`],
25//! [`is_function`], [`is_aggregate`], [`is_window_function`], [`is_subquery`], and
26//! [`is_select`].
27//!
28//! # Transformation
29//!
30//! The [`transform`] and [`transform_map`] functions perform bottom-up (post-order)
31//! tree rewrites, delegating to [`transform_recursive`](crate::dialects::transform_recursive).
32//! The [`ExpressionWalk::transform_owned`] method provides the same capability as
33//! an owned method on `Expression`.
34//!
35//! Based on traversal patterns from `sqlglot/expressions.py`.
36
37use crate::expressions::{Expression, TableRef};
38use std::collections::{HashMap, VecDeque};
39
40/// Unique identifier for expression nodes during traversal
41pub type NodeId = usize;
42
43/// Information about a node's parent relationship
44#[derive(Debug, Clone)]
45pub struct ParentInfo {
46    /// The NodeId of the parent (None for root)
47    pub parent_id: Option<NodeId>,
48    /// Which argument/field in the parent this node occupies
49    pub arg_key: String,
50    /// Index if the node is part of a list (e.g., expressions in SELECT)
51    pub index: Option<usize>,
52}
53
54/// External parent-tracking context for an expression tree.
55///
56/// Since Rust's ownership model does not allow intrusive parent pointers in the AST,
57/// `TreeContext` provides an on-demand side-table that maps each node (identified by
58/// a [`NodeId`]) to its [`ParentInfo`] (parent node, field name, and list index).
59///
60/// Build a context from any expression root with [`TreeContext::build`], then query
61/// parent relationships with [`get`](TreeContext::get), ancestry chains with
62/// [`ancestors_of`](TreeContext::ancestors_of), or tree depth with
63/// [`depth_of`](TreeContext::depth_of).
64///
65/// This is useful when analysis requires upward navigation (e.g., determining whether
66/// a column reference appears inside a WHERE clause or a JOIN condition).
67#[derive(Debug, Default)]
68pub struct TreeContext {
69    /// Map from NodeId to parent information
70    nodes: HashMap<NodeId, ParentInfo>,
71    /// Counter for generating NodeIds
72    next_id: NodeId,
73    /// Stack for tracking current path during traversal
74    path: Vec<(NodeId, String, Option<usize>)>,
75}
76
77impl TreeContext {
78    /// Create a new empty tree context
79    pub fn new() -> Self {
80        Self::default()
81    }
82
83    /// Build context from an expression tree
84    pub fn build(root: &Expression) -> Self {
85        let mut ctx = Self::new();
86        ctx.visit_expr(root);
87        ctx
88    }
89
90    /// Visit an expression and record parent information
91    fn visit_expr(&mut self, expr: &Expression) -> NodeId {
92        let id = self.next_id;
93        self.next_id += 1;
94
95        // Record parent info based on current path
96        let parent_info = if let Some((parent_id, arg_key, index)) = self.path.last() {
97            ParentInfo {
98                parent_id: Some(*parent_id),
99                arg_key: arg_key.clone(),
100                index: *index,
101            }
102        } else {
103            ParentInfo {
104                parent_id: None,
105                arg_key: String::new(),
106                index: None,
107            }
108        };
109        self.nodes.insert(id, parent_info);
110
111        // Visit children
112        for (key, child) in iter_children(expr) {
113            self.path.push((id, key.to_string(), None));
114            self.visit_expr(child);
115            self.path.pop();
116        }
117
118        // Visit children in lists
119        for (key, children) in iter_children_lists(expr) {
120            for (idx, child) in children.iter().enumerate() {
121                self.path.push((id, key.to_string(), Some(idx)));
122                self.visit_expr(child);
123                self.path.pop();
124            }
125        }
126
127        id
128    }
129
130    /// Get parent info for a node
131    pub fn get(&self, id: NodeId) -> Option<&ParentInfo> {
132        self.nodes.get(&id)
133    }
134
135    /// Get the depth of a node (0 for root)
136    pub fn depth_of(&self, id: NodeId) -> usize {
137        let mut depth = 0;
138        let mut current = id;
139        while let Some(info) = self.nodes.get(&current) {
140            if let Some(parent_id) = info.parent_id {
141                depth += 1;
142                current = parent_id;
143            } else {
144                break;
145            }
146        }
147        depth
148    }
149
150    /// Get ancestors of a node (parent, grandparent, etc.)
151    pub fn ancestors_of(&self, id: NodeId) -> Vec<NodeId> {
152        let mut ancestors = Vec::new();
153        let mut current = id;
154        while let Some(info) = self.nodes.get(&current) {
155            if let Some(parent_id) = info.parent_id {
156                ancestors.push(parent_id);
157                current = parent_id;
158            } else {
159                break;
160            }
161        }
162        ancestors
163    }
164}
165
166/// Iterate over single-child fields of an expression
167///
168/// Returns an iterator of (field_name, &Expression) pairs.
169fn iter_children(expr: &Expression) -> Vec<(&'static str, &Expression)> {
170    let mut children = Vec::new();
171
172    match expr {
173        Expression::Select(s) => {
174            if let Some(from) = &s.from {
175                for source in &from.expressions {
176                    children.push(("from", source));
177                }
178            }
179            for join in &s.joins {
180                children.push(("join_this", &join.this));
181                if let Some(on) = &join.on {
182                    children.push(("join_on", on));
183                }
184                if let Some(match_condition) = &join.match_condition {
185                    children.push(("join_match_condition", match_condition));
186                }
187                for pivot in &join.pivots {
188                    children.push(("join_pivot", pivot));
189                }
190            }
191            for lateral_view in &s.lateral_views {
192                children.push(("lateral_view", &lateral_view.this));
193            }
194            if let Some(prewhere) = &s.prewhere {
195                children.push(("prewhere", prewhere));
196            }
197            if let Some(where_clause) = &s.where_clause {
198                children.push(("where", &where_clause.this));
199            }
200            if let Some(group_by) = &s.group_by {
201                for e in &group_by.expressions {
202                    children.push(("group_by", e));
203                }
204            }
205            if let Some(having) = &s.having {
206                children.push(("having", &having.this));
207            }
208            if let Some(qualify) = &s.qualify {
209                children.push(("qualify", &qualify.this));
210            }
211            if let Some(order_by) = &s.order_by {
212                for ordered in &order_by.expressions {
213                    children.push(("order_by", &ordered.this));
214                }
215            }
216            if let Some(distribute_by) = &s.distribute_by {
217                for e in &distribute_by.expressions {
218                    children.push(("distribute_by", e));
219                }
220            }
221            if let Some(cluster_by) = &s.cluster_by {
222                for ordered in &cluster_by.expressions {
223                    children.push(("cluster_by", &ordered.this));
224                }
225            }
226            if let Some(sort_by) = &s.sort_by {
227                for ordered in &sort_by.expressions {
228                    children.push(("sort_by", &ordered.this));
229                }
230            }
231            if let Some(limit) = &s.limit {
232                children.push(("limit", &limit.this));
233            }
234            if let Some(offset) = &s.offset {
235                children.push(("offset", &offset.this));
236            }
237            if let Some(limit_by) = &s.limit_by {
238                for e in limit_by {
239                    children.push(("limit_by", e));
240                }
241            }
242            if let Some(fetch) = &s.fetch {
243                if let Some(count) = &fetch.count {
244                    children.push(("fetch", count));
245                }
246            }
247            if let Some(top) = &s.top {
248                children.push(("top", &top.this));
249            }
250            if let Some(with) = &s.with {
251                for cte in &with.ctes {
252                    children.push(("with_cte", &cte.this));
253                }
254                if let Some(search) = &with.search {
255                    children.push(("with_search", search));
256                }
257            }
258            if let Some(sample) = &s.sample {
259                children.push(("sample_size", &sample.size));
260                if let Some(seed) = &sample.seed {
261                    children.push(("sample_seed", seed));
262                }
263                if let Some(offset) = &sample.offset {
264                    children.push(("sample_offset", offset));
265                }
266                if let Some(bucket_numerator) = &sample.bucket_numerator {
267                    children.push(("sample_bucket_numerator", bucket_numerator));
268                }
269                if let Some(bucket_denominator) = &sample.bucket_denominator {
270                    children.push(("sample_bucket_denominator", bucket_denominator));
271                }
272                if let Some(bucket_field) = &sample.bucket_field {
273                    children.push(("sample_bucket_field", bucket_field));
274                }
275            }
276            if let Some(connect) = &s.connect {
277                if let Some(start) = &connect.start {
278                    children.push(("connect_start", start));
279                }
280                children.push(("connect", &connect.connect));
281            }
282            if let Some(into) = &s.into {
283                children.push(("into", &into.this));
284            }
285            for lock in &s.locks {
286                for e in &lock.expressions {
287                    children.push(("lock_expression", e));
288                }
289                if let Some(wait) = &lock.wait {
290                    children.push(("lock_wait", wait));
291                }
292                if let Some(key) = &lock.key {
293                    children.push(("lock_key", key));
294                }
295                if let Some(update) = &lock.update {
296                    children.push(("lock_update", update));
297                }
298            }
299            for e in &s.for_xml {
300                children.push(("for_xml", e));
301            }
302        }
303        Expression::With(with) => {
304            for cte in &with.ctes {
305                children.push(("cte", &cte.this));
306            }
307            if let Some(search) = &with.search {
308                children.push(("search", search));
309            }
310        }
311        Expression::Cte(cte) => {
312            children.push(("this", &cte.this));
313        }
314        Expression::Insert(insert) => {
315            if let Some(query) = &insert.query {
316                children.push(("query", query));
317            }
318            if let Some(with) = &insert.with {
319                for cte in &with.ctes {
320                    children.push(("with_cte", &cte.this));
321                }
322                if let Some(search) = &with.search {
323                    children.push(("with_search", search));
324                }
325            }
326            if let Some(on_conflict) = &insert.on_conflict {
327                children.push(("on_conflict", on_conflict));
328            }
329            if let Some(replace_where) = &insert.replace_where {
330                children.push(("replace_where", replace_where));
331            }
332            if let Some(source) = &insert.source {
333                children.push(("source", source));
334            }
335            if let Some(function_target) = &insert.function_target {
336                children.push(("function_target", function_target));
337            }
338            if let Some(partition_by) = &insert.partition_by {
339                children.push(("partition_by", partition_by));
340            }
341            if let Some(output) = &insert.output {
342                for column in &output.columns {
343                    children.push(("output_column", column));
344                }
345                if let Some(into_table) = &output.into_table {
346                    children.push(("output_into_table", into_table));
347                }
348            }
349            for row in &insert.values {
350                for value in row {
351                    children.push(("value", value));
352                }
353            }
354            for (_, value) in &insert.partition {
355                if let Some(value) = value {
356                    children.push(("partition_value", value));
357                }
358            }
359            for returning in &insert.returning {
360                children.push(("returning", returning));
361            }
362            for setting in &insert.settings {
363                children.push(("setting", setting));
364            }
365        }
366        Expression::Update(update) => {
367            if let Some(from_clause) = &update.from_clause {
368                for source in &from_clause.expressions {
369                    children.push(("from", source));
370                }
371            }
372            for join in &update.table_joins {
373                children.push(("table_join_this", &join.this));
374                if let Some(on) = &join.on {
375                    children.push(("table_join_on", on));
376                }
377            }
378            for join in &update.from_joins {
379                children.push(("from_join_this", &join.this));
380                if let Some(on) = &join.on {
381                    children.push(("from_join_on", on));
382                }
383            }
384            for (_, value) in &update.set {
385                children.push(("set_value", value));
386            }
387            if let Some(where_clause) = &update.where_clause {
388                children.push(("where", &where_clause.this));
389            }
390            if let Some(output) = &update.output {
391                for column in &output.columns {
392                    children.push(("output_column", column));
393                }
394                if let Some(into_table) = &output.into_table {
395                    children.push(("output_into_table", into_table));
396                }
397            }
398            if let Some(with) = &update.with {
399                for cte in &with.ctes {
400                    children.push(("with_cte", &cte.this));
401                }
402                if let Some(search) = &with.search {
403                    children.push(("with_search", search));
404                }
405            }
406            if let Some(limit) = &update.limit {
407                children.push(("limit", limit));
408            }
409            if let Some(order_by) = &update.order_by {
410                for ordered in &order_by.expressions {
411                    children.push(("order_by", &ordered.this));
412                }
413            }
414            for returning in &update.returning {
415                children.push(("returning", returning));
416            }
417        }
418        Expression::Delete(delete) => {
419            if let Some(with) = &delete.with {
420                for cte in &with.ctes {
421                    children.push(("with_cte", &cte.this));
422                }
423                if let Some(search) = &with.search {
424                    children.push(("with_search", search));
425                }
426            }
427            if let Some(where_clause) = &delete.where_clause {
428                children.push(("where", &where_clause.this));
429            }
430            if let Some(output) = &delete.output {
431                for column in &output.columns {
432                    children.push(("output_column", column));
433                }
434                if let Some(into_table) = &output.into_table {
435                    children.push(("output_into_table", into_table));
436                }
437            }
438            if let Some(limit) = &delete.limit {
439                children.push(("limit", limit));
440            }
441            if let Some(order_by) = &delete.order_by {
442                for ordered in &order_by.expressions {
443                    children.push(("order_by", &ordered.this));
444                }
445            }
446            for returning in &delete.returning {
447                children.push(("returning", returning));
448            }
449            for join in &delete.joins {
450                children.push(("join_this", &join.this));
451                if let Some(on) = &join.on {
452                    children.push(("join_on", on));
453                }
454            }
455        }
456        Expression::Join(join) => {
457            children.push(("this", &join.this));
458            if let Some(on) = &join.on {
459                children.push(("on", on));
460            }
461            if let Some(match_condition) = &join.match_condition {
462                children.push(("match_condition", match_condition));
463            }
464            for pivot in &join.pivots {
465                children.push(("pivot", pivot));
466            }
467        }
468        Expression::Alias(a) => {
469            children.push(("this", &a.this));
470        }
471        Expression::Cast(c) => {
472            children.push(("this", &c.this));
473        }
474        Expression::Not(u) | Expression::Neg(u) | Expression::BitwiseNot(u) => {
475            children.push(("this", &u.this));
476        }
477        Expression::Paren(p) => {
478            children.push(("this", &p.this));
479        }
480        Expression::IsNull(i) => {
481            children.push(("this", &i.this));
482        }
483        Expression::Exists(e) => {
484            children.push(("this", &e.this));
485        }
486        Expression::Subquery(s) => {
487            children.push(("this", &s.this));
488        }
489        Expression::Where(w) => {
490            children.push(("this", &w.this));
491        }
492        Expression::Having(h) => {
493            children.push(("this", &h.this));
494        }
495        Expression::Qualify(q) => {
496            children.push(("this", &q.this));
497        }
498        Expression::And(op)
499        | Expression::Or(op)
500        | Expression::Add(op)
501        | Expression::Sub(op)
502        | Expression::Mul(op)
503        | Expression::Div(op)
504        | Expression::Mod(op)
505        | Expression::Eq(op)
506        | Expression::Neq(op)
507        | Expression::Lt(op)
508        | Expression::Lte(op)
509        | Expression::Gt(op)
510        | Expression::Gte(op)
511        | Expression::BitwiseAnd(op)
512        | Expression::BitwiseOr(op)
513        | Expression::BitwiseXor(op)
514        | Expression::Concat(op) => {
515            children.push(("left", &op.left));
516            children.push(("right", &op.right));
517        }
518        Expression::Like(op) | Expression::ILike(op) => {
519            children.push(("left", &op.left));
520            children.push(("right", &op.right));
521        }
522        Expression::Between(b) => {
523            children.push(("this", &b.this));
524            children.push(("low", &b.low));
525            children.push(("high", &b.high));
526        }
527        Expression::In(i) => {
528            children.push(("this", &i.this));
529            if let Some(ref query) = i.query {
530                children.push(("query", query));
531            }
532            if let Some(ref unnest) = i.unnest {
533                children.push(("unnest", unnest));
534            }
535        }
536        Expression::Case(c) => {
537            if let Some(ref operand) = &c.operand {
538                children.push(("operand", operand));
539            }
540        }
541        Expression::WindowFunction(wf) => {
542            children.push(("this", &wf.this));
543        }
544        Expression::Union(u) => {
545            children.push(("left", &u.left));
546            children.push(("right", &u.right));
547            if let Some(with) = &u.with {
548                for cte in &with.ctes {
549                    children.push(("with_cte", &cte.this));
550                }
551                if let Some(search) = &with.search {
552                    children.push(("with_search", search));
553                }
554            }
555            if let Some(order_by) = &u.order_by {
556                for ordered in &order_by.expressions {
557                    children.push(("order_by", &ordered.this));
558                }
559            }
560            if let Some(limit) = &u.limit {
561                children.push(("limit", limit));
562            }
563            if let Some(offset) = &u.offset {
564                children.push(("offset", offset));
565            }
566            if let Some(distribute_by) = &u.distribute_by {
567                for e in &distribute_by.expressions {
568                    children.push(("distribute_by", e));
569                }
570            }
571            if let Some(sort_by) = &u.sort_by {
572                for ordered in &sort_by.expressions {
573                    children.push(("sort_by", &ordered.this));
574                }
575            }
576            if let Some(cluster_by) = &u.cluster_by {
577                for ordered in &cluster_by.expressions {
578                    children.push(("cluster_by", &ordered.this));
579                }
580            }
581            for e in &u.on_columns {
582                children.push(("on_column", e));
583            }
584        }
585        Expression::Intersect(i) => {
586            children.push(("left", &i.left));
587            children.push(("right", &i.right));
588            if let Some(with) = &i.with {
589                for cte in &with.ctes {
590                    children.push(("with_cte", &cte.this));
591                }
592                if let Some(search) = &with.search {
593                    children.push(("with_search", search));
594                }
595            }
596            if let Some(order_by) = &i.order_by {
597                for ordered in &order_by.expressions {
598                    children.push(("order_by", &ordered.this));
599                }
600            }
601            if let Some(limit) = &i.limit {
602                children.push(("limit", limit));
603            }
604            if let Some(offset) = &i.offset {
605                children.push(("offset", offset));
606            }
607            if let Some(distribute_by) = &i.distribute_by {
608                for e in &distribute_by.expressions {
609                    children.push(("distribute_by", e));
610                }
611            }
612            if let Some(sort_by) = &i.sort_by {
613                for ordered in &sort_by.expressions {
614                    children.push(("sort_by", &ordered.this));
615                }
616            }
617            if let Some(cluster_by) = &i.cluster_by {
618                for ordered in &cluster_by.expressions {
619                    children.push(("cluster_by", &ordered.this));
620                }
621            }
622            for e in &i.on_columns {
623                children.push(("on_column", e));
624            }
625        }
626        Expression::Except(e) => {
627            children.push(("left", &e.left));
628            children.push(("right", &e.right));
629            if let Some(with) = &e.with {
630                for cte in &with.ctes {
631                    children.push(("with_cte", &cte.this));
632                }
633                if let Some(search) = &with.search {
634                    children.push(("with_search", search));
635                }
636            }
637            if let Some(order_by) = &e.order_by {
638                for ordered in &order_by.expressions {
639                    children.push(("order_by", &ordered.this));
640                }
641            }
642            if let Some(limit) = &e.limit {
643                children.push(("limit", limit));
644            }
645            if let Some(offset) = &e.offset {
646                children.push(("offset", offset));
647            }
648            if let Some(distribute_by) = &e.distribute_by {
649                for expr in &distribute_by.expressions {
650                    children.push(("distribute_by", expr));
651                }
652            }
653            if let Some(sort_by) = &e.sort_by {
654                for ordered in &sort_by.expressions {
655                    children.push(("sort_by", &ordered.this));
656                }
657            }
658            if let Some(cluster_by) = &e.cluster_by {
659                for ordered in &cluster_by.expressions {
660                    children.push(("cluster_by", &ordered.this));
661                }
662            }
663            for expr in &e.on_columns {
664                children.push(("on_column", expr));
665            }
666        }
667        Expression::Merge(merge) => {
668            children.push(("this", &merge.this));
669            children.push(("using", &merge.using));
670            if let Some(on) = &merge.on {
671                children.push(("on", on));
672            }
673            if let Some(using_cond) = &merge.using_cond {
674                children.push(("using_cond", using_cond));
675            }
676            if let Some(whens) = &merge.whens {
677                children.push(("whens", whens));
678            }
679            if let Some(with_) = &merge.with_ {
680                children.push(("with_", with_));
681            }
682            if let Some(returning) = &merge.returning {
683                children.push(("returning", returning));
684            }
685        }
686        Expression::Any(q) | Expression::All(q) => {
687            children.push(("this", &q.this));
688            children.push(("subquery", &q.subquery));
689        }
690        Expression::Ordered(o) => {
691            children.push(("this", &o.this));
692        }
693        Expression::Interval(i) => {
694            if let Some(ref this) = i.this {
695                children.push(("this", this));
696            }
697        }
698        Expression::Describe(d) => {
699            children.push(("target", &d.target));
700        }
701        Expression::CreateTask(ct) => {
702            children.push(("body", &ct.body));
703        }
704        Expression::Prepare(prepare) => {
705            children.push(("statement", &prepare.statement));
706        }
707        Expression::Execute(exec) => {
708            children.push(("this", &exec.this));
709            for argument in &exec.arguments {
710                children.push(("argument", argument));
711            }
712            for parameter in &exec.parameters {
713                children.push(("parameter", &parameter.value));
714            }
715        }
716        Expression::Analyze(a) => {
717            if let Some(this) = &a.this {
718                children.push(("this", this));
719            }
720            if let Some(expr) = &a.expression {
721                children.push(("expression", expr));
722            }
723        }
724        _ => {}
725    }
726
727    children
728}
729
730/// Iterate over list-child fields of an expression
731///
732/// Returns an iterator of (field_name, &[Expression]) pairs.
733fn iter_children_lists(expr: &Expression) -> Vec<(&'static str, &[Expression])> {
734    let mut lists = Vec::new();
735
736    match expr {
737        Expression::Select(s) => lists.push(("expressions", s.expressions.as_slice())),
738        Expression::Function(f) => {
739            lists.push(("args", f.args.as_slice()));
740        }
741        Expression::AggregateFunction(f) => {
742            lists.push(("args", f.args.as_slice()));
743        }
744        Expression::From(f) => {
745            lists.push(("expressions", f.expressions.as_slice()));
746        }
747        Expression::GroupBy(g) => {
748            lists.push(("expressions", g.expressions.as_slice()));
749        }
750        // OrderBy.expressions is Vec<Ordered>, not Vec<Expression>
751        // We handle Ordered items via iter_children
752        Expression::In(i) => {
753            lists.push(("expressions", i.expressions.as_slice()));
754        }
755        Expression::Array(a) => {
756            lists.push(("expressions", a.expressions.as_slice()));
757        }
758        Expression::Tuple(t) => {
759            lists.push(("expressions", t.expressions.as_slice()));
760        }
761        Expression::TryCatch(try_catch) => {
762            lists.push(("try_body", try_catch.try_body.as_slice()));
763            if let Some(catch_body) = &try_catch.catch_body {
764                lists.push(("catch_body", catch_body.as_slice()));
765            }
766        }
767        // Values.expressions is Vec<Tuple>, handle specially
768        Expression::Coalesce(c) => {
769            lists.push(("expressions", c.expressions.as_slice()));
770        }
771        Expression::Greatest(g) | Expression::Least(g) => {
772            lists.push(("expressions", g.expressions.as_slice()));
773        }
774        _ => {}
775    }
776
777    lists
778}
779
780/// Pre-order depth-first iterator over an expression tree.
781///
782/// Visits each node before its children, using a stack-based approach. This means
783/// the root is yielded first, followed by the entire left subtree (recursively),
784/// then the right subtree. For a binary expression `a + b`, the iteration order
785/// is: `Add`, `a`, `b`.
786///
787/// Created via [`ExpressionWalk::dfs`] or [`DfsIter::new`].
788pub struct DfsIter<'a> {
789    stack: Vec<&'a Expression>,
790}
791
792impl<'a> DfsIter<'a> {
793    /// Create a new DFS iterator starting from the given expression
794    pub fn new(root: &'a Expression) -> Self {
795        Self { stack: vec![root] }
796    }
797}
798
799impl<'a> Iterator for DfsIter<'a> {
800    type Item = &'a Expression;
801
802    fn next(&mut self) -> Option<Self::Item> {
803        let expr = self.stack.pop()?;
804
805        // Add children in reverse order so they come out in forward order
806        let children: Vec<_> = iter_children(expr).into_iter().map(|(_, e)| e).collect();
807        for child in children.into_iter().rev() {
808            self.stack.push(child);
809        }
810
811        let lists: Vec<_> = iter_children_lists(expr)
812            .into_iter()
813            .flat_map(|(_, es)| es.iter())
814            .collect();
815        for child in lists.into_iter().rev() {
816            self.stack.push(child);
817        }
818
819        Some(expr)
820    }
821}
822
823/// Level-order breadth-first iterator over an expression tree.
824///
825/// Visits all nodes at depth N before any node at depth N+1, using a queue-based
826/// approach. For a tree `(a + b) = c`, the iteration order is: `Eq` (depth 0),
827/// `Add`, `c` (depth 1), `a`, `b` (depth 2).
828///
829/// Created via [`ExpressionWalk::bfs`] or [`BfsIter::new`].
830pub struct BfsIter<'a> {
831    queue: VecDeque<&'a Expression>,
832}
833
834impl<'a> BfsIter<'a> {
835    /// Create a new BFS iterator starting from the given expression
836    pub fn new(root: &'a Expression) -> Self {
837        let mut queue = VecDeque::new();
838        queue.push_back(root);
839        Self { queue }
840    }
841}
842
843impl<'a> Iterator for BfsIter<'a> {
844    type Item = &'a Expression;
845
846    fn next(&mut self) -> Option<Self::Item> {
847        let expr = self.queue.pop_front()?;
848
849        // Add children to queue
850        for (_, child) in iter_children(expr) {
851            self.queue.push_back(child);
852        }
853
854        for (_, children) in iter_children_lists(expr) {
855            for child in children {
856                self.queue.push_back(child);
857            }
858        }
859
860        Some(expr)
861    }
862}
863
864/// Extension trait that adds traversal and search methods to [`Expression`].
865///
866/// This trait is implemented for `Expression` and provides a fluent API for
867/// iterating, searching, measuring, and transforming expression trees without
868/// needing to import the iterator types directly.
869pub trait ExpressionWalk {
870    /// Returns a depth-first (pre-order) iterator over this expression and all descendants.
871    ///
872    /// The root node is yielded first, then its children are visited recursively
873    /// from left to right.
874    fn dfs(&self) -> DfsIter<'_>;
875
876    /// Returns a breadth-first (level-order) iterator over this expression and all descendants.
877    ///
878    /// All nodes at depth N are yielded before any node at depth N+1.
879    fn bfs(&self) -> BfsIter<'_>;
880
881    /// Finds the first expression matching `predicate` in depth-first order.
882    ///
883    /// Returns `None` if no descendant (including this node) matches.
884    fn find<F>(&self, predicate: F) -> Option<&Expression>
885    where
886        F: Fn(&Expression) -> bool;
887
888    /// Collects all expressions matching `predicate` in depth-first order.
889    ///
890    /// Returns an empty vector if no descendants match.
891    fn find_all<F>(&self, predicate: F) -> Vec<&Expression>
892    where
893        F: Fn(&Expression) -> bool;
894
895    /// Returns `true` if this node or any descendant matches `predicate`.
896    fn contains<F>(&self, predicate: F) -> bool
897    where
898        F: Fn(&Expression) -> bool;
899
900    /// Counts how many nodes (including this one) match `predicate`.
901    fn count<F>(&self, predicate: F) -> usize
902    where
903        F: Fn(&Expression) -> bool;
904
905    /// Returns direct child expressions of this node.
906    ///
907    /// Collects all single-child fields and list-child fields into a flat vector
908    /// of references. Leaf nodes return an empty vector.
909    fn children(&self) -> Vec<&Expression>;
910
911    /// Returns the maximum depth of the expression tree rooted at this node.
912    ///
913    /// A leaf node has depth 0, a node whose deepest child is a leaf has depth 1, etc.
914    fn tree_depth(&self) -> usize;
915
916    /// Transforms this expression tree bottom-up using the given function (owned variant).
917    ///
918    /// Children are transformed first, then `fun` is called on the resulting node.
919    /// Return `Ok(None)` from `fun` to replace a node with `NULL`.
920    /// Return `Ok(Some(expr))` to substitute the node with `expr`.
921    fn transform_owned<F>(self, fun: F) -> crate::Result<Expression>
922    where
923        F: Fn(Expression) -> crate::Result<Option<Expression>>,
924        Self: Sized;
925}
926
927impl ExpressionWalk for Expression {
928    fn dfs(&self) -> DfsIter<'_> {
929        DfsIter::new(self)
930    }
931
932    fn bfs(&self) -> BfsIter<'_> {
933        BfsIter::new(self)
934    }
935
936    fn find<F>(&self, predicate: F) -> Option<&Expression>
937    where
938        F: Fn(&Expression) -> bool,
939    {
940        self.dfs().find(|e| predicate(e))
941    }
942
943    fn find_all<F>(&self, predicate: F) -> Vec<&Expression>
944    where
945        F: Fn(&Expression) -> bool,
946    {
947        self.dfs().filter(|e| predicate(e)).collect()
948    }
949
950    fn contains<F>(&self, predicate: F) -> bool
951    where
952        F: Fn(&Expression) -> bool,
953    {
954        self.dfs().any(|e| predicate(e))
955    }
956
957    fn count<F>(&self, predicate: F) -> usize
958    where
959        F: Fn(&Expression) -> bool,
960    {
961        self.dfs().filter(|e| predicate(e)).count()
962    }
963
964    fn children(&self) -> Vec<&Expression> {
965        let mut result: Vec<&Expression> = Vec::new();
966        for (_, child) in iter_children(self) {
967            result.push(child);
968        }
969        for (_, children_list) in iter_children_lists(self) {
970            for child in children_list {
971                result.push(child);
972            }
973        }
974        result
975    }
976
977    fn tree_depth(&self) -> usize {
978        let mut max_depth = 0;
979
980        for (_, child) in iter_children(self) {
981            let child_depth = child.tree_depth();
982            if child_depth + 1 > max_depth {
983                max_depth = child_depth + 1;
984            }
985        }
986
987        for (_, children) in iter_children_lists(self) {
988            for child in children {
989                let child_depth = child.tree_depth();
990                if child_depth + 1 > max_depth {
991                    max_depth = child_depth + 1;
992                }
993            }
994        }
995
996        max_depth
997    }
998
999    fn transform_owned<F>(self, fun: F) -> crate::Result<Expression>
1000    where
1001        F: Fn(Expression) -> crate::Result<Option<Expression>>,
1002    {
1003        transform(self, &fun)
1004    }
1005}
1006
1007/// Transforms an expression tree bottom-up, with optional node removal.
1008///
1009/// Recursively transforms all children first, then applies `fun` to the resulting node.
1010/// If `fun` returns `Ok(None)`, the node is replaced with an `Expression::Null`.
1011/// If `fun` returns `Ok(Some(expr))`, the node is replaced with `expr`.
1012///
1013/// This is the primary transformation entry point when callers need the ability to
1014/// "delete" nodes by returning `None`.
1015///
1016/// # Example
1017///
1018/// ```rust,ignore
1019/// use polyglot_sql::traversal::transform;
1020///
1021/// // Remove all Paren wrapper nodes from a tree
1022/// let result = transform(expr, &|e| match e {
1023///     Expression::Paren(p) => Ok(Some(p.this)),
1024///     other => Ok(Some(other)),
1025/// })?;
1026/// ```
1027pub fn transform<F>(expr: Expression, fun: &F) -> crate::Result<Expression>
1028where
1029    F: Fn(Expression) -> crate::Result<Option<Expression>>,
1030{
1031    crate::dialects::transform_recursive(expr, &|e| match fun(e)? {
1032        Some(transformed) => Ok(transformed),
1033        None => Ok(Expression::Null(crate::expressions::Null)),
1034    })
1035}
1036
1037/// Transforms an expression tree bottom-up without node removal.
1038///
1039/// Like [`transform`], but `fun` returns an `Expression` directly rather than
1040/// `Option<Expression>`, so nodes cannot be deleted. This is a convenience wrapper
1041/// for the common case where every node is mapped to exactly one output node.
1042///
1043/// # Example
1044///
1045/// ```rust,ignore
1046/// use polyglot_sql::traversal::transform_map;
1047///
1048/// // Uppercase all column names in a tree
1049/// let result = transform_map(expr, &|e| match e {
1050///     Expression::Column(mut c) => {
1051///         c.name.name = c.name.name.to_uppercase();
1052///         Ok(Expression::Column(c))
1053///     }
1054///     other => Ok(other),
1055/// })?;
1056/// ```
1057pub fn transform_map<F>(expr: Expression, fun: &F) -> crate::Result<Expression>
1058where
1059    F: Fn(Expression) -> crate::Result<Expression>,
1060{
1061    crate::dialects::transform_recursive(expr, fun)
1062}
1063
1064// ---------------------------------------------------------------------------
1065// Common expression predicates
1066// ---------------------------------------------------------------------------
1067// These free functions are intended for use with the search methods on
1068// `ExpressionWalk` (e.g., `expr.find(is_column)`, `expr.contains(is_aggregate)`).
1069
1070/// Returns `true` if `expr` is a column reference ([`Expression::Column`]).
1071pub fn is_column(expr: &Expression) -> bool {
1072    matches!(expr, Expression::Column(_))
1073}
1074
1075/// Returns `true` if `expr` is a literal value (number, string, boolean, or NULL).
1076pub fn is_literal(expr: &Expression) -> bool {
1077    matches!(
1078        expr,
1079        Expression::Literal(_) | Expression::Boolean(_) | Expression::Null(_)
1080    )
1081}
1082
1083/// Returns `true` if `expr` is a function call (regular or aggregate).
1084pub fn is_function(expr: &Expression) -> bool {
1085    matches!(
1086        expr,
1087        Expression::Function(_) | Expression::AggregateFunction(_)
1088    )
1089}
1090
1091/// Returns `true` if `expr` is a subquery ([`Expression::Subquery`]).
1092pub fn is_subquery(expr: &Expression) -> bool {
1093    matches!(expr, Expression::Subquery(_))
1094}
1095
1096/// Returns `true` if `expr` is a SELECT statement ([`Expression::Select`]).
1097pub fn is_select(expr: &Expression) -> bool {
1098    matches!(expr, Expression::Select(_))
1099}
1100
1101/// Returns `true` if `expr` is an aggregate function.
1102pub fn is_aggregate(expr: &Expression) -> bool {
1103    matches!(
1104        expr,
1105        Expression::AggregateFunction(_)
1106            | Expression::Count(_)
1107            | Expression::Sum(_)
1108            | Expression::Avg(_)
1109            | Expression::Min(_)
1110            | Expression::Max(_)
1111            | Expression::GroupConcat(_)
1112            | Expression::StringAgg(_)
1113            | Expression::ListAgg(_)
1114            | Expression::CountIf(_)
1115            | Expression::SumIf(_)
1116    )
1117}
1118
1119/// Returns `true` if `expr` is a window function ([`Expression::WindowFunction`]).
1120pub fn is_window_function(expr: &Expression) -> bool {
1121    matches!(expr, Expression::WindowFunction(_))
1122}
1123
1124/// Collects all column references ([`Expression::Column`]) from the expression tree.
1125///
1126/// Performs a depth-first search and returns references to every column node found.
1127pub fn get_columns(expr: &Expression) -> Vec<&Expression> {
1128    expr.find_all(is_column)
1129}
1130
1131/// Collects all table references ([`Expression::Table`]) from the expression tree.
1132///
1133/// Performs a depth-first search and returns references to every table node found.
1134///
1135/// Note: DML target tables (`Insert.table`, `Update.table`, `Delete.table`) are
1136/// stored as `TableRef` struct fields, not as `Expression::Table` nodes, so they
1137/// are not reachable via tree traversal. Use [`get_all_tables`] to include those.
1138pub fn get_tables(expr: &Expression) -> Vec<&Expression> {
1139    expr.find_all(|e| matches!(e, Expression::Table(_)))
1140}
1141
1142/// Collects **all** referenced tables from the expression tree, including DML
1143/// target tables that are stored as `TableRef` struct fields and are therefore
1144/// not reachable through normal tree traversal.
1145///
1146/// Returns owned `Expression::Table` values. This is the comprehensive version
1147/// of [`get_tables`] — use it when you need to discover every table referenced
1148/// in a statement, including inside CTE bodies containing INSERT/UPDATE/DELETE.
1149pub fn get_all_tables(expr: &Expression) -> Vec<Expression> {
1150    use std::collections::HashSet;
1151
1152    let mut seen = HashSet::new();
1153    let mut result = Vec::new();
1154
1155    // First: collect all Expression::Table nodes found via DFS.
1156    for node in expr.dfs() {
1157        if let Expression::Table(t) = node {
1158            let qname = table_ref_qualified_name(t);
1159            if seen.insert(qname) {
1160                result.push(node.clone());
1161            }
1162        }
1163
1164        // Also extract DML target TableRef fields not reachable via iter_children.
1165        let refs: Vec<&TableRef> = match node {
1166            Expression::Insert(ins) => vec![&ins.table],
1167            Expression::Update(upd) => {
1168                let mut v = vec![&upd.table];
1169                v.extend(upd.extra_tables.iter());
1170                v
1171            }
1172            Expression::Delete(del) => {
1173                let mut v = vec![&del.table];
1174                v.extend(del.using.iter());
1175                v
1176            }
1177            _ => continue,
1178        };
1179        for tref in refs {
1180            if tref.name.name.is_empty() {
1181                continue;
1182            }
1183            let qname = table_ref_qualified_name(tref);
1184            if seen.insert(qname) {
1185                result.push(Expression::Table(Box::new(tref.clone())));
1186            }
1187        }
1188    }
1189
1190    result
1191}
1192
1193/// Build a qualified name string from a TableRef for deduplication purposes.
1194fn table_ref_qualified_name(t: &TableRef) -> String {
1195    let mut name = String::new();
1196    if let Some(ref cat) = t.catalog {
1197        name.push_str(&cat.name);
1198        name.push('.');
1199    }
1200    if let Some(ref schema) = t.schema {
1201        name.push_str(&schema.name);
1202        name.push('.');
1203    }
1204    name.push_str(&t.name.name);
1205    name
1206}
1207
1208/// Extracts the underlying [`Expression::Table`] from a MERGE field that may
1209/// be a bare `Table`, an `Alias` wrapping a `Table`, or an `Identifier`.
1210/// Returns `None` if the expression doesn't contain a recognisable table.
1211fn unwrap_merge_table(expr: &Expression) -> Option<&Expression> {
1212    match expr {
1213        Expression::Table(_) => Some(expr),
1214        Expression::Alias(alias) => match &alias.this {
1215            Expression::Table(_) => Some(&alias.this),
1216            _ => None,
1217        },
1218        _ => None,
1219    }
1220}
1221
1222/// Returns the target table of a MERGE statement (the `Merge.this` field),
1223/// unwrapping any alias wrapper to yield the underlying [`Expression::Table`].
1224///
1225/// Returns `None` if `expr` is not a `Merge` or the target isn't a recognisable table.
1226pub fn get_merge_target(expr: &Expression) -> Option<&Expression> {
1227    match expr {
1228        Expression::Merge(m) => unwrap_merge_table(&m.this),
1229        _ => None,
1230    }
1231}
1232
1233/// Returns the source table of a MERGE statement (the `Merge.using` field),
1234/// unwrapping any alias wrapper to yield the underlying [`Expression::Table`].
1235///
1236/// Returns `None` if `expr` is not a `Merge`, the source isn't a recognisable
1237/// table (e.g. it's a subquery), or the source is otherwise unresolvable.
1238pub fn get_merge_source(expr: &Expression) -> Option<&Expression> {
1239    match expr {
1240        Expression::Merge(m) => unwrap_merge_table(&m.using),
1241        _ => None,
1242    }
1243}
1244
1245/// Returns `true` if the expression tree contains any aggregate function calls.
1246pub fn contains_aggregate(expr: &Expression) -> bool {
1247    expr.contains(is_aggregate)
1248}
1249
1250/// Returns `true` if the expression tree contains any window function calls.
1251pub fn contains_window_function(expr: &Expression) -> bool {
1252    expr.contains(is_window_function)
1253}
1254
1255/// Returns `true` if the expression tree contains any subquery nodes.
1256pub fn contains_subquery(expr: &Expression) -> bool {
1257    expr.contains(is_subquery)
1258}
1259
1260// ---------------------------------------------------------------------------
1261// Extended type predicates
1262// ---------------------------------------------------------------------------
1263
1264/// Macro for generating simple type-predicate functions.
1265macro_rules! is_type {
1266    ($name:ident, $($variant:pat),+ $(,)?) => {
1267        /// Returns `true` if `expr` matches the expected AST variant(s).
1268        pub fn $name(expr: &Expression) -> bool {
1269            matches!(expr, $($variant)|+)
1270        }
1271    };
1272}
1273
1274// Query
1275is_type!(is_insert, Expression::Insert(_));
1276is_type!(is_update, Expression::Update(_));
1277is_type!(is_delete, Expression::Delete(_));
1278is_type!(is_merge, Expression::Merge(_));
1279is_type!(is_union, Expression::Union(_));
1280is_type!(is_intersect, Expression::Intersect(_));
1281is_type!(is_except, Expression::Except(_));
1282
1283// Identifiers & literals
1284is_type!(is_boolean, Expression::Boolean(_));
1285is_type!(is_null_literal, Expression::Null(_));
1286is_type!(is_star, Expression::Star(_));
1287is_type!(is_identifier, Expression::Identifier(_));
1288is_type!(is_table, Expression::Table(_));
1289
1290// Comparison
1291is_type!(is_eq, Expression::Eq(_));
1292is_type!(is_neq, Expression::Neq(_));
1293is_type!(is_lt, Expression::Lt(_));
1294is_type!(is_lte, Expression::Lte(_));
1295is_type!(is_gt, Expression::Gt(_));
1296is_type!(is_gte, Expression::Gte(_));
1297is_type!(is_like, Expression::Like(_));
1298is_type!(is_ilike, Expression::ILike(_));
1299
1300// Arithmetic
1301is_type!(is_add, Expression::Add(_));
1302is_type!(is_sub, Expression::Sub(_));
1303is_type!(is_mul, Expression::Mul(_));
1304is_type!(is_div, Expression::Div(_));
1305is_type!(is_mod, Expression::Mod(_));
1306is_type!(is_concat, Expression::Concat(_));
1307
1308// Logical
1309is_type!(is_and, Expression::And(_));
1310is_type!(is_or, Expression::Or(_));
1311is_type!(is_not, Expression::Not(_));
1312
1313// Predicates
1314is_type!(is_in, Expression::In(_));
1315is_type!(is_between, Expression::Between(_));
1316is_type!(is_is_null, Expression::IsNull(_));
1317is_type!(is_exists, Expression::Exists(_));
1318
1319// Functions
1320is_type!(is_count, Expression::Count(_));
1321is_type!(is_sum, Expression::Sum(_));
1322is_type!(is_avg, Expression::Avg(_));
1323is_type!(is_min_func, Expression::Min(_));
1324is_type!(is_max_func, Expression::Max(_));
1325is_type!(is_coalesce, Expression::Coalesce(_));
1326is_type!(is_null_if, Expression::NullIf(_));
1327is_type!(is_cast, Expression::Cast(_));
1328is_type!(is_try_cast, Expression::TryCast(_));
1329is_type!(is_safe_cast, Expression::SafeCast(_));
1330is_type!(is_case, Expression::Case(_));
1331
1332// Clauses
1333is_type!(is_from, Expression::From(_));
1334is_type!(is_join, Expression::Join(_));
1335is_type!(is_where, Expression::Where(_));
1336is_type!(is_group_by, Expression::GroupBy(_));
1337is_type!(is_having, Expression::Having(_));
1338is_type!(is_order_by, Expression::OrderBy(_));
1339is_type!(is_limit, Expression::Limit(_));
1340is_type!(is_offset, Expression::Offset(_));
1341is_type!(is_with, Expression::With(_));
1342is_type!(is_cte, Expression::Cte(_));
1343is_type!(is_alias, Expression::Alias(_));
1344is_type!(is_paren, Expression::Paren(_));
1345is_type!(is_ordered, Expression::Ordered(_));
1346
1347// DDL
1348is_type!(is_create_table, Expression::CreateTable(_));
1349is_type!(is_drop_table, Expression::DropTable(_));
1350is_type!(is_alter_table, Expression::AlterTable(_));
1351is_type!(is_create_index, Expression::CreateIndex(_));
1352is_type!(is_drop_index, Expression::DropIndex(_));
1353is_type!(is_create_view, Expression::CreateView(_));
1354is_type!(is_drop_view, Expression::DropView(_));
1355
1356// ---------------------------------------------------------------------------
1357// Composite predicates
1358// ---------------------------------------------------------------------------
1359
1360/// Returns `true` if `expr` is a query statement (SELECT, INSERT, UPDATE, DELETE, or MERGE).
1361pub fn is_query(expr: &Expression) -> bool {
1362    matches!(
1363        expr,
1364        Expression::Select(_)
1365            | Expression::Insert(_)
1366            | Expression::Update(_)
1367            | Expression::Delete(_)
1368            | Expression::Merge(_)
1369    )
1370}
1371
1372/// Returns `true` if `expr` is a set operation (UNION, INTERSECT, or EXCEPT).
1373pub fn is_set_operation(expr: &Expression) -> bool {
1374    matches!(
1375        expr,
1376        Expression::Union(_) | Expression::Intersect(_) | Expression::Except(_)
1377    )
1378}
1379
1380/// Returns `true` if `expr` is a comparison operator.
1381pub fn is_comparison(expr: &Expression) -> bool {
1382    matches!(
1383        expr,
1384        Expression::Eq(_)
1385            | Expression::Neq(_)
1386            | Expression::Lt(_)
1387            | Expression::Lte(_)
1388            | Expression::Gt(_)
1389            | Expression::Gte(_)
1390            | Expression::Like(_)
1391            | Expression::ILike(_)
1392    )
1393}
1394
1395/// Returns `true` if `expr` is an arithmetic operator.
1396pub fn is_arithmetic(expr: &Expression) -> bool {
1397    matches!(
1398        expr,
1399        Expression::Add(_)
1400            | Expression::Sub(_)
1401            | Expression::Mul(_)
1402            | Expression::Div(_)
1403            | Expression::Mod(_)
1404    )
1405}
1406
1407/// Returns `true` if `expr` is a logical operator (AND, OR, NOT).
1408pub fn is_logical(expr: &Expression) -> bool {
1409    matches!(
1410        expr,
1411        Expression::And(_) | Expression::Or(_) | Expression::Not(_)
1412    )
1413}
1414
1415/// Returns `true` if `expr` is a DDL statement.
1416pub fn is_ddl(expr: &Expression) -> bool {
1417    matches!(
1418        expr,
1419        Expression::CreateTable(_)
1420            | Expression::DropTable(_)
1421            | Expression::Undrop(_)
1422            | Expression::AlterTable(_)
1423            | Expression::CreateIndex(_)
1424            | Expression::DropIndex(_)
1425            | Expression::CreateView(_)
1426            | Expression::DropView(_)
1427            | Expression::AlterView(_)
1428            | Expression::CreateSchema(_)
1429            | Expression::DropSchema(_)
1430            | Expression::CreateDatabase(_)
1431            | Expression::DropDatabase(_)
1432            | Expression::CreateFunction(_)
1433            | Expression::DropFunction(_)
1434            | Expression::CreateProcedure(_)
1435            | Expression::DropProcedure(_)
1436            | Expression::CreateSequence(_)
1437            | Expression::CreateSynonym(_)
1438            | Expression::DropSequence(_)
1439            | Expression::AlterSequence(_)
1440            | Expression::CreateTrigger(_)
1441            | Expression::DropTrigger(_)
1442            | Expression::CreateType(_)
1443            | Expression::DropType(_)
1444    )
1445}
1446
1447/// Find the parent of `target` within the tree rooted at `root`.
1448///
1449/// Uses pointer identity ([`std::ptr::eq`]) — `target` must be a reference
1450/// obtained from the same tree (e.g., via [`ExpressionWalk::find`] or DFS iteration).
1451///
1452/// Returns `None` if `target` is the root itself or is not found in the tree.
1453pub fn find_parent<'a>(root: &'a Expression, target: &Expression) -> Option<&'a Expression> {
1454    fn search<'a>(node: &'a Expression, target: *const Expression) -> Option<&'a Expression> {
1455        for (_, child) in iter_children(node) {
1456            if std::ptr::eq(child, target) {
1457                return Some(node);
1458            }
1459            if let Some(found) = search(child, target) {
1460                return Some(found);
1461            }
1462        }
1463        for (_, children_list) in iter_children_lists(node) {
1464            for child in children_list {
1465                if std::ptr::eq(child, target) {
1466                    return Some(node);
1467                }
1468                if let Some(found) = search(child, target) {
1469                    return Some(found);
1470                }
1471            }
1472        }
1473        None
1474    }
1475
1476    search(root, target as *const Expression)
1477}
1478
1479/// Find the first ancestor of `target` matching `predicate`, walking from
1480/// parent toward root.
1481///
1482/// Uses pointer identity for target lookup. Returns `None` if no ancestor
1483/// matches or `target` is not found in the tree.
1484pub fn find_ancestor<'a, F>(
1485    root: &'a Expression,
1486    target: &Expression,
1487    predicate: F,
1488) -> Option<&'a Expression>
1489where
1490    F: Fn(&Expression) -> bool,
1491{
1492    // Build path from root to target
1493    fn build_path<'a>(
1494        node: &'a Expression,
1495        target: *const Expression,
1496        path: &mut Vec<&'a Expression>,
1497    ) -> bool {
1498        if std::ptr::eq(node, target) {
1499            return true;
1500        }
1501        path.push(node);
1502        for (_, child) in iter_children(node) {
1503            if build_path(child, target, path) {
1504                return true;
1505            }
1506        }
1507        for (_, children_list) in iter_children_lists(node) {
1508            for child in children_list {
1509                if build_path(child, target, path) {
1510                    return true;
1511                }
1512            }
1513        }
1514        path.pop();
1515        false
1516    }
1517
1518    let mut path = Vec::new();
1519    if !build_path(root, target as *const Expression, &mut path) {
1520        return None;
1521    }
1522
1523    // Walk path in reverse (parent first, then grandparent, etc.)
1524    for ancestor in path.iter().rev() {
1525        if predicate(ancestor) {
1526            return Some(ancestor);
1527        }
1528    }
1529    None
1530}
1531
1532#[cfg(test)]
1533mod tests {
1534    use super::*;
1535    use crate::expressions::{BinaryOp, Column, Identifier, Literal};
1536
1537    fn make_column(name: &str) -> Expression {
1538        Expression::boxed_column(Column {
1539            name: Identifier {
1540                name: name.to_string(),
1541                quoted: false,
1542                trailing_comments: vec![],
1543                span: None,
1544            },
1545            table: None,
1546            join_mark: false,
1547            trailing_comments: vec![],
1548            span: None,
1549            inferred_type: None,
1550        })
1551    }
1552
1553    fn make_literal(value: i64) -> Expression {
1554        Expression::Literal(Box::new(Literal::Number(value.to_string())))
1555    }
1556
1557    #[test]
1558    fn test_dfs_simple() {
1559        let left = make_column("a");
1560        let right = make_literal(1);
1561        let expr = Expression::Eq(Box::new(BinaryOp {
1562            left,
1563            right,
1564            left_comments: vec![],
1565            operator_comments: vec![],
1566            trailing_comments: vec![],
1567            inferred_type: None,
1568        }));
1569
1570        let nodes: Vec<_> = expr.dfs().collect();
1571        assert_eq!(nodes.len(), 3); // Eq, Column, Literal
1572        assert!(matches!(nodes[0], Expression::Eq(_)));
1573        assert!(matches!(nodes[1], Expression::Column(_)));
1574        assert!(matches!(nodes[2], Expression::Literal(_)));
1575    }
1576
1577    #[test]
1578    fn test_find() {
1579        let left = make_column("a");
1580        let right = make_literal(1);
1581        let expr = Expression::Eq(Box::new(BinaryOp {
1582            left,
1583            right,
1584            left_comments: vec![],
1585            operator_comments: vec![],
1586            trailing_comments: vec![],
1587            inferred_type: None,
1588        }));
1589
1590        let column = expr.find(is_column);
1591        assert!(column.is_some());
1592        assert!(matches!(column.unwrap(), Expression::Column(_)));
1593
1594        let literal = expr.find(is_literal);
1595        assert!(literal.is_some());
1596        assert!(matches!(literal.unwrap(), Expression::Literal(_)));
1597    }
1598
1599    #[test]
1600    fn test_find_all() {
1601        let col1 = make_column("a");
1602        let col2 = make_column("b");
1603        let expr = Expression::And(Box::new(BinaryOp {
1604            left: col1,
1605            right: col2,
1606            left_comments: vec![],
1607            operator_comments: vec![],
1608            trailing_comments: vec![],
1609            inferred_type: None,
1610        }));
1611
1612        let columns = expr.find_all(is_column);
1613        assert_eq!(columns.len(), 2);
1614    }
1615
1616    #[test]
1617    fn test_contains() {
1618        let col = make_column("a");
1619        let lit = make_literal(1);
1620        let expr = Expression::Eq(Box::new(BinaryOp {
1621            left: col,
1622            right: lit,
1623            left_comments: vec![],
1624            operator_comments: vec![],
1625            trailing_comments: vec![],
1626            inferred_type: None,
1627        }));
1628
1629        assert!(expr.contains(is_column));
1630        assert!(expr.contains(is_literal));
1631        assert!(!expr.contains(is_subquery));
1632    }
1633
1634    #[test]
1635    fn test_count() {
1636        let col1 = make_column("a");
1637        let col2 = make_column("b");
1638        let lit = make_literal(1);
1639
1640        let inner = Expression::Add(Box::new(BinaryOp {
1641            left: col2,
1642            right: lit,
1643            left_comments: vec![],
1644            operator_comments: vec![],
1645            trailing_comments: vec![],
1646            inferred_type: None,
1647        }));
1648
1649        let expr = Expression::Eq(Box::new(BinaryOp {
1650            left: col1,
1651            right: inner,
1652            left_comments: vec![],
1653            operator_comments: vec![],
1654            trailing_comments: vec![],
1655            inferred_type: None,
1656        }));
1657
1658        assert_eq!(expr.count(is_column), 2);
1659        assert_eq!(expr.count(is_literal), 1);
1660    }
1661
1662    #[test]
1663    fn test_tree_depth() {
1664        // Single node
1665        let lit = make_literal(1);
1666        assert_eq!(lit.tree_depth(), 0);
1667
1668        // One level
1669        let col = make_column("a");
1670        let expr = Expression::Eq(Box::new(BinaryOp {
1671            left: col,
1672            right: lit.clone(),
1673            left_comments: vec![],
1674            operator_comments: vec![],
1675            trailing_comments: vec![],
1676            inferred_type: None,
1677        }));
1678        assert_eq!(expr.tree_depth(), 1);
1679
1680        // Two levels
1681        let inner = Expression::Add(Box::new(BinaryOp {
1682            left: make_column("b"),
1683            right: lit,
1684            left_comments: vec![],
1685            operator_comments: vec![],
1686            trailing_comments: vec![],
1687            inferred_type: None,
1688        }));
1689        let outer = Expression::Eq(Box::new(BinaryOp {
1690            left: make_column("a"),
1691            right: inner,
1692            left_comments: vec![],
1693            operator_comments: vec![],
1694            trailing_comments: vec![],
1695            inferred_type: None,
1696        }));
1697        assert_eq!(outer.tree_depth(), 2);
1698    }
1699
1700    #[test]
1701    fn test_tree_context() {
1702        let col = make_column("a");
1703        let lit = make_literal(1);
1704        let expr = Expression::Eq(Box::new(BinaryOp {
1705            left: col,
1706            right: lit,
1707            left_comments: vec![],
1708            operator_comments: vec![],
1709            trailing_comments: vec![],
1710            inferred_type: None,
1711        }));
1712
1713        let ctx = TreeContext::build(&expr);
1714
1715        // Root has no parent
1716        let root_info = ctx.get(0).unwrap();
1717        assert!(root_info.parent_id.is_none());
1718
1719        // Children have root as parent
1720        let left_info = ctx.get(1).unwrap();
1721        assert_eq!(left_info.parent_id, Some(0));
1722        assert_eq!(left_info.arg_key, "left");
1723
1724        let right_info = ctx.get(2).unwrap();
1725        assert_eq!(right_info.parent_id, Some(0));
1726        assert_eq!(right_info.arg_key, "right");
1727    }
1728
1729    // -- Step 8: transform / transform_map tests --
1730
1731    #[test]
1732    fn test_transform_rename_columns() {
1733        let ast = crate::parser::Parser::parse_sql("SELECT a, b FROM t").unwrap();
1734        let expr = ast[0].clone();
1735        let result = super::transform_map(expr, &|e| {
1736            if let Expression::Column(ref c) = e {
1737                if c.name.name == "a" {
1738                    return Ok(Expression::boxed_column(Column {
1739                        name: Identifier::new("alpha"),
1740                        table: c.table.clone(),
1741                        join_mark: false,
1742                        trailing_comments: vec![],
1743                        span: None,
1744                        inferred_type: None,
1745                    }));
1746                }
1747            }
1748            Ok(e)
1749        })
1750        .unwrap();
1751        let sql = crate::generator::Generator::sql(&result).unwrap();
1752        assert!(sql.contains("alpha"), "Expected 'alpha' in: {}", sql);
1753        assert!(sql.contains("b"), "Expected 'b' in: {}", sql);
1754    }
1755
1756    #[test]
1757    fn test_transform_noop() {
1758        let ast = crate::parser::Parser::parse_sql("SELECT 1 + 2").unwrap();
1759        let expr = ast[0].clone();
1760        let result = super::transform_map(expr.clone(), &|e| Ok(e)).unwrap();
1761        let sql1 = crate::generator::Generator::sql(&expr).unwrap();
1762        let sql2 = crate::generator::Generator::sql(&result).unwrap();
1763        assert_eq!(sql1, sql2);
1764    }
1765
1766    #[test]
1767    fn test_transform_nested() {
1768        let ast = crate::parser::Parser::parse_sql("SELECT a + b FROM t").unwrap();
1769        let expr = ast[0].clone();
1770        let result = super::transform_map(expr, &|e| {
1771            if let Expression::Column(ref c) = e {
1772                return Ok(Expression::Literal(Box::new(Literal::Number(
1773                    if c.name.name == "a" { "1" } else { "2" }.to_string(),
1774                ))));
1775            }
1776            Ok(e)
1777        })
1778        .unwrap();
1779        let sql = crate::generator::Generator::sql(&result).unwrap();
1780        assert_eq!(sql, "SELECT 1 + 2 FROM t");
1781    }
1782
1783    #[test]
1784    fn test_transform_error() {
1785        let ast = crate::parser::Parser::parse_sql("SELECT a FROM t").unwrap();
1786        let expr = ast[0].clone();
1787        let result = super::transform_map(expr, &|e| {
1788            if let Expression::Column(ref c) = e {
1789                if c.name.name == "a" {
1790                    return Err(crate::error::Error::parse("test error", 0, 0, 0, 0));
1791                }
1792            }
1793            Ok(e)
1794        });
1795        assert!(result.is_err());
1796    }
1797
1798    #[test]
1799    fn test_transform_owned_trait() {
1800        let ast = crate::parser::Parser::parse_sql("SELECT x FROM t").unwrap();
1801        let expr = ast[0].clone();
1802        let result = expr.transform_owned(|e| Ok(Some(e))).unwrap();
1803        let sql = crate::generator::Generator::sql(&result).unwrap();
1804        assert_eq!(sql, "SELECT x FROM t");
1805    }
1806
1807    // -- children() tests --
1808
1809    #[test]
1810    fn test_children_leaf() {
1811        let lit = make_literal(1);
1812        assert_eq!(lit.children().len(), 0);
1813    }
1814
1815    #[test]
1816    fn test_children_binary_op() {
1817        let left = make_column("a");
1818        let right = make_literal(1);
1819        let expr = Expression::Eq(Box::new(BinaryOp {
1820            left,
1821            right,
1822            left_comments: vec![],
1823            operator_comments: vec![],
1824            trailing_comments: vec![],
1825            inferred_type: None,
1826        }));
1827        let children = expr.children();
1828        assert_eq!(children.len(), 2);
1829        assert!(matches!(children[0], Expression::Column(_)));
1830        assert!(matches!(children[1], Expression::Literal(_)));
1831    }
1832
1833    #[test]
1834    fn test_children_select() {
1835        let ast = crate::parser::Parser::parse_sql("SELECT a, b FROM t").unwrap();
1836        let expr = &ast[0];
1837        let children = expr.children();
1838        // Should include select list items (a, b)
1839        assert!(children.len() >= 2);
1840    }
1841
1842    #[test]
1843    fn test_children_select_includes_from_and_join_sources() {
1844        let ast = crate::parser::Parser::parse_sql(
1845            "SELECT u.id FROM users u JOIN orders o ON u.id = o.user_id",
1846        )
1847        .unwrap();
1848        let expr = &ast[0];
1849        let children = expr.children();
1850
1851        let table_names: Vec<&str> = children
1852            .iter()
1853            .filter_map(|e| match e {
1854                Expression::Table(t) => Some(t.name.name.as_str()),
1855                _ => None,
1856            })
1857            .collect();
1858
1859        assert!(table_names.contains(&"users"));
1860        assert!(table_names.contains(&"orders"));
1861    }
1862
1863    #[test]
1864    fn test_get_tables_includes_insert_query_sources() {
1865        let ast = crate::parser::Parser::parse_sql(
1866            "INSERT INTO dst (id) SELECT s.id FROM src s JOIN dim d ON s.id = d.id",
1867        )
1868        .unwrap();
1869        let expr = &ast[0];
1870        let tables = get_tables(expr);
1871        let names: Vec<&str> = tables
1872            .iter()
1873            .filter_map(|e| match e {
1874                Expression::Table(t) => Some(t.name.name.as_str()),
1875                _ => None,
1876            })
1877            .collect();
1878
1879        assert!(names.contains(&"src"));
1880        assert!(names.contains(&"dim"));
1881    }
1882
1883    // -- find_parent() tests --
1884
1885    #[test]
1886    fn test_find_parent_binary() {
1887        let left = make_column("a");
1888        let right = make_literal(1);
1889        let expr = Expression::Eq(Box::new(BinaryOp {
1890            left,
1891            right,
1892            left_comments: vec![],
1893            operator_comments: vec![],
1894            trailing_comments: vec![],
1895            inferred_type: None,
1896        }));
1897
1898        // Find the column child and get its parent
1899        let col = expr.find(is_column).unwrap();
1900        let parent = super::find_parent(&expr, col);
1901        assert!(parent.is_some());
1902        assert!(matches!(parent.unwrap(), Expression::Eq(_)));
1903    }
1904
1905    #[test]
1906    fn test_find_parent_root_has_none() {
1907        let lit = make_literal(1);
1908        let parent = super::find_parent(&lit, &lit);
1909        assert!(parent.is_none());
1910    }
1911
1912    // -- find_ancestor() tests --
1913
1914    #[test]
1915    fn test_find_ancestor_select() {
1916        let ast = crate::parser::Parser::parse_sql("SELECT a FROM t WHERE a > 1").unwrap();
1917        let expr = &ast[0];
1918
1919        // Find a column inside the WHERE clause
1920        let where_col = expr.dfs().find(|e| {
1921            if let Expression::Column(c) = e {
1922                c.name.name == "a"
1923            } else {
1924                false
1925            }
1926        });
1927        assert!(where_col.is_some());
1928
1929        // Find Select ancestor of that column
1930        let ancestor = super::find_ancestor(expr, where_col.unwrap(), is_select);
1931        assert!(ancestor.is_some());
1932        assert!(matches!(ancestor.unwrap(), Expression::Select(_)));
1933    }
1934
1935    #[test]
1936    fn test_find_ancestor_no_match() {
1937        let left = make_column("a");
1938        let right = make_literal(1);
1939        let expr = Expression::Eq(Box::new(BinaryOp {
1940            left,
1941            right,
1942            left_comments: vec![],
1943            operator_comments: vec![],
1944            trailing_comments: vec![],
1945            inferred_type: None,
1946        }));
1947
1948        let col = expr.find(is_column).unwrap();
1949        let ancestor = super::find_ancestor(&expr, col, is_select);
1950        assert!(ancestor.is_none());
1951    }
1952
1953    #[test]
1954    fn test_ancestors() {
1955        let col = make_column("a");
1956        let lit = make_literal(1);
1957        let inner = Expression::Add(Box::new(BinaryOp {
1958            left: col,
1959            right: lit,
1960            left_comments: vec![],
1961            operator_comments: vec![],
1962            trailing_comments: vec![],
1963            inferred_type: None,
1964        }));
1965        let outer = Expression::Eq(Box::new(BinaryOp {
1966            left: make_column("b"),
1967            right: inner,
1968            left_comments: vec![],
1969            operator_comments: vec![],
1970            trailing_comments: vec![],
1971            inferred_type: None,
1972        }));
1973
1974        let ctx = TreeContext::build(&outer);
1975
1976        // The inner Add's left child (column "a") should have ancestors
1977        // Node 0: Eq
1978        // Node 1: Column "b" (left of Eq)
1979        // Node 2: Add (right of Eq)
1980        // Node 3: Column "a" (left of Add)
1981        // Node 4: Literal (right of Add)
1982
1983        let ancestors = ctx.ancestors_of(3);
1984        assert_eq!(ancestors, vec![2, 0]); // Add, then Eq
1985    }
1986
1987    #[test]
1988    fn test_get_merge_target_and_source() {
1989        let dialect = crate::Dialect::get(crate::dialects::DialectType::Generic);
1990
1991        // MERGE with aliased target and source tables
1992        let sql = "MERGE INTO orders o USING customers c ON o.customer_id = c.id WHEN MATCHED THEN UPDATE SET amount = amount + 100";
1993        let exprs = dialect.parse(sql).unwrap();
1994        let expr = &exprs[0];
1995
1996        assert!(is_merge(expr));
1997        assert!(is_query(expr));
1998
1999        let target = get_merge_target(expr).expect("should find target table");
2000        assert!(matches!(target, Expression::Table(_)));
2001        if let Expression::Table(t) = target {
2002            assert_eq!(t.name.name, "orders");
2003        }
2004
2005        let source = get_merge_source(expr).expect("should find source table");
2006        assert!(matches!(source, Expression::Table(_)));
2007        if let Expression::Table(t) = source {
2008            assert_eq!(t.name.name, "customers");
2009        }
2010    }
2011
2012    #[test]
2013    fn test_get_merge_source_subquery_returns_none() {
2014        let dialect = crate::Dialect::get(crate::dialects::DialectType::Generic);
2015
2016        // MERGE with subquery source — get_merge_source should return None
2017        let sql = "MERGE INTO orders o USING (SELECT * FROM customers) c ON o.customer_id = c.id WHEN MATCHED THEN DELETE";
2018        let exprs = dialect.parse(sql).unwrap();
2019        let expr = &exprs[0];
2020
2021        assert!(get_merge_target(expr).is_some());
2022        assert!(get_merge_source(expr).is_none());
2023    }
2024
2025    #[test]
2026    fn test_get_merge_on_non_merge_returns_none() {
2027        let dialect = crate::Dialect::get(crate::dialects::DialectType::Generic);
2028        let exprs = dialect.parse("SELECT 1").unwrap();
2029        assert!(get_merge_target(&exprs[0]).is_none());
2030        assert!(get_merge_source(&exprs[0]).is_none());
2031    }
2032
2033    #[test]
2034    fn test_get_tables_finds_tables_inside_in_subquery() {
2035        let dialect = crate::Dialect::get(crate::dialects::DialectType::Generic);
2036        let sql = "SELECT id, name FROM customers WHERE id IN (SELECT customer_id FROM orders WHERE amount > 1000)";
2037        let exprs = dialect.parse(sql).unwrap();
2038        let tables = get_tables(&exprs[0]);
2039        let names: Vec<&str> = tables
2040            .iter()
2041            .filter_map(|e| {
2042                if let Expression::Table(t) = e {
2043                    Some(t.name.name.as_str())
2044                } else {
2045                    None
2046                }
2047            })
2048            .collect();
2049        assert!(names.contains(&"customers"), "should find outer table");
2050        assert!(names.contains(&"orders"), "should find subquery table");
2051    }
2052
2053    #[test]
2054    fn test_get_tables_finds_tables_inside_exists_subquery() {
2055        let dialect = crate::Dialect::get(crate::dialects::DialectType::Generic);
2056        let sql = "SELECT * FROM customers c WHERE EXISTS (SELECT 1 FROM orders o WHERE o.customer_id = c.id)";
2057        let exprs = dialect.parse(sql).unwrap();
2058        let tables = get_tables(&exprs[0]);
2059        let names: Vec<&str> = tables
2060            .iter()
2061            .filter_map(|e| {
2062                if let Expression::Table(t) = e {
2063                    Some(t.name.name.as_str())
2064                } else {
2065                    None
2066                }
2067            })
2068            .collect();
2069        assert!(names.contains(&"customers"), "should find outer table");
2070        assert!(
2071            names.contains(&"orders"),
2072            "should find EXISTS subquery table"
2073        );
2074    }
2075
2076    #[test]
2077    fn test_get_tables_finds_tables_in_correlated_subquery() {
2078        let dialect = crate::Dialect::get(crate::dialects::DialectType::TSQL);
2079        let sql = "SELECT id, name FROM customers WHERE id IN (SELECT customer_id FROM orders WHERE amount > 1000)";
2080        let exprs = dialect.parse(sql).unwrap();
2081        let tables = get_tables(&exprs[0]);
2082        let names: Vec<&str> = tables
2083            .iter()
2084            .filter_map(|e| {
2085                if let Expression::Table(t) = e {
2086                    Some(t.name.name.as_str())
2087                } else {
2088                    None
2089                }
2090            })
2091            .collect();
2092        assert!(
2093            names.contains(&"customers"),
2094            "TSQL: should find outer table"
2095        );
2096        assert!(
2097            names.contains(&"orders"),
2098            "TSQL: should find subquery table"
2099        );
2100    }
2101}