Skip to main content

diskann_quantization/alloc/
aligned.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use std::ptr::NonNull;
7
8use thiserror::Error;
9
10use super::{AllocatorCore, AllocatorError, GlobalAllocator};
11use crate::num::PowerOfTwo;
12
13/// An [`AllocatorCore`] that allocates memory aligned to at least a specified alignment.
14///
15/// This can be useful for large allocations that need a predictable base alignment.
16#[derive(Debug, Clone, Copy)]
17pub struct AlignedAllocator {
18    /// This represents a power of 2.
19    alignment: u8,
20}
21
22impl AlignedAllocator {
23    /// Aligned allocators for commonly specified boundaries in the codebase (4..4096)
24    pub const A4: Self = Self::new(PowerOfTwo::V4);
25    pub const A8: Self = Self::new(PowerOfTwo::V8);
26    pub const A16: Self = Self::new(PowerOfTwo::V16);
27    pub const A32: Self = Self::new(PowerOfTwo::V32);
28    pub const A64: Self = Self::new(PowerOfTwo::V64);
29    pub const A128: Self = Self::new(PowerOfTwo::V128);
30    pub const A256: Self = Self::new(PowerOfTwo::V256);
31    pub const A512: Self = Self::new(PowerOfTwo::V512);
32    pub const A1024: Self = Self::new(PowerOfTwo::V1024);
33    pub const A2048: Self = Self::new(PowerOfTwo::V2048);
34    pub const A4096: Self = Self::new(PowerOfTwo::V4096);
35
36    /// Construct a new allocator that uses the given alignment.
37    #[inline]
38    pub const fn new(alignment: PowerOfTwo) -> Self {
39        Self {
40            // CAST: `trailing_zeros` returns as most 63 (because we've removed 0), so
41            // the conversion is always lossless.
42            alignment: alignment.raw().trailing_zeros() as u8,
43        }
44    }
45
46    #[inline]
47    pub const fn alignment(&self) -> usize {
48        1usize << (self.alignment as usize)
49    }
50}
51
52#[derive(Debug, Clone, Copy, Error)]
53#[error("alignment {0} must be a power of two")]
54pub struct NotPowerOfTwo(usize);
55
56// SAFETY: We are making the alignment potentially stricter before forwarding to the
57// `GlobalAllocator`.
58unsafe impl AllocatorCore for AlignedAllocator {
59    #[inline]
60    fn allocate(&self, layout: std::alloc::Layout) -> Result<NonNull<[u8]>, AllocatorError> {
61        // Bump up the alignment.
62        let layout = layout
63            .align_to(self.alignment())
64            .map_err(|_| AllocatorError)?;
65        GlobalAllocator.allocate(layout)
66    }
67
68    #[inline]
69    unsafe fn deallocate(&self, ptr: NonNull<[u8]>, layout: std::alloc::Layout) {
70        // Lint: The given `layout` **should** be the same as that passed to `allocate`,
71        // which must have succeeded for the pointer to be valid in the first place.
72        #[allow(clippy::expect_used)]
73        let layout = layout
74            .align_to(self.alignment())
75            .expect("invalid layout provided");
76
77        // SAFETY: If the caller upheld the safety contract of `deallocate`, then this
78        // pointer is safe to deallocate and the layout is compatible with the layout
79        // created with `allocate`.
80        unsafe { GlobalAllocator.deallocate(ptr, layout) }
81    }
82}
83
84///////////
85// Tests //
86///////////
87
88#[cfg(test)]
89mod tests {
90    use super::*;
91
92    #[test]
93    fn test_aligned_allocator() {
94        let powers_of_two = [
95            1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192,
96        ];
97        let trials = 10;
98        for power in powers_of_two {
99            let alloc = AlignedAllocator::new(PowerOfTwo::new(power).unwrap());
100            assert_eq!(alloc.alignment(), power);
101
102            // Test allocation.
103            struct Guard<'a> {
104                ptr: NonNull<[u8]>,
105                layout: std::alloc::Layout,
106                allocator: &'a AlignedAllocator,
107            }
108
109            impl Drop for Guard<'_> {
110                fn drop(&mut self) {
111                    // SAFETY: We immediately pass allocated pointer to the guard, along
112                    // with the allocator and layout.
113                    unsafe { self.allocator.deallocate(self.ptr, self.layout) }
114                }
115            }
116
117            for trial in 1..(trials + 1) {
118                let layout = std::alloc::Layout::from_size_align(trial, power).unwrap();
119                let ptr = alloc.allocate(layout).unwrap();
120
121                // Ensure we deallocate if we panic.
122                let _guard = Guard {
123                    ptr,
124                    layout,
125                    allocator: &alloc,
126                };
127
128                assert_eq!(ptr.len(), trial);
129                assert_eq!(
130                    (ptr.cast::<u8>().as_ptr() as usize) % power,
131                    0,
132                    "ptr {:?} is not aligned to {}",
133                    ptr,
134                    power
135                );
136            }
137        }
138    }
139}