Skip to main content

diskann_quantization/algorithms/
heap.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use thiserror::Error;
7
8/// A fixed-size heap data structure that operates in place on non-empty mutable slices.
9///
10/// The heap size never changes after creation, and only supports updating the maximum element.
11pub struct SliceHeap<'a, T: Ord + Copy> {
12    data: &'a mut [T],
13}
14
15#[derive(Debug, Error)]
16#[error("heap cannot be constructed from an empty slice")]
17pub struct EmptySlice;
18
19impl<'a, T: Ord + Copy> SliceHeap<'a, T> {
20    /// Creates a new `SliceHeap` from a mutable slice.
21    /// The slice is assumed to be unordered initially and will be heapified.
22    ///
23    /// # Errors
24    ///
25    /// Returns `EmptySlice` if the input slice is empty.
26    pub fn new(data: &'a mut [T]) -> Result<Self, EmptySlice> {
27        if data.is_empty() {
28            return Err(EmptySlice);
29        }
30
31        let mut heap = SliceHeap { data };
32        heap.heapify();
33        Ok(heap)
34    }
35
36    /// Creates a new `SliceHeap` from a mutable slice without heapifying.
37    /// Use this if you know the slice is already in heap order.
38    ///
39    /// # Errors
40    ///
41    /// Returns `EmptySlice` if the input slice is empty.
42    pub fn new_unchecked(data: &'a mut [T]) -> Result<Self, EmptySlice> {
43        if data.is_empty() {
44            return Err(EmptySlice);
45        }
46
47        Ok(SliceHeap { data })
48    }
49
50    /// Returns the number of elements in the heap.
51    pub fn len(&self) -> usize {
52        self.data.len()
53    }
54
55    /// Always returns `false` as the heap can never be empty
56    pub fn is_empty(&self) -> bool {
57        false
58    }
59
60    /// Returns a reference to the greatest element in the heap, or `None` if empty.
61    pub fn peek(&self) -> Option<&T> {
62        self.data.first()
63    }
64
65    /// Updates the root element in place and restores the heap property.
66    /// This allows direct mutation of the maximum element.
67    ///
68    /// Since the heap cannot be empty (enforced by construction), this operation always succeeds.
69    pub fn update_root<F>(&mut self, update_fn: F)
70    where
71        F: FnOnce(&mut T),
72    {
73        // SAFETY: The heap is guaranteed to be unempty.
74        let root = unsafe { self.data.get_unchecked_mut(0) };
75        update_fn(root);
76        self.sift_down(0);
77    }
78
79    /// Converts the entire slice into a heap.
80    pub fn heapify(&mut self) {
81        if self.data.len() <= 1 {
82            return;
83        }
84
85        // Start from the last non-leaf node and sift down
86        let start = (self.data.len() - 2) / 2;
87        for i in (0..=start).rev() {
88            self.sift_down(i);
89        }
90    }
91
92    /// Returns a slice of all heap elements in heap order (not sorted order).
93    pub fn as_slice(&self) -> &[T] {
94        self.data
95    }
96
97    /// Get the element as position `pos`.
98    ///
99    /// # Safety
100    ///
101    /// `pos < self.len()` (checked in debug mode).
102    unsafe fn get_unchecked(&self, pos: usize) -> &T {
103        debug_assert!(pos < self.len());
104
105        // SAFETY: Inherited from caller.
106        unsafe { self.data.get_unchecked(pos) }
107    }
108
109    /// Swap the two elements as positions `a` and `b`.
110    ///
111    /// # Safety
112    ///
113    /// All the following must hold (these are checked in debug mode):
114    ///
115    /// 1. `a < self.len()`.
116    /// 2. `b < self.len()`.
117    /// 3. `a != b`.
118    unsafe fn swap_unchecked(&mut self, a: usize, b: usize) {
119        debug_assert!(a < self.len());
120        debug_assert!(b < self.len());
121        debug_assert!(a != b);
122        let base = self.data.as_mut_ptr();
123
124        // SAFETY: The safety requirements of this function imply that the pointer arithmetic
125        // is valid and that the non-overlapping criteria are satisfied.
126        unsafe { std::ptr::swap_nonoverlapping(base.add(a), base.add(b), 1) }
127    }
128
129    /// The implementation of this function is largely copied from `sift_down_range` in
130    /// https://doc.rust-lang.org/src/alloc/collections/binary_heap/mod.rs.html#776.
131    ///
132    /// Since we've constrainted `T: Copy`, we don't need to worry about the `Hole` helper
133    /// data structures.
134    fn sift_down(&mut self, mut pos: usize) {
135        const {
136            assert!(
137                std::mem::size_of::<T>() != 0,
138                "cannot operate on a `SliceHeap` with a zero sized type"
139            )
140        };
141
142        let len = self.len();
143
144        // Since the maximum allocation size is `isize::MAX`, the maximum value that `pos`
145        // can be while satisfying the safety requirements is `isize::MAX`.
146        //
147        // This means that `2 * pos + 1 == usize::MAX` so this operation never overflows.
148        let mut child = 2 * pos + 1;
149
150        // Loop Invariant: child == 2 * pos + 1
151        while child <= len.saturating_sub(2) {
152            // compare with the greater of the two children
153            // SAFETY: We have the following:
154            //  * `child >= 1`: By loop invariant. If we enter this loop, then we're
155            //    guaranteed that `len >= 3`.
156            //  * `child < self.len() - 1` and thus `child + 1 < self.len()` - so both are
157            //    valid indices.
158            child += unsafe { self.get_unchecked(child) <= self.get_unchecked(child + 1) } as usize;
159
160            // If we are already in order, stop.
161            //
162            // SAFETY: `child` is now either the old `child` or the old `child + 1`
163            // We already proven that both are `< self.len()`.
164            //
165            // Furthermore, since `pos < child` (no matter which one is chosen), `pos` is
166            // also in-bounds.
167            if unsafe { self.get_unchecked(pos) >= self.get_unchecked(child) } {
168                return;
169            }
170
171            // SAFETY: We've proven that `pos` and `child` are in-bounds. Since
172            //  * `child = 2 * pos + 1 > pos`.
173            //  * `child + 1 = 2 * pos + 2 > pos`.
174            // we are guaranteed that `pos != child`.
175            unsafe { self.swap_unchecked(pos, child) };
176            pos = child;
177            child = 2 * pos + 1;
178        }
179
180        // SAFETY: We've explicitly checked that `child < self.len()` and from the loop
181        // invariante above, `pos < child`. So both accesses are in-bounds.
182        if child == len - 1 && unsafe { self.get_unchecked(pos) < self.get_unchecked(child) } {
183            // SAFETY: We've proved that `pos` and `child` are in-bounds. From the loop
184            // invariant above, `pos != child`, so the swap is valid.
185            unsafe { self.swap_unchecked(pos, child) };
186        }
187    }
188}
189
190#[cfg(test)]
191mod tests {
192    use std::collections::BinaryHeap;
193
194    use rand::{Rng, SeedableRng, rngs::StdRng};
195
196    use super::*;
197
198    #[test]
199    fn test_basic_heap_creation() {
200        let mut data = [3, 1, 4, 1, 5, 9, 2, 6];
201        let heap = SliceHeap::new(&mut data).unwrap();
202
203        assert_eq!(heap.len(), 8);
204        assert!(!heap.is_empty());
205        assert_eq!(heap.peek(), Some(&9));
206    }
207
208    #[test]
209    fn test_update_root() {
210        let mut data = [3, 1, 4, 1, 5, 9, 2, 6];
211        let mut heap = SliceHeap::new(&mut data).unwrap();
212
213        // Update max (9) to 5
214        heap.update_root(|x| {
215            assert_eq!(*x, 9);
216            *x = 5
217        });
218
219        assert_eq!(heap.peek(), Some(&6));
220
221        // Update max to 10 (should become new max)
222        heap.update_root(|x| {
223            assert_eq!(*x, 6);
224            *x = 10
225        });
226        assert_eq!(heap.peek(), Some(&10));
227
228        // If we update to the same value, it should remain in place.
229        heap.update_root(|x| {
230            assert_eq!(*x, 10);
231            *x = 10;
232        });
233        assert_eq!(heap.peek(), Some(&10));
234
235        // Update max to 1 (should sink to bottom)
236        heap.update_root(|x| {
237            assert_eq!(*x, 10);
238            *x = 1
239        });
240        assert_eq!(heap.peek(), Some(&5));
241    }
242
243    #[test]
244    fn test_empty_heap() {
245        let mut data: [i32; 0] = [];
246        let result = SliceHeap::new(&mut data);
247
248        assert!(matches!(result, Err(EmptySlice)));
249
250        let result_unchecked = SliceHeap::new_unchecked(&mut data);
251        assert!(matches!(result_unchecked, Err(EmptySlice)));
252    }
253
254    #[test]
255    fn test_single_element() {
256        let mut data = [42];
257        let mut heap = SliceHeap::new(&mut data).unwrap();
258
259        assert_eq!(heap.len(), 1);
260        assert_eq!(heap.peek(), Some(&42));
261
262        heap.update_root(|x| *x = 100);
263        assert_eq!(heap.peek(), Some(&100));
264
265        heap.update_root(|x| *x = 10);
266        assert_eq!(heap.peek(), Some(&10));
267    }
268
269    #[test]
270    fn test_heapify() {
271        let mut data = [1, 2, 3, 4, 5];
272        let mut heap = SliceHeap::new_unchecked(&mut data).unwrap(); // Not heapified
273
274        // Manually heapify
275        heap.heapify();
276
277        assert_eq!(heap.peek(), Some(&5));
278
279        // Verify heap property by updating max to minimum and checking order
280        heap.update_root(|x| *x = 0);
281        assert_eq!(heap.peek(), Some(&4));
282
283        heap.update_root(|x| *x = 0);
284        assert_eq!(heap.peek(), Some(&3));
285    }
286
287    #[test]
288    fn test_heap_property_maintained() {
289        let mut data = [10, 8, 9, 4, 7, 5, 3, 2, 1, 6];
290        let mut heap = SliceHeap::new(&mut data).unwrap();
291
292        // Repeatedly update max with smaller values
293        for new_val in (1..10).rev() {
294            heap.update_root(|x| *x = new_val);
295
296            // Verify heap property: parent >= children
297            let slice = heap.as_slice();
298            for i in 0..slice.len() {
299                let left = 2 * i + 1;
300                let right = 2 * i + 2;
301
302                if left < slice.len() {
303                    assert!(
304                        slice[i] >= slice[left],
305                        "Heap property violated: parent {} < left child {}",
306                        slice[i],
307                        slice[left]
308                    );
309                }
310
311                if right < slice.len() {
312                    assert!(
313                        slice[i] >= slice[right],
314                        "Heap property violated: parent {} < right child {}",
315                        slice[i],
316                        slice[right]
317                    );
318                }
319            }
320        }
321    }
322
323    fn fuzz_test_impl(heap_size: usize, num_operations: usize, rng: &mut StdRng) {
324        // Generate initial data
325        let mut slice_data: Vec<i32> = (0..heap_size)
326            .map(|_| rng.random_range(-100..100))
327            .collect();
328
329        // Create heaps
330        let mut binary_heap: BinaryHeap<i32> = slice_data.iter().copied().collect();
331        let mut slice_heap = SliceHeap::new(&mut slice_data).unwrap();
332
333        // Verify initial state
334        assert_eq!(slice_heap.peek().copied(), binary_heap.peek().copied());
335
336        // Perform random operations
337        for iteration in 0..num_operations {
338            // Generate a random new value for the maximum element
339            let new_value = rng.random_range(-200..200);
340
341            // Update slice heap
342            let slice_old_max = slice_heap.peek().copied();
343            slice_heap.update_root(|x| *x = new_value);
344            let slice_new_max = slice_heap.peek().copied();
345
346            // Update binary heap (remove max, add new value)
347            let binary_old_max = binary_heap.pop();
348            binary_heap.push(new_value);
349            let binary_new_max = binary_heap.peek().copied();
350
351            // Verify they have the same maximum
352            assert_eq!(
353                slice_old_max,
354                binary_old_max,
355                "Iteration {}: Old maxima differ after updating {} to {}. SliceHeap old max: {:?}, BinaryHeap old max: {:?}",
356                iteration,
357                slice_old_max.unwrap_or(0),
358                new_value,
359                slice_old_max,
360                binary_old_max
361            );
362
363            assert_eq!(
364                slice_new_max,
365                binary_new_max,
366                "Iteration {}: Maxima differ after updating {} to {}. SliceHeap max: {:?}, BinaryHeap max: {:?}",
367                iteration,
368                slice_old_max.unwrap_or(0),
369                new_value,
370                slice_new_max,
371                binary_new_max
372            );
373
374            // Verify heap property is maintained in slice heap
375            verify_heap_property(slice_heap.as_slice());
376
377            // Occasionally verify that both heaps contain the same elements (when sorted)
378            if iteration % 100 == 0 {
379                let mut slice_elements: Vec<i32> = slice_heap.as_slice().to_vec();
380                slice_elements.sort_unstable();
381                slice_elements.reverse(); // Sort descending
382
383                let mut binary_elements: Vec<i32> = binary_heap.clone().into_sorted_vec();
384                binary_elements.reverse(); // BinaryHeap::into_sorted_vec() returns ascending, we want descending
385
386                assert_eq!(
387                    slice_elements, binary_elements,
388                    "Iteration {}: Heap contents differ when sorted",
389                    iteration
390                );
391            }
392        }
393    }
394
395    #[test]
396    fn fuzz_test_against_binary_heap() {
397        let mut rng = StdRng::seed_from_u64(0x0d270403030e30bb);
398
399        // Heap of size 1.
400        fuzz_test_impl(1, 101, &mut rng);
401
402        // Heap of size 2.
403        fuzz_test_impl(2, 101, &mut rng);
404
405        // Heap size not power of two.
406        fuzz_test_impl(1000, 1000, &mut rng);
407
408        // Heap size power of two.
409        fuzz_test_impl(128, 1000, &mut rng);
410    }
411
412    #[test]
413    fn fuzz_test_edge_cases() {
414        let mut rng = StdRng::seed_from_u64(123);
415
416        // Test with small heaps
417        for heap_size in 1..=10 {
418            let mut data: Vec<i32> = (0..heap_size)
419                .map(|_| rng.random_range(-100..100))
420                .collect();
421            let mut heap = SliceHeap::new(&mut data).unwrap();
422
423            // Perform random updates
424            for _ in 0..50 {
425                let new_value = rng.random_range(-200..200);
426                heap.update_root(|x| *x = new_value);
427
428                // Verify heap property
429                verify_heap_property(heap.as_slice());
430
431                // Verify max is actually the maximum
432                let max = heap.peek().unwrap();
433                assert!(
434                    heap.as_slice().iter().all(|&x| x <= *max),
435                    "Max element {} is not actually the maximum in heap: {:?}",
436                    max,
437                    heap.as_slice()
438                );
439            }
440        }
441    }
442
443    /// Helper function to verify the heap property holds for a slice
444    fn verify_heap_property(slice: &[i32]) {
445        for i in 0..slice.len() {
446            let left = 2 * i + 1;
447            let right = 2 * i + 2;
448
449            if left < slice.len() {
450                assert!(
451                    slice[i] >= slice[left],
452                    "Heap property violated: parent {} at index {} < left child {} at index {}. Full heap: {:?}",
453                    slice[i],
454                    i,
455                    slice[left],
456                    left,
457                    slice
458                );
459            }
460
461            if right < slice.len() {
462                assert!(
463                    slice[i] >= slice[right],
464                    "Heap property violated: parent {} at index {} < right child {} at index {}. Full heap: {:?}",
465                    slice[i],
466                    i,
467                    slice[right],
468                    right,
469                    slice
470                );
471            }
472        }
473    }
474}