1use std::{mem::ManuallyDrop, ptr};
2
3use collection::{Collection, Disposable};
4
5use crate::traits::QueueLike;
6
7pub use crate::traits::PriorityQueueLike;
8
9#[derive(Debug, Clone)]
10pub struct PriorityQueue<T>
11where
12 T: Ord,
13{
14 disposed: bool,
15 elements: Vec<T>,
16}
17
18struct RestoreOnDrop<T> {
19 ptr: *mut T,
20 pos: usize,
21 item: ManuallyDrop<T>,
22}
23
24impl<T> Drop for RestoreOnDrop<T> {
25 fn drop(&mut self) {
26 unsafe {
29 let dst = self.ptr.add(self.pos);
30 ptr::write(dst, ManuallyDrop::take(&mut self.item));
31 }
32 }
33}
34
35impl<T> PriorityQueue<T>
36where
37 T: Ord,
38{
39 pub fn new() -> Self {
40 Self {
41 disposed: false,
42 elements: Vec::new(),
43 }
44 }
45
46 #[inline(always)]
47 fn up(&mut self, index: usize) {
48 let len = self.elements.len();
49 if index == 0 || index >= len {
50 return;
51 }
52
53 unsafe {
54 let ptr = self.elements.as_mut_ptr();
55 let mut pos = index;
56
57 let item = ManuallyDrop::new(ptr::read(ptr.add(index)));
58 let mut restore = RestoreOnDrop { ptr, pos, item };
59 let item_ptr = (&restore.item as *const ManuallyDrop<T>).cast::<T>();
60
61 while pos > 0 {
62 let parent = (pos - 1) >> 1;
63 let parent_ref: &T = &*ptr.add(parent);
64 let item_ref: &T = &*item_ptr;
65 if parent_ref <= item_ref {
66 break;
67 }
68
69 ptr::copy_nonoverlapping(ptr.add(parent), ptr.add(pos), 1);
70 pos = parent;
71 restore.pos = pos;
72 }
73 }
74 }
75
76 #[inline(always)]
77 fn down(&mut self, index: usize) {
78 let n = self.elements.len();
79 if index >= n {
80 return;
81 }
82
83 unsafe {
84 let ptr = self.elements.as_mut_ptr();
85 let mut pos = index;
86
87 let item = ManuallyDrop::new(ptr::read(ptr.add(index)));
88 let mut restore = RestoreOnDrop { ptr, pos, item };
89 let item_ptr = (&restore.item as *const ManuallyDrop<T>).cast::<T>();
90
91 loop {
92 let left = (pos << 1) + 1;
93 if left >= n {
94 break;
95 }
96
97 let right = left + 1;
98 let mut child = left;
99 if right < n {
100 let right_ref: &T = &*ptr.add(right);
101 let left_ref: &T = &*ptr.add(left);
102 if right_ref < left_ref {
103 child = right;
104 }
105 }
106
107 let item_ref: &T = &*item_ptr;
108 let child_ref: &T = &*ptr.add(child);
109 if item_ref <= child_ref {
110 break;
111 }
112
113 ptr::copy_nonoverlapping(ptr.add(child), ptr.add(pos), 1);
114 pos = child;
115 restore.pos = pos;
116 }
117 }
118 }
119
120 #[inline(always)]
121 fn down_to_bottom_then_up(&mut self, start: usize) {
122 let n = self.elements.len();
123 if start >= n {
124 return;
125 }
126
127 unsafe {
128 let ptr = self.elements.as_mut_ptr();
129 let item = ManuallyDrop::new(ptr::read(ptr.add(start)));
130 let mut restore = RestoreOnDrop {
131 ptr,
132 pos: start,
133 item,
134 };
135 let item_ptr = (&restore.item as *const ManuallyDrop<T>).cast::<T>();
136
137 Self::sift_down_to_bottom(ptr, n, &mut restore);
138 Self::sift_up_from(start, ptr, &mut restore, item_ptr);
139 }
140 }
141
142 #[inline(always)]
143 unsafe fn sift_down_to_bottom(ptr: *mut T, n: usize, restore: &mut RestoreOnDrop<T>) {
144 let mut pos = restore.pos;
145 let mut child = (pos << 1) + 1;
146
147 while child + 1 < n {
148 let right = child + 1;
149 let right_ref: &T = unsafe { &*ptr.add(right) };
150 let child_ref: &T = unsafe { &*ptr.add(child) };
151 if right_ref < child_ref {
152 child = right;
153 }
154
155 unsafe { ptr::copy_nonoverlapping(ptr.add(child), ptr.add(pos), 1) };
157 pos = child;
158 restore.pos = pos;
159 child = (pos << 1) + 1;
160 }
161
162 if child < n {
163 unsafe { ptr::copy_nonoverlapping(ptr.add(child), ptr.add(pos), 1) };
165 pos = child;
166 restore.pos = pos;
167 }
168 }
169
170 #[inline(always)]
171 unsafe fn sift_up_from(
172 start: usize,
173 ptr: *mut T,
174 restore: &mut RestoreOnDrop<T>,
175 item_ptr: *const T,
176 ) {
177 let mut pos = restore.pos;
178
179 while pos > start {
180 let parent = (pos - 1) >> 1;
181 let parent_ref: &T = unsafe { &*ptr.add(parent) };
182 let item_ref: &T = unsafe { &*item_ptr };
183 if parent_ref <= item_ref {
184 break;
185 }
186
187 unsafe { ptr::copy_nonoverlapping(ptr.add(parent), ptr.add(pos), 1) };
189 pos = parent;
190 restore.pos = pos;
191 }
192 }
193
194 fn fast_build(&mut self) {
195 let n = self.elements.len();
196 if n <= 1 {
197 return;
198 }
199
200 let last_parent = (n >> 1) - 1;
201 for p in (0..=last_parent).rev() {
202 self.down(p);
203 }
204 }
205}
206
207impl<T> Default for PriorityQueue<T>
208where
209 T: Ord,
210{
211 fn default() -> Self {
212 Self::new()
213 }
214}
215
216impl<T> QueueLike<T> for PriorityQueue<T>
217where
218 T: Ord,
219{
220 fn front(&self) -> Option<&T> {
221 self.elements.first()
222 }
223
224 fn enqueue(&mut self, element: T) {
225 self.elements.push(element);
226 let index = self.elements.len() - 1;
227 self.up(index);
228 }
229
230 fn dequeue(&mut self) -> Option<T> {
231 self.elements.pop().map(|mut item| {
232 if !self.elements.is_empty() {
233 std::mem::swap(&mut item, &mut self.elements[0]);
234 self.down_to_bottom_then_up(0);
235 }
236 item
237 })
238 }
239
240 fn enqueues<I>(&mut self, elements: I)
241 where
242 I: IntoIterator<Item = T>,
243 {
244 let size = self.elements.len();
245 self.elements.extend(elements);
246
247 let next_size = self.elements.len();
248 if next_size == size {
249 return;
250 }
251
252 let new_added = next_size - size;
253 let next_size_f64 = next_size as f64;
254 if (new_added as f64) * next_size_f64.log2() > next_size_f64 {
255 self.fast_build();
256 } else {
257 for i in size..next_size {
258 self.up(i);
259 }
260 }
261 }
262
263 fn replace_front(&mut self, new_back: T) -> Option<T> {
264 if self.elements.is_empty() {
265 self.elements.push(new_back);
266 return None;
267 }
268
269 let removed = std::mem::replace(&mut self.elements[0], new_back);
270 self.down(0);
271 Some(removed)
272 }
273}
274
275impl<T> PriorityQueueLike<T> for PriorityQueue<T> where T: Ord {}
276
277impl<T> Collection for PriorityQueue<T>
278where
279 T: Ord,
280{
281 type Item = T;
282 type Iter<'a>
283 = std::slice::Iter<'a, T>
284 where
285 Self: 'a;
286
287 fn iter(&self) -> Self::Iter<'_> {
288 self.elements.iter()
289 }
290
291 fn size(&self) -> usize {
292 self.elements.len()
293 }
294
295 fn clear(&mut self) {
296 self.elements.clear();
297 }
298
299 fn retain<F>(&mut self, mut f: F) -> usize
300 where
301 F: FnMut(&Self::Item) -> bool,
302 {
303 let before = self.elements.len();
304 if before == 0 {
305 return 0;
306 }
307
308 self.elements.retain(|item| f(item));
309 let removed = before - self.elements.len();
310 if removed > 0 {
311 self.fast_build();
312 }
313 removed
314 }
315}
316
317impl<T> Disposable for PriorityQueue<T>
318where
319 T: Ord,
320{
321 fn dispose(&mut self) {
322 self.disposed = true;
323 self.elements.clear();
324 }
325
326 fn is_disposed(&self) -> bool {
327 self.disposed
328 }
329}
330
331impl<'a, T> IntoIterator for &'a PriorityQueue<T>
332where
333 T: Ord,
334{
335 type Item = &'a T;
336 type IntoIter = std::slice::Iter<'a, T>;
337
338 fn into_iter(self) -> Self::IntoIter {
339 self.elements.iter()
340 }
341}
342
343#[cfg(test)]
344mod tests {
345 use std::{cmp::Reverse, collections::BinaryHeap};
346
347 use collection::{Collection, Disposable};
348
349 use crate::traits::{PriorityQueueLike, QueueLike};
350
351 use super::PriorityQueue;
352
353 #[derive(Clone)]
354 struct XorShift64 {
355 state: u64,
356 }
357
358 impl XorShift64 {
359 fn new(seed: u64) -> Self {
360 Self { state: seed }
361 }
362
363 fn next_u64(&mut self) -> u64 {
364 let mut x = self.state;
365 x ^= x << 13;
366 x ^= x >> 7;
367 x ^= x << 17;
368 self.state = x;
369 x
370 }
371
372 fn next_i32_in(&mut self, bound: i32) -> i32 {
373 debug_assert!(bound > 0);
374 (self.next_u64() % bound as u64) as i32
375 }
376 }
377
378 fn drain_all<T>(q: &mut PriorityQueue<T>) -> Vec<T>
379 where
380 T: Ord,
381 {
382 let mut out = Vec::new();
383 while let Some(x) = q.dequeue() {
384 out.push(x);
385 }
386 out
387 }
388
389 #[test]
390 fn queue_like_min_heap_ops_should_work() {
391 let mut q = PriorityQueue::new();
392
393 assert_eq!(q.front(), None);
394 assert_eq!(q.dequeue(), None);
395
396 q.enqueues([4, 2, 5, 1, 3]);
397 assert_eq!(q.front(), Some(&1));
398 assert_eq!(q.replace_front(6), Some(1));
399 assert_eq!(drain_all(&mut q), vec![2, 3, 4, 5, 6]);
400 }
401
402 #[test]
403 fn enqueue_should_cover_up_break_path() {
404 let mut q = PriorityQueue::new();
405
406 q.enqueue(1);
407 q.enqueue(2);
408 q.enqueue(0);
409
410 assert_eq!(q.front(), Some(&0));
411 assert_eq!(drain_all(&mut q), vec![0, 1, 2]);
412 }
413
414 #[test]
415 fn replace_front_should_handle_empty_and_single_item() {
416 let mut q = PriorityQueue::new();
417
418 assert_eq!(q.replace_front(10), None);
419 assert_eq!(q.front(), Some(&10));
420
421 assert_eq!(q.replace_front(5), Some(10));
422 assert_eq!(q.front(), Some(&5));
423 assert_eq!(q.dequeue(), Some(5));
424 }
425
426 #[test]
427 fn enqueues_should_work_for_small_and_large_batch() {
428 let mut q = PriorityQueue::new();
429
430 q.enqueues([5, 4]);
431 q.enqueues(0..100);
432
433 assert_eq!(q.size(), 102);
434 assert_eq!(q.front(), Some(&0));
435
436 let drained = drain_all(&mut q);
437 assert_eq!(drained.len(), 102);
438 assert!(drained.windows(2).all(|w| w[0] <= w[1]));
439 }
440
441 #[test]
442 fn enqueues_empty_should_be_noop() {
443 let mut q = PriorityQueue::new();
444
445 q.enqueues(std::iter::empty());
446 assert!(q.is_empty());
447
448 q.enqueue(3);
449 q.enqueues(std::iter::empty());
450 assert_eq!(q.front(), Some(&3));
451 assert_eq!(q.size(), 1);
452 }
453
454 #[test]
455 fn reverse_ord_should_support_max_heap() {
456 let mut q = PriorityQueue::new();
457
458 q.enqueues([
459 Reverse(1_i32),
460 Reverse(5_i32),
461 Reverse(2_i32),
462 Reverse(4_i32),
463 Reverse(3_i32),
464 ]);
465
466 assert_eq!(q.front(), Some(&Reverse(5)));
467 assert_eq!(
468 drain_all(&mut q),
469 vec![Reverse(5), Reverse(4), Reverse(3), Reverse(2), Reverse(1)]
470 );
471 }
472
473 #[test]
474 fn retain_should_rebuild_heap() {
475 let mut q = PriorityQueue::new();
476 q.enqueues(1..=8);
477
478 let removed = q.retain(|x| *x % 2 == 0);
479 assert_eq!(removed, 4);
480 assert_eq!(drain_all(&mut q), vec![2, 4, 6, 8]);
481
482 let mut single = PriorityQueue::new();
483 single.enqueues([1, 2]);
484 let removed_single = single.retain(|x| *x == 2);
485 assert_eq!(removed_single, 1);
486 assert_eq!(single.front(), Some(&2));
487 }
488
489 #[test]
490 fn retain_on_empty_should_return_zero() {
491 let mut q = PriorityQueue::<i32>::new();
492
493 let removed = q.retain(|_| true);
494 assert_eq!(removed, 0);
495 assert!(q.is_empty());
496 }
497
498 #[test]
499 fn iter_should_be_unsorted_but_complete() {
500 let mut q = PriorityQueue::new();
501 q.enqueues([7, 1, 9, 3, 5]);
502
503 let mut from_iter: Vec<i32> = q.iter().copied().collect();
504 from_iter.sort();
505 assert_eq!(from_iter, vec![1, 3, 5, 7, 9]);
506
507 let mut from_into_iter: Vec<i32> = (&q).into_iter().copied().collect();
508 from_into_iter.sort();
509 assert_eq!(from_into_iter, vec![1, 3, 5, 7, 9]);
510 }
511
512 #[test]
513 fn collection_and_dispose_contract_should_work() {
514 let mut q = PriorityQueue::new();
515 q.enqueues([3, 1, 2]);
516
517 assert_eq!(Collection::size(&q), 3);
518 assert_eq!(Collection::count(&q, |x| *x % 2 == 1), 2);
519
520 let mut all = Collection::collect(&q);
521 all.sort();
522 assert_eq!(all, vec![1, 2, 3]);
523
524 Collection::clear(&mut q);
525 assert!(Collection::is_empty(&q));
526
527 assert!(!Disposable::is_disposed(&q));
528 Disposable::dispose(&mut q);
529 assert!(Disposable::is_disposed(&q));
530 assert!(Collection::is_empty(&q));
531 }
532
533 #[test]
534 fn priority_queue_like_should_be_implemented() {
535 fn assert_priority_queue_like<Q: PriorityQueueLike<i32>>(_q: &Q) {}
536
537 let q = PriorityQueue::new();
538 assert_priority_queue_like(&q);
539 }
540
541 fn model_front(model: &[i32]) -> Option<i32> {
542 model.iter().min().copied()
543 }
544
545 fn model_dequeue(model: &mut Vec<i32>) -> Option<i32> {
546 let (idx, _) = model.iter().enumerate().min_by_key(|(_, x)| *x)?;
547 Some(model.swap_remove(idx))
548 }
549
550 fn model_replace_front(model: &mut Vec<i32>, new_back: i32) -> Option<i32> {
551 match model_dequeue(model) {
552 Some(removed) => {
553 model.push(new_back);
554 Some(removed)
555 }
556 None => {
557 model.push(new_back);
558 None
559 }
560 }
561 }
562
563 fn sorted_dequeue_from_binary_heap(heap: &mut BinaryHeap<Reverse<i32>>) -> Vec<i32> {
564 let mut out = Vec::new();
565 while let Some(Reverse(x)) = heap.pop() {
566 out.push(x);
567 }
568 out
569 }
570
571 #[test]
572 fn randomized_ops_should_match_model_and_binary_heap() {
573 let seeds = [1_u64, 7, 97, 0x1234_5678, 0xDEAD_BEEF, 0xCAFE_BABE];
574
575 for seed in seeds {
576 let mut rng = XorShift64::new(seed);
577 let mut q = PriorityQueue::new();
578 let mut model: Vec<i32> = Vec::new();
579 let mut bh = BinaryHeap::<Reverse<i32>>::new();
580
581 for step in 0..5000 {
582 match rng.next_u64() % 6 {
583 0 => {
584 let x = rng.next_i32_in(10_000) - 5000;
585 q.enqueue(x);
586 model.push(x);
587 bh.push(Reverse(x));
588 }
589 1 => {
590 let got = q.dequeue();
591 let expect = model_dequeue(&mut model);
592 let bh_expect = bh.pop().map(|Reverse(x)| x);
593 assert_eq!(got, expect);
594 assert_eq!(got, bh_expect);
595 }
596 2 => {
597 let x = rng.next_i32_in(10_000) - 5000;
598 let got = q.replace_front(x);
599 let expect = model_replace_front(&mut model, x);
600 let bh_expect = match bh.pop() {
601 Some(Reverse(v)) => {
602 bh.push(Reverse(x));
603 Some(v)
604 }
605 None => {
606 bh.push(Reverse(x));
607 None
608 }
609 };
610 assert_eq!(got, expect);
611 assert_eq!(got, bh_expect);
612 }
613 3 => {
614 let batch_size = (rng.next_u64() % 8) as usize;
615 let mut batch = Vec::with_capacity(batch_size);
616 for _ in 0..batch_size {
617 let x = rng.next_i32_in(10_000) - 5000;
618 batch.push(x);
619 }
620 q.enqueues(batch.iter().copied());
621 for &x in &batch {
622 model.push(x);
623 bh.push(Reverse(x));
624 }
625 }
626 4 => {
627 let div = (rng.next_u64() % 5 + 2) as i32;
628 let rem = (rng.next_u64() % div as u64) as i32;
629 let removed = q.retain(|x| x.rem_euclid(div) != rem);
630
631 let before = model.len();
632 model.retain(|x| x.rem_euclid(div) != rem);
633 let expect_removed = before - model.len();
634
635 let mut kept = Vec::new();
636 while let Some(Reverse(x)) = bh.pop() {
637 if x.rem_euclid(div) != rem {
638 kept.push(x);
639 }
640 }
641 for x in kept {
642 bh.push(Reverse(x));
643 }
644
645 assert_eq!(removed, expect_removed);
646 }
647 _ => {
648 q.clear();
649 model.clear();
650 bh.clear();
651 }
652 }
653
654 assert_eq!(q.size(), model.len());
655 assert_eq!(q.front().copied(), model_front(&model));
656 assert_eq!(q.front().copied(), bh.peek().map(|Reverse(x)| *x));
657
658 if step % 257 == 0 {
659 let mut q_clone = q.clone();
660 let mut expected = model.clone();
661 expected.sort_unstable();
662 let actual = drain_all(&mut q_clone);
663 assert_eq!(actual, expected);
664
665 let mut bh_clone = bh.clone();
666 assert_eq!(actual, sorted_dequeue_from_binary_heap(&mut bh_clone));
667 }
668 }
669
670 let mut expected = model;
671 expected.sort_unstable();
672 assert_eq!(drain_all(&mut q), expected);
673 assert_eq!(expected, sorted_dequeue_from_binary_heap(&mut bh));
674 }
675 }
676}