datafusion_common/
tree_node.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`TreeNode`] for visiting and rewriting expression and plan trees
19
20use crate::Result;
21use std::collections::HashMap;
22use std::hash::Hash;
23use std::sync::Arc;
24
25/// These macros are used to determine continuation during transforming traversals.
26macro_rules! handle_transform_recursion {
27    ($F_DOWN:expr, $F_CHILD:expr, $F_UP:expr) => {{
28        $F_DOWN?
29            .transform_children(|n| n.map_children($F_CHILD))?
30            .transform_parent($F_UP)
31    }};
32}
33
34/// API for inspecting and rewriting tree data structures.
35///
36/// The `TreeNode` API is used to express algorithms separately from traversing
37/// the structure of `TreeNode`s, avoiding substantial code duplication.
38///
39/// This trait is implemented for plans ([`ExecutionPlan`], [`LogicalPlan`]) and
40/// expression trees ([`PhysicalExpr`], [`Expr`]) as well as Plan+Payload
41/// combinations [`PlanContext`] and [`ExprContext`].
42///
43/// # Overview
44/// There are three categories of TreeNode APIs:
45///
46/// 1. "Inspecting" APIs to traverse a tree of `&TreeNodes`:
47///    [`apply`], [`visit`], [`exists`].
48///
49/// 2. "Transforming" APIs that traverse and consume a tree of `TreeNode`s
50///    producing possibly changed `TreeNode`s: [`transform`], [`transform_up`],
51///    [`transform_down`], [`transform_down_up`], and [`rewrite`].
52///
53/// 3. Internal APIs used to implement the `TreeNode` API: [`apply_children`],
54///    and [`map_children`].
55///
56/// | Traversal Order | Inspecting | Transforming |
57/// | --- | --- | --- |
58/// | top-down | [`apply`], [`exists`] | [`transform_down`]|
59/// | bottom-up | | [`transform`] , [`transform_up`]|
60/// | combined with separate `f_down` and `f_up` closures | | [`transform_down_up`] |
61/// | combined with `f_down()` and `f_up()` in an object | [`visit`]  | [`rewrite`] |
62///
63/// **Note**:while there is currently no in-place mutation API that uses `&mut
64/// TreeNode`, the transforming APIs are efficient and optimized to avoid
65/// cloning.
66///
67/// [`apply`]: Self::apply
68/// [`visit`]: Self::visit
69/// [`exists`]: Self::exists
70/// [`transform`]: Self::transform
71/// [`transform_up`]: Self::transform_up
72/// [`transform_down`]: Self::transform_down
73/// [`transform_down_up`]: Self::transform_down_up
74/// [`rewrite`]: Self::rewrite
75/// [`apply_children`]: Self::apply_children
76/// [`map_children`]: Self::map_children
77///
78/// # Terminology
79/// The following terms are used in this trait
80///
81/// * `f_down`: Invoked before any children of the current node are visited.
82/// * `f_up`: Invoked after all children of the current node are visited.
83/// * `f`: closure that is applied to the current node.
84/// * `map_*`: applies a transformation to rewrite owned nodes
85/// * `apply_*`:  invokes a function on borrowed nodes
86/// * `transform_`: applies a transformation to rewrite owned nodes
87///
88/// <!-- Since these are in the datafusion-common crate, can't use intra doc links) -->
89/// [`ExecutionPlan`]: https://docs.rs/datafusion/latest/datafusion/physical_plan/trait.ExecutionPlan.html
90/// [`PhysicalExpr`]: https://docs.rs/datafusion/latest/datafusion/physical_plan/trait.PhysicalExpr.html
91/// [`LogicalPlan`]: https://docs.rs/datafusion-expr/latest/datafusion_expr/logical_plan/enum.LogicalPlan.html
92/// [`Expr`]: https://docs.rs/datafusion-expr/latest/datafusion_expr/expr/enum.Expr.html
93/// [`PlanContext`]: https://docs.rs/datafusion/latest/datafusion/physical_plan/tree_node/struct.PlanContext.html
94/// [`ExprContext`]: https://docs.rs/datafusion/latest/datafusion/physical_expr/tree_node/struct.ExprContext.html
95pub trait TreeNode: Sized {
96    /// Visit the tree node with a [`TreeNodeVisitor`], performing a
97    /// depth-first walk of the node and its children.
98    ///
99    /// [`TreeNodeVisitor::f_down()`] is called in top-down order (before
100    /// children are visited), [`TreeNodeVisitor::f_up()`] is called in
101    /// bottom-up order (after children are visited).
102    ///
103    /// # Return Value
104    /// Specifies how the tree walk ended. See [`TreeNodeRecursion`] for details.
105    ///
106    /// # See Also:
107    /// * [`Self::apply`] for inspecting nodes with a closure
108    /// * [`Self::rewrite`] to rewrite owned `TreeNode`s
109    ///
110    /// # Example
111    /// Consider the following tree structure:
112    /// ```text
113    /// ParentNode
114    ///    left: ChildNode1
115    ///    right: ChildNode2
116    /// ```
117    ///
118    /// Here, the nodes would be visited using the following order:
119    /// ```text
120    /// TreeNodeVisitor::f_down(ParentNode)
121    /// TreeNodeVisitor::f_down(ChildNode1)
122    /// TreeNodeVisitor::f_up(ChildNode1)
123    /// TreeNodeVisitor::f_down(ChildNode2)
124    /// TreeNodeVisitor::f_up(ChildNode2)
125    /// TreeNodeVisitor::f_up(ParentNode)
126    /// ```
127    #[cfg_attr(feature = "recursive_protection", recursive::recursive)]
128    fn visit<'n, V: TreeNodeVisitor<'n, Node = Self>>(
129        &'n self,
130        visitor: &mut V,
131    ) -> Result<TreeNodeRecursion> {
132        visitor
133            .f_down(self)?
134            .visit_children(|| self.apply_children(|c| c.visit(visitor)))?
135            .visit_parent(|| visitor.f_up(self))
136    }
137
138    /// Rewrite the tree node with a [`TreeNodeRewriter`], performing a
139    /// depth-first walk of the node and its children.
140    ///
141    /// [`TreeNodeRewriter::f_down()`] is called in top-down order (before
142    /// children are visited), [`TreeNodeRewriter::f_up()`] is called in
143    /// bottom-up order (after children are visited).
144    ///
145    /// Note: If using the default [`TreeNodeRewriter::f_up`] or
146    /// [`TreeNodeRewriter::f_down`] that do nothing, consider using
147    /// [`Self::transform_down`] instead.
148    ///
149    /// # Return Value
150    /// The returns value specifies how the tree walk should proceed. See
151    /// [`TreeNodeRecursion`] for details. If an [`Err`] is returned, the
152    /// recursion stops immediately.
153    ///
154    /// # See Also
155    /// * [`Self::visit`] for inspecting (without modification) `TreeNode`s
156    /// * [Self::transform_down_up] for a top-down (pre-order) traversal.
157    /// * [Self::transform_down] for a top-down (pre-order) traversal.
158    /// * [`Self::transform_up`] for a bottom-up (post-order) traversal.
159    ///
160    /// # Example
161    /// Consider the following tree structure:
162    /// ```text
163    /// ParentNode
164    ///    left: ChildNode1
165    ///    right: ChildNode2
166    /// ```
167    ///
168    /// Here, the nodes would be visited using the following order:
169    /// ```text
170    /// TreeNodeRewriter::f_down(ParentNode)
171    /// TreeNodeRewriter::f_down(ChildNode1)
172    /// TreeNodeRewriter::f_up(ChildNode1)
173    /// TreeNodeRewriter::f_down(ChildNode2)
174    /// TreeNodeRewriter::f_up(ChildNode2)
175    /// TreeNodeRewriter::f_up(ParentNode)
176    /// ```
177    #[cfg_attr(feature = "recursive_protection", recursive::recursive)]
178    fn rewrite<R: TreeNodeRewriter<Node = Self>>(
179        self,
180        rewriter: &mut R,
181    ) -> Result<Transformed<Self>> {
182        handle_transform_recursion!(rewriter.f_down(self), |c| c.rewrite(rewriter), |n| {
183            rewriter.f_up(n)
184        })
185    }
186
187    /// Applies `f` to the node then each of its children, recursively (a
188    /// top-down, pre-order traversal).
189    ///
190    /// The return [`TreeNodeRecursion`] controls the recursion and can cause
191    /// an early return.
192    ///
193    /// # See Also
194    /// * [`Self::transform_down`] for the equivalent transformation API.
195    /// * [`Self::visit`] for both top-down and bottom up traversal.
196    fn apply<'n, F: FnMut(&'n Self) -> Result<TreeNodeRecursion>>(
197        &'n self,
198        mut f: F,
199    ) -> Result<TreeNodeRecursion> {
200        #[cfg_attr(feature = "recursive_protection", recursive::recursive)]
201        fn apply_impl<'n, N: TreeNode, F: FnMut(&'n N) -> Result<TreeNodeRecursion>>(
202            node: &'n N,
203            f: &mut F,
204        ) -> Result<TreeNodeRecursion> {
205            f(node)?.visit_children(|| node.apply_children(|c| apply_impl(c, f)))
206        }
207
208        apply_impl(self, &mut f)
209    }
210
211    /// Recursively rewrite the node's children and then the node using `f`
212    /// (a bottom-up post-order traversal).
213    ///
214    /// A synonym of [`Self::transform_up`].
215    fn transform<F: FnMut(Self) -> Result<Transformed<Self>>>(
216        self,
217        f: F,
218    ) -> Result<Transformed<Self>> {
219        self.transform_up(f)
220    }
221
222    /// Recursively rewrite the tree using `f` in a top-down (pre-order)
223    /// fashion.
224    ///
225    /// `f` is applied to the node first, and then its children.
226    ///
227    /// # See Also
228    /// * [`Self::transform_up`] for a bottom-up (post-order) traversal.
229    /// * [Self::transform_down_up] for a combined traversal with closures
230    /// * [`Self::rewrite`] for a combined traversal with a visitor
231    fn transform_down<F: FnMut(Self) -> Result<Transformed<Self>>>(
232        self,
233        mut f: F,
234    ) -> Result<Transformed<Self>> {
235        #[cfg_attr(feature = "recursive_protection", recursive::recursive)]
236        fn transform_down_impl<N: TreeNode, F: FnMut(N) -> Result<Transformed<N>>>(
237            node: N,
238            f: &mut F,
239        ) -> Result<Transformed<N>> {
240            f(node)?.transform_children(|n| n.map_children(|c| transform_down_impl(c, f)))
241        }
242
243        transform_down_impl(self, &mut f)
244    }
245
246    /// Recursively rewrite the node using `f` in a bottom-up (post-order)
247    /// fashion.
248    ///
249    /// `f` is applied to the node's  children first, and then to the node itself.
250    ///
251    /// # See Also
252    /// * [`Self::transform_down`] top-down (pre-order) traversal.
253    /// * [Self::transform_down_up] for a combined traversal with closures
254    /// * [`Self::rewrite`] for a combined traversal with a visitor
255    fn transform_up<F: FnMut(Self) -> Result<Transformed<Self>>>(
256        self,
257        mut f: F,
258    ) -> Result<Transformed<Self>> {
259        #[cfg_attr(feature = "recursive_protection", recursive::recursive)]
260        fn transform_up_impl<N: TreeNode, F: FnMut(N) -> Result<Transformed<N>>>(
261            node: N,
262            f: &mut F,
263        ) -> Result<Transformed<N>> {
264            node.map_children(|c| transform_up_impl(c, f))?
265                .transform_parent(f)
266        }
267
268        transform_up_impl(self, &mut f)
269    }
270
271    /// Transforms the node using `f_down` while traversing the tree top-down
272    /// (pre-order), and using `f_up` while traversing the tree bottom-up
273    /// (post-order).
274    ///
275    /// The method behaves the same as calling [`Self::transform_down`] followed
276    /// by [`Self::transform_up`] on the same node. Use this method if you want
277    /// to start the `f_up` process right where `f_down` jumps. This can make
278    /// the whole process faster by reducing the number of `f_up` steps.
279    ///
280    /// # See Also
281    /// * [`Self::transform_up`] for a bottom-up (post-order) traversal.
282    /// * [Self::transform_down] for a top-down (pre-order) traversal.
283    /// * [`Self::rewrite`] for a combined traversal with a visitor
284    ///
285    /// # Example
286    /// Consider the following tree structure:
287    /// ```text
288    /// ParentNode
289    ///    left: ChildNode1
290    ///    right: ChildNode2
291    /// ```
292    ///
293    /// The nodes are visited using the following order:
294    /// ```text
295    /// f_down(ParentNode)
296    /// f_down(ChildNode1)
297    /// f_up(ChildNode1)
298    /// f_down(ChildNode2)
299    /// f_up(ChildNode2)
300    /// f_up(ParentNode)
301    /// ```
302    ///
303    /// See [`TreeNodeRecursion`] for more details on controlling the traversal.
304    ///
305    /// If `f_down` or `f_up` returns [`Err`], the recursion stops immediately.
306    ///
307    /// Example:
308    /// ```text
309    ///                                               |   +---+
310    ///                                               |   | J |
311    ///                                               |   +---+
312    ///                                               |     |
313    ///                                               |   +---+
314    ///                  TreeNodeRecursion::Continue  |   | I |
315    ///                                               |   +---+
316    ///                                               |     |
317    ///                                               |   +---+
318    ///                                              \|/  | F |
319    ///                                               '   +---+
320    ///                                                  /     \ ___________________
321    ///                  When `f_down` is           +---+                           \ ---+
322    ///                  applied on node "E",       | E |                            | G |
323    ///                  it returns with "Jump".    +---+                            +---+
324    ///                                               |                                |
325    ///                                             +---+                            +---+
326    ///                                             | C |                            | H |
327    ///                                             +---+                            +---+
328    ///                                             /   \
329    ///                                        +---+     +---+
330    ///                                        | B |     | D |
331    ///                                        +---+     +---+
332    ///                                                    |
333    ///                                                  +---+
334    ///                                                  | A |
335    ///                                                  +---+
336    ///
337    /// Instead of starting from leaf nodes, `f_up` starts from the node "E".
338    ///                                                   +---+
339    ///                                               |   | J |
340    ///                                               |   +---+
341    ///                                               |     |
342    ///                                               |   +---+
343    ///                                               |   | I |
344    ///                                               |   +---+
345    ///                                               |     |
346    ///                                              /    +---+
347    ///                                            /      | F |
348    ///                                          /        +---+
349    ///                                        /         /     \ ______________________
350    ///                                       |     +---+   .                          \ ---+
351    ///                                       |     | E |  /|\  After `f_down` jumps    | G |
352    ///                                       |     +---+   |   on node E, `f_up`       +---+
353    ///                                        \------| ---/   if applied on node E.      |
354    ///                                             +---+                               +---+
355    ///                                             | C |                               | H |
356    ///                                             +---+                               +---+
357    ///                                             /   \
358    ///                                        +---+     +---+
359    ///                                        | B |     | D |
360    ///                                        +---+     +---+
361    ///                                                    |
362    ///                                                  +---+
363    ///                                                  | A |
364    ///                                                  +---+
365    /// ```
366    fn transform_down_up<
367        FD: FnMut(Self) -> Result<Transformed<Self>>,
368        FU: FnMut(Self) -> Result<Transformed<Self>>,
369    >(
370        self,
371        mut f_down: FD,
372        mut f_up: FU,
373    ) -> Result<Transformed<Self>> {
374        #[cfg_attr(feature = "recursive_protection", recursive::recursive)]
375        fn transform_down_up_impl<
376            N: TreeNode,
377            FD: FnMut(N) -> Result<Transformed<N>>,
378            FU: FnMut(N) -> Result<Transformed<N>>,
379        >(
380            node: N,
381            f_down: &mut FD,
382            f_up: &mut FU,
383        ) -> Result<Transformed<N>> {
384            handle_transform_recursion!(
385                f_down(node),
386                |c| transform_down_up_impl(c, f_down, f_up),
387                f_up
388            )
389        }
390
391        transform_down_up_impl(self, &mut f_down, &mut f_up)
392    }
393
394    /// Returns true if `f` returns true for any node in the tree.
395    ///
396    /// Stops recursion as soon as a matching node is found
397    fn exists<F: FnMut(&Self) -> Result<bool>>(&self, mut f: F) -> Result<bool> {
398        let mut found = false;
399        self.apply(|n| {
400            Ok(if f(n)? {
401                found = true;
402                TreeNodeRecursion::Stop
403            } else {
404                TreeNodeRecursion::Continue
405            })
406        })
407        .map(|_| found)
408    }
409
410    /// Low-level API used to implement other APIs.
411    ///
412    /// If you want to implement the [`TreeNode`] trait for your own type, you
413    /// should implement this method and [`Self::map_children`].
414    ///
415    /// Users should use one of the higher level APIs described on [`Self`].
416    ///
417    /// Description: Apply `f` to inspect node's children (but not the node
418    /// itself).
419    fn apply_children<'n, F: FnMut(&'n Self) -> Result<TreeNodeRecursion>>(
420        &'n self,
421        f: F,
422    ) -> Result<TreeNodeRecursion>;
423
424    /// Low-level API used to implement other APIs.
425    ///
426    /// If you want to implement the [`TreeNode`] trait for your own type, you
427    /// should implement this method and [`Self::apply_children`].
428    ///
429    /// Users should use one of the higher level APIs described on [`Self`].
430    ///
431    /// Description: Apply `f` to rewrite the node's children (but not the node itself).
432    fn map_children<F: FnMut(Self) -> Result<Transformed<Self>>>(
433        self,
434        f: F,
435    ) -> Result<Transformed<Self>>;
436}
437
438/// A [Visitor](https://en.wikipedia.org/wiki/Visitor_pattern) for recursively
439/// inspecting [`TreeNode`]s via [`TreeNode::visit`].
440///
441/// See [`TreeNode`] for more details on available APIs
442///
443/// When passed to [`TreeNode::visit`], [`TreeNodeVisitor::f_down`] and
444/// [`TreeNodeVisitor::f_up`] are invoked recursively on the tree.
445/// See [`TreeNodeRecursion`] for more details on controlling the traversal.
446///
447/// # Return Value
448/// The returns value of `f_up` and `f_down` specifies how the tree walk should
449/// proceed. See [`TreeNodeRecursion`] for details. If an [`Err`] is returned,
450/// the recursion stops immediately.
451///
452/// Note: If using the default implementations of [`TreeNodeVisitor::f_up`] or
453/// [`TreeNodeVisitor::f_down`] that do nothing, consider using
454/// [`TreeNode::apply`] instead.
455///
456/// # See Also:
457/// * [`TreeNode::rewrite`] to rewrite owned `TreeNode`s
458pub trait TreeNodeVisitor<'n>: Sized {
459    /// The node type which is visitable.
460    type Node: TreeNode;
461
462    /// Invoked while traversing down the tree, before any children are visited.
463    /// Default implementation continues the recursion.
464    fn f_down(&mut self, _node: &'n Self::Node) -> Result<TreeNodeRecursion> {
465        Ok(TreeNodeRecursion::Continue)
466    }
467
468    /// Invoked while traversing up the tree after children are visited. Default
469    /// implementation continues the recursion.
470    fn f_up(&mut self, _node: &'n Self::Node) -> Result<TreeNodeRecursion> {
471        Ok(TreeNodeRecursion::Continue)
472    }
473}
474
475/// A [Visitor](https://en.wikipedia.org/wiki/Visitor_pattern) for recursively
476/// rewriting [`TreeNode`]s via [`TreeNode::rewrite`].
477///
478/// For example you can implement this trait on a struct to rewrite `Expr` or
479/// `LogicalPlan` that needs to track state during the rewrite.
480///
481/// See [`TreeNode`] for more details on available APIs
482///
483/// When passed to [`TreeNode::rewrite`], [`TreeNodeRewriter::f_down`] and
484/// [`TreeNodeRewriter::f_up`] are invoked recursively on the tree.
485/// See [`TreeNodeRecursion`] for more details on controlling the traversal.
486///
487/// # Return Value
488/// The returns value of `f_up` and `f_down` specifies how the tree walk should
489/// proceed. See [`TreeNodeRecursion`] for details. If an [`Err`] is returned,
490/// the recursion stops immediately.
491///
492/// Note: If using the default implementations of [`TreeNodeRewriter::f_up`] or
493/// [`TreeNodeRewriter::f_down`] that do nothing, consider using
494/// [`TreeNode::transform_up`] or [`TreeNode::transform_down`] instead.
495///
496/// # See Also:
497/// * [`TreeNode::visit`] to inspect borrowed `TreeNode`s
498pub trait TreeNodeRewriter: Sized {
499    /// The node type which is rewritable.
500    type Node: TreeNode;
501
502    /// Invoked while traversing down the tree before any children are rewritten.
503    /// Default implementation returns the node as is and continues recursion.
504    fn f_down(&mut self, node: Self::Node) -> Result<Transformed<Self::Node>> {
505        Ok(Transformed::no(node))
506    }
507
508    /// Invoked while traversing up the tree after all children have been rewritten.
509    /// Default implementation returns the node as is and continues recursion.
510    fn f_up(&mut self, node: Self::Node) -> Result<Transformed<Self::Node>> {
511        Ok(Transformed::no(node))
512    }
513}
514
515/// Controls how [`TreeNode`] recursions should proceed.
516#[derive(Debug, PartialEq, Clone, Copy)]
517pub enum TreeNodeRecursion {
518    /// Continue recursion with the next node.
519    Continue,
520    /// In top-down traversals, skip recursing into children but continue with
521    /// the next node, which actually means pruning of the subtree.
522    ///
523    /// In bottom-up traversals, bypass calling bottom-up closures till the next
524    /// leaf node.
525    ///
526    /// In combined traversals, if it is the `f_down` (pre-order) phase, execution
527    /// "jumps" to the next `f_up` (post-order) phase by shortcutting its children.
528    /// If it is the `f_up` (post-order) phase, execution "jumps" to the next `f_down`
529    /// (pre-order) phase by shortcutting its parent nodes until the first parent node
530    /// having unvisited children path.
531    Jump,
532    /// Stop recursion.
533    Stop,
534}
535
536impl TreeNodeRecursion {
537    /// Continues visiting nodes with `f` depending on the current [`TreeNodeRecursion`]
538    /// value and the fact that `f` is visiting the current node's children.
539    pub fn visit_children<F: FnOnce() -> Result<TreeNodeRecursion>>(
540        self,
541        f: F,
542    ) -> Result<TreeNodeRecursion> {
543        match self {
544            TreeNodeRecursion::Continue => f(),
545            TreeNodeRecursion::Jump => Ok(TreeNodeRecursion::Continue),
546            TreeNodeRecursion::Stop => Ok(self),
547        }
548    }
549
550    /// Continues visiting nodes with `f` depending on the current [`TreeNodeRecursion`]
551    /// value and the fact that `f` is visiting the current node's sibling.
552    pub fn visit_sibling<F: FnOnce() -> Result<TreeNodeRecursion>>(
553        self,
554        f: F,
555    ) -> Result<TreeNodeRecursion> {
556        match self {
557            TreeNodeRecursion::Continue | TreeNodeRecursion::Jump => f(),
558            TreeNodeRecursion::Stop => Ok(self),
559        }
560    }
561
562    /// Continues visiting nodes with `f` depending on the current [`TreeNodeRecursion`]
563    /// value and the fact that `f` is visiting the current node's parent.
564    pub fn visit_parent<F: FnOnce() -> Result<TreeNodeRecursion>>(
565        self,
566        f: F,
567    ) -> Result<TreeNodeRecursion> {
568        match self {
569            TreeNodeRecursion::Continue => f(),
570            TreeNodeRecursion::Jump | TreeNodeRecursion::Stop => Ok(self),
571        }
572    }
573}
574
575/// Result of tree walk / transformation APIs
576///
577/// `Transformed` is a wrapper around the tree node data (e.g. `Expr` or
578/// `LogicalPlan`). It is used to indicate whether the node was transformed
579/// and how the recursion should proceed.
580///
581/// [`TreeNode`] API users control the transformation by returning:
582/// - The resulting (possibly transformed) node,
583/// - `transformed`: flag indicating whether any change was made to the node
584/// - `tnr`: [`TreeNodeRecursion`] specifying how to proceed with the recursion.
585///
586/// At the end of the transformation, the return value will contain:
587/// - The final (possibly transformed) tree,
588/// - `transformed`: flag indicating whether any change was made to the node
589/// - `tnr`: [`TreeNodeRecursion`] specifying how the recursion ended.
590///
591/// See also
592/// * [`Transformed::update_data`] to modify the node without changing the `transformed` flag
593/// * [`Transformed::map_data`] for fallable operation that return the same type
594/// * [`Transformed::transform_data`] to chain fallable transformations
595/// * [`TransformedResult`] for working with `Result<Transformed<U>>`
596///
597/// # Examples
598///
599/// Use [`Transformed::yes`] and [`Transformed::no`] to signal that a node was
600/// rewritten and the recursion should continue:
601///
602/// ```
603/// # use datafusion_common::tree_node::Transformed;
604/// # // note use i64 instead of Expr as Expr is not in datafusion-common
605/// # fn orig_expr() -> i64 { 1 }
606/// # fn make_new_expr(i: i64) -> i64 { 2 }
607/// let expr = orig_expr();
608///
609/// // Create a new `Transformed` object signaling the node was not rewritten
610/// let ret = Transformed::no(expr.clone());
611/// assert!(!ret.transformed);
612///
613/// // Create a new `Transformed` object signaling the node was rewritten
614/// let ret = Transformed::yes(expr);
615/// assert!(ret.transformed)
616/// ```
617///
618/// Access the node within the `Transformed` object:
619/// ```
620/// # use datafusion_common::tree_node::Transformed;
621/// # // note use i64 instead of Expr as Expr is not in datafusion-common
622/// # fn orig_expr() -> i64 { 1 }
623/// # fn make_new_expr(i: i64) -> i64 { 2 }
624/// let expr = orig_expr();
625///
626/// // `Transformed` object signaling the node was not rewritten
627/// let ret = Transformed::no(expr.clone());
628/// // Access the inner object using .data
629/// assert_eq!(expr, ret.data);
630/// ```
631///
632/// Transform the node within the `Transformed` object.
633///
634/// ```
635/// # use datafusion_common::tree_node::Transformed;
636/// # // note use i64 instead of Expr as Expr is not in datafusion-common
637/// # fn orig_expr() -> i64 { 1 }
638/// # fn make_new_expr(i: i64) -> i64 { 2 }
639/// let expr = orig_expr();
640/// let ret = Transformed::no(expr.clone())
641///   .transform_data(|expr| {
642///    // closure returns a result and potentially transforms the node
643///    // in this example, it does transform the node
644///    let new_expr = make_new_expr(expr);
645///    Ok(Transformed::yes(new_expr))
646///  }).unwrap();
647/// // transformed flag is the union of the original ans closure's  transformed flag
648/// assert!(ret.transformed);
649/// ```
650/// # Example APIs that use `TreeNode`
651/// - [`TreeNode`],
652/// - [`TreeNode::rewrite`],
653/// - [`TreeNode::transform_down`],
654/// - [`TreeNode::transform_up`],
655/// - [`TreeNode::transform_down_up`]
656#[derive(PartialEq, Debug)]
657pub struct Transformed<T> {
658    pub data: T,
659    pub transformed: bool,
660    pub tnr: TreeNodeRecursion,
661}
662
663impl<T> Transformed<T> {
664    /// Create a new `Transformed` object with the given information.
665    pub fn new(data: T, transformed: bool, tnr: TreeNodeRecursion) -> Self {
666        Self {
667            data,
668            transformed,
669            tnr,
670        }
671    }
672
673    /// Create a `Transformed` with `transformed` and [`TreeNodeRecursion::Continue`].
674    pub fn new_transformed(data: T, transformed: bool) -> Self {
675        Self::new(data, transformed, TreeNodeRecursion::Continue)
676    }
677
678    /// Wrapper for transformed data with [`TreeNodeRecursion::Continue`] statement.
679    pub fn yes(data: T) -> Self {
680        Self::new(data, true, TreeNodeRecursion::Continue)
681    }
682
683    /// Wrapper for transformed data with [`TreeNodeRecursion::Stop`] statement.
684    pub fn complete(data: T) -> Self {
685        Self::new(data, true, TreeNodeRecursion::Stop)
686    }
687
688    /// Wrapper for unchanged data with [`TreeNodeRecursion::Continue`] statement.
689    pub fn no(data: T) -> Self {
690        Self::new(data, false, TreeNodeRecursion::Continue)
691    }
692
693    /// Applies an infallible `f` to the data of this [`Transformed`] object,
694    /// without modifying the `transformed` flag.
695    pub fn update_data<U, F: FnOnce(T) -> U>(self, f: F) -> Transformed<U> {
696        Transformed::new(f(self.data), self.transformed, self.tnr)
697    }
698
699    /// Applies a fallible `f` (returns `Result`) to the data of this
700    /// [`Transformed`] object, without modifying the `transformed` flag.
701    pub fn map_data<U, F: FnOnce(T) -> Result<U>>(self, f: F) -> Result<Transformed<U>> {
702        f(self.data).map(|data| Transformed::new(data, self.transformed, self.tnr))
703    }
704
705    /// Applies a fallible transforming `f` to the data of this [`Transformed`]
706    /// object.
707    ///
708    /// The returned `Transformed` object has the `transformed` flag set if either
709    /// `self` or the return value of `f` have the `transformed` flag set.
710    pub fn transform_data<U, F: FnOnce(T) -> Result<Transformed<U>>>(
711        self,
712        f: F,
713    ) -> Result<Transformed<U>> {
714        f(self.data).map(|mut t| {
715            t.transformed |= self.transformed;
716            t
717        })
718    }
719
720    /// Maps the [`Transformed`] object to the result of the given `f` depending on the
721    /// current [`TreeNodeRecursion`] value and the fact that `f` is changing the current
722    /// node's children.
723    pub fn transform_children<F: FnOnce(T) -> Result<Transformed<T>>>(
724        mut self,
725        f: F,
726    ) -> Result<Transformed<T>> {
727        match self.tnr {
728            TreeNodeRecursion::Continue => {
729                return f(self.data).map(|mut t| {
730                    t.transformed |= self.transformed;
731                    t
732                });
733            }
734            TreeNodeRecursion::Jump => {
735                self.tnr = TreeNodeRecursion::Continue;
736            }
737            TreeNodeRecursion::Stop => {}
738        }
739        Ok(self)
740    }
741
742    /// Maps the [`Transformed`] object to the result of the given `f` depending on the
743    /// current [`TreeNodeRecursion`] value and the fact that `f` is changing the current
744    /// node's sibling.
745    pub fn transform_sibling<F: FnOnce(T) -> Result<Transformed<T>>>(
746        self,
747        f: F,
748    ) -> Result<Transformed<T>> {
749        match self.tnr {
750            TreeNodeRecursion::Continue | TreeNodeRecursion::Jump => {
751                f(self.data).map(|mut t| {
752                    t.transformed |= self.transformed;
753                    t
754                })
755            }
756            TreeNodeRecursion::Stop => Ok(self),
757        }
758    }
759
760    /// Maps the [`Transformed`] object to the result of the given `f` depending on the
761    /// current [`TreeNodeRecursion`] value and the fact that `f` is changing the current
762    /// node's parent.
763    pub fn transform_parent<F: FnOnce(T) -> Result<Transformed<T>>>(
764        self,
765        f: F,
766    ) -> Result<Transformed<T>> {
767        match self.tnr {
768            TreeNodeRecursion::Continue => f(self.data).map(|mut t| {
769                t.transformed |= self.transformed;
770                t
771            }),
772            TreeNodeRecursion::Jump | TreeNodeRecursion::Stop => Ok(self),
773        }
774    }
775}
776
777/// [`TreeNodeContainer`] contains elements that a function can be applied on or mapped.
778/// The elements of the container are siblings so the continuation rules are similar to
779/// [`TreeNodeRecursion::visit_sibling`] / [`Transformed::transform_sibling`].
780pub trait TreeNodeContainer<'a, T: 'a>: Sized {
781    /// Applies `f` to all elements of the container.
782    /// This method is usually called from [`TreeNode::apply_children`] implementations as
783    /// a node is actually a container of the node's children.
784    fn apply_elements<F: FnMut(&'a T) -> Result<TreeNodeRecursion>>(
785        &'a self,
786        f: F,
787    ) -> Result<TreeNodeRecursion>;
788
789    /// Maps all elements of the container with `f`.
790    /// This method is usually called from [`TreeNode::map_children`] implementations as
791    /// a node is actually a container of the node's children.
792    fn map_elements<F: FnMut(T) -> Result<Transformed<T>>>(
793        self,
794        f: F,
795    ) -> Result<Transformed<Self>>;
796}
797
798impl<'a, T: 'a, C: TreeNodeContainer<'a, T>> TreeNodeContainer<'a, T> for Box<C> {
799    fn apply_elements<F: FnMut(&'a T) -> Result<TreeNodeRecursion>>(
800        &'a self,
801        f: F,
802    ) -> Result<TreeNodeRecursion> {
803        self.as_ref().apply_elements(f)
804    }
805
806    fn map_elements<F: FnMut(T) -> Result<Transformed<T>>>(
807        self,
808        f: F,
809    ) -> Result<Transformed<Self>> {
810        (*self).map_elements(f)?.map_data(|c| Ok(Self::new(c)))
811    }
812}
813
814impl<'a, T: 'a, C: TreeNodeContainer<'a, T> + Clone> TreeNodeContainer<'a, T> for Arc<C> {
815    fn apply_elements<F: FnMut(&'a T) -> Result<TreeNodeRecursion>>(
816        &'a self,
817        f: F,
818    ) -> Result<TreeNodeRecursion> {
819        self.as_ref().apply_elements(f)
820    }
821
822    fn map_elements<F: FnMut(T) -> Result<Transformed<T>>>(
823        self,
824        f: F,
825    ) -> Result<Transformed<Self>> {
826        Arc::unwrap_or_clone(self)
827            .map_elements(f)?
828            .map_data(|c| Ok(Arc::new(c)))
829    }
830}
831
832impl<'a, T: 'a, C: TreeNodeContainer<'a, T>> TreeNodeContainer<'a, T> for Option<C> {
833    fn apply_elements<F: FnMut(&'a T) -> Result<TreeNodeRecursion>>(
834        &'a self,
835        f: F,
836    ) -> Result<TreeNodeRecursion> {
837        match self {
838            Some(t) => t.apply_elements(f),
839            None => Ok(TreeNodeRecursion::Continue),
840        }
841    }
842
843    fn map_elements<F: FnMut(T) -> Result<Transformed<T>>>(
844        self,
845        f: F,
846    ) -> Result<Transformed<Self>> {
847        self.map_or(Ok(Transformed::no(None)), |c| {
848            c.map_elements(f)?.map_data(|c| Ok(Some(c)))
849        })
850    }
851}
852
853impl<'a, T: 'a, C: TreeNodeContainer<'a, T>> TreeNodeContainer<'a, T> for Vec<C> {
854    fn apply_elements<F: FnMut(&'a T) -> Result<TreeNodeRecursion>>(
855        &'a self,
856        mut f: F,
857    ) -> Result<TreeNodeRecursion> {
858        let mut tnr = TreeNodeRecursion::Continue;
859        for c in self {
860            tnr = c.apply_elements(&mut f)?;
861            match tnr {
862                TreeNodeRecursion::Continue | TreeNodeRecursion::Jump => {}
863                TreeNodeRecursion::Stop => return Ok(TreeNodeRecursion::Stop),
864            }
865        }
866        Ok(tnr)
867    }
868
869    fn map_elements<F: FnMut(T) -> Result<Transformed<T>>>(
870        self,
871        mut f: F,
872    ) -> Result<Transformed<Self>> {
873        let mut tnr = TreeNodeRecursion::Continue;
874        let mut transformed = false;
875        self.into_iter()
876            .map(|c| match tnr {
877                TreeNodeRecursion::Continue | TreeNodeRecursion::Jump => {
878                    c.map_elements(&mut f).map(|result| {
879                        tnr = result.tnr;
880                        transformed |= result.transformed;
881                        result.data
882                    })
883                }
884                TreeNodeRecursion::Stop => Ok(c),
885            })
886            .collect::<Result<Vec<_>>>()
887            .map(|data| Transformed::new(data, transformed, tnr))
888    }
889}
890
891impl<'a, T: 'a, K: Eq + Hash, C: TreeNodeContainer<'a, T>> TreeNodeContainer<'a, T>
892    for HashMap<K, C>
893{
894    fn apply_elements<F: FnMut(&'a T) -> Result<TreeNodeRecursion>>(
895        &'a self,
896        mut f: F,
897    ) -> Result<TreeNodeRecursion> {
898        let mut tnr = TreeNodeRecursion::Continue;
899        for c in self.values() {
900            tnr = c.apply_elements(&mut f)?;
901            match tnr {
902                TreeNodeRecursion::Continue | TreeNodeRecursion::Jump => {}
903                TreeNodeRecursion::Stop => return Ok(TreeNodeRecursion::Stop),
904            }
905        }
906        Ok(tnr)
907    }
908
909    fn map_elements<F: FnMut(T) -> Result<Transformed<T>>>(
910        self,
911        mut f: F,
912    ) -> Result<Transformed<Self>> {
913        let mut tnr = TreeNodeRecursion::Continue;
914        let mut transformed = false;
915        self.into_iter()
916            .map(|(k, c)| match tnr {
917                TreeNodeRecursion::Continue | TreeNodeRecursion::Jump => {
918                    c.map_elements(&mut f).map(|result| {
919                        tnr = result.tnr;
920                        transformed |= result.transformed;
921                        (k, result.data)
922                    })
923                }
924                TreeNodeRecursion::Stop => Ok((k, c)),
925            })
926            .collect::<Result<HashMap<_, _>>>()
927            .map(|data| Transformed::new(data, transformed, tnr))
928    }
929}
930
931impl<'a, T: 'a, C0: TreeNodeContainer<'a, T>, C1: TreeNodeContainer<'a, T>>
932    TreeNodeContainer<'a, T> for (C0, C1)
933{
934    fn apply_elements<F: FnMut(&'a T) -> Result<TreeNodeRecursion>>(
935        &'a self,
936        mut f: F,
937    ) -> Result<TreeNodeRecursion> {
938        self.0
939            .apply_elements(&mut f)?
940            .visit_sibling(|| self.1.apply_elements(&mut f))
941    }
942
943    fn map_elements<F: FnMut(T) -> Result<Transformed<T>>>(
944        self,
945        mut f: F,
946    ) -> Result<Transformed<Self>> {
947        self.0
948            .map_elements(&mut f)?
949            .map_data(|new_c0| Ok((new_c0, self.1)))?
950            .transform_sibling(|(new_c0, c1)| {
951                c1.map_elements(&mut f)?
952                    .map_data(|new_c1| Ok((new_c0, new_c1)))
953            })
954    }
955}
956
957impl<
958        'a,
959        T: 'a,
960        C0: TreeNodeContainer<'a, T>,
961        C1: TreeNodeContainer<'a, T>,
962        C2: TreeNodeContainer<'a, T>,
963    > TreeNodeContainer<'a, T> for (C0, C1, C2)
964{
965    fn apply_elements<F: FnMut(&'a T) -> Result<TreeNodeRecursion>>(
966        &'a self,
967        mut f: F,
968    ) -> Result<TreeNodeRecursion> {
969        self.0
970            .apply_elements(&mut f)?
971            .visit_sibling(|| self.1.apply_elements(&mut f))?
972            .visit_sibling(|| self.2.apply_elements(&mut f))
973    }
974
975    fn map_elements<F: FnMut(T) -> Result<Transformed<T>>>(
976        self,
977        mut f: F,
978    ) -> Result<Transformed<Self>> {
979        self.0
980            .map_elements(&mut f)?
981            .map_data(|new_c0| Ok((new_c0, self.1, self.2)))?
982            .transform_sibling(|(new_c0, c1, c2)| {
983                c1.map_elements(&mut f)?
984                    .map_data(|new_c1| Ok((new_c0, new_c1, c2)))
985            })?
986            .transform_sibling(|(new_c0, new_c1, c2)| {
987                c2.map_elements(&mut f)?
988                    .map_data(|new_c2| Ok((new_c0, new_c1, new_c2)))
989            })
990    }
991}
992
993impl<
994        'a,
995        T: 'a,
996        C0: TreeNodeContainer<'a, T>,
997        C1: TreeNodeContainer<'a, T>,
998        C2: TreeNodeContainer<'a, T>,
999        C3: TreeNodeContainer<'a, T>,
1000    > TreeNodeContainer<'a, T> for (C0, C1, C2, C3)
1001{
1002    fn apply_elements<F: FnMut(&'a T) -> Result<TreeNodeRecursion>>(
1003        &'a self,
1004        mut f: F,
1005    ) -> Result<TreeNodeRecursion> {
1006        self.0
1007            .apply_elements(&mut f)?
1008            .visit_sibling(|| self.1.apply_elements(&mut f))?
1009            .visit_sibling(|| self.2.apply_elements(&mut f))?
1010            .visit_sibling(|| self.3.apply_elements(&mut f))
1011    }
1012
1013    fn map_elements<F: FnMut(T) -> Result<Transformed<T>>>(
1014        self,
1015        mut f: F,
1016    ) -> Result<Transformed<Self>> {
1017        self.0
1018            .map_elements(&mut f)?
1019            .map_data(|new_c0| Ok((new_c0, self.1, self.2, self.3)))?
1020            .transform_sibling(|(new_c0, c1, c2, c3)| {
1021                c1.map_elements(&mut f)?
1022                    .map_data(|new_c1| Ok((new_c0, new_c1, c2, c3)))
1023            })?
1024            .transform_sibling(|(new_c0, new_c1, c2, c3)| {
1025                c2.map_elements(&mut f)?
1026                    .map_data(|new_c2| Ok((new_c0, new_c1, new_c2, c3)))
1027            })?
1028            .transform_sibling(|(new_c0, new_c1, new_c2, c3)| {
1029                c3.map_elements(&mut f)?
1030                    .map_data(|new_c3| Ok((new_c0, new_c1, new_c2, new_c3)))
1031            })
1032    }
1033}
1034
1035/// [`TreeNodeRefContainer`] contains references to elements that a function can be
1036/// applied on. The elements of the container are siblings so the continuation rules are
1037/// similar to [`TreeNodeRecursion::visit_sibling`].
1038///
1039/// This container is similar to [`TreeNodeContainer`], but the lifetime of the reference
1040/// elements (`T`) are not derived from the container's lifetime.
1041/// A typical usage of this container is in `Expr::apply_children` when we need to
1042/// construct a temporary container to be able to call `apply_ref_elements` on a
1043/// collection of tree node references. But in that case the container's temporary
1044/// lifetime is different to the lifetime of tree nodes that we put into it.
1045/// Please find an example use case in `Expr::apply_children` with the `Expr::Case` case.
1046///
1047/// Most of the cases we don't need to create a temporary container with
1048/// `TreeNodeRefContainer`, but we can just call `TreeNodeContainer::apply_elements`.
1049/// Please find an example use case in `Expr::apply_children` with the `Expr::GroupingSet`
1050/// case.
1051pub trait TreeNodeRefContainer<'a, T: 'a>: Sized {
1052    /// Applies `f` to all elements of the container.
1053    /// This method is usually called from [`TreeNode::apply_children`] implementations as
1054    /// a node is actually a container of the node's children.
1055    fn apply_ref_elements<F: FnMut(&'a T) -> Result<TreeNodeRecursion>>(
1056        &self,
1057        f: F,
1058    ) -> Result<TreeNodeRecursion>;
1059}
1060
1061impl<'a, T: 'a, C: TreeNodeContainer<'a, T>> TreeNodeRefContainer<'a, T> for Vec<&'a C> {
1062    fn apply_ref_elements<F: FnMut(&'a T) -> Result<TreeNodeRecursion>>(
1063        &self,
1064        mut f: F,
1065    ) -> Result<TreeNodeRecursion> {
1066        let mut tnr = TreeNodeRecursion::Continue;
1067        for c in self {
1068            tnr = c.apply_elements(&mut f)?;
1069            match tnr {
1070                TreeNodeRecursion::Continue | TreeNodeRecursion::Jump => {}
1071                TreeNodeRecursion::Stop => return Ok(TreeNodeRecursion::Stop),
1072            }
1073        }
1074        Ok(tnr)
1075    }
1076}
1077
1078impl<'a, T: 'a, C0: TreeNodeContainer<'a, T>, C1: TreeNodeContainer<'a, T>>
1079    TreeNodeRefContainer<'a, T> for (&'a C0, &'a C1)
1080{
1081    fn apply_ref_elements<F: FnMut(&'a T) -> Result<TreeNodeRecursion>>(
1082        &self,
1083        mut f: F,
1084    ) -> Result<TreeNodeRecursion> {
1085        self.0
1086            .apply_elements(&mut f)?
1087            .visit_sibling(|| self.1.apply_elements(&mut f))
1088    }
1089}
1090
1091impl<
1092        'a,
1093        T: 'a,
1094        C0: TreeNodeContainer<'a, T>,
1095        C1: TreeNodeContainer<'a, T>,
1096        C2: TreeNodeContainer<'a, T>,
1097    > TreeNodeRefContainer<'a, T> for (&'a C0, &'a C1, &'a C2)
1098{
1099    fn apply_ref_elements<F: FnMut(&'a T) -> Result<TreeNodeRecursion>>(
1100        &self,
1101        mut f: F,
1102    ) -> Result<TreeNodeRecursion> {
1103        self.0
1104            .apply_elements(&mut f)?
1105            .visit_sibling(|| self.1.apply_elements(&mut f))?
1106            .visit_sibling(|| self.2.apply_elements(&mut f))
1107    }
1108}
1109
1110impl<
1111        'a,
1112        T: 'a,
1113        C0: TreeNodeContainer<'a, T>,
1114        C1: TreeNodeContainer<'a, T>,
1115        C2: TreeNodeContainer<'a, T>,
1116        C3: TreeNodeContainer<'a, T>,
1117    > TreeNodeRefContainer<'a, T> for (&'a C0, &'a C1, &'a C2, &'a C3)
1118{
1119    fn apply_ref_elements<F: FnMut(&'a T) -> Result<TreeNodeRecursion>>(
1120        &self,
1121        mut f: F,
1122    ) -> Result<TreeNodeRecursion> {
1123        self.0
1124            .apply_elements(&mut f)?
1125            .visit_sibling(|| self.1.apply_elements(&mut f))?
1126            .visit_sibling(|| self.2.apply_elements(&mut f))?
1127            .visit_sibling(|| self.3.apply_elements(&mut f))
1128    }
1129}
1130
1131/// Transformation helper to process a sequence of iterable tree nodes that are siblings.
1132pub trait TreeNodeIterator: Iterator {
1133    /// Apples `f` to each item in this iterator
1134    ///
1135    /// Visits all items in the iterator unless
1136    /// `f` returns an error or `f` returns `TreeNodeRecursion::Stop`.
1137    ///
1138    /// # Returns
1139    /// Error if `f` returns an error or `Ok(TreeNodeRecursion)` from the last invocation
1140    /// of `f` or `Continue` if the iterator is empty
1141    fn apply_until_stop<F: FnMut(Self::Item) -> Result<TreeNodeRecursion>>(
1142        self,
1143        f: F,
1144    ) -> Result<TreeNodeRecursion>;
1145
1146    /// Apples `f` to each item in this iterator
1147    ///
1148    /// Visits all items in the iterator unless
1149    /// `f` returns an error or `f` returns `TreeNodeRecursion::Stop`.
1150    ///
1151    /// # Returns
1152    /// Error if `f` returns an error
1153    ///
1154    /// Ok(Transformed) such that:
1155    /// 1. `transformed` is true if any return from `f` had transformed true
1156    /// 2. `data` from the last invocation of `f`
1157    /// 3. `tnr` from the last invocation of `f` or `Continue` if the iterator is empty
1158    fn map_until_stop_and_collect<
1159        F: FnMut(Self::Item) -> Result<Transformed<Self::Item>>,
1160    >(
1161        self,
1162        f: F,
1163    ) -> Result<Transformed<Vec<Self::Item>>>;
1164}
1165
1166impl<I: Iterator> TreeNodeIterator for I {
1167    fn apply_until_stop<F: FnMut(Self::Item) -> Result<TreeNodeRecursion>>(
1168        self,
1169        mut f: F,
1170    ) -> Result<TreeNodeRecursion> {
1171        let mut tnr = TreeNodeRecursion::Continue;
1172        for i in self {
1173            tnr = f(i)?;
1174            match tnr {
1175                TreeNodeRecursion::Continue | TreeNodeRecursion::Jump => {}
1176                TreeNodeRecursion::Stop => return Ok(TreeNodeRecursion::Stop),
1177            }
1178        }
1179        Ok(tnr)
1180    }
1181
1182    fn map_until_stop_and_collect<
1183        F: FnMut(Self::Item) -> Result<Transformed<Self::Item>>,
1184    >(
1185        self,
1186        mut f: F,
1187    ) -> Result<Transformed<Vec<Self::Item>>> {
1188        let mut tnr = TreeNodeRecursion::Continue;
1189        let mut transformed = false;
1190        self.map(|item| match tnr {
1191            TreeNodeRecursion::Continue | TreeNodeRecursion::Jump => {
1192                f(item).map(|result| {
1193                    tnr = result.tnr;
1194                    transformed |= result.transformed;
1195                    result.data
1196                })
1197            }
1198            TreeNodeRecursion::Stop => Ok(item),
1199        })
1200        .collect::<Result<Vec<_>>>()
1201        .map(|data| Transformed::new(data, transformed, tnr))
1202    }
1203}
1204
1205/// Transformation helper to access [`Transformed`] fields in a [`Result`] easily.
1206///
1207/// # Example
1208/// Access the internal data of a `Result<Transformed<T>>`
1209/// as a `Result<T>` using the `data` method:
1210/// ```
1211/// # use datafusion_common::Result;
1212/// # use datafusion_common::tree_node::{Transformed, TransformedResult};
1213/// # // note use i64 instead of Expr as Expr is not in datafusion-common
1214/// # fn update_expr() -> i64 { 1 }
1215/// # fn main() -> Result<()> {
1216/// let transformed: Result<Transformed<_>> = Ok(Transformed::yes(update_expr()));
1217/// // access the internal data of the transformed result, or return the error
1218/// let transformed_expr = transformed.data()?;
1219/// # Ok(())
1220/// # }
1221/// ```
1222pub trait TransformedResult<T> {
1223    fn data(self) -> Result<T>;
1224
1225    fn transformed(self) -> Result<bool>;
1226
1227    fn tnr(self) -> Result<TreeNodeRecursion>;
1228}
1229
1230impl<T> TransformedResult<T> for Result<Transformed<T>> {
1231    fn data(self) -> Result<T> {
1232        self.map(|t| t.data)
1233    }
1234
1235    fn transformed(self) -> Result<bool> {
1236        self.map(|t| t.transformed)
1237    }
1238
1239    fn tnr(self) -> Result<TreeNodeRecursion> {
1240        self.map(|t| t.tnr)
1241    }
1242}
1243
1244/// Helper trait for implementing [`TreeNode`] that have children stored as
1245/// `Arc`s. If some trait object, such as `dyn T`, implements this trait,
1246/// its related `Arc<dyn T>` will automatically implement [`TreeNode`].
1247pub trait DynTreeNode {
1248    /// Returns all children of the specified `TreeNode`.
1249    fn arc_children(&self) -> Vec<&Arc<Self>>;
1250
1251    /// Constructs a new node with the specified children.
1252    fn with_new_arc_children(
1253        &self,
1254        arc_self: Arc<Self>,
1255        new_children: Vec<Arc<Self>>,
1256    ) -> Result<Arc<Self>>;
1257}
1258
1259/// Blanket implementation for any `Arc<T>` where `T` implements [`DynTreeNode`]
1260/// (such as [`Arc<dyn PhysicalExpr>`]).
1261impl<T: DynTreeNode + ?Sized> TreeNode for Arc<T> {
1262    fn apply_children<'n, F: FnMut(&'n Self) -> Result<TreeNodeRecursion>>(
1263        &'n self,
1264        f: F,
1265    ) -> Result<TreeNodeRecursion> {
1266        self.arc_children().into_iter().apply_until_stop(f)
1267    }
1268
1269    fn map_children<F: FnMut(Self) -> Result<Transformed<Self>>>(
1270        self,
1271        f: F,
1272    ) -> Result<Transformed<Self>> {
1273        let children = self.arc_children();
1274        if !children.is_empty() {
1275            let new_children = children
1276                .into_iter()
1277                .cloned()
1278                .map_until_stop_and_collect(f)?;
1279            // Propagate up `new_children.transformed` and `new_children.tnr`
1280            // along with the node containing transformed children.
1281            if new_children.transformed {
1282                let arc_self = Arc::clone(&self);
1283                new_children.map_data(|new_children| {
1284                    self.with_new_arc_children(arc_self, new_children)
1285                })
1286            } else {
1287                Ok(Transformed::new(self, false, new_children.tnr))
1288            }
1289        } else {
1290            Ok(Transformed::no(self))
1291        }
1292    }
1293}
1294
1295/// Instead of implementing [`TreeNode`], it's recommended to implement a [`ConcreteTreeNode`] for
1296/// trees that contain nodes with payloads. This approach ensures safe execution of algorithms
1297/// involving payloads, by enforcing rules for detaching and reattaching child nodes.
1298pub trait ConcreteTreeNode: Sized {
1299    /// Provides read-only access to child nodes.
1300    fn children(&self) -> &[Self];
1301
1302    /// Detaches the node from its children, returning the node itself and its detached children.
1303    fn take_children(self) -> (Self, Vec<Self>);
1304
1305    /// Reattaches updated child nodes to the node, returning the updated node.
1306    fn with_new_children(self, children: Vec<Self>) -> Result<Self>;
1307}
1308
1309impl<T: ConcreteTreeNode> TreeNode for T {
1310    fn apply_children<'n, F: FnMut(&'n Self) -> Result<TreeNodeRecursion>>(
1311        &'n self,
1312        f: F,
1313    ) -> Result<TreeNodeRecursion> {
1314        self.children().iter().apply_until_stop(f)
1315    }
1316
1317    fn map_children<F: FnMut(Self) -> Result<Transformed<Self>>>(
1318        self,
1319        f: F,
1320    ) -> Result<Transformed<Self>> {
1321        let (new_self, children) = self.take_children();
1322        if !children.is_empty() {
1323            let new_children = children.into_iter().map_until_stop_and_collect(f)?;
1324            // Propagate up `new_children.transformed` and `new_children.tnr` along with
1325            // the node containing transformed children.
1326            new_children.map_data(|new_children| new_self.with_new_children(new_children))
1327        } else {
1328            Ok(Transformed::no(new_self))
1329        }
1330    }
1331}
1332
1333#[cfg(test)]
1334pub(crate) mod tests {
1335    use std::collections::HashMap;
1336    use std::fmt::Display;
1337
1338    use crate::tree_node::{
1339        Transformed, TreeNode, TreeNodeContainer, TreeNodeRecursion, TreeNodeRewriter,
1340        TreeNodeVisitor,
1341    };
1342    use crate::Result;
1343
1344    #[derive(Debug, Eq, Hash, PartialEq, Clone)]
1345    pub struct TestTreeNode<T> {
1346        pub(crate) children: Vec<TestTreeNode<T>>,
1347        pub(crate) data: T,
1348    }
1349
1350    impl<T> TestTreeNode<T> {
1351        pub(crate) fn new(children: Vec<TestTreeNode<T>>, data: T) -> Self {
1352            Self { children, data }
1353        }
1354
1355        pub(crate) fn new_leaf(data: T) -> Self {
1356            Self {
1357                children: vec![],
1358                data,
1359            }
1360        }
1361
1362        pub(crate) fn is_leaf(&self) -> bool {
1363            self.children.is_empty()
1364        }
1365    }
1366
1367    impl<T> TreeNode for TestTreeNode<T> {
1368        fn apply_children<'n, F: FnMut(&'n Self) -> Result<TreeNodeRecursion>>(
1369            &'n self,
1370            f: F,
1371        ) -> Result<TreeNodeRecursion> {
1372            self.children.apply_elements(f)
1373        }
1374
1375        fn map_children<F: FnMut(Self) -> Result<Transformed<Self>>>(
1376            self,
1377            f: F,
1378        ) -> Result<Transformed<Self>> {
1379            Ok(self
1380                .children
1381                .map_elements(f)?
1382                .update_data(|new_children| Self {
1383                    children: new_children,
1384                    ..self
1385                }))
1386        }
1387    }
1388
1389    impl<'a, T: 'a> TreeNodeContainer<'a, Self> for TestTreeNode<T> {
1390        fn apply_elements<F: FnMut(&'a Self) -> Result<TreeNodeRecursion>>(
1391            &'a self,
1392            mut f: F,
1393        ) -> Result<TreeNodeRecursion> {
1394            f(self)
1395        }
1396
1397        fn map_elements<F: FnMut(Self) -> Result<Transformed<Self>>>(
1398            self,
1399            mut f: F,
1400        ) -> Result<Transformed<Self>> {
1401            f(self)
1402        }
1403    }
1404
1405    //       J
1406    //       |
1407    //       I
1408    //       |
1409    //       F
1410    //     /   \
1411    //    E     G
1412    //    |     |
1413    //    C     H
1414    //  /   \
1415    // B     D
1416    //       |
1417    //       A
1418    fn test_tree() -> TestTreeNode<String> {
1419        let node_a = TestTreeNode::new_leaf("a".to_string());
1420        let node_b = TestTreeNode::new_leaf("b".to_string());
1421        let node_d = TestTreeNode::new(vec![node_a], "d".to_string());
1422        let node_c = TestTreeNode::new(vec![node_b, node_d], "c".to_string());
1423        let node_e = TestTreeNode::new(vec![node_c], "e".to_string());
1424        let node_h = TestTreeNode::new_leaf("h".to_string());
1425        let node_g = TestTreeNode::new(vec![node_h], "g".to_string());
1426        let node_f = TestTreeNode::new(vec![node_e, node_g], "f".to_string());
1427        let node_i = TestTreeNode::new(vec![node_f], "i".to_string());
1428        TestTreeNode::new(vec![node_i], "j".to_string())
1429    }
1430
1431    // Continue on all nodes
1432    // Expected visits in a combined traversal
1433    fn all_visits() -> Vec<String> {
1434        vec![
1435            "f_down(j)",
1436            "f_down(i)",
1437            "f_down(f)",
1438            "f_down(e)",
1439            "f_down(c)",
1440            "f_down(b)",
1441            "f_up(b)",
1442            "f_down(d)",
1443            "f_down(a)",
1444            "f_up(a)",
1445            "f_up(d)",
1446            "f_up(c)",
1447            "f_up(e)",
1448            "f_down(g)",
1449            "f_down(h)",
1450            "f_up(h)",
1451            "f_up(g)",
1452            "f_up(f)",
1453            "f_up(i)",
1454            "f_up(j)",
1455        ]
1456        .into_iter()
1457        .map(|s| s.to_string())
1458        .collect()
1459    }
1460
1461    // Expected transformed tree after a combined traversal
1462    fn transformed_tree() -> TestTreeNode<String> {
1463        let node_a = TestTreeNode::new_leaf("f_up(f_down(a))".to_string());
1464        let node_b = TestTreeNode::new_leaf("f_up(f_down(b))".to_string());
1465        let node_d = TestTreeNode::new(vec![node_a], "f_up(f_down(d))".to_string());
1466        let node_c =
1467            TestTreeNode::new(vec![node_b, node_d], "f_up(f_down(c))".to_string());
1468        let node_e = TestTreeNode::new(vec![node_c], "f_up(f_down(e))".to_string());
1469        let node_h = TestTreeNode::new_leaf("f_up(f_down(h))".to_string());
1470        let node_g = TestTreeNode::new(vec![node_h], "f_up(f_down(g))".to_string());
1471        let node_f =
1472            TestTreeNode::new(vec![node_e, node_g], "f_up(f_down(f))".to_string());
1473        let node_i = TestTreeNode::new(vec![node_f], "f_up(f_down(i))".to_string());
1474        TestTreeNode::new(vec![node_i], "f_up(f_down(j))".to_string())
1475    }
1476
1477    // Expected transformed tree after a top-down traversal
1478    fn transformed_down_tree() -> TestTreeNode<String> {
1479        let node_a = TestTreeNode::new_leaf("f_down(a)".to_string());
1480        let node_b = TestTreeNode::new_leaf("f_down(b)".to_string());
1481        let node_d = TestTreeNode::new(vec![node_a], "f_down(d)".to_string());
1482        let node_c = TestTreeNode::new(vec![node_b, node_d], "f_down(c)".to_string());
1483        let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string());
1484        let node_h = TestTreeNode::new_leaf("f_down(h)".to_string());
1485        let node_g = TestTreeNode::new(vec![node_h], "f_down(g)".to_string());
1486        let node_f = TestTreeNode::new(vec![node_e, node_g], "f_down(f)".to_string());
1487        let node_i = TestTreeNode::new(vec![node_f], "f_down(i)".to_string());
1488        TestTreeNode::new(vec![node_i], "f_down(j)".to_string())
1489    }
1490
1491    // Expected transformed tree after a bottom-up traversal
1492    fn transformed_up_tree() -> TestTreeNode<String> {
1493        let node_a = TestTreeNode::new_leaf("f_up(a)".to_string());
1494        let node_b = TestTreeNode::new_leaf("f_up(b)".to_string());
1495        let node_d = TestTreeNode::new(vec![node_a], "f_up(d)".to_string());
1496        let node_c = TestTreeNode::new(vec![node_b, node_d], "f_up(c)".to_string());
1497        let node_e = TestTreeNode::new(vec![node_c], "f_up(e)".to_string());
1498        let node_h = TestTreeNode::new_leaf("f_up(h)".to_string());
1499        let node_g = TestTreeNode::new(vec![node_h], "f_up(g)".to_string());
1500        let node_f = TestTreeNode::new(vec![node_e, node_g], "f_up(f)".to_string());
1501        let node_i = TestTreeNode::new(vec![node_f], "f_up(i)".to_string());
1502        TestTreeNode::new(vec![node_i], "f_up(j)".to_string())
1503    }
1504
1505    // f_down Jump on A node
1506    fn f_down_jump_on_a_visits() -> Vec<String> {
1507        vec![
1508            "f_down(j)",
1509            "f_down(i)",
1510            "f_down(f)",
1511            "f_down(e)",
1512            "f_down(c)",
1513            "f_down(b)",
1514            "f_up(b)",
1515            "f_down(d)",
1516            "f_down(a)",
1517            "f_up(a)",
1518            "f_up(d)",
1519            "f_up(c)",
1520            "f_up(e)",
1521            "f_down(g)",
1522            "f_down(h)",
1523            "f_up(h)",
1524            "f_up(g)",
1525            "f_up(f)",
1526            "f_up(i)",
1527            "f_up(j)",
1528        ]
1529        .into_iter()
1530        .map(|s| s.to_string())
1531        .collect()
1532    }
1533
1534    fn f_down_jump_on_a_transformed_down_tree() -> TestTreeNode<String> {
1535        let node_a = TestTreeNode::new_leaf("f_down(a)".to_string());
1536        let node_b = TestTreeNode::new_leaf("f_down(b)".to_string());
1537        let node_d = TestTreeNode::new(vec![node_a], "f_down(d)".to_string());
1538        let node_c = TestTreeNode::new(vec![node_b, node_d], "f_down(c)".to_string());
1539        let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string());
1540        let node_h = TestTreeNode::new_leaf("f_down(h)".to_string());
1541        let node_g = TestTreeNode::new(vec![node_h], "f_down(g)".to_string());
1542        let node_f = TestTreeNode::new(vec![node_e, node_g], "f_down(f)".to_string());
1543        let node_i = TestTreeNode::new(vec![node_f], "f_down(i)".to_string());
1544        TestTreeNode::new(vec![node_i], "f_down(j)".to_string())
1545    }
1546
1547    // f_down Jump on E node
1548    fn f_down_jump_on_e_visits() -> Vec<String> {
1549        vec![
1550            "f_down(j)",
1551            "f_down(i)",
1552            "f_down(f)",
1553            "f_down(e)",
1554            "f_up(e)",
1555            "f_down(g)",
1556            "f_down(h)",
1557            "f_up(h)",
1558            "f_up(g)",
1559            "f_up(f)",
1560            "f_up(i)",
1561            "f_up(j)",
1562        ]
1563        .into_iter()
1564        .map(|s| s.to_string())
1565        .collect()
1566    }
1567
1568    fn f_down_jump_on_e_transformed_tree() -> TestTreeNode<String> {
1569        let node_a = TestTreeNode::new_leaf("a".to_string());
1570        let node_b = TestTreeNode::new_leaf("b".to_string());
1571        let node_d = TestTreeNode::new(vec![node_a], "d".to_string());
1572        let node_c = TestTreeNode::new(vec![node_b, node_d], "c".to_string());
1573        let node_e = TestTreeNode::new(vec![node_c], "f_up(f_down(e))".to_string());
1574        let node_h = TestTreeNode::new_leaf("f_up(f_down(h))".to_string());
1575        let node_g = TestTreeNode::new(vec![node_h], "f_up(f_down(g))".to_string());
1576        let node_f =
1577            TestTreeNode::new(vec![node_e, node_g], "f_up(f_down(f))".to_string());
1578        let node_i = TestTreeNode::new(vec![node_f], "f_up(f_down(i))".to_string());
1579        TestTreeNode::new(vec![node_i], "f_up(f_down(j))".to_string())
1580    }
1581
1582    fn f_down_jump_on_e_transformed_down_tree() -> TestTreeNode<String> {
1583        let node_a = TestTreeNode::new_leaf("a".to_string());
1584        let node_b = TestTreeNode::new_leaf("b".to_string());
1585        let node_d = TestTreeNode::new(vec![node_a], "d".to_string());
1586        let node_c = TestTreeNode::new(vec![node_b, node_d], "c".to_string());
1587        let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string());
1588        let node_h = TestTreeNode::new_leaf("f_down(h)".to_string());
1589        let node_g = TestTreeNode::new(vec![node_h], "f_down(g)".to_string());
1590        let node_f = TestTreeNode::new(vec![node_e, node_g], "f_down(f)".to_string());
1591        let node_i = TestTreeNode::new(vec![node_f], "f_down(i)".to_string());
1592        TestTreeNode::new(vec![node_i], "f_down(j)".to_string())
1593    }
1594
1595    // f_up Jump on A node
1596    fn f_up_jump_on_a_visits() -> Vec<String> {
1597        vec![
1598            "f_down(j)",
1599            "f_down(i)",
1600            "f_down(f)",
1601            "f_down(e)",
1602            "f_down(c)",
1603            "f_down(b)",
1604            "f_up(b)",
1605            "f_down(d)",
1606            "f_down(a)",
1607            "f_up(a)",
1608            "f_down(g)",
1609            "f_down(h)",
1610            "f_up(h)",
1611            "f_up(g)",
1612            "f_up(f)",
1613            "f_up(i)",
1614            "f_up(j)",
1615        ]
1616        .into_iter()
1617        .map(|s| s.to_string())
1618        .collect()
1619    }
1620
1621    fn f_up_jump_on_a_transformed_tree() -> TestTreeNode<String> {
1622        let node_a = TestTreeNode::new_leaf("f_up(f_down(a))".to_string());
1623        let node_b = TestTreeNode::new_leaf("f_up(f_down(b))".to_string());
1624        let node_d = TestTreeNode::new(vec![node_a], "f_down(d)".to_string());
1625        let node_c = TestTreeNode::new(vec![node_b, node_d], "f_down(c)".to_string());
1626        let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string());
1627        let node_h = TestTreeNode::new_leaf("f_up(f_down(h))".to_string());
1628        let node_g = TestTreeNode::new(vec![node_h], "f_up(f_down(g))".to_string());
1629        let node_f =
1630            TestTreeNode::new(vec![node_e, node_g], "f_up(f_down(f))".to_string());
1631        let node_i = TestTreeNode::new(vec![node_f], "f_up(f_down(i))".to_string());
1632        TestTreeNode::new(vec![node_i], "f_up(f_down(j))".to_string())
1633    }
1634
1635    fn f_up_jump_on_a_transformed_up_tree() -> TestTreeNode<String> {
1636        let node_a = TestTreeNode::new_leaf("f_up(a)".to_string());
1637        let node_b = TestTreeNode::new_leaf("f_up(b)".to_string());
1638        let node_d = TestTreeNode::new(vec![node_a], "d".to_string());
1639        let node_c = TestTreeNode::new(vec![node_b, node_d], "c".to_string());
1640        let node_e = TestTreeNode::new(vec![node_c], "e".to_string());
1641        let node_h = TestTreeNode::new_leaf("f_up(h)".to_string());
1642        let node_g = TestTreeNode::new(vec![node_h], "f_up(g)".to_string());
1643        let node_f = TestTreeNode::new(vec![node_e, node_g], "f_up(f)".to_string());
1644        let node_i = TestTreeNode::new(vec![node_f], "f_up(i)".to_string());
1645        TestTreeNode::new(vec![node_i], "f_up(j)".to_string())
1646    }
1647
1648    // f_up Jump on E node
1649    fn f_up_jump_on_e_visits() -> Vec<String> {
1650        vec![
1651            "f_down(j)",
1652            "f_down(i)",
1653            "f_down(f)",
1654            "f_down(e)",
1655            "f_down(c)",
1656            "f_down(b)",
1657            "f_up(b)",
1658            "f_down(d)",
1659            "f_down(a)",
1660            "f_up(a)",
1661            "f_up(d)",
1662            "f_up(c)",
1663            "f_up(e)",
1664            "f_down(g)",
1665            "f_down(h)",
1666            "f_up(h)",
1667            "f_up(g)",
1668            "f_up(f)",
1669            "f_up(i)",
1670            "f_up(j)",
1671        ]
1672        .into_iter()
1673        .map(|s| s.to_string())
1674        .collect()
1675    }
1676
1677    fn f_up_jump_on_e_transformed_tree() -> TestTreeNode<String> {
1678        transformed_tree()
1679    }
1680
1681    fn f_up_jump_on_e_transformed_up_tree() -> TestTreeNode<String> {
1682        transformed_up_tree()
1683    }
1684
1685    // f_down Stop on A node
1686
1687    fn f_down_stop_on_a_visits() -> Vec<String> {
1688        vec![
1689            "f_down(j)",
1690            "f_down(i)",
1691            "f_down(f)",
1692            "f_down(e)",
1693            "f_down(c)",
1694            "f_down(b)",
1695            "f_up(b)",
1696            "f_down(d)",
1697            "f_down(a)",
1698        ]
1699        .into_iter()
1700        .map(|s| s.to_string())
1701        .collect()
1702    }
1703
1704    fn f_down_stop_on_a_transformed_tree() -> TestTreeNode<String> {
1705        let node_a = TestTreeNode::new_leaf("f_down(a)".to_string());
1706        let node_b = TestTreeNode::new_leaf("f_up(f_down(b))".to_string());
1707        let node_d = TestTreeNode::new(vec![node_a], "f_down(d)".to_string());
1708        let node_c = TestTreeNode::new(vec![node_b, node_d], "f_down(c)".to_string());
1709        let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string());
1710        let node_h = TestTreeNode::new_leaf("h".to_string());
1711        let node_g = TestTreeNode::new(vec![node_h], "g".to_string());
1712        let node_f = TestTreeNode::new(vec![node_e, node_g], "f_down(f)".to_string());
1713        let node_i = TestTreeNode::new(vec![node_f], "f_down(i)".to_string());
1714        TestTreeNode::new(vec![node_i], "f_down(j)".to_string())
1715    }
1716
1717    fn f_down_stop_on_a_transformed_down_tree() -> TestTreeNode<String> {
1718        let node_a = TestTreeNode::new_leaf("f_down(a)".to_string());
1719        let node_b = TestTreeNode::new_leaf("f_down(b)".to_string());
1720        let node_d = TestTreeNode::new(vec![node_a], "f_down(d)".to_string());
1721        let node_c = TestTreeNode::new(vec![node_b, node_d], "f_down(c)".to_string());
1722        let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string());
1723        let node_h = TestTreeNode::new_leaf("h".to_string());
1724        let node_g = TestTreeNode::new(vec![node_h], "g".to_string());
1725        let node_f = TestTreeNode::new(vec![node_e, node_g], "f_down(f)".to_string());
1726        let node_i = TestTreeNode::new(vec![node_f], "f_down(i)".to_string());
1727        TestTreeNode::new(vec![node_i], "f_down(j)".to_string())
1728    }
1729
1730    // f_down Stop on E node
1731    fn f_down_stop_on_e_visits() -> Vec<String> {
1732        vec!["f_down(j)", "f_down(i)", "f_down(f)", "f_down(e)"]
1733            .into_iter()
1734            .map(|s| s.to_string())
1735            .collect()
1736    }
1737
1738    fn f_down_stop_on_e_transformed_tree() -> TestTreeNode<String> {
1739        let node_a = TestTreeNode::new_leaf("a".to_string());
1740        let node_b = TestTreeNode::new_leaf("b".to_string());
1741        let node_d = TestTreeNode::new(vec![node_a], "d".to_string());
1742        let node_c = TestTreeNode::new(vec![node_b, node_d], "c".to_string());
1743        let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string());
1744        let node_h = TestTreeNode::new_leaf("h".to_string());
1745        let node_g = TestTreeNode::new(vec![node_h], "g".to_string());
1746        let node_f = TestTreeNode::new(vec![node_e, node_g], "f_down(f)".to_string());
1747        let node_i = TestTreeNode::new(vec![node_f], "f_down(i)".to_string());
1748        TestTreeNode::new(vec![node_i], "f_down(j)".to_string())
1749    }
1750
1751    fn f_down_stop_on_e_transformed_down_tree() -> TestTreeNode<String> {
1752        let node_a = TestTreeNode::new_leaf("a".to_string());
1753        let node_b = TestTreeNode::new_leaf("b".to_string());
1754        let node_d = TestTreeNode::new(vec![node_a], "d".to_string());
1755        let node_c = TestTreeNode::new(vec![node_b, node_d], "c".to_string());
1756        let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string());
1757        let node_h = TestTreeNode::new_leaf("h".to_string());
1758        let node_g = TestTreeNode::new(vec![node_h], "g".to_string());
1759        let node_f = TestTreeNode::new(vec![node_e, node_g], "f_down(f)".to_string());
1760        let node_i = TestTreeNode::new(vec![node_f], "f_down(i)".to_string());
1761        TestTreeNode::new(vec![node_i], "f_down(j)".to_string())
1762    }
1763
1764    // f_up Stop on A node
1765    fn f_up_stop_on_a_visits() -> Vec<String> {
1766        vec![
1767            "f_down(j)",
1768            "f_down(i)",
1769            "f_down(f)",
1770            "f_down(e)",
1771            "f_down(c)",
1772            "f_down(b)",
1773            "f_up(b)",
1774            "f_down(d)",
1775            "f_down(a)",
1776            "f_up(a)",
1777        ]
1778        .into_iter()
1779        .map(|s| s.to_string())
1780        .collect()
1781    }
1782
1783    fn f_up_stop_on_a_transformed_tree() -> TestTreeNode<String> {
1784        let node_a = TestTreeNode::new_leaf("f_up(f_down(a))".to_string());
1785        let node_b = TestTreeNode::new_leaf("f_up(f_down(b))".to_string());
1786        let node_d = TestTreeNode::new(vec![node_a], "f_down(d)".to_string());
1787        let node_c = TestTreeNode::new(vec![node_b, node_d], "f_down(c)".to_string());
1788        let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string());
1789        let node_h = TestTreeNode::new_leaf("h".to_string());
1790        let node_g = TestTreeNode::new(vec![node_h], "g".to_string());
1791        let node_f = TestTreeNode::new(vec![node_e, node_g], "f_down(f)".to_string());
1792        let node_i = TestTreeNode::new(vec![node_f], "f_down(i)".to_string());
1793        TestTreeNode::new(vec![node_i], "f_down(j)".to_string())
1794    }
1795
1796    fn f_up_stop_on_a_transformed_up_tree() -> TestTreeNode<String> {
1797        let node_a = TestTreeNode::new_leaf("f_up(a)".to_string());
1798        let node_b = TestTreeNode::new_leaf("f_up(b)".to_string());
1799        let node_d = TestTreeNode::new(vec![node_a], "d".to_string());
1800        let node_c = TestTreeNode::new(vec![node_b, node_d], "c".to_string());
1801        let node_e = TestTreeNode::new(vec![node_c], "e".to_string());
1802        let node_h = TestTreeNode::new_leaf("h".to_string());
1803        let node_g = TestTreeNode::new(vec![node_h], "g".to_string());
1804        let node_f = TestTreeNode::new(vec![node_e, node_g], "f".to_string());
1805        let node_i = TestTreeNode::new(vec![node_f], "i".to_string());
1806        TestTreeNode::new(vec![node_i], "j".to_string())
1807    }
1808
1809    // f_up Stop on E node
1810    fn f_up_stop_on_e_visits() -> Vec<String> {
1811        vec![
1812            "f_down(j)",
1813            "f_down(i)",
1814            "f_down(f)",
1815            "f_down(e)",
1816            "f_down(c)",
1817            "f_down(b)",
1818            "f_up(b)",
1819            "f_down(d)",
1820            "f_down(a)",
1821            "f_up(a)",
1822            "f_up(d)",
1823            "f_up(c)",
1824            "f_up(e)",
1825        ]
1826        .into_iter()
1827        .map(|s| s.to_string())
1828        .collect()
1829    }
1830
1831    fn f_up_stop_on_e_transformed_tree() -> TestTreeNode<String> {
1832        let node_a = TestTreeNode::new_leaf("f_up(f_down(a))".to_string());
1833        let node_b = TestTreeNode::new_leaf("f_up(f_down(b))".to_string());
1834        let node_d = TestTreeNode::new(vec![node_a], "f_up(f_down(d))".to_string());
1835        let node_c =
1836            TestTreeNode::new(vec![node_b, node_d], "f_up(f_down(c))".to_string());
1837        let node_e = TestTreeNode::new(vec![node_c], "f_up(f_down(e))".to_string());
1838        let node_h = TestTreeNode::new_leaf("h".to_string());
1839        let node_g = TestTreeNode::new(vec![node_h], "g".to_string());
1840        let node_f = TestTreeNode::new(vec![node_e, node_g], "f_down(f)".to_string());
1841        let node_i = TestTreeNode::new(vec![node_f], "f_down(i)".to_string());
1842        TestTreeNode::new(vec![node_i], "f_down(j)".to_string())
1843    }
1844
1845    fn f_up_stop_on_e_transformed_up_tree() -> TestTreeNode<String> {
1846        let node_a = TestTreeNode::new_leaf("f_up(a)".to_string());
1847        let node_b = TestTreeNode::new_leaf("f_up(b)".to_string());
1848        let node_d = TestTreeNode::new(vec![node_a], "f_up(d)".to_string());
1849        let node_c = TestTreeNode::new(vec![node_b, node_d], "f_up(c)".to_string());
1850        let node_e = TestTreeNode::new(vec![node_c], "f_up(e)".to_string());
1851        let node_h = TestTreeNode::new_leaf("h".to_string());
1852        let node_g = TestTreeNode::new(vec![node_h], "g".to_string());
1853        let node_f = TestTreeNode::new(vec![node_e, node_g], "f".to_string());
1854        let node_i = TestTreeNode::new(vec![node_f], "i".to_string());
1855        TestTreeNode::new(vec![node_i], "j".to_string())
1856    }
1857
1858    fn down_visits(visits: Vec<String>) -> Vec<String> {
1859        visits
1860            .into_iter()
1861            .filter(|v| v.starts_with("f_down"))
1862            .collect()
1863    }
1864
1865    type TestVisitorF<T> = Box<dyn FnMut(&TestTreeNode<T>) -> Result<TreeNodeRecursion>>;
1866
1867    struct TestVisitor<T> {
1868        visits: Vec<String>,
1869        f_down: TestVisitorF<T>,
1870        f_up: TestVisitorF<T>,
1871    }
1872
1873    impl<T> TestVisitor<T> {
1874        fn new(f_down: TestVisitorF<T>, f_up: TestVisitorF<T>) -> Self {
1875            Self {
1876                visits: vec![],
1877                f_down,
1878                f_up,
1879            }
1880        }
1881    }
1882
1883    impl<'n, T: Display> TreeNodeVisitor<'n> for TestVisitor<T> {
1884        type Node = TestTreeNode<T>;
1885
1886        fn f_down(&mut self, node: &'n Self::Node) -> Result<TreeNodeRecursion> {
1887            self.visits.push(format!("f_down({})", node.data));
1888            (*self.f_down)(node)
1889        }
1890
1891        fn f_up(&mut self, node: &'n Self::Node) -> Result<TreeNodeRecursion> {
1892            self.visits.push(format!("f_up({})", node.data));
1893            (*self.f_up)(node)
1894        }
1895    }
1896
1897    fn visit_continue<T>(_: &TestTreeNode<T>) -> Result<TreeNodeRecursion> {
1898        Ok(TreeNodeRecursion::Continue)
1899    }
1900
1901    fn visit_event_on<T: PartialEq, D: Into<T>>(
1902        data: D,
1903        event: TreeNodeRecursion,
1904    ) -> impl FnMut(&TestTreeNode<T>) -> Result<TreeNodeRecursion> {
1905        let d = data.into();
1906        move |node| {
1907            Ok(if node.data == d {
1908                event
1909            } else {
1910                TreeNodeRecursion::Continue
1911            })
1912        }
1913    }
1914
1915    macro_rules! visit_test {
1916        ($NAME:ident, $F_DOWN:expr, $F_UP:expr, $EXPECTED_VISITS:expr) => {
1917            #[test]
1918            fn $NAME() -> Result<()> {
1919                let tree = test_tree();
1920                let mut visitor = TestVisitor::new(Box::new($F_DOWN), Box::new($F_UP));
1921                tree.visit(&mut visitor)?;
1922                assert_eq!(visitor.visits, $EXPECTED_VISITS);
1923
1924                Ok(())
1925            }
1926        };
1927    }
1928
1929    macro_rules! test_apply {
1930        ($NAME:ident, $F:expr, $EXPECTED_VISITS:expr) => {
1931            #[test]
1932            fn $NAME() -> Result<()> {
1933                let tree = test_tree();
1934                let mut visits = vec![];
1935                tree.apply(|node| {
1936                    visits.push(format!("f_down({})", node.data));
1937                    $F(node)
1938                })?;
1939                assert_eq!(visits, $EXPECTED_VISITS);
1940
1941                Ok(())
1942            }
1943        };
1944    }
1945
1946    type TestRewriterF<T> =
1947        Box<dyn FnMut(TestTreeNode<T>) -> Result<Transformed<TestTreeNode<T>>>>;
1948
1949    struct TestRewriter<T> {
1950        f_down: TestRewriterF<T>,
1951        f_up: TestRewriterF<T>,
1952    }
1953
1954    impl<T> TestRewriter<T> {
1955        fn new(f_down: TestRewriterF<T>, f_up: TestRewriterF<T>) -> Self {
1956            Self { f_down, f_up }
1957        }
1958    }
1959
1960    impl<T: Display> TreeNodeRewriter for TestRewriter<T> {
1961        type Node = TestTreeNode<T>;
1962
1963        fn f_down(&mut self, node: Self::Node) -> Result<Transformed<Self::Node>> {
1964            (*self.f_down)(node)
1965        }
1966
1967        fn f_up(&mut self, node: Self::Node) -> Result<Transformed<Self::Node>> {
1968            (*self.f_up)(node)
1969        }
1970    }
1971
1972    fn transform_yes<N: Display, T: Display + From<String>>(
1973        transformation_name: N,
1974    ) -> impl FnMut(TestTreeNode<T>) -> Result<Transformed<TestTreeNode<T>>> {
1975        move |node| {
1976            Ok(Transformed::yes(TestTreeNode::new(
1977                node.children,
1978                format!("{}({})", transformation_name, node.data).into(),
1979            )))
1980        }
1981    }
1982
1983    fn transform_and_event_on<
1984        N: Display,
1985        T: PartialEq + Display + From<String>,
1986        D: Into<T>,
1987    >(
1988        transformation_name: N,
1989        data: D,
1990        event: TreeNodeRecursion,
1991    ) -> impl FnMut(TestTreeNode<T>) -> Result<Transformed<TestTreeNode<T>>> {
1992        let d = data.into();
1993        move |node| {
1994            let new_node = TestTreeNode::new(
1995                node.children,
1996                format!("{}({})", transformation_name, node.data).into(),
1997            );
1998            Ok(if node.data == d {
1999                Transformed::new(new_node, true, event)
2000            } else {
2001                Transformed::yes(new_node)
2002            })
2003        }
2004    }
2005
2006    macro_rules! rewrite_test {
2007        ($NAME:ident, $F_DOWN:expr, $F_UP:expr, $EXPECTED_TREE:expr) => {
2008            #[test]
2009            fn $NAME() -> Result<()> {
2010                let tree = test_tree();
2011                let mut rewriter = TestRewriter::new(Box::new($F_DOWN), Box::new($F_UP));
2012                assert_eq!(tree.rewrite(&mut rewriter)?, $EXPECTED_TREE);
2013
2014                Ok(())
2015            }
2016        };
2017    }
2018
2019    macro_rules! transform_test {
2020        ($NAME:ident, $F_DOWN:expr, $F_UP:expr, $EXPECTED_TREE:expr) => {
2021            #[test]
2022            fn $NAME() -> Result<()> {
2023                let tree = test_tree();
2024                assert_eq!(tree.transform_down_up($F_DOWN, $F_UP,)?, $EXPECTED_TREE);
2025
2026                Ok(())
2027            }
2028        };
2029    }
2030
2031    macro_rules! transform_down_test {
2032        ($NAME:ident, $F:expr, $EXPECTED_TREE:expr) => {
2033            #[test]
2034            fn $NAME() -> Result<()> {
2035                let tree = test_tree();
2036                assert_eq!(tree.transform_down($F)?, $EXPECTED_TREE);
2037
2038                Ok(())
2039            }
2040        };
2041    }
2042
2043    macro_rules! transform_up_test {
2044        ($NAME:ident, $F:expr, $EXPECTED_TREE:expr) => {
2045            #[test]
2046            fn $NAME() -> Result<()> {
2047                let tree = test_tree();
2048                assert_eq!(tree.transform_up($F)?, $EXPECTED_TREE);
2049
2050                Ok(())
2051            }
2052        };
2053    }
2054
2055    visit_test!(test_visit, visit_continue, visit_continue, all_visits());
2056    visit_test!(
2057        test_visit_f_down_jump_on_a,
2058        visit_event_on("a", TreeNodeRecursion::Jump),
2059        visit_continue,
2060        f_down_jump_on_a_visits()
2061    );
2062    visit_test!(
2063        test_visit_f_down_jump_on_e,
2064        visit_event_on("e", TreeNodeRecursion::Jump),
2065        visit_continue,
2066        f_down_jump_on_e_visits()
2067    );
2068    visit_test!(
2069        test_visit_f_up_jump_on_a,
2070        visit_continue,
2071        visit_event_on("a", TreeNodeRecursion::Jump),
2072        f_up_jump_on_a_visits()
2073    );
2074    visit_test!(
2075        test_visit_f_up_jump_on_e,
2076        visit_continue,
2077        visit_event_on("e", TreeNodeRecursion::Jump),
2078        f_up_jump_on_e_visits()
2079    );
2080    visit_test!(
2081        test_visit_f_down_stop_on_a,
2082        visit_event_on("a", TreeNodeRecursion::Stop),
2083        visit_continue,
2084        f_down_stop_on_a_visits()
2085    );
2086    visit_test!(
2087        test_visit_f_down_stop_on_e,
2088        visit_event_on("e", TreeNodeRecursion::Stop),
2089        visit_continue,
2090        f_down_stop_on_e_visits()
2091    );
2092    visit_test!(
2093        test_visit_f_up_stop_on_a,
2094        visit_continue,
2095        visit_event_on("a", TreeNodeRecursion::Stop),
2096        f_up_stop_on_a_visits()
2097    );
2098    visit_test!(
2099        test_visit_f_up_stop_on_e,
2100        visit_continue,
2101        visit_event_on("e", TreeNodeRecursion::Stop),
2102        f_up_stop_on_e_visits()
2103    );
2104
2105    test_apply!(test_apply, visit_continue, down_visits(all_visits()));
2106    test_apply!(
2107        test_apply_f_down_jump_on_a,
2108        visit_event_on("a", TreeNodeRecursion::Jump),
2109        down_visits(f_down_jump_on_a_visits())
2110    );
2111    test_apply!(
2112        test_apply_f_down_jump_on_e,
2113        visit_event_on("e", TreeNodeRecursion::Jump),
2114        down_visits(f_down_jump_on_e_visits())
2115    );
2116    test_apply!(
2117        test_apply_f_down_stop_on_a,
2118        visit_event_on("a", TreeNodeRecursion::Stop),
2119        down_visits(f_down_stop_on_a_visits())
2120    );
2121    test_apply!(
2122        test_apply_f_down_stop_on_e,
2123        visit_event_on("e", TreeNodeRecursion::Stop),
2124        down_visits(f_down_stop_on_e_visits())
2125    );
2126
2127    rewrite_test!(
2128        test_rewrite,
2129        transform_yes("f_down"),
2130        transform_yes("f_up"),
2131        Transformed::yes(transformed_tree())
2132    );
2133    rewrite_test!(
2134        test_rewrite_f_down_jump_on_a,
2135        transform_and_event_on("f_down", "a", TreeNodeRecursion::Jump),
2136        transform_yes("f_up"),
2137        Transformed::yes(transformed_tree())
2138    );
2139    rewrite_test!(
2140        test_rewrite_f_down_jump_on_e,
2141        transform_and_event_on("f_down", "e", TreeNodeRecursion::Jump),
2142        transform_yes("f_up"),
2143        Transformed::yes(f_down_jump_on_e_transformed_tree())
2144    );
2145    rewrite_test!(
2146        test_rewrite_f_up_jump_on_a,
2147        transform_yes("f_down"),
2148        transform_and_event_on("f_up", "f_down(a)", TreeNodeRecursion::Jump),
2149        Transformed::yes(f_up_jump_on_a_transformed_tree())
2150    );
2151    rewrite_test!(
2152        test_rewrite_f_up_jump_on_e,
2153        transform_yes("f_down"),
2154        transform_and_event_on("f_up", "f_down(e)", TreeNodeRecursion::Jump),
2155        Transformed::yes(f_up_jump_on_e_transformed_tree())
2156    );
2157    rewrite_test!(
2158        test_rewrite_f_down_stop_on_a,
2159        transform_and_event_on("f_down", "a", TreeNodeRecursion::Stop),
2160        transform_yes("f_up"),
2161        Transformed::new(
2162            f_down_stop_on_a_transformed_tree(),
2163            true,
2164            TreeNodeRecursion::Stop
2165        )
2166    );
2167    rewrite_test!(
2168        test_rewrite_f_down_stop_on_e,
2169        transform_and_event_on("f_down", "e", TreeNodeRecursion::Stop),
2170        transform_yes("f_up"),
2171        Transformed::new(
2172            f_down_stop_on_e_transformed_tree(),
2173            true,
2174            TreeNodeRecursion::Stop
2175        )
2176    );
2177    rewrite_test!(
2178        test_rewrite_f_up_stop_on_a,
2179        transform_yes("f_down"),
2180        transform_and_event_on("f_up", "f_down(a)", TreeNodeRecursion::Stop),
2181        Transformed::new(
2182            f_up_stop_on_a_transformed_tree(),
2183            true,
2184            TreeNodeRecursion::Stop
2185        )
2186    );
2187    rewrite_test!(
2188        test_rewrite_f_up_stop_on_e,
2189        transform_yes("f_down"),
2190        transform_and_event_on("f_up", "f_down(e)", TreeNodeRecursion::Stop),
2191        Transformed::new(
2192            f_up_stop_on_e_transformed_tree(),
2193            true,
2194            TreeNodeRecursion::Stop
2195        )
2196    );
2197
2198    transform_test!(
2199        test_transform,
2200        transform_yes("f_down"),
2201        transform_yes("f_up"),
2202        Transformed::yes(transformed_tree())
2203    );
2204    transform_test!(
2205        test_transform_f_down_jump_on_a,
2206        transform_and_event_on("f_down", "a", TreeNodeRecursion::Jump),
2207        transform_yes("f_up"),
2208        Transformed::yes(transformed_tree())
2209    );
2210    transform_test!(
2211        test_transform_f_down_jump_on_e,
2212        transform_and_event_on("f_down", "e", TreeNodeRecursion::Jump),
2213        transform_yes("f_up"),
2214        Transformed::yes(f_down_jump_on_e_transformed_tree())
2215    );
2216    transform_test!(
2217        test_transform_f_up_jump_on_a,
2218        transform_yes("f_down"),
2219        transform_and_event_on("f_up", "f_down(a)", TreeNodeRecursion::Jump),
2220        Transformed::yes(f_up_jump_on_a_transformed_tree())
2221    );
2222    transform_test!(
2223        test_transform_f_up_jump_on_e,
2224        transform_yes("f_down"),
2225        transform_and_event_on("f_up", "f_down(e)", TreeNodeRecursion::Jump),
2226        Transformed::yes(f_up_jump_on_e_transformed_tree())
2227    );
2228    transform_test!(
2229        test_transform_f_down_stop_on_a,
2230        transform_and_event_on("f_down", "a", TreeNodeRecursion::Stop),
2231        transform_yes("f_up"),
2232        Transformed::new(
2233            f_down_stop_on_a_transformed_tree(),
2234            true,
2235            TreeNodeRecursion::Stop
2236        )
2237    );
2238    transform_test!(
2239        test_transform_f_down_stop_on_e,
2240        transform_and_event_on("f_down", "e", TreeNodeRecursion::Stop),
2241        transform_yes("f_up"),
2242        Transformed::new(
2243            f_down_stop_on_e_transformed_tree(),
2244            true,
2245            TreeNodeRecursion::Stop
2246        )
2247    );
2248    transform_test!(
2249        test_transform_f_up_stop_on_a,
2250        transform_yes("f_down"),
2251        transform_and_event_on("f_up", "f_down(a)", TreeNodeRecursion::Stop),
2252        Transformed::new(
2253            f_up_stop_on_a_transformed_tree(),
2254            true,
2255            TreeNodeRecursion::Stop
2256        )
2257    );
2258    transform_test!(
2259        test_transform_f_up_stop_on_e,
2260        transform_yes("f_down"),
2261        transform_and_event_on("f_up", "f_down(e)", TreeNodeRecursion::Stop),
2262        Transformed::new(
2263            f_up_stop_on_e_transformed_tree(),
2264            true,
2265            TreeNodeRecursion::Stop
2266        )
2267    );
2268
2269    transform_down_test!(
2270        test_transform_down,
2271        transform_yes("f_down"),
2272        Transformed::yes(transformed_down_tree())
2273    );
2274    transform_down_test!(
2275        test_transform_down_f_down_jump_on_a,
2276        transform_and_event_on("f_down", "a", TreeNodeRecursion::Jump),
2277        Transformed::yes(f_down_jump_on_a_transformed_down_tree())
2278    );
2279    transform_down_test!(
2280        test_transform_down_f_down_jump_on_e,
2281        transform_and_event_on("f_down", "e", TreeNodeRecursion::Jump),
2282        Transformed::yes(f_down_jump_on_e_transformed_down_tree())
2283    );
2284    transform_down_test!(
2285        test_transform_down_f_down_stop_on_a,
2286        transform_and_event_on("f_down", "a", TreeNodeRecursion::Stop),
2287        Transformed::new(
2288            f_down_stop_on_a_transformed_down_tree(),
2289            true,
2290            TreeNodeRecursion::Stop
2291        )
2292    );
2293    transform_down_test!(
2294        test_transform_down_f_down_stop_on_e,
2295        transform_and_event_on("f_down", "e", TreeNodeRecursion::Stop),
2296        Transformed::new(
2297            f_down_stop_on_e_transformed_down_tree(),
2298            true,
2299            TreeNodeRecursion::Stop
2300        )
2301    );
2302
2303    transform_up_test!(
2304        test_transform_up,
2305        transform_yes("f_up"),
2306        Transformed::yes(transformed_up_tree())
2307    );
2308    transform_up_test!(
2309        test_transform_up_f_up_jump_on_a,
2310        transform_and_event_on("f_up", "a", TreeNodeRecursion::Jump),
2311        Transformed::yes(f_up_jump_on_a_transformed_up_tree())
2312    );
2313    transform_up_test!(
2314        test_transform_up_f_up_jump_on_e,
2315        transform_and_event_on("f_up", "e", TreeNodeRecursion::Jump),
2316        Transformed::yes(f_up_jump_on_e_transformed_up_tree())
2317    );
2318    transform_up_test!(
2319        test_transform_up_f_up_stop_on_a,
2320        transform_and_event_on("f_up", "a", TreeNodeRecursion::Stop),
2321        Transformed::new(
2322            f_up_stop_on_a_transformed_up_tree(),
2323            true,
2324            TreeNodeRecursion::Stop
2325        )
2326    );
2327    transform_up_test!(
2328        test_transform_up_f_up_stop_on_e,
2329        transform_and_event_on("f_up", "e", TreeNodeRecursion::Stop),
2330        Transformed::new(
2331            f_up_stop_on_e_transformed_up_tree(),
2332            true,
2333            TreeNodeRecursion::Stop
2334        )
2335    );
2336
2337    //             F
2338    //          /  |  \
2339    //       /     |     \
2340    //    E        C        A
2341    //    |      /   \
2342    //    C     B     D
2343    //  /   \         |
2344    // B     D        A
2345    //       |
2346    //       A
2347    #[test]
2348    fn test_apply_and_visit_references() -> Result<()> {
2349        let node_a = TestTreeNode::new_leaf("a".to_string());
2350        let node_b = TestTreeNode::new_leaf("b".to_string());
2351        let node_d = TestTreeNode::new(vec![node_a], "d".to_string());
2352        let node_c = TestTreeNode::new(vec![node_b, node_d], "c".to_string());
2353        let node_e = TestTreeNode::new(vec![node_c], "e".to_string());
2354        let node_a_2 = TestTreeNode::new_leaf("a".to_string());
2355        let node_b_2 = TestTreeNode::new_leaf("b".to_string());
2356        let node_d_2 = TestTreeNode::new(vec![node_a_2], "d".to_string());
2357        let node_c_2 = TestTreeNode::new(vec![node_b_2, node_d_2], "c".to_string());
2358        let node_a_3 = TestTreeNode::new_leaf("a".to_string());
2359        let tree = TestTreeNode::new(vec![node_e, node_c_2, node_a_3], "f".to_string());
2360
2361        let node_f_ref = &tree;
2362        let node_e_ref = &node_f_ref.children[0];
2363        let node_c_ref = &node_e_ref.children[0];
2364        let node_b_ref = &node_c_ref.children[0];
2365        let node_d_ref = &node_c_ref.children[1];
2366        let node_a_ref = &node_d_ref.children[0];
2367
2368        let mut m: HashMap<&TestTreeNode<String>, usize> = HashMap::new();
2369        tree.apply(|e| {
2370            *m.entry(e).or_insert(0) += 1;
2371            Ok(TreeNodeRecursion::Continue)
2372        })?;
2373
2374        let expected = HashMap::from([
2375            (node_f_ref, 1),
2376            (node_e_ref, 1),
2377            (node_c_ref, 2),
2378            (node_d_ref, 2),
2379            (node_b_ref, 2),
2380            (node_a_ref, 3),
2381        ]);
2382        assert_eq!(m, expected);
2383
2384        struct TestVisitor<'n> {
2385            m: HashMap<&'n TestTreeNode<String>, (usize, usize)>,
2386        }
2387
2388        impl<'n> TreeNodeVisitor<'n> for TestVisitor<'n> {
2389            type Node = TestTreeNode<String>;
2390
2391            fn f_down(&mut self, node: &'n Self::Node) -> Result<TreeNodeRecursion> {
2392                let (down_count, _) = self.m.entry(node).or_insert((0, 0));
2393                *down_count += 1;
2394                Ok(TreeNodeRecursion::Continue)
2395            }
2396
2397            fn f_up(&mut self, node: &'n Self::Node) -> Result<TreeNodeRecursion> {
2398                let (_, up_count) = self.m.entry(node).or_insert((0, 0));
2399                *up_count += 1;
2400                Ok(TreeNodeRecursion::Continue)
2401            }
2402        }
2403
2404        let mut visitor = TestVisitor { m: HashMap::new() };
2405        tree.visit(&mut visitor)?;
2406
2407        let expected = HashMap::from([
2408            (node_f_ref, (1, 1)),
2409            (node_e_ref, (1, 1)),
2410            (node_c_ref, (2, 2)),
2411            (node_d_ref, (2, 2)),
2412            (node_b_ref, (2, 2)),
2413            (node_a_ref, (3, 3)),
2414        ]);
2415        assert_eq!(visitor.m, expected);
2416
2417        Ok(())
2418    }
2419
2420    #[cfg(feature = "recursive_protection")]
2421    #[test]
2422    fn test_large_tree() {
2423        let mut item = TestTreeNode::new_leaf("initial".to_string());
2424        for i in 0..3000 {
2425            item = TestTreeNode::new(vec![item], format!("parent-{i}"));
2426        }
2427
2428        let mut visitor =
2429            TestVisitor::new(Box::new(visit_continue), Box::new(visit_continue));
2430
2431        item.visit(&mut visitor).unwrap();
2432    }
2433}