diskann_quantization/alloc/
mod.rs1use std::{alloc::Layout, ptr::NonNull};
7
8mod aligned;
9mod bump;
10mod poly;
11mod traits;
12
13pub use aligned::{AlignedAllocator, NotPowerOfTwo};
14pub use bump::BumpAllocator;
15pub use poly::{poly, CompoundError, Poly, TrustedIter};
16pub use traits::{Allocator, AllocatorCore, AllocatorError};
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq)]
20pub struct GlobalAllocator;
21
22unsafe impl AllocatorCore for GlobalAllocator {
27 fn allocate(&self, layout: Layout) -> Result<NonNull<[u8]>, AllocatorError> {
28 if layout.size() == 0 {
29 return Err(AllocatorError);
30 }
31
32 let ptr = unsafe { std::alloc::alloc(layout) };
34 let ptr = std::ptr::slice_from_raw_parts_mut(ptr, layout.size());
35 NonNull::new(ptr).ok_or(AllocatorError)
36 }
37
38 unsafe fn deallocate(&self, ptr: NonNull<[u8]>, layout: Layout) {
39 unsafe { std::alloc::dealloc(ptr.as_ptr().cast::<u8>(), layout) }
42 }
43}
44
45trait DebugAllocator: AllocatorCore + std::fmt::Debug {}
50impl<T> DebugAllocator for T where T: AllocatorCore + std::fmt::Debug {}
51
52#[derive(Debug, Clone, Copy)]
58pub struct ScopedAllocator<'a> {
59 allocator: &'a dyn DebugAllocator,
60}
61
62impl<'a> ScopedAllocator<'a> {
63 pub const fn new<T>(allocator: &'a T) -> Self
65 where
66 T: AllocatorCore + std::fmt::Debug,
67 {
68 Self { allocator }
69 }
70}
71
72impl ScopedAllocator<'static> {
73 pub const fn global() -> Self {
76 Self {
77 allocator: &GlobalAllocator,
78 }
79 }
80}
81
82unsafe impl AllocatorCore for ScopedAllocator<'_> {
84 fn allocate(&self, layout: std::alloc::Layout) -> Result<NonNull<[u8]>, AllocatorError> {
85 self.allocator.allocate(layout)
86 }
87
88 unsafe fn deallocate(&self, ptr: NonNull<[u8]>, layout: std::alloc::Layout) {
89 self.allocator.deallocate(ptr, layout)
90 }
91}
92
93pub(crate) trait TryClone: Sized {
105 fn try_clone(&self) -> Result<Self, AllocatorError>;
107}
108
109#[cfg(test)]
114mod tests {
115 use super::*;
116
117 fn test_alloc<T>() {
118 let alloc = GlobalAllocator;
119
120 let layout = Layout::new::<T>();
121 let ptr = alloc.allocate(layout).unwrap();
122
123 assert_eq!(ptr.len(), layout.size());
124 assert_eq!(ptr.len(), std::mem::size_of::<T>());
125 assert_eq!((ptr.as_ptr().cast::<u8>() as usize) % layout.align(), 0);
126
127 unsafe { alloc.deallocate(ptr, layout) };
129 }
130
131 #[test]
132 fn test_global_allocator() {
133 assert!(GlobalAllocator.allocate(Layout::new::<()>()).is_err());
134
135 test_alloc::<(u8,)>();
136 test_alloc::<(u8, u8)>();
137 test_alloc::<(u8, u8, u8)>();
138 test_alloc::<(u8, u8, u8, u8)>();
139 test_alloc::<(u8, u8, u8, u8, u8)>();
140 test_alloc::<(u8, u8, u8, u8, u8, u8)>();
141 test_alloc::<(u8, u8, u8, u8, u8, u8, u8)>();
142 test_alloc::<(u8, u8, u8, u8, u8, u8, u8, u8)>();
143 test_alloc::<(u8, u8, u8, u8, u8, u8, u8, u8, u8)>();
144
145 test_alloc::<(u16,)>();
146 test_alloc::<(u16, u16)>();
147 test_alloc::<(u16, u16, u16)>();
148 test_alloc::<(u16, u16, u16, u16)>();
149 test_alloc::<(u16, u16, u16, u16, u16)>();
150
151 test_alloc::<(u32,)>();
152 test_alloc::<(u32, u32)>();
153 test_alloc::<(u32, u32, u32)>();
154
155 test_alloc::<String>();
156 }
157}