diskann_quantization/alloc/
aligned.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use std::ptr::NonNull;
7
8use thiserror::Error;
9
10use super::{AllocatorCore, AllocatorError, GlobalAllocator};
11use crate::num::PowerOfTwo;
12
13/// An [`AllocatorCore`] that allocates memory aligned to at least a specified alignment.
14///
15/// This can be useful for large allocations that need a predictable base alignment.
16#[derive(Debug, Clone, Copy)]
17pub struct AlignedAllocator {
18    /// This represents a power of 2.
19    alignment: u8,
20}
21
22impl AlignedAllocator {
23    /// Construct a new allocator that uses the given alignment.
24    #[inline]
25    pub const fn new(alignment: PowerOfTwo) -> Self {
26        Self {
27            // CAST: `trailing_zeros` returns as most 63 (because we've removed 0), so
28            // the conversion is always lossless.
29            alignment: alignment.raw().trailing_zeros() as u8,
30        }
31    }
32
33    #[inline]
34    pub const fn alignment(&self) -> usize {
35        1usize << (self.alignment as usize)
36    }
37}
38
39#[derive(Debug, Clone, Copy, Error)]
40#[error("alignment {0} must be a power of two")]
41pub struct NotPowerOfTwo(usize);
42
43// SAFETY: We are making the alignment potentially stricter before forwarding to the
44// `GlobalAllocator`.
45unsafe impl AllocatorCore for AlignedAllocator {
46    #[inline]
47    fn allocate(&self, layout: std::alloc::Layout) -> Result<NonNull<[u8]>, AllocatorError> {
48        // Bump up the alignment.
49        let layout = layout
50            .align_to(self.alignment())
51            .map_err(|_| AllocatorError)?;
52        GlobalAllocator.allocate(layout)
53    }
54
55    #[inline]
56    unsafe fn deallocate(&self, ptr: NonNull<[u8]>, layout: std::alloc::Layout) {
57        // Lint: The given `layout` **should** be the same as that passed to `allocate`,
58        // which must have succeeded for the pointer to be valid in the first place.
59        #[allow(clippy::expect_used)]
60        let layout = layout
61            .align_to(self.alignment())
62            .expect("invalid layout provided");
63        GlobalAllocator.deallocate(ptr, layout)
64    }
65}
66
67///////////
68// Tests //
69///////////
70
71#[cfg(test)]
72mod tests {
73    use super::*;
74
75    #[test]
76    fn test_aligned_allocator() {
77        let powers_of_two = [
78            1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192,
79        ];
80        let trials = 10;
81        for power in powers_of_two {
82            let alloc = AlignedAllocator::new(PowerOfTwo::new(power).unwrap());
83            assert_eq!(alloc.alignment(), power);
84
85            // Test allocation.
86            struct Guard<'a> {
87                ptr: NonNull<[u8]>,
88                layout: std::alloc::Layout,
89                allocator: &'a AlignedAllocator,
90            }
91
92            impl Drop for Guard<'_> {
93                fn drop(&mut self) {
94                    // SAFETY: We immediately pass allocated pointer to the guard, along
95                    // with the allocator and layout.
96                    unsafe { self.allocator.deallocate(self.ptr, self.layout) }
97                }
98            }
99
100            for trial in 1..(trials + 1) {
101                let layout = std::alloc::Layout::from_size_align(trial, power).unwrap();
102                let ptr = alloc.allocate(layout).unwrap();
103
104                // Ensure we deallocate if we panic.
105                let _guard = Guard {
106                    ptr,
107                    layout,
108                    allocator: &alloc,
109                };
110
111                assert_eq!(ptr.len(), trial);
112                assert_eq!(
113                    (ptr.cast::<u8>().as_ptr() as usize) % power,
114                    0,
115                    "ptr {:?} is not aligned to {}",
116                    ptr,
117                    power
118                );
119            }
120        }
121    }
122}