pumpkin_core/containers/
key_value_heap.rs

1//! A heap where the keys range from [0, ..., n - 1] and the values are nonnegative floating points.
2//! The heap can be queried to return key with the maximum value, and certain keys can be
3//! (temporarily) removed/readded as necessary It allows increasing/decreasing the values of its
4//! entries
5
6// The implementation could be more efficient in the following ways:
7//  - Currently more comparisons are done than necessary when sifting
8//  - Possibly the recursion could be unrolled
9use std::ops::AddAssign;
10use std::ops::DivAssign;
11
12use super::KeyedVec;
13use super::StorageKey;
14use crate::containers::HashSet;
15use crate::pumpkin_assert_moderate;
16
17/// A [max-heap](https://en.wikipedia.org/wiki/Min-max_heap)
18/// which allows for generalised `Key`s (required to implement [StorageKey]) and `Value`s (which are
19/// required to be ordered, divisible and addable).
20#[derive(Debug, Clone)]
21pub struct KeyValueHeap<Key, Value> {
22    /// Contains the values stored as a heap; the value of key `i` is at index
23    /// [`KeyValueHeap::map_key_to_position\[i\]`][KeyValueHeap::map_key_to_position]
24    values: Vec<Value>,
25    /// `map_key_to_position[i]` is the index of the value of the key `i` in
26    /// [`KeyValueHeap::values`]
27    map_key_to_position: KeyedVec<Key, usize>,
28    /// `map_position_to_key[i]` is the key which is associated with `i` in
29    /// [`KeyValueHeap::values`]
30    map_position_to_key: Vec<Key>,
31    /// The index of the last element in [`KeyValueHeap::values`]
32    end_position: usize,
33}
34
35impl<Key: StorageKey, Value> Default for KeyValueHeap<Key, Value> {
36    fn default() -> Self {
37        Self {
38            values: Default::default(),
39            map_key_to_position: Default::default(),
40            map_position_to_key: Default::default(),
41            end_position: Default::default(),
42        }
43    }
44}
45
46impl<Key, Value> KeyValueHeap<Key, Value> {
47    pub(crate) const fn new() -> Self {
48        Self {
49            values: Vec::new(),
50            map_key_to_position: KeyedVec::new(),
51            map_position_to_key: Vec::new(),
52            end_position: 0,
53        }
54    }
55}
56
57impl<Key, Value> KeyValueHeap<Key, Value>
58where
59    Key: StorageKey + Copy,
60    Value: AddAssign<Value> + DivAssign<Value> + PartialOrd + Default + Copy,
61{
62    /// Get the keys in the heap.
63    ///
64    /// The order in which the keys are yielded is unspecified.
65    pub(crate) fn keys(&self) -> impl Iterator<Item = Key> + '_ {
66        self.map_position_to_key[..self.end_position]
67            .iter()
68            .copied()
69    }
70
71    /// Return the key with maximum value from the heap, or None if the heap is empty. Note that
72    /// this does not delete the key (see [`KeyValueHeap::pop_max`] to get and delete).
73    ///
74    /// The time-complexity of this operation is O(1)
75    pub(crate) fn peek_max(&self) -> Option<(&Key, &Value)> {
76        if self.has_no_nonremoved_elements() {
77            None
78        } else {
79            Some((
80                &self.map_position_to_key[0],
81                &self.values[self.map_key_to_position[&self.map_position_to_key[0]]],
82            ))
83        }
84    }
85
86    pub(crate) fn get_value(&self, key: Key) -> &Value {
87        pumpkin_assert_moderate!(
88            key.index() < self.map_key_to_position.len(),
89            "Attempted to get key with index {} for a map with length {}",
90            key.index(),
91            self.map_key_to_position.len()
92        );
93        &self.values[self.map_key_to_position[key]]
94    }
95
96    /// Deletes the key with maximum value from the heap and returns it, or None if the heap is
97    /// empty.
98    ///
99    ///  The time-complexity of this operation is O(logn).
100    pub(crate) fn pop_max(&mut self) -> Option<Key> {
101        if !self.has_no_nonremoved_elements() {
102            let best_key = self.map_position_to_key[0];
103            pumpkin_assert_moderate!(0 == self.map_key_to_position[best_key]);
104            // pumpkin_assert_extreme!(self.is_max_at_top());
105            self.delete_key(best_key);
106            Some(best_key)
107        } else {
108            None
109        }
110    }
111
112    /// Increments the value of the element of 'key' by 'increment'
113    ///
114    /// The worst-case time-complexity of this operation is O(logn); average case is likely to be
115    /// better
116    pub(crate) fn increment(&mut self, key: Key, increment: Value) {
117        let position = self.map_key_to_position[key];
118        self.values[position] += increment;
119        // Recall that increment may be applied to keys not present
120        // So we only apply sift up in case the key is present
121        if self.is_key_present(key) {
122            self.sift_up(position);
123        }
124    }
125
126    /// Restores the entry with key 'key' to the heap if the key is not present, otherwise does
127    /// nothing. Its value is the previous value used before 'delete_key' was called.
128    ///
129    ///  The run-time complexity of this operation is O(logn)
130    pub(crate) fn restore_key(&mut self, key: Key) {
131        if !self.is_key_present(key) {
132            // The key is somewhere in the range [end_position, max_size-1]
133            // We place the key at the end of the heap, increase end_position, and sift up
134            let position = self.map_key_to_position[key];
135            pumpkin_assert_moderate!(position >= self.end_position);
136            self.swap_positions(position, self.end_position);
137            self.end_position += 1;
138            self.sift_up(self.end_position - 1);
139        }
140    }
141
142    /// Removes the entry with key 'key' (temporarily) from the heap if the key is present,
143    /// otherwise does nothing. Its value remains recorded internally and is available upon
144    /// calling [`KeyValueHeap::restore_key`]. The value can still be subjected to
145    /// [`KeyValueHeap::divide_values`].
146    ///
147    /// The run-time complexity of this operation is O(logn)
148    pub(crate) fn delete_key(&mut self, key: Key) {
149        if self.is_key_present(key) {
150            // Place the key at the end of the heap, decrement the heap, and sift down to ensure a
151            // valid heap
152            let position = self.map_key_to_position[key];
153            self.swap_positions(position, self.end_position - 1);
154            self.end_position -= 1;
155            if position < self.end_position {
156                self.sift_down(position);
157            }
158        }
159    }
160
161    /// Returns how many elements are in the heap (including the (temporarily) "removed" values)
162    pub(crate) fn len(&self) -> usize {
163        self.values.len()
164    }
165
166    pub(crate) fn num_nonremoved_elements(&self) -> usize {
167        self.end_position
168    }
169
170    /// Returns whether there are elements left in the heap (excluding the "removed" values)
171    pub(crate) fn has_no_nonremoved_elements(&self) -> bool {
172        self.num_nonremoved_elements() == 0
173    }
174
175    /// Returns whether the key is currently not (temporarily) remove
176    pub(crate) fn is_key_present(&self, key: Key) -> bool {
177        key.index() < self.map_key_to_position.len()
178            && self.map_key_to_position[key] < self.end_position
179    }
180
181    /// Increases the size of the heap by one and adjust the data structures appropriately by adding
182    /// `Key` and `Value`
183    pub(crate) fn grow(&mut self, key: Key, value: Value) {
184        let last_index = self.values.len();
185        self.values.push(value);
186        // Initially the key is placed placed at the very end, will be placed in the correct
187        // position below to ensure a valid heap structure
188        let _ = self.map_key_to_position.push(last_index);
189        self.map_position_to_key.push(key);
190        pumpkin_assert_moderate!(
191            self.map_position_to_key[last_index].index() == key.index()
192                && self.map_key_to_position[key] == last_index
193        );
194        self.swap_positions(self.end_position, last_index);
195        self.end_position += 1;
196        self.sift_up(self.end_position - 1);
197    }
198
199    pub(crate) fn clear(&mut self) {
200        self.values.clear();
201        self.map_key_to_position.clear();
202        self.map_position_to_key.clear();
203        self.end_position = 0;
204    }
205
206    /// Divides all the values in the heap by 'divisor'. This will also affect the values of keys
207    /// that have been [`KeyValueHeap::delete_key`].
208    ///
209    /// The run-time complexity of this operation is O(n)
210    pub(crate) fn divide_values(&mut self, divisor: Value) {
211        for value in self.values.iter_mut() {
212            *value /= divisor;
213        }
214    }
215
216    fn swap_positions(&mut self, a: usize, b: usize) {
217        let key_i = self.map_position_to_key[a];
218        pumpkin_assert_moderate!(self.map_key_to_position[key_i] == a);
219        let key_j = self.map_position_to_key[b];
220        pumpkin_assert_moderate!(self.map_key_to_position[key_j] == b);
221
222        self.values.swap(a, b);
223        self.map_position_to_key.swap(a, b);
224        self.map_key_to_position.swap(key_i.index(), key_j.index());
225
226        pumpkin_assert_moderate!(
227            self.map_key_to_position[key_i] == b && self.map_key_to_position[key_j] == a
228        );
229
230        pumpkin_assert_moderate!(
231            self.map_key_to_position
232                .iter()
233                .collect::<HashSet<&usize>>()
234                .len()
235                == self.map_key_to_position.len()
236        )
237    }
238
239    fn sift_up(&mut self, position: usize) {
240        // Only sift up if not at the root
241        if position > 0 {
242            let parent_position = KeyValueHeap::<Key, Value>::get_parent_position(position);
243            // Continue sift up if the heap property is violated
244            if self.values[parent_position] < self.values[position] {
245                self.swap_positions(parent_position, position);
246                self.sift_up(parent_position);
247            }
248        }
249    }
250
251    fn sift_down(&mut self, position: usize) {
252        pumpkin_assert_moderate!(position < self.end_position);
253
254        if !self.is_heap_locally(position) {
255            let largest_child_position = self.get_largest_child_position(position);
256            self.swap_positions(largest_child_position, position);
257            self.sift_down(largest_child_position);
258        }
259    }
260
261    fn is_heap_locally(&self, position: usize) -> bool {
262        // Either the node is a leaf, or it satisfies the heap property (the value of the parent is
263        // at least as large as the values of its child)
264        let left_child_position = KeyValueHeap::<Key, Value>::get_left_child_position(position);
265        let right_child_position = KeyValueHeap::<Key, Value>::get_right_child_position(position);
266
267        if self.is_leaf(position) {
268            return true;
269        }
270
271        // if does not have right child, then just compare with left child.
272        if right_child_position >= self.end_position {
273            return self.values[position] >= self.values[left_child_position];
274        }
275
276        // Otherwise the node has two children, compare with both.
277        self.values[position] >= self.values[left_child_position]
278            && self.values[position] >= self.values[right_child_position]
279    }
280
281    fn is_leaf(&self, position: usize) -> bool {
282        KeyValueHeap::<Key, Value>::get_left_child_position(position) >= self.end_position
283    }
284
285    fn get_largest_child_position(&self, position: usize) -> usize {
286        pumpkin_assert_moderate!(!self.is_leaf(position));
287
288        let left_child_position = KeyValueHeap::<Key, Value>::get_left_child_position(position);
289        let right_child_position = KeyValueHeap::<Key, Value>::get_right_child_position(position);
290
291        if right_child_position < self.end_position
292            && self.values[right_child_position] > self.values[left_child_position]
293        {
294            right_child_position
295        } else {
296            left_child_position
297        }
298    }
299
300    fn get_parent_position(child_position: usize) -> usize {
301        pumpkin_assert_moderate!(child_position > 0, "Root has no parent.");
302        (child_position - 1) / 2
303    }
304
305    fn get_left_child_position(position: usize) -> usize {
306        2 * position + 1
307    }
308
309    fn get_right_child_position(position: usize) -> usize {
310        2 * position + 2
311    }
312}
313
314#[cfg(test)]
315mod test {
316    use super::KeyValueHeap;
317
318    #[test]
319    fn failing_test_case() {
320        let mut heap: KeyValueHeap<usize, u32> = KeyValueHeap::default();
321
322        heap.grow(0, 7);
323        heap.grow(1, 5);
324
325        assert_eq!(heap.pop_max().unwrap(), 0);
326
327        heap.grow(2, 7);
328        heap.grow(3, 6);
329
330        assert_eq!(heap.pop_max().unwrap(), 2);
331        assert_eq!(heap.pop_max().unwrap(), 3);
332    }
333
334    #[test]
335    fn failing_test_case2() {
336        let mut heap: KeyValueHeap<usize, u32> = KeyValueHeap::default();
337
338        heap.grow(0, 5);
339        heap.grow(1, 7);
340        heap.grow(2, 6);
341
342        assert_eq!(heap.pop_max().unwrap(), 1);
343        assert_eq!(heap.pop_max().unwrap(), 2);
344    }
345
346    // Uses the heap to sort the input vectors, and compare with a sorted version of the vector.
347    fn heap_sort_test_helper(numbers: Vec<usize>) {
348        let mut sorted_numbers = numbers.clone();
349        sorted_numbers.sort();
350        sorted_numbers.reverse();
351
352        let mut heap: KeyValueHeap<usize, usize> = KeyValueHeap::default();
353        for n in numbers.iter().enumerate() {
354            heap.grow(n.0, *n.1);
355        }
356
357        let mut heap_sorted_vector: Vec<usize> = vec![];
358        while let Some(index) = heap.pop_max() {
359            heap_sorted_vector.push(numbers[index]);
360        }
361
362        assert_eq!(heap_sorted_vector, sorted_numbers);
363    }
364
365    #[test]
366    fn trivial() {
367        let mut heap: KeyValueHeap<usize, usize> = KeyValueHeap::default();
368        heap.grow(0, 5);
369        assert_eq!(heap.pop_max(), Some(0));
370        assert!(heap.has_no_nonremoved_elements());
371        assert_eq!(heap.pop_max(), None);
372    }
373
374    #[test]
375    fn trivial_sort() {
376        heap_sort_test_helper(vec![5]);
377    }
378
379    #[test]
380    fn simple() {
381        heap_sort_test_helper(vec![5, 10]);
382    }
383
384    #[test]
385    fn random1() {
386        heap_sort_test_helper(vec![5, 10, 3]);
387    }
388
389    #[test]
390    fn random2() {
391        heap_sort_test_helper(vec![3, 10, 5]);
392    }
393
394    #[test]
395    fn random3() {
396        heap_sort_test_helper(vec![1, 2, 3, 4]);
397    }
398
399    #[test]
400    fn duplicates() {
401        heap_sort_test_helper(vec![2, 2, 1, 1, 3, 3, 3]);
402    }
403}