Skip to main content

polyglot_sql/
traversal.rs

1//! Tree traversal utilities for SQL expression ASTs.
2//!
3//! This module provides read-only traversal, search, and transformation utilities
4//! for the [`Expression`] tree produced by the parser. Because Rust's ownership
5//! model does not allow parent pointers inside the AST, parent information is
6//! tracked externally via [`TreeContext`] (built on demand).
7//!
8//! # Traversal
9//!
10//! Two iterator types are provided:
11//! - [`DfsIter`] -- depth-first (pre-order) traversal using a stack. Visits a node
12//!   before its children. Good for top-down analysis and early termination.
13//! - [`BfsIter`] -- breadth-first (level-order) traversal using a queue. Visits all
14//!   nodes at depth N before any node at depth N+1. Good for level-aware analysis.
15//!
16//! Both are available through the [`ExpressionWalk`] trait methods [`dfs`](ExpressionWalk::dfs)
17//! and [`bfs`](ExpressionWalk::bfs).
18//!
19//! # Searching
20//!
21//! The [`ExpressionWalk`] trait also provides convenience methods for finding expressions:
22//! [`find`](ExpressionWalk::find), [`find_all`](ExpressionWalk::find_all),
23//! [`contains`](ExpressionWalk::contains), and [`count`](ExpressionWalk::count).
24//! Common predicates are available as free functions: [`is_column`], [`is_literal`],
25//! [`is_function`], [`is_aggregate`], [`is_window_function`], [`is_subquery`], and
26//! [`is_select`].
27//!
28//! # Transformation
29//!
30//! The [`transform`] and [`transform_map`] functions perform bottom-up (post-order)
31//! tree rewrites, delegating to [`transform_recursive`](crate::dialects::transform_recursive).
32//! The [`ExpressionWalk::transform_owned`] method provides the same capability as
33//! an owned method on `Expression`.
34//!
35//! Based on traversal patterns from `sqlglot/expressions.py`.
36
37use crate::expressions::Expression;
38use std::collections::{HashMap, VecDeque};
39
40/// Unique identifier for expression nodes during traversal
41pub type NodeId = usize;
42
43/// Information about a node's parent relationship
44#[derive(Debug, Clone)]
45pub struct ParentInfo {
46    /// The NodeId of the parent (None for root)
47    pub parent_id: Option<NodeId>,
48    /// Which argument/field in the parent this node occupies
49    pub arg_key: String,
50    /// Index if the node is part of a list (e.g., expressions in SELECT)
51    pub index: Option<usize>,
52}
53
54/// External parent-tracking context for an expression tree.
55///
56/// Since Rust's ownership model does not allow intrusive parent pointers in the AST,
57/// `TreeContext` provides an on-demand side-table that maps each node (identified by
58/// a [`NodeId`]) to its [`ParentInfo`] (parent node, field name, and list index).
59///
60/// Build a context from any expression root with [`TreeContext::build`], then query
61/// parent relationships with [`get`](TreeContext::get), ancestry chains with
62/// [`ancestors_of`](TreeContext::ancestors_of), or tree depth with
63/// [`depth_of`](TreeContext::depth_of).
64///
65/// This is useful when analysis requires upward navigation (e.g., determining whether
66/// a column reference appears inside a WHERE clause or a JOIN condition).
67#[derive(Debug, Default)]
68pub struct TreeContext {
69    /// Map from NodeId to parent information
70    nodes: HashMap<NodeId, ParentInfo>,
71    /// Counter for generating NodeIds
72    next_id: NodeId,
73    /// Stack for tracking current path during traversal
74    path: Vec<(NodeId, String, Option<usize>)>,
75}
76
77impl TreeContext {
78    /// Create a new empty tree context
79    pub fn new() -> Self {
80        Self::default()
81    }
82
83    /// Build context from an expression tree
84    pub fn build(root: &Expression) -> Self {
85        let mut ctx = Self::new();
86        ctx.visit_expr(root);
87        ctx
88    }
89
90    /// Visit an expression and record parent information
91    fn visit_expr(&mut self, expr: &Expression) -> NodeId {
92        let id = self.next_id;
93        self.next_id += 1;
94
95        // Record parent info based on current path
96        let parent_info = if let Some((parent_id, arg_key, index)) = self.path.last() {
97            ParentInfo {
98                parent_id: Some(*parent_id),
99                arg_key: arg_key.clone(),
100                index: *index,
101            }
102        } else {
103            ParentInfo {
104                parent_id: None,
105                arg_key: String::new(),
106                index: None,
107            }
108        };
109        self.nodes.insert(id, parent_info);
110
111        // Visit children
112        for (key, child) in iter_children(expr) {
113            self.path.push((id, key.to_string(), None));
114            self.visit_expr(child);
115            self.path.pop();
116        }
117
118        // Visit children in lists
119        for (key, children) in iter_children_lists(expr) {
120            for (idx, child) in children.iter().enumerate() {
121                self.path.push((id, key.to_string(), Some(idx)));
122                self.visit_expr(child);
123                self.path.pop();
124            }
125        }
126
127        id
128    }
129
130    /// Get parent info for a node
131    pub fn get(&self, id: NodeId) -> Option<&ParentInfo> {
132        self.nodes.get(&id)
133    }
134
135    /// Get the depth of a node (0 for root)
136    pub fn depth_of(&self, id: NodeId) -> usize {
137        let mut depth = 0;
138        let mut current = id;
139        while let Some(info) = self.nodes.get(&current) {
140            if let Some(parent_id) = info.parent_id {
141                depth += 1;
142                current = parent_id;
143            } else {
144                break;
145            }
146        }
147        depth
148    }
149
150    /// Get ancestors of a node (parent, grandparent, etc.)
151    pub fn ancestors_of(&self, id: NodeId) -> Vec<NodeId> {
152        let mut ancestors = Vec::new();
153        let mut current = id;
154        while let Some(info) = self.nodes.get(&current) {
155            if let Some(parent_id) = info.parent_id {
156                ancestors.push(parent_id);
157                current = parent_id;
158            } else {
159                break;
160            }
161        }
162        ancestors
163    }
164}
165
166/// Iterate over single-child fields of an expression
167///
168/// Returns an iterator of (field_name, &Expression) pairs.
169fn iter_children(expr: &Expression) -> Vec<(&'static str, &Expression)> {
170    let mut children = Vec::new();
171
172    match expr {
173        Expression::Select(s) => {
174            if let Some(from) = &s.from {
175                for source in &from.expressions {
176                    children.push(("from", source));
177                }
178            }
179            for join in &s.joins {
180                children.push(("join_this", &join.this));
181                if let Some(on) = &join.on {
182                    children.push(("join_on", on));
183                }
184                if let Some(match_condition) = &join.match_condition {
185                    children.push(("join_match_condition", match_condition));
186                }
187                for pivot in &join.pivots {
188                    children.push(("join_pivot", pivot));
189                }
190            }
191            for lateral_view in &s.lateral_views {
192                children.push(("lateral_view", &lateral_view.this));
193            }
194            if let Some(prewhere) = &s.prewhere {
195                children.push(("prewhere", prewhere));
196            }
197            if let Some(where_clause) = &s.where_clause {
198                children.push(("where", &where_clause.this));
199            }
200            if let Some(group_by) = &s.group_by {
201                for e in &group_by.expressions {
202                    children.push(("group_by", e));
203                }
204            }
205            if let Some(having) = &s.having {
206                children.push(("having", &having.this));
207            }
208            if let Some(qualify) = &s.qualify {
209                children.push(("qualify", &qualify.this));
210            }
211            if let Some(order_by) = &s.order_by {
212                for ordered in &order_by.expressions {
213                    children.push(("order_by", &ordered.this));
214                }
215            }
216            if let Some(distribute_by) = &s.distribute_by {
217                for e in &distribute_by.expressions {
218                    children.push(("distribute_by", e));
219                }
220            }
221            if let Some(cluster_by) = &s.cluster_by {
222                for ordered in &cluster_by.expressions {
223                    children.push(("cluster_by", &ordered.this));
224                }
225            }
226            if let Some(sort_by) = &s.sort_by {
227                for ordered in &sort_by.expressions {
228                    children.push(("sort_by", &ordered.this));
229                }
230            }
231            if let Some(limit) = &s.limit {
232                children.push(("limit", &limit.this));
233            }
234            if let Some(offset) = &s.offset {
235                children.push(("offset", &offset.this));
236            }
237            if let Some(limit_by) = &s.limit_by {
238                for e in limit_by {
239                    children.push(("limit_by", e));
240                }
241            }
242            if let Some(fetch) = &s.fetch {
243                if let Some(count) = &fetch.count {
244                    children.push(("fetch", count));
245                }
246            }
247            if let Some(top) = &s.top {
248                children.push(("top", &top.this));
249            }
250            if let Some(with) = &s.with {
251                for cte in &with.ctes {
252                    children.push(("with_cte", &cte.this));
253                }
254                if let Some(search) = &with.search {
255                    children.push(("with_search", search));
256                }
257            }
258            if let Some(sample) = &s.sample {
259                children.push(("sample_size", &sample.size));
260                if let Some(seed) = &sample.seed {
261                    children.push(("sample_seed", seed));
262                }
263                if let Some(offset) = &sample.offset {
264                    children.push(("sample_offset", offset));
265                }
266                if let Some(bucket_numerator) = &sample.bucket_numerator {
267                    children.push(("sample_bucket_numerator", bucket_numerator));
268                }
269                if let Some(bucket_denominator) = &sample.bucket_denominator {
270                    children.push(("sample_bucket_denominator", bucket_denominator));
271                }
272                if let Some(bucket_field) = &sample.bucket_field {
273                    children.push(("sample_bucket_field", bucket_field));
274                }
275            }
276            if let Some(connect) = &s.connect {
277                if let Some(start) = &connect.start {
278                    children.push(("connect_start", start));
279                }
280                children.push(("connect", &connect.connect));
281            }
282            if let Some(into) = &s.into {
283                children.push(("into", &into.this));
284            }
285            for lock in &s.locks {
286                for e in &lock.expressions {
287                    children.push(("lock_expression", e));
288                }
289                if let Some(wait) = &lock.wait {
290                    children.push(("lock_wait", wait));
291                }
292                if let Some(key) = &lock.key {
293                    children.push(("lock_key", key));
294                }
295                if let Some(update) = &lock.update {
296                    children.push(("lock_update", update));
297                }
298            }
299            for e in &s.for_xml {
300                children.push(("for_xml", e));
301            }
302        }
303        Expression::With(with) => {
304            for cte in &with.ctes {
305                children.push(("cte", &cte.this));
306            }
307            if let Some(search) = &with.search {
308                children.push(("search", search));
309            }
310        }
311        Expression::Cte(cte) => {
312            children.push(("this", &cte.this));
313        }
314        Expression::Insert(insert) => {
315            if let Some(query) = &insert.query {
316                children.push(("query", query));
317            }
318            if let Some(with) = &insert.with {
319                for cte in &with.ctes {
320                    children.push(("with_cte", &cte.this));
321                }
322                if let Some(search) = &with.search {
323                    children.push(("with_search", search));
324                }
325            }
326            if let Some(on_conflict) = &insert.on_conflict {
327                children.push(("on_conflict", on_conflict));
328            }
329            if let Some(replace_where) = &insert.replace_where {
330                children.push(("replace_where", replace_where));
331            }
332            if let Some(source) = &insert.source {
333                children.push(("source", source));
334            }
335            if let Some(function_target) = &insert.function_target {
336                children.push(("function_target", function_target));
337            }
338            if let Some(partition_by) = &insert.partition_by {
339                children.push(("partition_by", partition_by));
340            }
341            if let Some(output) = &insert.output {
342                for column in &output.columns {
343                    children.push(("output_column", column));
344                }
345                if let Some(into_table) = &output.into_table {
346                    children.push(("output_into_table", into_table));
347                }
348            }
349            for row in &insert.values {
350                for value in row {
351                    children.push(("value", value));
352                }
353            }
354            for (_, value) in &insert.partition {
355                if let Some(value) = value {
356                    children.push(("partition_value", value));
357                }
358            }
359            for returning in &insert.returning {
360                children.push(("returning", returning));
361            }
362            for setting in &insert.settings {
363                children.push(("setting", setting));
364            }
365        }
366        Expression::Update(update) => {
367            if let Some(from_clause) = &update.from_clause {
368                for source in &from_clause.expressions {
369                    children.push(("from", source));
370                }
371            }
372            for join in &update.table_joins {
373                children.push(("table_join_this", &join.this));
374                if let Some(on) = &join.on {
375                    children.push(("table_join_on", on));
376                }
377            }
378            for join in &update.from_joins {
379                children.push(("from_join_this", &join.this));
380                if let Some(on) = &join.on {
381                    children.push(("from_join_on", on));
382                }
383            }
384            for (_, value) in &update.set {
385                children.push(("set_value", value));
386            }
387            if let Some(where_clause) = &update.where_clause {
388                children.push(("where", &where_clause.this));
389            }
390            if let Some(output) = &update.output {
391                for column in &output.columns {
392                    children.push(("output_column", column));
393                }
394                if let Some(into_table) = &output.into_table {
395                    children.push(("output_into_table", into_table));
396                }
397            }
398            if let Some(with) = &update.with {
399                for cte in &with.ctes {
400                    children.push(("with_cte", &cte.this));
401                }
402                if let Some(search) = &with.search {
403                    children.push(("with_search", search));
404                }
405            }
406            if let Some(limit) = &update.limit {
407                children.push(("limit", limit));
408            }
409            if let Some(order_by) = &update.order_by {
410                for ordered in &order_by.expressions {
411                    children.push(("order_by", &ordered.this));
412                }
413            }
414            for returning in &update.returning {
415                children.push(("returning", returning));
416            }
417        }
418        Expression::Delete(delete) => {
419            if let Some(with) = &delete.with {
420                for cte in &with.ctes {
421                    children.push(("with_cte", &cte.this));
422                }
423                if let Some(search) = &with.search {
424                    children.push(("with_search", search));
425                }
426            }
427            if let Some(where_clause) = &delete.where_clause {
428                children.push(("where", &where_clause.this));
429            }
430            if let Some(output) = &delete.output {
431                for column in &output.columns {
432                    children.push(("output_column", column));
433                }
434                if let Some(into_table) = &output.into_table {
435                    children.push(("output_into_table", into_table));
436                }
437            }
438            if let Some(limit) = &delete.limit {
439                children.push(("limit", limit));
440            }
441            if let Some(order_by) = &delete.order_by {
442                for ordered in &order_by.expressions {
443                    children.push(("order_by", &ordered.this));
444                }
445            }
446            for returning in &delete.returning {
447                children.push(("returning", returning));
448            }
449            for join in &delete.joins {
450                children.push(("join_this", &join.this));
451                if let Some(on) = &join.on {
452                    children.push(("join_on", on));
453                }
454            }
455        }
456        Expression::Join(join) => {
457            children.push(("this", &join.this));
458            if let Some(on) = &join.on {
459                children.push(("on", on));
460            }
461            if let Some(match_condition) = &join.match_condition {
462                children.push(("match_condition", match_condition));
463            }
464            for pivot in &join.pivots {
465                children.push(("pivot", pivot));
466            }
467        }
468        Expression::Alias(a) => {
469            children.push(("this", &a.this));
470        }
471        Expression::Cast(c) => {
472            children.push(("this", &c.this));
473        }
474        Expression::Not(u) | Expression::Neg(u) | Expression::BitwiseNot(u) => {
475            children.push(("this", &u.this));
476        }
477        Expression::Paren(p) => {
478            children.push(("this", &p.this));
479        }
480        Expression::IsNull(i) => {
481            children.push(("this", &i.this));
482        }
483        Expression::Exists(e) => {
484            children.push(("this", &e.this));
485        }
486        Expression::Subquery(s) => {
487            children.push(("this", &s.this));
488        }
489        Expression::Where(w) => {
490            children.push(("this", &w.this));
491        }
492        Expression::Having(h) => {
493            children.push(("this", &h.this));
494        }
495        Expression::Qualify(q) => {
496            children.push(("this", &q.this));
497        }
498        Expression::And(op)
499        | Expression::Or(op)
500        | Expression::Add(op)
501        | Expression::Sub(op)
502        | Expression::Mul(op)
503        | Expression::Div(op)
504        | Expression::Mod(op)
505        | Expression::Eq(op)
506        | Expression::Neq(op)
507        | Expression::Lt(op)
508        | Expression::Lte(op)
509        | Expression::Gt(op)
510        | Expression::Gte(op)
511        | Expression::BitwiseAnd(op)
512        | Expression::BitwiseOr(op)
513        | Expression::BitwiseXor(op)
514        | Expression::Concat(op) => {
515            children.push(("left", &op.left));
516            children.push(("right", &op.right));
517        }
518        Expression::Like(op) | Expression::ILike(op) => {
519            children.push(("left", &op.left));
520            children.push(("right", &op.right));
521        }
522        Expression::Between(b) => {
523            children.push(("this", &b.this));
524            children.push(("low", &b.low));
525            children.push(("high", &b.high));
526        }
527        Expression::In(i) => {
528            children.push(("this", &i.this));
529        }
530        Expression::Case(c) => {
531            if let Some(ref operand) = &c.operand {
532                children.push(("operand", operand));
533            }
534        }
535        Expression::WindowFunction(wf) => {
536            children.push(("this", &wf.this));
537        }
538        Expression::Union(u) => {
539            children.push(("left", &u.left));
540            children.push(("right", &u.right));
541            if let Some(with) = &u.with {
542                for cte in &with.ctes {
543                    children.push(("with_cte", &cte.this));
544                }
545                if let Some(search) = &with.search {
546                    children.push(("with_search", search));
547                }
548            }
549            if let Some(order_by) = &u.order_by {
550                for ordered in &order_by.expressions {
551                    children.push(("order_by", &ordered.this));
552                }
553            }
554            if let Some(limit) = &u.limit {
555                children.push(("limit", limit));
556            }
557            if let Some(offset) = &u.offset {
558                children.push(("offset", offset));
559            }
560            if let Some(distribute_by) = &u.distribute_by {
561                for e in &distribute_by.expressions {
562                    children.push(("distribute_by", e));
563                }
564            }
565            if let Some(sort_by) = &u.sort_by {
566                for ordered in &sort_by.expressions {
567                    children.push(("sort_by", &ordered.this));
568                }
569            }
570            if let Some(cluster_by) = &u.cluster_by {
571                for ordered in &cluster_by.expressions {
572                    children.push(("cluster_by", &ordered.this));
573                }
574            }
575            for e in &u.on_columns {
576                children.push(("on_column", e));
577            }
578        }
579        Expression::Intersect(i) => {
580            children.push(("left", &i.left));
581            children.push(("right", &i.right));
582            if let Some(with) = &i.with {
583                for cte in &with.ctes {
584                    children.push(("with_cte", &cte.this));
585                }
586                if let Some(search) = &with.search {
587                    children.push(("with_search", search));
588                }
589            }
590            if let Some(order_by) = &i.order_by {
591                for ordered in &order_by.expressions {
592                    children.push(("order_by", &ordered.this));
593                }
594            }
595            if let Some(limit) = &i.limit {
596                children.push(("limit", limit));
597            }
598            if let Some(offset) = &i.offset {
599                children.push(("offset", offset));
600            }
601            if let Some(distribute_by) = &i.distribute_by {
602                for e in &distribute_by.expressions {
603                    children.push(("distribute_by", e));
604                }
605            }
606            if let Some(sort_by) = &i.sort_by {
607                for ordered in &sort_by.expressions {
608                    children.push(("sort_by", &ordered.this));
609                }
610            }
611            if let Some(cluster_by) = &i.cluster_by {
612                for ordered in &cluster_by.expressions {
613                    children.push(("cluster_by", &ordered.this));
614                }
615            }
616            for e in &i.on_columns {
617                children.push(("on_column", e));
618            }
619        }
620        Expression::Except(e) => {
621            children.push(("left", &e.left));
622            children.push(("right", &e.right));
623            if let Some(with) = &e.with {
624                for cte in &with.ctes {
625                    children.push(("with_cte", &cte.this));
626                }
627                if let Some(search) = &with.search {
628                    children.push(("with_search", search));
629                }
630            }
631            if let Some(order_by) = &e.order_by {
632                for ordered in &order_by.expressions {
633                    children.push(("order_by", &ordered.this));
634                }
635            }
636            if let Some(limit) = &e.limit {
637                children.push(("limit", limit));
638            }
639            if let Some(offset) = &e.offset {
640                children.push(("offset", offset));
641            }
642            if let Some(distribute_by) = &e.distribute_by {
643                for expr in &distribute_by.expressions {
644                    children.push(("distribute_by", expr));
645                }
646            }
647            if let Some(sort_by) = &e.sort_by {
648                for ordered in &sort_by.expressions {
649                    children.push(("sort_by", &ordered.this));
650                }
651            }
652            if let Some(cluster_by) = &e.cluster_by {
653                for ordered in &cluster_by.expressions {
654                    children.push(("cluster_by", &ordered.this));
655                }
656            }
657            for expr in &e.on_columns {
658                children.push(("on_column", expr));
659            }
660        }
661        Expression::Merge(merge) => {
662            children.push(("this", &merge.this));
663            children.push(("using", &merge.using));
664            if let Some(on) = &merge.on {
665                children.push(("on", on));
666            }
667            if let Some(using_cond) = &merge.using_cond {
668                children.push(("using_cond", using_cond));
669            }
670            if let Some(whens) = &merge.whens {
671                children.push(("whens", whens));
672            }
673            if let Some(with_) = &merge.with_ {
674                children.push(("with_", with_));
675            }
676            if let Some(returning) = &merge.returning {
677                children.push(("returning", returning));
678            }
679        }
680        Expression::Ordered(o) => {
681            children.push(("this", &o.this));
682        }
683        Expression::Interval(i) => {
684            if let Some(ref this) = i.this {
685                children.push(("this", this));
686            }
687        }
688        _ => {}
689    }
690
691    children
692}
693
694/// Iterate over list-child fields of an expression
695///
696/// Returns an iterator of (field_name, &[Expression]) pairs.
697fn iter_children_lists(expr: &Expression) -> Vec<(&'static str, &[Expression])> {
698    let mut lists = Vec::new();
699
700    match expr {
701        Expression::Select(s) => lists.push(("expressions", s.expressions.as_slice())),
702        Expression::Function(f) => {
703            lists.push(("args", f.args.as_slice()));
704        }
705        Expression::AggregateFunction(f) => {
706            lists.push(("args", f.args.as_slice()));
707        }
708        Expression::From(f) => {
709            lists.push(("expressions", f.expressions.as_slice()));
710        }
711        Expression::GroupBy(g) => {
712            lists.push(("expressions", g.expressions.as_slice()));
713        }
714        // OrderBy.expressions is Vec<Ordered>, not Vec<Expression>
715        // We handle Ordered items via iter_children
716        Expression::In(i) => {
717            lists.push(("expressions", i.expressions.as_slice()));
718        }
719        Expression::Array(a) => {
720            lists.push(("expressions", a.expressions.as_slice()));
721        }
722        Expression::Tuple(t) => {
723            lists.push(("expressions", t.expressions.as_slice()));
724        }
725        // Values.expressions is Vec<Tuple>, handle specially
726        Expression::Coalesce(c) => {
727            lists.push(("expressions", c.expressions.as_slice()));
728        }
729        Expression::Greatest(g) | Expression::Least(g) => {
730            lists.push(("expressions", g.expressions.as_slice()));
731        }
732        _ => {}
733    }
734
735    lists
736}
737
738/// Pre-order depth-first iterator over an expression tree.
739///
740/// Visits each node before its children, using a stack-based approach. This means
741/// the root is yielded first, followed by the entire left subtree (recursively),
742/// then the right subtree. For a binary expression `a + b`, the iteration order
743/// is: `Add`, `a`, `b`.
744///
745/// Created via [`ExpressionWalk::dfs`] or [`DfsIter::new`].
746pub struct DfsIter<'a> {
747    stack: Vec<&'a Expression>,
748}
749
750impl<'a> DfsIter<'a> {
751    /// Create a new DFS iterator starting from the given expression
752    pub fn new(root: &'a Expression) -> Self {
753        Self { stack: vec![root] }
754    }
755}
756
757impl<'a> Iterator for DfsIter<'a> {
758    type Item = &'a Expression;
759
760    fn next(&mut self) -> Option<Self::Item> {
761        let expr = self.stack.pop()?;
762
763        // Add children in reverse order so they come out in forward order
764        let children: Vec<_> = iter_children(expr).into_iter().map(|(_, e)| e).collect();
765        for child in children.into_iter().rev() {
766            self.stack.push(child);
767        }
768
769        let lists: Vec<_> = iter_children_lists(expr)
770            .into_iter()
771            .flat_map(|(_, es)| es.iter())
772            .collect();
773        for child in lists.into_iter().rev() {
774            self.stack.push(child);
775        }
776
777        Some(expr)
778    }
779}
780
781/// Level-order breadth-first iterator over an expression tree.
782///
783/// Visits all nodes at depth N before any node at depth N+1, using a queue-based
784/// approach. For a tree `(a + b) = c`, the iteration order is: `Eq` (depth 0),
785/// `Add`, `c` (depth 1), `a`, `b` (depth 2).
786///
787/// Created via [`ExpressionWalk::bfs`] or [`BfsIter::new`].
788pub struct BfsIter<'a> {
789    queue: VecDeque<&'a Expression>,
790}
791
792impl<'a> BfsIter<'a> {
793    /// Create a new BFS iterator starting from the given expression
794    pub fn new(root: &'a Expression) -> Self {
795        let mut queue = VecDeque::new();
796        queue.push_back(root);
797        Self { queue }
798    }
799}
800
801impl<'a> Iterator for BfsIter<'a> {
802    type Item = &'a Expression;
803
804    fn next(&mut self) -> Option<Self::Item> {
805        let expr = self.queue.pop_front()?;
806
807        // Add children to queue
808        for (_, child) in iter_children(expr) {
809            self.queue.push_back(child);
810        }
811
812        for (_, children) in iter_children_lists(expr) {
813            for child in children {
814                self.queue.push_back(child);
815            }
816        }
817
818        Some(expr)
819    }
820}
821
822/// Extension trait that adds traversal and search methods to [`Expression`].
823///
824/// This trait is implemented for `Expression` and provides a fluent API for
825/// iterating, searching, measuring, and transforming expression trees without
826/// needing to import the iterator types directly.
827pub trait ExpressionWalk {
828    /// Returns a depth-first (pre-order) iterator over this expression and all descendants.
829    ///
830    /// The root node is yielded first, then its children are visited recursively
831    /// from left to right.
832    fn dfs(&self) -> DfsIter<'_>;
833
834    /// Returns a breadth-first (level-order) iterator over this expression and all descendants.
835    ///
836    /// All nodes at depth N are yielded before any node at depth N+1.
837    fn bfs(&self) -> BfsIter<'_>;
838
839    /// Finds the first expression matching `predicate` in depth-first order.
840    ///
841    /// Returns `None` if no descendant (including this node) matches.
842    fn find<F>(&self, predicate: F) -> Option<&Expression>
843    where
844        F: Fn(&Expression) -> bool;
845
846    /// Collects all expressions matching `predicate` in depth-first order.
847    ///
848    /// Returns an empty vector if no descendants match.
849    fn find_all<F>(&self, predicate: F) -> Vec<&Expression>
850    where
851        F: Fn(&Expression) -> bool;
852
853    /// Returns `true` if this node or any descendant matches `predicate`.
854    fn contains<F>(&self, predicate: F) -> bool
855    where
856        F: Fn(&Expression) -> bool;
857
858    /// Counts how many nodes (including this one) match `predicate`.
859    fn count<F>(&self, predicate: F) -> usize
860    where
861        F: Fn(&Expression) -> bool;
862
863    /// Returns direct child expressions of this node.
864    ///
865    /// Collects all single-child fields and list-child fields into a flat vector
866    /// of references. Leaf nodes return an empty vector.
867    fn children(&self) -> Vec<&Expression>;
868
869    /// Returns the maximum depth of the expression tree rooted at this node.
870    ///
871    /// A leaf node has depth 0, a node whose deepest child is a leaf has depth 1, etc.
872    fn tree_depth(&self) -> usize;
873
874    /// Transforms this expression tree bottom-up using the given function (owned variant).
875    ///
876    /// Children are transformed first, then `fun` is called on the resulting node.
877    /// Return `Ok(None)` from `fun` to replace a node with `NULL`.
878    /// Return `Ok(Some(expr))` to substitute the node with `expr`.
879    fn transform_owned<F>(self, fun: F) -> crate::Result<Expression>
880    where
881        F: Fn(Expression) -> crate::Result<Option<Expression>>,
882        Self: Sized;
883}
884
885impl ExpressionWalk for Expression {
886    fn dfs(&self) -> DfsIter<'_> {
887        DfsIter::new(self)
888    }
889
890    fn bfs(&self) -> BfsIter<'_> {
891        BfsIter::new(self)
892    }
893
894    fn find<F>(&self, predicate: F) -> Option<&Expression>
895    where
896        F: Fn(&Expression) -> bool,
897    {
898        self.dfs().find(|e| predicate(e))
899    }
900
901    fn find_all<F>(&self, predicate: F) -> Vec<&Expression>
902    where
903        F: Fn(&Expression) -> bool,
904    {
905        self.dfs().filter(|e| predicate(e)).collect()
906    }
907
908    fn contains<F>(&self, predicate: F) -> bool
909    where
910        F: Fn(&Expression) -> bool,
911    {
912        self.dfs().any(|e| predicate(e))
913    }
914
915    fn count<F>(&self, predicate: F) -> usize
916    where
917        F: Fn(&Expression) -> bool,
918    {
919        self.dfs().filter(|e| predicate(e)).count()
920    }
921
922    fn children(&self) -> Vec<&Expression> {
923        let mut result: Vec<&Expression> = Vec::new();
924        for (_, child) in iter_children(self) {
925            result.push(child);
926        }
927        for (_, children_list) in iter_children_lists(self) {
928            for child in children_list {
929                result.push(child);
930            }
931        }
932        result
933    }
934
935    fn tree_depth(&self) -> usize {
936        let mut max_depth = 0;
937
938        for (_, child) in iter_children(self) {
939            let child_depth = child.tree_depth();
940            if child_depth + 1 > max_depth {
941                max_depth = child_depth + 1;
942            }
943        }
944
945        for (_, children) in iter_children_lists(self) {
946            for child in children {
947                let child_depth = child.tree_depth();
948                if child_depth + 1 > max_depth {
949                    max_depth = child_depth + 1;
950                }
951            }
952        }
953
954        max_depth
955    }
956
957    fn transform_owned<F>(self, fun: F) -> crate::Result<Expression>
958    where
959        F: Fn(Expression) -> crate::Result<Option<Expression>>,
960    {
961        transform(self, &fun)
962    }
963}
964
965/// Transforms an expression tree bottom-up, with optional node removal.
966///
967/// Recursively transforms all children first, then applies `fun` to the resulting node.
968/// If `fun` returns `Ok(None)`, the node is replaced with an `Expression::Null`.
969/// If `fun` returns `Ok(Some(expr))`, the node is replaced with `expr`.
970///
971/// This is the primary transformation entry point when callers need the ability to
972/// "delete" nodes by returning `None`.
973///
974/// # Example
975///
976/// ```rust,ignore
977/// use polyglot_sql::traversal::transform;
978///
979/// // Remove all Paren wrapper nodes from a tree
980/// let result = transform(expr, &|e| match e {
981///     Expression::Paren(p) => Ok(Some(p.this)),
982///     other => Ok(Some(other)),
983/// })?;
984/// ```
985pub fn transform<F>(expr: Expression, fun: &F) -> crate::Result<Expression>
986where
987    F: Fn(Expression) -> crate::Result<Option<Expression>>,
988{
989    crate::dialects::transform_recursive(expr, &|e| match fun(e)? {
990        Some(transformed) => Ok(transformed),
991        None => Ok(Expression::Null(crate::expressions::Null)),
992    })
993}
994
995/// Transforms an expression tree bottom-up without node removal.
996///
997/// Like [`transform`], but `fun` returns an `Expression` directly rather than
998/// `Option<Expression>`, so nodes cannot be deleted. This is a convenience wrapper
999/// for the common case where every node is mapped to exactly one output node.
1000///
1001/// # Example
1002///
1003/// ```rust,ignore
1004/// use polyglot_sql::traversal::transform_map;
1005///
1006/// // Uppercase all column names in a tree
1007/// let result = transform_map(expr, &|e| match e {
1008///     Expression::Column(mut c) => {
1009///         c.name.name = c.name.name.to_uppercase();
1010///         Ok(Expression::Column(c))
1011///     }
1012///     other => Ok(other),
1013/// })?;
1014/// ```
1015pub fn transform_map<F>(expr: Expression, fun: &F) -> crate::Result<Expression>
1016where
1017    F: Fn(Expression) -> crate::Result<Expression>,
1018{
1019    crate::dialects::transform_recursive(expr, fun)
1020}
1021
1022// ---------------------------------------------------------------------------
1023// Common expression predicates
1024// ---------------------------------------------------------------------------
1025// These free functions are intended for use with the search methods on
1026// `ExpressionWalk` (e.g., `expr.find(is_column)`, `expr.contains(is_aggregate)`).
1027
1028/// Returns `true` if `expr` is a column reference ([`Expression::Column`]).
1029pub fn is_column(expr: &Expression) -> bool {
1030    matches!(expr, Expression::Column(_))
1031}
1032
1033/// Returns `true` if `expr` is a literal value (number, string, boolean, or NULL).
1034pub fn is_literal(expr: &Expression) -> bool {
1035    matches!(
1036        expr,
1037        Expression::Literal(_) | Expression::Boolean(_) | Expression::Null(_)
1038    )
1039}
1040
1041/// Returns `true` if `expr` is a function call (regular or aggregate).
1042pub fn is_function(expr: &Expression) -> bool {
1043    matches!(
1044        expr,
1045        Expression::Function(_) | Expression::AggregateFunction(_)
1046    )
1047}
1048
1049/// Returns `true` if `expr` is a subquery ([`Expression::Subquery`]).
1050pub fn is_subquery(expr: &Expression) -> bool {
1051    matches!(expr, Expression::Subquery(_))
1052}
1053
1054/// Returns `true` if `expr` is a SELECT statement ([`Expression::Select`]).
1055pub fn is_select(expr: &Expression) -> bool {
1056    matches!(expr, Expression::Select(_))
1057}
1058
1059/// Returns `true` if `expr` is an aggregate function ([`Expression::AggregateFunction`]).
1060pub fn is_aggregate(expr: &Expression) -> bool {
1061    matches!(expr, Expression::AggregateFunction(_))
1062}
1063
1064/// Returns `true` if `expr` is a window function ([`Expression::WindowFunction`]).
1065pub fn is_window_function(expr: &Expression) -> bool {
1066    matches!(expr, Expression::WindowFunction(_))
1067}
1068
1069/// Collects all column references ([`Expression::Column`]) from the expression tree.
1070///
1071/// Performs a depth-first search and returns references to every column node found.
1072pub fn get_columns(expr: &Expression) -> Vec<&Expression> {
1073    expr.find_all(is_column)
1074}
1075
1076/// Collects all table references ([`Expression::Table`]) from the expression tree.
1077///
1078/// Performs a depth-first search and returns references to every table node found.
1079pub fn get_tables(expr: &Expression) -> Vec<&Expression> {
1080    expr.find_all(|e| matches!(e, Expression::Table(_)))
1081}
1082
1083/// Returns `true` if the expression tree contains any aggregate function calls.
1084pub fn contains_aggregate(expr: &Expression) -> bool {
1085    expr.contains(is_aggregate)
1086}
1087
1088/// Returns `true` if the expression tree contains any window function calls.
1089pub fn contains_window_function(expr: &Expression) -> bool {
1090    expr.contains(is_window_function)
1091}
1092
1093/// Returns `true` if the expression tree contains any subquery nodes.
1094pub fn contains_subquery(expr: &Expression) -> bool {
1095    expr.contains(is_subquery)
1096}
1097
1098// ---------------------------------------------------------------------------
1099// Extended type predicates
1100// ---------------------------------------------------------------------------
1101
1102/// Macro for generating simple type-predicate functions.
1103macro_rules! is_type {
1104    ($name:ident, $($variant:pat),+ $(,)?) => {
1105        /// Returns `true` if `expr` matches the expected AST variant(s).
1106        pub fn $name(expr: &Expression) -> bool {
1107            matches!(expr, $($variant)|+)
1108        }
1109    };
1110}
1111
1112// Query
1113is_type!(is_insert, Expression::Insert(_));
1114is_type!(is_update, Expression::Update(_));
1115is_type!(is_delete, Expression::Delete(_));
1116is_type!(is_union, Expression::Union(_));
1117is_type!(is_intersect, Expression::Intersect(_));
1118is_type!(is_except, Expression::Except(_));
1119
1120// Identifiers & literals
1121is_type!(is_boolean, Expression::Boolean(_));
1122is_type!(is_null_literal, Expression::Null(_));
1123is_type!(is_star, Expression::Star(_));
1124is_type!(is_identifier, Expression::Identifier(_));
1125is_type!(is_table, Expression::Table(_));
1126
1127// Comparison
1128is_type!(is_eq, Expression::Eq(_));
1129is_type!(is_neq, Expression::Neq(_));
1130is_type!(is_lt, Expression::Lt(_));
1131is_type!(is_lte, Expression::Lte(_));
1132is_type!(is_gt, Expression::Gt(_));
1133is_type!(is_gte, Expression::Gte(_));
1134is_type!(is_like, Expression::Like(_));
1135is_type!(is_ilike, Expression::ILike(_));
1136
1137// Arithmetic
1138is_type!(is_add, Expression::Add(_));
1139is_type!(is_sub, Expression::Sub(_));
1140is_type!(is_mul, Expression::Mul(_));
1141is_type!(is_div, Expression::Div(_));
1142is_type!(is_mod, Expression::Mod(_));
1143is_type!(is_concat, Expression::Concat(_));
1144
1145// Logical
1146is_type!(is_and, Expression::And(_));
1147is_type!(is_or, Expression::Or(_));
1148is_type!(is_not, Expression::Not(_));
1149
1150// Predicates
1151is_type!(is_in, Expression::In(_));
1152is_type!(is_between, Expression::Between(_));
1153is_type!(is_is_null, Expression::IsNull(_));
1154is_type!(is_exists, Expression::Exists(_));
1155
1156// Functions
1157is_type!(is_count, Expression::Count(_));
1158is_type!(is_sum, Expression::Sum(_));
1159is_type!(is_avg, Expression::Avg(_));
1160is_type!(is_min_func, Expression::Min(_));
1161is_type!(is_max_func, Expression::Max(_));
1162is_type!(is_coalesce, Expression::Coalesce(_));
1163is_type!(is_null_if, Expression::NullIf(_));
1164is_type!(is_cast, Expression::Cast(_));
1165is_type!(is_try_cast, Expression::TryCast(_));
1166is_type!(is_safe_cast, Expression::SafeCast(_));
1167is_type!(is_case, Expression::Case(_));
1168
1169// Clauses
1170is_type!(is_from, Expression::From(_));
1171is_type!(is_join, Expression::Join(_));
1172is_type!(is_where, Expression::Where(_));
1173is_type!(is_group_by, Expression::GroupBy(_));
1174is_type!(is_having, Expression::Having(_));
1175is_type!(is_order_by, Expression::OrderBy(_));
1176is_type!(is_limit, Expression::Limit(_));
1177is_type!(is_offset, Expression::Offset(_));
1178is_type!(is_with, Expression::With(_));
1179is_type!(is_cte, Expression::Cte(_));
1180is_type!(is_alias, Expression::Alias(_));
1181is_type!(is_paren, Expression::Paren(_));
1182is_type!(is_ordered, Expression::Ordered(_));
1183
1184// DDL
1185is_type!(is_create_table, Expression::CreateTable(_));
1186is_type!(is_drop_table, Expression::DropTable(_));
1187is_type!(is_alter_table, Expression::AlterTable(_));
1188is_type!(is_create_index, Expression::CreateIndex(_));
1189is_type!(is_drop_index, Expression::DropIndex(_));
1190is_type!(is_create_view, Expression::CreateView(_));
1191is_type!(is_drop_view, Expression::DropView(_));
1192
1193// ---------------------------------------------------------------------------
1194// Composite predicates
1195// ---------------------------------------------------------------------------
1196
1197/// Returns `true` if `expr` is a query statement (SELECT, INSERT, UPDATE, or DELETE).
1198pub fn is_query(expr: &Expression) -> bool {
1199    matches!(
1200        expr,
1201        Expression::Select(_)
1202            | Expression::Insert(_)
1203            | Expression::Update(_)
1204            | Expression::Delete(_)
1205    )
1206}
1207
1208/// Returns `true` if `expr` is a set operation (UNION, INTERSECT, or EXCEPT).
1209pub fn is_set_operation(expr: &Expression) -> bool {
1210    matches!(
1211        expr,
1212        Expression::Union(_) | Expression::Intersect(_) | Expression::Except(_)
1213    )
1214}
1215
1216/// Returns `true` if `expr` is a comparison operator.
1217pub fn is_comparison(expr: &Expression) -> bool {
1218    matches!(
1219        expr,
1220        Expression::Eq(_)
1221            | Expression::Neq(_)
1222            | Expression::Lt(_)
1223            | Expression::Lte(_)
1224            | Expression::Gt(_)
1225            | Expression::Gte(_)
1226            | Expression::Like(_)
1227            | Expression::ILike(_)
1228    )
1229}
1230
1231/// Returns `true` if `expr` is an arithmetic operator.
1232pub fn is_arithmetic(expr: &Expression) -> bool {
1233    matches!(
1234        expr,
1235        Expression::Add(_)
1236            | Expression::Sub(_)
1237            | Expression::Mul(_)
1238            | Expression::Div(_)
1239            | Expression::Mod(_)
1240    )
1241}
1242
1243/// Returns `true` if `expr` is a logical operator (AND, OR, NOT).
1244pub fn is_logical(expr: &Expression) -> bool {
1245    matches!(
1246        expr,
1247        Expression::And(_) | Expression::Or(_) | Expression::Not(_)
1248    )
1249}
1250
1251/// Returns `true` if `expr` is a DDL statement.
1252pub fn is_ddl(expr: &Expression) -> bool {
1253    matches!(
1254        expr,
1255        Expression::CreateTable(_)
1256            | Expression::DropTable(_)
1257            | Expression::AlterTable(_)
1258            | Expression::CreateIndex(_)
1259            | Expression::DropIndex(_)
1260            | Expression::CreateView(_)
1261            | Expression::DropView(_)
1262            | Expression::AlterView(_)
1263            | Expression::CreateSchema(_)
1264            | Expression::DropSchema(_)
1265            | Expression::CreateDatabase(_)
1266            | Expression::DropDatabase(_)
1267            | Expression::CreateFunction(_)
1268            | Expression::DropFunction(_)
1269            | Expression::CreateProcedure(_)
1270            | Expression::DropProcedure(_)
1271            | Expression::CreateSequence(_)
1272            | Expression::DropSequence(_)
1273            | Expression::AlterSequence(_)
1274            | Expression::CreateTrigger(_)
1275            | Expression::DropTrigger(_)
1276            | Expression::CreateType(_)
1277            | Expression::DropType(_)
1278    )
1279}
1280
1281/// Find the parent of `target` within the tree rooted at `root`.
1282///
1283/// Uses pointer identity ([`std::ptr::eq`]) — `target` must be a reference
1284/// obtained from the same tree (e.g., via [`ExpressionWalk::find`] or DFS iteration).
1285///
1286/// Returns `None` if `target` is the root itself or is not found in the tree.
1287pub fn find_parent<'a>(root: &'a Expression, target: &Expression) -> Option<&'a Expression> {
1288    fn search<'a>(node: &'a Expression, target: *const Expression) -> Option<&'a Expression> {
1289        for (_, child) in iter_children(node) {
1290            if std::ptr::eq(child, target) {
1291                return Some(node);
1292            }
1293            if let Some(found) = search(child, target) {
1294                return Some(found);
1295            }
1296        }
1297        for (_, children_list) in iter_children_lists(node) {
1298            for child in children_list {
1299                if std::ptr::eq(child, target) {
1300                    return Some(node);
1301                }
1302                if let Some(found) = search(child, target) {
1303                    return Some(found);
1304                }
1305            }
1306        }
1307        None
1308    }
1309
1310    search(root, target as *const Expression)
1311}
1312
1313/// Find the first ancestor of `target` matching `predicate`, walking from
1314/// parent toward root.
1315///
1316/// Uses pointer identity for target lookup. Returns `None` if no ancestor
1317/// matches or `target` is not found in the tree.
1318pub fn find_ancestor<'a, F>(
1319    root: &'a Expression,
1320    target: &Expression,
1321    predicate: F,
1322) -> Option<&'a Expression>
1323where
1324    F: Fn(&Expression) -> bool,
1325{
1326    // Build path from root to target
1327    fn build_path<'a>(
1328        node: &'a Expression,
1329        target: *const Expression,
1330        path: &mut Vec<&'a Expression>,
1331    ) -> bool {
1332        if std::ptr::eq(node, target) {
1333            return true;
1334        }
1335        path.push(node);
1336        for (_, child) in iter_children(node) {
1337            if build_path(child, target, path) {
1338                return true;
1339            }
1340        }
1341        for (_, children_list) in iter_children_lists(node) {
1342            for child in children_list {
1343                if build_path(child, target, path) {
1344                    return true;
1345                }
1346            }
1347        }
1348        path.pop();
1349        false
1350    }
1351
1352    let mut path = Vec::new();
1353    if !build_path(root, target as *const Expression, &mut path) {
1354        return None;
1355    }
1356
1357    // Walk path in reverse (parent first, then grandparent, etc.)
1358    for ancestor in path.iter().rev() {
1359        if predicate(ancestor) {
1360            return Some(ancestor);
1361        }
1362    }
1363    None
1364}
1365
1366#[cfg(test)]
1367mod tests {
1368    use super::*;
1369    use crate::expressions::{BinaryOp, Column, Identifier, Literal};
1370
1371    fn make_column(name: &str) -> Expression {
1372        Expression::Column(Column {
1373            name: Identifier {
1374                name: name.to_string(),
1375                quoted: false,
1376                trailing_comments: vec![],
1377            },
1378            table: None,
1379            join_mark: false,
1380            trailing_comments: vec![],
1381        })
1382    }
1383
1384    fn make_literal(value: i64) -> Expression {
1385        Expression::Literal(Literal::Number(value.to_string()))
1386    }
1387
1388    #[test]
1389    fn test_dfs_simple() {
1390        let left = make_column("a");
1391        let right = make_literal(1);
1392        let expr = Expression::Eq(Box::new(BinaryOp {
1393            left,
1394            right,
1395            left_comments: vec![],
1396            operator_comments: vec![],
1397            trailing_comments: vec![],
1398        }));
1399
1400        let nodes: Vec<_> = expr.dfs().collect();
1401        assert_eq!(nodes.len(), 3); // Eq, Column, Literal
1402        assert!(matches!(nodes[0], Expression::Eq(_)));
1403        assert!(matches!(nodes[1], Expression::Column(_)));
1404        assert!(matches!(nodes[2], Expression::Literal(_)));
1405    }
1406
1407    #[test]
1408    fn test_find() {
1409        let left = make_column("a");
1410        let right = make_literal(1);
1411        let expr = Expression::Eq(Box::new(BinaryOp {
1412            left,
1413            right,
1414            left_comments: vec![],
1415            operator_comments: vec![],
1416            trailing_comments: vec![],
1417        }));
1418
1419        let column = expr.find(is_column);
1420        assert!(column.is_some());
1421        assert!(matches!(column.unwrap(), Expression::Column(_)));
1422
1423        let literal = expr.find(is_literal);
1424        assert!(literal.is_some());
1425        assert!(matches!(literal.unwrap(), Expression::Literal(_)));
1426    }
1427
1428    #[test]
1429    fn test_find_all() {
1430        let col1 = make_column("a");
1431        let col2 = make_column("b");
1432        let expr = Expression::And(Box::new(BinaryOp {
1433            left: col1,
1434            right: col2,
1435            left_comments: vec![],
1436            operator_comments: vec![],
1437            trailing_comments: vec![],
1438        }));
1439
1440        let columns = expr.find_all(is_column);
1441        assert_eq!(columns.len(), 2);
1442    }
1443
1444    #[test]
1445    fn test_contains() {
1446        let col = make_column("a");
1447        let lit = make_literal(1);
1448        let expr = Expression::Eq(Box::new(BinaryOp {
1449            left: col,
1450            right: lit,
1451            left_comments: vec![],
1452            operator_comments: vec![],
1453            trailing_comments: vec![],
1454        }));
1455
1456        assert!(expr.contains(is_column));
1457        assert!(expr.contains(is_literal));
1458        assert!(!expr.contains(is_subquery));
1459    }
1460
1461    #[test]
1462    fn test_count() {
1463        let col1 = make_column("a");
1464        let col2 = make_column("b");
1465        let lit = make_literal(1);
1466
1467        let inner = Expression::Add(Box::new(BinaryOp {
1468            left: col2,
1469            right: lit,
1470            left_comments: vec![],
1471            operator_comments: vec![],
1472            trailing_comments: vec![],
1473        }));
1474
1475        let expr = Expression::Eq(Box::new(BinaryOp {
1476            left: col1,
1477            right: inner,
1478            left_comments: vec![],
1479            operator_comments: vec![],
1480            trailing_comments: vec![],
1481        }));
1482
1483        assert_eq!(expr.count(is_column), 2);
1484        assert_eq!(expr.count(is_literal), 1);
1485    }
1486
1487    #[test]
1488    fn test_tree_depth() {
1489        // Single node
1490        let lit = make_literal(1);
1491        assert_eq!(lit.tree_depth(), 0);
1492
1493        // One level
1494        let col = make_column("a");
1495        let expr = Expression::Eq(Box::new(BinaryOp {
1496            left: col,
1497            right: lit.clone(),
1498            left_comments: vec![],
1499            operator_comments: vec![],
1500            trailing_comments: vec![],
1501        }));
1502        assert_eq!(expr.tree_depth(), 1);
1503
1504        // Two levels
1505        let inner = Expression::Add(Box::new(BinaryOp {
1506            left: make_column("b"),
1507            right: lit,
1508            left_comments: vec![],
1509            operator_comments: vec![],
1510            trailing_comments: vec![],
1511        }));
1512        let outer = Expression::Eq(Box::new(BinaryOp {
1513            left: make_column("a"),
1514            right: inner,
1515            left_comments: vec![],
1516            operator_comments: vec![],
1517            trailing_comments: vec![],
1518        }));
1519        assert_eq!(outer.tree_depth(), 2);
1520    }
1521
1522    #[test]
1523    fn test_tree_context() {
1524        let col = make_column("a");
1525        let lit = make_literal(1);
1526        let expr = Expression::Eq(Box::new(BinaryOp {
1527            left: col,
1528            right: lit,
1529            left_comments: vec![],
1530            operator_comments: vec![],
1531            trailing_comments: vec![],
1532        }));
1533
1534        let ctx = TreeContext::build(&expr);
1535
1536        // Root has no parent
1537        let root_info = ctx.get(0).unwrap();
1538        assert!(root_info.parent_id.is_none());
1539
1540        // Children have root as parent
1541        let left_info = ctx.get(1).unwrap();
1542        assert_eq!(left_info.parent_id, Some(0));
1543        assert_eq!(left_info.arg_key, "left");
1544
1545        let right_info = ctx.get(2).unwrap();
1546        assert_eq!(right_info.parent_id, Some(0));
1547        assert_eq!(right_info.arg_key, "right");
1548    }
1549
1550    // -- Step 8: transform / transform_map tests --
1551
1552    #[test]
1553    fn test_transform_rename_columns() {
1554        let ast = crate::parser::Parser::parse_sql("SELECT a, b FROM t").unwrap();
1555        let expr = ast[0].clone();
1556        let result = super::transform_map(expr, &|e| {
1557            if let Expression::Column(ref c) = e {
1558                if c.name.name == "a" {
1559                    return Ok(Expression::Column(Column {
1560                        name: Identifier::new("alpha"),
1561                        table: c.table.clone(),
1562                        join_mark: false,
1563                        trailing_comments: vec![],
1564                    }));
1565                }
1566            }
1567            Ok(e)
1568        })
1569        .unwrap();
1570        let sql = crate::generator::Generator::sql(&result).unwrap();
1571        assert!(sql.contains("alpha"), "Expected 'alpha' in: {}", sql);
1572        assert!(sql.contains("b"), "Expected 'b' in: {}", sql);
1573    }
1574
1575    #[test]
1576    fn test_transform_noop() {
1577        let ast = crate::parser::Parser::parse_sql("SELECT 1 + 2").unwrap();
1578        let expr = ast[0].clone();
1579        let result = super::transform_map(expr.clone(), &|e| Ok(e)).unwrap();
1580        let sql1 = crate::generator::Generator::sql(&expr).unwrap();
1581        let sql2 = crate::generator::Generator::sql(&result).unwrap();
1582        assert_eq!(sql1, sql2);
1583    }
1584
1585    #[test]
1586    fn test_transform_nested() {
1587        let ast = crate::parser::Parser::parse_sql("SELECT a + b FROM t").unwrap();
1588        let expr = ast[0].clone();
1589        let result = super::transform_map(expr, &|e| {
1590            if let Expression::Column(ref c) = e {
1591                return Ok(Expression::Literal(Literal::Number(
1592                    if c.name.name == "a" { "1" } else { "2" }.to_string(),
1593                )));
1594            }
1595            Ok(e)
1596        })
1597        .unwrap();
1598        let sql = crate::generator::Generator::sql(&result).unwrap();
1599        assert_eq!(sql, "SELECT 1 + 2 FROM t");
1600    }
1601
1602    #[test]
1603    fn test_transform_error() {
1604        let ast = crate::parser::Parser::parse_sql("SELECT a FROM t").unwrap();
1605        let expr = ast[0].clone();
1606        let result = super::transform_map(expr, &|e| {
1607            if let Expression::Column(ref c) = e {
1608                if c.name.name == "a" {
1609                    return Err(crate::error::Error::parse("test error", 0, 0));
1610                }
1611            }
1612            Ok(e)
1613        });
1614        assert!(result.is_err());
1615    }
1616
1617    #[test]
1618    fn test_transform_owned_trait() {
1619        let ast = crate::parser::Parser::parse_sql("SELECT x FROM t").unwrap();
1620        let expr = ast[0].clone();
1621        let result = expr.transform_owned(|e| Ok(Some(e))).unwrap();
1622        let sql = crate::generator::Generator::sql(&result).unwrap();
1623        assert_eq!(sql, "SELECT x FROM t");
1624    }
1625
1626    // -- children() tests --
1627
1628    #[test]
1629    fn test_children_leaf() {
1630        let lit = make_literal(1);
1631        assert_eq!(lit.children().len(), 0);
1632    }
1633
1634    #[test]
1635    fn test_children_binary_op() {
1636        let left = make_column("a");
1637        let right = make_literal(1);
1638        let expr = Expression::Eq(Box::new(BinaryOp {
1639            left,
1640            right,
1641            left_comments: vec![],
1642            operator_comments: vec![],
1643            trailing_comments: vec![],
1644        }));
1645        let children = expr.children();
1646        assert_eq!(children.len(), 2);
1647        assert!(matches!(children[0], Expression::Column(_)));
1648        assert!(matches!(children[1], Expression::Literal(_)));
1649    }
1650
1651    #[test]
1652    fn test_children_select() {
1653        let ast = crate::parser::Parser::parse_sql("SELECT a, b FROM t").unwrap();
1654        let expr = &ast[0];
1655        let children = expr.children();
1656        // Should include select list items (a, b)
1657        assert!(children.len() >= 2);
1658    }
1659
1660    #[test]
1661    fn test_children_select_includes_from_and_join_sources() {
1662        let ast = crate::parser::Parser::parse_sql(
1663            "SELECT u.id FROM users u JOIN orders o ON u.id = o.user_id",
1664        )
1665        .unwrap();
1666        let expr = &ast[0];
1667        let children = expr.children();
1668
1669        let table_names: Vec<&str> = children
1670            .iter()
1671            .filter_map(|e| match e {
1672                Expression::Table(t) => Some(t.name.name.as_str()),
1673                _ => None,
1674            })
1675            .collect();
1676
1677        assert!(table_names.contains(&"users"));
1678        assert!(table_names.contains(&"orders"));
1679    }
1680
1681    #[test]
1682    fn test_get_tables_includes_insert_query_sources() {
1683        let ast = crate::parser::Parser::parse_sql(
1684            "INSERT INTO dst (id) SELECT s.id FROM src s JOIN dim d ON s.id = d.id",
1685        )
1686        .unwrap();
1687        let expr = &ast[0];
1688        let tables = get_tables(expr);
1689        let names: Vec<&str> = tables
1690            .iter()
1691            .filter_map(|e| match e {
1692                Expression::Table(t) => Some(t.name.name.as_str()),
1693                _ => None,
1694            })
1695            .collect();
1696
1697        assert!(names.contains(&"src"));
1698        assert!(names.contains(&"dim"));
1699    }
1700
1701    // -- find_parent() tests --
1702
1703    #[test]
1704    fn test_find_parent_binary() {
1705        let left = make_column("a");
1706        let right = make_literal(1);
1707        let expr = Expression::Eq(Box::new(BinaryOp {
1708            left,
1709            right,
1710            left_comments: vec![],
1711            operator_comments: vec![],
1712            trailing_comments: vec![],
1713        }));
1714
1715        // Find the column child and get its parent
1716        let col = expr.find(is_column).unwrap();
1717        let parent = super::find_parent(&expr, col);
1718        assert!(parent.is_some());
1719        assert!(matches!(parent.unwrap(), Expression::Eq(_)));
1720    }
1721
1722    #[test]
1723    fn test_find_parent_root_has_none() {
1724        let lit = make_literal(1);
1725        let parent = super::find_parent(&lit, &lit);
1726        assert!(parent.is_none());
1727    }
1728
1729    // -- find_ancestor() tests --
1730
1731    #[test]
1732    fn test_find_ancestor_select() {
1733        let ast = crate::parser::Parser::parse_sql("SELECT a FROM t WHERE a > 1").unwrap();
1734        let expr = &ast[0];
1735
1736        // Find a column inside the WHERE clause
1737        let where_col = expr.dfs().find(|e| {
1738            if let Expression::Column(c) = e {
1739                c.name.name == "a"
1740            } else {
1741                false
1742            }
1743        });
1744        assert!(where_col.is_some());
1745
1746        // Find Select ancestor of that column
1747        let ancestor = super::find_ancestor(expr, where_col.unwrap(), is_select);
1748        assert!(ancestor.is_some());
1749        assert!(matches!(ancestor.unwrap(), Expression::Select(_)));
1750    }
1751
1752    #[test]
1753    fn test_find_ancestor_no_match() {
1754        let left = make_column("a");
1755        let right = make_literal(1);
1756        let expr = Expression::Eq(Box::new(BinaryOp {
1757            left,
1758            right,
1759            left_comments: vec![],
1760            operator_comments: vec![],
1761            trailing_comments: vec![],
1762        }));
1763
1764        let col = expr.find(is_column).unwrap();
1765        let ancestor = super::find_ancestor(&expr, col, is_select);
1766        assert!(ancestor.is_none());
1767    }
1768
1769    #[test]
1770    fn test_ancestors() {
1771        let col = make_column("a");
1772        let lit = make_literal(1);
1773        let inner = Expression::Add(Box::new(BinaryOp {
1774            left: col,
1775            right: lit,
1776            left_comments: vec![],
1777            operator_comments: vec![],
1778            trailing_comments: vec![],
1779        }));
1780        let outer = Expression::Eq(Box::new(BinaryOp {
1781            left: make_column("b"),
1782            right: inner,
1783            left_comments: vec![],
1784            operator_comments: vec![],
1785            trailing_comments: vec![],
1786        }));
1787
1788        let ctx = TreeContext::build(&outer);
1789
1790        // The inner Add's left child (column "a") should have ancestors
1791        // Node 0: Eq
1792        // Node 1: Column "b" (left of Eq)
1793        // Node 2: Add (right of Eq)
1794        // Node 3: Column "a" (left of Add)
1795        // Node 4: Literal (right of Add)
1796
1797        let ancestors = ctx.ancestors_of(3);
1798        assert_eq!(ancestors, vec![2, 0]); // Add, then Eq
1799    }
1800}