Algod/data_structures/
heap.rs

1// Heap data structure
2// Takes a closure as a comparator to allow for min-heap, max-heap, and works with custom key functions
3
4use std::cmp::Ord;
5use std::default::Default;
6
7pub struct Heap<T>
8where
9    T: Default,
10{
11    count: usize,
12    items: Vec<T>,
13    comparator: fn(&T, &T) -> bool,
14}
15
16impl<T> Heap<T>
17where
18    T: Default,
19{
20    pub fn new(comparator: fn(&T, &T) -> bool) -> Self {
21        Self {
22            count: 0,
23            // Add a default in the first spot to offset indexes
24            // for the parent/child math to work out.
25            // Vecs have to have all the same type so using Default
26            // is a way to add an unused item.
27            items: vec![T::default()],
28            comparator,
29        }
30    }
31
32    pub fn len(&self) -> usize {
33        self.count
34    }
35
36    pub fn is_empty(&self) -> bool {
37        self.len() == 0
38    }
39
40    pub fn add(&mut self, value: T) {
41        self.count += 1;
42        self.items.push(value);
43
44        // Heapify Up
45        let mut idx = self.count;
46        while self.parent_idx(idx) > 0 {
47            let pdx = self.parent_idx(idx);
48            if (self.comparator)(&self.items[idx], &self.items[pdx]) {
49                self.items.swap(idx, pdx);
50            }
51            idx = pdx;
52        }
53    }
54
55    fn parent_idx(&self, idx: usize) -> usize {
56        idx / 2
57    }
58
59    fn children_present(&self, idx: usize) -> bool {
60        self.left_child_idx(idx) <= self.count
61    }
62
63    fn left_child_idx(&self, idx: usize) -> usize {
64        idx * 2
65    }
66
67    fn right_child_idx(&self, idx: usize) -> usize {
68        self.left_child_idx(idx) + 1
69    }
70
71    fn smallest_child_idx(&self, idx: usize) -> usize {
72        if self.right_child_idx(idx) > self.count {
73            self.left_child_idx(idx)
74        } else {
75            let ldx = self.left_child_idx(idx);
76            let rdx = self.right_child_idx(idx);
77            if (self.comparator)(&self.items[ldx], &self.items[rdx]) {
78                ldx
79            } else {
80                rdx
81            }
82        }
83    }
84}
85
86impl<T> Heap<T>
87where
88    T: Default + Ord,
89{
90    /// Create a new MinHeap
91    pub fn new_min() -> Self {
92        Self::new(|a, b| a < b)
93    }
94
95    /// Create a new MaxHeap
96    pub fn new_max() -> Self {
97        Self::new(|a, b| a > b)
98    }
99}
100
101impl<T> Iterator for Heap<T>
102where
103    T: Default,
104{
105    type Item = T;
106
107    fn next(&mut self) -> Option<T> {
108        if self.count == 0 {
109            return None;
110        }
111        // This feels like a function built for heap impl :)
112        // Removes an item at an index and fills in with the last item
113        // of the Vec
114        let next = Some(self.items.swap_remove(1));
115        self.count -= 1;
116
117        if self.count > 0 {
118            // Heapify Down
119            let mut idx = 1;
120            while self.children_present(idx) {
121                let cdx = self.smallest_child_idx(idx);
122                if !(self.comparator)(&self.items[idx], &self.items[cdx]) {
123                    self.items.swap(idx, cdx);
124                }
125                idx = cdx;
126            }
127        }
128
129        next
130    }
131}
132
133pub struct MinHeap;
134
135impl MinHeap {
136    #[allow(clippy::new_ret_no_self)]
137    pub fn new<T>() -> Heap<T>
138    where
139        T: Default + Ord,
140    {
141        Heap::new(|a, b| a < b)
142    }
143}
144
145pub struct MaxHeap;
146
147impl MaxHeap {
148    #[allow(clippy::new_ret_no_self)]
149    pub fn new<T>() -> Heap<T>
150    where
151        T: Default + Ord,
152    {
153        Heap::new(|a, b| a > b)
154    }
155}
156
157#[cfg(test)]
158mod tests {
159    use super::*;
160    #[test]
161    fn test_empty_heap() {
162        let mut heap = MaxHeap::new::<i32>();
163        assert_eq!(heap.next(), None);
164    }
165
166    #[test]
167    fn test_min_heap() {
168        let mut heap = MinHeap::new();
169        heap.add(4);
170        heap.add(2);
171        heap.add(9);
172        heap.add(11);
173        assert_eq!(heap.len(), 4);
174        assert_eq!(heap.next(), Some(2));
175        assert_eq!(heap.next(), Some(4));
176        assert_eq!(heap.next(), Some(9));
177        heap.add(1);
178        assert_eq!(heap.next(), Some(1));
179    }
180
181    #[test]
182    fn test_max_heap() {
183        let mut heap = MaxHeap::new();
184        heap.add(4);
185        heap.add(2);
186        heap.add(9);
187        heap.add(11);
188        assert_eq!(heap.len(), 4);
189        assert_eq!(heap.next(), Some(11));
190        assert_eq!(heap.next(), Some(9));
191        assert_eq!(heap.next(), Some(4));
192        heap.add(1);
193        assert_eq!(heap.next(), Some(2));
194    }
195
196    struct Point(/* x */ i32, /* y */ i32);
197    impl Default for Point {
198        fn default() -> Self {
199            Self(0, 0)
200        }
201    }
202
203    #[test]
204    fn test_key_heap() {
205        let mut heap: Heap<Point> = Heap::new(|a, b| a.0 < b.0);
206        heap.add(Point(1, 5));
207        heap.add(Point(3, 10));
208        heap.add(Point(-2, 4));
209        assert_eq!(heap.len(), 3);
210        assert_eq!(heap.next().unwrap().0, -2);
211        assert_eq!(heap.next().unwrap().0, 1);
212        heap.add(Point(50, 34));
213        assert_eq!(heap.next().unwrap().0, 3);
214    }
215}