1use crate::{
2 Node,
3 node::{Child, child_storage::Storage},
4};
5
6impl<T, C> Node<T, C>
7where
8 C: Storage + ?Sized,
9{
10 #[inline]
13 pub fn descendants(
14 &self,
15 include_self: bool,
16 ) -> impl Iterator<Item = (heapless::Vec<u8, 16>, Child<&Self, &T>)> {
17 DescendantIter::new(self, include_self)
18 }
19
20 #[inline]
23 pub fn descendant_nodes(
24 &self,
25 include_self: bool,
26 ) -> impl Iterator<Item = (heapless::Vec<u8, 16>, &Self)> {
27 DescendantIter::new(self, include_self).filter_map(|(addr, child)| match child {
28 Child::Path(node) => Some((addr, node)),
29 _ => None,
30 })
31 }
32}
33
34struct DescendantIter<'a, T, C>
36where
37 C: Storage + ?Sized,
38{
39 node_path: heapless::Vec<(u8, &'a Node<T, C>), 16>,
41
42 next_child: u8,
44
45 yield_self: bool,
46}
47
48impl<'a, T, C> DescendantIter<'a, T, C>
49where
50 C: Storage + ?Sized,
51{
52 #[inline]
53 fn new(node: &'a Node<T, C>, include_self: bool) -> Self {
54 Self {
55 node_path: heapless::Vec::from_iter([(0, node)]),
57 next_child: 0,
58 yield_self: include_self,
59 }
60 }
61}
62
63impl<'a, T, C> Iterator for DescendantIter<'a, T, C>
64where
65 C: Storage + ?Sized,
66{
67 type Item = (heapless::Vec<u8, 16>, Child<&'a Node<T, C>, &'a T>);
68
69 fn next(&mut self) -> Option<Self::Item> {
70 if self.yield_self {
71 self.yield_self = false;
72 return Some((
73 heapless::Vec::new(),
74 Child::Path(self.node_path.first().unwrap().1),
77 ));
78 }
79
80 while let Some(&(this_addr, node)) = self.node_path.last() {
81 let ret = match node
82 .children
83 .iter()
84 .find(|&(addr, _)| addr >= self.next_child)
85 {
86 Some((addr, child)) => {
87 let mut path = self
88 .node_path
89 .iter()
90 .map(|&(addr, _node)| addr)
91 .skip(1) .collect::<heapless::Vec<u8, 16>>();
93
94 path.push(addr).unwrap();
95
96 if let Child::Path(node) = child {
97 self.node_path
99 .push((addr, C::as_ref(node)))
100 .map_err(|_| ())
101 .unwrap();
102 self.next_child = 0;
103 return Some((path, child.as_ref().map_node(C::as_ref)));
104 }
105
106 self.next_child = addr;
107
108 Some((path, child.as_ref().map_node(C::as_ref)))
109 }
110
111 None => {
112 self.node_path.pop();
113 self.next_child = this_addr;
114 None
115 }
116 };
117
118 while self.next_child == 255 {
119 let Some((popped, _)) = self.node_path.pop() else {
120 break;
121 };
122
123 self.next_child = popped;
124 }
125
126 self.next_child += 1;
127
128 if ret.is_some() {
129 return ret;
130 }
131 }
132
133 None
134 }
135}
136
137#[cfg(test)]
138mod test {
139 use super::*;
140 use crate::node::StrideOpsExt;
141
142 #[test]
143 fn zero() {
144 assert_eq!(0, Node::<()>::EMPTY.descendants(false).count());
145 assert_eq!(0, Node::<()>::EMPTY.descendant_nodes(false).count());
146
147 assert_eq!(1, Node::<()>::EMPTY.descendants(true).count());
148 assert_eq!(1, Node::<()>::EMPTY.descendant_nodes(true).count());
149 }
150
151 #[test]
152 fn single_level() {
153 let node = Node::<()>::EMPTY
154 .with_child(0, Child::Fringe(()))
155 .with_child(1, Child::Fringe(()))
156 .with_child(255, Child::Fringe(()));
157
158 assert_eq!(3, node.descendants(false).count());
159
160 let node = node.with_child(
161 3,
162 Child::Leaf {
163 prefix: Default::default(),
164 value: Default::default(),
165 },
166 );
167 assert_eq!(4, node.descendants(false).count());
168 assert_eq!(0, node.descendant_nodes(false).count());
169
170 node.descendants(false)
171 .zip([0, 1, 3, 255])
172 .for_each(|((path, child), addr)| {
173 assert_eq!(&path, &[addr]);
174 match child {
175 Child::Leaf { .. } | Child::Fringe(..) => {}
176 Child::Path(..) => panic!(),
177 }
178 });
179 }
180
181 #[test]
182 fn multi_level() {
183 let node = Node::<()>::EMPTY
184 .with_child(0, Child::dummy_leaf())
185 .with_child(
186 2,
187 Node::EMPTY
188 .with_child(12, Child::dummy_leaf())
189 .with_child(32, Child::dummy_fringe())
190 .into_child(),
191 )
192 .with_child(5, Child::dummy_leaf())
193 .with_child(
194 255,
195 Node::EMPTY
196 .with_child(0, Child::dummy_fringe())
197 .with_child(1, Child::dummy_fringe())
198 .with_child(255, Node::EMPTY.into_child())
199 .into_child(),
200 );
201
202 assert_eq!(9, node.descendants(false).count());
203 assert_eq!(10, node.descendants(true).count());
204 assert_eq!(3, node.descendant_nodes(false).count());
205 assert_eq!(4, node.descendant_nodes(true).count());
206
207 for ((path, _child), expected_path) in node.descendants(false).zip([
208 &[0u8] as &[u8],
209 &[2],
210 &[2, 12],
211 &[2, 32],
212 &[5],
213 &[255],
214 &[255, 0],
215 &[255, 1],
216 &[255, 255],
217 ]) {
218 assert_eq!(&path, expected_path);
219 }
220
221 for ((path, _child), expected_path) in
222 node.descendant_nodes(false)
223 .zip([&[2u8] as &[u8], &[255], &[255, 255]])
224 {
225 assert_eq!(&path, expected_path);
226 }
227 }
228}