Skip to main content

polyglot_sql/
traversal.rs

1//! Tree traversal utilities for SQL expression ASTs.
2//!
3//! This module provides read-only traversal, search, and transformation utilities
4//! for the [`Expression`] tree produced by the parser. Because Rust's ownership
5//! model does not allow parent pointers inside the AST, parent information is
6//! tracked externally via [`TreeContext`] (built on demand).
7//!
8//! # Traversal
9//!
10//! Two iterator types are provided:
11//! - [`DfsIter`] -- depth-first (pre-order) traversal using a stack. Visits a node
12//!   before its children. Good for top-down analysis and early termination.
13//! - [`BfsIter`] -- breadth-first (level-order) traversal using a queue. Visits all
14//!   nodes at depth N before any node at depth N+1. Good for level-aware analysis.
15//!
16//! Both are available through the [`ExpressionWalk`] trait methods [`dfs`](ExpressionWalk::dfs)
17//! and [`bfs`](ExpressionWalk::bfs).
18//!
19//! # Searching
20//!
21//! The [`ExpressionWalk`] trait also provides convenience methods for finding expressions:
22//! [`find`](ExpressionWalk::find), [`find_all`](ExpressionWalk::find_all),
23//! [`contains`](ExpressionWalk::contains), and [`count`](ExpressionWalk::count).
24//! Common predicates are available as free functions: [`is_column`], [`is_literal`],
25//! [`is_function`], [`is_aggregate`], [`is_window_function`], [`is_subquery`], and
26//! [`is_select`].
27//!
28//! # Transformation
29//!
30//! The [`transform`] and [`transform_map`] functions perform bottom-up (post-order)
31//! tree rewrites, delegating to [`transform_recursive`](crate::dialects::transform_recursive).
32//! The [`ExpressionWalk::transform_owned`] method provides the same capability as
33//! an owned method on `Expression`.
34//!
35//! Based on traversal patterns from `sqlglot/expressions.py`.
36
37use crate::expressions::Expression;
38use std::collections::{HashMap, VecDeque};
39
40/// Unique identifier for expression nodes during traversal
41pub type NodeId = usize;
42
43/// Information about a node's parent relationship
44#[derive(Debug, Clone)]
45pub struct ParentInfo {
46    /// The NodeId of the parent (None for root)
47    pub parent_id: Option<NodeId>,
48    /// Which argument/field in the parent this node occupies
49    pub arg_key: String,
50    /// Index if the node is part of a list (e.g., expressions in SELECT)
51    pub index: Option<usize>,
52}
53
54/// External parent-tracking context for an expression tree.
55///
56/// Since Rust's ownership model does not allow intrusive parent pointers in the AST,
57/// `TreeContext` provides an on-demand side-table that maps each node (identified by
58/// a [`NodeId`]) to its [`ParentInfo`] (parent node, field name, and list index).
59///
60/// Build a context from any expression root with [`TreeContext::build`], then query
61/// parent relationships with [`get`](TreeContext::get), ancestry chains with
62/// [`ancestors_of`](TreeContext::ancestors_of), or tree depth with
63/// [`depth_of`](TreeContext::depth_of).
64///
65/// This is useful when analysis requires upward navigation (e.g., determining whether
66/// a column reference appears inside a WHERE clause or a JOIN condition).
67#[derive(Debug, Default)]
68pub struct TreeContext {
69    /// Map from NodeId to parent information
70    nodes: HashMap<NodeId, ParentInfo>,
71    /// Counter for generating NodeIds
72    next_id: NodeId,
73    /// Stack for tracking current path during traversal
74    path: Vec<(NodeId, String, Option<usize>)>,
75}
76
77impl TreeContext {
78    /// Create a new empty tree context
79    pub fn new() -> Self {
80        Self::default()
81    }
82
83    /// Build context from an expression tree
84    pub fn build(root: &Expression) -> Self {
85        let mut ctx = Self::new();
86        ctx.visit_expr(root);
87        ctx
88    }
89
90    /// Visit an expression and record parent information
91    fn visit_expr(&mut self, expr: &Expression) -> NodeId {
92        let id = self.next_id;
93        self.next_id += 1;
94
95        // Record parent info based on current path
96        let parent_info = if let Some((parent_id, arg_key, index)) = self.path.last() {
97            ParentInfo {
98                parent_id: Some(*parent_id),
99                arg_key: arg_key.clone(),
100                index: *index,
101            }
102        } else {
103            ParentInfo {
104                parent_id: None,
105                arg_key: String::new(),
106                index: None,
107            }
108        };
109        self.nodes.insert(id, parent_info);
110
111        // Visit children
112        for (key, child) in iter_children(expr) {
113            self.path.push((id, key.to_string(), None));
114            self.visit_expr(child);
115            self.path.pop();
116        }
117
118        // Visit children in lists
119        for (key, children) in iter_children_lists(expr) {
120            for (idx, child) in children.iter().enumerate() {
121                self.path.push((id, key.to_string(), Some(idx)));
122                self.visit_expr(child);
123                self.path.pop();
124            }
125        }
126
127        id
128    }
129
130    /// Get parent info for a node
131    pub fn get(&self, id: NodeId) -> Option<&ParentInfo> {
132        self.nodes.get(&id)
133    }
134
135    /// Get the depth of a node (0 for root)
136    pub fn depth_of(&self, id: NodeId) -> usize {
137        let mut depth = 0;
138        let mut current = id;
139        while let Some(info) = self.nodes.get(&current) {
140            if let Some(parent_id) = info.parent_id {
141                depth += 1;
142                current = parent_id;
143            } else {
144                break;
145            }
146        }
147        depth
148    }
149
150    /// Get ancestors of a node (parent, grandparent, etc.)
151    pub fn ancestors_of(&self, id: NodeId) -> Vec<NodeId> {
152        let mut ancestors = Vec::new();
153        let mut current = id;
154        while let Some(info) = self.nodes.get(&current) {
155            if let Some(parent_id) = info.parent_id {
156                ancestors.push(parent_id);
157                current = parent_id;
158            } else {
159                break;
160            }
161        }
162        ancestors
163    }
164}
165
166/// Iterate over single-child fields of an expression
167///
168/// Returns an iterator of (field_name, &Expression) pairs.
169fn iter_children(expr: &Expression) -> Vec<(&'static str, &Expression)> {
170    let mut children = Vec::new();
171
172    match expr {
173        Expression::Alias(a) => {
174            children.push(("this", &a.this));
175        }
176        Expression::Cast(c) => {
177            children.push(("this", &c.this));
178        }
179        Expression::Not(u) | Expression::Neg(u) | Expression::BitwiseNot(u) => {
180            children.push(("this", &u.this));
181        }
182        Expression::Paren(p) => {
183            children.push(("this", &p.this));
184        }
185        Expression::IsNull(i) => {
186            children.push(("this", &i.this));
187        }
188        Expression::Exists(e) => {
189            children.push(("this", &e.this));
190        }
191        Expression::Subquery(s) => {
192            children.push(("this", &s.this));
193        }
194        Expression::Where(w) => {
195            children.push(("this", &w.this));
196        }
197        Expression::Having(h) => {
198            children.push(("this", &h.this));
199        }
200        Expression::Qualify(q) => {
201            children.push(("this", &q.this));
202        }
203        Expression::And(op)
204        | Expression::Or(op)
205        | Expression::Add(op)
206        | Expression::Sub(op)
207        | Expression::Mul(op)
208        | Expression::Div(op)
209        | Expression::Mod(op)
210        | Expression::Eq(op)
211        | Expression::Neq(op)
212        | Expression::Lt(op)
213        | Expression::Lte(op)
214        | Expression::Gt(op)
215        | Expression::Gte(op)
216        | Expression::BitwiseAnd(op)
217        | Expression::BitwiseOr(op)
218        | Expression::BitwiseXor(op)
219        | Expression::Concat(op) => {
220            children.push(("left", &op.left));
221            children.push(("right", &op.right));
222        }
223        Expression::Like(op) | Expression::ILike(op) => {
224            children.push(("left", &op.left));
225            children.push(("right", &op.right));
226        }
227        Expression::Between(b) => {
228            children.push(("this", &b.this));
229            children.push(("low", &b.low));
230            children.push(("high", &b.high));
231        }
232        Expression::In(i) => {
233            children.push(("this", &i.this));
234        }
235        Expression::Case(c) => {
236            if let Some(ref operand) = &c.operand {
237                children.push(("operand", operand));
238            }
239        }
240        Expression::WindowFunction(wf) => {
241            children.push(("this", &wf.this));
242        }
243        Expression::Union(u) => {
244            children.push(("left", &u.left));
245            children.push(("right", &u.right));
246        }
247        Expression::Intersect(i) => {
248            children.push(("left", &i.left));
249            children.push(("right", &i.right));
250        }
251        Expression::Except(e) => {
252            children.push(("left", &e.left));
253            children.push(("right", &e.right));
254        }
255        Expression::Ordered(o) => {
256            children.push(("this", &o.this));
257        }
258        Expression::Interval(i) => {
259            if let Some(ref this) = i.this {
260                children.push(("this", this));
261            }
262        }
263        _ => {}
264    }
265
266    children
267}
268
269/// Iterate over list-child fields of an expression
270///
271/// Returns an iterator of (field_name, &[Expression]) pairs.
272fn iter_children_lists(expr: &Expression) -> Vec<(&'static str, &[Expression])> {
273    let mut lists = Vec::new();
274
275    match expr {
276        Expression::Select(s) => {
277            lists.push(("expressions", s.expressions.as_slice()));
278            // Note: FROM, JOINs, etc. are stored differently
279        }
280        Expression::Function(f) => {
281            lists.push(("args", f.args.as_slice()));
282        }
283        Expression::AggregateFunction(f) => {
284            lists.push(("args", f.args.as_slice()));
285        }
286        Expression::From(f) => {
287            lists.push(("expressions", f.expressions.as_slice()));
288        }
289        Expression::GroupBy(g) => {
290            lists.push(("expressions", g.expressions.as_slice()));
291        }
292        // OrderBy.expressions is Vec<Ordered>, not Vec<Expression>
293        // We handle Ordered items via iter_children
294        Expression::In(i) => {
295            lists.push(("expressions", i.expressions.as_slice()));
296        }
297        Expression::Array(a) => {
298            lists.push(("expressions", a.expressions.as_slice()));
299        }
300        Expression::Tuple(t) => {
301            lists.push(("expressions", t.expressions.as_slice()));
302        }
303        // Values.expressions is Vec<Tuple>, handle specially
304        Expression::Coalesce(c) => {
305            lists.push(("expressions", c.expressions.as_slice()));
306        }
307        Expression::Greatest(g) | Expression::Least(g) => {
308            lists.push(("expressions", g.expressions.as_slice()));
309        }
310        _ => {}
311    }
312
313    lists
314}
315
316/// Pre-order depth-first iterator over an expression tree.
317///
318/// Visits each node before its children, using a stack-based approach. This means
319/// the root is yielded first, followed by the entire left subtree (recursively),
320/// then the right subtree. For a binary expression `a + b`, the iteration order
321/// is: `Add`, `a`, `b`.
322///
323/// Created via [`ExpressionWalk::dfs`] or [`DfsIter::new`].
324pub struct DfsIter<'a> {
325    stack: Vec<&'a Expression>,
326}
327
328impl<'a> DfsIter<'a> {
329    /// Create a new DFS iterator starting from the given expression
330    pub fn new(root: &'a Expression) -> Self {
331        Self { stack: vec![root] }
332    }
333}
334
335impl<'a> Iterator for DfsIter<'a> {
336    type Item = &'a Expression;
337
338    fn next(&mut self) -> Option<Self::Item> {
339        let expr = self.stack.pop()?;
340
341        // Add children in reverse order so they come out in forward order
342        let children: Vec<_> = iter_children(expr).into_iter().map(|(_, e)| e).collect();
343        for child in children.into_iter().rev() {
344            self.stack.push(child);
345        }
346
347        let lists: Vec<_> = iter_children_lists(expr)
348            .into_iter()
349            .flat_map(|(_, es)| es.iter())
350            .collect();
351        for child in lists.into_iter().rev() {
352            self.stack.push(child);
353        }
354
355        Some(expr)
356    }
357}
358
359/// Level-order breadth-first iterator over an expression tree.
360///
361/// Visits all nodes at depth N before any node at depth N+1, using a queue-based
362/// approach. For a tree `(a + b) = c`, the iteration order is: `Eq` (depth 0),
363/// `Add`, `c` (depth 1), `a`, `b` (depth 2).
364///
365/// Created via [`ExpressionWalk::bfs`] or [`BfsIter::new`].
366pub struct BfsIter<'a> {
367    queue: VecDeque<&'a Expression>,
368}
369
370impl<'a> BfsIter<'a> {
371    /// Create a new BFS iterator starting from the given expression
372    pub fn new(root: &'a Expression) -> Self {
373        let mut queue = VecDeque::new();
374        queue.push_back(root);
375        Self { queue }
376    }
377}
378
379impl<'a> Iterator for BfsIter<'a> {
380    type Item = &'a Expression;
381
382    fn next(&mut self) -> Option<Self::Item> {
383        let expr = self.queue.pop_front()?;
384
385        // Add children to queue
386        for (_, child) in iter_children(expr) {
387            self.queue.push_back(child);
388        }
389
390        for (_, children) in iter_children_lists(expr) {
391            for child in children {
392                self.queue.push_back(child);
393            }
394        }
395
396        Some(expr)
397    }
398}
399
400/// Extension trait that adds traversal and search methods to [`Expression`].
401///
402/// This trait is implemented for `Expression` and provides a fluent API for
403/// iterating, searching, measuring, and transforming expression trees without
404/// needing to import the iterator types directly.
405pub trait ExpressionWalk {
406    /// Returns a depth-first (pre-order) iterator over this expression and all descendants.
407    ///
408    /// The root node is yielded first, then its children are visited recursively
409    /// from left to right.
410    fn dfs(&self) -> DfsIter<'_>;
411
412    /// Returns a breadth-first (level-order) iterator over this expression and all descendants.
413    ///
414    /// All nodes at depth N are yielded before any node at depth N+1.
415    fn bfs(&self) -> BfsIter<'_>;
416
417    /// Finds the first expression matching `predicate` in depth-first order.
418    ///
419    /// Returns `None` if no descendant (including this node) matches.
420    fn find<F>(&self, predicate: F) -> Option<&Expression>
421    where
422        F: Fn(&Expression) -> bool;
423
424    /// Collects all expressions matching `predicate` in depth-first order.
425    ///
426    /// Returns an empty vector if no descendants match.
427    fn find_all<F>(&self, predicate: F) -> Vec<&Expression>
428    where
429        F: Fn(&Expression) -> bool;
430
431    /// Returns `true` if this node or any descendant matches `predicate`.
432    fn contains<F>(&self, predicate: F) -> bool
433    where
434        F: Fn(&Expression) -> bool;
435
436    /// Counts how many nodes (including this one) match `predicate`.
437    fn count<F>(&self, predicate: F) -> usize
438    where
439        F: Fn(&Expression) -> bool;
440
441    /// Returns direct child expressions of this node.
442    ///
443    /// Collects all single-child fields and list-child fields into a flat vector
444    /// of references. Leaf nodes return an empty vector.
445    fn children(&self) -> Vec<&Expression>;
446
447    /// Returns the maximum depth of the expression tree rooted at this node.
448    ///
449    /// A leaf node has depth 0, a node whose deepest child is a leaf has depth 1, etc.
450    fn tree_depth(&self) -> usize;
451
452    /// Transforms this expression tree bottom-up using the given function (owned variant).
453    ///
454    /// Children are transformed first, then `fun` is called on the resulting node.
455    /// Return `Ok(None)` from `fun` to replace a node with `NULL`.
456    /// Return `Ok(Some(expr))` to substitute the node with `expr`.
457    fn transform_owned<F>(self, fun: F) -> crate::Result<Expression>
458    where
459        F: Fn(Expression) -> crate::Result<Option<Expression>>,
460        Self: Sized;
461}
462
463impl ExpressionWalk for Expression {
464    fn dfs(&self) -> DfsIter<'_> {
465        DfsIter::new(self)
466    }
467
468    fn bfs(&self) -> BfsIter<'_> {
469        BfsIter::new(self)
470    }
471
472    fn find<F>(&self, predicate: F) -> Option<&Expression>
473    where
474        F: Fn(&Expression) -> bool,
475    {
476        self.dfs().find(|e| predicate(e))
477    }
478
479    fn find_all<F>(&self, predicate: F) -> Vec<&Expression>
480    where
481        F: Fn(&Expression) -> bool,
482    {
483        self.dfs().filter(|e| predicate(e)).collect()
484    }
485
486    fn contains<F>(&self, predicate: F) -> bool
487    where
488        F: Fn(&Expression) -> bool,
489    {
490        self.dfs().any(|e| predicate(e))
491    }
492
493    fn count<F>(&self, predicate: F) -> usize
494    where
495        F: Fn(&Expression) -> bool,
496    {
497        self.dfs().filter(|e| predicate(e)).count()
498    }
499
500    fn children(&self) -> Vec<&Expression> {
501        let mut result: Vec<&Expression> = Vec::new();
502        for (_, child) in iter_children(self) {
503            result.push(child);
504        }
505        for (_, children_list) in iter_children_lists(self) {
506            for child in children_list {
507                result.push(child);
508            }
509        }
510        result
511    }
512
513    fn tree_depth(&self) -> usize {
514        let mut max_depth = 0;
515
516        for (_, child) in iter_children(self) {
517            let child_depth = child.tree_depth();
518            if child_depth + 1 > max_depth {
519                max_depth = child_depth + 1;
520            }
521        }
522
523        for (_, children) in iter_children_lists(self) {
524            for child in children {
525                let child_depth = child.tree_depth();
526                if child_depth + 1 > max_depth {
527                    max_depth = child_depth + 1;
528                }
529            }
530        }
531
532        max_depth
533    }
534
535    fn transform_owned<F>(self, fun: F) -> crate::Result<Expression>
536    where
537        F: Fn(Expression) -> crate::Result<Option<Expression>>,
538    {
539        transform(self, &fun)
540    }
541}
542
543/// Transforms an expression tree bottom-up, with optional node removal.
544///
545/// Recursively transforms all children first, then applies `fun` to the resulting node.
546/// If `fun` returns `Ok(None)`, the node is replaced with an `Expression::Null`.
547/// If `fun` returns `Ok(Some(expr))`, the node is replaced with `expr`.
548///
549/// This is the primary transformation entry point when callers need the ability to
550/// "delete" nodes by returning `None`.
551///
552/// # Example
553///
554/// ```rust,ignore
555/// use polyglot_sql::traversal::transform;
556///
557/// // Remove all Paren wrapper nodes from a tree
558/// let result = transform(expr, &|e| match e {
559///     Expression::Paren(p) => Ok(Some(p.this)),
560///     other => Ok(Some(other)),
561/// })?;
562/// ```
563pub fn transform<F>(expr: Expression, fun: &F) -> crate::Result<Expression>
564where
565    F: Fn(Expression) -> crate::Result<Option<Expression>>,
566{
567    crate::dialects::transform_recursive(expr, &|e| match fun(e)? {
568        Some(transformed) => Ok(transformed),
569        None => Ok(Expression::Null(crate::expressions::Null)),
570    })
571}
572
573/// Transforms an expression tree bottom-up without node removal.
574///
575/// Like [`transform`], but `fun` returns an `Expression` directly rather than
576/// `Option<Expression>`, so nodes cannot be deleted. This is a convenience wrapper
577/// for the common case where every node is mapped to exactly one output node.
578///
579/// # Example
580///
581/// ```rust,ignore
582/// use polyglot_sql::traversal::transform_map;
583///
584/// // Uppercase all column names in a tree
585/// let result = transform_map(expr, &|e| match e {
586///     Expression::Column(mut c) => {
587///         c.name.name = c.name.name.to_uppercase();
588///         Ok(Expression::Column(c))
589///     }
590///     other => Ok(other),
591/// })?;
592/// ```
593pub fn transform_map<F>(expr: Expression, fun: &F) -> crate::Result<Expression>
594where
595    F: Fn(Expression) -> crate::Result<Expression>,
596{
597    crate::dialects::transform_recursive(expr, fun)
598}
599
600// ---------------------------------------------------------------------------
601// Common expression predicates
602// ---------------------------------------------------------------------------
603// These free functions are intended for use with the search methods on
604// `ExpressionWalk` (e.g., `expr.find(is_column)`, `expr.contains(is_aggregate)`).
605
606/// Returns `true` if `expr` is a column reference ([`Expression::Column`]).
607pub fn is_column(expr: &Expression) -> bool {
608    matches!(expr, Expression::Column(_))
609}
610
611/// Returns `true` if `expr` is a literal value (number, string, boolean, or NULL).
612pub fn is_literal(expr: &Expression) -> bool {
613    matches!(
614        expr,
615        Expression::Literal(_) | Expression::Boolean(_) | Expression::Null(_)
616    )
617}
618
619/// Returns `true` if `expr` is a function call (regular or aggregate).
620pub fn is_function(expr: &Expression) -> bool {
621    matches!(
622        expr,
623        Expression::Function(_) | Expression::AggregateFunction(_)
624    )
625}
626
627/// Returns `true` if `expr` is a subquery ([`Expression::Subquery`]).
628pub fn is_subquery(expr: &Expression) -> bool {
629    matches!(expr, Expression::Subquery(_))
630}
631
632/// Returns `true` if `expr` is a SELECT statement ([`Expression::Select`]).
633pub fn is_select(expr: &Expression) -> bool {
634    matches!(expr, Expression::Select(_))
635}
636
637/// Returns `true` if `expr` is an aggregate function ([`Expression::AggregateFunction`]).
638pub fn is_aggregate(expr: &Expression) -> bool {
639    matches!(expr, Expression::AggregateFunction(_))
640}
641
642/// Returns `true` if `expr` is a window function ([`Expression::WindowFunction`]).
643pub fn is_window_function(expr: &Expression) -> bool {
644    matches!(expr, Expression::WindowFunction(_))
645}
646
647/// Collects all column references ([`Expression::Column`]) from the expression tree.
648///
649/// Performs a depth-first search and returns references to every column node found.
650pub fn get_columns(expr: &Expression) -> Vec<&Expression> {
651    expr.find_all(is_column)
652}
653
654/// Collects all table references ([`Expression::Table`]) from the expression tree.
655///
656/// Performs a depth-first search and returns references to every table node found.
657pub fn get_tables(expr: &Expression) -> Vec<&Expression> {
658    expr.find_all(|e| matches!(e, Expression::Table(_)))
659}
660
661/// Returns `true` if the expression tree contains any aggregate function calls.
662pub fn contains_aggregate(expr: &Expression) -> bool {
663    expr.contains(is_aggregate)
664}
665
666/// Returns `true` if the expression tree contains any window function calls.
667pub fn contains_window_function(expr: &Expression) -> bool {
668    expr.contains(is_window_function)
669}
670
671/// Returns `true` if the expression tree contains any subquery nodes.
672pub fn contains_subquery(expr: &Expression) -> bool {
673    expr.contains(is_subquery)
674}
675
676// ---------------------------------------------------------------------------
677// Extended type predicates
678// ---------------------------------------------------------------------------
679
680/// Macro for generating simple type-predicate functions.
681macro_rules! is_type {
682    ($name:ident, $($variant:pat),+ $(,)?) => {
683        /// Returns `true` if `expr` matches the expected AST variant(s).
684        pub fn $name(expr: &Expression) -> bool {
685            matches!(expr, $($variant)|+)
686        }
687    };
688}
689
690// Query
691is_type!(is_insert, Expression::Insert(_));
692is_type!(is_update, Expression::Update(_));
693is_type!(is_delete, Expression::Delete(_));
694is_type!(is_union, Expression::Union(_));
695is_type!(is_intersect, Expression::Intersect(_));
696is_type!(is_except, Expression::Except(_));
697
698// Identifiers & literals
699is_type!(is_boolean, Expression::Boolean(_));
700is_type!(is_null_literal, Expression::Null(_));
701is_type!(is_star, Expression::Star(_));
702is_type!(is_identifier, Expression::Identifier(_));
703is_type!(is_table, Expression::Table(_));
704
705// Comparison
706is_type!(is_eq, Expression::Eq(_));
707is_type!(is_neq, Expression::Neq(_));
708is_type!(is_lt, Expression::Lt(_));
709is_type!(is_lte, Expression::Lte(_));
710is_type!(is_gt, Expression::Gt(_));
711is_type!(is_gte, Expression::Gte(_));
712is_type!(is_like, Expression::Like(_));
713is_type!(is_ilike, Expression::ILike(_));
714
715// Arithmetic
716is_type!(is_add, Expression::Add(_));
717is_type!(is_sub, Expression::Sub(_));
718is_type!(is_mul, Expression::Mul(_));
719is_type!(is_div, Expression::Div(_));
720is_type!(is_mod, Expression::Mod(_));
721is_type!(is_concat, Expression::Concat(_));
722
723// Logical
724is_type!(is_and, Expression::And(_));
725is_type!(is_or, Expression::Or(_));
726is_type!(is_not, Expression::Not(_));
727
728// Predicates
729is_type!(is_in, Expression::In(_));
730is_type!(is_between, Expression::Between(_));
731is_type!(is_is_null, Expression::IsNull(_));
732is_type!(is_exists, Expression::Exists(_));
733
734// Functions
735is_type!(is_count, Expression::Count(_));
736is_type!(is_sum, Expression::Sum(_));
737is_type!(is_avg, Expression::Avg(_));
738is_type!(is_min_func, Expression::Min(_));
739is_type!(is_max_func, Expression::Max(_));
740is_type!(is_coalesce, Expression::Coalesce(_));
741is_type!(is_null_if, Expression::NullIf(_));
742is_type!(is_cast, Expression::Cast(_));
743is_type!(is_try_cast, Expression::TryCast(_));
744is_type!(is_safe_cast, Expression::SafeCast(_));
745is_type!(is_case, Expression::Case(_));
746
747// Clauses
748is_type!(is_from, Expression::From(_));
749is_type!(is_join, Expression::Join(_));
750is_type!(is_where, Expression::Where(_));
751is_type!(is_group_by, Expression::GroupBy(_));
752is_type!(is_having, Expression::Having(_));
753is_type!(is_order_by, Expression::OrderBy(_));
754is_type!(is_limit, Expression::Limit(_));
755is_type!(is_offset, Expression::Offset(_));
756is_type!(is_with, Expression::With(_));
757is_type!(is_cte, Expression::Cte(_));
758is_type!(is_alias, Expression::Alias(_));
759is_type!(is_paren, Expression::Paren(_));
760is_type!(is_ordered, Expression::Ordered(_));
761
762// DDL
763is_type!(is_create_table, Expression::CreateTable(_));
764is_type!(is_drop_table, Expression::DropTable(_));
765is_type!(is_alter_table, Expression::AlterTable(_));
766is_type!(is_create_index, Expression::CreateIndex(_));
767is_type!(is_drop_index, Expression::DropIndex(_));
768is_type!(is_create_view, Expression::CreateView(_));
769is_type!(is_drop_view, Expression::DropView(_));
770
771// ---------------------------------------------------------------------------
772// Composite predicates
773// ---------------------------------------------------------------------------
774
775/// Returns `true` if `expr` is a query statement (SELECT, INSERT, UPDATE, or DELETE).
776pub fn is_query(expr: &Expression) -> bool {
777    matches!(
778        expr,
779        Expression::Select(_) | Expression::Insert(_) | Expression::Update(_) | Expression::Delete(_)
780    )
781}
782
783/// Returns `true` if `expr` is a set operation (UNION, INTERSECT, or EXCEPT).
784pub fn is_set_operation(expr: &Expression) -> bool {
785    matches!(
786        expr,
787        Expression::Union(_) | Expression::Intersect(_) | Expression::Except(_)
788    )
789}
790
791/// Returns `true` if `expr` is a comparison operator.
792pub fn is_comparison(expr: &Expression) -> bool {
793    matches!(
794        expr,
795        Expression::Eq(_)
796            | Expression::Neq(_)
797            | Expression::Lt(_)
798            | Expression::Lte(_)
799            | Expression::Gt(_)
800            | Expression::Gte(_)
801            | Expression::Like(_)
802            | Expression::ILike(_)
803    )
804}
805
806/// Returns `true` if `expr` is an arithmetic operator.
807pub fn is_arithmetic(expr: &Expression) -> bool {
808    matches!(
809        expr,
810        Expression::Add(_)
811            | Expression::Sub(_)
812            | Expression::Mul(_)
813            | Expression::Div(_)
814            | Expression::Mod(_)
815    )
816}
817
818/// Returns `true` if `expr` is a logical operator (AND, OR, NOT).
819pub fn is_logical(expr: &Expression) -> bool {
820    matches!(
821        expr,
822        Expression::And(_) | Expression::Or(_) | Expression::Not(_)
823    )
824}
825
826/// Returns `true` if `expr` is a DDL statement.
827pub fn is_ddl(expr: &Expression) -> bool {
828    matches!(
829        expr,
830        Expression::CreateTable(_)
831            | Expression::DropTable(_)
832            | Expression::AlterTable(_)
833            | Expression::CreateIndex(_)
834            | Expression::DropIndex(_)
835            | Expression::CreateView(_)
836            | Expression::DropView(_)
837            | Expression::AlterView(_)
838            | Expression::CreateSchema(_)
839            | Expression::DropSchema(_)
840            | Expression::CreateDatabase(_)
841            | Expression::DropDatabase(_)
842            | Expression::CreateFunction(_)
843            | Expression::DropFunction(_)
844            | Expression::CreateProcedure(_)
845            | Expression::DropProcedure(_)
846            | Expression::CreateSequence(_)
847            | Expression::DropSequence(_)
848            | Expression::AlterSequence(_)
849            | Expression::CreateTrigger(_)
850            | Expression::DropTrigger(_)
851            | Expression::CreateType(_)
852            | Expression::DropType(_)
853    )
854}
855
856/// Find the parent of `target` within the tree rooted at `root`.
857///
858/// Uses pointer identity ([`std::ptr::eq`]) — `target` must be a reference
859/// obtained from the same tree (e.g., via [`ExpressionWalk::find`] or DFS iteration).
860///
861/// Returns `None` if `target` is the root itself or is not found in the tree.
862pub fn find_parent<'a>(root: &'a Expression, target: &Expression) -> Option<&'a Expression> {
863    fn search<'a>(node: &'a Expression, target: *const Expression) -> Option<&'a Expression> {
864        for (_, child) in iter_children(node) {
865            if std::ptr::eq(child, target) {
866                return Some(node);
867            }
868            if let Some(found) = search(child, target) {
869                return Some(found);
870            }
871        }
872        for (_, children_list) in iter_children_lists(node) {
873            for child in children_list {
874                if std::ptr::eq(child, target) {
875                    return Some(node);
876                }
877                if let Some(found) = search(child, target) {
878                    return Some(found);
879                }
880            }
881        }
882        None
883    }
884
885    search(root, target as *const Expression)
886}
887
888/// Find the first ancestor of `target` matching `predicate`, walking from
889/// parent toward root.
890///
891/// Uses pointer identity for target lookup. Returns `None` if no ancestor
892/// matches or `target` is not found in the tree.
893pub fn find_ancestor<'a, F>(
894    root: &'a Expression,
895    target: &Expression,
896    predicate: F,
897) -> Option<&'a Expression>
898where
899    F: Fn(&Expression) -> bool,
900{
901    // Build path from root to target
902    fn build_path<'a>(
903        node: &'a Expression,
904        target: *const Expression,
905        path: &mut Vec<&'a Expression>,
906    ) -> bool {
907        if std::ptr::eq(node, target) {
908            return true;
909        }
910        path.push(node);
911        for (_, child) in iter_children(node) {
912            if build_path(child, target, path) {
913                return true;
914            }
915        }
916        for (_, children_list) in iter_children_lists(node) {
917            for child in children_list {
918                if build_path(child, target, path) {
919                    return true;
920                }
921            }
922        }
923        path.pop();
924        false
925    }
926
927    let mut path = Vec::new();
928    if !build_path(root, target as *const Expression, &mut path) {
929        return None;
930    }
931
932    // Walk path in reverse (parent first, then grandparent, etc.)
933    for ancestor in path.iter().rev() {
934        if predicate(ancestor) {
935            return Some(ancestor);
936        }
937    }
938    None
939}
940
941#[cfg(test)]
942mod tests {
943    use super::*;
944    use crate::expressions::{BinaryOp, Column, Identifier, Literal};
945
946    fn make_column(name: &str) -> Expression {
947        Expression::Column(Column {
948            name: Identifier {
949                name: name.to_string(),
950                quoted: false,
951                trailing_comments: vec![],
952            },
953            table: None,
954            join_mark: false,
955            trailing_comments: vec![],
956        })
957    }
958
959    fn make_literal(value: i64) -> Expression {
960        Expression::Literal(Literal::Number(value.to_string()))
961    }
962
963    #[test]
964    fn test_dfs_simple() {
965        let left = make_column("a");
966        let right = make_literal(1);
967        let expr = Expression::Eq(Box::new(BinaryOp {
968            left,
969            right,
970            left_comments: vec![],
971            operator_comments: vec![],
972            trailing_comments: vec![],
973        }));
974
975        let nodes: Vec<_> = expr.dfs().collect();
976        assert_eq!(nodes.len(), 3); // Eq, Column, Literal
977        assert!(matches!(nodes[0], Expression::Eq(_)));
978        assert!(matches!(nodes[1], Expression::Column(_)));
979        assert!(matches!(nodes[2], Expression::Literal(_)));
980    }
981
982    #[test]
983    fn test_find() {
984        let left = make_column("a");
985        let right = make_literal(1);
986        let expr = Expression::Eq(Box::new(BinaryOp {
987            left,
988            right,
989            left_comments: vec![],
990            operator_comments: vec![],
991            trailing_comments: vec![],
992        }));
993
994        let column = expr.find(is_column);
995        assert!(column.is_some());
996        assert!(matches!(column.unwrap(), Expression::Column(_)));
997
998        let literal = expr.find(is_literal);
999        assert!(literal.is_some());
1000        assert!(matches!(literal.unwrap(), Expression::Literal(_)));
1001    }
1002
1003    #[test]
1004    fn test_find_all() {
1005        let col1 = make_column("a");
1006        let col2 = make_column("b");
1007        let expr = Expression::And(Box::new(BinaryOp {
1008            left: col1,
1009            right: col2,
1010            left_comments: vec![],
1011            operator_comments: vec![],
1012            trailing_comments: vec![],
1013        }));
1014
1015        let columns = expr.find_all(is_column);
1016        assert_eq!(columns.len(), 2);
1017    }
1018
1019    #[test]
1020    fn test_contains() {
1021        let col = make_column("a");
1022        let lit = make_literal(1);
1023        let expr = Expression::Eq(Box::new(BinaryOp {
1024            left: col,
1025            right: lit,
1026            left_comments: vec![],
1027            operator_comments: vec![],
1028            trailing_comments: vec![],
1029        }));
1030
1031        assert!(expr.contains(is_column));
1032        assert!(expr.contains(is_literal));
1033        assert!(!expr.contains(is_subquery));
1034    }
1035
1036    #[test]
1037    fn test_count() {
1038        let col1 = make_column("a");
1039        let col2 = make_column("b");
1040        let lit = make_literal(1);
1041
1042        let inner = Expression::Add(Box::new(BinaryOp {
1043            left: col2,
1044            right: lit,
1045            left_comments: vec![],
1046            operator_comments: vec![],
1047            trailing_comments: vec![],
1048        }));
1049
1050        let expr = Expression::Eq(Box::new(BinaryOp {
1051            left: col1,
1052            right: inner,
1053            left_comments: vec![],
1054            operator_comments: vec![],
1055            trailing_comments: vec![],
1056        }));
1057
1058        assert_eq!(expr.count(is_column), 2);
1059        assert_eq!(expr.count(is_literal), 1);
1060    }
1061
1062    #[test]
1063    fn test_tree_depth() {
1064        // Single node
1065        let lit = make_literal(1);
1066        assert_eq!(lit.tree_depth(), 0);
1067
1068        // One level
1069        let col = make_column("a");
1070        let expr = Expression::Eq(Box::new(BinaryOp {
1071            left: col,
1072            right: lit.clone(),
1073            left_comments: vec![],
1074            operator_comments: vec![],
1075            trailing_comments: vec![],
1076        }));
1077        assert_eq!(expr.tree_depth(), 1);
1078
1079        // Two levels
1080        let inner = Expression::Add(Box::new(BinaryOp {
1081            left: make_column("b"),
1082            right: lit,
1083            left_comments: vec![],
1084            operator_comments: vec![],
1085            trailing_comments: vec![],
1086        }));
1087        let outer = Expression::Eq(Box::new(BinaryOp {
1088            left: make_column("a"),
1089            right: inner,
1090            left_comments: vec![],
1091            operator_comments: vec![],
1092            trailing_comments: vec![],
1093        }));
1094        assert_eq!(outer.tree_depth(), 2);
1095    }
1096
1097    #[test]
1098    fn test_tree_context() {
1099        let col = make_column("a");
1100        let lit = make_literal(1);
1101        let expr = Expression::Eq(Box::new(BinaryOp {
1102            left: col,
1103            right: lit,
1104            left_comments: vec![],
1105            operator_comments: vec![],
1106            trailing_comments: vec![],
1107        }));
1108
1109        let ctx = TreeContext::build(&expr);
1110
1111        // Root has no parent
1112        let root_info = ctx.get(0).unwrap();
1113        assert!(root_info.parent_id.is_none());
1114
1115        // Children have root as parent
1116        let left_info = ctx.get(1).unwrap();
1117        assert_eq!(left_info.parent_id, Some(0));
1118        assert_eq!(left_info.arg_key, "left");
1119
1120        let right_info = ctx.get(2).unwrap();
1121        assert_eq!(right_info.parent_id, Some(0));
1122        assert_eq!(right_info.arg_key, "right");
1123    }
1124
1125    // -- Step 8: transform / transform_map tests --
1126
1127    #[test]
1128    fn test_transform_rename_columns() {
1129        let ast = crate::parser::Parser::parse_sql("SELECT a, b FROM t").unwrap();
1130        let expr = ast[0].clone();
1131        let result = super::transform_map(expr, &|e| {
1132            if let Expression::Column(ref c) = e {
1133                if c.name.name == "a" {
1134                    return Ok(Expression::Column(Column {
1135                        name: Identifier::new("alpha"),
1136                        table: c.table.clone(),
1137                        join_mark: false,
1138                        trailing_comments: vec![],
1139                    }));
1140                }
1141            }
1142            Ok(e)
1143        })
1144        .unwrap();
1145        let sql = crate::generator::Generator::sql(&result).unwrap();
1146        assert!(sql.contains("alpha"), "Expected 'alpha' in: {}", sql);
1147        assert!(sql.contains("b"), "Expected 'b' in: {}", sql);
1148    }
1149
1150    #[test]
1151    fn test_transform_noop() {
1152        let ast = crate::parser::Parser::parse_sql("SELECT 1 + 2").unwrap();
1153        let expr = ast[0].clone();
1154        let result = super::transform_map(expr.clone(), &|e| Ok(e)).unwrap();
1155        let sql1 = crate::generator::Generator::sql(&expr).unwrap();
1156        let sql2 = crate::generator::Generator::sql(&result).unwrap();
1157        assert_eq!(sql1, sql2);
1158    }
1159
1160    #[test]
1161    fn test_transform_nested() {
1162        let ast = crate::parser::Parser::parse_sql("SELECT a + b FROM t").unwrap();
1163        let expr = ast[0].clone();
1164        let result = super::transform_map(expr, &|e| {
1165            if let Expression::Column(ref c) = e {
1166                return Ok(Expression::Literal(Literal::Number(
1167                    if c.name.name == "a" { "1" } else { "2" }.to_string(),
1168                )));
1169            }
1170            Ok(e)
1171        })
1172        .unwrap();
1173        let sql = crate::generator::Generator::sql(&result).unwrap();
1174        assert_eq!(sql, "SELECT 1 + 2 FROM t");
1175    }
1176
1177    #[test]
1178    fn test_transform_error() {
1179        let ast = crate::parser::Parser::parse_sql("SELECT a FROM t").unwrap();
1180        let expr = ast[0].clone();
1181        let result = super::transform_map(expr, &|e| {
1182            if let Expression::Column(ref c) = e {
1183                if c.name.name == "a" {
1184                    return Err(crate::error::Error::Parse("test error".to_string()));
1185                }
1186            }
1187            Ok(e)
1188        });
1189        assert!(result.is_err());
1190    }
1191
1192    #[test]
1193    fn test_transform_owned_trait() {
1194        let ast = crate::parser::Parser::parse_sql("SELECT x FROM t").unwrap();
1195        let expr = ast[0].clone();
1196        let result = expr.transform_owned(|e| Ok(Some(e))).unwrap();
1197        let sql = crate::generator::Generator::sql(&result).unwrap();
1198        assert_eq!(sql, "SELECT x FROM t");
1199    }
1200
1201    // -- children() tests --
1202
1203    #[test]
1204    fn test_children_leaf() {
1205        let lit = make_literal(1);
1206        assert_eq!(lit.children().len(), 0);
1207    }
1208
1209    #[test]
1210    fn test_children_binary_op() {
1211        let left = make_column("a");
1212        let right = make_literal(1);
1213        let expr = Expression::Eq(Box::new(BinaryOp {
1214            left,
1215            right,
1216            left_comments: vec![],
1217            operator_comments: vec![],
1218            trailing_comments: vec![],
1219        }));
1220        let children = expr.children();
1221        assert_eq!(children.len(), 2);
1222        assert!(matches!(children[0], Expression::Column(_)));
1223        assert!(matches!(children[1], Expression::Literal(_)));
1224    }
1225
1226    #[test]
1227    fn test_children_select() {
1228        let ast = crate::parser::Parser::parse_sql("SELECT a, b FROM t").unwrap();
1229        let expr = &ast[0];
1230        let children = expr.children();
1231        // Should include select list items (a, b)
1232        assert!(children.len() >= 2);
1233    }
1234
1235    // -- find_parent() tests --
1236
1237    #[test]
1238    fn test_find_parent_binary() {
1239        let left = make_column("a");
1240        let right = make_literal(1);
1241        let expr = Expression::Eq(Box::new(BinaryOp {
1242            left,
1243            right,
1244            left_comments: vec![],
1245            operator_comments: vec![],
1246            trailing_comments: vec![],
1247        }));
1248
1249        // Find the column child and get its parent
1250        let col = expr.find(is_column).unwrap();
1251        let parent = super::find_parent(&expr, col);
1252        assert!(parent.is_some());
1253        assert!(matches!(parent.unwrap(), Expression::Eq(_)));
1254    }
1255
1256    #[test]
1257    fn test_find_parent_root_has_none() {
1258        let lit = make_literal(1);
1259        let parent = super::find_parent(&lit, &lit);
1260        assert!(parent.is_none());
1261    }
1262
1263    // -- find_ancestor() tests --
1264
1265    #[test]
1266    fn test_find_ancestor_select() {
1267        let ast = crate::parser::Parser::parse_sql("SELECT a FROM t WHERE a > 1").unwrap();
1268        let expr = &ast[0];
1269
1270        // Find a column inside the WHERE clause
1271        let where_col = expr.dfs().find(|e| {
1272            if let Expression::Column(c) = e {
1273                c.name.name == "a"
1274            } else {
1275                false
1276            }
1277        });
1278        assert!(where_col.is_some());
1279
1280        // Find Select ancestor of that column
1281        let ancestor = super::find_ancestor(expr, where_col.unwrap(), is_select);
1282        assert!(ancestor.is_some());
1283        assert!(matches!(ancestor.unwrap(), Expression::Select(_)));
1284    }
1285
1286    #[test]
1287    fn test_find_ancestor_no_match() {
1288        let left = make_column("a");
1289        let right = make_literal(1);
1290        let expr = Expression::Eq(Box::new(BinaryOp {
1291            left,
1292            right,
1293            left_comments: vec![],
1294            operator_comments: vec![],
1295            trailing_comments: vec![],
1296        }));
1297
1298        let col = expr.find(is_column).unwrap();
1299        let ancestor = super::find_ancestor(&expr, col, is_select);
1300        assert!(ancestor.is_none());
1301    }
1302
1303    #[test]
1304    fn test_ancestors() {
1305        let col = make_column("a");
1306        let lit = make_literal(1);
1307        let inner = Expression::Add(Box::new(BinaryOp {
1308            left: col,
1309            right: lit,
1310            left_comments: vec![],
1311            operator_comments: vec![],
1312            trailing_comments: vec![],
1313        }));
1314        let outer = Expression::Eq(Box::new(BinaryOp {
1315            left: make_column("b"),
1316            right: inner,
1317            left_comments: vec![],
1318            operator_comments: vec![],
1319            trailing_comments: vec![],
1320        }));
1321
1322        let ctx = TreeContext::build(&outer);
1323
1324        // The inner Add's left child (column "a") should have ancestors
1325        // Node 0: Eq
1326        // Node 1: Column "b" (left of Eq)
1327        // Node 2: Add (right of Eq)
1328        // Node 3: Column "a" (left of Add)
1329        // Node 4: Literal (right of Add)
1330
1331        let ancestors = ctx.ancestors_of(3);
1332        assert_eq!(ancestors, vec![2, 0]); // Add, then Eq
1333    }
1334}