Skip to main content

diskann_quantization/algorithms/
heap.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use thiserror::Error;
7
8/// A fixed-size heap data structure that operates in place on non-empty mutable slices.
9///
10/// The heap size never changes after creation, and only supports updating the maximum element.
11pub struct SliceHeap<'a, T: Ord + Copy> {
12    data: &'a mut [T],
13}
14
15#[derive(Debug, Error)]
16#[error("heap cannot be constructed from an empty slice")]
17pub struct EmptySlice;
18
19impl<'a, T: Ord + Copy> SliceHeap<'a, T> {
20    /// Creates a new `SliceHeap` from a mutable slice.
21    /// The slice is assumed to be unordered initially and will be heapified.
22    ///
23    /// # Errors
24    ///
25    /// Returns `EmptySlice` if the input slice is empty.
26    pub fn new(data: &'a mut [T]) -> Result<Self, EmptySlice> {
27        if data.is_empty() {
28            return Err(EmptySlice);
29        }
30
31        let mut heap = SliceHeap { data };
32        heap.heapify();
33        Ok(heap)
34    }
35
36    /// Creates a new `SliceHeap` from a mutable slice without heapifying.
37    /// Use this if you know the slice is already in heap order.
38    ///
39    /// # Errors
40    ///
41    /// Returns `EmptySlice` if the input slice is empty.
42    pub fn new_unchecked(data: &'a mut [T]) -> Result<Self, EmptySlice> {
43        if data.is_empty() {
44            return Err(EmptySlice);
45        }
46
47        Ok(SliceHeap { data })
48    }
49
50    /// Returns the number of elements in the heap.
51    pub fn len(&self) -> usize {
52        self.data.len()
53    }
54
55    /// Always returns `false` as the heap can never be empty
56    pub fn is_empty(&self) -> bool {
57        false
58    }
59
60    /// Returns a reference to the greatest element in the heap, or `None` if empty.
61    pub fn peek(&self) -> Option<&T> {
62        self.data.first()
63    }
64
65    /// Updates the root element in place and restores the heap property.
66    /// This allows direct mutation of the maximum element.
67    ///
68    /// Since the heap cannot be empty (enforced by construction), this operation always succeeds.
69    pub fn update_root<F>(&mut self, update_fn: F)
70    where
71        F: FnOnce(&mut T),
72    {
73        // SAFETY: The heap is guaranteed to be unempty.
74        let root = unsafe { self.data.get_unchecked_mut(0) };
75        update_fn(root);
76        self.sift_down(0);
77    }
78
79    /// Converts the entire slice into a heap.
80    pub fn heapify(&mut self) {
81        if self.data.len() <= 1 {
82            return;
83        }
84
85        // Start from the last non-leaf node and sift down
86        let start = (self.data.len() - 2) / 2;
87        for i in (0..=start).rev() {
88            self.sift_down(i);
89        }
90    }
91
92    /// Returns a slice of all heap elements in heap order (not sorted order).
93    pub fn as_slice(&self) -> &[T] {
94        self.data
95    }
96
97    /// Get the element as position `pos`.
98    ///
99    /// # Safety
100    ///
101    /// `pos < self.len()` (checked in debug mode).
102    unsafe fn get_unchecked(&self, pos: usize) -> &T {
103        debug_assert!(pos < self.len());
104        self.data.get_unchecked(pos)
105    }
106
107    /// Swap the two elements as positions `a` and `b`.
108    ///
109    /// # Safety
110    ///
111    /// All the following must hold (these are checked in debug mode):
112    ///
113    /// 1. `a < self.len()`.
114    /// 2. `b < self.len()`.
115    /// 3. `a != b`.
116    unsafe fn swap_unchecked(&mut self, a: usize, b: usize) {
117        debug_assert!(a < self.len());
118        debug_assert!(b < self.len());
119        debug_assert!(a != b);
120        let base = self.data.as_mut_ptr();
121
122        // SAFETY: The safety requirements of this function imply that the pointer arithmetic
123        // is valid and that the non-overlapping criteria are satisfied.
124        unsafe { std::ptr::swap_nonoverlapping(base.add(a), base.add(b), 1) }
125    }
126
127    /// The implementation of this function is largely copied from `sift_down_range` in
128    /// https://doc.rust-lang.org/src/alloc/collections/binary_heap/mod.rs.html#776.
129    ///
130    /// Since we've constrainted `T: Copy`, we don't need to worry about the `Hole` helper
131    /// data structures.
132    fn sift_down(&mut self, mut pos: usize) {
133        const {
134            assert!(
135                std::mem::size_of::<T>() != 0,
136                "cannot operate on a `SliceHeap` with a zero sized type"
137            )
138        };
139
140        let len = self.len();
141
142        // Since the maximum allocation size is `isize::MAX`, the maximum value that `pos`
143        // can be while satisfying the safety requirements is `isize::MAX`.
144        //
145        // This means that `2 * pos + 1 == usize::MAX` so this operation never overflows.
146        let mut child = 2 * pos + 1;
147
148        // Loop Invariant: child == 2 * pos + 1
149        while child <= len.saturating_sub(2) {
150            // compare with the greater of the two children
151            // SAFETY: We have the following:
152            //  * `child >= 1`: By loop invariant. If we enter this loop, then we're
153            //    guaranteed that `len >= 3`.
154            //  * `child < self.len() - 1` and thus `child + 1 < self.len()` - so both are
155            //    valid indices.
156            child += unsafe { self.get_unchecked(child) <= self.get_unchecked(child + 1) } as usize;
157
158            // If we are already in order, stop.
159            //
160            // SAFETY: `child` is now either the old `child` or the old `child + 1`
161            // We already proven that both are `< self.len()`.
162            //
163            // Furthermore, since `pos < child` (no matter which one is chosen), `pos` is
164            // also in-bounds.
165            if unsafe { self.get_unchecked(pos) >= self.get_unchecked(child) } {
166                return;
167            }
168
169            // SAFETY: We've proven that `pos` and `child` are in-bounds. Since
170            //  * `child = 2 * pos + 1 > pos`.
171            //  * `child + 1 = 2 * pos + 2 > pos`.
172            // we are guaranteed that `pos != child`.
173            unsafe { self.swap_unchecked(pos, child) };
174            pos = child;
175            child = 2 * pos + 1;
176        }
177
178        // SAFETY: We've explicitly checked that `child < self.len()` and from the loop
179        // invariante above, `pos < child`. So both accesses are in-bounds.
180        if child == len - 1 && unsafe { self.get_unchecked(pos) < self.get_unchecked(child) } {
181            // SAFETY: We've proved that `pos` and `child` are in-bounds. From the loop
182            // invariant above, `pos != child`, so the swap is valid.
183            unsafe { self.swap_unchecked(pos, child) };
184        }
185    }
186}
187
188#[cfg(test)]
189mod tests {
190    use std::collections::BinaryHeap;
191
192    use rand::{rngs::StdRng, Rng, SeedableRng};
193
194    use super::*;
195
196    #[test]
197    fn test_basic_heap_creation() {
198        let mut data = [3, 1, 4, 1, 5, 9, 2, 6];
199        let heap = SliceHeap::new(&mut data).unwrap();
200
201        assert_eq!(heap.len(), 8);
202        assert!(!heap.is_empty());
203        assert_eq!(heap.peek(), Some(&9));
204    }
205
206    #[test]
207    fn test_update_root() {
208        let mut data = [3, 1, 4, 1, 5, 9, 2, 6];
209        let mut heap = SliceHeap::new(&mut data).unwrap();
210
211        // Update max (9) to 5
212        heap.update_root(|x| {
213            assert_eq!(*x, 9);
214            *x = 5
215        });
216
217        assert_eq!(heap.peek(), Some(&6));
218
219        // Update max to 10 (should become new max)
220        heap.update_root(|x| {
221            assert_eq!(*x, 6);
222            *x = 10
223        });
224        assert_eq!(heap.peek(), Some(&10));
225
226        // If we update to the same value, it should remain in place.
227        heap.update_root(|x| {
228            assert_eq!(*x, 10);
229            *x = 10;
230        });
231        assert_eq!(heap.peek(), Some(&10));
232
233        // Update max to 1 (should sink to bottom)
234        heap.update_root(|x| {
235            assert_eq!(*x, 10);
236            *x = 1
237        });
238        assert_eq!(heap.peek(), Some(&5));
239    }
240
241    #[test]
242    fn test_empty_heap() {
243        let mut data: [i32; 0] = [];
244        let result = SliceHeap::new(&mut data);
245
246        assert!(matches!(result, Err(EmptySlice)));
247
248        let result_unchecked = SliceHeap::new_unchecked(&mut data);
249        assert!(matches!(result_unchecked, Err(EmptySlice)));
250    }
251
252    #[test]
253    fn test_single_element() {
254        let mut data = [42];
255        let mut heap = SliceHeap::new(&mut data).unwrap();
256
257        assert_eq!(heap.len(), 1);
258        assert_eq!(heap.peek(), Some(&42));
259
260        heap.update_root(|x| *x = 100);
261        assert_eq!(heap.peek(), Some(&100));
262
263        heap.update_root(|x| *x = 10);
264        assert_eq!(heap.peek(), Some(&10));
265    }
266
267    #[test]
268    fn test_heapify() {
269        let mut data = [1, 2, 3, 4, 5];
270        let mut heap = SliceHeap::new_unchecked(&mut data).unwrap(); // Not heapified
271
272        // Manually heapify
273        heap.heapify();
274
275        assert_eq!(heap.peek(), Some(&5));
276
277        // Verify heap property by updating max to minimum and checking order
278        heap.update_root(|x| *x = 0);
279        assert_eq!(heap.peek(), Some(&4));
280
281        heap.update_root(|x| *x = 0);
282        assert_eq!(heap.peek(), Some(&3));
283    }
284
285    #[test]
286    fn test_heap_property_maintained() {
287        let mut data = [10, 8, 9, 4, 7, 5, 3, 2, 1, 6];
288        let mut heap = SliceHeap::new(&mut data).unwrap();
289
290        // Repeatedly update max with smaller values
291        for new_val in (1..10).rev() {
292            heap.update_root(|x| *x = new_val);
293
294            // Verify heap property: parent >= children
295            let slice = heap.as_slice();
296            for i in 0..slice.len() {
297                let left = 2 * i + 1;
298                let right = 2 * i + 2;
299
300                if left < slice.len() {
301                    assert!(
302                        slice[i] >= slice[left],
303                        "Heap property violated: parent {} < left child {}",
304                        slice[i],
305                        slice[left]
306                    );
307                }
308
309                if right < slice.len() {
310                    assert!(
311                        slice[i] >= slice[right],
312                        "Heap property violated: parent {} < right child {}",
313                        slice[i],
314                        slice[right]
315                    );
316                }
317            }
318        }
319    }
320
321    fn fuzz_test_impl(heap_size: usize, num_operations: usize, rng: &mut StdRng) {
322        // Generate initial data
323        let mut slice_data: Vec<i32> = (0..heap_size)
324            .map(|_| rng.random_range(-100..100))
325            .collect();
326
327        // Create heaps
328        let mut binary_heap: BinaryHeap<i32> = slice_data.iter().copied().collect();
329        let mut slice_heap = SliceHeap::new(&mut slice_data).unwrap();
330
331        // Verify initial state
332        assert_eq!(slice_heap.peek().copied(), binary_heap.peek().copied());
333
334        // Perform random operations
335        for iteration in 0..num_operations {
336            // Generate a random new value for the maximum element
337            let new_value = rng.random_range(-200..200);
338
339            // Update slice heap
340            let slice_old_max = slice_heap.peek().copied();
341            slice_heap.update_root(|x| *x = new_value);
342            let slice_new_max = slice_heap.peek().copied();
343
344            // Update binary heap (remove max, add new value)
345            let binary_old_max = binary_heap.pop();
346            binary_heap.push(new_value);
347            let binary_new_max = binary_heap.peek().copied();
348
349            // Verify they have the same maximum
350            assert_eq!(
351                slice_old_max, binary_old_max,
352                "Iteration {}: Old maxima differ after updating {} to {}. SliceHeap old max: {:?}, BinaryHeap old max: {:?}",
353                iteration, slice_old_max.unwrap_or(0), new_value, slice_old_max, binary_old_max
354            );
355
356            assert_eq!(
357                slice_new_max, binary_new_max,
358                "Iteration {}: Maxima differ after updating {} to {}. SliceHeap max: {:?}, BinaryHeap max: {:?}",
359                iteration, slice_old_max.unwrap_or(0), new_value, slice_new_max, binary_new_max
360            );
361
362            // Verify heap property is maintained in slice heap
363            verify_heap_property(slice_heap.as_slice());
364
365            // Occasionally verify that both heaps contain the same elements (when sorted)
366            if iteration % 100 == 0 {
367                let mut slice_elements: Vec<i32> = slice_heap.as_slice().to_vec();
368                slice_elements.sort_unstable();
369                slice_elements.reverse(); // Sort descending
370
371                let mut binary_elements: Vec<i32> = binary_heap.clone().into_sorted_vec();
372                binary_elements.reverse(); // BinaryHeap::into_sorted_vec() returns ascending, we want descending
373
374                assert_eq!(
375                    slice_elements, binary_elements,
376                    "Iteration {}: Heap contents differ when sorted",
377                    iteration
378                );
379            }
380        }
381    }
382
383    #[test]
384    fn fuzz_test_against_binary_heap() {
385        let mut rng = StdRng::seed_from_u64(0x0d270403030e30bb);
386
387        // Heap of size 1.
388        fuzz_test_impl(1, 101, &mut rng);
389
390        // Heap of size 2.
391        fuzz_test_impl(2, 101, &mut rng);
392
393        // Heap size not power of two.
394        fuzz_test_impl(1000, 1000, &mut rng);
395
396        // Heap size power of two.
397        fuzz_test_impl(128, 1000, &mut rng);
398    }
399
400    #[test]
401    fn fuzz_test_edge_cases() {
402        let mut rng = StdRng::seed_from_u64(123);
403
404        // Test with small heaps
405        for heap_size in 1..=10 {
406            let mut data: Vec<i32> = (0..heap_size)
407                .map(|_| rng.random_range(-100..100))
408                .collect();
409            let mut heap = SliceHeap::new(&mut data).unwrap();
410
411            // Perform random updates
412            for _ in 0..50 {
413                let new_value = rng.random_range(-200..200);
414                heap.update_root(|x| *x = new_value);
415
416                // Verify heap property
417                verify_heap_property(heap.as_slice());
418
419                // Verify max is actually the maximum
420                let max = heap.peek().unwrap();
421                assert!(
422                    heap.as_slice().iter().all(|&x| x <= *max),
423                    "Max element {} is not actually the maximum in heap: {:?}",
424                    max,
425                    heap.as_slice()
426                );
427            }
428        }
429    }
430
431    /// Helper function to verify the heap property holds for a slice
432    fn verify_heap_property(slice: &[i32]) {
433        for i in 0..slice.len() {
434            let left = 2 * i + 1;
435            let right = 2 * i + 2;
436
437            if left < slice.len() {
438                assert!(
439                    slice[i] >= slice[left],
440                    "Heap property violated: parent {} at index {} < left child {} at index {}. Full heap: {:?}",
441                    slice[i], i, slice[left], left, slice
442                );
443            }
444
445            if right < slice.len() {
446                assert!(
447                    slice[i] >= slice[right],
448                    "Heap property violated: parent {} at index {} < right child {} at index {}. Full heap: {:?}",
449                    slice[i], i, slice[right], right, slice
450                );
451            }
452        }
453    }
454}