ic_stable_structures/
min_heap.rs

1use crate::base_vec::BaseVec;
2use crate::storable::Storable;
3use crate::Memory;
4use std::fmt;
5
6#[cfg(test)]
7mod tests;
8
9const MAGIC: [u8; 3] = *b"SMH"; // Short for "stable min heap".
10
11/// An implementation of the [binary min heap](https://en.wikipedia.org/wiki/Binary_heap).
12// NB. Contrary to [std::collections::BinaryHeap], this heap is a min-heap (smallest items come first).
13// Motivation: max heaps are helpful for sorting, but most daily programming tasks require min
14// heaps.
15pub struct MinHeap<T: Storable + PartialOrd, M: Memory>(BaseVec<T, M>);
16
17// Note: Heap Invariant
18// ~~~~~~~~~~~~~~~~~~~~
19//
20// HeapInvariant(heap, i, j) :=
21//   ∀ k: i ≤ k ≤ j: LET p = (k - 1)/2 IN (p ≤ i) => heap[p] ≤ heap[k]
22
23impl<T, M> MinHeap<T, M>
24where
25    T: Storable + PartialOrd,
26    M: Memory,
27{
28    /// Creates a new empty heap in the specified memory,
29    /// overwriting any data structures the memory might have
30    /// contained.
31    ///
32    /// Complexity: O(1)
33    pub fn new(memory: M) -> Self {
34        BaseVec::<T, M>::new(memory, MAGIC)
35            .map(Self)
36            .expect("Failed to create a new heap")
37    }
38
39    /// Initializes a heap in the specified memory.
40    ///
41    /// Complexity: O(1)
42    ///
43    /// PRECONDITION: the memory is either empty or contains a valid
44    /// stable heap.
45    pub fn init(memory: M) -> Self {
46        BaseVec::<T, M>::init(memory, MAGIC)
47            .map(Self)
48            .expect("Failed to initialize a heap")
49    }
50
51    /// Returns the number of items in the heap.
52    ///
53    /// Complexity: O(1)
54    pub fn len(&self) -> u64 {
55        self.0.len()
56    }
57
58    /// Returns true if the heap is empty.
59    ///
60    /// Complexity: O(1)
61    pub fn is_empty(&self) -> bool {
62        self.0.is_empty()
63    }
64
65    /// Pushes an item onto the heap.
66    ///
67    /// Complexity: O(log(self.len()))
68    pub fn push(&mut self, item: &T) {
69        self.0.push(item).expect("heap push failed");
70        self.bubble_up(self.0.len() - 1, item);
71        debug_assert_eq!(Ok(()), self.check_invariant());
72    }
73
74    /// Removes the smallest item from the heap and returns it.
75    /// Returns `None` if the heap is empty.
76    ///
77    /// Complexity: O(log(self.len()))
78    pub fn pop(&mut self) -> Option<T> {
79        let n = self.len();
80        match n {
81            0 => None,
82            1 => self.0.pop(),
83            _more => {
84                let smallest = self.0.get(0).unwrap();
85                let last = self.0.pop().unwrap();
86                self.0.set(0, &last);
87                self.bubble_down(0, n - 1, &last);
88                debug_assert_eq!(Ok(()), self.check_invariant());
89                Some(smallest)
90            }
91        }
92    }
93
94    /// Returns the smallest item in the heap.
95    /// Returns `None` if the heap is empty.
96    ///
97    /// Complexity: O(1)
98    pub fn peek(&self) -> Option<T> {
99        self.0.get(0)
100    }
101
102    /// Returns an iterator visiting all values in the underlying vector, in arbitrary order.
103    pub fn iter(&self) -> impl Iterator<Item = T> + '_ {
104        self.0.iter()
105    }
106
107    /// Returns the underlying memory instance.
108    pub fn into_memory(self) -> M {
109        self.0.into_memory()
110    }
111
112    #[allow(dead_code)]
113    /// Checks the HeapInvariant(self, 0, self.len() - 1)
114    fn check_invariant(&self) -> Result<(), String> {
115        let n = self.len();
116        for i in 1..n {
117            let p = (i - 1) / 2;
118            let item = self.0.get(i).unwrap();
119            let parent = self.0.get(p).unwrap();
120            if is_less(&item, &parent) {
121                return Err(format!(
122                    "Binary heap invariant violated in indices {i} and {p}"
123                ));
124            }
125        }
126        Ok(())
127    }
128
129    /// PRECONDITION: self.0.get(i) == item
130    fn bubble_up(&mut self, mut i: u64, item: &T) {
131        // We set the flag if self.0.get(i) does not contain the item anymore.
132        let mut swapped = false;
133        // LOOP INVARIANT: HeapInvariant(self, i, self.len() - 1)
134        while i > 0 {
135            let p = (i - 1) / 2;
136            let parent = self.0.get(p).unwrap();
137            if is_less(item, &parent) {
138                self.0.set(i, &parent);
139                swapped = true;
140            } else {
141                break;
142            }
143            i = p;
144        }
145        if swapped {
146            self.0.set(i, item);
147        }
148    }
149
150    /// PRECONDITION: self.0.get(i) == item
151    fn bubble_down(&mut self, mut i: u64, n: u64, item: &T) {
152        // We set the flag if self.0.get(i) does not contain the item anymore.
153        let mut swapped = false;
154        // LOOP INVARIANT: HeapInvariant(self, 0, i)
155        loop {
156            let l = i * 2 + 1;
157            let r = l + 1;
158
159            if n <= l {
160                break;
161            }
162
163            if n <= r {
164                // Only the left child is within the array bounds.
165
166                let left = self.0.get(l).unwrap();
167                if is_less(&left, item) {
168                    self.0.set(i, &left);
169                    swapped = true;
170                    i = l;
171                    continue;
172                }
173            } else {
174                // Both children are within the array bounds.
175
176                let left = self.0.get(l).unwrap();
177                let right = self.0.get(r).unwrap();
178
179                let (min_index, min_elem) = if is_less(&left, &right) {
180                    (l, &left)
181                } else {
182                    (r, &right)
183                };
184
185                if is_less(min_elem, item) {
186                    self.0.set(i, min_elem);
187                    swapped = true;
188                    i = min_index;
189                    continue;
190                }
191            }
192            break;
193        }
194        if swapped {
195            self.0.set(i, item);
196        }
197    }
198}
199
200fn is_less<T: PartialOrd>(x: &T, y: &T) -> bool {
201    x.partial_cmp(y) == Some(std::cmp::Ordering::Less)
202}
203
204impl<T, M> fmt::Debug for MinHeap<T, M>
205where
206    T: Storable + PartialOrd + fmt::Debug,
207    M: Memory,
208{
209    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
210        self.0.fmt(fmt)
211    }
212}