datafusion_common/
tree_node.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`TreeNode`] for visiting and rewriting expression and plan trees
19
20use crate::Result;
21use std::collections::HashMap;
22use std::hash::Hash;
23use std::sync::Arc;
24
25/// These macros are used to determine continuation during transforming traversals.
26macro_rules! handle_transform_recursion {
27    ($F_DOWN:expr, $F_CHILD:expr, $F_UP:expr) => {{
28        $F_DOWN?
29            .transform_children(|n| n.map_children($F_CHILD))?
30            .transform_parent($F_UP)
31    }};
32}
33
34/// API for inspecting and rewriting tree data structures.
35///
36/// The `TreeNode` API is used to express algorithms separately from traversing
37/// the structure of `TreeNode`s, avoiding substantial code duplication.
38///
39/// This trait is implemented for plans ([`ExecutionPlan`], [`LogicalPlan`]) and
40/// expression trees ([`PhysicalExpr`], [`Expr`]) as well as Plan+Payload
41/// combinations [`PlanContext`] and [`ExprContext`].
42///
43/// # Overview
44/// There are three categories of TreeNode APIs:
45///
46/// 1. "Inspecting" APIs to traverse a tree of `&TreeNodes`:
47///    [`apply`], [`visit`], [`exists`].
48///
49/// 2. "Transforming" APIs that traverse and consume a tree of `TreeNode`s
50///    producing possibly changed `TreeNode`s: [`transform`], [`transform_up`],
51///    [`transform_down`], [`transform_down_up`], and [`rewrite`].
52///
53/// 3. Internal APIs used to implement the `TreeNode` API: [`apply_children`],
54///    and [`map_children`].
55///
56/// | Traversal Order | Inspecting | Transforming |
57/// | --- | --- | --- |
58/// | top-down | [`apply`], [`exists`] | [`transform_down`]|
59/// | bottom-up | | [`transform`] , [`transform_up`]|
60/// | combined with separate `f_down` and `f_up` closures | | [`transform_down_up`] |
61/// | combined with `f_down()` and `f_up()` in an object | [`visit`]  | [`rewrite`] |
62///
63/// **Note**:while there is currently no in-place mutation API that uses `&mut
64/// TreeNode`, the transforming APIs are efficient and optimized to avoid
65/// cloning.
66///
67/// [`apply`]: Self::apply
68/// [`visit`]: Self::visit
69/// [`exists`]: Self::exists
70/// [`transform`]: Self::transform
71/// [`transform_up`]: Self::transform_up
72/// [`transform_down`]: Self::transform_down
73/// [`transform_down_up`]: Self::transform_down_up
74/// [`rewrite`]: Self::rewrite
75/// [`apply_children`]: Self::apply_children
76/// [`map_children`]: Self::map_children
77///
78/// # Terminology
79/// The following terms are used in this trait
80///
81/// * `f_down`: Invoked before any children of the current node are visited.
82/// * `f_up`: Invoked after all children of the current node are visited.
83/// * `f`: closure that is applied to the current node.
84/// * `map_*`: applies a transformation to rewrite owned nodes
85/// * `apply_*`:  invokes a function on borrowed nodes
86/// * `transform_`: applies a transformation to rewrite owned nodes
87///
88/// <!-- Since these are in the datafusion-common crate, can't use intra doc links) -->
89/// [`ExecutionPlan`]: https://docs.rs/datafusion/latest/datafusion/physical_plan/trait.ExecutionPlan.html
90/// [`PhysicalExpr`]: https://docs.rs/datafusion/latest/datafusion/physical_plan/trait.PhysicalExpr.html
91/// [`LogicalPlan`]: https://docs.rs/datafusion-expr/latest/datafusion_expr/logical_plan/enum.LogicalPlan.html
92/// [`Expr`]: https://docs.rs/datafusion-expr/latest/datafusion_expr/expr/enum.Expr.html
93/// [`PlanContext`]: https://docs.rs/datafusion/latest/datafusion/physical_plan/tree_node/struct.PlanContext.html
94/// [`ExprContext`]: https://docs.rs/datafusion/latest/datafusion/physical_expr/tree_node/struct.ExprContext.html
95pub trait TreeNode: Sized {
96    /// Visit the tree node with a [`TreeNodeVisitor`], performing a
97    /// depth-first walk of the node and its children.
98    ///
99    /// [`TreeNodeVisitor::f_down()`] is called in top-down order (before
100    /// children are visited), [`TreeNodeVisitor::f_up()`] is called in
101    /// bottom-up order (after children are visited).
102    ///
103    /// # Return Value
104    /// Specifies how the tree walk ended. See [`TreeNodeRecursion`] for details.
105    ///
106    /// # See Also:
107    /// * [`Self::apply`] for inspecting nodes with a closure
108    /// * [`Self::rewrite`] to rewrite owned `TreeNode`s
109    ///
110    /// # Example
111    /// Consider the following tree structure:
112    /// ```text
113    /// ParentNode
114    ///    left: ChildNode1
115    ///    right: ChildNode2
116    /// ```
117    ///
118    /// Here, the nodes would be visited using the following order:
119    /// ```text
120    /// TreeNodeVisitor::f_down(ParentNode)
121    /// TreeNodeVisitor::f_down(ChildNode1)
122    /// TreeNodeVisitor::f_up(ChildNode1)
123    /// TreeNodeVisitor::f_down(ChildNode2)
124    /// TreeNodeVisitor::f_up(ChildNode2)
125    /// TreeNodeVisitor::f_up(ParentNode)
126    /// ```
127    #[cfg_attr(feature = "recursive_protection", recursive::recursive)]
128    fn visit<'n, V: TreeNodeVisitor<'n, Node = Self>>(
129        &'n self,
130        visitor: &mut V,
131    ) -> Result<TreeNodeRecursion> {
132        visitor
133            .f_down(self)?
134            .visit_children(|| self.apply_children(|c| c.visit(visitor)))?
135            .visit_parent(|| visitor.f_up(self))
136    }
137
138    /// Rewrite the tree node with a [`TreeNodeRewriter`], performing a
139    /// depth-first walk of the node and its children.
140    ///
141    /// [`TreeNodeRewriter::f_down()`] is called in top-down order (before
142    /// children are visited), [`TreeNodeRewriter::f_up()`] is called in
143    /// bottom-up order (after children are visited).
144    ///
145    /// Note: If using the default [`TreeNodeRewriter::f_up`] or
146    /// [`TreeNodeRewriter::f_down`] that do nothing, consider using
147    /// [`Self::transform_down`] instead.
148    ///
149    /// # Return Value
150    /// The returns value specifies how the tree walk should proceed. See
151    /// [`TreeNodeRecursion`] for details. If an [`Err`] is returned, the
152    /// recursion stops immediately.
153    ///
154    /// # See Also
155    /// * [`Self::visit`] for inspecting (without modification) `TreeNode`s
156    /// * [Self::transform_down_up] for a top-down (pre-order) traversal.
157    /// * [Self::transform_down] for a top-down (pre-order) traversal.
158    /// * [`Self::transform_up`] for a bottom-up (post-order) traversal.
159    ///
160    /// # Example
161    /// Consider the following tree structure:
162    /// ```text
163    /// ParentNode
164    ///    left: ChildNode1
165    ///    right: ChildNode2
166    /// ```
167    ///
168    /// Here, the nodes would be visited using the following order:
169    /// ```text
170    /// TreeNodeRewriter::f_down(ParentNode)
171    /// TreeNodeRewriter::f_down(ChildNode1)
172    /// TreeNodeRewriter::f_up(ChildNode1)
173    /// TreeNodeRewriter::f_down(ChildNode2)
174    /// TreeNodeRewriter::f_up(ChildNode2)
175    /// TreeNodeRewriter::f_up(ParentNode)
176    /// ```
177    #[cfg_attr(feature = "recursive_protection", recursive::recursive)]
178    fn rewrite<R: TreeNodeRewriter<Node = Self>>(
179        self,
180        rewriter: &mut R,
181    ) -> Result<Transformed<Self>> {
182        handle_transform_recursion!(rewriter.f_down(self), |c| c.rewrite(rewriter), |n| {
183            rewriter.f_up(n)
184        })
185    }
186
187    /// Applies `f` to the node then each of its children, recursively (a
188    /// top-down, pre-order traversal).
189    ///
190    /// The return [`TreeNodeRecursion`] controls the recursion and can cause
191    /// an early return.
192    ///
193    /// # See Also
194    /// * [`Self::transform_down`] for the equivalent transformation API.
195    /// * [`Self::visit`] for both top-down and bottom up traversal.
196    fn apply<'n, F: FnMut(&'n Self) -> Result<TreeNodeRecursion>>(
197        &'n self,
198        mut f: F,
199    ) -> Result<TreeNodeRecursion> {
200        #[cfg_attr(feature = "recursive_protection", recursive::recursive)]
201        fn apply_impl<'n, N: TreeNode, F: FnMut(&'n N) -> Result<TreeNodeRecursion>>(
202            node: &'n N,
203            f: &mut F,
204        ) -> Result<TreeNodeRecursion> {
205            f(node)?.visit_children(|| node.apply_children(|c| apply_impl(c, f)))
206        }
207
208        apply_impl(self, &mut f)
209    }
210
211    /// Recursively rewrite the node's children and then the node using `f`
212    /// (a bottom-up post-order traversal).
213    ///
214    /// A synonym of [`Self::transform_up`].
215    fn transform<F: FnMut(Self) -> Result<Transformed<Self>>>(
216        self,
217        f: F,
218    ) -> Result<Transformed<Self>> {
219        self.transform_up(f)
220    }
221
222    /// Recursively rewrite the tree using `f` in a top-down (pre-order)
223    /// fashion.
224    ///
225    /// `f` is applied to the node first, and then its children.
226    ///
227    /// # See Also
228    /// * [`Self::transform_up`] for a bottom-up (post-order) traversal.
229    /// * [Self::transform_down_up] for a combined traversal with closures
230    /// * [`Self::rewrite`] for a combined traversal with a visitor
231    fn transform_down<F: FnMut(Self) -> Result<Transformed<Self>>>(
232        self,
233        mut f: F,
234    ) -> Result<Transformed<Self>> {
235        #[cfg_attr(feature = "recursive_protection", recursive::recursive)]
236        fn transform_down_impl<N: TreeNode, F: FnMut(N) -> Result<Transformed<N>>>(
237            node: N,
238            f: &mut F,
239        ) -> Result<Transformed<N>> {
240            f(node)?.transform_children(|n| n.map_children(|c| transform_down_impl(c, f)))
241        }
242
243        transform_down_impl(self, &mut f)
244    }
245
246    /// Recursively rewrite the node using `f` in a bottom-up (post-order)
247    /// fashion.
248    ///
249    /// `f` is applied to the node's  children first, and then to the node itself.
250    ///
251    /// # See Also
252    /// * [`Self::transform_down`] top-down (pre-order) traversal.
253    /// * [Self::transform_down_up] for a combined traversal with closures
254    /// * [`Self::rewrite`] for a combined traversal with a visitor
255    fn transform_up<F: FnMut(Self) -> Result<Transformed<Self>>>(
256        self,
257        mut f: F,
258    ) -> Result<Transformed<Self>> {
259        #[cfg_attr(feature = "recursive_protection", recursive::recursive)]
260        fn transform_up_impl<N: TreeNode, F: FnMut(N) -> Result<Transformed<N>>>(
261            node: N,
262            f: &mut F,
263        ) -> Result<Transformed<N>> {
264            node.map_children(|c| transform_up_impl(c, f))?
265                .transform_parent(f)
266        }
267
268        transform_up_impl(self, &mut f)
269    }
270
271    /// Transforms the node using `f_down` while traversing the tree top-down
272    /// (pre-order), and using `f_up` while traversing the tree bottom-up
273    /// (post-order).
274    ///
275    /// The method behaves the same as calling [`Self::transform_down`] followed
276    /// by [`Self::transform_up`] on the same node. Use this method if you want
277    /// to start the `f_up` process right where `f_down` jumps. This can make
278    /// the whole process faster by reducing the number of `f_up` steps.
279    ///
280    /// # See Also
281    /// * [`Self::transform_up`] for a bottom-up (post-order) traversal.
282    /// * [Self::transform_down] for a top-down (pre-order) traversal.
283    /// * [`Self::rewrite`] for a combined traversal with a visitor
284    ///
285    /// # Example
286    /// Consider the following tree structure:
287    /// ```text
288    /// ParentNode
289    ///    left: ChildNode1
290    ///    right: ChildNode2
291    /// ```
292    ///
293    /// The nodes are visited using the following order:
294    /// ```text
295    /// f_down(ParentNode)
296    /// f_down(ChildNode1)
297    /// f_up(ChildNode1)
298    /// f_down(ChildNode2)
299    /// f_up(ChildNode2)
300    /// f_up(ParentNode)
301    /// ```
302    ///
303    /// See [`TreeNodeRecursion`] for more details on controlling the traversal.
304    ///
305    /// If `f_down` or `f_up` returns [`Err`], the recursion stops immediately.
306    ///
307    /// Example:
308    /// ```text
309    ///                                               |   +---+
310    ///                                               |   | J |
311    ///                                               |   +---+
312    ///                                               |     |
313    ///                                               |   +---+
314    ///                  TreeNodeRecursion::Continue  |   | I |
315    ///                                               |   +---+
316    ///                                               |     |
317    ///                                               |   +---+
318    ///                                              \|/  | F |
319    ///                                               '   +---+
320    ///                                                  /     \ ___________________
321    ///                  When `f_down` is           +---+                           \ ---+
322    ///                  applied on node "E",       | E |                            | G |
323    ///                  it returns with "Jump".    +---+                            +---+
324    ///                                               |                                |
325    ///                                             +---+                            +---+
326    ///                                             | C |                            | H |
327    ///                                             +---+                            +---+
328    ///                                             /   \
329    ///                                        +---+     +---+
330    ///                                        | B |     | D |
331    ///                                        +---+     +---+
332    ///                                                    |
333    ///                                                  +---+
334    ///                                                  | A |
335    ///                                                  +---+
336    ///
337    /// Instead of starting from leaf nodes, `f_up` starts from the node "E".
338    ///                                                   +---+
339    ///                                               |   | J |
340    ///                                               |   +---+
341    ///                                               |     |
342    ///                                               |   +---+
343    ///                                               |   | I |
344    ///                                               |   +---+
345    ///                                               |     |
346    ///                                              /    +---+
347    ///                                            /      | F |
348    ///                                          /        +---+
349    ///                                        /         /     \ ______________________
350    ///                                       |     +---+   .                          \ ---+
351    ///                                       |     | E |  /|\  After `f_down` jumps    | G |
352    ///                                       |     +---+   |   on node E, `f_up`       +---+
353    ///                                        \------| ---/   if applied on node E.      |
354    ///                                             +---+                               +---+
355    ///                                             | C |                               | H |
356    ///                                             +---+                               +---+
357    ///                                             /   \
358    ///                                        +---+     +---+
359    ///                                        | B |     | D |
360    ///                                        +---+     +---+
361    ///                                                    |
362    ///                                                  +---+
363    ///                                                  | A |
364    ///                                                  +---+
365    /// ```
366    fn transform_down_up<
367        FD: FnMut(Self) -> Result<Transformed<Self>>,
368        FU: FnMut(Self) -> Result<Transformed<Self>>,
369    >(
370        self,
371        mut f_down: FD,
372        mut f_up: FU,
373    ) -> Result<Transformed<Self>> {
374        #[cfg_attr(feature = "recursive_protection", recursive::recursive)]
375        fn transform_down_up_impl<
376            N: TreeNode,
377            FD: FnMut(N) -> Result<Transformed<N>>,
378            FU: FnMut(N) -> Result<Transformed<N>>,
379        >(
380            node: N,
381            f_down: &mut FD,
382            f_up: &mut FU,
383        ) -> Result<Transformed<N>> {
384            handle_transform_recursion!(
385                f_down(node),
386                |c| transform_down_up_impl(c, f_down, f_up),
387                f_up
388            )
389        }
390
391        transform_down_up_impl(self, &mut f_down, &mut f_up)
392    }
393
394    /// Returns true if `f` returns true for any node in the tree.
395    ///
396    /// Stops recursion as soon as a matching node is found
397    fn exists<F: FnMut(&Self) -> Result<bool>>(&self, mut f: F) -> Result<bool> {
398        let mut found = false;
399        self.apply(|n| {
400            Ok(if f(n)? {
401                found = true;
402                TreeNodeRecursion::Stop
403            } else {
404                TreeNodeRecursion::Continue
405            })
406        })
407        .map(|_| found)
408    }
409
410    /// Low-level API used to implement other APIs.
411    ///
412    /// If you want to implement the [`TreeNode`] trait for your own type, you
413    /// should implement this method and [`Self::map_children`].
414    ///
415    /// Users should use one of the higher level APIs described on [`Self`].
416    ///
417    /// Description: Apply `f` to inspect node's children (but not the node
418    /// itself).
419    fn apply_children<'n, F: FnMut(&'n Self) -> Result<TreeNodeRecursion>>(
420        &'n self,
421        f: F,
422    ) -> Result<TreeNodeRecursion>;
423
424    /// Low-level API used to implement other APIs.
425    ///
426    /// If you want to implement the [`TreeNode`] trait for your own type, you
427    /// should implement this method and [`Self::apply_children`].
428    ///
429    /// Users should use one of the higher level APIs described on [`Self`].
430    ///
431    /// Description: Apply `f` to rewrite the node's children (but not the node itself).
432    fn map_children<F: FnMut(Self) -> Result<Transformed<Self>>>(
433        self,
434        f: F,
435    ) -> Result<Transformed<Self>>;
436}
437
438/// A [Visitor](https://en.wikipedia.org/wiki/Visitor_pattern) for recursively
439/// inspecting [`TreeNode`]s via [`TreeNode::visit`].
440///
441/// See [`TreeNode`] for more details on available APIs
442///
443/// When passed to [`TreeNode::visit`], [`TreeNodeVisitor::f_down`] and
444/// [`TreeNodeVisitor::f_up`] are invoked recursively on the tree.
445/// See [`TreeNodeRecursion`] for more details on controlling the traversal.
446///
447/// # Return Value
448/// The returns value of `f_up` and `f_down` specifies how the tree walk should
449/// proceed. See [`TreeNodeRecursion`] for details. If an [`Err`] is returned,
450/// the recursion stops immediately.
451///
452/// Note: If using the default implementations of [`TreeNodeVisitor::f_up`] or
453/// [`TreeNodeVisitor::f_down`] that do nothing, consider using
454/// [`TreeNode::apply`] instead.
455///
456/// # See Also:
457/// * [`TreeNode::rewrite`] to rewrite owned `TreeNode`s
458pub trait TreeNodeVisitor<'n>: Sized {
459    /// The node type which is visitable.
460    type Node: TreeNode;
461
462    /// Invoked while traversing down the tree, before any children are visited.
463    /// Default implementation continues the recursion.
464    fn f_down(&mut self, _node: &'n Self::Node) -> Result<TreeNodeRecursion> {
465        Ok(TreeNodeRecursion::Continue)
466    }
467
468    /// Invoked while traversing up the tree after children are visited. Default
469    /// implementation continues the recursion.
470    fn f_up(&mut self, _node: &'n Self::Node) -> Result<TreeNodeRecursion> {
471        Ok(TreeNodeRecursion::Continue)
472    }
473}
474
475/// A [Visitor](https://en.wikipedia.org/wiki/Visitor_pattern) for recursively
476/// rewriting [`TreeNode`]s via [`TreeNode::rewrite`].
477///
478/// For example you can implement this trait on a struct to rewrite `Expr` or
479/// `LogicalPlan` that needs to track state during the rewrite.
480///
481/// See [`TreeNode`] for more details on available APIs
482///
483/// When passed to [`TreeNode::rewrite`], [`TreeNodeRewriter::f_down`] and
484/// [`TreeNodeRewriter::f_up`] are invoked recursively on the tree.
485/// See [`TreeNodeRecursion`] for more details on controlling the traversal.
486///
487/// # Return Value
488/// The returns value of `f_up` and `f_down` specifies how the tree walk should
489/// proceed. See [`TreeNodeRecursion`] for details. If an [`Err`] is returned,
490/// the recursion stops immediately.
491///
492/// Note: If using the default implementations of [`TreeNodeRewriter::f_up`] or
493/// [`TreeNodeRewriter::f_down`] that do nothing, consider using
494/// [`TreeNode::transform_up`] or [`TreeNode::transform_down`] instead.
495///
496/// # See Also:
497/// * [`TreeNode::visit`] to inspect borrowed `TreeNode`s
498pub trait TreeNodeRewriter: Sized {
499    /// The node type which is rewritable.
500    type Node: TreeNode;
501
502    /// Invoked while traversing down the tree before any children are rewritten.
503    /// Default implementation returns the node as is and continues recursion.
504    fn f_down(&mut self, node: Self::Node) -> Result<Transformed<Self::Node>> {
505        Ok(Transformed::no(node))
506    }
507
508    /// Invoked while traversing up the tree after all children have been rewritten.
509    /// Default implementation returns the node as is and continues recursion.
510    fn f_up(&mut self, node: Self::Node) -> Result<Transformed<Self::Node>> {
511        Ok(Transformed::no(node))
512    }
513}
514
515/// Controls how [`TreeNode`] recursions should proceed.
516#[derive(Debug, PartialEq, Clone, Copy)]
517pub enum TreeNodeRecursion {
518    /// Continue recursion with the next node.
519    Continue,
520    /// In top-down traversals, skip recursing into children but continue with
521    /// the next node, which actually means pruning of the subtree.
522    ///
523    /// In bottom-up traversals, bypass calling bottom-up closures till the next
524    /// leaf node.
525    ///
526    /// In combined traversals, if it is the `f_down` (pre-order) phase, execution
527    /// "jumps" to the next `f_up` (post-order) phase by shortcutting its children.
528    /// If it is the `f_up` (post-order) phase, execution "jumps" to the next `f_down`
529    /// (pre-order) phase by shortcutting its parent nodes until the first parent node
530    /// having unvisited children path.
531    Jump,
532    /// Stop recursion.
533    Stop,
534}
535
536impl TreeNodeRecursion {
537    /// Continues visiting nodes with `f` depending on the current [`TreeNodeRecursion`]
538    /// value and the fact that `f` is visiting the current node's children.
539    pub fn visit_children<F: FnOnce() -> Result<TreeNodeRecursion>>(
540        self,
541        f: F,
542    ) -> Result<TreeNodeRecursion> {
543        match self {
544            TreeNodeRecursion::Continue => f(),
545            TreeNodeRecursion::Jump => Ok(TreeNodeRecursion::Continue),
546            TreeNodeRecursion::Stop => Ok(self),
547        }
548    }
549
550    /// Continues visiting nodes with `f` depending on the current [`TreeNodeRecursion`]
551    /// value and the fact that `f` is visiting the current node's sibling.
552    pub fn visit_sibling<F: FnOnce() -> Result<TreeNodeRecursion>>(
553        self,
554        f: F,
555    ) -> Result<TreeNodeRecursion> {
556        match self {
557            TreeNodeRecursion::Continue | TreeNodeRecursion::Jump => f(),
558            TreeNodeRecursion::Stop => Ok(self),
559        }
560    }
561
562    /// Continues visiting nodes with `f` depending on the current [`TreeNodeRecursion`]
563    /// value and the fact that `f` is visiting the current node's parent.
564    pub fn visit_parent<F: FnOnce() -> Result<TreeNodeRecursion>>(
565        self,
566        f: F,
567    ) -> Result<TreeNodeRecursion> {
568        match self {
569            TreeNodeRecursion::Continue => f(),
570            TreeNodeRecursion::Jump | TreeNodeRecursion::Stop => Ok(self),
571        }
572    }
573}
574
575/// Result of tree walk / transformation APIs
576///
577/// `Transformed` is a wrapper around the tree node data (e.g. `Expr` or
578/// `LogicalPlan`). It is used to indicate whether the node was transformed
579/// and how the recursion should proceed.
580///
581/// [`TreeNode`] API users control the transformation by returning:
582/// - The resulting (possibly transformed) node,
583/// - `transformed`: flag indicating whether any change was made to the node
584/// - `tnr`: [`TreeNodeRecursion`] specifying how to proceed with the recursion.
585///
586/// At the end of the transformation, the return value will contain:
587/// - The final (possibly transformed) tree,
588/// - `transformed`: flag indicating whether any change was made to the node
589/// - `tnr`: [`TreeNodeRecursion`] specifying how the recursion ended.
590///
591/// See also
592/// * [`Transformed::update_data`] to modify the node without changing the `transformed` flag
593/// * [`Transformed::map_data`] for fallable operation that return the same type
594/// * [`Transformed::transform_data`] to chain fallable transformations
595/// * [`TransformedResult`] for working with `Result<Transformed<U>>`
596///
597/// # Examples
598///
599/// Use [`Transformed::yes`] and [`Transformed::no`] to signal that a node was
600/// rewritten and the recursion should continue:
601///
602/// ```
603/// # use datafusion_common::tree_node::Transformed;
604/// # // note use i64 instead of Expr as Expr is not in datafusion-common
605/// # fn orig_expr() -> i64 { 1 }
606/// # fn make_new_expr(i: i64) -> i64 { 2 }
607/// let expr = orig_expr();
608///
609/// // Create a new `Transformed` object signaling the node was not rewritten
610/// let ret = Transformed::no(expr.clone());
611/// assert!(!ret.transformed);
612///
613/// // Create a new `Transformed` object signaling the node was rewritten
614/// let ret = Transformed::yes(expr);
615/// assert!(ret.transformed)
616/// ```
617///
618/// Access the node within the `Transformed` object:
619/// ```
620/// # use datafusion_common::tree_node::Transformed;
621/// # // note use i64 instead of Expr as Expr is not in datafusion-common
622/// # fn orig_expr() -> i64 { 1 }
623/// # fn make_new_expr(i: i64) -> i64 { 2 }
624/// let expr = orig_expr();
625///
626/// // `Transformed` object signaling the node was not rewritten
627/// let ret = Transformed::no(expr.clone());
628/// // Access the inner object using .data
629/// assert_eq!(expr, ret.data);
630/// ```
631///
632/// Transform the node within the `Transformed` object.
633///
634/// ```
635/// # use datafusion_common::tree_node::Transformed;
636/// # // note use i64 instead of Expr as Expr is not in datafusion-common
637/// # fn orig_expr() -> i64 { 1 }
638/// # fn make_new_expr(i: i64) -> i64 { 2 }
639/// let expr = orig_expr();
640/// let ret = Transformed::no(expr.clone())
641///     .transform_data(|expr| {
642///         // closure returns a result and potentially transforms the node
643///         // in this example, it does transform the node
644///         let new_expr = make_new_expr(expr);
645///         Ok(Transformed::yes(new_expr))
646///     })
647///     .unwrap();
648/// // transformed flag is the union of the original ans closure's  transformed flag
649/// assert!(ret.transformed);
650/// ```
651/// # Example APIs that use `TreeNode`
652/// - [`TreeNode`],
653/// - [`TreeNode::rewrite`],
654/// - [`TreeNode::transform_down`],
655/// - [`TreeNode::transform_up`],
656/// - [`TreeNode::transform_down_up`]
657#[derive(PartialEq, Debug)]
658pub struct Transformed<T> {
659    pub data: T,
660    pub transformed: bool,
661    pub tnr: TreeNodeRecursion,
662}
663
664impl<T> Transformed<T> {
665    /// Create a new `Transformed` object with the given information.
666    pub fn new(data: T, transformed: bool, tnr: TreeNodeRecursion) -> Self {
667        Self {
668            data,
669            transformed,
670            tnr,
671        }
672    }
673
674    /// Create a `Transformed` with `transformed` and [`TreeNodeRecursion::Continue`].
675    pub fn new_transformed(data: T, transformed: bool) -> Self {
676        Self::new(data, transformed, TreeNodeRecursion::Continue)
677    }
678
679    /// Wrapper for transformed data with [`TreeNodeRecursion::Continue`] statement.
680    pub fn yes(data: T) -> Self {
681        Self::new(data, true, TreeNodeRecursion::Continue)
682    }
683
684    /// Wrapper for transformed data with [`TreeNodeRecursion::Stop`] statement.
685    pub fn complete(data: T) -> Self {
686        Self::new(data, true, TreeNodeRecursion::Stop)
687    }
688
689    /// Wrapper for unchanged data with [`TreeNodeRecursion::Continue`] statement.
690    pub fn no(data: T) -> Self {
691        Self::new(data, false, TreeNodeRecursion::Continue)
692    }
693
694    /// Applies an infallible `f` to the data of this [`Transformed`] object,
695    /// without modifying the `transformed` flag.
696    pub fn update_data<U, F: FnOnce(T) -> U>(self, f: F) -> Transformed<U> {
697        Transformed::new(f(self.data), self.transformed, self.tnr)
698    }
699
700    /// Applies a fallible `f` (returns `Result`) to the data of this
701    /// [`Transformed`] object, without modifying the `transformed` flag.
702    pub fn map_data<U, F: FnOnce(T) -> Result<U>>(self, f: F) -> Result<Transformed<U>> {
703        f(self.data).map(|data| Transformed::new(data, self.transformed, self.tnr))
704    }
705
706    /// Applies a fallible transforming `f` to the data of this [`Transformed`]
707    /// object.
708    ///
709    /// The returned `Transformed` object has the `transformed` flag set if either
710    /// `self` or the return value of `f` have the `transformed` flag set.
711    pub fn transform_data<U, F: FnOnce(T) -> Result<Transformed<U>>>(
712        self,
713        f: F,
714    ) -> Result<Transformed<U>> {
715        f(self.data).map(|mut t| {
716            t.transformed |= self.transformed;
717            t
718        })
719    }
720
721    /// Maps the [`Transformed`] object to the result of the given `f` depending on the
722    /// current [`TreeNodeRecursion`] value and the fact that `f` is changing the current
723    /// node's children.
724    pub fn transform_children<F: FnOnce(T) -> Result<Transformed<T>>>(
725        mut self,
726        f: F,
727    ) -> Result<Transformed<T>> {
728        match self.tnr {
729            TreeNodeRecursion::Continue => {
730                return f(self.data).map(|mut t| {
731                    t.transformed |= self.transformed;
732                    t
733                });
734            }
735            TreeNodeRecursion::Jump => {
736                self.tnr = TreeNodeRecursion::Continue;
737            }
738            TreeNodeRecursion::Stop => {}
739        }
740        Ok(self)
741    }
742
743    /// Maps the [`Transformed`] object to the result of the given `f` depending on the
744    /// current [`TreeNodeRecursion`] value and the fact that `f` is changing the current
745    /// node's sibling.
746    pub fn transform_sibling<F: FnOnce(T) -> Result<Transformed<T>>>(
747        self,
748        f: F,
749    ) -> Result<Transformed<T>> {
750        match self.tnr {
751            TreeNodeRecursion::Continue | TreeNodeRecursion::Jump => {
752                f(self.data).map(|mut t| {
753                    t.transformed |= self.transformed;
754                    t
755                })
756            }
757            TreeNodeRecursion::Stop => Ok(self),
758        }
759    }
760
761    /// Maps the [`Transformed`] object to the result of the given `f` depending on the
762    /// current [`TreeNodeRecursion`] value and the fact that `f` is changing the current
763    /// node's parent.
764    pub fn transform_parent<F: FnOnce(T) -> Result<Transformed<T>>>(
765        self,
766        f: F,
767    ) -> Result<Transformed<T>> {
768        match self.tnr {
769            TreeNodeRecursion::Continue => f(self.data).map(|mut t| {
770                t.transformed |= self.transformed;
771                t
772            }),
773            TreeNodeRecursion::Jump | TreeNodeRecursion::Stop => Ok(self),
774        }
775    }
776}
777
778/// [`TreeNodeContainer`] contains elements that a function can be applied on or mapped.
779/// The elements of the container are siblings so the continuation rules are similar to
780/// [`TreeNodeRecursion::visit_sibling`] / [`Transformed::transform_sibling`].
781pub trait TreeNodeContainer<'a, T: 'a>: Sized {
782    /// Applies `f` to all elements of the container.
783    /// This method is usually called from [`TreeNode::apply_children`] implementations as
784    /// a node is actually a container of the node's children.
785    fn apply_elements<F: FnMut(&'a T) -> Result<TreeNodeRecursion>>(
786        &'a self,
787        f: F,
788    ) -> Result<TreeNodeRecursion>;
789
790    /// Maps all elements of the container with `f`.
791    /// This method is usually called from [`TreeNode::map_children`] implementations as
792    /// a node is actually a container of the node's children.
793    fn map_elements<F: FnMut(T) -> Result<Transformed<T>>>(
794        self,
795        f: F,
796    ) -> Result<Transformed<Self>>;
797}
798
799impl<'a, T: 'a, C: TreeNodeContainer<'a, T>> TreeNodeContainer<'a, T> for Box<C> {
800    fn apply_elements<F: FnMut(&'a T) -> Result<TreeNodeRecursion>>(
801        &'a self,
802        f: F,
803    ) -> Result<TreeNodeRecursion> {
804        self.as_ref().apply_elements(f)
805    }
806
807    fn map_elements<F: FnMut(T) -> Result<Transformed<T>>>(
808        self,
809        f: F,
810    ) -> Result<Transformed<Self>> {
811        (*self).map_elements(f)?.map_data(|c| Ok(Self::new(c)))
812    }
813}
814
815impl<'a, T: 'a, C: TreeNodeContainer<'a, T> + Clone> TreeNodeContainer<'a, T> for Arc<C> {
816    fn apply_elements<F: FnMut(&'a T) -> Result<TreeNodeRecursion>>(
817        &'a self,
818        f: F,
819    ) -> Result<TreeNodeRecursion> {
820        self.as_ref().apply_elements(f)
821    }
822
823    fn map_elements<F: FnMut(T) -> Result<Transformed<T>>>(
824        self,
825        f: F,
826    ) -> Result<Transformed<Self>> {
827        Arc::unwrap_or_clone(self)
828            .map_elements(f)?
829            .map_data(|c| Ok(Arc::new(c)))
830    }
831}
832
833impl<'a, T: 'a, C: TreeNodeContainer<'a, T>> TreeNodeContainer<'a, T> for Option<C> {
834    fn apply_elements<F: FnMut(&'a T) -> Result<TreeNodeRecursion>>(
835        &'a self,
836        f: F,
837    ) -> Result<TreeNodeRecursion> {
838        match self {
839            Some(t) => t.apply_elements(f),
840            None => Ok(TreeNodeRecursion::Continue),
841        }
842    }
843
844    fn map_elements<F: FnMut(T) -> Result<Transformed<T>>>(
845        self,
846        f: F,
847    ) -> Result<Transformed<Self>> {
848        self.map_or(Ok(Transformed::no(None)), |c| {
849            c.map_elements(f)?.map_data(|c| Ok(Some(c)))
850        })
851    }
852}
853
854impl<'a, T: 'a, C: TreeNodeContainer<'a, T>> TreeNodeContainer<'a, T> for Vec<C> {
855    fn apply_elements<F: FnMut(&'a T) -> Result<TreeNodeRecursion>>(
856        &'a self,
857        mut f: F,
858    ) -> Result<TreeNodeRecursion> {
859        let mut tnr = TreeNodeRecursion::Continue;
860        for c in self {
861            tnr = c.apply_elements(&mut f)?;
862            match tnr {
863                TreeNodeRecursion::Continue | TreeNodeRecursion::Jump => {}
864                TreeNodeRecursion::Stop => return Ok(TreeNodeRecursion::Stop),
865            }
866        }
867        Ok(tnr)
868    }
869
870    fn map_elements<F: FnMut(T) -> Result<Transformed<T>>>(
871        self,
872        mut f: F,
873    ) -> Result<Transformed<Self>> {
874        let mut tnr = TreeNodeRecursion::Continue;
875        let mut transformed = false;
876        self.into_iter()
877            .map(|c| match tnr {
878                TreeNodeRecursion::Continue | TreeNodeRecursion::Jump => {
879                    c.map_elements(&mut f).map(|result| {
880                        tnr = result.tnr;
881                        transformed |= result.transformed;
882                        result.data
883                    })
884                }
885                TreeNodeRecursion::Stop => Ok(c),
886            })
887            .collect::<Result<Vec<_>>>()
888            .map(|data| Transformed::new(data, transformed, tnr))
889    }
890}
891
892impl<'a, T: 'a, K: Eq + Hash, C: TreeNodeContainer<'a, T>> TreeNodeContainer<'a, T>
893    for HashMap<K, C>
894{
895    fn apply_elements<F: FnMut(&'a T) -> Result<TreeNodeRecursion>>(
896        &'a self,
897        mut f: F,
898    ) -> Result<TreeNodeRecursion> {
899        let mut tnr = TreeNodeRecursion::Continue;
900        for c in self.values() {
901            tnr = c.apply_elements(&mut f)?;
902            match tnr {
903                TreeNodeRecursion::Continue | TreeNodeRecursion::Jump => {}
904                TreeNodeRecursion::Stop => return Ok(TreeNodeRecursion::Stop),
905            }
906        }
907        Ok(tnr)
908    }
909
910    fn map_elements<F: FnMut(T) -> Result<Transformed<T>>>(
911        self,
912        mut f: F,
913    ) -> Result<Transformed<Self>> {
914        let mut tnr = TreeNodeRecursion::Continue;
915        let mut transformed = false;
916        self.into_iter()
917            .map(|(k, c)| match tnr {
918                TreeNodeRecursion::Continue | TreeNodeRecursion::Jump => {
919                    c.map_elements(&mut f).map(|result| {
920                        tnr = result.tnr;
921                        transformed |= result.transformed;
922                        (k, result.data)
923                    })
924                }
925                TreeNodeRecursion::Stop => Ok((k, c)),
926            })
927            .collect::<Result<HashMap<_, _>>>()
928            .map(|data| Transformed::new(data, transformed, tnr))
929    }
930}
931
932impl<'a, T: 'a, C0: TreeNodeContainer<'a, T>, C1: TreeNodeContainer<'a, T>>
933    TreeNodeContainer<'a, T> for (C0, C1)
934{
935    fn apply_elements<F: FnMut(&'a T) -> Result<TreeNodeRecursion>>(
936        &'a self,
937        mut f: F,
938    ) -> Result<TreeNodeRecursion> {
939        self.0
940            .apply_elements(&mut f)?
941            .visit_sibling(|| self.1.apply_elements(&mut f))
942    }
943
944    fn map_elements<F: FnMut(T) -> Result<Transformed<T>>>(
945        self,
946        mut f: F,
947    ) -> Result<Transformed<Self>> {
948        self.0
949            .map_elements(&mut f)?
950            .map_data(|new_c0| Ok((new_c0, self.1)))?
951            .transform_sibling(|(new_c0, c1)| {
952                c1.map_elements(&mut f)?
953                    .map_data(|new_c1| Ok((new_c0, new_c1)))
954            })
955    }
956}
957
958impl<
959        'a,
960        T: 'a,
961        C0: TreeNodeContainer<'a, T>,
962        C1: TreeNodeContainer<'a, T>,
963        C2: TreeNodeContainer<'a, T>,
964    > TreeNodeContainer<'a, T> for (C0, C1, C2)
965{
966    fn apply_elements<F: FnMut(&'a T) -> Result<TreeNodeRecursion>>(
967        &'a self,
968        mut f: F,
969    ) -> Result<TreeNodeRecursion> {
970        self.0
971            .apply_elements(&mut f)?
972            .visit_sibling(|| self.1.apply_elements(&mut f))?
973            .visit_sibling(|| self.2.apply_elements(&mut f))
974    }
975
976    fn map_elements<F: FnMut(T) -> Result<Transformed<T>>>(
977        self,
978        mut f: F,
979    ) -> Result<Transformed<Self>> {
980        self.0
981            .map_elements(&mut f)?
982            .map_data(|new_c0| Ok((new_c0, self.1, self.2)))?
983            .transform_sibling(|(new_c0, c1, c2)| {
984                c1.map_elements(&mut f)?
985                    .map_data(|new_c1| Ok((new_c0, new_c1, c2)))
986            })?
987            .transform_sibling(|(new_c0, new_c1, c2)| {
988                c2.map_elements(&mut f)?
989                    .map_data(|new_c2| Ok((new_c0, new_c1, new_c2)))
990            })
991    }
992}
993
994impl<
995        'a,
996        T: 'a,
997        C0: TreeNodeContainer<'a, T>,
998        C1: TreeNodeContainer<'a, T>,
999        C2: TreeNodeContainer<'a, T>,
1000        C3: TreeNodeContainer<'a, T>,
1001    > TreeNodeContainer<'a, T> for (C0, C1, C2, C3)
1002{
1003    fn apply_elements<F: FnMut(&'a T) -> Result<TreeNodeRecursion>>(
1004        &'a self,
1005        mut f: F,
1006    ) -> Result<TreeNodeRecursion> {
1007        self.0
1008            .apply_elements(&mut f)?
1009            .visit_sibling(|| self.1.apply_elements(&mut f))?
1010            .visit_sibling(|| self.2.apply_elements(&mut f))?
1011            .visit_sibling(|| self.3.apply_elements(&mut f))
1012    }
1013
1014    fn map_elements<F: FnMut(T) -> Result<Transformed<T>>>(
1015        self,
1016        mut f: F,
1017    ) -> Result<Transformed<Self>> {
1018        self.0
1019            .map_elements(&mut f)?
1020            .map_data(|new_c0| Ok((new_c0, self.1, self.2, self.3)))?
1021            .transform_sibling(|(new_c0, c1, c2, c3)| {
1022                c1.map_elements(&mut f)?
1023                    .map_data(|new_c1| Ok((new_c0, new_c1, c2, c3)))
1024            })?
1025            .transform_sibling(|(new_c0, new_c1, c2, c3)| {
1026                c2.map_elements(&mut f)?
1027                    .map_data(|new_c2| Ok((new_c0, new_c1, new_c2, c3)))
1028            })?
1029            .transform_sibling(|(new_c0, new_c1, new_c2, c3)| {
1030                c3.map_elements(&mut f)?
1031                    .map_data(|new_c3| Ok((new_c0, new_c1, new_c2, new_c3)))
1032            })
1033    }
1034}
1035
1036/// [`TreeNodeRefContainer`] contains references to elements that a function can be
1037/// applied on. The elements of the container are siblings so the continuation rules are
1038/// similar to [`TreeNodeRecursion::visit_sibling`].
1039///
1040/// This container is similar to [`TreeNodeContainer`], but the lifetime of the reference
1041/// elements (`T`) are not derived from the container's lifetime.
1042/// A typical usage of this container is in `Expr::apply_children` when we need to
1043/// construct a temporary container to be able to call `apply_ref_elements` on a
1044/// collection of tree node references. But in that case the container's temporary
1045/// lifetime is different to the lifetime of tree nodes that we put into it.
1046/// Please find an example use case in `Expr::apply_children` with the `Expr::Case` case.
1047///
1048/// Most of the cases we don't need to create a temporary container with
1049/// `TreeNodeRefContainer`, but we can just call `TreeNodeContainer::apply_elements`.
1050/// Please find an example use case in `Expr::apply_children` with the `Expr::GroupingSet`
1051/// case.
1052pub trait TreeNodeRefContainer<'a, T: 'a>: Sized {
1053    /// Applies `f` to all elements of the container.
1054    /// This method is usually called from [`TreeNode::apply_children`] implementations as
1055    /// a node is actually a container of the node's children.
1056    fn apply_ref_elements<F: FnMut(&'a T) -> Result<TreeNodeRecursion>>(
1057        &self,
1058        f: F,
1059    ) -> Result<TreeNodeRecursion>;
1060}
1061
1062impl<'a, T: 'a, C: TreeNodeContainer<'a, T>> TreeNodeRefContainer<'a, T> for Vec<&'a C> {
1063    fn apply_ref_elements<F: FnMut(&'a T) -> Result<TreeNodeRecursion>>(
1064        &self,
1065        mut f: F,
1066    ) -> Result<TreeNodeRecursion> {
1067        let mut tnr = TreeNodeRecursion::Continue;
1068        for c in self {
1069            tnr = c.apply_elements(&mut f)?;
1070            match tnr {
1071                TreeNodeRecursion::Continue | TreeNodeRecursion::Jump => {}
1072                TreeNodeRecursion::Stop => return Ok(TreeNodeRecursion::Stop),
1073            }
1074        }
1075        Ok(tnr)
1076    }
1077}
1078
1079impl<'a, T: 'a, C0: TreeNodeContainer<'a, T>, C1: TreeNodeContainer<'a, T>>
1080    TreeNodeRefContainer<'a, T> for (&'a C0, &'a C1)
1081{
1082    fn apply_ref_elements<F: FnMut(&'a T) -> Result<TreeNodeRecursion>>(
1083        &self,
1084        mut f: F,
1085    ) -> Result<TreeNodeRecursion> {
1086        self.0
1087            .apply_elements(&mut f)?
1088            .visit_sibling(|| self.1.apply_elements(&mut f))
1089    }
1090}
1091
1092impl<
1093        'a,
1094        T: 'a,
1095        C0: TreeNodeContainer<'a, T>,
1096        C1: TreeNodeContainer<'a, T>,
1097        C2: TreeNodeContainer<'a, T>,
1098    > TreeNodeRefContainer<'a, T> for (&'a C0, &'a C1, &'a C2)
1099{
1100    fn apply_ref_elements<F: FnMut(&'a T) -> Result<TreeNodeRecursion>>(
1101        &self,
1102        mut f: F,
1103    ) -> Result<TreeNodeRecursion> {
1104        self.0
1105            .apply_elements(&mut f)?
1106            .visit_sibling(|| self.1.apply_elements(&mut f))?
1107            .visit_sibling(|| self.2.apply_elements(&mut f))
1108    }
1109}
1110
1111impl<
1112        'a,
1113        T: 'a,
1114        C0: TreeNodeContainer<'a, T>,
1115        C1: TreeNodeContainer<'a, T>,
1116        C2: TreeNodeContainer<'a, T>,
1117        C3: TreeNodeContainer<'a, T>,
1118    > TreeNodeRefContainer<'a, T> for (&'a C0, &'a C1, &'a C2, &'a C3)
1119{
1120    fn apply_ref_elements<F: FnMut(&'a T) -> Result<TreeNodeRecursion>>(
1121        &self,
1122        mut f: F,
1123    ) -> Result<TreeNodeRecursion> {
1124        self.0
1125            .apply_elements(&mut f)?
1126            .visit_sibling(|| self.1.apply_elements(&mut f))?
1127            .visit_sibling(|| self.2.apply_elements(&mut f))?
1128            .visit_sibling(|| self.3.apply_elements(&mut f))
1129    }
1130}
1131
1132/// Transformation helper to process a sequence of iterable tree nodes that are siblings.
1133pub trait TreeNodeIterator: Iterator {
1134    /// Apples `f` to each item in this iterator
1135    ///
1136    /// Visits all items in the iterator unless
1137    /// `f` returns an error or `f` returns `TreeNodeRecursion::Stop`.
1138    ///
1139    /// # Returns
1140    /// Error if `f` returns an error or `Ok(TreeNodeRecursion)` from the last invocation
1141    /// of `f` or `Continue` if the iterator is empty
1142    fn apply_until_stop<F: FnMut(Self::Item) -> Result<TreeNodeRecursion>>(
1143        self,
1144        f: F,
1145    ) -> Result<TreeNodeRecursion>;
1146
1147    /// Apples `f` to each item in this iterator
1148    ///
1149    /// Visits all items in the iterator unless
1150    /// `f` returns an error or `f` returns `TreeNodeRecursion::Stop`.
1151    ///
1152    /// # Returns
1153    /// Error if `f` returns an error
1154    ///
1155    /// Ok(Transformed) such that:
1156    /// 1. `transformed` is true if any return from `f` had transformed true
1157    /// 2. `data` from the last invocation of `f`
1158    /// 3. `tnr` from the last invocation of `f` or `Continue` if the iterator is empty
1159    fn map_until_stop_and_collect<
1160        F: FnMut(Self::Item) -> Result<Transformed<Self::Item>>,
1161    >(
1162        self,
1163        f: F,
1164    ) -> Result<Transformed<Vec<Self::Item>>>;
1165}
1166
1167impl<I: Iterator> TreeNodeIterator for I {
1168    fn apply_until_stop<F: FnMut(Self::Item) -> Result<TreeNodeRecursion>>(
1169        self,
1170        mut f: F,
1171    ) -> Result<TreeNodeRecursion> {
1172        let mut tnr = TreeNodeRecursion::Continue;
1173        for i in self {
1174            tnr = f(i)?;
1175            match tnr {
1176                TreeNodeRecursion::Continue | TreeNodeRecursion::Jump => {}
1177                TreeNodeRecursion::Stop => return Ok(TreeNodeRecursion::Stop),
1178            }
1179        }
1180        Ok(tnr)
1181    }
1182
1183    fn map_until_stop_and_collect<
1184        F: FnMut(Self::Item) -> Result<Transformed<Self::Item>>,
1185    >(
1186        self,
1187        mut f: F,
1188    ) -> Result<Transformed<Vec<Self::Item>>> {
1189        let mut tnr = TreeNodeRecursion::Continue;
1190        let mut transformed = false;
1191        self.map(|item| match tnr {
1192            TreeNodeRecursion::Continue | TreeNodeRecursion::Jump => {
1193                f(item).map(|result| {
1194                    tnr = result.tnr;
1195                    transformed |= result.transformed;
1196                    result.data
1197                })
1198            }
1199            TreeNodeRecursion::Stop => Ok(item),
1200        })
1201        .collect::<Result<Vec<_>>>()
1202        .map(|data| Transformed::new(data, transformed, tnr))
1203    }
1204}
1205
1206/// Transformation helper to access [`Transformed`] fields in a [`Result`] easily.
1207///
1208/// # Example
1209/// Access the internal data of a `Result<Transformed<T>>`
1210/// as a `Result<T>` using the `data` method:
1211/// ```
1212/// # use datafusion_common::Result;
1213/// # use datafusion_common::tree_node::{Transformed, TransformedResult};
1214/// # // note use i64 instead of Expr as Expr is not in datafusion-common
1215/// # fn update_expr() -> i64 { 1 }
1216/// # fn main() -> Result<()> {
1217/// let transformed: Result<Transformed<_>> = Ok(Transformed::yes(update_expr()));
1218/// // access the internal data of the transformed result, or return the error
1219/// let transformed_expr = transformed.data()?;
1220/// # Ok(())
1221/// # }
1222/// ```
1223pub trait TransformedResult<T> {
1224    fn data(self) -> Result<T>;
1225
1226    fn transformed(self) -> Result<bool>;
1227
1228    fn tnr(self) -> Result<TreeNodeRecursion>;
1229}
1230
1231impl<T> TransformedResult<T> for Result<Transformed<T>> {
1232    fn data(self) -> Result<T> {
1233        self.map(|t| t.data)
1234    }
1235
1236    fn transformed(self) -> Result<bool> {
1237        self.map(|t| t.transformed)
1238    }
1239
1240    fn tnr(self) -> Result<TreeNodeRecursion> {
1241        self.map(|t| t.tnr)
1242    }
1243}
1244
1245/// Helper trait for implementing [`TreeNode`] that have children stored as
1246/// `Arc`s. If some trait object, such as `dyn T`, implements this trait,
1247/// its related `Arc<dyn T>` will automatically implement [`TreeNode`].
1248pub trait DynTreeNode {
1249    /// Returns all children of the specified `TreeNode`.
1250    fn arc_children(&self) -> Vec<&Arc<Self>>;
1251
1252    /// Constructs a new node with the specified children.
1253    fn with_new_arc_children(
1254        &self,
1255        arc_self: Arc<Self>,
1256        new_children: Vec<Arc<Self>>,
1257    ) -> Result<Arc<Self>>;
1258}
1259
1260/// Blanket implementation for any `Arc<T>` where `T` implements [`DynTreeNode`]
1261/// (such as [`Arc<dyn PhysicalExpr>`]).
1262impl<T: DynTreeNode + ?Sized> TreeNode for Arc<T> {
1263    fn apply_children<'n, F: FnMut(&'n Self) -> Result<TreeNodeRecursion>>(
1264        &'n self,
1265        f: F,
1266    ) -> Result<TreeNodeRecursion> {
1267        self.arc_children().into_iter().apply_until_stop(f)
1268    }
1269
1270    fn map_children<F: FnMut(Self) -> Result<Transformed<Self>>>(
1271        self,
1272        f: F,
1273    ) -> Result<Transformed<Self>> {
1274        let children = self.arc_children();
1275        if !children.is_empty() {
1276            let new_children = children
1277                .into_iter()
1278                .cloned()
1279                .map_until_stop_and_collect(f)?;
1280            // Propagate up `new_children.transformed` and `new_children.tnr`
1281            // along with the node containing transformed children.
1282            if new_children.transformed {
1283                let arc_self = Arc::clone(&self);
1284                new_children.map_data(|new_children| {
1285                    self.with_new_arc_children(arc_self, new_children)
1286                })
1287            } else {
1288                Ok(Transformed::new(self, false, new_children.tnr))
1289            }
1290        } else {
1291            Ok(Transformed::no(self))
1292        }
1293    }
1294}
1295
1296/// Instead of implementing [`TreeNode`], it's recommended to implement a [`ConcreteTreeNode`] for
1297/// trees that contain nodes with payloads. This approach ensures safe execution of algorithms
1298/// involving payloads, by enforcing rules for detaching and reattaching child nodes.
1299pub trait ConcreteTreeNode: Sized {
1300    /// Provides read-only access to child nodes.
1301    fn children(&self) -> &[Self];
1302
1303    /// Detaches the node from its children, returning the node itself and its detached children.
1304    fn take_children(self) -> (Self, Vec<Self>);
1305
1306    /// Reattaches updated child nodes to the node, returning the updated node.
1307    fn with_new_children(self, children: Vec<Self>) -> Result<Self>;
1308}
1309
1310impl<T: ConcreteTreeNode> TreeNode for T {
1311    fn apply_children<'n, F: FnMut(&'n Self) -> Result<TreeNodeRecursion>>(
1312        &'n self,
1313        f: F,
1314    ) -> Result<TreeNodeRecursion> {
1315        self.children().iter().apply_until_stop(f)
1316    }
1317
1318    fn map_children<F: FnMut(Self) -> Result<Transformed<Self>>>(
1319        self,
1320        f: F,
1321    ) -> Result<Transformed<Self>> {
1322        let (new_self, children) = self.take_children();
1323        if !children.is_empty() {
1324            let new_children = children.into_iter().map_until_stop_and_collect(f)?;
1325            // Propagate up `new_children.transformed` and `new_children.tnr` along with
1326            // the node containing transformed children.
1327            new_children.map_data(|new_children| new_self.with_new_children(new_children))
1328        } else {
1329            Ok(Transformed::no(new_self))
1330        }
1331    }
1332}
1333
1334#[cfg(test)]
1335pub(crate) mod tests {
1336    use std::collections::HashMap;
1337    use std::fmt::Display;
1338
1339    use crate::tree_node::{
1340        Transformed, TreeNode, TreeNodeContainer, TreeNodeRecursion, TreeNodeRewriter,
1341        TreeNodeVisitor,
1342    };
1343    use crate::Result;
1344
1345    #[derive(Debug, Eq, Hash, PartialEq, Clone)]
1346    pub struct TestTreeNode<T> {
1347        pub(crate) children: Vec<TestTreeNode<T>>,
1348        pub(crate) data: T,
1349    }
1350
1351    impl<T> TestTreeNode<T> {
1352        pub(crate) fn new(children: Vec<TestTreeNode<T>>, data: T) -> Self {
1353            Self { children, data }
1354        }
1355
1356        pub(crate) fn new_leaf(data: T) -> Self {
1357            Self {
1358                children: vec![],
1359                data,
1360            }
1361        }
1362
1363        pub(crate) fn is_leaf(&self) -> bool {
1364            self.children.is_empty()
1365        }
1366    }
1367
1368    impl<T> TreeNode for TestTreeNode<T> {
1369        fn apply_children<'n, F: FnMut(&'n Self) -> Result<TreeNodeRecursion>>(
1370            &'n self,
1371            f: F,
1372        ) -> Result<TreeNodeRecursion> {
1373            self.children.apply_elements(f)
1374        }
1375
1376        fn map_children<F: FnMut(Self) -> Result<Transformed<Self>>>(
1377            self,
1378            f: F,
1379        ) -> Result<Transformed<Self>> {
1380            Ok(self
1381                .children
1382                .map_elements(f)?
1383                .update_data(|new_children| Self {
1384                    children: new_children,
1385                    ..self
1386                }))
1387        }
1388    }
1389
1390    impl<'a, T: 'a> TreeNodeContainer<'a, Self> for TestTreeNode<T> {
1391        fn apply_elements<F: FnMut(&'a Self) -> Result<TreeNodeRecursion>>(
1392            &'a self,
1393            mut f: F,
1394        ) -> Result<TreeNodeRecursion> {
1395            f(self)
1396        }
1397
1398        fn map_elements<F: FnMut(Self) -> Result<Transformed<Self>>>(
1399            self,
1400            mut f: F,
1401        ) -> Result<Transformed<Self>> {
1402            f(self)
1403        }
1404    }
1405
1406    //       J
1407    //       |
1408    //       I
1409    //       |
1410    //       F
1411    //     /   \
1412    //    E     G
1413    //    |     |
1414    //    C     H
1415    //  /   \
1416    // B     D
1417    //       |
1418    //       A
1419    fn test_tree() -> TestTreeNode<String> {
1420        let node_a = TestTreeNode::new_leaf("a".to_string());
1421        let node_b = TestTreeNode::new_leaf("b".to_string());
1422        let node_d = TestTreeNode::new(vec![node_a], "d".to_string());
1423        let node_c = TestTreeNode::new(vec![node_b, node_d], "c".to_string());
1424        let node_e = TestTreeNode::new(vec![node_c], "e".to_string());
1425        let node_h = TestTreeNode::new_leaf("h".to_string());
1426        let node_g = TestTreeNode::new(vec![node_h], "g".to_string());
1427        let node_f = TestTreeNode::new(vec![node_e, node_g], "f".to_string());
1428        let node_i = TestTreeNode::new(vec![node_f], "i".to_string());
1429        TestTreeNode::new(vec![node_i], "j".to_string())
1430    }
1431
1432    // Continue on all nodes
1433    // Expected visits in a combined traversal
1434    fn all_visits() -> Vec<String> {
1435        vec![
1436            "f_down(j)",
1437            "f_down(i)",
1438            "f_down(f)",
1439            "f_down(e)",
1440            "f_down(c)",
1441            "f_down(b)",
1442            "f_up(b)",
1443            "f_down(d)",
1444            "f_down(a)",
1445            "f_up(a)",
1446            "f_up(d)",
1447            "f_up(c)",
1448            "f_up(e)",
1449            "f_down(g)",
1450            "f_down(h)",
1451            "f_up(h)",
1452            "f_up(g)",
1453            "f_up(f)",
1454            "f_up(i)",
1455            "f_up(j)",
1456        ]
1457        .into_iter()
1458        .map(|s| s.to_string())
1459        .collect()
1460    }
1461
1462    // Expected transformed tree after a combined traversal
1463    fn transformed_tree() -> TestTreeNode<String> {
1464        let node_a = TestTreeNode::new_leaf("f_up(f_down(a))".to_string());
1465        let node_b = TestTreeNode::new_leaf("f_up(f_down(b))".to_string());
1466        let node_d = TestTreeNode::new(vec![node_a], "f_up(f_down(d))".to_string());
1467        let node_c =
1468            TestTreeNode::new(vec![node_b, node_d], "f_up(f_down(c))".to_string());
1469        let node_e = TestTreeNode::new(vec![node_c], "f_up(f_down(e))".to_string());
1470        let node_h = TestTreeNode::new_leaf("f_up(f_down(h))".to_string());
1471        let node_g = TestTreeNode::new(vec![node_h], "f_up(f_down(g))".to_string());
1472        let node_f =
1473            TestTreeNode::new(vec![node_e, node_g], "f_up(f_down(f))".to_string());
1474        let node_i = TestTreeNode::new(vec![node_f], "f_up(f_down(i))".to_string());
1475        TestTreeNode::new(vec![node_i], "f_up(f_down(j))".to_string())
1476    }
1477
1478    // Expected transformed tree after a top-down traversal
1479    fn transformed_down_tree() -> TestTreeNode<String> {
1480        let node_a = TestTreeNode::new_leaf("f_down(a)".to_string());
1481        let node_b = TestTreeNode::new_leaf("f_down(b)".to_string());
1482        let node_d = TestTreeNode::new(vec![node_a], "f_down(d)".to_string());
1483        let node_c = TestTreeNode::new(vec![node_b, node_d], "f_down(c)".to_string());
1484        let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string());
1485        let node_h = TestTreeNode::new_leaf("f_down(h)".to_string());
1486        let node_g = TestTreeNode::new(vec![node_h], "f_down(g)".to_string());
1487        let node_f = TestTreeNode::new(vec![node_e, node_g], "f_down(f)".to_string());
1488        let node_i = TestTreeNode::new(vec![node_f], "f_down(i)".to_string());
1489        TestTreeNode::new(vec![node_i], "f_down(j)".to_string())
1490    }
1491
1492    // Expected transformed tree after a bottom-up traversal
1493    fn transformed_up_tree() -> TestTreeNode<String> {
1494        let node_a = TestTreeNode::new_leaf("f_up(a)".to_string());
1495        let node_b = TestTreeNode::new_leaf("f_up(b)".to_string());
1496        let node_d = TestTreeNode::new(vec![node_a], "f_up(d)".to_string());
1497        let node_c = TestTreeNode::new(vec![node_b, node_d], "f_up(c)".to_string());
1498        let node_e = TestTreeNode::new(vec![node_c], "f_up(e)".to_string());
1499        let node_h = TestTreeNode::new_leaf("f_up(h)".to_string());
1500        let node_g = TestTreeNode::new(vec![node_h], "f_up(g)".to_string());
1501        let node_f = TestTreeNode::new(vec![node_e, node_g], "f_up(f)".to_string());
1502        let node_i = TestTreeNode::new(vec![node_f], "f_up(i)".to_string());
1503        TestTreeNode::new(vec![node_i], "f_up(j)".to_string())
1504    }
1505
1506    // f_down Jump on A node
1507    fn f_down_jump_on_a_visits() -> Vec<String> {
1508        vec![
1509            "f_down(j)",
1510            "f_down(i)",
1511            "f_down(f)",
1512            "f_down(e)",
1513            "f_down(c)",
1514            "f_down(b)",
1515            "f_up(b)",
1516            "f_down(d)",
1517            "f_down(a)",
1518            "f_up(a)",
1519            "f_up(d)",
1520            "f_up(c)",
1521            "f_up(e)",
1522            "f_down(g)",
1523            "f_down(h)",
1524            "f_up(h)",
1525            "f_up(g)",
1526            "f_up(f)",
1527            "f_up(i)",
1528            "f_up(j)",
1529        ]
1530        .into_iter()
1531        .map(|s| s.to_string())
1532        .collect()
1533    }
1534
1535    fn f_down_jump_on_a_transformed_down_tree() -> TestTreeNode<String> {
1536        let node_a = TestTreeNode::new_leaf("f_down(a)".to_string());
1537        let node_b = TestTreeNode::new_leaf("f_down(b)".to_string());
1538        let node_d = TestTreeNode::new(vec![node_a], "f_down(d)".to_string());
1539        let node_c = TestTreeNode::new(vec![node_b, node_d], "f_down(c)".to_string());
1540        let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string());
1541        let node_h = TestTreeNode::new_leaf("f_down(h)".to_string());
1542        let node_g = TestTreeNode::new(vec![node_h], "f_down(g)".to_string());
1543        let node_f = TestTreeNode::new(vec![node_e, node_g], "f_down(f)".to_string());
1544        let node_i = TestTreeNode::new(vec![node_f], "f_down(i)".to_string());
1545        TestTreeNode::new(vec![node_i], "f_down(j)".to_string())
1546    }
1547
1548    // f_down Jump on E node
1549    fn f_down_jump_on_e_visits() -> Vec<String> {
1550        vec![
1551            "f_down(j)",
1552            "f_down(i)",
1553            "f_down(f)",
1554            "f_down(e)",
1555            "f_up(e)",
1556            "f_down(g)",
1557            "f_down(h)",
1558            "f_up(h)",
1559            "f_up(g)",
1560            "f_up(f)",
1561            "f_up(i)",
1562            "f_up(j)",
1563        ]
1564        .into_iter()
1565        .map(|s| s.to_string())
1566        .collect()
1567    }
1568
1569    fn f_down_jump_on_e_transformed_tree() -> TestTreeNode<String> {
1570        let node_a = TestTreeNode::new_leaf("a".to_string());
1571        let node_b = TestTreeNode::new_leaf("b".to_string());
1572        let node_d = TestTreeNode::new(vec![node_a], "d".to_string());
1573        let node_c = TestTreeNode::new(vec![node_b, node_d], "c".to_string());
1574        let node_e = TestTreeNode::new(vec![node_c], "f_up(f_down(e))".to_string());
1575        let node_h = TestTreeNode::new_leaf("f_up(f_down(h))".to_string());
1576        let node_g = TestTreeNode::new(vec![node_h], "f_up(f_down(g))".to_string());
1577        let node_f =
1578            TestTreeNode::new(vec![node_e, node_g], "f_up(f_down(f))".to_string());
1579        let node_i = TestTreeNode::new(vec![node_f], "f_up(f_down(i))".to_string());
1580        TestTreeNode::new(vec![node_i], "f_up(f_down(j))".to_string())
1581    }
1582
1583    fn f_down_jump_on_e_transformed_down_tree() -> TestTreeNode<String> {
1584        let node_a = TestTreeNode::new_leaf("a".to_string());
1585        let node_b = TestTreeNode::new_leaf("b".to_string());
1586        let node_d = TestTreeNode::new(vec![node_a], "d".to_string());
1587        let node_c = TestTreeNode::new(vec![node_b, node_d], "c".to_string());
1588        let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string());
1589        let node_h = TestTreeNode::new_leaf("f_down(h)".to_string());
1590        let node_g = TestTreeNode::new(vec![node_h], "f_down(g)".to_string());
1591        let node_f = TestTreeNode::new(vec![node_e, node_g], "f_down(f)".to_string());
1592        let node_i = TestTreeNode::new(vec![node_f], "f_down(i)".to_string());
1593        TestTreeNode::new(vec![node_i], "f_down(j)".to_string())
1594    }
1595
1596    // f_up Jump on A node
1597    fn f_up_jump_on_a_visits() -> Vec<String> {
1598        vec![
1599            "f_down(j)",
1600            "f_down(i)",
1601            "f_down(f)",
1602            "f_down(e)",
1603            "f_down(c)",
1604            "f_down(b)",
1605            "f_up(b)",
1606            "f_down(d)",
1607            "f_down(a)",
1608            "f_up(a)",
1609            "f_down(g)",
1610            "f_down(h)",
1611            "f_up(h)",
1612            "f_up(g)",
1613            "f_up(f)",
1614            "f_up(i)",
1615            "f_up(j)",
1616        ]
1617        .into_iter()
1618        .map(|s| s.to_string())
1619        .collect()
1620    }
1621
1622    fn f_up_jump_on_a_transformed_tree() -> TestTreeNode<String> {
1623        let node_a = TestTreeNode::new_leaf("f_up(f_down(a))".to_string());
1624        let node_b = TestTreeNode::new_leaf("f_up(f_down(b))".to_string());
1625        let node_d = TestTreeNode::new(vec![node_a], "f_down(d)".to_string());
1626        let node_c = TestTreeNode::new(vec![node_b, node_d], "f_down(c)".to_string());
1627        let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string());
1628        let node_h = TestTreeNode::new_leaf("f_up(f_down(h))".to_string());
1629        let node_g = TestTreeNode::new(vec![node_h], "f_up(f_down(g))".to_string());
1630        let node_f =
1631            TestTreeNode::new(vec![node_e, node_g], "f_up(f_down(f))".to_string());
1632        let node_i = TestTreeNode::new(vec![node_f], "f_up(f_down(i))".to_string());
1633        TestTreeNode::new(vec![node_i], "f_up(f_down(j))".to_string())
1634    }
1635
1636    fn f_up_jump_on_a_transformed_up_tree() -> TestTreeNode<String> {
1637        let node_a = TestTreeNode::new_leaf("f_up(a)".to_string());
1638        let node_b = TestTreeNode::new_leaf("f_up(b)".to_string());
1639        let node_d = TestTreeNode::new(vec![node_a], "d".to_string());
1640        let node_c = TestTreeNode::new(vec![node_b, node_d], "c".to_string());
1641        let node_e = TestTreeNode::new(vec![node_c], "e".to_string());
1642        let node_h = TestTreeNode::new_leaf("f_up(h)".to_string());
1643        let node_g = TestTreeNode::new(vec![node_h], "f_up(g)".to_string());
1644        let node_f = TestTreeNode::new(vec![node_e, node_g], "f_up(f)".to_string());
1645        let node_i = TestTreeNode::new(vec![node_f], "f_up(i)".to_string());
1646        TestTreeNode::new(vec![node_i], "f_up(j)".to_string())
1647    }
1648
1649    // f_up Jump on E node
1650    fn f_up_jump_on_e_visits() -> Vec<String> {
1651        vec![
1652            "f_down(j)",
1653            "f_down(i)",
1654            "f_down(f)",
1655            "f_down(e)",
1656            "f_down(c)",
1657            "f_down(b)",
1658            "f_up(b)",
1659            "f_down(d)",
1660            "f_down(a)",
1661            "f_up(a)",
1662            "f_up(d)",
1663            "f_up(c)",
1664            "f_up(e)",
1665            "f_down(g)",
1666            "f_down(h)",
1667            "f_up(h)",
1668            "f_up(g)",
1669            "f_up(f)",
1670            "f_up(i)",
1671            "f_up(j)",
1672        ]
1673        .into_iter()
1674        .map(|s| s.to_string())
1675        .collect()
1676    }
1677
1678    fn f_up_jump_on_e_transformed_tree() -> TestTreeNode<String> {
1679        transformed_tree()
1680    }
1681
1682    fn f_up_jump_on_e_transformed_up_tree() -> TestTreeNode<String> {
1683        transformed_up_tree()
1684    }
1685
1686    // f_down Stop on A node
1687
1688    fn f_down_stop_on_a_visits() -> Vec<String> {
1689        vec![
1690            "f_down(j)",
1691            "f_down(i)",
1692            "f_down(f)",
1693            "f_down(e)",
1694            "f_down(c)",
1695            "f_down(b)",
1696            "f_up(b)",
1697            "f_down(d)",
1698            "f_down(a)",
1699        ]
1700        .into_iter()
1701        .map(|s| s.to_string())
1702        .collect()
1703    }
1704
1705    fn f_down_stop_on_a_transformed_tree() -> TestTreeNode<String> {
1706        let node_a = TestTreeNode::new_leaf("f_down(a)".to_string());
1707        let node_b = TestTreeNode::new_leaf("f_up(f_down(b))".to_string());
1708        let node_d = TestTreeNode::new(vec![node_a], "f_down(d)".to_string());
1709        let node_c = TestTreeNode::new(vec![node_b, node_d], "f_down(c)".to_string());
1710        let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string());
1711        let node_h = TestTreeNode::new_leaf("h".to_string());
1712        let node_g = TestTreeNode::new(vec![node_h], "g".to_string());
1713        let node_f = TestTreeNode::new(vec![node_e, node_g], "f_down(f)".to_string());
1714        let node_i = TestTreeNode::new(vec![node_f], "f_down(i)".to_string());
1715        TestTreeNode::new(vec![node_i], "f_down(j)".to_string())
1716    }
1717
1718    fn f_down_stop_on_a_transformed_down_tree() -> TestTreeNode<String> {
1719        let node_a = TestTreeNode::new_leaf("f_down(a)".to_string());
1720        let node_b = TestTreeNode::new_leaf("f_down(b)".to_string());
1721        let node_d = TestTreeNode::new(vec![node_a], "f_down(d)".to_string());
1722        let node_c = TestTreeNode::new(vec![node_b, node_d], "f_down(c)".to_string());
1723        let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string());
1724        let node_h = TestTreeNode::new_leaf("h".to_string());
1725        let node_g = TestTreeNode::new(vec![node_h], "g".to_string());
1726        let node_f = TestTreeNode::new(vec![node_e, node_g], "f_down(f)".to_string());
1727        let node_i = TestTreeNode::new(vec![node_f], "f_down(i)".to_string());
1728        TestTreeNode::new(vec![node_i], "f_down(j)".to_string())
1729    }
1730
1731    // f_down Stop on E node
1732    fn f_down_stop_on_e_visits() -> Vec<String> {
1733        vec!["f_down(j)", "f_down(i)", "f_down(f)", "f_down(e)"]
1734            .into_iter()
1735            .map(|s| s.to_string())
1736            .collect()
1737    }
1738
1739    fn f_down_stop_on_e_transformed_tree() -> TestTreeNode<String> {
1740        let node_a = TestTreeNode::new_leaf("a".to_string());
1741        let node_b = TestTreeNode::new_leaf("b".to_string());
1742        let node_d = TestTreeNode::new(vec![node_a], "d".to_string());
1743        let node_c = TestTreeNode::new(vec![node_b, node_d], "c".to_string());
1744        let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string());
1745        let node_h = TestTreeNode::new_leaf("h".to_string());
1746        let node_g = TestTreeNode::new(vec![node_h], "g".to_string());
1747        let node_f = TestTreeNode::new(vec![node_e, node_g], "f_down(f)".to_string());
1748        let node_i = TestTreeNode::new(vec![node_f], "f_down(i)".to_string());
1749        TestTreeNode::new(vec![node_i], "f_down(j)".to_string())
1750    }
1751
1752    fn f_down_stop_on_e_transformed_down_tree() -> TestTreeNode<String> {
1753        let node_a = TestTreeNode::new_leaf("a".to_string());
1754        let node_b = TestTreeNode::new_leaf("b".to_string());
1755        let node_d = TestTreeNode::new(vec![node_a], "d".to_string());
1756        let node_c = TestTreeNode::new(vec![node_b, node_d], "c".to_string());
1757        let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string());
1758        let node_h = TestTreeNode::new_leaf("h".to_string());
1759        let node_g = TestTreeNode::new(vec![node_h], "g".to_string());
1760        let node_f = TestTreeNode::new(vec![node_e, node_g], "f_down(f)".to_string());
1761        let node_i = TestTreeNode::new(vec![node_f], "f_down(i)".to_string());
1762        TestTreeNode::new(vec![node_i], "f_down(j)".to_string())
1763    }
1764
1765    // f_up Stop on A node
1766    fn f_up_stop_on_a_visits() -> Vec<String> {
1767        vec![
1768            "f_down(j)",
1769            "f_down(i)",
1770            "f_down(f)",
1771            "f_down(e)",
1772            "f_down(c)",
1773            "f_down(b)",
1774            "f_up(b)",
1775            "f_down(d)",
1776            "f_down(a)",
1777            "f_up(a)",
1778        ]
1779        .into_iter()
1780        .map(|s| s.to_string())
1781        .collect()
1782    }
1783
1784    fn f_up_stop_on_a_transformed_tree() -> TestTreeNode<String> {
1785        let node_a = TestTreeNode::new_leaf("f_up(f_down(a))".to_string());
1786        let node_b = TestTreeNode::new_leaf("f_up(f_down(b))".to_string());
1787        let node_d = TestTreeNode::new(vec![node_a], "f_down(d)".to_string());
1788        let node_c = TestTreeNode::new(vec![node_b, node_d], "f_down(c)".to_string());
1789        let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string());
1790        let node_h = TestTreeNode::new_leaf("h".to_string());
1791        let node_g = TestTreeNode::new(vec![node_h], "g".to_string());
1792        let node_f = TestTreeNode::new(vec![node_e, node_g], "f_down(f)".to_string());
1793        let node_i = TestTreeNode::new(vec![node_f], "f_down(i)".to_string());
1794        TestTreeNode::new(vec![node_i], "f_down(j)".to_string())
1795    }
1796
1797    fn f_up_stop_on_a_transformed_up_tree() -> TestTreeNode<String> {
1798        let node_a = TestTreeNode::new_leaf("f_up(a)".to_string());
1799        let node_b = TestTreeNode::new_leaf("f_up(b)".to_string());
1800        let node_d = TestTreeNode::new(vec![node_a], "d".to_string());
1801        let node_c = TestTreeNode::new(vec![node_b, node_d], "c".to_string());
1802        let node_e = TestTreeNode::new(vec![node_c], "e".to_string());
1803        let node_h = TestTreeNode::new_leaf("h".to_string());
1804        let node_g = TestTreeNode::new(vec![node_h], "g".to_string());
1805        let node_f = TestTreeNode::new(vec![node_e, node_g], "f".to_string());
1806        let node_i = TestTreeNode::new(vec![node_f], "i".to_string());
1807        TestTreeNode::new(vec![node_i], "j".to_string())
1808    }
1809
1810    // f_up Stop on E node
1811    fn f_up_stop_on_e_visits() -> Vec<String> {
1812        vec![
1813            "f_down(j)",
1814            "f_down(i)",
1815            "f_down(f)",
1816            "f_down(e)",
1817            "f_down(c)",
1818            "f_down(b)",
1819            "f_up(b)",
1820            "f_down(d)",
1821            "f_down(a)",
1822            "f_up(a)",
1823            "f_up(d)",
1824            "f_up(c)",
1825            "f_up(e)",
1826        ]
1827        .into_iter()
1828        .map(|s| s.to_string())
1829        .collect()
1830    }
1831
1832    fn f_up_stop_on_e_transformed_tree() -> TestTreeNode<String> {
1833        let node_a = TestTreeNode::new_leaf("f_up(f_down(a))".to_string());
1834        let node_b = TestTreeNode::new_leaf("f_up(f_down(b))".to_string());
1835        let node_d = TestTreeNode::new(vec![node_a], "f_up(f_down(d))".to_string());
1836        let node_c =
1837            TestTreeNode::new(vec![node_b, node_d], "f_up(f_down(c))".to_string());
1838        let node_e = TestTreeNode::new(vec![node_c], "f_up(f_down(e))".to_string());
1839        let node_h = TestTreeNode::new_leaf("h".to_string());
1840        let node_g = TestTreeNode::new(vec![node_h], "g".to_string());
1841        let node_f = TestTreeNode::new(vec![node_e, node_g], "f_down(f)".to_string());
1842        let node_i = TestTreeNode::new(vec![node_f], "f_down(i)".to_string());
1843        TestTreeNode::new(vec![node_i], "f_down(j)".to_string())
1844    }
1845
1846    fn f_up_stop_on_e_transformed_up_tree() -> TestTreeNode<String> {
1847        let node_a = TestTreeNode::new_leaf("f_up(a)".to_string());
1848        let node_b = TestTreeNode::new_leaf("f_up(b)".to_string());
1849        let node_d = TestTreeNode::new(vec![node_a], "f_up(d)".to_string());
1850        let node_c = TestTreeNode::new(vec![node_b, node_d], "f_up(c)".to_string());
1851        let node_e = TestTreeNode::new(vec![node_c], "f_up(e)".to_string());
1852        let node_h = TestTreeNode::new_leaf("h".to_string());
1853        let node_g = TestTreeNode::new(vec![node_h], "g".to_string());
1854        let node_f = TestTreeNode::new(vec![node_e, node_g], "f".to_string());
1855        let node_i = TestTreeNode::new(vec![node_f], "i".to_string());
1856        TestTreeNode::new(vec![node_i], "j".to_string())
1857    }
1858
1859    fn down_visits(visits: Vec<String>) -> Vec<String> {
1860        visits
1861            .into_iter()
1862            .filter(|v| v.starts_with("f_down"))
1863            .collect()
1864    }
1865
1866    type TestVisitorF<T> = Box<dyn FnMut(&TestTreeNode<T>) -> Result<TreeNodeRecursion>>;
1867
1868    struct TestVisitor<T> {
1869        visits: Vec<String>,
1870        f_down: TestVisitorF<T>,
1871        f_up: TestVisitorF<T>,
1872    }
1873
1874    impl<T> TestVisitor<T> {
1875        fn new(f_down: TestVisitorF<T>, f_up: TestVisitorF<T>) -> Self {
1876            Self {
1877                visits: vec![],
1878                f_down,
1879                f_up,
1880            }
1881        }
1882    }
1883
1884    impl<'n, T: Display> TreeNodeVisitor<'n> for TestVisitor<T> {
1885        type Node = TestTreeNode<T>;
1886
1887        fn f_down(&mut self, node: &'n Self::Node) -> Result<TreeNodeRecursion> {
1888            self.visits.push(format!("f_down({})", node.data));
1889            (*self.f_down)(node)
1890        }
1891
1892        fn f_up(&mut self, node: &'n Self::Node) -> Result<TreeNodeRecursion> {
1893            self.visits.push(format!("f_up({})", node.data));
1894            (*self.f_up)(node)
1895        }
1896    }
1897
1898    fn visit_continue<T>(_: &TestTreeNode<T>) -> Result<TreeNodeRecursion> {
1899        Ok(TreeNodeRecursion::Continue)
1900    }
1901
1902    fn visit_event_on<T: PartialEq, D: Into<T>>(
1903        data: D,
1904        event: TreeNodeRecursion,
1905    ) -> impl FnMut(&TestTreeNode<T>) -> Result<TreeNodeRecursion> {
1906        let d = data.into();
1907        move |node| {
1908            Ok(if node.data == d {
1909                event
1910            } else {
1911                TreeNodeRecursion::Continue
1912            })
1913        }
1914    }
1915
1916    macro_rules! visit_test {
1917        ($NAME:ident, $F_DOWN:expr, $F_UP:expr, $EXPECTED_VISITS:expr) => {
1918            #[test]
1919            fn $NAME() -> Result<()> {
1920                let tree = test_tree();
1921                let mut visitor = TestVisitor::new(Box::new($F_DOWN), Box::new($F_UP));
1922                tree.visit(&mut visitor)?;
1923                assert_eq!(visitor.visits, $EXPECTED_VISITS);
1924
1925                Ok(())
1926            }
1927        };
1928    }
1929
1930    macro_rules! test_apply {
1931        ($NAME:ident, $F:expr, $EXPECTED_VISITS:expr) => {
1932            #[test]
1933            fn $NAME() -> Result<()> {
1934                let tree = test_tree();
1935                let mut visits = vec![];
1936                tree.apply(|node| {
1937                    visits.push(format!("f_down({})", node.data));
1938                    $F(node)
1939                })?;
1940                assert_eq!(visits, $EXPECTED_VISITS);
1941
1942                Ok(())
1943            }
1944        };
1945    }
1946
1947    type TestRewriterF<T> =
1948        Box<dyn FnMut(TestTreeNode<T>) -> Result<Transformed<TestTreeNode<T>>>>;
1949
1950    struct TestRewriter<T> {
1951        f_down: TestRewriterF<T>,
1952        f_up: TestRewriterF<T>,
1953    }
1954
1955    impl<T> TestRewriter<T> {
1956        fn new(f_down: TestRewriterF<T>, f_up: TestRewriterF<T>) -> Self {
1957            Self { f_down, f_up }
1958        }
1959    }
1960
1961    impl<T: Display> TreeNodeRewriter for TestRewriter<T> {
1962        type Node = TestTreeNode<T>;
1963
1964        fn f_down(&mut self, node: Self::Node) -> Result<Transformed<Self::Node>> {
1965            (*self.f_down)(node)
1966        }
1967
1968        fn f_up(&mut self, node: Self::Node) -> Result<Transformed<Self::Node>> {
1969            (*self.f_up)(node)
1970        }
1971    }
1972
1973    fn transform_yes<N: Display, T: Display + From<String>>(
1974        transformation_name: N,
1975    ) -> impl FnMut(TestTreeNode<T>) -> Result<Transformed<TestTreeNode<T>>> {
1976        move |node| {
1977            Ok(Transformed::yes(TestTreeNode::new(
1978                node.children,
1979                format!("{}({})", transformation_name, node.data).into(),
1980            )))
1981        }
1982    }
1983
1984    fn transform_and_event_on<
1985        N: Display,
1986        T: PartialEq + Display + From<String>,
1987        D: Into<T>,
1988    >(
1989        transformation_name: N,
1990        data: D,
1991        event: TreeNodeRecursion,
1992    ) -> impl FnMut(TestTreeNode<T>) -> Result<Transformed<TestTreeNode<T>>> {
1993        let d = data.into();
1994        move |node| {
1995            let new_node = TestTreeNode::new(
1996                node.children,
1997                format!("{}({})", transformation_name, node.data).into(),
1998            );
1999            Ok(if node.data == d {
2000                Transformed::new(new_node, true, event)
2001            } else {
2002                Transformed::yes(new_node)
2003            })
2004        }
2005    }
2006
2007    macro_rules! rewrite_test {
2008        ($NAME:ident, $F_DOWN:expr, $F_UP:expr, $EXPECTED_TREE:expr) => {
2009            #[test]
2010            fn $NAME() -> Result<()> {
2011                let tree = test_tree();
2012                let mut rewriter = TestRewriter::new(Box::new($F_DOWN), Box::new($F_UP));
2013                assert_eq!(tree.rewrite(&mut rewriter)?, $EXPECTED_TREE);
2014
2015                Ok(())
2016            }
2017        };
2018    }
2019
2020    macro_rules! transform_test {
2021        ($NAME:ident, $F_DOWN:expr, $F_UP:expr, $EXPECTED_TREE:expr) => {
2022            #[test]
2023            fn $NAME() -> Result<()> {
2024                let tree = test_tree();
2025                assert_eq!(tree.transform_down_up($F_DOWN, $F_UP,)?, $EXPECTED_TREE);
2026
2027                Ok(())
2028            }
2029        };
2030    }
2031
2032    macro_rules! transform_down_test {
2033        ($NAME:ident, $F:expr, $EXPECTED_TREE:expr) => {
2034            #[test]
2035            fn $NAME() -> Result<()> {
2036                let tree = test_tree();
2037                assert_eq!(tree.transform_down($F)?, $EXPECTED_TREE);
2038
2039                Ok(())
2040            }
2041        };
2042    }
2043
2044    macro_rules! transform_up_test {
2045        ($NAME:ident, $F:expr, $EXPECTED_TREE:expr) => {
2046            #[test]
2047            fn $NAME() -> Result<()> {
2048                let tree = test_tree();
2049                assert_eq!(tree.transform_up($F)?, $EXPECTED_TREE);
2050
2051                Ok(())
2052            }
2053        };
2054    }
2055
2056    visit_test!(test_visit, visit_continue, visit_continue, all_visits());
2057    visit_test!(
2058        test_visit_f_down_jump_on_a,
2059        visit_event_on("a", TreeNodeRecursion::Jump),
2060        visit_continue,
2061        f_down_jump_on_a_visits()
2062    );
2063    visit_test!(
2064        test_visit_f_down_jump_on_e,
2065        visit_event_on("e", TreeNodeRecursion::Jump),
2066        visit_continue,
2067        f_down_jump_on_e_visits()
2068    );
2069    visit_test!(
2070        test_visit_f_up_jump_on_a,
2071        visit_continue,
2072        visit_event_on("a", TreeNodeRecursion::Jump),
2073        f_up_jump_on_a_visits()
2074    );
2075    visit_test!(
2076        test_visit_f_up_jump_on_e,
2077        visit_continue,
2078        visit_event_on("e", TreeNodeRecursion::Jump),
2079        f_up_jump_on_e_visits()
2080    );
2081    visit_test!(
2082        test_visit_f_down_stop_on_a,
2083        visit_event_on("a", TreeNodeRecursion::Stop),
2084        visit_continue,
2085        f_down_stop_on_a_visits()
2086    );
2087    visit_test!(
2088        test_visit_f_down_stop_on_e,
2089        visit_event_on("e", TreeNodeRecursion::Stop),
2090        visit_continue,
2091        f_down_stop_on_e_visits()
2092    );
2093    visit_test!(
2094        test_visit_f_up_stop_on_a,
2095        visit_continue,
2096        visit_event_on("a", TreeNodeRecursion::Stop),
2097        f_up_stop_on_a_visits()
2098    );
2099    visit_test!(
2100        test_visit_f_up_stop_on_e,
2101        visit_continue,
2102        visit_event_on("e", TreeNodeRecursion::Stop),
2103        f_up_stop_on_e_visits()
2104    );
2105
2106    test_apply!(test_apply, visit_continue, down_visits(all_visits()));
2107    test_apply!(
2108        test_apply_f_down_jump_on_a,
2109        visit_event_on("a", TreeNodeRecursion::Jump),
2110        down_visits(f_down_jump_on_a_visits())
2111    );
2112    test_apply!(
2113        test_apply_f_down_jump_on_e,
2114        visit_event_on("e", TreeNodeRecursion::Jump),
2115        down_visits(f_down_jump_on_e_visits())
2116    );
2117    test_apply!(
2118        test_apply_f_down_stop_on_a,
2119        visit_event_on("a", TreeNodeRecursion::Stop),
2120        down_visits(f_down_stop_on_a_visits())
2121    );
2122    test_apply!(
2123        test_apply_f_down_stop_on_e,
2124        visit_event_on("e", TreeNodeRecursion::Stop),
2125        down_visits(f_down_stop_on_e_visits())
2126    );
2127
2128    rewrite_test!(
2129        test_rewrite,
2130        transform_yes("f_down"),
2131        transform_yes("f_up"),
2132        Transformed::yes(transformed_tree())
2133    );
2134    rewrite_test!(
2135        test_rewrite_f_down_jump_on_a,
2136        transform_and_event_on("f_down", "a", TreeNodeRecursion::Jump),
2137        transform_yes("f_up"),
2138        Transformed::yes(transformed_tree())
2139    );
2140    rewrite_test!(
2141        test_rewrite_f_down_jump_on_e,
2142        transform_and_event_on("f_down", "e", TreeNodeRecursion::Jump),
2143        transform_yes("f_up"),
2144        Transformed::yes(f_down_jump_on_e_transformed_tree())
2145    );
2146    rewrite_test!(
2147        test_rewrite_f_up_jump_on_a,
2148        transform_yes("f_down"),
2149        transform_and_event_on("f_up", "f_down(a)", TreeNodeRecursion::Jump),
2150        Transformed::yes(f_up_jump_on_a_transformed_tree())
2151    );
2152    rewrite_test!(
2153        test_rewrite_f_up_jump_on_e,
2154        transform_yes("f_down"),
2155        transform_and_event_on("f_up", "f_down(e)", TreeNodeRecursion::Jump),
2156        Transformed::yes(f_up_jump_on_e_transformed_tree())
2157    );
2158    rewrite_test!(
2159        test_rewrite_f_down_stop_on_a,
2160        transform_and_event_on("f_down", "a", TreeNodeRecursion::Stop),
2161        transform_yes("f_up"),
2162        Transformed::new(
2163            f_down_stop_on_a_transformed_tree(),
2164            true,
2165            TreeNodeRecursion::Stop
2166        )
2167    );
2168    rewrite_test!(
2169        test_rewrite_f_down_stop_on_e,
2170        transform_and_event_on("f_down", "e", TreeNodeRecursion::Stop),
2171        transform_yes("f_up"),
2172        Transformed::new(
2173            f_down_stop_on_e_transformed_tree(),
2174            true,
2175            TreeNodeRecursion::Stop
2176        )
2177    );
2178    rewrite_test!(
2179        test_rewrite_f_up_stop_on_a,
2180        transform_yes("f_down"),
2181        transform_and_event_on("f_up", "f_down(a)", TreeNodeRecursion::Stop),
2182        Transformed::new(
2183            f_up_stop_on_a_transformed_tree(),
2184            true,
2185            TreeNodeRecursion::Stop
2186        )
2187    );
2188    rewrite_test!(
2189        test_rewrite_f_up_stop_on_e,
2190        transform_yes("f_down"),
2191        transform_and_event_on("f_up", "f_down(e)", TreeNodeRecursion::Stop),
2192        Transformed::new(
2193            f_up_stop_on_e_transformed_tree(),
2194            true,
2195            TreeNodeRecursion::Stop
2196        )
2197    );
2198
2199    transform_test!(
2200        test_transform,
2201        transform_yes("f_down"),
2202        transform_yes("f_up"),
2203        Transformed::yes(transformed_tree())
2204    );
2205    transform_test!(
2206        test_transform_f_down_jump_on_a,
2207        transform_and_event_on("f_down", "a", TreeNodeRecursion::Jump),
2208        transform_yes("f_up"),
2209        Transformed::yes(transformed_tree())
2210    );
2211    transform_test!(
2212        test_transform_f_down_jump_on_e,
2213        transform_and_event_on("f_down", "e", TreeNodeRecursion::Jump),
2214        transform_yes("f_up"),
2215        Transformed::yes(f_down_jump_on_e_transformed_tree())
2216    );
2217    transform_test!(
2218        test_transform_f_up_jump_on_a,
2219        transform_yes("f_down"),
2220        transform_and_event_on("f_up", "f_down(a)", TreeNodeRecursion::Jump),
2221        Transformed::yes(f_up_jump_on_a_transformed_tree())
2222    );
2223    transform_test!(
2224        test_transform_f_up_jump_on_e,
2225        transform_yes("f_down"),
2226        transform_and_event_on("f_up", "f_down(e)", TreeNodeRecursion::Jump),
2227        Transformed::yes(f_up_jump_on_e_transformed_tree())
2228    );
2229    transform_test!(
2230        test_transform_f_down_stop_on_a,
2231        transform_and_event_on("f_down", "a", TreeNodeRecursion::Stop),
2232        transform_yes("f_up"),
2233        Transformed::new(
2234            f_down_stop_on_a_transformed_tree(),
2235            true,
2236            TreeNodeRecursion::Stop
2237        )
2238    );
2239    transform_test!(
2240        test_transform_f_down_stop_on_e,
2241        transform_and_event_on("f_down", "e", TreeNodeRecursion::Stop),
2242        transform_yes("f_up"),
2243        Transformed::new(
2244            f_down_stop_on_e_transformed_tree(),
2245            true,
2246            TreeNodeRecursion::Stop
2247        )
2248    );
2249    transform_test!(
2250        test_transform_f_up_stop_on_a,
2251        transform_yes("f_down"),
2252        transform_and_event_on("f_up", "f_down(a)", TreeNodeRecursion::Stop),
2253        Transformed::new(
2254            f_up_stop_on_a_transformed_tree(),
2255            true,
2256            TreeNodeRecursion::Stop
2257        )
2258    );
2259    transform_test!(
2260        test_transform_f_up_stop_on_e,
2261        transform_yes("f_down"),
2262        transform_and_event_on("f_up", "f_down(e)", TreeNodeRecursion::Stop),
2263        Transformed::new(
2264            f_up_stop_on_e_transformed_tree(),
2265            true,
2266            TreeNodeRecursion::Stop
2267        )
2268    );
2269
2270    transform_down_test!(
2271        test_transform_down,
2272        transform_yes("f_down"),
2273        Transformed::yes(transformed_down_tree())
2274    );
2275    transform_down_test!(
2276        test_transform_down_f_down_jump_on_a,
2277        transform_and_event_on("f_down", "a", TreeNodeRecursion::Jump),
2278        Transformed::yes(f_down_jump_on_a_transformed_down_tree())
2279    );
2280    transform_down_test!(
2281        test_transform_down_f_down_jump_on_e,
2282        transform_and_event_on("f_down", "e", TreeNodeRecursion::Jump),
2283        Transformed::yes(f_down_jump_on_e_transformed_down_tree())
2284    );
2285    transform_down_test!(
2286        test_transform_down_f_down_stop_on_a,
2287        transform_and_event_on("f_down", "a", TreeNodeRecursion::Stop),
2288        Transformed::new(
2289            f_down_stop_on_a_transformed_down_tree(),
2290            true,
2291            TreeNodeRecursion::Stop
2292        )
2293    );
2294    transform_down_test!(
2295        test_transform_down_f_down_stop_on_e,
2296        transform_and_event_on("f_down", "e", TreeNodeRecursion::Stop),
2297        Transformed::new(
2298            f_down_stop_on_e_transformed_down_tree(),
2299            true,
2300            TreeNodeRecursion::Stop
2301        )
2302    );
2303
2304    transform_up_test!(
2305        test_transform_up,
2306        transform_yes("f_up"),
2307        Transformed::yes(transformed_up_tree())
2308    );
2309    transform_up_test!(
2310        test_transform_up_f_up_jump_on_a,
2311        transform_and_event_on("f_up", "a", TreeNodeRecursion::Jump),
2312        Transformed::yes(f_up_jump_on_a_transformed_up_tree())
2313    );
2314    transform_up_test!(
2315        test_transform_up_f_up_jump_on_e,
2316        transform_and_event_on("f_up", "e", TreeNodeRecursion::Jump),
2317        Transformed::yes(f_up_jump_on_e_transformed_up_tree())
2318    );
2319    transform_up_test!(
2320        test_transform_up_f_up_stop_on_a,
2321        transform_and_event_on("f_up", "a", TreeNodeRecursion::Stop),
2322        Transformed::new(
2323            f_up_stop_on_a_transformed_up_tree(),
2324            true,
2325            TreeNodeRecursion::Stop
2326        )
2327    );
2328    transform_up_test!(
2329        test_transform_up_f_up_stop_on_e,
2330        transform_and_event_on("f_up", "e", TreeNodeRecursion::Stop),
2331        Transformed::new(
2332            f_up_stop_on_e_transformed_up_tree(),
2333            true,
2334            TreeNodeRecursion::Stop
2335        )
2336    );
2337
2338    //             F
2339    //          /  |  \
2340    //       /     |     \
2341    //    E        C        A
2342    //    |      /   \
2343    //    C     B     D
2344    //  /   \         |
2345    // B     D        A
2346    //       |
2347    //       A
2348    #[test]
2349    fn test_apply_and_visit_references() -> Result<()> {
2350        let node_a = TestTreeNode::new_leaf("a".to_string());
2351        let node_b = TestTreeNode::new_leaf("b".to_string());
2352        let node_d = TestTreeNode::new(vec![node_a], "d".to_string());
2353        let node_c = TestTreeNode::new(vec![node_b, node_d], "c".to_string());
2354        let node_e = TestTreeNode::new(vec![node_c], "e".to_string());
2355        let node_a_2 = TestTreeNode::new_leaf("a".to_string());
2356        let node_b_2 = TestTreeNode::new_leaf("b".to_string());
2357        let node_d_2 = TestTreeNode::new(vec![node_a_2], "d".to_string());
2358        let node_c_2 = TestTreeNode::new(vec![node_b_2, node_d_2], "c".to_string());
2359        let node_a_3 = TestTreeNode::new_leaf("a".to_string());
2360        let tree = TestTreeNode::new(vec![node_e, node_c_2, node_a_3], "f".to_string());
2361
2362        let node_f_ref = &tree;
2363        let node_e_ref = &node_f_ref.children[0];
2364        let node_c_ref = &node_e_ref.children[0];
2365        let node_b_ref = &node_c_ref.children[0];
2366        let node_d_ref = &node_c_ref.children[1];
2367        let node_a_ref = &node_d_ref.children[0];
2368
2369        let mut m: HashMap<&TestTreeNode<String>, usize> = HashMap::new();
2370        tree.apply(|e| {
2371            *m.entry(e).or_insert(0) += 1;
2372            Ok(TreeNodeRecursion::Continue)
2373        })?;
2374
2375        let expected = HashMap::from([
2376            (node_f_ref, 1),
2377            (node_e_ref, 1),
2378            (node_c_ref, 2),
2379            (node_d_ref, 2),
2380            (node_b_ref, 2),
2381            (node_a_ref, 3),
2382        ]);
2383        assert_eq!(m, expected);
2384
2385        struct TestVisitor<'n> {
2386            m: HashMap<&'n TestTreeNode<String>, (usize, usize)>,
2387        }
2388
2389        impl<'n> TreeNodeVisitor<'n> for TestVisitor<'n> {
2390            type Node = TestTreeNode<String>;
2391
2392            fn f_down(&mut self, node: &'n Self::Node) -> Result<TreeNodeRecursion> {
2393                let (down_count, _) = self.m.entry(node).or_insert((0, 0));
2394                *down_count += 1;
2395                Ok(TreeNodeRecursion::Continue)
2396            }
2397
2398            fn f_up(&mut self, node: &'n Self::Node) -> Result<TreeNodeRecursion> {
2399                let (_, up_count) = self.m.entry(node).or_insert((0, 0));
2400                *up_count += 1;
2401                Ok(TreeNodeRecursion::Continue)
2402            }
2403        }
2404
2405        let mut visitor = TestVisitor { m: HashMap::new() };
2406        tree.visit(&mut visitor)?;
2407
2408        let expected = HashMap::from([
2409            (node_f_ref, (1, 1)),
2410            (node_e_ref, (1, 1)),
2411            (node_c_ref, (2, 2)),
2412            (node_d_ref, (2, 2)),
2413            (node_b_ref, (2, 2)),
2414            (node_a_ref, (3, 3)),
2415        ]);
2416        assert_eq!(visitor.m, expected);
2417
2418        Ok(())
2419    }
2420
2421    #[cfg(feature = "recursive_protection")]
2422    #[test]
2423    fn test_large_tree() {
2424        let mut item = TestTreeNode::new_leaf("initial".to_string());
2425        for i in 0..3000 {
2426            item = TestTreeNode::new(vec![item], format!("parent-{i}"));
2427        }
2428
2429        let mut visitor =
2430            TestVisitor::new(Box::new(visit_continue), Box::new(visit_continue));
2431
2432        item.visit(&mut visitor).unwrap();
2433    }
2434}