Skip to main content

ts_bart/node/
descendants.rs

1use crate::{
2    Node,
3    node::{Child, child_storage::Storage},
4};
5
6impl<T, C> Node<T, C>
7where
8    C: Storage + ?Sized,
9{
10    /// Get all the descendant [`Child`]ren of this node. Order is depth-first,
11    /// in-order by address.
12    #[inline]
13    pub fn descendants(
14        &self,
15        include_self: bool,
16    ) -> impl Iterator<Item = (heapless::Vec<u8, 16>, Child<&Self, &T>)> {
17        DescendantIter::new(self, include_self)
18    }
19
20    /// Get all the descendant [`Node`]s of this node. Order is depth-first,
21    /// in-order by address.
22    #[inline]
23    pub fn descendant_nodes(
24        &self,
25        include_self: bool,
26    ) -> impl Iterator<Item = (heapless::Vec<u8, 16>, &Self)> {
27        DescendantIter::new(self, include_self).filter_map(|(addr, child)| match child {
28            Child::Path(node) => Some((addr, node)),
29            _ => None,
30        })
31    }
32}
33
34/// Provides a DFS walk of the trie rooted at a given node.
35struct DescendantIter<'a, T, C>
36where
37    C: Storage + ?Sized,
38{
39    /// The current (being-iterated) item in the trie.
40    node_path: heapless::Vec<(u8, &'a Node<T, C>), 16>,
41
42    /// The next child address to be considered.
43    next_child: u8,
44
45    yield_self: bool,
46}
47
48impl<'a, T, C> DescendantIter<'a, T, C>
49where
50    C: Storage + ?Sized,
51{
52    #[inline]
53    fn new(node: &'a Node<T, C>, include_self: bool) -> Self {
54        Self {
55            // The first address in the path is ignored.
56            node_path: heapless::Vec::from_iter([(0, node)]),
57            next_child: 0,
58            yield_self: include_self,
59        }
60    }
61}
62
63impl<'a, T, C> Iterator for DescendantIter<'a, T, C>
64where
65    C: Storage + ?Sized,
66{
67    type Item = (heapless::Vec<u8, 16>, Child<&'a Node<T, C>, &'a T>);
68
69    fn next(&mut self) -> Option<Self::Item> {
70        if self.yield_self {
71            self.yield_self = false;
72            return Some((
73                heapless::Vec::new(),
74                // invariant: always constructed with `new`, there is always at least one
75                // entry
76                Child::Path(self.node_path.first().unwrap().1),
77            ));
78        }
79
80        while let Some(&(this_addr, node)) = self.node_path.last() {
81            let ret = match node
82                .children
83                .iter()
84                .find(|&(addr, _)| addr >= self.next_child)
85            {
86                Some((addr, child)) => {
87                    let mut path = self
88                        .node_path
89                        .iter()
90                        .map(|&(addr, _node)| addr)
91                        .skip(1) // skip the root item's path
92                        .collect::<heapless::Vec<u8, 16>>();
93
94                    path.push(addr).unwrap();
95
96                    if let Child::Path(node) = child {
97                        // invariant: node path is sized to fit any ipv4/ipv6 addr
98                        self.node_path
99                            .push((addr, C::as_ref(node)))
100                            .map_err(|_| ())
101                            .unwrap();
102                        self.next_child = 0;
103                        return Some((path, child.as_ref().map_node(C::as_ref)));
104                    }
105
106                    self.next_child = addr;
107
108                    Some((path, child.as_ref().map_node(C::as_ref)))
109                }
110
111                None => {
112                    self.node_path.pop();
113                    self.next_child = this_addr;
114                    None
115                }
116            };
117
118            while self.next_child == 255 {
119                let Some((popped, _)) = self.node_path.pop() else {
120                    break;
121                };
122
123                self.next_child = popped;
124            }
125
126            self.next_child += 1;
127
128            if ret.is_some() {
129                return ret;
130            }
131        }
132
133        None
134    }
135}
136
137#[cfg(test)]
138mod test {
139    use super::*;
140    use crate::node::StrideOpsExt;
141
142    #[test]
143    fn zero() {
144        assert_eq!(0, Node::<()>::EMPTY.descendants(false).count());
145        assert_eq!(0, Node::<()>::EMPTY.descendant_nodes(false).count());
146
147        assert_eq!(1, Node::<()>::EMPTY.descendants(true).count());
148        assert_eq!(1, Node::<()>::EMPTY.descendant_nodes(true).count());
149    }
150
151    #[test]
152    fn single_level() {
153        let node = Node::<()>::EMPTY
154            .with_child(0, Child::Fringe(()))
155            .with_child(1, Child::Fringe(()))
156            .with_child(255, Child::Fringe(()));
157
158        assert_eq!(3, node.descendants(false).count());
159
160        let node = node.with_child(
161            3,
162            Child::Leaf {
163                prefix: Default::default(),
164                value: Default::default(),
165            },
166        );
167        assert_eq!(4, node.descendants(false).count());
168        assert_eq!(0, node.descendant_nodes(false).count());
169
170        node.descendants(false)
171            .zip([0, 1, 3, 255])
172            .for_each(|((path, child), addr)| {
173                assert_eq!(&path, &[addr]);
174                match child {
175                    Child::Leaf { .. } | Child::Fringe(..) => {}
176                    Child::Path(..) => panic!(),
177                }
178            });
179    }
180
181    #[test]
182    fn multi_level() {
183        let node = Node::<()>::EMPTY
184            .with_child(0, Child::dummy_leaf())
185            .with_child(
186                2,
187                Node::EMPTY
188                    .with_child(12, Child::dummy_leaf())
189                    .with_child(32, Child::dummy_fringe())
190                    .into_child(),
191            )
192            .with_child(5, Child::dummy_leaf())
193            .with_child(
194                255,
195                Node::EMPTY
196                    .with_child(0, Child::dummy_fringe())
197                    .with_child(1, Child::dummy_fringe())
198                    .with_child(255, Node::EMPTY.into_child())
199                    .into_child(),
200            );
201
202        assert_eq!(9, node.descendants(false).count());
203        assert_eq!(10, node.descendants(true).count());
204        assert_eq!(3, node.descendant_nodes(false).count());
205        assert_eq!(4, node.descendant_nodes(true).count());
206
207        for ((path, _child), expected_path) in node.descendants(false).zip([
208            &[0u8] as &[u8],
209            &[2],
210            &[2, 12],
211            &[2, 32],
212            &[5],
213            &[255],
214            &[255, 0],
215            &[255, 1],
216            &[255, 255],
217        ]) {
218            assert_eq!(&path, expected_path);
219        }
220
221        for ((path, _child), expected_path) in
222            node.descendant_nodes(false)
223                .zip([&[2u8] as &[u8], &[255], &[255, 255]])
224        {
225            assert_eq!(&path, expected_path);
226        }
227    }
228}