1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
use crate::base_vec::{BaseVec, InitError};
use crate::storable::BoundedStorable;
use crate::{GrowFailed, Memory};
use std::fmt;

#[cfg(test)]
mod tests;

const MAGIC: [u8; 3] = *b"SMH"; // Short for "stable min heap".

/// An implementation of the [binary min heap](https://en.wikipedia.org/wiki/Binary_heap).
// NB. Contrary to [std::collections::BinaryHeap], this heap is a min-heap (smallest items come first).
// Motivation: max heaps are helpful for sorting, but most daily programming tasks require min
// heaps.
pub struct MinHeap<T: BoundedStorable + PartialOrd, M: Memory>(BaseVec<T, M>);

// Note: Heap Invariant
// ~~~~~~~~~~~~~~~~~~~~
//
// HeapInvariant(heap, i, j) :=
//   ∀ k: i ≤ k ≤ j: LET p = (k - 1)/2 IN (p ≤ i) => heap[p] ≤ heap[k]

impl<T, M> MinHeap<T, M>
where
    T: BoundedStorable + PartialOrd,
    M: Memory,
{
    /// Creates a new empty heap in the specified memory,
    /// overwriting any data structures the memory might have
    /// contained.
    ///
    /// Complexity: O(1)
    pub fn new(memory: M) -> Result<Self, GrowFailed> {
        BaseVec::<T, M>::new(memory, MAGIC).map(Self)
    }

    /// Initializes a heap in the specified memory.
    ///
    /// Complexity: O(1)
    ///
    /// PRECONDITION: the memory is either empty or contains a valid
    /// stable heap.
    pub fn init(memory: M) -> Result<Self, InitError> {
        BaseVec::<T, M>::init(memory, MAGIC).map(Self)
    }

    /// Returns the number of items in the heap.
    ///
    /// Complexity: O(1)
    pub fn len(&self) -> u64 {
        self.0.len()
    }

    /// Returns true if the heap is empty.
    ///
    /// Complexity: O(1)
    pub fn is_empty(&self) -> bool {
        self.0.is_empty()
    }

    /// Pushes an item onto the heap.
    ///
    /// Complexity: O(log(self.len()))
    pub fn push(&mut self, item: &T) -> Result<(), GrowFailed> {
        self.0.push(item)?;
        self.bubble_up(self.0.len() - 1, item);
        debug_assert_eq!(Ok(()), self.check_invariant());
        Ok(())
    }

    /// Removes the smallest item from the heap and returns it.
    /// Returns `None` if the heap is empty.
    ///
    /// Complexity: O(log(self.len()))
    pub fn pop(&mut self) -> Option<T> {
        let n = self.len();
        match n {
            0 => None,
            1 => self.0.pop(),
            _more => {
                let smallest = self.0.get(0).unwrap();
                let last = self.0.pop().unwrap();
                self.0.set(0, &last);
                self.bubble_down(0, n - 1, &last);
                debug_assert_eq!(Ok(()), self.check_invariant());
                Some(smallest)
            }
        }
    }

    /// Returns the smallest item in the heap.
    /// Returns `None` if the heap is empty.
    ///
    /// Complexity: O(1)
    pub fn peek(&self) -> Option<T> {
        self.0.get(0)
    }

    /// Returns an iterator visiting all values in the underlying vector, in arbitrary order.
    pub fn iter(&self) -> impl Iterator<Item = T> + '_ {
        self.0.iter()
    }

    /// Returns the underlying memory instance.
    pub fn into_memory(self) -> M {
        self.0.into_memory()
    }

    #[allow(dead_code)]
    /// Checks the HeapInvariant(self, 0, self.len() - 1)
    fn check_invariant(&self) -> Result<(), String> {
        let n = self.len();
        for i in 1..n {
            let p = (i - 1) / 2;
            let item = self.0.get(i).unwrap();
            let parent = self.0.get(p).unwrap();
            if is_less(&item, &parent) {
                return Err(format!(
                    "Binary heap invariant violated in indices {i} and {p}"
                ));
            }
        }
        Ok(())
    }

    /// PRECONDITION: self.0.get(i) == item
    fn bubble_up(&mut self, mut i: u64, item: &T) {
        // We set the flag if self.0.get(i) does not contain the item anymore.
        let mut swapped = false;
        // LOOP INVARIANT: HeapInvariant(self, i, self.len() - 1)
        while i > 0 {
            let p = (i - 1) / 2;
            let parent = self.0.get(p).unwrap();
            if is_less(item, &parent) {
                self.0.set(i, &parent);
                swapped = true;
            } else {
                break;
            }
            i = p;
        }
        if swapped {
            self.0.set(i, item);
        }
    }

    /// PRECONDITION: self.0.get(i) == item
    fn bubble_down(&mut self, mut i: u64, n: u64, item: &T) {
        // We set the flag if self.0.get(i) does not contain the item anymore.
        let mut swapped = false;
        // LOOP INVARIANT: HeapInvariant(self, 0, i)
        loop {
            let l = i * 2 + 1;
            let r = l + 1;

            if n <= l {
                break;
            }

            if n <= r {
                // Only the left child is within the array bounds.

                let left = self.0.get(l).unwrap();
                if is_less(&left, item) {
                    self.0.set(i, &left);
                    swapped = true;
                    i = l;
                    continue;
                }
            } else {
                // Both children are within the array bounds.

                let left = self.0.get(l).unwrap();
                let right = self.0.get(r).unwrap();

                let (min_index, min_elem) = if is_less(&left, &right) {
                    (l, &left)
                } else {
                    (r, &right)
                };

                if is_less(min_elem, item) {
                    self.0.set(i, min_elem);
                    swapped = true;
                    i = min_index;
                    continue;
                }
            }
            break;
        }
        if swapped {
            self.0.set(i, item);
        }
    }
}

fn is_less<T: PartialOrd>(x: &T, y: &T) -> bool {
    x.partial_cmp(y) == Some(std::cmp::Ordering::Less)
}

impl<T, M> fmt::Debug for MinHeap<T, M>
where
    T: BoundedStorable + PartialOrd + fmt::Debug,
    M: Memory,
{
    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
        self.0.fmt(fmt)
    }
}