diskann_quantization/alloc/
aligned.rs1use std::ptr::NonNull;
7
8use thiserror::Error;
9
10use super::{AllocatorCore, AllocatorError, GlobalAllocator};
11use crate::num::PowerOfTwo;
12
13#[derive(Debug, Clone, Copy)]
17pub struct AlignedAllocator {
18 alignment: u8,
20}
21
22impl AlignedAllocator {
23 #[inline]
25 pub const fn new(alignment: PowerOfTwo) -> Self {
26 Self {
27 alignment: alignment.raw().trailing_zeros() as u8,
30 }
31 }
32
33 #[inline]
34 pub const fn alignment(&self) -> usize {
35 1usize << (self.alignment as usize)
36 }
37}
38
39#[derive(Debug, Clone, Copy, Error)]
40#[error("alignment {0} must be a power of two")]
41pub struct NotPowerOfTwo(usize);
42
43unsafe impl AllocatorCore for AlignedAllocator {
46 #[inline]
47 fn allocate(&self, layout: std::alloc::Layout) -> Result<NonNull<[u8]>, AllocatorError> {
48 let layout = layout
50 .align_to(self.alignment())
51 .map_err(|_| AllocatorError)?;
52 GlobalAllocator.allocate(layout)
53 }
54
55 #[inline]
56 unsafe fn deallocate(&self, ptr: NonNull<[u8]>, layout: std::alloc::Layout) {
57 #[allow(clippy::expect_used)]
60 let layout = layout
61 .align_to(self.alignment())
62 .expect("invalid layout provided");
63 GlobalAllocator.deallocate(ptr, layout)
64 }
65}
66
67#[cfg(test)]
72mod tests {
73 use super::*;
74
75 #[test]
76 fn test_aligned_allocator() {
77 let powers_of_two = [
78 1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192,
79 ];
80 let trials = 10;
81 for power in powers_of_two {
82 let alloc = AlignedAllocator::new(PowerOfTwo::new(power).unwrap());
83 assert_eq!(alloc.alignment(), power);
84
85 struct Guard<'a> {
87 ptr: NonNull<[u8]>,
88 layout: std::alloc::Layout,
89 allocator: &'a AlignedAllocator,
90 }
91
92 impl Drop for Guard<'_> {
93 fn drop(&mut self) {
94 unsafe { self.allocator.deallocate(self.ptr, self.layout) }
97 }
98 }
99
100 for trial in 1..(trials + 1) {
101 let layout = std::alloc::Layout::from_size_align(trial, power).unwrap();
102 let ptr = alloc.allocate(layout).unwrap();
103
104 let _guard = Guard {
106 ptr,
107 layout,
108 allocator: &alloc,
109 };
110
111 assert_eq!(ptr.len(), trial);
112 assert_eq!(
113 (ptr.cast::<u8>().as_ptr() as usize) % power,
114 0,
115 "ptr {:?} is not aligned to {}",
116 ptr,
117 power
118 );
119 }
120 }
121 }
122}