1#![forbid(unsafe_code)]
2
3use std::collections::HashMap;
34
35#[derive(Debug, Clone)]
37pub struct TreeNode<T> {
38 pub id: u32,
40 pub data: T,
42 pub children: Vec<u32>,
44}
45
46impl<T> TreeNode<T> {
47 pub fn new(id: u32, data: T, children: Vec<u32>) -> Self {
49 Self { id, data, children }
50 }
51}
52
53#[derive(Debug, Clone)]
55pub struct VebEntry<T> {
56 pub id: u32,
58 pub data: T,
60 pub child_indices: Vec<u32>,
62 pub parent_index: u32,
64 pub depth: u16,
66}
67
68#[derive(Debug, Clone)]
70pub struct VebTree<T> {
71 nodes: Vec<VebEntry<T>>,
73 index: HashMap<u32, u32>,
75}
76
77impl<T: Clone> VebTree<T> {
78 pub fn build(input: Vec<TreeNode<T>>) -> Self {
84 if input.is_empty() {
85 return Self {
86 nodes: Vec::new(),
87 index: HashMap::new(),
88 };
89 }
90
91 let node_map: HashMap<u32, &TreeNode<T>> = input.iter().map(|n| (n.id, n)).collect();
93
94 let all_children: std::collections::HashSet<u32> = input
96 .iter()
97 .flat_map(|n| n.children.iter().copied())
98 .collect();
99 let root_id = input
100 .iter()
101 .find(|n| !all_children.contains(&n.id))
102 .map(|n| n.id)
103 .unwrap_or(input[0].id);
104
105 let mut depths: HashMap<u32, u16> = HashMap::new();
107 let mut queue = std::collections::VecDeque::new();
108 queue.push_back((root_id, 0u16));
109 while let Some((nid, d)) = queue.pop_front() {
110 depths.insert(nid, d);
111 if let Some(node) = node_map.get(&nid) {
112 for &cid in &node.children {
113 queue.push_back((cid, d + 1));
114 }
115 }
116 }
117
118 let mut dfs_order: Vec<u32> = Vec::with_capacity(input.len());
120 let mut stack = vec![root_id];
121 while let Some(nid) = stack.pop() {
122 dfs_order.push(nid);
123 if let Some(node) = node_map.get(&nid) {
124 for &cid in node.children.iter().rev() {
126 stack.push(cid);
127 }
128 }
129 }
130
131 let veb_order = veb_layout_order(&dfs_order, &node_map);
133
134 let mut id_to_pos: HashMap<u32, u32> = HashMap::with_capacity(veb_order.len());
136 for (pos, &nid) in veb_order.iter().enumerate() {
137 id_to_pos.insert(nid, pos as u32);
138 }
139
140 let mut parent_map: HashMap<u32, u32> = HashMap::new();
142 for node in &input {
143 for &cid in &node.children {
144 parent_map.insert(cid, node.id);
145 }
146 }
147
148 let nodes: Vec<VebEntry<T>> = veb_order
149 .iter()
150 .map(|&nid| {
151 let node = node_map[&nid];
152 let child_indices: Vec<u32> = node
153 .children
154 .iter()
155 .filter_map(|cid| id_to_pos.get(cid).copied())
156 .collect();
157 let parent_index = parent_map
158 .get(&nid)
159 .and_then(|pid| id_to_pos.get(pid).copied())
160 .unwrap_or(u32::MAX);
161 VebEntry {
162 id: nid,
163 data: node.data.clone(),
164 child_indices,
165 parent_index,
166 depth: depths.get(&nid).copied().unwrap_or(0),
167 }
168 })
169 .collect();
170
171 Self {
172 nodes,
173 index: id_to_pos,
174 }
175 }
176
177 #[inline]
179 pub fn len(&self) -> usize {
180 self.nodes.len()
181 }
182
183 #[inline]
185 pub fn is_empty(&self) -> bool {
186 self.nodes.is_empty()
187 }
188
189 #[inline]
191 pub fn get(&self, id: u32) -> Option<&VebEntry<T>> {
192 self.index.get(&id).map(|&pos| &self.nodes[pos as usize])
193 }
194
195 #[inline]
197 pub fn get_by_index(&self, idx: u32) -> Option<&VebEntry<T>> {
198 self.nodes.get(idx as usize)
199 }
200
201 pub fn iter(&self) -> impl Iterator<Item = &VebEntry<T>> {
203 self.nodes.iter()
204 }
205
206 pub fn iter_dfs(&self) -> Vec<&VebEntry<T>> {
208 if self.nodes.is_empty() {
209 return Vec::new();
210 }
211 let mut result = Vec::with_capacity(self.nodes.len());
212 let mut stack = vec![0u32]; while let Some(idx) = stack.pop() {
214 if let Some(entry) = self.nodes.get(idx as usize) {
215 result.push(entry);
216 for &ci in entry.child_indices.iter().rev() {
217 stack.push(ci);
218 }
219 }
220 }
221 result
222 }
223
224 pub fn root(&self) -> Option<&VebEntry<T>> {
226 self.nodes.first()
227 }
228
229 pub fn as_slice(&self) -> &[VebEntry<T>] {
231 &self.nodes
232 }
233}
234
235fn veb_layout_order<T>(dfs_order: &[u32], node_map: &HashMap<u32, &TreeNode<T>>) -> Vec<u32> {
242 if dfs_order.len() <= 1 {
243 return dfs_order.to_vec();
244 }
245
246 let root = dfs_order[0];
248 let mut depths: HashMap<u32, u16> = HashMap::new();
249 let mut queue = std::collections::VecDeque::new();
250 queue.push_back((root, 0u16));
251 let subtree_set: std::collections::HashSet<u32> = dfs_order.iter().copied().collect();
252 while let Some((nid, d)) = queue.pop_front() {
253 depths.insert(nid, d);
254 if let Some(node) = node_map.get(&nid) {
255 for &cid in &node.children {
256 if subtree_set.contains(&cid) {
257 queue.push_back((cid, d + 1));
258 }
259 }
260 }
261 }
262
263 let max_depth = depths.values().copied().max().unwrap_or(0);
264 if max_depth <= 1 {
265 return dfs_order.to_vec();
267 }
268
269 let mid_depth = max_depth / 2;
270
271 let mut top: Vec<u32> = Vec::new();
273 let mut bottom_roots: Vec<u32> = Vec::new();
274 let mut bottom_subtrees: HashMap<u32, Vec<u32>> = HashMap::new();
275
276 for &nid in dfs_order {
277 let d = depths.get(&nid).copied().unwrap_or(0);
278 if d <= mid_depth {
279 top.push(nid);
280 if let Some(node) = node_map.get(&nid) {
282 for &cid in &node.children {
283 if subtree_set.contains(&cid) {
284 let cd = depths.get(&cid).copied().unwrap_or(0);
285 if cd > mid_depth {
286 bottom_roots.push(cid);
287 }
288 }
289 }
290 }
291 }
292 }
293
294 for &br in &bottom_roots {
296 let mut subtree = Vec::new();
297 let mut stack = vec![br];
298 while let Some(nid) = stack.pop() {
299 if subtree_set.contains(&nid) {
300 subtree.push(nid);
301 if let Some(node) = node_map.get(&nid) {
302 for &cid in node.children.iter().rev() {
303 if subtree_set.contains(&cid) {
304 stack.push(cid);
305 }
306 }
307 }
308 }
309 }
310 bottom_subtrees.insert(br, subtree);
311 }
312
313 let mut result = veb_layout_order(&top, node_map);
315 for &br in &bottom_roots {
316 if let Some(subtree) = bottom_subtrees.get(&br) {
317 result.extend(veb_layout_order(subtree, node_map));
318 }
319 }
320 result
321}
322
323#[cfg(test)]
324mod tests {
325 use super::*;
326
327 fn make_binary_tree(depth: u16) -> Vec<TreeNode<String>> {
328 let mut nodes = Vec::new();
329 let mut next_id = 1u32;
330 fn build(
331 id: u32,
332 depth: u16,
333 remaining: u16,
334 next_id: &mut u32,
335 nodes: &mut Vec<TreeNode<String>>,
336 ) {
337 let label = format!("node_{id}_d{depth}");
338 if remaining == 0 {
339 nodes.push(TreeNode::new(id, label, vec![]));
340 return;
341 }
342 let left = *next_id;
343 *next_id += 1;
344 let right = *next_id;
345 *next_id += 1;
346 nodes.push(TreeNode::new(id, label, vec![left, right]));
347 build(left, depth + 1, remaining - 1, next_id, nodes);
348 build(right, depth + 1, remaining - 1, next_id, nodes);
349 }
350 build(0, 0, depth, &mut next_id, &mut nodes);
351 nodes
352 }
353
354 #[test]
355 fn empty_tree() {
356 let tree: VebTree<&str> = VebTree::build(vec![]);
357 assert!(tree.is_empty());
358 assert_eq!(tree.len(), 0);
359 assert!(tree.root().is_none());
360 }
361
362 #[test]
363 fn single_node() {
364 let tree = VebTree::build(vec![TreeNode::new(42, "solo", vec![])]);
365 assert_eq!(tree.len(), 1);
366 let root = tree.root().unwrap();
367 assert_eq!(root.id, 42);
368 assert_eq!(root.data, "solo");
369 assert!(root.child_indices.is_empty());
370 assert_eq!(root.parent_index, u32::MAX);
371 }
372
373 #[test]
374 fn three_node_tree() {
375 let nodes = vec![
376 TreeNode::new(0, "root", vec![1, 2]),
377 TreeNode::new(1, "left", vec![]),
378 TreeNode::new(2, "right", vec![]),
379 ];
380 let tree = VebTree::build(nodes);
381 assert_eq!(tree.len(), 3);
382 assert_eq!(tree.get(0).unwrap().data, "root");
383 assert_eq!(tree.get(1).unwrap().data, "left");
384 assert_eq!(tree.get(2).unwrap().data, "right");
385 }
386
387 #[test]
388 fn lookup_by_id() {
389 let nodes = vec![
390 TreeNode::new(10, "a", vec![20, 30]),
391 TreeNode::new(20, "b", vec![]),
392 TreeNode::new(30, "c", vec![]),
393 ];
394 let tree = VebTree::build(nodes);
395 assert_eq!(tree.get(10).unwrap().data, "a");
396 assert_eq!(tree.get(20).unwrap().data, "b");
397 assert_eq!(tree.get(30).unwrap().data, "c");
398 assert!(tree.get(99).is_none());
399 }
400
401 #[test]
402 fn parent_indices_correct() {
403 let nodes = vec![
404 TreeNode::new(0, "r", vec![1, 2]),
405 TreeNode::new(1, "l", vec![3]),
406 TreeNode::new(2, "r2", vec![]),
407 TreeNode::new(3, "ll", vec![]),
408 ];
409 let tree = VebTree::build(nodes);
410 let root = tree.get(0).unwrap();
411 assert_eq!(root.parent_index, u32::MAX);
412
413 let left = tree.get(1).unwrap();
414 let root_pos = tree.index[&0];
415 assert_eq!(left.parent_index, root_pos);
416
417 let ll = tree.get(3).unwrap();
418 let left_pos = tree.index[&1];
419 assert_eq!(ll.parent_index, left_pos);
420 }
421
422 #[test]
423 fn child_indices_correct() {
424 let nodes = vec![
425 TreeNode::new(0, "r", vec![1, 2]),
426 TreeNode::new(1, "l", vec![]),
427 TreeNode::new(2, "r2", vec![]),
428 ];
429 let tree = VebTree::build(nodes);
430 let root = tree.get(0).unwrap();
431 assert_eq!(root.child_indices.len(), 2);
432
433 for &ci in &root.child_indices {
435 let child = tree.get_by_index(ci).unwrap();
436 assert!(child.id == 1 || child.id == 2);
437 }
438 }
439
440 #[test]
441 fn dfs_iteration_preserves_all_nodes() {
442 let nodes = make_binary_tree(3);
443 let count = nodes.len();
444 let tree = VebTree::build(nodes);
445 let dfs = tree.iter_dfs();
446 assert_eq!(dfs.len(), count);
447
448 let mut ids: Vec<u32> = dfs.iter().map(|e| e.id).collect();
450 ids.sort();
451 let mut expected: Vec<u32> = (0..count as u32).collect();
452 expected.sort();
453 assert_eq!(ids, expected);
454 }
455
456 #[test]
457 fn dfs_root_first() {
458 let nodes = make_binary_tree(3);
459 let tree = VebTree::build(nodes);
460 let dfs = tree.iter_dfs();
461 assert_eq!(dfs[0].id, 0); }
463
464 #[test]
465 fn veb_order_contains_all_nodes() {
466 let nodes = make_binary_tree(4);
467 let count = nodes.len();
468 let tree = VebTree::build(nodes);
469 assert_eq!(tree.len(), count);
470
471 let veb_ids: Vec<u32> = tree.iter().map(|e| e.id).collect();
473 assert_eq!(veb_ids.len(), count);
474
475 let mut sorted = veb_ids.clone();
477 sorted.sort();
478 sorted.dedup();
479 assert_eq!(sorted.len(), count);
480 }
481
482 #[test]
483 fn depth_values_correct() {
484 let nodes = vec![
485 TreeNode::new(0, "d0", vec![1, 2]),
486 TreeNode::new(1, "d1a", vec![3]),
487 TreeNode::new(2, "d1b", vec![]),
488 TreeNode::new(3, "d2", vec![]),
489 ];
490 let tree = VebTree::build(nodes);
491 assert_eq!(tree.get(0).unwrap().depth, 0);
492 assert_eq!(tree.get(1).unwrap().depth, 1);
493 assert_eq!(tree.get(2).unwrap().depth, 1);
494 assert_eq!(tree.get(3).unwrap().depth, 2);
495 }
496
497 #[test]
498 fn large_tree_1000_nodes() {
499 let nodes: Vec<TreeNode<u32>> = (0..1000)
501 .map(|i| {
502 let children = if i < 999 { vec![i + 1] } else { vec![] };
503 TreeNode::new(i, i, children)
504 })
505 .collect();
506 let tree = VebTree::build(nodes);
507 assert_eq!(tree.len(), 1000);
508 assert_eq!(tree.get(0).unwrap().depth, 0);
509 assert_eq!(tree.get(999).unwrap().depth, 999);
510 }
511
512 #[test]
513 fn layout_results_identical() {
514 let nodes = make_binary_tree(4);
516 let tree = VebTree::build(nodes);
517
518 let veb_ids: std::collections::HashSet<u32> = tree.iter().map(|e| e.id).collect();
519 let dfs_ids: std::collections::HashSet<u32> =
520 tree.iter_dfs().iter().map(|e| e.id).collect();
521 assert_eq!(veb_ids, dfs_ids);
522 }
523
524 #[test]
525 fn wide_tree() {
526 let mut nodes = vec![TreeNode::new(0, 0u32, (1..=100).collect())];
528 for i in 1..=100 {
529 nodes.push(TreeNode::new(i, i, vec![]));
530 }
531 let tree = VebTree::build(nodes);
532 assert_eq!(tree.len(), 101);
533 assert_eq!(tree.get(0).unwrap().child_indices.len(), 100);
534 }
535
536 #[test]
537 fn rebuild_produces_same_result() {
538 let nodes = make_binary_tree(3);
539 let tree1 = VebTree::build(nodes.clone());
540 let tree2 = VebTree::build(nodes);
541
542 let ids1: Vec<u32> = tree1.iter().map(|e| e.id).collect();
543 let ids2: Vec<u32> = tree2.iter().map(|e| e.id).collect();
544 assert_eq!(ids1, ids2);
545 }
546}