diskann_quantization/alloc/
aligned.rs1use std::ptr::NonNull;
7
8use thiserror::Error;
9
10use super::{AllocatorCore, AllocatorError, GlobalAllocator};
11use crate::num::PowerOfTwo;
12
13#[derive(Debug, Clone, Copy)]
17pub struct AlignedAllocator {
18 alignment: u8,
20}
21
22impl AlignedAllocator {
23 #[inline]
25 pub const fn new(alignment: PowerOfTwo) -> Self {
26 Self {
27 alignment: alignment.raw().trailing_zeros() as u8,
30 }
31 }
32
33 #[inline]
34 pub const fn alignment(&self) -> usize {
35 1usize << (self.alignment as usize)
36 }
37}
38
39#[derive(Debug, Clone, Copy, Error)]
40#[error("alignment {0} must be a power of two")]
41pub struct NotPowerOfTwo(usize);
42
43unsafe impl AllocatorCore for AlignedAllocator {
46 #[inline]
47 fn allocate(&self, layout: std::alloc::Layout) -> Result<NonNull<[u8]>, AllocatorError> {
48 let layout = layout
50 .align_to(self.alignment())
51 .map_err(|_| AllocatorError)?;
52 GlobalAllocator.allocate(layout)
53 }
54
55 #[inline]
56 unsafe fn deallocate(&self, ptr: NonNull<[u8]>, layout: std::alloc::Layout) {
57 #[allow(clippy::expect_used)]
60 let layout = layout
61 .align_to(self.alignment())
62 .expect("invalid layout provided");
63
64 unsafe { GlobalAllocator.deallocate(ptr, layout) }
68 }
69}
70
71#[cfg(test)]
76mod tests {
77 use super::*;
78
79 #[test]
80 fn test_aligned_allocator() {
81 let powers_of_two = [
82 1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192,
83 ];
84 let trials = 10;
85 for power in powers_of_two {
86 let alloc = AlignedAllocator::new(PowerOfTwo::new(power).unwrap());
87 assert_eq!(alloc.alignment(), power);
88
89 struct Guard<'a> {
91 ptr: NonNull<[u8]>,
92 layout: std::alloc::Layout,
93 allocator: &'a AlignedAllocator,
94 }
95
96 impl Drop for Guard<'_> {
97 fn drop(&mut self) {
98 unsafe { self.allocator.deallocate(self.ptr, self.layout) }
101 }
102 }
103
104 for trial in 1..(trials + 1) {
105 let layout = std::alloc::Layout::from_size_align(trial, power).unwrap();
106 let ptr = alloc.allocate(layout).unwrap();
107
108 let _guard = Guard {
110 ptr,
111 layout,
112 allocator: &alloc,
113 };
114
115 assert_eq!(ptr.len(), trial);
116 assert_eq!(
117 (ptr.cast::<u8>().as_ptr() as usize) % power,
118 0,
119 "ptr {:?} is not aligned to {}",
120 ptr,
121 power
122 );
123 }
124 }
125 }
126}