ex3-ic-stable-structures 0.6.0-beta.3

A collection of data structures for fearless canister upgrades.
Documentation
use crate::base_vec::{BaseVec, InitError};
use crate::storable::Storable;
use crate::{GrowFailed, Memory};
use std::fmt;

#[cfg(test)]
mod tests;

const MAGIC: [u8; 3] = *b"SMH"; // Short for "stable min heap".

/// An implementation of the [binary min heap](https://en.wikipedia.org/wiki/Binary_heap).
// NB. Contrary to [std::collections::BinaryHeap], this heap is a min-heap (smallest items come first).
// Motivation: max heaps are helpful for sorting, but most daily programming tasks require min
// heaps.
pub struct MinHeap<T: Storable + PartialOrd, M: Memory>(BaseVec<T, M>);

// Note: Heap Invariant
// ~~~~~~~~~~~~~~~~~~~~
//
// HeapInvariant(heap, i, j) :=
//   ∀ k: i ≤ k ≤ j: LET p = (k - 1)/2 IN (p ≤ i) => heap[p] ≤ heap[k]

impl<T, M> MinHeap<T, M>
where
    T: Storable + PartialOrd,
    M: Memory,
{
    /// Creates a new empty heap in the specified memory,
    /// overwriting any data structures the memory might have
    /// contained.
    ///
    /// Complexity: O(1)
    pub fn new(memory: M) -> Result<Self, GrowFailed> {
        BaseVec::<T, M>::new(memory, MAGIC).map(Self)
    }

    /// Initializes a heap in the specified memory.
    ///
    /// Complexity: O(1)
    ///
    /// PRECONDITION: the memory is either empty or contains a valid
    /// stable heap.
    pub fn init(memory: M) -> Result<Self, InitError> {
        BaseVec::<T, M>::init(memory, MAGIC).map(Self)
    }

    /// Returns the number of items in the heap.
    ///
    /// Complexity: O(1)
    pub fn len(&self) -> u64 {
        self.0.len()
    }

    /// Returns true if the heap is empty.
    ///
    /// Complexity: O(1)
    pub fn is_empty(&self) -> bool {
        self.0.is_empty()
    }

    /// Pushes an item onto the heap.
    ///
    /// Complexity: O(log(self.len()))
    pub fn push(&mut self, item: &T) -> Result<(), GrowFailed> {
        self.0.push(item)?;
        self.bubble_up(self.0.len() - 1, item);
        debug_assert_eq!(Ok(()), self.check_invariant());
        Ok(())
    }

    /// Removes the smallest item from the heap and returns it.
    /// Returns `None` if the heap is empty.
    ///
    /// Complexity: O(log(self.len()))
    pub fn pop(&mut self) -> Option<T> {
        let n = self.len();
        match n {
            0 => None,
            1 => self.0.pop(),
            _more => {
                let smallest = self.0.get(0).unwrap();
                let last = self.0.pop().unwrap();
                self.0.set(0, &last);
                self.bubble_down(0, n - 1, &last);
                debug_assert_eq!(Ok(()), self.check_invariant());
                Some(smallest)
            }
        }
    }

    /// Returns the smallest item in the heap.
    /// Returns `None` if the heap is empty.
    ///
    /// Complexity: O(1)
    pub fn peek(&self) -> Option<T> {
        self.0.get(0)
    }

    /// Returns an iterator visiting all values in the underlying vector, in arbitrary order.
    pub fn iter(&self) -> impl Iterator<Item = T> + '_ {
        self.0.iter()
    }

    /// Returns the underlying memory instance.
    pub fn into_memory(self) -> M {
        self.0.into_memory()
    }

    #[allow(dead_code)]
    /// Checks the HeapInvariant(self, 0, self.len() - 1)
    fn check_invariant(&self) -> Result<(), String> {
        let n = self.len();
        for i in 1..n {
            let p = (i - 1) / 2;
            let item = self.0.get(i).unwrap();
            let parent = self.0.get(p).unwrap();
            if is_less(&item, &parent) {
                return Err(format!(
                    "Binary heap invariant violated in indices {i} and {p}"
                ));
            }
        }
        Ok(())
    }

    /// PRECONDITION: self.0.get(i) == item
    fn bubble_up(&mut self, mut i: u64, item: &T) {
        // We set the flag if self.0.get(i) does not contain the item anymore.
        let mut swapped = false;
        // LOOP INVARIANT: HeapInvariant(self, i, self.len() - 1)
        while i > 0 {
            let p = (i - 1) / 2;
            let parent = self.0.get(p).unwrap();
            if is_less(item, &parent) {
                self.0.set(i, &parent);
                swapped = true;
            } else {
                break;
            }
            i = p;
        }
        if swapped {
            self.0.set(i, item);
        }
    }

    /// PRECONDITION: self.0.get(i) == item
    fn bubble_down(&mut self, mut i: u64, n: u64, item: &T) {
        // We set the flag if self.0.get(i) does not contain the item anymore.
        let mut swapped = false;
        // LOOP INVARIANT: HeapInvariant(self, 0, i)
        loop {
            let l = i * 2 + 1;
            let r = l + 1;

            if n <= l {
                break;
            }

            if n <= r {
                // Only the left child is within the array bounds.

                let left = self.0.get(l).unwrap();
                if is_less(&left, item) {
                    self.0.set(i, &left);
                    swapped = true;
                    i = l;
                    continue;
                }
            } else {
                // Both children are within the array bounds.

                let left = self.0.get(l).unwrap();
                let right = self.0.get(r).unwrap();

                let (min_index, min_elem) = if is_less(&left, &right) {
                    (l, &left)
                } else {
                    (r, &right)
                };

                if is_less(min_elem, item) {
                    self.0.set(i, min_elem);
                    swapped = true;
                    i = min_index;
                    continue;
                }
            }
            break;
        }
        if swapped {
            self.0.set(i, item);
        }
    }
}

fn is_less<T: PartialOrd>(x: &T, y: &T) -> bool {
    x.partial_cmp(y) == Some(std::cmp::Ordering::Less)
}

impl<T, M> fmt::Debug for MinHeap<T, M>
where
    T: Storable + PartialOrd + fmt::Debug,
    M: Memory,
{
    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
        self.0.fmt(fmt)
    }
}