diskann_quantization/alloc/
mod.rs1use std::{alloc::Layout, ptr::NonNull};
7
8mod aligned;
9mod bump;
10mod poly;
11mod traits;
12
13pub use aligned::{AlignedAllocator, NotPowerOfTwo};
14pub use bump::BumpAllocator;
15pub use poly::{CompoundError, Poly, TrustedIter, poly};
16pub use traits::{Allocator, AllocatorCore, AllocatorError};
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq)]
20pub struct GlobalAllocator;
21
22unsafe impl AllocatorCore for GlobalAllocator {
27 fn allocate(&self, layout: Layout) -> Result<NonNull<[u8]>, AllocatorError> {
28 if layout.size() == 0 {
29 return Err(AllocatorError);
30 }
31
32 let ptr = unsafe { std::alloc::alloc(layout) };
34 let ptr = std::ptr::slice_from_raw_parts_mut(ptr, layout.size());
35 NonNull::new(ptr).ok_or(AllocatorError)
36 }
37
38 unsafe fn deallocate(&self, ptr: NonNull<[u8]>, layout: Layout) {
39 unsafe { std::alloc::dealloc(ptr.as_ptr().cast::<u8>(), layout) }
42 }
43}
44
45trait DebugAllocator: AllocatorCore + std::fmt::Debug {}
50impl<T> DebugAllocator for T where T: AllocatorCore + std::fmt::Debug {}
51
52#[derive(Debug, Clone, Copy)]
58pub struct ScopedAllocator<'a> {
59 allocator: &'a dyn DebugAllocator,
60}
61
62impl<'a> ScopedAllocator<'a> {
63 pub const fn new<T>(allocator: &'a T) -> Self
65 where
66 T: AllocatorCore + std::fmt::Debug,
67 {
68 Self { allocator }
69 }
70}
71
72impl ScopedAllocator<'static> {
73 pub const fn global() -> Self {
76 Self {
77 allocator: &GlobalAllocator,
78 }
79 }
80}
81
82unsafe impl AllocatorCore for ScopedAllocator<'_> {
84 fn allocate(&self, layout: std::alloc::Layout) -> Result<NonNull<[u8]>, AllocatorError> {
85 self.allocator.allocate(layout)
86 }
87
88 unsafe fn deallocate(&self, ptr: NonNull<[u8]>, layout: std::alloc::Layout) {
89 unsafe { self.allocator.deallocate(ptr, layout) }
91 }
92}
93
94pub(crate) trait TryClone: Sized {
106 fn try_clone(&self) -> Result<Self, AllocatorError>;
108}
109
110#[cfg(test)]
115mod tests {
116 use super::*;
117
118 fn test_alloc<T>() {
119 let alloc = GlobalAllocator;
120
121 let layout = Layout::new::<T>();
122 let ptr = alloc.allocate(layout).unwrap();
123
124 assert_eq!(ptr.len(), layout.size());
125 assert_eq!(ptr.len(), std::mem::size_of::<T>());
126 assert_eq!((ptr.as_ptr().cast::<u8>() as usize) % layout.align(), 0);
127
128 unsafe { alloc.deallocate(ptr, layout) };
130 }
131
132 #[test]
133 fn test_global_allocator() {
134 assert!(GlobalAllocator.allocate(Layout::new::<()>()).is_err());
135
136 test_alloc::<(u8,)>();
137 test_alloc::<(u8, u8)>();
138 test_alloc::<(u8, u8, u8)>();
139 test_alloc::<(u8, u8, u8, u8)>();
140 test_alloc::<(u8, u8, u8, u8, u8)>();
141 test_alloc::<(u8, u8, u8, u8, u8, u8)>();
142 test_alloc::<(u8, u8, u8, u8, u8, u8, u8)>();
143 test_alloc::<(u8, u8, u8, u8, u8, u8, u8, u8)>();
144 test_alloc::<(u8, u8, u8, u8, u8, u8, u8, u8, u8)>();
145
146 test_alloc::<(u16,)>();
147 test_alloc::<(u16, u16)>();
148 test_alloc::<(u16, u16, u16)>();
149 test_alloc::<(u16, u16, u16, u16)>();
150 test_alloc::<(u16, u16, u16, u16, u16)>();
151
152 test_alloc::<(u32,)>();
153 test_alloc::<(u32, u32)>();
154 test_alloc::<(u32, u32, u32)>();
155
156 test_alloc::<String>();
157 }
158}