certified_vars/rbtree/
iterator.rs

1use super::{Node, RbTree};
2use crate::label::Label;
3use crate::AsHashTree;
4use std::marker::PhantomData;
5
6/// An iterator over key-values in a RbTree.
7pub struct RbTreeIterator<'tree, K: 'static + Label, V: AsHashTree + 'static> {
8    visit: *mut Node<K, V>,
9    stack: Vec<*mut Node<K, V>>,
10    remaining_elements: usize,
11    lifetime: PhantomData<&'tree RbTree<K, V>>,
12}
13
14impl<'tree, K: 'static + Label, V: AsHashTree + 'static> RbTreeIterator<'tree, K, V> {
15    pub fn new(tree: &'tree RbTree<K, V>) -> Self {
16        Self {
17            visit: tree.root,
18            stack: Vec::with_capacity(8),
19            remaining_elements: tree.len(),
20            lifetime: PhantomData::default(),
21        }
22    }
23}
24
25impl<'tree, K: 'static + Label, V: AsHashTree + 'static> Iterator for RbTreeIterator<'tree, K, V> {
26    type Item = (&'tree K, &'tree V);
27
28    #[inline]
29    fn next(&mut self) -> Option<Self::Item> {
30        unsafe {
31            while !self.visit.is_null() {
32                self.stack.push(self.visit);
33                self.visit = (*self.visit).left;
34            }
35
36            if let Some(node) = self.stack.pop() {
37                self.visit = (*node).right;
38                self.remaining_elements -= 1;
39                return Some((&(*node).key, &(*node).value));
40            }
41
42            None
43        }
44    }
45
46    #[inline]
47    fn size_hint(&self) -> (usize, Option<usize>) {
48        (self.remaining_elements, Some(self.remaining_elements))
49    }
50}
51
52#[cfg(test)]
53mod tests {
54    use super::*;
55
56    #[test]
57    fn should_visit_all() {
58        let mut tree = RbTree::<[u8; 1], u8>::new();
59
60        for i in 0..250u8 {
61            tree.insert([i], i);
62        }
63
64        let iter = RbTreeIterator::new(&tree);
65
66        let mut expected_v = 0u8;
67
68        for (_, v) in iter {
69            assert_eq!(v, &expected_v);
70            expected_v += 1;
71        }
72
73        assert_eq!(expected_v, 250);
74    }
75}