1use std::ops::AddAssign;
10use std::ops::DivAssign;
11
12use super::KeyedVec;
13use super::StorageKey;
14use crate::containers::HashSet;
15use crate::pumpkin_assert_moderate;
16
17#[derive(Debug, Clone)]
21pub struct KeyValueHeap<Key, Value> {
22 values: Vec<Value>,
25 map_key_to_position: KeyedVec<Key, usize>,
28 map_position_to_key: Vec<Key>,
31 end_position: usize,
33}
34
35impl<Key: StorageKey, Value> Default for KeyValueHeap<Key, Value> {
36 fn default() -> Self {
37 Self {
38 values: Default::default(),
39 map_key_to_position: Default::default(),
40 map_position_to_key: Default::default(),
41 end_position: Default::default(),
42 }
43 }
44}
45
46impl<Key, Value> KeyValueHeap<Key, Value> {
47 pub(crate) const fn new() -> Self {
48 Self {
49 values: Vec::new(),
50 map_key_to_position: KeyedVec::new(),
51 map_position_to_key: Vec::new(),
52 end_position: 0,
53 }
54 }
55}
56
57impl<Key, Value> KeyValueHeap<Key, Value>
58where
59 Key: StorageKey + Copy,
60 Value: AddAssign<Value> + DivAssign<Value> + PartialOrd + Default + Copy,
61{
62 pub(crate) fn keys(&self) -> impl Iterator<Item = Key> + '_ {
66 self.map_position_to_key[..self.end_position]
67 .iter()
68 .copied()
69 }
70
71 pub(crate) fn peek_max(&self) -> Option<(&Key, &Value)> {
76 if self.has_no_nonremoved_elements() {
77 None
78 } else {
79 Some((
80 &self.map_position_to_key[0],
81 &self.values[self.map_key_to_position[&self.map_position_to_key[0]]],
82 ))
83 }
84 }
85
86 pub(crate) fn get_value(&self, key: Key) -> &Value {
87 pumpkin_assert_moderate!(
88 key.index() < self.map_key_to_position.len(),
89 "Attempted to get key with index {} for a map with length {}",
90 key.index(),
91 self.map_key_to_position.len()
92 );
93 &self.values[self.map_key_to_position[key]]
94 }
95
96 pub(crate) fn pop_max(&mut self) -> Option<Key> {
101 if !self.has_no_nonremoved_elements() {
102 let best_key = self.map_position_to_key[0];
103 pumpkin_assert_moderate!(0 == self.map_key_to_position[best_key]);
104 self.delete_key(best_key);
106 Some(best_key)
107 } else {
108 None
109 }
110 }
111
112 pub(crate) fn increment(&mut self, key: Key, increment: Value) {
117 let position = self.map_key_to_position[key];
118 self.values[position] += increment;
119 if self.is_key_present(key) {
122 self.sift_up(position);
123 }
124 }
125
126 pub(crate) fn restore_key(&mut self, key: Key) {
131 if !self.is_key_present(key) {
132 let position = self.map_key_to_position[key];
135 pumpkin_assert_moderate!(position >= self.end_position);
136 self.swap_positions(position, self.end_position);
137 self.end_position += 1;
138 self.sift_up(self.end_position - 1);
139 }
140 }
141
142 pub(crate) fn delete_key(&mut self, key: Key) {
149 if self.is_key_present(key) {
150 let position = self.map_key_to_position[key];
153 self.swap_positions(position, self.end_position - 1);
154 self.end_position -= 1;
155 if position < self.end_position {
156 self.sift_down(position);
157 }
158 }
159 }
160
161 pub(crate) fn len(&self) -> usize {
163 self.values.len()
164 }
165
166 pub(crate) fn num_nonremoved_elements(&self) -> usize {
167 self.end_position
168 }
169
170 pub(crate) fn has_no_nonremoved_elements(&self) -> bool {
172 self.num_nonremoved_elements() == 0
173 }
174
175 pub(crate) fn is_key_present(&self, key: Key) -> bool {
177 key.index() < self.map_key_to_position.len()
178 && self.map_key_to_position[key] < self.end_position
179 }
180
181 pub(crate) fn grow(&mut self, key: Key, value: Value) {
184 let last_index = self.values.len();
185 self.values.push(value);
186 let _ = self.map_key_to_position.push(last_index);
189 self.map_position_to_key.push(key);
190 pumpkin_assert_moderate!(
191 self.map_position_to_key[last_index].index() == key.index()
192 && self.map_key_to_position[key] == last_index
193 );
194 self.swap_positions(self.end_position, last_index);
195 self.end_position += 1;
196 self.sift_up(self.end_position - 1);
197 }
198
199 pub(crate) fn clear(&mut self) {
200 self.values.clear();
201 self.map_key_to_position.clear();
202 self.map_position_to_key.clear();
203 self.end_position = 0;
204 }
205
206 pub(crate) fn divide_values(&mut self, divisor: Value) {
211 for value in self.values.iter_mut() {
212 *value /= divisor;
213 }
214 }
215
216 fn swap_positions(&mut self, a: usize, b: usize) {
217 let key_i = self.map_position_to_key[a];
218 pumpkin_assert_moderate!(self.map_key_to_position[key_i] == a);
219 let key_j = self.map_position_to_key[b];
220 pumpkin_assert_moderate!(self.map_key_to_position[key_j] == b);
221
222 self.values.swap(a, b);
223 self.map_position_to_key.swap(a, b);
224 self.map_key_to_position.swap(key_i.index(), key_j.index());
225
226 pumpkin_assert_moderate!(
227 self.map_key_to_position[key_i] == b && self.map_key_to_position[key_j] == a
228 );
229
230 pumpkin_assert_moderate!(
231 self.map_key_to_position
232 .iter()
233 .collect::<HashSet<&usize>>()
234 .len()
235 == self.map_key_to_position.len()
236 )
237 }
238
239 fn sift_up(&mut self, position: usize) {
240 if position > 0 {
242 let parent_position = KeyValueHeap::<Key, Value>::get_parent_position(position);
243 if self.values[parent_position] < self.values[position] {
245 self.swap_positions(parent_position, position);
246 self.sift_up(parent_position);
247 }
248 }
249 }
250
251 fn sift_down(&mut self, position: usize) {
252 pumpkin_assert_moderate!(position < self.end_position);
253
254 if !self.is_heap_locally(position) {
255 let largest_child_position = self.get_largest_child_position(position);
256 self.swap_positions(largest_child_position, position);
257 self.sift_down(largest_child_position);
258 }
259 }
260
261 fn is_heap_locally(&self, position: usize) -> bool {
262 let left_child_position = KeyValueHeap::<Key, Value>::get_left_child_position(position);
265 let right_child_position = KeyValueHeap::<Key, Value>::get_right_child_position(position);
266
267 if self.is_leaf(position) {
268 return true;
269 }
270
271 if right_child_position >= self.end_position {
273 return self.values[position] >= self.values[left_child_position];
274 }
275
276 self.values[position] >= self.values[left_child_position]
278 && self.values[position] >= self.values[right_child_position]
279 }
280
281 fn is_leaf(&self, position: usize) -> bool {
282 KeyValueHeap::<Key, Value>::get_left_child_position(position) >= self.end_position
283 }
284
285 fn get_largest_child_position(&self, position: usize) -> usize {
286 pumpkin_assert_moderate!(!self.is_leaf(position));
287
288 let left_child_position = KeyValueHeap::<Key, Value>::get_left_child_position(position);
289 let right_child_position = KeyValueHeap::<Key, Value>::get_right_child_position(position);
290
291 if right_child_position < self.end_position
292 && self.values[right_child_position] > self.values[left_child_position]
293 {
294 right_child_position
295 } else {
296 left_child_position
297 }
298 }
299
300 fn get_parent_position(child_position: usize) -> usize {
301 pumpkin_assert_moderate!(child_position > 0, "Root has no parent.");
302 (child_position - 1) / 2
303 }
304
305 fn get_left_child_position(position: usize) -> usize {
306 2 * position + 1
307 }
308
309 fn get_right_child_position(position: usize) -> usize {
310 2 * position + 2
311 }
312}
313
314#[cfg(test)]
315mod test {
316 use super::KeyValueHeap;
317
318 #[test]
319 fn failing_test_case() {
320 let mut heap: KeyValueHeap<usize, u32> = KeyValueHeap::default();
321
322 heap.grow(0, 7);
323 heap.grow(1, 5);
324
325 assert_eq!(heap.pop_max().unwrap(), 0);
326
327 heap.grow(2, 7);
328 heap.grow(3, 6);
329
330 assert_eq!(heap.pop_max().unwrap(), 2);
331 assert_eq!(heap.pop_max().unwrap(), 3);
332 }
333
334 #[test]
335 fn failing_test_case2() {
336 let mut heap: KeyValueHeap<usize, u32> = KeyValueHeap::default();
337
338 heap.grow(0, 5);
339 heap.grow(1, 7);
340 heap.grow(2, 6);
341
342 assert_eq!(heap.pop_max().unwrap(), 1);
343 assert_eq!(heap.pop_max().unwrap(), 2);
344 }
345
346 fn heap_sort_test_helper(numbers: Vec<usize>) {
348 let mut sorted_numbers = numbers.clone();
349 sorted_numbers.sort();
350 sorted_numbers.reverse();
351
352 let mut heap: KeyValueHeap<usize, usize> = KeyValueHeap::default();
353 for n in numbers.iter().enumerate() {
354 heap.grow(n.0, *n.1);
355 }
356
357 let mut heap_sorted_vector: Vec<usize> = vec![];
358 while let Some(index) = heap.pop_max() {
359 heap_sorted_vector.push(numbers[index]);
360 }
361
362 assert_eq!(heap_sorted_vector, sorted_numbers);
363 }
364
365 #[test]
366 fn trivial() {
367 let mut heap: KeyValueHeap<usize, usize> = KeyValueHeap::default();
368 heap.grow(0, 5);
369 assert_eq!(heap.pop_max(), Some(0));
370 assert!(heap.has_no_nonremoved_elements());
371 assert_eq!(heap.pop_max(), None);
372 }
373
374 #[test]
375 fn trivial_sort() {
376 heap_sort_test_helper(vec![5]);
377 }
378
379 #[test]
380 fn simple() {
381 heap_sort_test_helper(vec![5, 10]);
382 }
383
384 #[test]
385 fn random1() {
386 heap_sort_test_helper(vec![5, 10, 3]);
387 }
388
389 #[test]
390 fn random2() {
391 heap_sort_test_helper(vec![3, 10, 5]);
392 }
393
394 #[test]
395 fn random3() {
396 heap_sort_test_helper(vec![1, 2, 3, 4]);
397 }
398
399 #[test]
400 fn duplicates() {
401 heap_sort_test_helper(vec![2, 2, 1, 1, 3, 3, 3]);
402 }
403}