diskann_quantization/algorithms/
heap.rs1use thiserror::Error;
7
8pub struct SliceHeap<'a, T: Ord + Copy> {
12 data: &'a mut [T],
13}
14
15#[derive(Debug, Error)]
16#[error("heap cannot be constructed from an empty slice")]
17pub struct EmptySlice;
18
19impl<'a, T: Ord + Copy> SliceHeap<'a, T> {
20 pub fn new(data: &'a mut [T]) -> Result<Self, EmptySlice> {
27 if data.is_empty() {
28 return Err(EmptySlice);
29 }
30
31 let mut heap = SliceHeap { data };
32 heap.heapify();
33 Ok(heap)
34 }
35
36 pub fn new_unchecked(data: &'a mut [T]) -> Result<Self, EmptySlice> {
43 if data.is_empty() {
44 return Err(EmptySlice);
45 }
46
47 Ok(SliceHeap { data })
48 }
49
50 pub fn len(&self) -> usize {
52 self.data.len()
53 }
54
55 pub fn is_empty(&self) -> bool {
57 false
58 }
59
60 pub fn peek(&self) -> Option<&T> {
62 self.data.first()
63 }
64
65 pub fn update_root<F>(&mut self, update_fn: F)
70 where
71 F: FnOnce(&mut T),
72 {
73 let root = unsafe { self.data.get_unchecked_mut(0) };
75 update_fn(root);
76 self.sift_down(0);
77 }
78
79 pub fn heapify(&mut self) {
81 if self.data.len() <= 1 {
82 return;
83 }
84
85 let start = (self.data.len() - 2) / 2;
87 for i in (0..=start).rev() {
88 self.sift_down(i);
89 }
90 }
91
92 pub fn as_slice(&self) -> &[T] {
94 self.data
95 }
96
97 unsafe fn get_unchecked(&self, pos: usize) -> &T {
103 debug_assert!(pos < self.len());
104 self.data.get_unchecked(pos)
105 }
106
107 unsafe fn swap_unchecked(&mut self, a: usize, b: usize) {
117 debug_assert!(a < self.len());
118 debug_assert!(b < self.len());
119 debug_assert!(a != b);
120 let base = self.data.as_mut_ptr();
121
122 unsafe { std::ptr::swap_nonoverlapping(base.add(a), base.add(b), 1) }
125 }
126
127 fn sift_down(&mut self, mut pos: usize) {
133 const {
134 assert!(
135 std::mem::size_of::<T>() != 0,
136 "cannot operate on a `SliceHeap` with a zero sized type"
137 )
138 };
139
140 let len = self.len();
141
142 let mut child = 2 * pos + 1;
147
148 while child <= len.saturating_sub(2) {
150 child += unsafe { self.get_unchecked(child) <= self.get_unchecked(child + 1) } as usize;
157
158 if unsafe { self.get_unchecked(pos) >= self.get_unchecked(child) } {
166 return;
167 }
168
169 unsafe { self.swap_unchecked(pos, child) };
174 pos = child;
175 child = 2 * pos + 1;
176 }
177
178 if child == len - 1 && unsafe { self.get_unchecked(pos) < self.get_unchecked(child) } {
181 unsafe { self.swap_unchecked(pos, child) };
184 }
185 }
186}
187
188#[cfg(test)]
189mod tests {
190 use std::collections::BinaryHeap;
191
192 use rand::{rngs::StdRng, Rng, SeedableRng};
193
194 use super::*;
195
196 #[test]
197 fn test_basic_heap_creation() {
198 let mut data = [3, 1, 4, 1, 5, 9, 2, 6];
199 let heap = SliceHeap::new(&mut data).unwrap();
200
201 assert_eq!(heap.len(), 8);
202 assert!(!heap.is_empty());
203 assert_eq!(heap.peek(), Some(&9));
204 }
205
206 #[test]
207 fn test_update_root() {
208 let mut data = [3, 1, 4, 1, 5, 9, 2, 6];
209 let mut heap = SliceHeap::new(&mut data).unwrap();
210
211 heap.update_root(|x| {
213 assert_eq!(*x, 9);
214 *x = 5
215 });
216
217 assert_eq!(heap.peek(), Some(&6));
218
219 heap.update_root(|x| {
221 assert_eq!(*x, 6);
222 *x = 10
223 });
224 assert_eq!(heap.peek(), Some(&10));
225
226 heap.update_root(|x| {
228 assert_eq!(*x, 10);
229 *x = 10;
230 });
231 assert_eq!(heap.peek(), Some(&10));
232
233 heap.update_root(|x| {
235 assert_eq!(*x, 10);
236 *x = 1
237 });
238 assert_eq!(heap.peek(), Some(&5));
239 }
240
241 #[test]
242 fn test_empty_heap() {
243 let mut data: [i32; 0] = [];
244 let result = SliceHeap::new(&mut data);
245
246 assert!(matches!(result, Err(EmptySlice)));
247
248 let result_unchecked = SliceHeap::new_unchecked(&mut data);
249 assert!(matches!(result_unchecked, Err(EmptySlice)));
250 }
251
252 #[test]
253 fn test_single_element() {
254 let mut data = [42];
255 let mut heap = SliceHeap::new(&mut data).unwrap();
256
257 assert_eq!(heap.len(), 1);
258 assert_eq!(heap.peek(), Some(&42));
259
260 heap.update_root(|x| *x = 100);
261 assert_eq!(heap.peek(), Some(&100));
262
263 heap.update_root(|x| *x = 10);
264 assert_eq!(heap.peek(), Some(&10));
265 }
266
267 #[test]
268 fn test_heapify() {
269 let mut data = [1, 2, 3, 4, 5];
270 let mut heap = SliceHeap::new_unchecked(&mut data).unwrap(); heap.heapify();
274
275 assert_eq!(heap.peek(), Some(&5));
276
277 heap.update_root(|x| *x = 0);
279 assert_eq!(heap.peek(), Some(&4));
280
281 heap.update_root(|x| *x = 0);
282 assert_eq!(heap.peek(), Some(&3));
283 }
284
285 #[test]
286 fn test_heap_property_maintained() {
287 let mut data = [10, 8, 9, 4, 7, 5, 3, 2, 1, 6];
288 let mut heap = SliceHeap::new(&mut data).unwrap();
289
290 for new_val in (1..10).rev() {
292 heap.update_root(|x| *x = new_val);
293
294 let slice = heap.as_slice();
296 for i in 0..slice.len() {
297 let left = 2 * i + 1;
298 let right = 2 * i + 2;
299
300 if left < slice.len() {
301 assert!(
302 slice[i] >= slice[left],
303 "Heap property violated: parent {} < left child {}",
304 slice[i],
305 slice[left]
306 );
307 }
308
309 if right < slice.len() {
310 assert!(
311 slice[i] >= slice[right],
312 "Heap property violated: parent {} < right child {}",
313 slice[i],
314 slice[right]
315 );
316 }
317 }
318 }
319 }
320
321 fn fuzz_test_impl(heap_size: usize, num_operations: usize, rng: &mut StdRng) {
322 let mut slice_data: Vec<i32> = (0..heap_size)
324 .map(|_| rng.random_range(-100..100))
325 .collect();
326
327 let mut binary_heap: BinaryHeap<i32> = slice_data.iter().copied().collect();
329 let mut slice_heap = SliceHeap::new(&mut slice_data).unwrap();
330
331 assert_eq!(slice_heap.peek().copied(), binary_heap.peek().copied());
333
334 for iteration in 0..num_operations {
336 let new_value = rng.random_range(-200..200);
338
339 let slice_old_max = slice_heap.peek().copied();
341 slice_heap.update_root(|x| *x = new_value);
342 let slice_new_max = slice_heap.peek().copied();
343
344 let binary_old_max = binary_heap.pop();
346 binary_heap.push(new_value);
347 let binary_new_max = binary_heap.peek().copied();
348
349 assert_eq!(
351 slice_old_max, binary_old_max,
352 "Iteration {}: Old maxima differ after updating {} to {}. SliceHeap old max: {:?}, BinaryHeap old max: {:?}",
353 iteration, slice_old_max.unwrap_or(0), new_value, slice_old_max, binary_old_max
354 );
355
356 assert_eq!(
357 slice_new_max, binary_new_max,
358 "Iteration {}: Maxima differ after updating {} to {}. SliceHeap max: {:?}, BinaryHeap max: {:?}",
359 iteration, slice_old_max.unwrap_or(0), new_value, slice_new_max, binary_new_max
360 );
361
362 verify_heap_property(slice_heap.as_slice());
364
365 if iteration % 100 == 0 {
367 let mut slice_elements: Vec<i32> = slice_heap.as_slice().to_vec();
368 slice_elements.sort_unstable();
369 slice_elements.reverse(); let mut binary_elements: Vec<i32> = binary_heap.clone().into_sorted_vec();
372 binary_elements.reverse(); assert_eq!(
375 slice_elements, binary_elements,
376 "Iteration {}: Heap contents differ when sorted",
377 iteration
378 );
379 }
380 }
381 }
382
383 #[test]
384 fn fuzz_test_against_binary_heap() {
385 let mut rng = StdRng::seed_from_u64(0x0d270403030e30bb);
386
387 fuzz_test_impl(1, 101, &mut rng);
389
390 fuzz_test_impl(2, 101, &mut rng);
392
393 fuzz_test_impl(1000, 1000, &mut rng);
395
396 fuzz_test_impl(128, 1000, &mut rng);
398 }
399
400 #[test]
401 fn fuzz_test_edge_cases() {
402 let mut rng = StdRng::seed_from_u64(123);
403
404 for heap_size in 1..=10 {
406 let mut data: Vec<i32> = (0..heap_size)
407 .map(|_| rng.random_range(-100..100))
408 .collect();
409 let mut heap = SliceHeap::new(&mut data).unwrap();
410
411 for _ in 0..50 {
413 let new_value = rng.random_range(-200..200);
414 heap.update_root(|x| *x = new_value);
415
416 verify_heap_property(heap.as_slice());
418
419 let max = heap.peek().unwrap();
421 assert!(
422 heap.as_slice().iter().all(|&x| x <= *max),
423 "Max element {} is not actually the maximum in heap: {:?}",
424 max,
425 heap.as_slice()
426 );
427 }
428 }
429 }
430
431 fn verify_heap_property(slice: &[i32]) {
433 for i in 0..slice.len() {
434 let left = 2 * i + 1;
435 let right = 2 * i + 2;
436
437 if left < slice.len() {
438 assert!(
439 slice[i] >= slice[left],
440 "Heap property violated: parent {} at index {} < left child {} at index {}. Full heap: {:?}",
441 slice[i], i, slice[left], left, slice
442 );
443 }
444
445 if right < slice.len() {
446 assert!(
447 slice[i] >= slice[right],
448 "Heap property violated: parent {} at index {} < right child {} at index {}. Full heap: {:?}",
449 slice[i], i, slice[right], right, slice
450 );
451 }
452 }
453 }
454}