diskann_quantization/alloc/
mod.rs1use std::{alloc::Layout, ptr::NonNull};
7
8mod aligned;
9mod aligned_slice;
10mod bump;
11mod poly;
12mod traits;
13
14pub use aligned::{AlignedAllocator, NotPowerOfTwo};
15pub use aligned_slice::{AlignedSlice, aligned_slice};
16pub use bump::BumpAllocator;
17pub use poly::{CompoundError, Poly, TrustedIter, poly};
18pub use traits::{Allocator, AllocatorCore, AllocatorError};
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub struct GlobalAllocator;
23
24unsafe impl AllocatorCore for GlobalAllocator {
29 fn allocate(&self, layout: Layout) -> Result<NonNull<[u8]>, AllocatorError> {
30 if layout.size() == 0 {
31 return Err(AllocatorError);
32 }
33
34 let ptr = unsafe { std::alloc::alloc(layout) };
36 let ptr = std::ptr::slice_from_raw_parts_mut(ptr, layout.size());
37 NonNull::new(ptr).ok_or(AllocatorError)
38 }
39
40 unsafe fn deallocate(&self, ptr: NonNull<[u8]>, layout: Layout) {
41 unsafe { std::alloc::dealloc(ptr.as_ptr().cast::<u8>(), layout) }
44 }
45}
46
47trait DebugAllocator: AllocatorCore + std::fmt::Debug {}
52impl<T> DebugAllocator for T where T: AllocatorCore + std::fmt::Debug {}
53
54#[derive(Debug, Clone, Copy)]
60pub struct ScopedAllocator<'a> {
61 allocator: &'a dyn DebugAllocator,
62}
63
64impl<'a> ScopedAllocator<'a> {
65 pub const fn new<T>(allocator: &'a T) -> Self
67 where
68 T: AllocatorCore + std::fmt::Debug,
69 {
70 Self { allocator }
71 }
72}
73
74impl ScopedAllocator<'static> {
75 pub const fn global() -> Self {
78 Self {
79 allocator: &GlobalAllocator,
80 }
81 }
82}
83
84unsafe impl AllocatorCore for ScopedAllocator<'_> {
86 fn allocate(&self, layout: std::alloc::Layout) -> Result<NonNull<[u8]>, AllocatorError> {
87 self.allocator.allocate(layout)
88 }
89
90 unsafe fn deallocate(&self, ptr: NonNull<[u8]>, layout: std::alloc::Layout) {
91 unsafe { self.allocator.deallocate(ptr, layout) }
93 }
94}
95
96pub(crate) trait TryClone: Sized {
108 fn try_clone(&self) -> Result<Self, AllocatorError>;
110}
111
112#[cfg(test)]
117mod tests {
118 use super::*;
119
120 fn test_alloc<T>() {
121 let alloc = GlobalAllocator;
122
123 let layout = Layout::new::<T>();
124 let ptr = alloc.allocate(layout).unwrap();
125
126 assert_eq!(ptr.len(), layout.size());
127 assert_eq!(ptr.len(), std::mem::size_of::<T>());
128 assert_eq!((ptr.as_ptr().cast::<u8>() as usize) % layout.align(), 0);
129
130 unsafe { alloc.deallocate(ptr, layout) };
132 }
133
134 #[test]
135 fn test_global_allocator() {
136 assert!(GlobalAllocator.allocate(Layout::new::<()>()).is_err());
137
138 test_alloc::<(u8,)>();
139 test_alloc::<(u8, u8)>();
140 test_alloc::<(u8, u8, u8)>();
141 test_alloc::<(u8, u8, u8, u8)>();
142 test_alloc::<(u8, u8, u8, u8, u8)>();
143 test_alloc::<(u8, u8, u8, u8, u8, u8)>();
144 test_alloc::<(u8, u8, u8, u8, u8, u8, u8)>();
145 test_alloc::<(u8, u8, u8, u8, u8, u8, u8, u8)>();
146 test_alloc::<(u8, u8, u8, u8, u8, u8, u8, u8, u8)>();
147
148 test_alloc::<(u16,)>();
149 test_alloc::<(u16, u16)>();
150 test_alloc::<(u16, u16, u16)>();
151 test_alloc::<(u16, u16, u16, u16)>();
152 test_alloc::<(u16, u16, u16, u16, u16)>();
153
154 test_alloc::<(u32,)>();
155 test_alloc::<(u32, u32)>();
156 test_alloc::<(u32, u32, u32)>();
157
158 test_alloc::<String>();
159 }
160}