1use std::ops::AddAssign;
10use std::ops::DivAssign;
11
12use super::KeyedVec;
13use super::StorageKey;
14use crate::containers::HashSet;
15use crate::pumpkin_assert_moderate;
16
17#[derive(Debug, Clone)]
21pub struct KeyValueHeap<Key, Value> {
22 values: Vec<Value>,
25 map_key_to_position: KeyedVec<Key, usize>,
28 map_position_to_key: Vec<Key>,
31 end_position: usize,
33}
34
35impl<Key: StorageKey, Value> Default for KeyValueHeap<Key, Value> {
36 fn default() -> Self {
37 Self {
38 values: Default::default(),
39 map_key_to_position: Default::default(),
40 map_position_to_key: Default::default(),
41 end_position: Default::default(),
42 }
43 }
44}
45
46impl<Key, Value> KeyValueHeap<Key, Value> {
47 pub(crate) const fn new() -> Self {
48 Self {
49 values: Vec::new(),
50 map_key_to_position: KeyedVec::new(),
51 map_position_to_key: Vec::new(),
52 end_position: 0,
53 }
54 }
55}
56
57impl<Key, Value> KeyValueHeap<Key, Value>
58where
59 Key: StorageKey + Copy,
60 Value: AddAssign<Value> + DivAssign<Value> + PartialOrd + Default + Copy,
61{
62 pub(crate) fn keys(&self) -> impl Iterator<Item = Key> + '_ {
66 self.map_position_to_key[..self.end_position]
67 .iter()
68 .copied()
69 }
70
71 pub(crate) fn peek_max(&self) -> Option<(&Key, &Value)> {
76 if self.has_no_nonremoved_elements() {
77 None
78 } else {
79 Some((
80 &self.map_position_to_key[0],
81 &self.values[self.map_key_to_position[&self.map_position_to_key[0]]],
82 ))
83 }
84 }
85
86 pub fn get_value(&self, key: Key) -> &Value {
87 pumpkin_assert_moderate!(
88 key.index() < self.map_key_to_position.len(),
89 "Attempted to get key with index {} for a map with length {}",
90 key.index(),
91 self.map_key_to_position.len()
92 );
93 &self.values[self.map_key_to_position[key]]
94 }
95
96 pub fn pop_max(&mut self) -> Option<Key> {
101 if !self.has_no_nonremoved_elements() {
102 let best_key = self.map_position_to_key[0];
103 pumpkin_assert_moderate!(0 == self.map_key_to_position[best_key]);
104 self.delete_key(best_key);
106 Some(best_key)
107 } else {
108 None
109 }
110 }
111
112 pub fn increment(&mut self, key: Key, increment: Value) {
117 let position = self.map_key_to_position[key];
118 self.values[position] += increment;
119 if self.is_key_present(key) {
122 self.sift_up(position);
123 }
124 }
125
126 pub fn restore_key(&mut self, key: Key) {
131 if !self.is_key_present(key) {
132 let position = self.map_key_to_position[key];
135 pumpkin_assert_moderate!(position >= self.end_position);
136 self.swap_positions(position, self.end_position);
137 self.end_position += 1;
138 self.sift_up(self.end_position - 1);
139 }
140 }
141
142 pub fn delete_key(&mut self, key: Key) {
149 if self.is_key_present(key) {
150 let position = self.map_key_to_position[key];
153 self.swap_positions(position, self.end_position - 1);
154 self.end_position -= 1;
155 if position < self.end_position {
156 self.sift_down(position);
157 }
158 }
159 }
160
161 pub fn len(&self) -> usize {
163 self.values.len()
164 }
165
166 pub fn is_empty(&self) -> bool {
169 self.len() == 0
170 }
171
172 pub fn num_nonremoved_elements(&self) -> usize {
173 self.end_position
174 }
175
176 pub(crate) fn has_no_nonremoved_elements(&self) -> bool {
178 self.num_nonremoved_elements() == 0
179 }
180
181 pub fn is_key_present(&self, key: Key) -> bool {
183 key.index() < self.map_key_to_position.len()
184 && self.map_key_to_position[key] < self.end_position
185 }
186
187 pub fn grow(&mut self, key: Key, value: Value) {
190 let last_index = self.values.len();
191 self.values.push(value);
192 let _ = self.map_key_to_position.push(last_index);
195 self.map_position_to_key.push(key);
196 pumpkin_assert_moderate!(
197 self.map_position_to_key[last_index].index() == key.index()
198 && self.map_key_to_position[key] == last_index
199 );
200 self.swap_positions(self.end_position, last_index);
201 self.end_position += 1;
202 self.sift_up(self.end_position - 1);
203 }
204
205 pub fn clear(&mut self) {
206 self.values.clear();
207 self.map_key_to_position.clear();
208 self.map_position_to_key.clear();
209 self.end_position = 0;
210 }
211
212 pub fn divide_values(&mut self, divisor: Value) {
217 for value in self.values.iter_mut() {
218 *value /= divisor;
219 }
220 }
221
222 fn swap_positions(&mut self, a: usize, b: usize) {
223 let key_i = self.map_position_to_key[a];
224 pumpkin_assert_moderate!(self.map_key_to_position[key_i] == a);
225 let key_j = self.map_position_to_key[b];
226 pumpkin_assert_moderate!(self.map_key_to_position[key_j] == b);
227
228 self.values.swap(a, b);
229 self.map_position_to_key.swap(a, b);
230 self.map_key_to_position.swap(key_i.index(), key_j.index());
231
232 pumpkin_assert_moderate!(
233 self.map_key_to_position[key_i] == b && self.map_key_to_position[key_j] == a
234 );
235
236 pumpkin_assert_moderate!(
237 self.map_key_to_position
238 .iter()
239 .collect::<HashSet<&usize>>()
240 .len()
241 == self.map_key_to_position.len()
242 )
243 }
244
245 fn sift_up(&mut self, position: usize) {
246 if position > 0 {
248 let parent_position = KeyValueHeap::<Key, Value>::get_parent_position(position);
249 if self.values[parent_position] < self.values[position] {
251 self.swap_positions(parent_position, position);
252 self.sift_up(parent_position);
253 }
254 }
255 }
256
257 fn sift_down(&mut self, position: usize) {
258 pumpkin_assert_moderate!(position < self.end_position);
259
260 if !self.is_heap_locally(position) {
261 let largest_child_position = self.get_largest_child_position(position);
262 self.swap_positions(largest_child_position, position);
263 self.sift_down(largest_child_position);
264 }
265 }
266
267 fn is_heap_locally(&self, position: usize) -> bool {
268 let left_child_position = KeyValueHeap::<Key, Value>::get_left_child_position(position);
271 let right_child_position = KeyValueHeap::<Key, Value>::get_right_child_position(position);
272
273 if self.is_leaf(position) {
274 return true;
275 }
276
277 if right_child_position >= self.end_position {
279 return self.values[position] >= self.values[left_child_position];
280 }
281
282 self.values[position] >= self.values[left_child_position]
284 && self.values[position] >= self.values[right_child_position]
285 }
286
287 fn is_leaf(&self, position: usize) -> bool {
288 KeyValueHeap::<Key, Value>::get_left_child_position(position) >= self.end_position
289 }
290
291 fn get_largest_child_position(&self, position: usize) -> usize {
292 pumpkin_assert_moderate!(!self.is_leaf(position));
293
294 let left_child_position = KeyValueHeap::<Key, Value>::get_left_child_position(position);
295 let right_child_position = KeyValueHeap::<Key, Value>::get_right_child_position(position);
296
297 if right_child_position < self.end_position
298 && self.values[right_child_position] > self.values[left_child_position]
299 {
300 right_child_position
301 } else {
302 left_child_position
303 }
304 }
305
306 fn get_parent_position(child_position: usize) -> usize {
307 pumpkin_assert_moderate!(child_position > 0, "Root has no parent.");
308 (child_position - 1) / 2
309 }
310
311 fn get_left_child_position(position: usize) -> usize {
312 2 * position + 1
313 }
314
315 fn get_right_child_position(position: usize) -> usize {
316 2 * position + 2
317 }
318}
319
320#[cfg(test)]
321mod test {
322 use super::KeyValueHeap;
323
324 #[test]
325 fn failing_test_case() {
326 let mut heap: KeyValueHeap<usize, u32> = KeyValueHeap::default();
327
328 heap.grow(0, 7);
329 heap.grow(1, 5);
330
331 assert_eq!(heap.pop_max().unwrap(), 0);
332
333 heap.grow(2, 7);
334 heap.grow(3, 6);
335
336 assert_eq!(heap.pop_max().unwrap(), 2);
337 assert_eq!(heap.pop_max().unwrap(), 3);
338 }
339
340 #[test]
341 fn failing_test_case2() {
342 let mut heap: KeyValueHeap<usize, u32> = KeyValueHeap::default();
343
344 heap.grow(0, 5);
345 heap.grow(1, 7);
346 heap.grow(2, 6);
347
348 assert_eq!(heap.pop_max().unwrap(), 1);
349 assert_eq!(heap.pop_max().unwrap(), 2);
350 }
351
352 fn heap_sort_test_helper(numbers: Vec<usize>) {
354 let mut sorted_numbers = numbers.clone();
355 sorted_numbers.sort();
356 sorted_numbers.reverse();
357
358 let mut heap: KeyValueHeap<usize, usize> = KeyValueHeap::default();
359 for n in numbers.iter().enumerate() {
360 heap.grow(n.0, *n.1);
361 }
362
363 let mut heap_sorted_vector: Vec<usize> = vec![];
364 while let Some(index) = heap.pop_max() {
365 heap_sorted_vector.push(numbers[index]);
366 }
367
368 assert_eq!(heap_sorted_vector, sorted_numbers);
369 }
370
371 #[test]
372 fn trivial() {
373 let mut heap: KeyValueHeap<usize, usize> = KeyValueHeap::default();
374 heap.grow(0, 5);
375 assert_eq!(heap.pop_max(), Some(0));
376 assert!(heap.has_no_nonremoved_elements());
377 assert_eq!(heap.pop_max(), None);
378 }
379
380 #[test]
381 fn trivial_sort() {
382 heap_sort_test_helper(vec![5]);
383 }
384
385 #[test]
386 fn simple() {
387 heap_sort_test_helper(vec![5, 10]);
388 }
389
390 #[test]
391 fn random1() {
392 heap_sort_test_helper(vec![5, 10, 3]);
393 }
394
395 #[test]
396 fn random2() {
397 heap_sort_test_helper(vec![3, 10, 5]);
398 }
399
400 #[test]
401 fn random3() {
402 heap_sort_test_helper(vec![1, 2, 3, 4]);
403 }
404
405 #[test]
406 fn duplicates() {
407 heap_sort_test_helper(vec![2, 2, 1, 1, 3, 3, 3]);
408 }
409}