Skip to main content

pumpkin_core/containers/
key_value_heap.rs

1//! A heap where the keys range from [0, ..., n - 1] and the values are nonnegative floating points.
2//! The heap can be queried to return key with the maximum value, and certain keys can be
3//! (temporarily) removed/readded as necessary It allows increasing/decreasing the values of its
4//! entries
5
6// The implementation could be more efficient in the following ways:
7//  - Currently more comparisons are done than necessary when sifting
8//  - Possibly the recursion could be unrolled
9use std::ops::AddAssign;
10use std::ops::DivAssign;
11
12use super::KeyedVec;
13use super::StorageKey;
14use crate::containers::HashSet;
15use crate::pumpkin_assert_moderate;
16
17/// A [max-heap](https://en.wikipedia.org/wiki/Min-max_heap)
18/// which allows for generalised `Key`s (required to implement [StorageKey]) and `Value`s (which are
19/// required to be ordered, divisible and addable).
20#[derive(Debug, Clone)]
21pub struct KeyValueHeap<Key, Value> {
22    /// Contains the values stored as a heap; the value of key `i` is at index
23    /// [`KeyValueHeap::map_key_to_position\[i\]`][KeyValueHeap::map_key_to_position]
24    values: Vec<Value>,
25    /// `map_key_to_position[i]` is the index of the value of the key `i` in
26    /// [`KeyValueHeap::values`]
27    map_key_to_position: KeyedVec<Key, usize>,
28    /// `map_position_to_key[i]` is the key which is associated with `i` in
29    /// [`KeyValueHeap::values`]
30    map_position_to_key: Vec<Key>,
31    /// The index of the last element in [`KeyValueHeap::values`]
32    end_position: usize,
33}
34
35impl<Key: StorageKey, Value> Default for KeyValueHeap<Key, Value> {
36    fn default() -> Self {
37        Self {
38            values: Default::default(),
39            map_key_to_position: Default::default(),
40            map_position_to_key: Default::default(),
41            end_position: Default::default(),
42        }
43    }
44}
45
46impl<Key, Value> KeyValueHeap<Key, Value> {
47    pub(crate) const fn new() -> Self {
48        Self {
49            values: Vec::new(),
50            map_key_to_position: KeyedVec::new(),
51            map_position_to_key: Vec::new(),
52            end_position: 0,
53        }
54    }
55}
56
57impl<Key, Value> KeyValueHeap<Key, Value>
58where
59    Key: StorageKey + Copy,
60    Value: AddAssign<Value> + DivAssign<Value> + PartialOrd + Default + Copy,
61{
62    /// Get the keys in the heap.
63    ///
64    /// The order in which the keys are yielded is unspecified.
65    pub(crate) fn keys(&self) -> impl Iterator<Item = Key> + '_ {
66        self.map_position_to_key[..self.end_position]
67            .iter()
68            .copied()
69    }
70
71    /// Return the key with maximum value from the heap, or None if the heap is empty. Note that
72    /// this does not delete the key (see [`KeyValueHeap::pop_max`] to get and delete).
73    ///
74    /// The time-complexity of this operation is O(1)
75    pub(crate) fn peek_max(&self) -> Option<(&Key, &Value)> {
76        if self.has_no_nonremoved_elements() {
77            None
78        } else {
79            Some((
80                &self.map_position_to_key[0],
81                &self.values[self.map_key_to_position[&self.map_position_to_key[0]]],
82            ))
83        }
84    }
85
86    pub fn get_value(&self, key: Key) -> &Value {
87        pumpkin_assert_moderate!(
88            key.index() < self.map_key_to_position.len(),
89            "Attempted to get key with index {} for a map with length {}",
90            key.index(),
91            self.map_key_to_position.len()
92        );
93        &self.values[self.map_key_to_position[key]]
94    }
95
96    /// Deletes the key with maximum value from the heap and returns it, or None if the heap is
97    /// empty.
98    ///
99    ///  The time-complexity of this operation is O(logn).
100    pub fn pop_max(&mut self) -> Option<Key> {
101        if !self.has_no_nonremoved_elements() {
102            let best_key = self.map_position_to_key[0];
103            pumpkin_assert_moderate!(0 == self.map_key_to_position[best_key]);
104            // pumpkin_assert_extreme!(self.is_max_at_top());
105            self.delete_key(best_key);
106            Some(best_key)
107        } else {
108            None
109        }
110    }
111
112    /// Increments the value of the element of 'key' by 'increment'
113    ///
114    /// The worst-case time-complexity of this operation is O(logn); average case is likely to be
115    /// better
116    pub fn increment(&mut self, key: Key, increment: Value) {
117        let position = self.map_key_to_position[key];
118        self.values[position] += increment;
119        // Recall that increment may be applied to keys not present
120        // So we only apply sift up in case the key is present
121        if self.is_key_present(key) {
122            self.sift_up(position);
123        }
124    }
125
126    /// Restores the entry with key 'key' to the heap if the key is not present, otherwise does
127    /// nothing. Its value is the previous value used before 'delete_key' was called.
128    ///
129    ///  The run-time complexity of this operation is O(logn)
130    pub fn restore_key(&mut self, key: Key) {
131        if !self.is_key_present(key) {
132            // The key is somewhere in the range [end_position, max_size-1]
133            // We place the key at the end of the heap, increase end_position, and sift up
134            let position = self.map_key_to_position[key];
135            pumpkin_assert_moderate!(position >= self.end_position);
136            self.swap_positions(position, self.end_position);
137            self.end_position += 1;
138            self.sift_up(self.end_position - 1);
139        }
140    }
141
142    /// Removes the entry with key 'key' (temporarily) from the heap if the key is present,
143    /// otherwise does nothing. Its value remains recorded internally and is available upon
144    /// calling [`KeyValueHeap::restore_key`]. The value can still be subjected to
145    /// [`KeyValueHeap::divide_values`].
146    ///
147    /// The run-time complexity of this operation is O(logn)
148    pub fn delete_key(&mut self, key: Key) {
149        if self.is_key_present(key) {
150            // Place the key at the end of the heap, decrement the heap, and sift down to ensure a
151            // valid heap
152            let position = self.map_key_to_position[key];
153            self.swap_positions(position, self.end_position - 1);
154            self.end_position -= 1;
155            if position < self.end_position {
156                self.sift_down(position);
157            }
158        }
159    }
160
161    /// Returns how many elements are in the heap (including the (temporarily) "removed" values)
162    pub fn len(&self) -> usize {
163        self.values.len()
164    }
165
166    /// Returns whether there are no elements in the heap (including the (temporarily) "removed"
167    /// values)
168    pub fn is_empty(&self) -> bool {
169        self.len() == 0
170    }
171
172    pub fn num_nonremoved_elements(&self) -> usize {
173        self.end_position
174    }
175
176    /// Returns whether there are elements left in the heap (excluding the "removed" values)
177    pub(crate) fn has_no_nonremoved_elements(&self) -> bool {
178        self.num_nonremoved_elements() == 0
179    }
180
181    /// Returns whether the key is currently not (temporarily) remove
182    pub fn is_key_present(&self, key: Key) -> bool {
183        key.index() < self.map_key_to_position.len()
184            && self.map_key_to_position[key] < self.end_position
185    }
186
187    /// Increases the size of the heap by one and adjust the data structures appropriately by adding
188    /// `Key` and `Value`
189    pub fn grow(&mut self, key: Key, value: Value) {
190        let last_index = self.values.len();
191        self.values.push(value);
192        // Initially the key is placed placed at the very end, will be placed in the correct
193        // position below to ensure a valid heap structure
194        let _ = self.map_key_to_position.push(last_index);
195        self.map_position_to_key.push(key);
196        pumpkin_assert_moderate!(
197            self.map_position_to_key[last_index].index() == key.index()
198                && self.map_key_to_position[key] == last_index
199        );
200        self.swap_positions(self.end_position, last_index);
201        self.end_position += 1;
202        self.sift_up(self.end_position - 1);
203    }
204
205    pub fn clear(&mut self) {
206        self.values.clear();
207        self.map_key_to_position.clear();
208        self.map_position_to_key.clear();
209        self.end_position = 0;
210    }
211
212    /// Divides all the values in the heap by 'divisor'. This will also affect the values of keys
213    /// that have been [`KeyValueHeap::delete_key`].
214    ///
215    /// The run-time complexity of this operation is O(n)
216    pub fn divide_values(&mut self, divisor: Value) {
217        for value in self.values.iter_mut() {
218            *value /= divisor;
219        }
220    }
221
222    fn swap_positions(&mut self, a: usize, b: usize) {
223        let key_i = self.map_position_to_key[a];
224        pumpkin_assert_moderate!(self.map_key_to_position[key_i] == a);
225        let key_j = self.map_position_to_key[b];
226        pumpkin_assert_moderate!(self.map_key_to_position[key_j] == b);
227
228        self.values.swap(a, b);
229        self.map_position_to_key.swap(a, b);
230        self.map_key_to_position.swap(key_i.index(), key_j.index());
231
232        pumpkin_assert_moderate!(
233            self.map_key_to_position[key_i] == b && self.map_key_to_position[key_j] == a
234        );
235
236        pumpkin_assert_moderate!(
237            self.map_key_to_position
238                .iter()
239                .collect::<HashSet<&usize>>()
240                .len()
241                == self.map_key_to_position.len()
242        )
243    }
244
245    fn sift_up(&mut self, position: usize) {
246        // Only sift up if not at the root
247        if position > 0 {
248            let parent_position = KeyValueHeap::<Key, Value>::get_parent_position(position);
249            // Continue sift up if the heap property is violated
250            if self.values[parent_position] < self.values[position] {
251                self.swap_positions(parent_position, position);
252                self.sift_up(parent_position);
253            }
254        }
255    }
256
257    fn sift_down(&mut self, position: usize) {
258        pumpkin_assert_moderate!(position < self.end_position);
259
260        if !self.is_heap_locally(position) {
261            let largest_child_position = self.get_largest_child_position(position);
262            self.swap_positions(largest_child_position, position);
263            self.sift_down(largest_child_position);
264        }
265    }
266
267    fn is_heap_locally(&self, position: usize) -> bool {
268        // Either the node is a leaf, or it satisfies the heap property (the value of the parent is
269        // at least as large as the values of its child)
270        let left_child_position = KeyValueHeap::<Key, Value>::get_left_child_position(position);
271        let right_child_position = KeyValueHeap::<Key, Value>::get_right_child_position(position);
272
273        if self.is_leaf(position) {
274            return true;
275        }
276
277        // if does not have right child, then just compare with left child.
278        if right_child_position >= self.end_position {
279            return self.values[position] >= self.values[left_child_position];
280        }
281
282        // Otherwise the node has two children, compare with both.
283        self.values[position] >= self.values[left_child_position]
284            && self.values[position] >= self.values[right_child_position]
285    }
286
287    fn is_leaf(&self, position: usize) -> bool {
288        KeyValueHeap::<Key, Value>::get_left_child_position(position) >= self.end_position
289    }
290
291    fn get_largest_child_position(&self, position: usize) -> usize {
292        pumpkin_assert_moderate!(!self.is_leaf(position));
293
294        let left_child_position = KeyValueHeap::<Key, Value>::get_left_child_position(position);
295        let right_child_position = KeyValueHeap::<Key, Value>::get_right_child_position(position);
296
297        if right_child_position < self.end_position
298            && self.values[right_child_position] > self.values[left_child_position]
299        {
300            right_child_position
301        } else {
302            left_child_position
303        }
304    }
305
306    fn get_parent_position(child_position: usize) -> usize {
307        pumpkin_assert_moderate!(child_position > 0, "Root has no parent.");
308        (child_position - 1) / 2
309    }
310
311    fn get_left_child_position(position: usize) -> usize {
312        2 * position + 1
313    }
314
315    fn get_right_child_position(position: usize) -> usize {
316        2 * position + 2
317    }
318}
319
320#[cfg(test)]
321mod test {
322    use super::KeyValueHeap;
323
324    #[test]
325    fn failing_test_case() {
326        let mut heap: KeyValueHeap<usize, u32> = KeyValueHeap::default();
327
328        heap.grow(0, 7);
329        heap.grow(1, 5);
330
331        assert_eq!(heap.pop_max().unwrap(), 0);
332
333        heap.grow(2, 7);
334        heap.grow(3, 6);
335
336        assert_eq!(heap.pop_max().unwrap(), 2);
337        assert_eq!(heap.pop_max().unwrap(), 3);
338    }
339
340    #[test]
341    fn failing_test_case2() {
342        let mut heap: KeyValueHeap<usize, u32> = KeyValueHeap::default();
343
344        heap.grow(0, 5);
345        heap.grow(1, 7);
346        heap.grow(2, 6);
347
348        assert_eq!(heap.pop_max().unwrap(), 1);
349        assert_eq!(heap.pop_max().unwrap(), 2);
350    }
351
352    // Uses the heap to sort the input vectors, and compare with a sorted version of the vector.
353    fn heap_sort_test_helper(numbers: Vec<usize>) {
354        let mut sorted_numbers = numbers.clone();
355        sorted_numbers.sort();
356        sorted_numbers.reverse();
357
358        let mut heap: KeyValueHeap<usize, usize> = KeyValueHeap::default();
359        for n in numbers.iter().enumerate() {
360            heap.grow(n.0, *n.1);
361        }
362
363        let mut heap_sorted_vector: Vec<usize> = vec![];
364        while let Some(index) = heap.pop_max() {
365            heap_sorted_vector.push(numbers[index]);
366        }
367
368        assert_eq!(heap_sorted_vector, sorted_numbers);
369    }
370
371    #[test]
372    fn trivial() {
373        let mut heap: KeyValueHeap<usize, usize> = KeyValueHeap::default();
374        heap.grow(0, 5);
375        assert_eq!(heap.pop_max(), Some(0));
376        assert!(heap.has_no_nonremoved_elements());
377        assert_eq!(heap.pop_max(), None);
378    }
379
380    #[test]
381    fn trivial_sort() {
382        heap_sort_test_helper(vec![5]);
383    }
384
385    #[test]
386    fn simple() {
387        heap_sort_test_helper(vec![5, 10]);
388    }
389
390    #[test]
391    fn random1() {
392        heap_sort_test_helper(vec![5, 10, 3]);
393    }
394
395    #[test]
396    fn random2() {
397        heap_sort_test_helper(vec![3, 10, 5]);
398    }
399
400    #[test]
401    fn random3() {
402        heap_sort_test_helper(vec![1, 2, 3, 4]);
403    }
404
405    #[test]
406    fn duplicates() {
407        heap_sort_test_helper(vec![2, 2, 1, 1, 3, 3, 3]);
408    }
409}