Skip to main content

diskann_quantization/alloc/
aligned.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use std::ptr::NonNull;
7
8use thiserror::Error;
9
10use super::{AllocatorCore, AllocatorError, GlobalAllocator};
11use crate::num::PowerOfTwo;
12
13/// An [`AllocatorCore`] that allocates memory aligned to at least a specified alignment.
14///
15/// This can be useful for large allocations that need a predictable base alignment.
16#[derive(Debug, Clone, Copy)]
17pub struct AlignedAllocator {
18    /// This represents a power of 2.
19    alignment: u8,
20}
21
22impl AlignedAllocator {
23    /// Construct a new allocator that uses the given alignment.
24    #[inline]
25    pub const fn new(alignment: PowerOfTwo) -> Self {
26        Self {
27            // CAST: `trailing_zeros` returns as most 63 (because we've removed 0), so
28            // the conversion is always lossless.
29            alignment: alignment.raw().trailing_zeros() as u8,
30        }
31    }
32
33    #[inline]
34    pub const fn alignment(&self) -> usize {
35        1usize << (self.alignment as usize)
36    }
37}
38
39#[derive(Debug, Clone, Copy, Error)]
40#[error("alignment {0} must be a power of two")]
41pub struct NotPowerOfTwo(usize);
42
43// SAFETY: We are making the alignment potentially stricter before forwarding to the
44// `GlobalAllocator`.
45unsafe impl AllocatorCore for AlignedAllocator {
46    #[inline]
47    fn allocate(&self, layout: std::alloc::Layout) -> Result<NonNull<[u8]>, AllocatorError> {
48        // Bump up the alignment.
49        let layout = layout
50            .align_to(self.alignment())
51            .map_err(|_| AllocatorError)?;
52        GlobalAllocator.allocate(layout)
53    }
54
55    #[inline]
56    unsafe fn deallocate(&self, ptr: NonNull<[u8]>, layout: std::alloc::Layout) {
57        // Lint: The given `layout` **should** be the same as that passed to `allocate`,
58        // which must have succeeded for the pointer to be valid in the first place.
59        #[allow(clippy::expect_used)]
60        let layout = layout
61            .align_to(self.alignment())
62            .expect("invalid layout provided");
63
64        // SAFETY: If the caller upheld the safety contract of `deallocate`, then this
65        // pointer is safe to deallocate and the layout is compatible with the layout
66        // created with `allocate`.
67        unsafe { GlobalAllocator.deallocate(ptr, layout) }
68    }
69}
70
71///////////
72// Tests //
73///////////
74
75#[cfg(test)]
76mod tests {
77    use super::*;
78
79    #[test]
80    fn test_aligned_allocator() {
81        let powers_of_two = [
82            1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192,
83        ];
84        let trials = 10;
85        for power in powers_of_two {
86            let alloc = AlignedAllocator::new(PowerOfTwo::new(power).unwrap());
87            assert_eq!(alloc.alignment(), power);
88
89            // Test allocation.
90            struct Guard<'a> {
91                ptr: NonNull<[u8]>,
92                layout: std::alloc::Layout,
93                allocator: &'a AlignedAllocator,
94            }
95
96            impl Drop for Guard<'_> {
97                fn drop(&mut self) {
98                    // SAFETY: We immediately pass allocated pointer to the guard, along
99                    // with the allocator and layout.
100                    unsafe { self.allocator.deallocate(self.ptr, self.layout) }
101                }
102            }
103
104            for trial in 1..(trials + 1) {
105                let layout = std::alloc::Layout::from_size_align(trial, power).unwrap();
106                let ptr = alloc.allocate(layout).unwrap();
107
108                // Ensure we deallocate if we panic.
109                let _guard = Guard {
110                    ptr,
111                    layout,
112                    allocator: &alloc,
113                };
114
115                assert_eq!(ptr.len(), trial);
116                assert_eq!(
117                    (ptr.cast::<u8>().as_ptr() as usize) % power,
118                    0,
119                    "ptr {:?} is not aligned to {}",
120                    ptr,
121                    power
122                );
123            }
124        }
125    }
126}