Skip to main content

rstl_queue/
priority_queue.rs

1use std::{mem::ManuallyDrop, ptr};
2
3use collection::{Collection, Disposable};
4
5use crate::traits::QueueLike;
6
7pub use crate::traits::PriorityQueueLike;
8
9#[derive(Debug, Clone)]
10pub struct PriorityQueue<T>
11where
12    T: Ord,
13{
14    disposed: bool,
15    elements: Vec<T>,
16}
17
18struct RestoreOnDrop<T> {
19    ptr: *mut T,
20    pos: usize,
21    item: ManuallyDrop<T>,
22}
23
24impl<T> Drop for RestoreOnDrop<T> {
25    fn drop(&mut self) {
26        // SAFETY: `self.pos` is always in-bounds and marks the current hole.
27        // Writing `item` back restores full initialization if unwinding occurs.
28        unsafe {
29            let dst = self.ptr.add(self.pos);
30            ptr::write(dst, ManuallyDrop::take(&mut self.item));
31        }
32    }
33}
34
35impl<T> PriorityQueue<T>
36where
37    T: Ord,
38{
39    pub fn new() -> Self {
40        Self {
41            disposed: false,
42            elements: Vec::new(),
43        }
44    }
45
46    #[inline(always)]
47    fn up(&mut self, index: usize) {
48        let len = self.elements.len();
49        if index == 0 || index >= len {
50            return;
51        }
52
53        unsafe {
54            let ptr = self.elements.as_mut_ptr();
55            let mut pos = index;
56
57            let item = ManuallyDrop::new(ptr::read(ptr.add(index)));
58            let mut restore = RestoreOnDrop { ptr, pos, item };
59            let item_ptr = (&restore.item as *const ManuallyDrop<T>).cast::<T>();
60
61            while pos > 0 {
62                let parent = (pos - 1) >> 1;
63                let parent_ref: &T = &*ptr.add(parent);
64                let item_ref: &T = &*item_ptr;
65                if parent_ref <= item_ref {
66                    break;
67                }
68
69                ptr::copy_nonoverlapping(ptr.add(parent), ptr.add(pos), 1);
70                pos = parent;
71                restore.pos = pos;
72            }
73        }
74    }
75
76    #[inline(always)]
77    fn down(&mut self, index: usize) {
78        let n = self.elements.len();
79        if index >= n {
80            return;
81        }
82
83        unsafe {
84            let ptr = self.elements.as_mut_ptr();
85            let mut pos = index;
86
87            let item = ManuallyDrop::new(ptr::read(ptr.add(index)));
88            let mut restore = RestoreOnDrop { ptr, pos, item };
89            let item_ptr = (&restore.item as *const ManuallyDrop<T>).cast::<T>();
90
91            loop {
92                let left = (pos << 1) + 1;
93                if left >= n {
94                    break;
95                }
96
97                let right = left + 1;
98                let mut child = left;
99                if right < n {
100                    let right_ref: &T = &*ptr.add(right);
101                    let left_ref: &T = &*ptr.add(left);
102                    if right_ref < left_ref {
103                        child = right;
104                    }
105                }
106
107                let item_ref: &T = &*item_ptr;
108                let child_ref: &T = &*ptr.add(child);
109                if item_ref <= child_ref {
110                    break;
111                }
112
113                ptr::copy_nonoverlapping(ptr.add(child), ptr.add(pos), 1);
114                pos = child;
115                restore.pos = pos;
116            }
117        }
118    }
119
120    #[inline(always)]
121    fn down_to_bottom_then_up(&mut self, start: usize) {
122        let n = self.elements.len();
123        if start >= n {
124            return;
125        }
126
127        unsafe {
128            let ptr = self.elements.as_mut_ptr();
129            let item = ManuallyDrop::new(ptr::read(ptr.add(start)));
130            let mut restore = RestoreOnDrop {
131                ptr,
132                pos: start,
133                item,
134            };
135            let item_ptr = (&restore.item as *const ManuallyDrop<T>).cast::<T>();
136
137            Self::sift_down_to_bottom(ptr, n, &mut restore);
138            Self::sift_up_from(start, ptr, &mut restore, item_ptr);
139        }
140    }
141
142    #[inline(always)]
143    unsafe fn sift_down_to_bottom(ptr: *mut T, n: usize, restore: &mut RestoreOnDrop<T>) {
144        let mut pos = restore.pos;
145        let mut child = (pos << 1) + 1;
146
147        while child + 1 < n {
148            let right = child + 1;
149            let right_ref: &T = unsafe { &*ptr.add(right) };
150            let child_ref: &T = unsafe { &*ptr.add(child) };
151            if right_ref < child_ref {
152                child = right;
153            }
154
155            // SAFETY: `child` and `pos` are always valid indices in the heap range.
156            unsafe { ptr::copy_nonoverlapping(ptr.add(child), ptr.add(pos), 1) };
157            pos = child;
158            restore.pos = pos;
159            child = (pos << 1) + 1;
160        }
161
162        if child < n {
163            // SAFETY: when `child < n`, both source and destination are valid.
164            unsafe { ptr::copy_nonoverlapping(ptr.add(child), ptr.add(pos), 1) };
165            pos = child;
166            restore.pos = pos;
167        }
168    }
169
170    #[inline(always)]
171    unsafe fn sift_up_from(
172        start: usize,
173        ptr: *mut T,
174        restore: &mut RestoreOnDrop<T>,
175        item_ptr: *const T,
176    ) {
177        let mut pos = restore.pos;
178
179        while pos > start {
180            let parent = (pos - 1) >> 1;
181            let parent_ref: &T = unsafe { &*ptr.add(parent) };
182            let item_ref: &T = unsafe { &*item_ptr };
183            if parent_ref <= item_ref {
184                break;
185            }
186
187            // SAFETY: `parent` and `pos` are valid and non-overlapping.
188            unsafe { ptr::copy_nonoverlapping(ptr.add(parent), ptr.add(pos), 1) };
189            pos = parent;
190            restore.pos = pos;
191        }
192    }
193
194    fn fast_build(&mut self) {
195        let n = self.elements.len();
196        if n <= 1 {
197            return;
198        }
199
200        let last_parent = (n >> 1) - 1;
201        for p in (0..=last_parent).rev() {
202            self.down(p);
203        }
204    }
205}
206
207impl<T> Default for PriorityQueue<T>
208where
209    T: Ord,
210{
211    fn default() -> Self {
212        Self::new()
213    }
214}
215
216impl<T> QueueLike<T> for PriorityQueue<T>
217where
218    T: Ord,
219{
220    fn front(&self) -> Option<&T> {
221        self.elements.first()
222    }
223
224    fn enqueue(&mut self, element: T) {
225        self.elements.push(element);
226        let index = self.elements.len() - 1;
227        self.up(index);
228    }
229
230    fn dequeue(&mut self) -> Option<T> {
231        self.elements.pop().map(|mut item| {
232            if !self.elements.is_empty() {
233                std::mem::swap(&mut item, &mut self.elements[0]);
234                self.down_to_bottom_then_up(0);
235            }
236            item
237        })
238    }
239
240    fn enqueues<I>(&mut self, elements: I)
241    where
242        I: IntoIterator<Item = T>,
243    {
244        let size = self.elements.len();
245        self.elements.extend(elements);
246
247        let next_size = self.elements.len();
248        if next_size == size {
249            return;
250        }
251
252        let new_added = next_size - size;
253        let next_size_f64 = next_size as f64;
254        if (new_added as f64) * next_size_f64.log2() > next_size_f64 {
255            self.fast_build();
256        } else {
257            for i in size..next_size {
258                self.up(i);
259            }
260        }
261    }
262
263    fn replace_front(&mut self, new_back: T) -> Option<T> {
264        if self.elements.is_empty() {
265            self.elements.push(new_back);
266            return None;
267        }
268
269        let removed = std::mem::replace(&mut self.elements[0], new_back);
270        self.down(0);
271        Some(removed)
272    }
273}
274
275impl<T> PriorityQueueLike<T> for PriorityQueue<T> where T: Ord {}
276
277impl<T> Collection for PriorityQueue<T>
278where
279    T: Ord,
280{
281    type Item = T;
282    type Iter<'a>
283        = std::slice::Iter<'a, T>
284    where
285        Self: 'a;
286
287    fn iter(&self) -> Self::Iter<'_> {
288        self.elements.iter()
289    }
290
291    fn size(&self) -> usize {
292        self.elements.len()
293    }
294
295    fn clear(&mut self) {
296        self.elements.clear();
297    }
298
299    fn retain<F>(&mut self, mut f: F) -> usize
300    where
301        F: FnMut(&Self::Item) -> bool,
302    {
303        let before = self.elements.len();
304        if before == 0 {
305            return 0;
306        }
307
308        self.elements.retain(|item| f(item));
309        let removed = before - self.elements.len();
310        if removed > 0 {
311            self.fast_build();
312        }
313        removed
314    }
315}
316
317impl<T> Disposable for PriorityQueue<T>
318where
319    T: Ord,
320{
321    fn dispose(&mut self) {
322        self.disposed = true;
323        self.elements.clear();
324    }
325
326    fn is_disposed(&self) -> bool {
327        self.disposed
328    }
329}
330
331impl<'a, T> IntoIterator for &'a PriorityQueue<T>
332where
333    T: Ord,
334{
335    type Item = &'a T;
336    type IntoIter = std::slice::Iter<'a, T>;
337
338    fn into_iter(self) -> Self::IntoIter {
339        self.elements.iter()
340    }
341}
342
343#[cfg(test)]
344mod tests {
345    use std::{cmp::Reverse, collections::BinaryHeap};
346
347    use collection::{Collection, Disposable};
348
349    use crate::traits::{PriorityQueueLike, QueueLike};
350
351    use super::PriorityQueue;
352
353    #[derive(Clone)]
354    struct XorShift64 {
355        state: u64,
356    }
357
358    impl XorShift64 {
359        fn new(seed: u64) -> Self {
360            Self { state: seed }
361        }
362
363        fn next_u64(&mut self) -> u64 {
364            let mut x = self.state;
365            x ^= x << 13;
366            x ^= x >> 7;
367            x ^= x << 17;
368            self.state = x;
369            x
370        }
371
372        fn next_i32_in(&mut self, bound: i32) -> i32 {
373            debug_assert!(bound > 0);
374            (self.next_u64() % bound as u64) as i32
375        }
376    }
377
378    fn drain_all<T>(q: &mut PriorityQueue<T>) -> Vec<T>
379    where
380        T: Ord,
381    {
382        let mut out = Vec::new();
383        while let Some(x) = q.dequeue() {
384            out.push(x);
385        }
386        out
387    }
388
389    #[test]
390    fn queue_like_min_heap_ops_should_work() {
391        let mut q = PriorityQueue::new();
392
393        assert_eq!(q.front(), None);
394        assert_eq!(q.dequeue(), None);
395
396        q.enqueues([4, 2, 5, 1, 3]);
397        assert_eq!(q.front(), Some(&1));
398        assert_eq!(q.replace_front(6), Some(1));
399        assert_eq!(drain_all(&mut q), vec![2, 3, 4, 5, 6]);
400    }
401
402    #[test]
403    fn enqueue_should_cover_up_break_path() {
404        let mut q = PriorityQueue::new();
405
406        q.enqueue(1);
407        q.enqueue(2);
408        q.enqueue(0);
409
410        assert_eq!(q.front(), Some(&0));
411        assert_eq!(drain_all(&mut q), vec![0, 1, 2]);
412    }
413
414    #[test]
415    fn replace_front_should_handle_empty_and_single_item() {
416        let mut q = PriorityQueue::new();
417
418        assert_eq!(q.replace_front(10), None);
419        assert_eq!(q.front(), Some(&10));
420
421        assert_eq!(q.replace_front(5), Some(10));
422        assert_eq!(q.front(), Some(&5));
423        assert_eq!(q.dequeue(), Some(5));
424    }
425
426    #[test]
427    fn enqueues_should_work_for_small_and_large_batch() {
428        let mut q = PriorityQueue::new();
429
430        q.enqueues([5, 4]);
431        q.enqueues(0..100);
432
433        assert_eq!(q.size(), 102);
434        assert_eq!(q.front(), Some(&0));
435
436        let drained = drain_all(&mut q);
437        assert_eq!(drained.len(), 102);
438        assert!(drained.windows(2).all(|w| w[0] <= w[1]));
439    }
440
441    #[test]
442    fn enqueues_empty_should_be_noop() {
443        let mut q = PriorityQueue::new();
444
445        q.enqueues(std::iter::empty());
446        assert!(q.is_empty());
447
448        q.enqueue(3);
449        q.enqueues(std::iter::empty());
450        assert_eq!(q.front(), Some(&3));
451        assert_eq!(q.size(), 1);
452    }
453
454    #[test]
455    fn reverse_ord_should_support_max_heap() {
456        let mut q = PriorityQueue::new();
457
458        q.enqueues([
459            Reverse(1_i32),
460            Reverse(5_i32),
461            Reverse(2_i32),
462            Reverse(4_i32),
463            Reverse(3_i32),
464        ]);
465
466        assert_eq!(q.front(), Some(&Reverse(5)));
467        assert_eq!(
468            drain_all(&mut q),
469            vec![Reverse(5), Reverse(4), Reverse(3), Reverse(2), Reverse(1)]
470        );
471    }
472
473    #[test]
474    fn retain_should_rebuild_heap() {
475        let mut q = PriorityQueue::new();
476        q.enqueues(1..=8);
477
478        let removed = q.retain(|x| *x % 2 == 0);
479        assert_eq!(removed, 4);
480        assert_eq!(drain_all(&mut q), vec![2, 4, 6, 8]);
481
482        let mut single = PriorityQueue::new();
483        single.enqueues([1, 2]);
484        let removed_single = single.retain(|x| *x == 2);
485        assert_eq!(removed_single, 1);
486        assert_eq!(single.front(), Some(&2));
487    }
488
489    #[test]
490    fn retain_on_empty_should_return_zero() {
491        let mut q = PriorityQueue::<i32>::new();
492
493        let removed = q.retain(|_| true);
494        assert_eq!(removed, 0);
495        assert!(q.is_empty());
496    }
497
498    #[test]
499    fn iter_should_be_unsorted_but_complete() {
500        let mut q = PriorityQueue::new();
501        q.enqueues([7, 1, 9, 3, 5]);
502
503        let mut from_iter: Vec<i32> = q.iter().copied().collect();
504        from_iter.sort();
505        assert_eq!(from_iter, vec![1, 3, 5, 7, 9]);
506
507        let mut from_into_iter: Vec<i32> = (&q).into_iter().copied().collect();
508        from_into_iter.sort();
509        assert_eq!(from_into_iter, vec![1, 3, 5, 7, 9]);
510    }
511
512    #[test]
513    fn collection_and_dispose_contract_should_work() {
514        let mut q = PriorityQueue::new();
515        q.enqueues([3, 1, 2]);
516
517        assert_eq!(Collection::size(&q), 3);
518        assert_eq!(Collection::count(&q, |x| *x % 2 == 1), 2);
519
520        let mut all = Collection::collect(&q);
521        all.sort();
522        assert_eq!(all, vec![1, 2, 3]);
523
524        Collection::clear(&mut q);
525        assert!(Collection::is_empty(&q));
526
527        assert!(!Disposable::is_disposed(&q));
528        Disposable::dispose(&mut q);
529        assert!(Disposable::is_disposed(&q));
530        assert!(Collection::is_empty(&q));
531    }
532
533    #[test]
534    fn priority_queue_like_should_be_implemented() {
535        fn assert_priority_queue_like<Q: PriorityQueueLike<i32>>(_q: &Q) {}
536
537        let q = PriorityQueue::new();
538        assert_priority_queue_like(&q);
539    }
540
541    fn model_front(model: &[i32]) -> Option<i32> {
542        model.iter().min().copied()
543    }
544
545    fn model_dequeue(model: &mut Vec<i32>) -> Option<i32> {
546        let (idx, _) = model.iter().enumerate().min_by_key(|(_, x)| *x)?;
547        Some(model.swap_remove(idx))
548    }
549
550    fn model_replace_front(model: &mut Vec<i32>, new_back: i32) -> Option<i32> {
551        match model_dequeue(model) {
552            Some(removed) => {
553                model.push(new_back);
554                Some(removed)
555            }
556            None => {
557                model.push(new_back);
558                None
559            }
560        }
561    }
562
563    fn sorted_dequeue_from_binary_heap(heap: &mut BinaryHeap<Reverse<i32>>) -> Vec<i32> {
564        let mut out = Vec::new();
565        while let Some(Reverse(x)) = heap.pop() {
566            out.push(x);
567        }
568        out
569    }
570
571    #[test]
572    fn randomized_ops_should_match_model_and_binary_heap() {
573        let seeds = [1_u64, 7, 97, 0x1234_5678, 0xDEAD_BEEF, 0xCAFE_BABE];
574
575        for seed in seeds {
576            let mut rng = XorShift64::new(seed);
577            let mut q = PriorityQueue::new();
578            let mut model: Vec<i32> = Vec::new();
579            let mut bh = BinaryHeap::<Reverse<i32>>::new();
580
581            for step in 0..5000 {
582                match rng.next_u64() % 6 {
583                    0 => {
584                        let x = rng.next_i32_in(10_000) - 5000;
585                        q.enqueue(x);
586                        model.push(x);
587                        bh.push(Reverse(x));
588                    }
589                    1 => {
590                        let got = q.dequeue();
591                        let expect = model_dequeue(&mut model);
592                        let bh_expect = bh.pop().map(|Reverse(x)| x);
593                        assert_eq!(got, expect);
594                        assert_eq!(got, bh_expect);
595                    }
596                    2 => {
597                        let x = rng.next_i32_in(10_000) - 5000;
598                        let got = q.replace_front(x);
599                        let expect = model_replace_front(&mut model, x);
600                        let bh_expect = match bh.pop() {
601                            Some(Reverse(v)) => {
602                                bh.push(Reverse(x));
603                                Some(v)
604                            }
605                            None => {
606                                bh.push(Reverse(x));
607                                None
608                            }
609                        };
610                        assert_eq!(got, expect);
611                        assert_eq!(got, bh_expect);
612                    }
613                    3 => {
614                        let batch_size = (rng.next_u64() % 8) as usize;
615                        let mut batch = Vec::with_capacity(batch_size);
616                        for _ in 0..batch_size {
617                            let x = rng.next_i32_in(10_000) - 5000;
618                            batch.push(x);
619                        }
620                        q.enqueues(batch.iter().copied());
621                        for &x in &batch {
622                            model.push(x);
623                            bh.push(Reverse(x));
624                        }
625                    }
626                    4 => {
627                        let div = (rng.next_u64() % 5 + 2) as i32;
628                        let rem = (rng.next_u64() % div as u64) as i32;
629                        let removed = q.retain(|x| x.rem_euclid(div) != rem);
630
631                        let before = model.len();
632                        model.retain(|x| x.rem_euclid(div) != rem);
633                        let expect_removed = before - model.len();
634
635                        let mut kept = Vec::new();
636                        while let Some(Reverse(x)) = bh.pop() {
637                            if x.rem_euclid(div) != rem {
638                                kept.push(x);
639                            }
640                        }
641                        for x in kept {
642                            bh.push(Reverse(x));
643                        }
644
645                        assert_eq!(removed, expect_removed);
646                    }
647                    _ => {
648                        q.clear();
649                        model.clear();
650                        bh.clear();
651                    }
652                }
653
654                assert_eq!(q.size(), model.len());
655                assert_eq!(q.front().copied(), model_front(&model));
656                assert_eq!(q.front().copied(), bh.peek().map(|Reverse(x)| *x));
657
658                if step % 257 == 0 {
659                    let mut q_clone = q.clone();
660                    let mut expected = model.clone();
661                    expected.sort_unstable();
662                    let actual = drain_all(&mut q_clone);
663                    assert_eq!(actual, expected);
664
665                    let mut bh_clone = bh.clone();
666                    assert_eq!(actual, sorted_dequeue_from_binary_heap(&mut bh_clone));
667                }
668            }
669
670            let mut expected = model;
671            expected.sort_unstable();
672            assert_eq!(drain_all(&mut q), expected);
673            assert_eq!(expected, sorted_dequeue_from_binary_heap(&mut bh));
674        }
675    }
676}