Skip to main content

ark/tree/
mod.rs

1
2pub mod signed;
3
4use std::cmp;
5
6/// The max radix of this tree is 4.
7const RADIX: usize = 4;
8
9#[derive(Debug, Clone)]
10pub struct Node {
11	idx: u32,
12	parent: Option<u32>,
13	children: [Option<u32>; RADIX],
14	/// Exclusive range of leaves, allowed to revolve back to 0.
15	leaves: (u32, u32),
16	nb_tree_leaves: u32,
17	level: u32,
18}
19
20impl Node {
21	// Tree construction arithmetic on bounded leaf indices and a small RADIX.
22	#[allow(clippy::arithmetic_side_effects)]
23	fn new_leaf(idx: usize, nb_tree_leaves: usize) -> Node {
24		let idx = idx as u32;
25		Node {
26			idx,
27			parent: None,
28			children: [None; RADIX],
29			leaves: (idx, (idx+1) % nb_tree_leaves as u32),
30			nb_tree_leaves: nb_tree_leaves as u32,
31			level: 0,
32		}
33	}
34
35	pub fn idx(&self) -> usize {
36		self.idx as usize
37	}
38
39	/// The index among internal nodes, starting after the leaves
40	///
41	/// Panics if this node is a leaf node, if [Node::is_leaf] returns true.
42	pub fn internal_idx(&self) -> usize {
43		self.idx.checked_sub(self.nb_tree_leaves)
44			.expect("called internal_idx on leaf node") as usize
45	}
46
47	pub fn parent(&self) -> Option<usize> {
48		self.parent.map(|p| p as usize)
49	}
50
51	pub fn children(&self) -> impl Iterator<Item = usize> {
52		self.children.clone().into_iter().filter_map(|c| c).map(|c| c as usize)
53	}
54
55	/// The level of the node in the tree, starting with 0 for a leaf
56	pub fn level(&self) -> usize {
57		self.level as usize
58	}
59
60	/// The internal level of the node in the tree
61	///
62	/// Panics if this node is a leaf node, if [Node::is_leaf] returns true.
63	///
64	/// Returns 0 for a node  that has leaves as children
65	pub fn internal_level(&self) -> usize {
66		self.level.checked_sub(1).expect("called internal_level on leaf node") as usize
67	}
68
69	/// An iterator over all leaf indices under this node.
70	#[allow(clippy::arithmetic_side_effects)]
71	pub fn leaves(&self) -> impl Iterator<Item = usize> + Clone {
72		let (first, last) = self.leaves;
73		let nb = self.nb_tree_leaves;
74		(first..)
75			.take(nb as usize)
76			.map(move |e| e % nb)
77			.take_while(move |e| first == last || *e != last)
78			.map(|e| e as usize)
79	}
80
81	pub fn is_leaf(&self) -> bool {
82		self.children.iter().all(|o| o.is_none())
83	}
84
85	pub fn is_root(&self) -> bool {
86		self.parent.is_none()
87	}
88}
89
90//TODO(stevenroose) consider eliminating this type in favor of straight in-line iterators
91// for all nodes and for branches
92/// A radix-4 tree.
93#[derive(Debug, Clone)]
94pub struct Tree {
95	/// The nodes in the tree, starting with all the leaves
96	/// and then building up towards the root.
97	nodes: Vec<Node>,
98	nb_leaves: usize,
99}
100
101impl Tree {
102	/// Calculate the total number of nodes a tree would have
103	/// for the given number of leaves.
104	// Tree-size accumulation: bounded by RADIX-base log of nb_leaves.
105	#[allow(clippy::arithmetic_side_effects)]
106	pub fn nb_nodes_for_leaves(nb_leaves: usize) -> usize {
107		let mut ret = nb_leaves;
108		let mut left = nb_leaves;
109		while left > 1 {
110			let radix = cmp::min(left, RADIX);
111			left -= radix;
112			left += 1;
113			ret += 1;
114		}
115		ret
116	}
117
118	// Tree construction: cursor/nb_children/level bounded by RADIX-base log of nb_leaves.
119	#[allow(clippy::arithmetic_side_effects)]
120	pub fn new(
121		nb_leaves: usize,
122	) -> Tree {
123		assert_ne!(nb_leaves, 0, "trees can't be empty");
124
125		let mut nodes = Vec::with_capacity(Tree::nb_nodes_for_leaves(nb_leaves));
126
127		// First we add all the leaves to the tree.
128		nodes.extend((0..nb_leaves).map(|i| Node::new_leaf(i, nb_leaves)));
129
130		let mut cursor = 0;
131		// As long as there is more than 1 element on the leftover stack,
132		// we have to add more nodes.
133		while cursor < nodes.len() - 1 {
134			let mut children = [None; RADIX];
135			let mut nb_children = 0;
136			let mut max_child_level = 0;
137			while cursor < nodes.len() && nb_children < RADIX {
138				children[nb_children] = Some(cursor as u32);
139
140				let new_idx = nodes.len(); // idx of next node
141				let child = &mut nodes[cursor];
142				child.parent = Some(new_idx as u32);
143
144				// adjust level and leaf indices
145				if child.level > max_child_level {
146					max_child_level = child.level;
147				}
148
149				cursor += 1;
150				nb_children += 1;
151			}
152			nodes.push(Node {
153				idx: nodes.len() as u32,
154				leaves: (
155					nodes[children.first().unwrap().unwrap() as usize].leaves.0,
156					nodes[children.iter().filter_map(|c| *c).last().unwrap() as usize].leaves.1,
157				),
158				children,
159				level: max_child_level + 1,
160				parent: None,
161				nb_tree_leaves: nb_leaves as u32,
162			});
163		}
164
165		Tree { nodes, nb_leaves }
166	}
167
168	pub fn nb_leaves(&self) -> usize {
169		self.nb_leaves
170	}
171
172	pub fn nb_nodes(&self) -> usize {
173		self.nodes.len()
174	}
175
176	/// The number of internal nodes
177	pub fn nb_internal_nodes(&self) -> usize {
178		self.nodes.len().checked_sub(self.nb_leaves)
179			.expect("tree can't have less nodes than leaves")
180	}
181
182	pub fn node_at(&self, node_idx: usize) -> &Node {
183		self.nodes.get(node_idx).expect("node_idx out of bounds")
184	}
185
186	pub fn root(&self) -> &Node {
187		self.nodes.last().expect("no empty trees")
188	}
189
190	/// Iterate over all nodes, starting with the leaves, towards the root.
191	pub fn iter(&self) -> std::slice::Iter<'_, Node> {
192		self.nodes.iter()
193	}
194
195	/// Iterate over all internal nodes, starting with the ones
196	/// right beyond the leaves, towards the root.
197	pub fn iter_internal(&self) -> std::slice::Iter<'_, Node> {
198		self.nodes[self.nb_leaves..].iter()
199	}
200
201	/// Iterate over all nodes, starting with the leaves, towards the root.
202	pub fn into_iter(self) -> std::vec::IntoIter<Node> {
203		self.nodes.into_iter()
204	}
205
206	/// Iterate nodes over a branch starting at the leaf
207	/// with index `leaf_idx` ending in the root.
208	pub fn iter_branch(&self, leaf_idx: usize) -> BranchIter<'_> {
209		assert!(leaf_idx < self.nodes.len());
210		BranchIter {
211			tree: &self,
212			cursor: Some(leaf_idx),
213		}
214	}
215
216	/// Iterate over ancestors of a node with child indices.
217	///
218	/// Starting from `node_idx`, walks up towards the root. The starting node
219	/// is excluded from iteration. Each returned tuple `(ancestor_idx, child_idx)`
220	/// indicates that `child_idx` is the child position that leads back down
221	/// towards `node_idx`.
222	///
223	/// # Example
224	///
225	/// For a node 12 with children `[4, 5, 6, 7]`:
226	/// ```text
227	/// iter_branch_with_output(6) yields (12, 2), ..., (root_idx, ...)
228	/// ```
229	/// Node 6 is at child index 2 (0-indexed) of node 12.
230	pub fn iter_branch_with_output(&self, node_idx: usize) -> BranchWithOutputIter<'_> {
231		assert!(node_idx < self.nodes.len());
232		BranchWithOutputIter {
233			tree: self,
234			prev_idx: node_idx,
235			cursor: self.nodes[node_idx].parent(),
236		}
237	}
238
239	pub fn parent_idx_of(&self, idx: usize) -> Option<usize> {
240		self.nodes.get(idx).and_then(|n| n.parent.map(|c| c as usize))
241	}
242
243	/// Returns index of the the parent of the node with given `idx`,
244	/// and the index of the node among its siblings.
245	pub fn parent_idx_of_with_sibling_idx(&self, idx: usize) -> Option<(usize, usize)> {
246		self.nodes.get(idx).and_then(|n| n.parent).map(|parent_idx| {
247			let child_idx = self.nodes[parent_idx as usize].children.iter()
248				.position(|c| *c == Some(idx as u32))
249				.expect("broken tree");
250			(self.nodes[parent_idx as usize].idx as usize, child_idx as usize)
251		})
252	}
253
254}
255
256/// Iterates a tree branch.
257#[derive(Clone)]
258pub struct BranchIter<'a> {
259	tree: &'a Tree,
260	cursor: Option<usize>,
261}
262
263impl<'a> Iterator for BranchIter<'a> {
264	type Item = &'a Node;
265	fn next(&mut self) -> Option<Self::Item> {
266		if let Some(cursor) = self.cursor {
267			let ret = &self.tree.nodes[cursor];
268			self.cursor = ret.parent();
269			Some(ret)
270		} else {
271			None
272		}
273	}
274}
275
276/// Iterates ancestors of a node, returning (node_idx, child_idx) tuples.
277#[derive(Clone)]
278pub struct BranchWithOutputIter<'a> {
279	tree: &'a Tree,
280	prev_idx: usize,
281	cursor: Option<usize>,
282}
283
284impl<'a> Iterator for BranchWithOutputIter<'a> {
285	type Item = (usize, usize);
286	fn next(&mut self) -> Option<Self::Item> {
287		let cursor = self.cursor?;
288		let node = &self.tree.nodes[cursor];
289		let child_idx = node.children()
290			.position(|c| c == self.prev_idx)
291			.expect("broken tree");
292		self.prev_idx = cursor;
293		self.cursor = node.parent();
294		Some((cursor, child_idx))
295	}
296}
297
298#[cfg(test)]
299mod test {
300	use std::collections::HashSet;
301
302use super::*;
303
304	#[test]
305	fn test_simple_tree() {
306		for n in 1..100 {
307			let tree = Tree::new(n);
308
309			assert!(tree.nodes.iter().rev().skip(1).all(|n| n.parent.is_some()));
310			assert!(tree.nodes.iter().enumerate().skip(tree.nb_leaves).all(|(i, n)| {
311				n.children.iter().filter_map(|v| *v)
312					.all(|c| tree.nodes[c as usize].parent == Some(i as u32))
313			}));
314			assert!(tree.nodes.iter().enumerate().rev().skip(1).all(|(i, n)| {
315				let parent_idx = n.parent.unwrap() as usize;
316				tree.nodes[parent_idx].children.iter().find(|c| **c == Some(i as u32)).is_some()
317			}));
318			assert_eq!(Tree::nb_nodes_for_leaves(n), tree.nb_nodes(), "leaves: {}", n);
319		}
320	}
321
322	#[test]
323	fn test_leaves_range() {
324		for n in 1..42 {
325			let tree = Tree::new(n);
326
327			for node in &tree.nodes[0..tree.nb_leaves()] {
328				assert_eq!(node.leaves().collect::<Vec<_>>(), vec![node.idx()]);
329			}
330			for node in tree.iter() {
331				if !node.is_leaf() {
332					assert_eq!(
333						node.leaves().count(),
334						node.children().map(|c| tree.nodes[c].leaves().count()).sum::<usize>(),
335						"idx: {}", node.idx(),
336					);
337				}
338				assert!(node.leaves().all(|l| l < tree.nb_leaves()));
339				assert_eq!(
340					node.leaves().count(),
341					node.leaves().collect::<HashSet<_>>().len(),
342				);
343			}
344			println!("n={n} ok");
345		}
346	}
347
348	#[test]
349	fn test_iter_branch_with_output() {
350		for n in 1..100 {
351			let tree = Tree::new(n);
352
353			for start_idx in 0..tree.nb_nodes() {
354				let results: Vec<_> = tree.iter_branch_with_output(start_idx).collect();
355
356				// 1. Verify the iterator excludes the starting node
357				assert!(results.iter().all(|(idx, _)| *idx != start_idx));
358
359				// 2. Verify each returned node is an ancestor of the previous
360				let mut expected_parent = tree.nodes[start_idx].parent();
361				for (ancestor_idx, _) in &results {
362					assert_eq!(Some(*ancestor_idx), expected_parent);
363					expected_parent = tree.nodes[*ancestor_idx].parent();
364				}
365
366				// 3. Verify child_idx actually points back down the branch
367				let mut prev = start_idx;
368				for (ancestor_idx, child_idx) in &results {
369					let child = tree.nodes[*ancestor_idx].children().nth(*child_idx).unwrap();
370					assert_eq!(child, prev);
371					prev = *ancestor_idx;
372				}
373
374				// 4. Verify the last node is the root (has no parent)
375				if let Some((last_idx, _)) = results.last() {
376					assert!(tree.nodes[*last_idx].is_root());
377				}
378
379				// 5. Verify consistency with iter_branch (same path, minus starting node)
380				let branch_len = tree.iter_branch(start_idx).skip(1).count();
381				assert_eq!(results.len(), branch_len);
382			}
383		}
384	}
385}