diskann_quantization/alloc/
bump.rs1use std::{
7 cell::UnsafeCell,
8 ptr::NonNull,
9 sync::{
10 Arc,
11 atomic::{AtomicUsize, Ordering},
12 },
13};
14
15use super::{AlignedAllocator, AllocatorCore, AllocatorError, Poly};
16use crate::num::PowerOfTwo;
17
18#[derive(Debug, Clone)]
26pub struct BumpAllocator {
27 inner: Arc<BumpAllocatorInner>,
28}
29
30#[derive(Debug)]
31struct BumpAllocatorInner {
32 buffer: Poly<UnsafeCell<[u8]>, AlignedAllocator>,
33 head: AtomicUsize,
34}
35
36unsafe impl Send for BumpAllocatorInner {}
39
40unsafe impl Sync for BumpAllocatorInner {}
43
44impl std::panic::RefUnwindSafe for BumpAllocatorInner {}
47
48impl BumpAllocator {
49 pub fn new(capacity: usize, alignment: PowerOfTwo) -> Result<Self, AllocatorError> {
55 let allocator = AlignedAllocator::new(alignment);
56 let buffer = Poly::<[u8], _>::new_uninit_slice(capacity.max(1), allocator)?;
57 let (ptr, alloc) = Poly::into_raw(buffer);
58
59 let buffer = unsafe {
67 Poly::from_raw(
68 NonNull::new_unchecked(ptr.as_ptr() as *mut UnsafeCell<[u8]>),
69 alloc,
70 )
71 };
72
73 Ok(Self {
74 inner: std::sync::Arc::new(BumpAllocatorInner {
75 buffer,
76 head: Default::default(),
77 }),
78 })
79 }
80
81 pub fn capacity(&self) -> usize {
83 self.inner.buffer.get().len()
84 }
85
86 pub fn as_ptr(&self) -> *const u8 {
88 self.inner.buffer.get().cast::<u8>().cast_const()
89 }
90}
91
92fn next(base: usize, offset: usize, layout: std::alloc::Layout) -> Option<usize> {
96 let p = PowerOfTwo::from_align(&layout);
97 p.arg_checked_next_multiple_of(base + offset)
98 .map(|x| x - base)
99 .and_then(|x| x.checked_add(layout.size()))
100}
101
102unsafe impl AllocatorCore for BumpAllocator {
111 fn allocate(&self, layout: std::alloc::Layout) -> Result<NonNull<[u8]>, AllocatorError> {
112 let base = self.as_ptr() as usize;
114
115 let compute_next = |head: usize| -> Result<usize, AllocatorError> {
118 let new_head = next(base, head, layout).ok_or(AllocatorError)?;
119 if new_head > self.capacity() {
120 Err(AllocatorError)
121 } else {
122 Ok(new_head)
123 }
124 };
125
126 let mut old_head = self.inner.head.load(Ordering::Relaxed);
130 let mut new_head = compute_next(old_head)?;
131 loop {
132 match self.inner.head.compare_exchange(
133 old_head,
134 new_head,
135 Ordering::Relaxed,
136 Ordering::Relaxed,
137 ) {
138 Ok(_) => break,
139 Err(h) => {
140 old_head = h;
141 new_head = compute_next(h)?;
142 }
143 }
144 }
145
146 let ptr = unsafe { self.as_ptr().add(old_head) };
148
149 let ptr =
152 unsafe { ptr.add(PowerOfTwo::from_align(&layout).arg_align_offset(ptr as usize)) };
153
154 NonNull::new(std::ptr::slice_from_raw_parts_mut(
157 ptr.cast_mut(),
158 layout.size(),
159 ))
160 .ok_or(AllocatorError)
161 }
162
163 unsafe fn deallocate(&self, _ptr: NonNull<[u8]>, _layout: std::alloc::Layout) {}
166}
167
168#[cfg(test)]
173mod tests {
174 use rand::{
175 SeedableRng,
176 distr::{Distribution, Uniform},
177 rngs::StdRng,
178 };
179
180 use super::*;
181 use crate::alloc::Poly;
182
183 #[test]
188 fn test_bump_allocator() {
189 let allocator = BumpAllocator::new(128, PowerOfTwo::new(1).unwrap()).unwrap();
190 let mut a = Poly::new(0usize, allocator.clone()).unwrap();
191 let mut b = Poly::new(1usize, allocator.clone()).unwrap();
192 let mut c = Poly::new(2usize, allocator.clone()).unwrap();
193
194 *b = 5;
195 *a = 10;
196 *c = 87;
197 *a = 20;
198
199 assert_eq!(*b, 5);
200 }
201
202 #[test]
203 fn poly_new_with_allocates_first() {
204 let allocator = BumpAllocator::new(128, PowerOfTwo::new(64).unwrap()).unwrap();
205
206 struct Nested {
207 inner: Poly<[usize], BumpAllocator>,
208 value: f32,
209 }
210
211 let poly = Poly::<Nested, _>::new_with(
212 |a| -> Result<_, AllocatorError> {
213 Ok(Nested {
214 inner: Poly::from_iter(0..10, a)?,
215 value: 10.0,
216 })
217 },
218 allocator.clone(),
219 )
220 .unwrap();
221
222 assert!(poly.inner.iter().enumerate().all(|(i, v)| i == *v));
224 assert_eq!(poly.value, 10.0);
225
226 let base = allocator.as_ptr();
228 assert_eq!(base, Poly::as_ptr(&poly).cast::<u8>());
229 assert_eq!(
230 base.wrapping_add(32),
231 Poly::as_ptr(&poly.inner).cast::<u8>()
232 );
233 }
234
235 fn values<T: Default>(alloc: BumpAllocator, seed: u64) {
236 let mut buf = Vec::new();
237 let mut rng = StdRng::seed_from_u64(seed);
238
239 let index_dist = Uniform::new(0, 10).unwrap();
240
241 while let Ok(poly) = Poly::new(T::default(), alloc.clone()) {
242 buf.push(poly);
243 if buf.len() == 10 {
244 buf.remove(index_dist.sample(&mut rng));
245 }
246 }
247 }
248
249 fn slices<T: Default>(alloc: BumpAllocator, seed: u64) {
250 let mut buf = Vec::new();
251 let mut rng = StdRng::seed_from_u64(seed);
252
253 let dist = Uniform::new(0, 10).unwrap();
254
255 while let Ok(poly) = Poly::from_iter(
256 (0..dist.sample(&mut rng)).map(|_| T::default()),
257 alloc.clone(),
258 ) {
259 buf.push(poly);
260 if buf.len() == 10 {
261 buf.remove(dist.sample(&mut rng));
262 }
263 }
264 }
265
266 fn stress_test_impl() {
267 let alloc = BumpAllocator::new(4096, PowerOfTwo::new(1).unwrap()).unwrap();
268
269 let c0 = alloc.clone();
270 let c1 = alloc.clone();
271 let c2 = alloc.clone();
272 let c3 = alloc.clone();
273 let handles = [
274 std::thread::spawn(move || values::<u8>(c0, 0xa7c0b68e3ece66f7)),
275 std::thread::spawn(move || values::<String>(c1, 0x72f0fbcaaefbc884)),
276 std::thread::spawn(move || slices::<u16>(c2, 0x447a846ceb3eeda9)),
277 std::thread::spawn(move || slices::<String>(c3, 0xd34c7cbedaf165ad)),
278 ];
279
280 for h in handles.into_iter() {
281 h.join().unwrap();
282 }
283 }
284
285 #[test]
286 fn stress_test() {
287 let trials = if cfg!(miri) { 3 } else { 100 };
288
289 for _ in 0..trials {
290 stress_test_impl();
291 }
292 }
293}