diskann_quantization/algorithms/
heap.rs1use thiserror::Error;
7
8pub struct SliceHeap<'a, T: Ord + Copy> {
12 data: &'a mut [T],
13}
14
15#[derive(Debug, Error)]
16#[error("heap cannot be constructed from an empty slice")]
17pub struct EmptySlice;
18
19impl<'a, T: Ord + Copy> SliceHeap<'a, T> {
20 pub fn new(data: &'a mut [T]) -> Result<Self, EmptySlice> {
27 if data.is_empty() {
28 return Err(EmptySlice);
29 }
30
31 let mut heap = SliceHeap { data };
32 heap.heapify();
33 Ok(heap)
34 }
35
36 pub fn new_unchecked(data: &'a mut [T]) -> Result<Self, EmptySlice> {
43 if data.is_empty() {
44 return Err(EmptySlice);
45 }
46
47 Ok(SliceHeap { data })
48 }
49
50 pub fn len(&self) -> usize {
52 self.data.len()
53 }
54
55 pub fn is_empty(&self) -> bool {
57 false
58 }
59
60 pub fn peek(&self) -> Option<&T> {
62 self.data.first()
63 }
64
65 pub fn update_root<F>(&mut self, update_fn: F)
70 where
71 F: FnOnce(&mut T),
72 {
73 let root = unsafe { self.data.get_unchecked_mut(0) };
75 update_fn(root);
76 self.sift_down(0);
77 }
78
79 pub fn heapify(&mut self) {
81 if self.data.len() <= 1 {
82 return;
83 }
84
85 let start = (self.data.len() - 2) / 2;
87 for i in (0..=start).rev() {
88 self.sift_down(i);
89 }
90 }
91
92 pub fn as_slice(&self) -> &[T] {
94 self.data
95 }
96
97 unsafe fn get_unchecked(&self, pos: usize) -> &T {
103 debug_assert!(pos < self.len());
104
105 unsafe { self.data.get_unchecked(pos) }
107 }
108
109 unsafe fn swap_unchecked(&mut self, a: usize, b: usize) {
119 debug_assert!(a < self.len());
120 debug_assert!(b < self.len());
121 debug_assert!(a != b);
122 let base = self.data.as_mut_ptr();
123
124 unsafe { std::ptr::swap_nonoverlapping(base.add(a), base.add(b), 1) }
127 }
128
129 fn sift_down(&mut self, mut pos: usize) {
135 const {
136 assert!(
137 std::mem::size_of::<T>() != 0,
138 "cannot operate on a `SliceHeap` with a zero sized type"
139 )
140 };
141
142 let len = self.len();
143
144 let mut child = 2 * pos + 1;
149
150 while child <= len.saturating_sub(2) {
152 child += unsafe { self.get_unchecked(child) <= self.get_unchecked(child + 1) } as usize;
159
160 if unsafe { self.get_unchecked(pos) >= self.get_unchecked(child) } {
168 return;
169 }
170
171 unsafe { self.swap_unchecked(pos, child) };
176 pos = child;
177 child = 2 * pos + 1;
178 }
179
180 if child == len - 1 && unsafe { self.get_unchecked(pos) < self.get_unchecked(child) } {
183 unsafe { self.swap_unchecked(pos, child) };
186 }
187 }
188}
189
190#[cfg(test)]
191mod tests {
192 use std::collections::BinaryHeap;
193
194 use rand::{Rng, SeedableRng, rngs::StdRng};
195
196 use super::*;
197
198 #[test]
199 fn test_basic_heap_creation() {
200 let mut data = [3, 1, 4, 1, 5, 9, 2, 6];
201 let heap = SliceHeap::new(&mut data).unwrap();
202
203 assert_eq!(heap.len(), 8);
204 assert!(!heap.is_empty());
205 assert_eq!(heap.peek(), Some(&9));
206 }
207
208 #[test]
209 fn test_update_root() {
210 let mut data = [3, 1, 4, 1, 5, 9, 2, 6];
211 let mut heap = SliceHeap::new(&mut data).unwrap();
212
213 heap.update_root(|x| {
215 assert_eq!(*x, 9);
216 *x = 5
217 });
218
219 assert_eq!(heap.peek(), Some(&6));
220
221 heap.update_root(|x| {
223 assert_eq!(*x, 6);
224 *x = 10
225 });
226 assert_eq!(heap.peek(), Some(&10));
227
228 heap.update_root(|x| {
230 assert_eq!(*x, 10);
231 *x = 10;
232 });
233 assert_eq!(heap.peek(), Some(&10));
234
235 heap.update_root(|x| {
237 assert_eq!(*x, 10);
238 *x = 1
239 });
240 assert_eq!(heap.peek(), Some(&5));
241 }
242
243 #[test]
244 fn test_empty_heap() {
245 let mut data: [i32; 0] = [];
246 let result = SliceHeap::new(&mut data);
247
248 assert!(matches!(result, Err(EmptySlice)));
249
250 let result_unchecked = SliceHeap::new_unchecked(&mut data);
251 assert!(matches!(result_unchecked, Err(EmptySlice)));
252 }
253
254 #[test]
255 fn test_single_element() {
256 let mut data = [42];
257 let mut heap = SliceHeap::new(&mut data).unwrap();
258
259 assert_eq!(heap.len(), 1);
260 assert_eq!(heap.peek(), Some(&42));
261
262 heap.update_root(|x| *x = 100);
263 assert_eq!(heap.peek(), Some(&100));
264
265 heap.update_root(|x| *x = 10);
266 assert_eq!(heap.peek(), Some(&10));
267 }
268
269 #[test]
270 fn test_heapify() {
271 let mut data = [1, 2, 3, 4, 5];
272 let mut heap = SliceHeap::new_unchecked(&mut data).unwrap(); heap.heapify();
276
277 assert_eq!(heap.peek(), Some(&5));
278
279 heap.update_root(|x| *x = 0);
281 assert_eq!(heap.peek(), Some(&4));
282
283 heap.update_root(|x| *x = 0);
284 assert_eq!(heap.peek(), Some(&3));
285 }
286
287 #[test]
288 fn test_heap_property_maintained() {
289 let mut data = [10, 8, 9, 4, 7, 5, 3, 2, 1, 6];
290 let mut heap = SliceHeap::new(&mut data).unwrap();
291
292 for new_val in (1..10).rev() {
294 heap.update_root(|x| *x = new_val);
295
296 let slice = heap.as_slice();
298 for i in 0..slice.len() {
299 let left = 2 * i + 1;
300 let right = 2 * i + 2;
301
302 if left < slice.len() {
303 assert!(
304 slice[i] >= slice[left],
305 "Heap property violated: parent {} < left child {}",
306 slice[i],
307 slice[left]
308 );
309 }
310
311 if right < slice.len() {
312 assert!(
313 slice[i] >= slice[right],
314 "Heap property violated: parent {} < right child {}",
315 slice[i],
316 slice[right]
317 );
318 }
319 }
320 }
321 }
322
323 fn fuzz_test_impl(heap_size: usize, num_operations: usize, rng: &mut StdRng) {
324 let mut slice_data: Vec<i32> = (0..heap_size)
326 .map(|_| rng.random_range(-100..100))
327 .collect();
328
329 let mut binary_heap: BinaryHeap<i32> = slice_data.iter().copied().collect();
331 let mut slice_heap = SliceHeap::new(&mut slice_data).unwrap();
332
333 assert_eq!(slice_heap.peek().copied(), binary_heap.peek().copied());
335
336 for iteration in 0..num_operations {
338 let new_value = rng.random_range(-200..200);
340
341 let slice_old_max = slice_heap.peek().copied();
343 slice_heap.update_root(|x| *x = new_value);
344 let slice_new_max = slice_heap.peek().copied();
345
346 let binary_old_max = binary_heap.pop();
348 binary_heap.push(new_value);
349 let binary_new_max = binary_heap.peek().copied();
350
351 assert_eq!(
353 slice_old_max,
354 binary_old_max,
355 "Iteration {}: Old maxima differ after updating {} to {}. SliceHeap old max: {:?}, BinaryHeap old max: {:?}",
356 iteration,
357 slice_old_max.unwrap_or(0),
358 new_value,
359 slice_old_max,
360 binary_old_max
361 );
362
363 assert_eq!(
364 slice_new_max,
365 binary_new_max,
366 "Iteration {}: Maxima differ after updating {} to {}. SliceHeap max: {:?}, BinaryHeap max: {:?}",
367 iteration,
368 slice_old_max.unwrap_or(0),
369 new_value,
370 slice_new_max,
371 binary_new_max
372 );
373
374 verify_heap_property(slice_heap.as_slice());
376
377 if iteration % 100 == 0 {
379 let mut slice_elements: Vec<i32> = slice_heap.as_slice().to_vec();
380 slice_elements.sort_unstable();
381 slice_elements.reverse(); let mut binary_elements: Vec<i32> = binary_heap.clone().into_sorted_vec();
384 binary_elements.reverse(); assert_eq!(
387 slice_elements, binary_elements,
388 "Iteration {}: Heap contents differ when sorted",
389 iteration
390 );
391 }
392 }
393 }
394
395 #[test]
396 fn fuzz_test_against_binary_heap() {
397 let mut rng = StdRng::seed_from_u64(0x0d270403030e30bb);
398
399 fuzz_test_impl(1, 101, &mut rng);
401
402 fuzz_test_impl(2, 101, &mut rng);
404
405 fuzz_test_impl(10, 101, &mut rng);
407
408 #[cfg(not(miri))]
410 {
411 fuzz_test_impl(1000, 1000, &mut rng);
413
414 fuzz_test_impl(128, 1000, &mut rng);
416 }
417 }
418
419 #[test]
420 fn fuzz_test_edge_cases() {
421 let mut rng = StdRng::seed_from_u64(123);
422
423 for heap_size in 1..=10 {
425 let mut data: Vec<i32> = (0..heap_size)
426 .map(|_| rng.random_range(-100..100))
427 .collect();
428 let mut heap = SliceHeap::new(&mut data).unwrap();
429
430 for _ in 0..50 {
432 let new_value = rng.random_range(-200..200);
433 heap.update_root(|x| *x = new_value);
434
435 verify_heap_property(heap.as_slice());
437
438 let max = heap.peek().unwrap();
440 assert!(
441 heap.as_slice().iter().all(|&x| x <= *max),
442 "Max element {} is not actually the maximum in heap: {:?}",
443 max,
444 heap.as_slice()
445 );
446 }
447 }
448 }
449
450 fn verify_heap_property(slice: &[i32]) {
452 for i in 0..slice.len() {
453 let left = 2 * i + 1;
454 let right = 2 * i + 2;
455
456 if left < slice.len() {
457 assert!(
458 slice[i] >= slice[left],
459 "Heap property violated: parent {} at index {} < left child {} at index {}. Full heap: {:?}",
460 slice[i],
461 i,
462 slice[left],
463 left,
464 slice
465 );
466 }
467
468 if right < slice.len() {
469 assert!(
470 slice[i] >= slice[right],
471 "Heap property violated: parent {} at index {} < right child {} at index {}. Full heap: {:?}",
472 slice[i],
473 i,
474 slice[right],
475 right,
476 slice
477 );
478 }
479 }
480 }
481}