Skip to main content

datafusion_common/
tree_node.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`TreeNode`] for visiting and rewriting expression and plan trees
19
20use crate::Result;
21use std::collections::HashMap;
22use std::hash::Hash;
23use std::sync::Arc;
24
25/// These macros are used to determine continuation during transforming traversals.
26macro_rules! handle_transform_recursion {
27    ($F_DOWN:expr, $F_CHILD:expr, $F_UP:expr) => {{
28        $F_DOWN?
29            .transform_children(|n| n.map_children($F_CHILD))?
30            .transform_parent($F_UP)
31    }};
32}
33
34/// API for inspecting and rewriting tree data structures.
35///
36/// The `TreeNode` API is used to express algorithms separately from traversing
37/// the structure of `TreeNode`s, avoiding substantial code duplication.
38///
39/// This trait is implemented for plans ([`ExecutionPlan`], [`LogicalPlan`]) and
40/// expression trees ([`PhysicalExpr`], [`Expr`]) as well as Plan+Payload
41/// combinations [`PlanContext`] and [`ExprContext`].
42///
43/// # Overview
44/// There are three categories of TreeNode APIs:
45///
46/// 1. "Inspecting" APIs to traverse a tree of `&TreeNodes`:
47///    [`apply`], [`visit`], [`exists`].
48///
49/// 2. "Transforming" APIs that traverse and consume a tree of `TreeNode`s
50///    producing possibly changed `TreeNode`s: [`transform`], [`transform_up`],
51///    [`transform_down`], [`transform_down_up`], and [`rewrite`].
52///
53/// 3. Internal APIs used to implement the `TreeNode` API: [`apply_children`],
54///    and [`map_children`].
55///
56/// | Traversal Order | Inspecting | Transforming |
57/// | --- | --- | --- |
58/// | top-down | [`apply`], [`exists`] | [`transform_down`]|
59/// | bottom-up | | [`transform`] , [`transform_up`]|
60/// | combined with separate `f_down` and `f_up` closures | | [`transform_down_up`] |
61/// | combined with `f_down()` and `f_up()` in an object | [`visit`]  | [`rewrite`] |
62///
63/// **Note**:while there is currently no in-place mutation API that uses `&mut
64/// TreeNode`, the transforming APIs are efficient and optimized to avoid
65/// cloning.
66///
67/// [`apply`]: Self::apply
68/// [`visit`]: Self::visit
69/// [`exists`]: Self::exists
70/// [`transform`]: Self::transform
71/// [`transform_up`]: Self::transform_up
72/// [`transform_down`]: Self::transform_down
73/// [`transform_down_up`]: Self::transform_down_up
74/// [`rewrite`]: Self::rewrite
75/// [`apply_children`]: Self::apply_children
76/// [`map_children`]: Self::map_children
77///
78/// # Terminology
79/// The following terms are used in this trait
80///
81/// * `f_down`: Invoked before any children of the current node are visited.
82/// * `f_up`: Invoked after all children of the current node are visited.
83/// * `f`: closure that is applied to the current node.
84/// * `map_*`: applies a transformation to rewrite owned nodes
85/// * `apply_*`:  invokes a function on borrowed nodes
86/// * `transform_`: applies a transformation to rewrite owned nodes
87///
88/// <!-- Since these are in the datafusion-common crate, can't use intra doc links) -->
89/// [`ExecutionPlan`]: https://docs.rs/datafusion/latest/datafusion/physical_plan/trait.ExecutionPlan.html
90/// [`PhysicalExpr`]: https://docs.rs/datafusion/latest/datafusion/physical_plan/trait.PhysicalExpr.html
91/// [`LogicalPlan`]: https://docs.rs/datafusion-expr/latest/datafusion_expr/logical_plan/enum.LogicalPlan.html
92/// [`Expr`]: https://docs.rs/datafusion-expr/latest/datafusion_expr/expr/enum.Expr.html
93/// [`PlanContext`]: https://docs.rs/datafusion/latest/datafusion/physical_plan/tree_node/struct.PlanContext.html
94/// [`ExprContext`]: https://docs.rs/datafusion/latest/datafusion/physical_expr/tree_node/struct.ExprContext.html
95pub trait TreeNode: Sized {
96    /// Visit the tree node with a [`TreeNodeVisitor`], performing a
97    /// depth-first walk of the node and its children.
98    ///
99    /// [`TreeNodeVisitor::f_down()`] is called in top-down order (before
100    /// children are visited), [`TreeNodeVisitor::f_up()`] is called in
101    /// bottom-up order (after children are visited).
102    ///
103    /// # Return Value
104    /// Specifies how the tree walk ended. See [`TreeNodeRecursion`] for details.
105    ///
106    /// # See Also:
107    /// * [`Self::apply`] for inspecting nodes with a closure
108    /// * [`Self::rewrite`] to rewrite owned `TreeNode`s
109    ///
110    /// # Example
111    /// Consider the following tree structure:
112    /// ```text
113    /// ParentNode
114    ///    left: ChildNode1
115    ///    right: ChildNode2
116    /// ```
117    ///
118    /// Here, the nodes would be visited using the following order:
119    /// ```text
120    /// TreeNodeVisitor::f_down(ParentNode)
121    /// TreeNodeVisitor::f_down(ChildNode1)
122    /// TreeNodeVisitor::f_up(ChildNode1)
123    /// TreeNodeVisitor::f_down(ChildNode2)
124    /// TreeNodeVisitor::f_up(ChildNode2)
125    /// TreeNodeVisitor::f_up(ParentNode)
126    /// ```
127    #[cfg_attr(feature = "recursive_protection", recursive::recursive)]
128    fn visit<'n, V: TreeNodeVisitor<'n, Node = Self>>(
129        &'n self,
130        visitor: &mut V,
131    ) -> Result<TreeNodeRecursion> {
132        visitor
133            .f_down(self)?
134            .visit_children(|| self.apply_children(|c| c.visit(visitor)))?
135            .visit_parent(|| visitor.f_up(self))
136    }
137
138    /// Rewrite the tree node with a [`TreeNodeRewriter`], performing a
139    /// depth-first walk of the node and its children.
140    ///
141    /// [`TreeNodeRewriter::f_down()`] is called in top-down order (before
142    /// children are visited), [`TreeNodeRewriter::f_up()`] is called in
143    /// bottom-up order (after children are visited).
144    ///
145    /// Note: If using the default [`TreeNodeRewriter::f_up`] or
146    /// [`TreeNodeRewriter::f_down`] that do nothing, consider using
147    /// [`Self::transform_down`] instead.
148    ///
149    /// # Return Value
150    /// The returns value specifies how the tree walk should proceed. See
151    /// [`TreeNodeRecursion`] for details. If an [`Err`] is returned, the
152    /// recursion stops immediately.
153    ///
154    /// # See Also
155    /// * [`Self::visit`] for inspecting (without modification) `TreeNode`s
156    /// * [Self::transform_down_up] for a top-down (pre-order) traversal.
157    /// * [Self::transform_down] for a top-down (pre-order) traversal.
158    /// * [`Self::transform_up`] for a bottom-up (post-order) traversal.
159    ///
160    /// # Example
161    /// Consider the following tree structure:
162    /// ```text
163    /// ParentNode
164    ///    left: ChildNode1
165    ///    right: ChildNode2
166    /// ```
167    ///
168    /// Here, the nodes would be visited using the following order:
169    /// ```text
170    /// TreeNodeRewriter::f_down(ParentNode)
171    /// TreeNodeRewriter::f_down(ChildNode1)
172    /// TreeNodeRewriter::f_up(ChildNode1)
173    /// TreeNodeRewriter::f_down(ChildNode2)
174    /// TreeNodeRewriter::f_up(ChildNode2)
175    /// TreeNodeRewriter::f_up(ParentNode)
176    /// ```
177    #[cfg_attr(feature = "recursive_protection", recursive::recursive)]
178    fn rewrite<R: TreeNodeRewriter<Node = Self>>(
179        self,
180        rewriter: &mut R,
181    ) -> Result<Transformed<Self>> {
182        handle_transform_recursion!(rewriter.f_down(self), |c| c.rewrite(rewriter), |n| {
183            rewriter.f_up(n)
184        })
185    }
186
187    /// Applies `f` to the node then each of its children, recursively (a
188    /// top-down, pre-order traversal).
189    ///
190    /// The return [`TreeNodeRecursion`] controls the recursion and can cause
191    /// an early return.
192    ///
193    /// # See Also
194    /// * [`Self::transform_down`] for the equivalent transformation API.
195    /// * [`Self::visit`] for both top-down and bottom up traversal.
196    fn apply<'n, F: FnMut(&'n Self) -> Result<TreeNodeRecursion>>(
197        &'n self,
198        mut f: F,
199    ) -> Result<TreeNodeRecursion> {
200        #[cfg_attr(feature = "recursive_protection", recursive::recursive)]
201        fn apply_impl<'n, N: TreeNode, F: FnMut(&'n N) -> Result<TreeNodeRecursion>>(
202            node: &'n N,
203            f: &mut F,
204        ) -> Result<TreeNodeRecursion> {
205            f(node)?.visit_children(|| node.apply_children(|c| apply_impl(c, f)))
206        }
207
208        apply_impl(self, &mut f)
209    }
210
211    /// Recursively rewrite the node's children and then the node using `f`
212    /// (a bottom-up post-order traversal).
213    ///
214    /// A synonym of [`Self::transform_up`].
215    fn transform<F: FnMut(Self) -> Result<Transformed<Self>>>(
216        self,
217        f: F,
218    ) -> Result<Transformed<Self>> {
219        self.transform_up(f)
220    }
221
222    /// Recursively rewrite the tree using `f` in a top-down (pre-order)
223    /// fashion.
224    ///
225    /// `f` is applied to the node first, and then its children.
226    ///
227    /// # See Also
228    /// * [`Self::transform_up`] for a bottom-up (post-order) traversal.
229    /// * [Self::transform_down_up] for a combined traversal with closures
230    /// * [`Self::rewrite`] for a combined traversal with a visitor
231    fn transform_down<F: FnMut(Self) -> Result<Transformed<Self>>>(
232        self,
233        mut f: F,
234    ) -> Result<Transformed<Self>> {
235        #[cfg_attr(feature = "recursive_protection", recursive::recursive)]
236        fn transform_down_impl<N: TreeNode, F: FnMut(N) -> Result<Transformed<N>>>(
237            node: N,
238            f: &mut F,
239        ) -> Result<Transformed<N>> {
240            f(node)?.transform_children(|n| n.map_children(|c| transform_down_impl(c, f)))
241        }
242
243        transform_down_impl(self, &mut f)
244    }
245
246    /// Recursively rewrite the node using `f` in a bottom-up (post-order)
247    /// fashion.
248    ///
249    /// `f` is applied to the node's  children first, and then to the node itself.
250    ///
251    /// # See Also
252    /// * [`Self::transform_down`] top-down (pre-order) traversal.
253    /// * [Self::transform_down_up] for a combined traversal with closures
254    /// * [`Self::rewrite`] for a combined traversal with a visitor
255    fn transform_up<F: FnMut(Self) -> Result<Transformed<Self>>>(
256        self,
257        mut f: F,
258    ) -> Result<Transformed<Self>> {
259        #[cfg_attr(feature = "recursive_protection", recursive::recursive)]
260        fn transform_up_impl<N: TreeNode, F: FnMut(N) -> Result<Transformed<N>>>(
261            node: N,
262            f: &mut F,
263        ) -> Result<Transformed<N>> {
264            node.map_children(|c| transform_up_impl(c, f))?
265                .transform_parent(f)
266        }
267
268        transform_up_impl(self, &mut f)
269    }
270
271    /// Transforms the node using `f_down` while traversing the tree top-down
272    /// (pre-order), and using `f_up` while traversing the tree bottom-up
273    /// (post-order).
274    ///
275    /// The method behaves the same as calling [`Self::transform_down`] followed
276    /// by [`Self::transform_up`] on the same node. Use this method if you want
277    /// to start the `f_up` process right where `f_down` jumps. This can make
278    /// the whole process faster by reducing the number of `f_up` steps.
279    ///
280    /// # See Also
281    /// * [`Self::transform_up`] for a bottom-up (post-order) traversal.
282    /// * [Self::transform_down] for a top-down (pre-order) traversal.
283    /// * [`Self::rewrite`] for a combined traversal with a visitor
284    ///
285    /// # Example
286    /// Consider the following tree structure:
287    /// ```text
288    /// ParentNode
289    ///    left: ChildNode1
290    ///    right: ChildNode2
291    /// ```
292    ///
293    /// The nodes are visited using the following order:
294    /// ```text
295    /// f_down(ParentNode)
296    /// f_down(ChildNode1)
297    /// f_up(ChildNode1)
298    /// f_down(ChildNode2)
299    /// f_up(ChildNode2)
300    /// f_up(ParentNode)
301    /// ```
302    ///
303    /// See [`TreeNodeRecursion`] for more details on controlling the traversal.
304    ///
305    /// If `f_down` or `f_up` returns [`Err`], the recursion stops immediately.
306    ///
307    /// Example:
308    /// ```text
309    ///                                               |   +---+
310    ///                                               |   | J |
311    ///                                               |   +---+
312    ///                                               |     |
313    ///                                               |   +---+
314    ///                  TreeNodeRecursion::Continue  |   | I |
315    ///                                               |   +---+
316    ///                                               |     |
317    ///                                               |   +---+
318    ///                                              \|/  | F |
319    ///                                               '   +---+
320    ///                                                  /     \ ___________________
321    ///                  When `f_down` is           +---+                           \ ---+
322    ///                  applied on node "E",       | E |                            | G |
323    ///                  it returns with "Jump".    +---+                            +---+
324    ///                                               |                                |
325    ///                                             +---+                            +---+
326    ///                                             | C |                            | H |
327    ///                                             +---+                            +---+
328    ///                                             /   \
329    ///                                        +---+     +---+
330    ///                                        | B |     | D |
331    ///                                        +---+     +---+
332    ///                                                    |
333    ///                                                  +---+
334    ///                                                  | A |
335    ///                                                  +---+
336    ///
337    /// Instead of starting from leaf nodes, `f_up` starts from the node "E".
338    ///                                                   +---+
339    ///                                               |   | J |
340    ///                                               |   +---+
341    ///                                               |     |
342    ///                                               |   +---+
343    ///                                               |   | I |
344    ///                                               |   +---+
345    ///                                               |     |
346    ///                                              /    +---+
347    ///                                            /      | F |
348    ///                                          /        +---+
349    ///                                        /         /     \ ______________________
350    ///                                       |     +---+   .                          \ ---+
351    ///                                       |     | E |  /|\  After `f_down` jumps    | G |
352    ///                                       |     +---+   |   on node E, `f_up`       +---+
353    ///                                        \------| ---/   if applied on node E.      |
354    ///                                             +---+                               +---+
355    ///                                             | C |                               | H |
356    ///                                             +---+                               +---+
357    ///                                             /   \
358    ///                                        +---+     +---+
359    ///                                        | B |     | D |
360    ///                                        +---+     +---+
361    ///                                                    |
362    ///                                                  +---+
363    ///                                                  | A |
364    ///                                                  +---+
365    /// ```
366    fn transform_down_up<
367        FD: FnMut(Self) -> Result<Transformed<Self>>,
368        FU: FnMut(Self) -> Result<Transformed<Self>>,
369    >(
370        self,
371        mut f_down: FD,
372        mut f_up: FU,
373    ) -> Result<Transformed<Self>> {
374        #[cfg_attr(feature = "recursive_protection", recursive::recursive)]
375        fn transform_down_up_impl<
376            N: TreeNode,
377            FD: FnMut(N) -> Result<Transformed<N>>,
378            FU: FnMut(N) -> Result<Transformed<N>>,
379        >(
380            node: N,
381            f_down: &mut FD,
382            f_up: &mut FU,
383        ) -> Result<Transformed<N>> {
384            handle_transform_recursion!(
385                f_down(node),
386                |c| transform_down_up_impl(c, f_down, f_up),
387                f_up
388            )
389        }
390
391        transform_down_up_impl(self, &mut f_down, &mut f_up)
392    }
393
394    /// Returns true if `f` returns true for any node in the tree.
395    ///
396    /// Stops recursion as soon as a matching node is found
397    fn exists<F: FnMut(&Self) -> Result<bool>>(&self, mut f: F) -> Result<bool> {
398        let mut found = false;
399        self.apply(|n| {
400            Ok(if f(n)? {
401                found = true;
402                TreeNodeRecursion::Stop
403            } else {
404                TreeNodeRecursion::Continue
405            })
406        })
407        .map(|_| found)
408    }
409
410    /// Low-level API used to implement other APIs.
411    ///
412    /// If you want to implement the [`TreeNode`] trait for your own type, you
413    /// should implement this method and [`Self::map_children`].
414    ///
415    /// Users should use one of the higher level APIs described on [`Self`].
416    ///
417    /// Description: Apply `f` to inspect node's children (but not the node
418    /// itself).
419    fn apply_children<'n, F: FnMut(&'n Self) -> Result<TreeNodeRecursion>>(
420        &'n self,
421        f: F,
422    ) -> Result<TreeNodeRecursion>;
423
424    /// Low-level API used to implement other APIs.
425    ///
426    /// If you want to implement the [`TreeNode`] trait for your own type, you
427    /// should implement this method and [`Self::apply_children`].
428    ///
429    /// Users should use one of the higher level APIs described on [`Self`].
430    ///
431    /// Description: Apply `f` to rewrite the node's children (but not the node itself).
432    fn map_children<F: FnMut(Self) -> Result<Transformed<Self>>>(
433        self,
434        f: F,
435    ) -> Result<Transformed<Self>>;
436}
437
438/// A [Visitor](https://en.wikipedia.org/wiki/Visitor_pattern) for recursively
439/// inspecting [`TreeNode`]s via [`TreeNode::visit`].
440///
441/// See [`TreeNode`] for more details on available APIs
442///
443/// When passed to [`TreeNode::visit`], [`TreeNodeVisitor::f_down`] and
444/// [`TreeNodeVisitor::f_up`] are invoked recursively on the tree.
445/// See [`TreeNodeRecursion`] for more details on controlling the traversal.
446///
447/// # Return Value
448/// The returns value of `f_up` and `f_down` specifies how the tree walk should
449/// proceed. See [`TreeNodeRecursion`] for details. If an [`Err`] is returned,
450/// the recursion stops immediately.
451///
452/// Note: If using the default implementations of [`TreeNodeVisitor::f_up`] or
453/// [`TreeNodeVisitor::f_down`] that do nothing, consider using
454/// [`TreeNode::apply`] instead.
455///
456/// # See Also:
457/// * [`TreeNode::rewrite`] to rewrite owned `TreeNode`s
458pub trait TreeNodeVisitor<'n>: Sized {
459    /// The node type which is visitable.
460    type Node: TreeNode;
461
462    /// Invoked while traversing down the tree, before any children are visited.
463    /// Default implementation continues the recursion.
464    fn f_down(&mut self, _node: &'n Self::Node) -> Result<TreeNodeRecursion> {
465        Ok(TreeNodeRecursion::Continue)
466    }
467
468    /// Invoked while traversing up the tree after children are visited. Default
469    /// implementation continues the recursion.
470    fn f_up(&mut self, _node: &'n Self::Node) -> Result<TreeNodeRecursion> {
471        Ok(TreeNodeRecursion::Continue)
472    }
473}
474
475/// A [Visitor](https://en.wikipedia.org/wiki/Visitor_pattern) for recursively
476/// rewriting [`TreeNode`]s via [`TreeNode::rewrite`].
477///
478/// For example you can implement this trait on a struct to rewrite `Expr` or
479/// `LogicalPlan` that needs to track state during the rewrite.
480///
481/// See [`TreeNode`] for more details on available APIs
482///
483/// When passed to [`TreeNode::rewrite`], [`TreeNodeRewriter::f_down`] and
484/// [`TreeNodeRewriter::f_up`] are invoked recursively on the tree.
485/// See [`TreeNodeRecursion`] for more details on controlling the traversal.
486///
487/// # Return Value
488/// The returns value of `f_up` and `f_down` specifies how the tree walk should
489/// proceed. See [`TreeNodeRecursion`] for details. If an [`Err`] is returned,
490/// the recursion stops immediately.
491///
492/// Note: If using the default implementations of [`TreeNodeRewriter::f_up`] or
493/// [`TreeNodeRewriter::f_down`] that do nothing, consider using
494/// [`TreeNode::transform_up`] or [`TreeNode::transform_down`] instead.
495///
496/// # See Also:
497/// * [`TreeNode::visit`] to inspect borrowed `TreeNode`s
498pub trait TreeNodeRewriter: Sized {
499    /// The node type which is rewritable.
500    type Node: TreeNode;
501
502    /// Invoked while traversing down the tree before any children are rewritten.
503    /// Default implementation returns the node as is and continues recursion.
504    fn f_down(&mut self, node: Self::Node) -> Result<Transformed<Self::Node>> {
505        Ok(Transformed::no(node))
506    }
507
508    /// Invoked while traversing up the tree after all children have been rewritten.
509    /// Default implementation returns the node as is and continues recursion.
510    fn f_up(&mut self, node: Self::Node) -> Result<Transformed<Self::Node>> {
511        Ok(Transformed::no(node))
512    }
513}
514
515/// Controls how [`TreeNode`] recursions should proceed.
516#[derive(Debug, PartialEq, Clone, Copy)]
517pub enum TreeNodeRecursion {
518    /// Continue recursion with the next node.
519    Continue,
520    /// In top-down traversals, skip recursing into children but continue with
521    /// the next node, which actually means pruning of the subtree.
522    ///
523    /// In bottom-up traversals, bypass calling bottom-up closures till the next
524    /// leaf node.
525    ///
526    /// In combined traversals, if it is the `f_down` (pre-order) phase, execution
527    /// "jumps" to the next `f_up` (post-order) phase by shortcutting its children.
528    /// If it is the `f_up` (post-order) phase, execution "jumps" to the next `f_down`
529    /// (pre-order) phase by shortcutting its parent nodes until the first parent node
530    /// having unvisited children path.
531    Jump,
532    /// Stop recursion.
533    Stop,
534}
535
536impl TreeNodeRecursion {
537    /// Continues visiting nodes with `f` depending on the current [`TreeNodeRecursion`]
538    /// value and the fact that `f` is visiting the current node's children.
539    pub fn visit_children<F: FnOnce() -> Result<TreeNodeRecursion>>(
540        self,
541        f: F,
542    ) -> Result<TreeNodeRecursion> {
543        match self {
544            TreeNodeRecursion::Continue => f(),
545            TreeNodeRecursion::Jump => Ok(TreeNodeRecursion::Continue),
546            TreeNodeRecursion::Stop => Ok(self),
547        }
548    }
549
550    /// Continues visiting nodes with `f` depending on the current [`TreeNodeRecursion`]
551    /// value and the fact that `f` is visiting the current node's sibling.
552    pub fn visit_sibling<F: FnOnce() -> Result<TreeNodeRecursion>>(
553        self,
554        f: F,
555    ) -> Result<TreeNodeRecursion> {
556        match self {
557            TreeNodeRecursion::Continue | TreeNodeRecursion::Jump => f(),
558            TreeNodeRecursion::Stop => Ok(self),
559        }
560    }
561
562    /// Continues visiting nodes with `f` depending on the current [`TreeNodeRecursion`]
563    /// value and the fact that `f` is visiting the current node's parent.
564    pub fn visit_parent<F: FnOnce() -> Result<TreeNodeRecursion>>(
565        self,
566        f: F,
567    ) -> Result<TreeNodeRecursion> {
568        match self {
569            TreeNodeRecursion::Continue => f(),
570            TreeNodeRecursion::Jump | TreeNodeRecursion::Stop => Ok(self),
571        }
572    }
573}
574
575/// Result of tree walk / transformation APIs
576///
577/// `Transformed` is a wrapper around the tree node data (e.g. `Expr` or
578/// `LogicalPlan`). It is used to indicate whether the node was transformed
579/// and how the recursion should proceed.
580///
581/// [`TreeNode`] API users control the transformation by returning:
582/// - The resulting (possibly transformed) node,
583/// - `transformed`: flag indicating whether any change was made to the node
584/// - `tnr`: [`TreeNodeRecursion`] specifying how to proceed with the recursion.
585///
586/// At the end of the transformation, the return value will contain:
587/// - The final (possibly transformed) tree,
588/// - `transformed`: flag indicating whether any change was made to the node
589/// - `tnr`: [`TreeNodeRecursion`] specifying how the recursion ended.
590///
591/// See also
592/// * [`Transformed::update_data`] to modify the node without changing the `transformed` flag
593/// * [`Transformed::map_data`] for fallable operation that return the same type
594/// * [`Transformed::transform_data`] to chain fallable transformations
595/// * [`TransformedResult`] for working with `Result<Transformed<U>>`
596///
597/// # Examples
598///
599/// Use [`Transformed::yes`] and [`Transformed::no`] to signal that a node was
600/// rewritten and the recursion should continue:
601///
602/// ```
603/// # use datafusion_common::tree_node::Transformed;
604/// # // note use i64 instead of Expr as Expr is not in datafusion-common
605/// # fn orig_expr() -> i64 { 1 }
606/// # fn make_new_expr(i: i64) -> i64 { 2 }
607/// let expr = orig_expr();
608///
609/// // Create a new `Transformed` object signaling the node was not rewritten
610/// let ret = Transformed::no(expr.clone());
611/// assert!(!ret.transformed);
612///
613/// // Create a new `Transformed` object signaling the node was rewritten
614/// let ret = Transformed::yes(expr);
615/// assert!(ret.transformed)
616/// ```
617///
618/// Access the node within the `Transformed` object:
619/// ```
620/// # use datafusion_common::tree_node::Transformed;
621/// # // note use i64 instead of Expr as Expr is not in datafusion-common
622/// # fn orig_expr() -> i64 { 1 }
623/// # fn make_new_expr(i: i64) -> i64 { 2 }
624/// let expr = orig_expr();
625///
626/// // `Transformed` object signaling the node was not rewritten
627/// let ret = Transformed::no(expr.clone());
628/// // Access the inner object using .data
629/// assert_eq!(expr, ret.data);
630/// ```
631///
632/// Transform the node within the `Transformed` object.
633///
634/// ```
635/// # use datafusion_common::tree_node::Transformed;
636/// # // note use i64 instead of Expr as Expr is not in datafusion-common
637/// # fn orig_expr() -> i64 { 1 }
638/// # fn make_new_expr(i: i64) -> i64 { 2 }
639/// let expr = orig_expr();
640/// let ret = Transformed::no(expr.clone())
641///     .transform_data(|expr| {
642///         // closure returns a result and potentially transforms the node
643///         // in this example, it does transform the node
644///         let new_expr = make_new_expr(expr);
645///         Ok(Transformed::yes(new_expr))
646///     })
647///     .unwrap();
648/// // transformed flag is the union of the original ans closure's  transformed flag
649/// assert!(ret.transformed);
650/// ```
651/// # Example APIs that use `TreeNode`
652/// - [`TreeNode`],
653/// - [`TreeNode::rewrite`],
654/// - [`TreeNode::transform_down`],
655/// - [`TreeNode::transform_up`],
656/// - [`TreeNode::transform_down_up`]
657#[derive(PartialEq, Debug)]
658pub struct Transformed<T> {
659    pub data: T,
660    pub transformed: bool,
661    pub tnr: TreeNodeRecursion,
662}
663
664impl<T> Transformed<T> {
665    /// Create a new `Transformed` object with the given information.
666    pub fn new(data: T, transformed: bool, tnr: TreeNodeRecursion) -> Self {
667        Self {
668            data,
669            transformed,
670            tnr,
671        }
672    }
673
674    /// Create a `Transformed` with `transformed` and [`TreeNodeRecursion::Continue`].
675    pub fn new_transformed(data: T, transformed: bool) -> Self {
676        Self::new(data, transformed, TreeNodeRecursion::Continue)
677    }
678
679    /// Wrapper for transformed data with [`TreeNodeRecursion::Continue`] statement.
680    pub fn yes(data: T) -> Self {
681        Self::new(data, true, TreeNodeRecursion::Continue)
682    }
683
684    /// Wrapper for transformed data with [`TreeNodeRecursion::Stop`] statement.
685    pub fn complete(data: T) -> Self {
686        Self::new(data, true, TreeNodeRecursion::Stop)
687    }
688
689    /// Wrapper for unchanged data with [`TreeNodeRecursion::Continue`] statement.
690    pub fn no(data: T) -> Self {
691        Self::new(data, false, TreeNodeRecursion::Continue)
692    }
693
694    /// Applies an infallible `f` to the data of this [`Transformed`] object,
695    /// without modifying the `transformed` flag.
696    pub fn update_data<U, F: FnOnce(T) -> U>(self, f: F) -> Transformed<U> {
697        Transformed::new(f(self.data), self.transformed, self.tnr)
698    }
699
700    /// Applies a fallible `f` (returns `Result`) to the data of this
701    /// [`Transformed`] object, without modifying the `transformed` flag.
702    pub fn map_data<U, F: FnOnce(T) -> Result<U>>(self, f: F) -> Result<Transformed<U>> {
703        f(self.data).map(|data| Transformed::new(data, self.transformed, self.tnr))
704    }
705
706    /// Applies a fallible transforming `f` to the data of this [`Transformed`]
707    /// object.
708    ///
709    /// The returned `Transformed` object has the `transformed` flag set if either
710    /// `self` or the return value of `f` have the `transformed` flag set.
711    pub fn transform_data<U, F: FnOnce(T) -> Result<Transformed<U>>>(
712        self,
713        f: F,
714    ) -> Result<Transformed<U>> {
715        f(self.data).map(|mut t| {
716            t.transformed |= self.transformed;
717            t
718        })
719    }
720
721    /// Maps the [`Transformed`] object to the result of the given `f` depending on the
722    /// current [`TreeNodeRecursion`] value and the fact that `f` is changing the current
723    /// node's children.
724    pub fn transform_children<F: FnOnce(T) -> Result<Transformed<T>>>(
725        mut self,
726        f: F,
727    ) -> Result<Transformed<T>> {
728        match self.tnr {
729            TreeNodeRecursion::Continue => {
730                return f(self.data).map(|mut t| {
731                    t.transformed |= self.transformed;
732                    t
733                });
734            }
735            TreeNodeRecursion::Jump => {
736                self.tnr = TreeNodeRecursion::Continue;
737            }
738            TreeNodeRecursion::Stop => {}
739        }
740        Ok(self)
741    }
742
743    /// Maps the [`Transformed`] object to the result of the given `f` depending on the
744    /// current [`TreeNodeRecursion`] value and the fact that `f` is changing the current
745    /// node's sibling.
746    pub fn transform_sibling<F: FnOnce(T) -> Result<Transformed<T>>>(
747        self,
748        f: F,
749    ) -> Result<Transformed<T>> {
750        match self.tnr {
751            TreeNodeRecursion::Continue | TreeNodeRecursion::Jump => {
752                f(self.data).map(|mut t| {
753                    t.transformed |= self.transformed;
754                    t
755                })
756            }
757            TreeNodeRecursion::Stop => Ok(self),
758        }
759    }
760
761    /// Maps the [`Transformed`] object to the result of the given `f` depending on the
762    /// current [`TreeNodeRecursion`] value and the fact that `f` is changing the current
763    /// node's parent.
764    pub fn transform_parent<F: FnOnce(T) -> Result<Transformed<T>>>(
765        self,
766        f: F,
767    ) -> Result<Transformed<T>> {
768        match self.tnr {
769            TreeNodeRecursion::Continue => f(self.data).map(|mut t| {
770                t.transformed |= self.transformed;
771                t
772            }),
773            TreeNodeRecursion::Jump | TreeNodeRecursion::Stop => Ok(self),
774        }
775    }
776}
777
778/// [`TreeNodeContainer`] contains elements that a function can be applied on or mapped.
779/// The elements of the container are siblings so the continuation rules are similar to
780/// [`TreeNodeRecursion::visit_sibling`] / [`Transformed::transform_sibling`].
781pub trait TreeNodeContainer<'a, T: 'a>: Sized {
782    /// Applies `f` to all elements of the container.
783    /// This method is usually called from [`TreeNode::apply_children`] implementations as
784    /// a node is actually a container of the node's children.
785    fn apply_elements<F: FnMut(&'a T) -> Result<TreeNodeRecursion>>(
786        &'a self,
787        f: F,
788    ) -> Result<TreeNodeRecursion>;
789
790    /// Maps all elements of the container with `f`.
791    /// This method is usually called from [`TreeNode::map_children`] implementations as
792    /// a node is actually a container of the node's children.
793    fn map_elements<F: FnMut(T) -> Result<Transformed<T>>>(
794        self,
795        f: F,
796    ) -> Result<Transformed<Self>>;
797}
798
799impl<'a, T: 'a, C: TreeNodeContainer<'a, T> + Default> TreeNodeContainer<'a, T>
800    for Box<C>
801{
802    fn apply_elements<F: FnMut(&'a T) -> Result<TreeNodeRecursion>>(
803        &'a self,
804        f: F,
805    ) -> Result<TreeNodeRecursion> {
806        self.as_ref().apply_elements(f)
807    }
808
809    fn map_elements<F: FnMut(T) -> Result<Transformed<T>>>(
810        mut self,
811        f: F,
812    ) -> Result<Transformed<Self>> {
813        // Rewrite in place so the existing heap allocation can be reused.
814        // `mem::take` hands the inner `C` to `f` while leaving
815        // `C::default()` in the slot, so an unwinding drop finds a valid
816        // `C` even if `f` panics or the `?` short-circuits.
817        let inner = std::mem::take(&mut *self);
818        Ok(inner.map_elements(f)?.update_data(|c| {
819            *self = c;
820            self
821        }))
822    }
823}
824
825impl<'a, T: 'a, C: TreeNodeContainer<'a, T> + Clone + Default> TreeNodeContainer<'a, T>
826    for Arc<C>
827{
828    fn apply_elements<F: FnMut(&'a T) -> Result<TreeNodeRecursion>>(
829        &'a self,
830        f: F,
831    ) -> Result<TreeNodeRecursion> {
832        self.as_ref().apply_elements(f)
833    }
834
835    fn map_elements<F: FnMut(T) -> Result<Transformed<T>>>(
836        mut self,
837        f: F,
838    ) -> Result<Transformed<Self>> {
839        // Rewrite in place using the same `mem::take` strategy as
840        // `Box<C>::map_elements`. `Arc::make_mut` gives us exclusive
841        // access (cloning `C` first if we were sharing), after which
842        // `get_mut` is infallible.
843        let inner = std::mem::take(Arc::make_mut(&mut self));
844        Ok(inner.map_elements(f)?.update_data(|c| {
845            *Arc::get_mut(&mut self).unwrap() = c;
846            self
847        }))
848    }
849}
850
851impl<'a, T: 'a, C: TreeNodeContainer<'a, T>> TreeNodeContainer<'a, T> for Option<C> {
852    fn apply_elements<F: FnMut(&'a T) -> Result<TreeNodeRecursion>>(
853        &'a self,
854        f: F,
855    ) -> Result<TreeNodeRecursion> {
856        match self {
857            Some(t) => t.apply_elements(f),
858            None => Ok(TreeNodeRecursion::Continue),
859        }
860    }
861
862    fn map_elements<F: FnMut(T) -> Result<Transformed<T>>>(
863        self,
864        f: F,
865    ) -> Result<Transformed<Self>> {
866        self.map_or(Ok(Transformed::no(None)), |c| {
867            c.map_elements(f)?.map_data(|c| Ok(Some(c)))
868        })
869    }
870}
871
872impl<'a, T: 'a, C: TreeNodeContainer<'a, T>> TreeNodeContainer<'a, T> for Vec<C> {
873    fn apply_elements<F: FnMut(&'a T) -> Result<TreeNodeRecursion>>(
874        &'a self,
875        mut f: F,
876    ) -> Result<TreeNodeRecursion> {
877        let mut tnr = TreeNodeRecursion::Continue;
878        for c in self {
879            tnr = c.apply_elements(&mut f)?;
880            match tnr {
881                TreeNodeRecursion::Continue | TreeNodeRecursion::Jump => {}
882                TreeNodeRecursion::Stop => return Ok(TreeNodeRecursion::Stop),
883            }
884        }
885        Ok(tnr)
886    }
887
888    fn map_elements<F: FnMut(T) -> Result<Transformed<T>>>(
889        self,
890        mut f: F,
891    ) -> Result<Transformed<Self>> {
892        let mut tnr = TreeNodeRecursion::Continue;
893        let mut transformed = false;
894        self.into_iter()
895            .map(|c| match tnr {
896                TreeNodeRecursion::Continue | TreeNodeRecursion::Jump => {
897                    c.map_elements(&mut f).map(|result| {
898                        tnr = result.tnr;
899                        transformed |= result.transformed;
900                        result.data
901                    })
902                }
903                TreeNodeRecursion::Stop => Ok(c),
904            })
905            .collect::<Result<Vec<_>>>()
906            .map(|data| Transformed::new(data, transformed, tnr))
907    }
908}
909
910impl<'a, T: 'a, K: Eq + Hash, C: TreeNodeContainer<'a, T>> TreeNodeContainer<'a, T>
911    for HashMap<K, C>
912{
913    fn apply_elements<F: FnMut(&'a T) -> Result<TreeNodeRecursion>>(
914        &'a self,
915        mut f: F,
916    ) -> Result<TreeNodeRecursion> {
917        let mut tnr = TreeNodeRecursion::Continue;
918        for c in self.values() {
919            tnr = c.apply_elements(&mut f)?;
920            match tnr {
921                TreeNodeRecursion::Continue | TreeNodeRecursion::Jump => {}
922                TreeNodeRecursion::Stop => return Ok(TreeNodeRecursion::Stop),
923            }
924        }
925        Ok(tnr)
926    }
927
928    fn map_elements<F: FnMut(T) -> Result<Transformed<T>>>(
929        self,
930        mut f: F,
931    ) -> Result<Transformed<Self>> {
932        let mut tnr = TreeNodeRecursion::Continue;
933        let mut transformed = false;
934        self.into_iter()
935            .map(|(k, c)| match tnr {
936                TreeNodeRecursion::Continue | TreeNodeRecursion::Jump => {
937                    c.map_elements(&mut f).map(|result| {
938                        tnr = result.tnr;
939                        transformed |= result.transformed;
940                        (k, result.data)
941                    })
942                }
943                TreeNodeRecursion::Stop => Ok((k, c)),
944            })
945            .collect::<Result<HashMap<_, _>>>()
946            .map(|data| Transformed::new(data, transformed, tnr))
947    }
948}
949
950impl<'a, T: 'a, C0: TreeNodeContainer<'a, T>, C1: TreeNodeContainer<'a, T>>
951    TreeNodeContainer<'a, T> for (C0, C1)
952{
953    fn apply_elements<F: FnMut(&'a T) -> Result<TreeNodeRecursion>>(
954        &'a self,
955        mut f: F,
956    ) -> Result<TreeNodeRecursion> {
957        self.0
958            .apply_elements(&mut f)?
959            .visit_sibling(|| self.1.apply_elements(&mut f))
960    }
961
962    fn map_elements<F: FnMut(T) -> Result<Transformed<T>>>(
963        self,
964        mut f: F,
965    ) -> Result<Transformed<Self>> {
966        self.0
967            .map_elements(&mut f)?
968            .map_data(|new_c0| Ok((new_c0, self.1)))?
969            .transform_sibling(|(new_c0, c1)| {
970                c1.map_elements(&mut f)?
971                    .map_data(|new_c1| Ok((new_c0, new_c1)))
972            })
973    }
974}
975
976impl<
977    'a,
978    T: 'a,
979    C0: TreeNodeContainer<'a, T>,
980    C1: TreeNodeContainer<'a, T>,
981    C2: TreeNodeContainer<'a, T>,
982> TreeNodeContainer<'a, T> for (C0, C1, C2)
983{
984    fn apply_elements<F: FnMut(&'a T) -> Result<TreeNodeRecursion>>(
985        &'a self,
986        mut f: F,
987    ) -> Result<TreeNodeRecursion> {
988        self.0
989            .apply_elements(&mut f)?
990            .visit_sibling(|| self.1.apply_elements(&mut f))?
991            .visit_sibling(|| self.2.apply_elements(&mut f))
992    }
993
994    fn map_elements<F: FnMut(T) -> Result<Transformed<T>>>(
995        self,
996        mut f: F,
997    ) -> Result<Transformed<Self>> {
998        self.0
999            .map_elements(&mut f)?
1000            .map_data(|new_c0| Ok((new_c0, self.1, self.2)))?
1001            .transform_sibling(|(new_c0, c1, c2)| {
1002                c1.map_elements(&mut f)?
1003                    .map_data(|new_c1| Ok((new_c0, new_c1, c2)))
1004            })?
1005            .transform_sibling(|(new_c0, new_c1, c2)| {
1006                c2.map_elements(&mut f)?
1007                    .map_data(|new_c2| Ok((new_c0, new_c1, new_c2)))
1008            })
1009    }
1010}
1011
1012impl<
1013    'a,
1014    T: 'a,
1015    C0: TreeNodeContainer<'a, T>,
1016    C1: TreeNodeContainer<'a, T>,
1017    C2: TreeNodeContainer<'a, T>,
1018    C3: TreeNodeContainer<'a, T>,
1019> TreeNodeContainer<'a, T> for (C0, C1, C2, C3)
1020{
1021    fn apply_elements<F: FnMut(&'a T) -> Result<TreeNodeRecursion>>(
1022        &'a self,
1023        mut f: F,
1024    ) -> Result<TreeNodeRecursion> {
1025        self.0
1026            .apply_elements(&mut f)?
1027            .visit_sibling(|| self.1.apply_elements(&mut f))?
1028            .visit_sibling(|| self.2.apply_elements(&mut f))?
1029            .visit_sibling(|| self.3.apply_elements(&mut f))
1030    }
1031
1032    fn map_elements<F: FnMut(T) -> Result<Transformed<T>>>(
1033        self,
1034        mut f: F,
1035    ) -> Result<Transformed<Self>> {
1036        self.0
1037            .map_elements(&mut f)?
1038            .map_data(|new_c0| Ok((new_c0, self.1, self.2, self.3)))?
1039            .transform_sibling(|(new_c0, c1, c2, c3)| {
1040                c1.map_elements(&mut f)?
1041                    .map_data(|new_c1| Ok((new_c0, new_c1, c2, c3)))
1042            })?
1043            .transform_sibling(|(new_c0, new_c1, c2, c3)| {
1044                c2.map_elements(&mut f)?
1045                    .map_data(|new_c2| Ok((new_c0, new_c1, new_c2, c3)))
1046            })?
1047            .transform_sibling(|(new_c0, new_c1, new_c2, c3)| {
1048                c3.map_elements(&mut f)?
1049                    .map_data(|new_c3| Ok((new_c0, new_c1, new_c2, new_c3)))
1050            })
1051    }
1052}
1053
1054/// [`TreeNodeRefContainer`] contains references to elements that a function can be
1055/// applied on. The elements of the container are siblings so the continuation rules are
1056/// similar to [`TreeNodeRecursion::visit_sibling`].
1057///
1058/// This container is similar to [`TreeNodeContainer`], but the lifetime of the reference
1059/// elements (`T`) are not derived from the container's lifetime.
1060/// A typical usage of this container is in `Expr::apply_children` when we need to
1061/// construct a temporary container to be able to call `apply_ref_elements` on a
1062/// collection of tree node references. But in that case the container's temporary
1063/// lifetime is different to the lifetime of tree nodes that we put into it.
1064/// Please find an example use case in `Expr::apply_children` with the `Expr::Case` case.
1065///
1066/// Most of the cases we don't need to create a temporary container with
1067/// `TreeNodeRefContainer`, but we can just call `TreeNodeContainer::apply_elements`.
1068/// Please find an example use case in `Expr::apply_children` with the `Expr::GroupingSet`
1069/// case.
1070pub trait TreeNodeRefContainer<'a, T: 'a>: Sized {
1071    /// Applies `f` to all elements of the container.
1072    /// This method is usually called from [`TreeNode::apply_children`] implementations as
1073    /// a node is actually a container of the node's children.
1074    fn apply_ref_elements<F: FnMut(&'a T) -> Result<TreeNodeRecursion>>(
1075        &self,
1076        f: F,
1077    ) -> Result<TreeNodeRecursion>;
1078}
1079
1080impl<'a, T: 'a, C: TreeNodeContainer<'a, T>> TreeNodeRefContainer<'a, T> for Vec<&'a C> {
1081    fn apply_ref_elements<F: FnMut(&'a T) -> Result<TreeNodeRecursion>>(
1082        &self,
1083        mut f: F,
1084    ) -> Result<TreeNodeRecursion> {
1085        let mut tnr = TreeNodeRecursion::Continue;
1086        for c in self {
1087            tnr = c.apply_elements(&mut f)?;
1088            match tnr {
1089                TreeNodeRecursion::Continue | TreeNodeRecursion::Jump => {}
1090                TreeNodeRecursion::Stop => return Ok(TreeNodeRecursion::Stop),
1091            }
1092        }
1093        Ok(tnr)
1094    }
1095}
1096
1097impl<'a, T: 'a, C0: TreeNodeContainer<'a, T>, C1: TreeNodeContainer<'a, T>>
1098    TreeNodeRefContainer<'a, T> for (&'a C0, &'a C1)
1099{
1100    fn apply_ref_elements<F: FnMut(&'a T) -> Result<TreeNodeRecursion>>(
1101        &self,
1102        mut f: F,
1103    ) -> Result<TreeNodeRecursion> {
1104        self.0
1105            .apply_elements(&mut f)?
1106            .visit_sibling(|| self.1.apply_elements(&mut f))
1107    }
1108}
1109
1110impl<
1111    'a,
1112    T: 'a,
1113    C0: TreeNodeContainer<'a, T>,
1114    C1: TreeNodeContainer<'a, T>,
1115    C2: TreeNodeContainer<'a, T>,
1116> TreeNodeRefContainer<'a, T> for (&'a C0, &'a C1, &'a C2)
1117{
1118    fn apply_ref_elements<F: FnMut(&'a T) -> Result<TreeNodeRecursion>>(
1119        &self,
1120        mut f: F,
1121    ) -> Result<TreeNodeRecursion> {
1122        self.0
1123            .apply_elements(&mut f)?
1124            .visit_sibling(|| self.1.apply_elements(&mut f))?
1125            .visit_sibling(|| self.2.apply_elements(&mut f))
1126    }
1127}
1128
1129impl<
1130    'a,
1131    T: 'a,
1132    C0: TreeNodeContainer<'a, T>,
1133    C1: TreeNodeContainer<'a, T>,
1134    C2: TreeNodeContainer<'a, T>,
1135    C3: TreeNodeContainer<'a, T>,
1136> TreeNodeRefContainer<'a, T> for (&'a C0, &'a C1, &'a C2, &'a C3)
1137{
1138    fn apply_ref_elements<F: FnMut(&'a T) -> Result<TreeNodeRecursion>>(
1139        &self,
1140        mut f: F,
1141    ) -> Result<TreeNodeRecursion> {
1142        self.0
1143            .apply_elements(&mut f)?
1144            .visit_sibling(|| self.1.apply_elements(&mut f))?
1145            .visit_sibling(|| self.2.apply_elements(&mut f))?
1146            .visit_sibling(|| self.3.apply_elements(&mut f))
1147    }
1148}
1149
1150/// Transformation helper to process a sequence of iterable tree nodes that are siblings.
1151pub trait TreeNodeIterator: Iterator {
1152    /// Apples `f` to each item in this iterator
1153    ///
1154    /// Visits all items in the iterator unless
1155    /// `f` returns an error or `f` returns `TreeNodeRecursion::Stop`.
1156    ///
1157    /// # Returns
1158    /// Error if `f` returns an error or `Ok(TreeNodeRecursion)` from the last invocation
1159    /// of `f` or `Continue` if the iterator is empty
1160    fn apply_until_stop<F: FnMut(Self::Item) -> Result<TreeNodeRecursion>>(
1161        self,
1162        f: F,
1163    ) -> Result<TreeNodeRecursion>;
1164
1165    /// Apples `f` to each item in this iterator
1166    ///
1167    /// Visits all items in the iterator unless
1168    /// `f` returns an error or `f` returns `TreeNodeRecursion::Stop`.
1169    ///
1170    /// # Returns
1171    /// Error if `f` returns an error
1172    ///
1173    /// Ok(Transformed) such that:
1174    /// 1. `transformed` is true if any return from `f` had transformed true
1175    /// 2. `data` from the last invocation of `f`
1176    /// 3. `tnr` from the last invocation of `f` or `Continue` if the iterator is empty
1177    fn map_until_stop_and_collect<
1178        F: FnMut(Self::Item) -> Result<Transformed<Self::Item>>,
1179    >(
1180        self,
1181        f: F,
1182    ) -> Result<Transformed<Vec<Self::Item>>>;
1183}
1184
1185impl<I: Iterator> TreeNodeIterator for I {
1186    fn apply_until_stop<F: FnMut(Self::Item) -> Result<TreeNodeRecursion>>(
1187        self,
1188        mut f: F,
1189    ) -> Result<TreeNodeRecursion> {
1190        let mut tnr = TreeNodeRecursion::Continue;
1191        for i in self {
1192            tnr = f(i)?;
1193            match tnr {
1194                TreeNodeRecursion::Continue | TreeNodeRecursion::Jump => {}
1195                TreeNodeRecursion::Stop => return Ok(TreeNodeRecursion::Stop),
1196            }
1197        }
1198        Ok(tnr)
1199    }
1200
1201    fn map_until_stop_and_collect<
1202        F: FnMut(Self::Item) -> Result<Transformed<Self::Item>>,
1203    >(
1204        self,
1205        mut f: F,
1206    ) -> Result<Transformed<Vec<Self::Item>>> {
1207        let mut tnr = TreeNodeRecursion::Continue;
1208        let mut transformed = false;
1209        self.map(|item| match tnr {
1210            TreeNodeRecursion::Continue | TreeNodeRecursion::Jump => {
1211                f(item).map(|result| {
1212                    tnr = result.tnr;
1213                    transformed |= result.transformed;
1214                    result.data
1215                })
1216            }
1217            TreeNodeRecursion::Stop => Ok(item),
1218        })
1219        .collect::<Result<Vec<_>>>()
1220        .map(|data| Transformed::new(data, transformed, tnr))
1221    }
1222}
1223
1224/// Transformation helper to access [`Transformed`] fields in a [`Result`] easily.
1225///
1226/// # Example
1227/// Access the internal data of a `Result<Transformed<T>>`
1228/// as a `Result<T>` using the `data` method:
1229/// ```
1230/// # use datafusion_common::Result;
1231/// # use datafusion_common::tree_node::{Transformed, TransformedResult};
1232/// # // note use i64 instead of Expr as Expr is not in datafusion-common
1233/// # fn update_expr() -> i64 { 1 }
1234/// # fn main() -> Result<()> {
1235/// let transformed: Result<Transformed<_>> = Ok(Transformed::yes(update_expr()));
1236/// // access the internal data of the transformed result, or return the error
1237/// let transformed_expr = transformed.data()?;
1238/// # Ok(())
1239/// # }
1240/// ```
1241pub trait TransformedResult<T> {
1242    fn data(self) -> Result<T>;
1243
1244    fn transformed(self) -> Result<bool>;
1245
1246    fn tnr(self) -> Result<TreeNodeRecursion>;
1247}
1248
1249impl<T> TransformedResult<T> for Result<Transformed<T>> {
1250    fn data(self) -> Result<T> {
1251        self.map(|t| t.data)
1252    }
1253
1254    fn transformed(self) -> Result<bool> {
1255        self.map(|t| t.transformed)
1256    }
1257
1258    fn tnr(self) -> Result<TreeNodeRecursion> {
1259        self.map(|t| t.tnr)
1260    }
1261}
1262
1263/// Helper trait for implementing [`TreeNode`] that have children stored as
1264/// `Arc`s. If some trait object, such as `dyn T`, implements this trait,
1265/// its related `Arc<dyn T>` will automatically implement [`TreeNode`].
1266pub trait DynTreeNode {
1267    /// Returns all children of the specified `TreeNode`.
1268    fn arc_children(&self) -> Vec<&Arc<Self>>;
1269
1270    /// Constructs a new node with the specified children.
1271    fn with_new_arc_children(
1272        &self,
1273        arc_self: Arc<Self>,
1274        new_children: Vec<Arc<Self>>,
1275    ) -> Result<Arc<Self>>;
1276}
1277
1278/// Blanket implementation for any `Arc<T>` where `T` implements [`DynTreeNode`]
1279/// (such as [`Arc<dyn PhysicalExpr>`]).
1280impl<T: DynTreeNode + ?Sized> TreeNode for Arc<T> {
1281    fn apply_children<'n, F: FnMut(&'n Self) -> Result<TreeNodeRecursion>>(
1282        &'n self,
1283        f: F,
1284    ) -> Result<TreeNodeRecursion> {
1285        self.arc_children().into_iter().apply_until_stop(f)
1286    }
1287
1288    fn map_children<F: FnMut(Self) -> Result<Transformed<Self>>>(
1289        self,
1290        f: F,
1291    ) -> Result<Transformed<Self>> {
1292        let children = self.arc_children();
1293        if !children.is_empty() {
1294            let new_children = children
1295                .into_iter()
1296                .cloned()
1297                .map_until_stop_and_collect(f)?;
1298            // Propagate up `new_children.transformed` and `new_children.tnr`
1299            // along with the node containing transformed children.
1300            if new_children.transformed {
1301                let arc_self = Arc::clone(&self);
1302                new_children.map_data(|new_children| {
1303                    self.with_new_arc_children(arc_self, new_children)
1304                })
1305            } else {
1306                Ok(Transformed::new(self, false, new_children.tnr))
1307            }
1308        } else {
1309            Ok(Transformed::no(self))
1310        }
1311    }
1312}
1313
1314/// Instead of implementing [`TreeNode`], it's recommended to implement a [`ConcreteTreeNode`] for
1315/// trees that contain nodes with payloads. This approach ensures safe execution of algorithms
1316/// involving payloads, by enforcing rules for detaching and reattaching child nodes.
1317pub trait ConcreteTreeNode: Sized {
1318    /// Provides read-only access to child nodes.
1319    fn children(&self) -> &[Self];
1320
1321    /// Detaches the node from its children, returning the node itself and its detached children.
1322    fn take_children(self) -> (Self, Vec<Self>);
1323
1324    /// Reattaches updated child nodes to the node, returning the updated node.
1325    fn with_new_children(self, children: Vec<Self>) -> Result<Self>;
1326}
1327
1328impl<T: ConcreteTreeNode> TreeNode for T {
1329    fn apply_children<'n, F: FnMut(&'n Self) -> Result<TreeNodeRecursion>>(
1330        &'n self,
1331        f: F,
1332    ) -> Result<TreeNodeRecursion> {
1333        self.children().iter().apply_until_stop(f)
1334    }
1335
1336    fn map_children<F: FnMut(Self) -> Result<Transformed<Self>>>(
1337        self,
1338        f: F,
1339    ) -> Result<Transformed<Self>> {
1340        let (new_self, children) = self.take_children();
1341        if !children.is_empty() {
1342            let new_children = children.into_iter().map_until_stop_and_collect(f)?;
1343            // Propagate up `new_children.transformed` and `new_children.tnr` along with
1344            // the node containing transformed children.
1345            new_children.map_data(|new_children| new_self.with_new_children(new_children))
1346        } else {
1347            Ok(Transformed::no(new_self))
1348        }
1349    }
1350}
1351
1352#[cfg(test)]
1353pub(crate) mod tests {
1354    use std::collections::HashMap;
1355    use std::fmt::Display;
1356    use std::sync::Arc;
1357
1358    use crate::Result;
1359    use crate::tree_node::{
1360        Transformed, TreeNode, TreeNodeContainer, TreeNodeRecursion, TreeNodeRewriter,
1361        TreeNodeVisitor,
1362    };
1363
1364    #[derive(Debug, Default, Eq, Hash, PartialEq, Clone)]
1365    pub struct TestTreeNode<T> {
1366        pub(crate) children: Vec<TestTreeNode<T>>,
1367        pub(crate) data: T,
1368    }
1369
1370    impl<T> TestTreeNode<T> {
1371        pub(crate) fn new(children: Vec<TestTreeNode<T>>, data: T) -> Self {
1372            Self { children, data }
1373        }
1374
1375        pub(crate) fn new_leaf(data: T) -> Self {
1376            Self {
1377                children: vec![],
1378                data,
1379            }
1380        }
1381
1382        pub(crate) fn is_leaf(&self) -> bool {
1383            self.children.is_empty()
1384        }
1385    }
1386
1387    impl<T> TreeNode for TestTreeNode<T> {
1388        fn apply_children<'n, F: FnMut(&'n Self) -> Result<TreeNodeRecursion>>(
1389            &'n self,
1390            f: F,
1391        ) -> Result<TreeNodeRecursion> {
1392            self.children.apply_elements(f)
1393        }
1394
1395        fn map_children<F: FnMut(Self) -> Result<Transformed<Self>>>(
1396            self,
1397            f: F,
1398        ) -> Result<Transformed<Self>> {
1399            Ok(self
1400                .children
1401                .map_elements(f)?
1402                .update_data(|new_children| Self {
1403                    children: new_children,
1404                    ..self
1405                }))
1406        }
1407    }
1408
1409    impl<'a, T: 'a> TreeNodeContainer<'a, Self> for TestTreeNode<T> {
1410        fn apply_elements<F: FnMut(&'a Self) -> Result<TreeNodeRecursion>>(
1411            &'a self,
1412            mut f: F,
1413        ) -> Result<TreeNodeRecursion> {
1414            f(self)
1415        }
1416
1417        fn map_elements<F: FnMut(Self) -> Result<Transformed<Self>>>(
1418            self,
1419            mut f: F,
1420        ) -> Result<Transformed<Self>> {
1421            f(self)
1422        }
1423    }
1424
1425    //       J
1426    //       |
1427    //       I
1428    //       |
1429    //       F
1430    //     /   \
1431    //    E     G
1432    //    |     |
1433    //    C     H
1434    //  /   \
1435    // B     D
1436    //       |
1437    //       A
1438    fn test_tree() -> TestTreeNode<String> {
1439        let node_a = TestTreeNode::new_leaf("a".to_string());
1440        let node_b = TestTreeNode::new_leaf("b".to_string());
1441        let node_d = TestTreeNode::new(vec![node_a], "d".to_string());
1442        let node_c = TestTreeNode::new(vec![node_b, node_d], "c".to_string());
1443        let node_e = TestTreeNode::new(vec![node_c], "e".to_string());
1444        let node_h = TestTreeNode::new_leaf("h".to_string());
1445        let node_g = TestTreeNode::new(vec![node_h], "g".to_string());
1446        let node_f = TestTreeNode::new(vec![node_e, node_g], "f".to_string());
1447        let node_i = TestTreeNode::new(vec![node_f], "i".to_string());
1448        TestTreeNode::new(vec![node_i], "j".to_string())
1449    }
1450
1451    // Continue on all nodes
1452    // Expected visits in a combined traversal
1453    fn all_visits() -> Vec<String> {
1454        vec![
1455            "f_down(j)",
1456            "f_down(i)",
1457            "f_down(f)",
1458            "f_down(e)",
1459            "f_down(c)",
1460            "f_down(b)",
1461            "f_up(b)",
1462            "f_down(d)",
1463            "f_down(a)",
1464            "f_up(a)",
1465            "f_up(d)",
1466            "f_up(c)",
1467            "f_up(e)",
1468            "f_down(g)",
1469            "f_down(h)",
1470            "f_up(h)",
1471            "f_up(g)",
1472            "f_up(f)",
1473            "f_up(i)",
1474            "f_up(j)",
1475        ]
1476        .into_iter()
1477        .map(|s| s.to_string())
1478        .collect()
1479    }
1480
1481    // Expected transformed tree after a combined traversal
1482    fn transformed_tree() -> TestTreeNode<String> {
1483        let node_a = TestTreeNode::new_leaf("f_up(f_down(a))".to_string());
1484        let node_b = TestTreeNode::new_leaf("f_up(f_down(b))".to_string());
1485        let node_d = TestTreeNode::new(vec![node_a], "f_up(f_down(d))".to_string());
1486        let node_c =
1487            TestTreeNode::new(vec![node_b, node_d], "f_up(f_down(c))".to_string());
1488        let node_e = TestTreeNode::new(vec![node_c], "f_up(f_down(e))".to_string());
1489        let node_h = TestTreeNode::new_leaf("f_up(f_down(h))".to_string());
1490        let node_g = TestTreeNode::new(vec![node_h], "f_up(f_down(g))".to_string());
1491        let node_f =
1492            TestTreeNode::new(vec![node_e, node_g], "f_up(f_down(f))".to_string());
1493        let node_i = TestTreeNode::new(vec![node_f], "f_up(f_down(i))".to_string());
1494        TestTreeNode::new(vec![node_i], "f_up(f_down(j))".to_string())
1495    }
1496
1497    // Expected transformed tree after a top-down traversal
1498    fn transformed_down_tree() -> TestTreeNode<String> {
1499        let node_a = TestTreeNode::new_leaf("f_down(a)".to_string());
1500        let node_b = TestTreeNode::new_leaf("f_down(b)".to_string());
1501        let node_d = TestTreeNode::new(vec![node_a], "f_down(d)".to_string());
1502        let node_c = TestTreeNode::new(vec![node_b, node_d], "f_down(c)".to_string());
1503        let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string());
1504        let node_h = TestTreeNode::new_leaf("f_down(h)".to_string());
1505        let node_g = TestTreeNode::new(vec![node_h], "f_down(g)".to_string());
1506        let node_f = TestTreeNode::new(vec![node_e, node_g], "f_down(f)".to_string());
1507        let node_i = TestTreeNode::new(vec![node_f], "f_down(i)".to_string());
1508        TestTreeNode::new(vec![node_i], "f_down(j)".to_string())
1509    }
1510
1511    // Expected transformed tree after a bottom-up traversal
1512    fn transformed_up_tree() -> TestTreeNode<String> {
1513        let node_a = TestTreeNode::new_leaf("f_up(a)".to_string());
1514        let node_b = TestTreeNode::new_leaf("f_up(b)".to_string());
1515        let node_d = TestTreeNode::new(vec![node_a], "f_up(d)".to_string());
1516        let node_c = TestTreeNode::new(vec![node_b, node_d], "f_up(c)".to_string());
1517        let node_e = TestTreeNode::new(vec![node_c], "f_up(e)".to_string());
1518        let node_h = TestTreeNode::new_leaf("f_up(h)".to_string());
1519        let node_g = TestTreeNode::new(vec![node_h], "f_up(g)".to_string());
1520        let node_f = TestTreeNode::new(vec![node_e, node_g], "f_up(f)".to_string());
1521        let node_i = TestTreeNode::new(vec![node_f], "f_up(i)".to_string());
1522        TestTreeNode::new(vec![node_i], "f_up(j)".to_string())
1523    }
1524
1525    // f_down Jump on A node
1526    fn f_down_jump_on_a_visits() -> Vec<String> {
1527        vec![
1528            "f_down(j)",
1529            "f_down(i)",
1530            "f_down(f)",
1531            "f_down(e)",
1532            "f_down(c)",
1533            "f_down(b)",
1534            "f_up(b)",
1535            "f_down(d)",
1536            "f_down(a)",
1537            "f_up(a)",
1538            "f_up(d)",
1539            "f_up(c)",
1540            "f_up(e)",
1541            "f_down(g)",
1542            "f_down(h)",
1543            "f_up(h)",
1544            "f_up(g)",
1545            "f_up(f)",
1546            "f_up(i)",
1547            "f_up(j)",
1548        ]
1549        .into_iter()
1550        .map(|s| s.to_string())
1551        .collect()
1552    }
1553
1554    fn f_down_jump_on_a_transformed_down_tree() -> TestTreeNode<String> {
1555        let node_a = TestTreeNode::new_leaf("f_down(a)".to_string());
1556        let node_b = TestTreeNode::new_leaf("f_down(b)".to_string());
1557        let node_d = TestTreeNode::new(vec![node_a], "f_down(d)".to_string());
1558        let node_c = TestTreeNode::new(vec![node_b, node_d], "f_down(c)".to_string());
1559        let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string());
1560        let node_h = TestTreeNode::new_leaf("f_down(h)".to_string());
1561        let node_g = TestTreeNode::new(vec![node_h], "f_down(g)".to_string());
1562        let node_f = TestTreeNode::new(vec![node_e, node_g], "f_down(f)".to_string());
1563        let node_i = TestTreeNode::new(vec![node_f], "f_down(i)".to_string());
1564        TestTreeNode::new(vec![node_i], "f_down(j)".to_string())
1565    }
1566
1567    // f_down Jump on E node
1568    fn f_down_jump_on_e_visits() -> Vec<String> {
1569        vec![
1570            "f_down(j)",
1571            "f_down(i)",
1572            "f_down(f)",
1573            "f_down(e)",
1574            "f_up(e)",
1575            "f_down(g)",
1576            "f_down(h)",
1577            "f_up(h)",
1578            "f_up(g)",
1579            "f_up(f)",
1580            "f_up(i)",
1581            "f_up(j)",
1582        ]
1583        .into_iter()
1584        .map(|s| s.to_string())
1585        .collect()
1586    }
1587
1588    fn f_down_jump_on_e_transformed_tree() -> TestTreeNode<String> {
1589        let node_a = TestTreeNode::new_leaf("a".to_string());
1590        let node_b = TestTreeNode::new_leaf("b".to_string());
1591        let node_d = TestTreeNode::new(vec![node_a], "d".to_string());
1592        let node_c = TestTreeNode::new(vec![node_b, node_d], "c".to_string());
1593        let node_e = TestTreeNode::new(vec![node_c], "f_up(f_down(e))".to_string());
1594        let node_h = TestTreeNode::new_leaf("f_up(f_down(h))".to_string());
1595        let node_g = TestTreeNode::new(vec![node_h], "f_up(f_down(g))".to_string());
1596        let node_f =
1597            TestTreeNode::new(vec![node_e, node_g], "f_up(f_down(f))".to_string());
1598        let node_i = TestTreeNode::new(vec![node_f], "f_up(f_down(i))".to_string());
1599        TestTreeNode::new(vec![node_i], "f_up(f_down(j))".to_string())
1600    }
1601
1602    fn f_down_jump_on_e_transformed_down_tree() -> TestTreeNode<String> {
1603        let node_a = TestTreeNode::new_leaf("a".to_string());
1604        let node_b = TestTreeNode::new_leaf("b".to_string());
1605        let node_d = TestTreeNode::new(vec![node_a], "d".to_string());
1606        let node_c = TestTreeNode::new(vec![node_b, node_d], "c".to_string());
1607        let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string());
1608        let node_h = TestTreeNode::new_leaf("f_down(h)".to_string());
1609        let node_g = TestTreeNode::new(vec![node_h], "f_down(g)".to_string());
1610        let node_f = TestTreeNode::new(vec![node_e, node_g], "f_down(f)".to_string());
1611        let node_i = TestTreeNode::new(vec![node_f], "f_down(i)".to_string());
1612        TestTreeNode::new(vec![node_i], "f_down(j)".to_string())
1613    }
1614
1615    // f_up Jump on A node
1616    fn f_up_jump_on_a_visits() -> Vec<String> {
1617        vec![
1618            "f_down(j)",
1619            "f_down(i)",
1620            "f_down(f)",
1621            "f_down(e)",
1622            "f_down(c)",
1623            "f_down(b)",
1624            "f_up(b)",
1625            "f_down(d)",
1626            "f_down(a)",
1627            "f_up(a)",
1628            "f_down(g)",
1629            "f_down(h)",
1630            "f_up(h)",
1631            "f_up(g)",
1632            "f_up(f)",
1633            "f_up(i)",
1634            "f_up(j)",
1635        ]
1636        .into_iter()
1637        .map(|s| s.to_string())
1638        .collect()
1639    }
1640
1641    fn f_up_jump_on_a_transformed_tree() -> TestTreeNode<String> {
1642        let node_a = TestTreeNode::new_leaf("f_up(f_down(a))".to_string());
1643        let node_b = TestTreeNode::new_leaf("f_up(f_down(b))".to_string());
1644        let node_d = TestTreeNode::new(vec![node_a], "f_down(d)".to_string());
1645        let node_c = TestTreeNode::new(vec![node_b, node_d], "f_down(c)".to_string());
1646        let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string());
1647        let node_h = TestTreeNode::new_leaf("f_up(f_down(h))".to_string());
1648        let node_g = TestTreeNode::new(vec![node_h], "f_up(f_down(g))".to_string());
1649        let node_f =
1650            TestTreeNode::new(vec![node_e, node_g], "f_up(f_down(f))".to_string());
1651        let node_i = TestTreeNode::new(vec![node_f], "f_up(f_down(i))".to_string());
1652        TestTreeNode::new(vec![node_i], "f_up(f_down(j))".to_string())
1653    }
1654
1655    fn f_up_jump_on_a_transformed_up_tree() -> TestTreeNode<String> {
1656        let node_a = TestTreeNode::new_leaf("f_up(a)".to_string());
1657        let node_b = TestTreeNode::new_leaf("f_up(b)".to_string());
1658        let node_d = TestTreeNode::new(vec![node_a], "d".to_string());
1659        let node_c = TestTreeNode::new(vec![node_b, node_d], "c".to_string());
1660        let node_e = TestTreeNode::new(vec![node_c], "e".to_string());
1661        let node_h = TestTreeNode::new_leaf("f_up(h)".to_string());
1662        let node_g = TestTreeNode::new(vec![node_h], "f_up(g)".to_string());
1663        let node_f = TestTreeNode::new(vec![node_e, node_g], "f_up(f)".to_string());
1664        let node_i = TestTreeNode::new(vec![node_f], "f_up(i)".to_string());
1665        TestTreeNode::new(vec![node_i], "f_up(j)".to_string())
1666    }
1667
1668    // f_up Jump on E node
1669    fn f_up_jump_on_e_visits() -> Vec<String> {
1670        vec![
1671            "f_down(j)",
1672            "f_down(i)",
1673            "f_down(f)",
1674            "f_down(e)",
1675            "f_down(c)",
1676            "f_down(b)",
1677            "f_up(b)",
1678            "f_down(d)",
1679            "f_down(a)",
1680            "f_up(a)",
1681            "f_up(d)",
1682            "f_up(c)",
1683            "f_up(e)",
1684            "f_down(g)",
1685            "f_down(h)",
1686            "f_up(h)",
1687            "f_up(g)",
1688            "f_up(f)",
1689            "f_up(i)",
1690            "f_up(j)",
1691        ]
1692        .into_iter()
1693        .map(|s| s.to_string())
1694        .collect()
1695    }
1696
1697    fn f_up_jump_on_e_transformed_tree() -> TestTreeNode<String> {
1698        transformed_tree()
1699    }
1700
1701    fn f_up_jump_on_e_transformed_up_tree() -> TestTreeNode<String> {
1702        transformed_up_tree()
1703    }
1704
1705    // f_down Stop on A node
1706
1707    fn f_down_stop_on_a_visits() -> Vec<String> {
1708        vec![
1709            "f_down(j)",
1710            "f_down(i)",
1711            "f_down(f)",
1712            "f_down(e)",
1713            "f_down(c)",
1714            "f_down(b)",
1715            "f_up(b)",
1716            "f_down(d)",
1717            "f_down(a)",
1718        ]
1719        .into_iter()
1720        .map(|s| s.to_string())
1721        .collect()
1722    }
1723
1724    fn f_down_stop_on_a_transformed_tree() -> TestTreeNode<String> {
1725        let node_a = TestTreeNode::new_leaf("f_down(a)".to_string());
1726        let node_b = TestTreeNode::new_leaf("f_up(f_down(b))".to_string());
1727        let node_d = TestTreeNode::new(vec![node_a], "f_down(d)".to_string());
1728        let node_c = TestTreeNode::new(vec![node_b, node_d], "f_down(c)".to_string());
1729        let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string());
1730        let node_h = TestTreeNode::new_leaf("h".to_string());
1731        let node_g = TestTreeNode::new(vec![node_h], "g".to_string());
1732        let node_f = TestTreeNode::new(vec![node_e, node_g], "f_down(f)".to_string());
1733        let node_i = TestTreeNode::new(vec![node_f], "f_down(i)".to_string());
1734        TestTreeNode::new(vec![node_i], "f_down(j)".to_string())
1735    }
1736
1737    fn f_down_stop_on_a_transformed_down_tree() -> TestTreeNode<String> {
1738        let node_a = TestTreeNode::new_leaf("f_down(a)".to_string());
1739        let node_b = TestTreeNode::new_leaf("f_down(b)".to_string());
1740        let node_d = TestTreeNode::new(vec![node_a], "f_down(d)".to_string());
1741        let node_c = TestTreeNode::new(vec![node_b, node_d], "f_down(c)".to_string());
1742        let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string());
1743        let node_h = TestTreeNode::new_leaf("h".to_string());
1744        let node_g = TestTreeNode::new(vec![node_h], "g".to_string());
1745        let node_f = TestTreeNode::new(vec![node_e, node_g], "f_down(f)".to_string());
1746        let node_i = TestTreeNode::new(vec![node_f], "f_down(i)".to_string());
1747        TestTreeNode::new(vec![node_i], "f_down(j)".to_string())
1748    }
1749
1750    // f_down Stop on E node
1751    fn f_down_stop_on_e_visits() -> Vec<String> {
1752        vec!["f_down(j)", "f_down(i)", "f_down(f)", "f_down(e)"]
1753            .into_iter()
1754            .map(|s| s.to_string())
1755            .collect()
1756    }
1757
1758    fn f_down_stop_on_e_transformed_tree() -> TestTreeNode<String> {
1759        let node_a = TestTreeNode::new_leaf("a".to_string());
1760        let node_b = TestTreeNode::new_leaf("b".to_string());
1761        let node_d = TestTreeNode::new(vec![node_a], "d".to_string());
1762        let node_c = TestTreeNode::new(vec![node_b, node_d], "c".to_string());
1763        let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string());
1764        let node_h = TestTreeNode::new_leaf("h".to_string());
1765        let node_g = TestTreeNode::new(vec![node_h], "g".to_string());
1766        let node_f = TestTreeNode::new(vec![node_e, node_g], "f_down(f)".to_string());
1767        let node_i = TestTreeNode::new(vec![node_f], "f_down(i)".to_string());
1768        TestTreeNode::new(vec![node_i], "f_down(j)".to_string())
1769    }
1770
1771    fn f_down_stop_on_e_transformed_down_tree() -> TestTreeNode<String> {
1772        let node_a = TestTreeNode::new_leaf("a".to_string());
1773        let node_b = TestTreeNode::new_leaf("b".to_string());
1774        let node_d = TestTreeNode::new(vec![node_a], "d".to_string());
1775        let node_c = TestTreeNode::new(vec![node_b, node_d], "c".to_string());
1776        let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string());
1777        let node_h = TestTreeNode::new_leaf("h".to_string());
1778        let node_g = TestTreeNode::new(vec![node_h], "g".to_string());
1779        let node_f = TestTreeNode::new(vec![node_e, node_g], "f_down(f)".to_string());
1780        let node_i = TestTreeNode::new(vec![node_f], "f_down(i)".to_string());
1781        TestTreeNode::new(vec![node_i], "f_down(j)".to_string())
1782    }
1783
1784    // f_up Stop on A node
1785    fn f_up_stop_on_a_visits() -> Vec<String> {
1786        vec![
1787            "f_down(j)",
1788            "f_down(i)",
1789            "f_down(f)",
1790            "f_down(e)",
1791            "f_down(c)",
1792            "f_down(b)",
1793            "f_up(b)",
1794            "f_down(d)",
1795            "f_down(a)",
1796            "f_up(a)",
1797        ]
1798        .into_iter()
1799        .map(|s| s.to_string())
1800        .collect()
1801    }
1802
1803    fn f_up_stop_on_a_transformed_tree() -> TestTreeNode<String> {
1804        let node_a = TestTreeNode::new_leaf("f_up(f_down(a))".to_string());
1805        let node_b = TestTreeNode::new_leaf("f_up(f_down(b))".to_string());
1806        let node_d = TestTreeNode::new(vec![node_a], "f_down(d)".to_string());
1807        let node_c = TestTreeNode::new(vec![node_b, node_d], "f_down(c)".to_string());
1808        let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string());
1809        let node_h = TestTreeNode::new_leaf("h".to_string());
1810        let node_g = TestTreeNode::new(vec![node_h], "g".to_string());
1811        let node_f = TestTreeNode::new(vec![node_e, node_g], "f_down(f)".to_string());
1812        let node_i = TestTreeNode::new(vec![node_f], "f_down(i)".to_string());
1813        TestTreeNode::new(vec![node_i], "f_down(j)".to_string())
1814    }
1815
1816    fn f_up_stop_on_a_transformed_up_tree() -> TestTreeNode<String> {
1817        let node_a = TestTreeNode::new_leaf("f_up(a)".to_string());
1818        let node_b = TestTreeNode::new_leaf("f_up(b)".to_string());
1819        let node_d = TestTreeNode::new(vec![node_a], "d".to_string());
1820        let node_c = TestTreeNode::new(vec![node_b, node_d], "c".to_string());
1821        let node_e = TestTreeNode::new(vec![node_c], "e".to_string());
1822        let node_h = TestTreeNode::new_leaf("h".to_string());
1823        let node_g = TestTreeNode::new(vec![node_h], "g".to_string());
1824        let node_f = TestTreeNode::new(vec![node_e, node_g], "f".to_string());
1825        let node_i = TestTreeNode::new(vec![node_f], "i".to_string());
1826        TestTreeNode::new(vec![node_i], "j".to_string())
1827    }
1828
1829    // f_up Stop on E node
1830    fn f_up_stop_on_e_visits() -> Vec<String> {
1831        vec![
1832            "f_down(j)",
1833            "f_down(i)",
1834            "f_down(f)",
1835            "f_down(e)",
1836            "f_down(c)",
1837            "f_down(b)",
1838            "f_up(b)",
1839            "f_down(d)",
1840            "f_down(a)",
1841            "f_up(a)",
1842            "f_up(d)",
1843            "f_up(c)",
1844            "f_up(e)",
1845        ]
1846        .into_iter()
1847        .map(|s| s.to_string())
1848        .collect()
1849    }
1850
1851    fn f_up_stop_on_e_transformed_tree() -> TestTreeNode<String> {
1852        let node_a = TestTreeNode::new_leaf("f_up(f_down(a))".to_string());
1853        let node_b = TestTreeNode::new_leaf("f_up(f_down(b))".to_string());
1854        let node_d = TestTreeNode::new(vec![node_a], "f_up(f_down(d))".to_string());
1855        let node_c =
1856            TestTreeNode::new(vec![node_b, node_d], "f_up(f_down(c))".to_string());
1857        let node_e = TestTreeNode::new(vec![node_c], "f_up(f_down(e))".to_string());
1858        let node_h = TestTreeNode::new_leaf("h".to_string());
1859        let node_g = TestTreeNode::new(vec![node_h], "g".to_string());
1860        let node_f = TestTreeNode::new(vec![node_e, node_g], "f_down(f)".to_string());
1861        let node_i = TestTreeNode::new(vec![node_f], "f_down(i)".to_string());
1862        TestTreeNode::new(vec![node_i], "f_down(j)".to_string())
1863    }
1864
1865    fn f_up_stop_on_e_transformed_up_tree() -> TestTreeNode<String> {
1866        let node_a = TestTreeNode::new_leaf("f_up(a)".to_string());
1867        let node_b = TestTreeNode::new_leaf("f_up(b)".to_string());
1868        let node_d = TestTreeNode::new(vec![node_a], "f_up(d)".to_string());
1869        let node_c = TestTreeNode::new(vec![node_b, node_d], "f_up(c)".to_string());
1870        let node_e = TestTreeNode::new(vec![node_c], "f_up(e)".to_string());
1871        let node_h = TestTreeNode::new_leaf("h".to_string());
1872        let node_g = TestTreeNode::new(vec![node_h], "g".to_string());
1873        let node_f = TestTreeNode::new(vec![node_e, node_g], "f".to_string());
1874        let node_i = TestTreeNode::new(vec![node_f], "i".to_string());
1875        TestTreeNode::new(vec![node_i], "j".to_string())
1876    }
1877
1878    fn down_visits(visits: Vec<String>) -> Vec<String> {
1879        visits
1880            .into_iter()
1881            .filter(|v| v.starts_with("f_down"))
1882            .collect()
1883    }
1884
1885    type TestVisitorF<T> = Box<dyn FnMut(&TestTreeNode<T>) -> Result<TreeNodeRecursion>>;
1886
1887    struct TestVisitor<T> {
1888        visits: Vec<String>,
1889        f_down: TestVisitorF<T>,
1890        f_up: TestVisitorF<T>,
1891    }
1892
1893    impl<T> TestVisitor<T> {
1894        fn new(f_down: TestVisitorF<T>, f_up: TestVisitorF<T>) -> Self {
1895            Self {
1896                visits: vec![],
1897                f_down,
1898                f_up,
1899            }
1900        }
1901    }
1902
1903    impl<'n, T: Display> TreeNodeVisitor<'n> for TestVisitor<T> {
1904        type Node = TestTreeNode<T>;
1905
1906        fn f_down(&mut self, node: &'n Self::Node) -> Result<TreeNodeRecursion> {
1907            self.visits.push(format!("f_down({})", node.data));
1908            (*self.f_down)(node)
1909        }
1910
1911        fn f_up(&mut self, node: &'n Self::Node) -> Result<TreeNodeRecursion> {
1912            self.visits.push(format!("f_up({})", node.data));
1913            (*self.f_up)(node)
1914        }
1915    }
1916
1917    fn visit_continue<T>(_: &TestTreeNode<T>) -> Result<TreeNodeRecursion> {
1918        Ok(TreeNodeRecursion::Continue)
1919    }
1920
1921    fn visit_event_on<T: PartialEq, D: Into<T>>(
1922        data: D,
1923        event: TreeNodeRecursion,
1924    ) -> impl FnMut(&TestTreeNode<T>) -> Result<TreeNodeRecursion> {
1925        let d = data.into();
1926        move |node| {
1927            Ok(if node.data == d {
1928                event
1929            } else {
1930                TreeNodeRecursion::Continue
1931            })
1932        }
1933    }
1934
1935    macro_rules! visit_test {
1936        ($NAME:ident, $F_DOWN:expr, $F_UP:expr, $EXPECTED_VISITS:expr) => {
1937            #[test]
1938            fn $NAME() -> Result<()> {
1939                let tree = test_tree();
1940                let mut visitor = TestVisitor::new(Box::new($F_DOWN), Box::new($F_UP));
1941                tree.visit(&mut visitor)?;
1942                assert_eq!(visitor.visits, $EXPECTED_VISITS);
1943
1944                Ok(())
1945            }
1946        };
1947    }
1948
1949    macro_rules! test_apply {
1950        ($NAME:ident, $F:expr, $EXPECTED_VISITS:expr) => {
1951            #[test]
1952            fn $NAME() -> Result<()> {
1953                let tree = test_tree();
1954                let mut visits = vec![];
1955                tree.apply(|node| {
1956                    visits.push(format!("f_down({})", node.data));
1957                    $F(node)
1958                })?;
1959                assert_eq!(visits, $EXPECTED_VISITS);
1960
1961                Ok(())
1962            }
1963        };
1964    }
1965
1966    type TestRewriterF<T> =
1967        Box<dyn FnMut(TestTreeNode<T>) -> Result<Transformed<TestTreeNode<T>>>>;
1968
1969    struct TestRewriter<T> {
1970        f_down: TestRewriterF<T>,
1971        f_up: TestRewriterF<T>,
1972    }
1973
1974    impl<T> TestRewriter<T> {
1975        fn new(f_down: TestRewriterF<T>, f_up: TestRewriterF<T>) -> Self {
1976            Self { f_down, f_up }
1977        }
1978    }
1979
1980    impl<T: Display> TreeNodeRewriter for TestRewriter<T> {
1981        type Node = TestTreeNode<T>;
1982
1983        fn f_down(&mut self, node: Self::Node) -> Result<Transformed<Self::Node>> {
1984            (*self.f_down)(node)
1985        }
1986
1987        fn f_up(&mut self, node: Self::Node) -> Result<Transformed<Self::Node>> {
1988            (*self.f_up)(node)
1989        }
1990    }
1991
1992    fn transform_yes<N: Display, T: Display + From<String>>(
1993        transformation_name: N,
1994    ) -> impl FnMut(TestTreeNode<T>) -> Result<Transformed<TestTreeNode<T>>> {
1995        move |node| {
1996            Ok(Transformed::yes(TestTreeNode::new(
1997                node.children,
1998                format!("{}({})", transformation_name, node.data).into(),
1999            )))
2000        }
2001    }
2002
2003    fn transform_and_event_on<
2004        N: Display,
2005        T: PartialEq + Display + From<String>,
2006        D: Into<T>,
2007    >(
2008        transformation_name: N,
2009        data: D,
2010        event: TreeNodeRecursion,
2011    ) -> impl FnMut(TestTreeNode<T>) -> Result<Transformed<TestTreeNode<T>>> {
2012        let d = data.into();
2013        move |node| {
2014            let new_node = TestTreeNode::new(
2015                node.children,
2016                format!("{}({})", transformation_name, node.data).into(),
2017            );
2018            Ok(if node.data == d {
2019                Transformed::new(new_node, true, event)
2020            } else {
2021                Transformed::yes(new_node)
2022            })
2023        }
2024    }
2025
2026    macro_rules! rewrite_test {
2027        ($NAME:ident, $F_DOWN:expr, $F_UP:expr, $EXPECTED_TREE:expr) => {
2028            #[test]
2029            fn $NAME() -> Result<()> {
2030                let tree = test_tree();
2031                let mut rewriter = TestRewriter::new(Box::new($F_DOWN), Box::new($F_UP));
2032                assert_eq!(tree.rewrite(&mut rewriter)?, $EXPECTED_TREE);
2033
2034                Ok(())
2035            }
2036        };
2037    }
2038
2039    macro_rules! transform_test {
2040        ($NAME:ident, $F_DOWN:expr, $F_UP:expr, $EXPECTED_TREE:expr) => {
2041            #[test]
2042            fn $NAME() -> Result<()> {
2043                let tree = test_tree();
2044                assert_eq!(tree.transform_down_up($F_DOWN, $F_UP,)?, $EXPECTED_TREE);
2045
2046                Ok(())
2047            }
2048        };
2049    }
2050
2051    macro_rules! transform_down_test {
2052        ($NAME:ident, $F:expr, $EXPECTED_TREE:expr) => {
2053            #[test]
2054            fn $NAME() -> Result<()> {
2055                let tree = test_tree();
2056                assert_eq!(tree.transform_down($F)?, $EXPECTED_TREE);
2057
2058                Ok(())
2059            }
2060        };
2061    }
2062
2063    macro_rules! transform_up_test {
2064        ($NAME:ident, $F:expr, $EXPECTED_TREE:expr) => {
2065            #[test]
2066            fn $NAME() -> Result<()> {
2067                let tree = test_tree();
2068                assert_eq!(tree.transform_up($F)?, $EXPECTED_TREE);
2069
2070                Ok(())
2071            }
2072        };
2073    }
2074
2075    visit_test!(test_visit, visit_continue, visit_continue, all_visits());
2076    visit_test!(
2077        test_visit_f_down_jump_on_a,
2078        visit_event_on("a", TreeNodeRecursion::Jump),
2079        visit_continue,
2080        f_down_jump_on_a_visits()
2081    );
2082    visit_test!(
2083        test_visit_f_down_jump_on_e,
2084        visit_event_on("e", TreeNodeRecursion::Jump),
2085        visit_continue,
2086        f_down_jump_on_e_visits()
2087    );
2088    visit_test!(
2089        test_visit_f_up_jump_on_a,
2090        visit_continue,
2091        visit_event_on("a", TreeNodeRecursion::Jump),
2092        f_up_jump_on_a_visits()
2093    );
2094    visit_test!(
2095        test_visit_f_up_jump_on_e,
2096        visit_continue,
2097        visit_event_on("e", TreeNodeRecursion::Jump),
2098        f_up_jump_on_e_visits()
2099    );
2100    visit_test!(
2101        test_visit_f_down_stop_on_a,
2102        visit_event_on("a", TreeNodeRecursion::Stop),
2103        visit_continue,
2104        f_down_stop_on_a_visits()
2105    );
2106    visit_test!(
2107        test_visit_f_down_stop_on_e,
2108        visit_event_on("e", TreeNodeRecursion::Stop),
2109        visit_continue,
2110        f_down_stop_on_e_visits()
2111    );
2112    visit_test!(
2113        test_visit_f_up_stop_on_a,
2114        visit_continue,
2115        visit_event_on("a", TreeNodeRecursion::Stop),
2116        f_up_stop_on_a_visits()
2117    );
2118    visit_test!(
2119        test_visit_f_up_stop_on_e,
2120        visit_continue,
2121        visit_event_on("e", TreeNodeRecursion::Stop),
2122        f_up_stop_on_e_visits()
2123    );
2124
2125    test_apply!(test_apply, visit_continue, down_visits(all_visits()));
2126    test_apply!(
2127        test_apply_f_down_jump_on_a,
2128        visit_event_on("a", TreeNodeRecursion::Jump),
2129        down_visits(f_down_jump_on_a_visits())
2130    );
2131    test_apply!(
2132        test_apply_f_down_jump_on_e,
2133        visit_event_on("e", TreeNodeRecursion::Jump),
2134        down_visits(f_down_jump_on_e_visits())
2135    );
2136    test_apply!(
2137        test_apply_f_down_stop_on_a,
2138        visit_event_on("a", TreeNodeRecursion::Stop),
2139        down_visits(f_down_stop_on_a_visits())
2140    );
2141    test_apply!(
2142        test_apply_f_down_stop_on_e,
2143        visit_event_on("e", TreeNodeRecursion::Stop),
2144        down_visits(f_down_stop_on_e_visits())
2145    );
2146
2147    rewrite_test!(
2148        test_rewrite,
2149        transform_yes("f_down"),
2150        transform_yes("f_up"),
2151        Transformed::yes(transformed_tree())
2152    );
2153    rewrite_test!(
2154        test_rewrite_f_down_jump_on_a,
2155        transform_and_event_on("f_down", "a", TreeNodeRecursion::Jump),
2156        transform_yes("f_up"),
2157        Transformed::yes(transformed_tree())
2158    );
2159    rewrite_test!(
2160        test_rewrite_f_down_jump_on_e,
2161        transform_and_event_on("f_down", "e", TreeNodeRecursion::Jump),
2162        transform_yes("f_up"),
2163        Transformed::yes(f_down_jump_on_e_transformed_tree())
2164    );
2165    rewrite_test!(
2166        test_rewrite_f_up_jump_on_a,
2167        transform_yes("f_down"),
2168        transform_and_event_on("f_up", "f_down(a)", TreeNodeRecursion::Jump),
2169        Transformed::yes(f_up_jump_on_a_transformed_tree())
2170    );
2171    rewrite_test!(
2172        test_rewrite_f_up_jump_on_e,
2173        transform_yes("f_down"),
2174        transform_and_event_on("f_up", "f_down(e)", TreeNodeRecursion::Jump),
2175        Transformed::yes(f_up_jump_on_e_transformed_tree())
2176    );
2177    rewrite_test!(
2178        test_rewrite_f_down_stop_on_a,
2179        transform_and_event_on("f_down", "a", TreeNodeRecursion::Stop),
2180        transform_yes("f_up"),
2181        Transformed::new(
2182            f_down_stop_on_a_transformed_tree(),
2183            true,
2184            TreeNodeRecursion::Stop
2185        )
2186    );
2187    rewrite_test!(
2188        test_rewrite_f_down_stop_on_e,
2189        transform_and_event_on("f_down", "e", TreeNodeRecursion::Stop),
2190        transform_yes("f_up"),
2191        Transformed::new(
2192            f_down_stop_on_e_transformed_tree(),
2193            true,
2194            TreeNodeRecursion::Stop
2195        )
2196    );
2197    rewrite_test!(
2198        test_rewrite_f_up_stop_on_a,
2199        transform_yes("f_down"),
2200        transform_and_event_on("f_up", "f_down(a)", TreeNodeRecursion::Stop),
2201        Transformed::new(
2202            f_up_stop_on_a_transformed_tree(),
2203            true,
2204            TreeNodeRecursion::Stop
2205        )
2206    );
2207    rewrite_test!(
2208        test_rewrite_f_up_stop_on_e,
2209        transform_yes("f_down"),
2210        transform_and_event_on("f_up", "f_down(e)", TreeNodeRecursion::Stop),
2211        Transformed::new(
2212            f_up_stop_on_e_transformed_tree(),
2213            true,
2214            TreeNodeRecursion::Stop
2215        )
2216    );
2217
2218    transform_test!(
2219        test_transform,
2220        transform_yes("f_down"),
2221        transform_yes("f_up"),
2222        Transformed::yes(transformed_tree())
2223    );
2224    transform_test!(
2225        test_transform_f_down_jump_on_a,
2226        transform_and_event_on("f_down", "a", TreeNodeRecursion::Jump),
2227        transform_yes("f_up"),
2228        Transformed::yes(transformed_tree())
2229    );
2230    transform_test!(
2231        test_transform_f_down_jump_on_e,
2232        transform_and_event_on("f_down", "e", TreeNodeRecursion::Jump),
2233        transform_yes("f_up"),
2234        Transformed::yes(f_down_jump_on_e_transformed_tree())
2235    );
2236    transform_test!(
2237        test_transform_f_up_jump_on_a,
2238        transform_yes("f_down"),
2239        transform_and_event_on("f_up", "f_down(a)", TreeNodeRecursion::Jump),
2240        Transformed::yes(f_up_jump_on_a_transformed_tree())
2241    );
2242    transform_test!(
2243        test_transform_f_up_jump_on_e,
2244        transform_yes("f_down"),
2245        transform_and_event_on("f_up", "f_down(e)", TreeNodeRecursion::Jump),
2246        Transformed::yes(f_up_jump_on_e_transformed_tree())
2247    );
2248    transform_test!(
2249        test_transform_f_down_stop_on_a,
2250        transform_and_event_on("f_down", "a", TreeNodeRecursion::Stop),
2251        transform_yes("f_up"),
2252        Transformed::new(
2253            f_down_stop_on_a_transformed_tree(),
2254            true,
2255            TreeNodeRecursion::Stop
2256        )
2257    );
2258    transform_test!(
2259        test_transform_f_down_stop_on_e,
2260        transform_and_event_on("f_down", "e", TreeNodeRecursion::Stop),
2261        transform_yes("f_up"),
2262        Transformed::new(
2263            f_down_stop_on_e_transformed_tree(),
2264            true,
2265            TreeNodeRecursion::Stop
2266        )
2267    );
2268    transform_test!(
2269        test_transform_f_up_stop_on_a,
2270        transform_yes("f_down"),
2271        transform_and_event_on("f_up", "f_down(a)", TreeNodeRecursion::Stop),
2272        Transformed::new(
2273            f_up_stop_on_a_transformed_tree(),
2274            true,
2275            TreeNodeRecursion::Stop
2276        )
2277    );
2278    transform_test!(
2279        test_transform_f_up_stop_on_e,
2280        transform_yes("f_down"),
2281        transform_and_event_on("f_up", "f_down(e)", TreeNodeRecursion::Stop),
2282        Transformed::new(
2283            f_up_stop_on_e_transformed_tree(),
2284            true,
2285            TreeNodeRecursion::Stop
2286        )
2287    );
2288
2289    transform_down_test!(
2290        test_transform_down,
2291        transform_yes("f_down"),
2292        Transformed::yes(transformed_down_tree())
2293    );
2294    transform_down_test!(
2295        test_transform_down_f_down_jump_on_a,
2296        transform_and_event_on("f_down", "a", TreeNodeRecursion::Jump),
2297        Transformed::yes(f_down_jump_on_a_transformed_down_tree())
2298    );
2299    transform_down_test!(
2300        test_transform_down_f_down_jump_on_e,
2301        transform_and_event_on("f_down", "e", TreeNodeRecursion::Jump),
2302        Transformed::yes(f_down_jump_on_e_transformed_down_tree())
2303    );
2304    transform_down_test!(
2305        test_transform_down_f_down_stop_on_a,
2306        transform_and_event_on("f_down", "a", TreeNodeRecursion::Stop),
2307        Transformed::new(
2308            f_down_stop_on_a_transformed_down_tree(),
2309            true,
2310            TreeNodeRecursion::Stop
2311        )
2312    );
2313    transform_down_test!(
2314        test_transform_down_f_down_stop_on_e,
2315        transform_and_event_on("f_down", "e", TreeNodeRecursion::Stop),
2316        Transformed::new(
2317            f_down_stop_on_e_transformed_down_tree(),
2318            true,
2319            TreeNodeRecursion::Stop
2320        )
2321    );
2322
2323    transform_up_test!(
2324        test_transform_up,
2325        transform_yes("f_up"),
2326        Transformed::yes(transformed_up_tree())
2327    );
2328    transform_up_test!(
2329        test_transform_up_f_up_jump_on_a,
2330        transform_and_event_on("f_up", "a", TreeNodeRecursion::Jump),
2331        Transformed::yes(f_up_jump_on_a_transformed_up_tree())
2332    );
2333    transform_up_test!(
2334        test_transform_up_f_up_jump_on_e,
2335        transform_and_event_on("f_up", "e", TreeNodeRecursion::Jump),
2336        Transformed::yes(f_up_jump_on_e_transformed_up_tree())
2337    );
2338    transform_up_test!(
2339        test_transform_up_f_up_stop_on_a,
2340        transform_and_event_on("f_up", "a", TreeNodeRecursion::Stop),
2341        Transformed::new(
2342            f_up_stop_on_a_transformed_up_tree(),
2343            true,
2344            TreeNodeRecursion::Stop
2345        )
2346    );
2347    transform_up_test!(
2348        test_transform_up_f_up_stop_on_e,
2349        transform_and_event_on("f_up", "e", TreeNodeRecursion::Stop),
2350        Transformed::new(
2351            f_up_stop_on_e_transformed_up_tree(),
2352            true,
2353            TreeNodeRecursion::Stop
2354        )
2355    );
2356
2357    //             F
2358    //          /  |  \
2359    //       /     |     \
2360    //    E        C        A
2361    //    |      /   \
2362    //    C     B     D
2363    //  /   \         |
2364    // B     D        A
2365    //       |
2366    //       A
2367    #[test]
2368    fn test_apply_and_visit_references() -> Result<()> {
2369        let node_a = TestTreeNode::new_leaf("a".to_string());
2370        let node_b = TestTreeNode::new_leaf("b".to_string());
2371        let node_d = TestTreeNode::new(vec![node_a], "d".to_string());
2372        let node_c = TestTreeNode::new(vec![node_b, node_d], "c".to_string());
2373        let node_e = TestTreeNode::new(vec![node_c], "e".to_string());
2374        let node_a_2 = TestTreeNode::new_leaf("a".to_string());
2375        let node_b_2 = TestTreeNode::new_leaf("b".to_string());
2376        let node_d_2 = TestTreeNode::new(vec![node_a_2], "d".to_string());
2377        let node_c_2 = TestTreeNode::new(vec![node_b_2, node_d_2], "c".to_string());
2378        let node_a_3 = TestTreeNode::new_leaf("a".to_string());
2379        let tree = TestTreeNode::new(vec![node_e, node_c_2, node_a_3], "f".to_string());
2380
2381        let node_f_ref = &tree;
2382        let node_e_ref = &node_f_ref.children[0];
2383        let node_c_ref = &node_e_ref.children[0];
2384        let node_b_ref = &node_c_ref.children[0];
2385        let node_d_ref = &node_c_ref.children[1];
2386        let node_a_ref = &node_d_ref.children[0];
2387
2388        let mut m: HashMap<&TestTreeNode<String>, usize> = HashMap::new();
2389        tree.apply(|e| {
2390            *m.entry(e).or_insert(0) += 1;
2391            Ok(TreeNodeRecursion::Continue)
2392        })?;
2393
2394        let expected = HashMap::from([
2395            (node_f_ref, 1),
2396            (node_e_ref, 1),
2397            (node_c_ref, 2),
2398            (node_d_ref, 2),
2399            (node_b_ref, 2),
2400            (node_a_ref, 3),
2401        ]);
2402        assert_eq!(m, expected);
2403
2404        struct TestVisitor<'n> {
2405            m: HashMap<&'n TestTreeNode<String>, (usize, usize)>,
2406        }
2407
2408        impl<'n> TreeNodeVisitor<'n> for TestVisitor<'n> {
2409            type Node = TestTreeNode<String>;
2410
2411            fn f_down(&mut self, node: &'n Self::Node) -> Result<TreeNodeRecursion> {
2412                let (down_count, _) = self.m.entry(node).or_insert((0, 0));
2413                *down_count += 1;
2414                Ok(TreeNodeRecursion::Continue)
2415            }
2416
2417            fn f_up(&mut self, node: &'n Self::Node) -> Result<TreeNodeRecursion> {
2418                let (_, up_count) = self.m.entry(node).or_insert((0, 0));
2419                *up_count += 1;
2420                Ok(TreeNodeRecursion::Continue)
2421            }
2422        }
2423
2424        let mut visitor = TestVisitor { m: HashMap::new() };
2425        tree.visit(&mut visitor)?;
2426
2427        let expected = HashMap::from([
2428            (node_f_ref, (1, 1)),
2429            (node_e_ref, (1, 1)),
2430            (node_c_ref, (2, 2)),
2431            (node_d_ref, (2, 2)),
2432            (node_b_ref, (2, 2)),
2433            (node_a_ref, (3, 3)),
2434        ]);
2435        assert_eq!(visitor.m, expected);
2436
2437        Ok(())
2438    }
2439
2440    #[cfg(feature = "recursive_protection")]
2441    #[test]
2442    fn test_large_tree() {
2443        let mut item = TestTreeNode::new_leaf("initial".to_string());
2444        for i in 0..3000 {
2445            item = TestTreeNode::new(vec![item], format!("parent-{i}"));
2446        }
2447
2448        let mut visitor =
2449            TestVisitor::new(Box::new(visit_continue), Box::new(visit_continue));
2450
2451        item.visit(&mut visitor).unwrap();
2452    }
2453
2454    #[test]
2455    fn box_map_elements_reuses_allocation() {
2456        let boxed = Box::new(TestTreeNode::new_leaf(42i32));
2457        let before: *const TestTreeNode<i32> = &*boxed;
2458        let out = boxed.map_elements(|n| Ok(Transformed::no(n))).unwrap();
2459        let after: *const TestTreeNode<i32> = &*out.data;
2460        assert_eq!(after, before);
2461    }
2462
2463    #[test]
2464    fn arc_map_elements_reuses_allocation_when_unique() {
2465        let arc = Arc::new(TestTreeNode::new_leaf(42i32));
2466        let before = Arc::as_ptr(&arc);
2467        let out = arc.map_elements(|n| Ok(Transformed::no(n))).unwrap();
2468        assert_eq!(Arc::as_ptr(&out.data), before);
2469    }
2470
2471    #[test]
2472    fn arc_map_elements_clones_when_shared() {
2473        // When the input `Arc` is shared, `make_mut` clones into a fresh
2474        // allocation, so the reuse optimization does not apply.
2475        let arc = Arc::new(TestTreeNode::new_leaf(42i32));
2476        let _keepalive = Arc::clone(&arc);
2477        let before = Arc::as_ptr(&arc);
2478        let out = arc.map_elements(|n| Ok(Transformed::no(n))).unwrap();
2479        assert_ne!(Arc::as_ptr(&out.data), before);
2480    }
2481
2482    #[test]
2483    fn box_map_elements_panic() {
2484        use std::panic::{AssertUnwindSafe, catch_unwind};
2485        let boxed = Box::new(TestTreeNode::new_leaf(42i32));
2486        let result = catch_unwind(AssertUnwindSafe(|| {
2487            boxed
2488                .map_elements(|_: TestTreeNode<i32>| -> Result<_> {
2489                    panic!("simulated panic during rewrite")
2490                })
2491                .ok()
2492        }));
2493        assert!(result.is_err());
2494    }
2495}