Skip to main content

diskann_quantization/alloc/
bump.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use std::{
7    cell::UnsafeCell,
8    ptr::NonNull,
9    sync::{
10        Arc,
11        atomic::{AtomicUsize, Ordering},
12    },
13};
14
15use super::{AlignedAllocator, AllocatorCore, AllocatorError, Poly};
16use crate::num::PowerOfTwo;
17
18/// An [`AllocatorCore`] that pre-allocates a large buffer of memory and then satisfies
19/// allocation requests from that buffer.
20///
21/// Note that the memory for allocations made through this allocator and its clones will not
22/// be freed until all clones have been dropped.
23///
24/// Memory allocation through this page is thread safe.
25#[derive(Debug, Clone)]
26pub struct BumpAllocator {
27    inner: Arc<BumpAllocatorInner>,
28}
29
30#[derive(Debug)]
31struct BumpAllocatorInner {
32    buffer: Poly<UnsafeCell<[u8]>, AlignedAllocator>,
33    head: AtomicUsize,
34}
35
36// SAFETY: Allocation and deallocation are thread-safe, so `BumpAllocatorInner` can be sent
37// between threads.
38unsafe impl Send for BumpAllocatorInner {}
39
40// SAFETY: Allocation and deallocation are thread-safe, so `BumpAllocatorInner` can be shared
41// between threads.
42unsafe impl Sync for BumpAllocatorInner {}
43
44// Interior mutation only occurs in a section of code that should not panic. So even if
45// we're unwinding around a `&BumpAllocatorInner`, we should not break invariants.
46impl std::panic::RefUnwindSafe for BumpAllocatorInner {}
47
48impl BumpAllocator {
49    /// Construct a new [`BumpAllocator`] with room for `capacity` bytes. The base pointer
50    /// for the allocator will be aligned to at least `alignment` bytes.
51    ///
52    /// Returns an error if `alignment` is not a power of two or if an error occurs during
53    /// memory allocation.
54    pub fn new(capacity: usize, alignment: PowerOfTwo) -> Result<Self, AllocatorError> {
55        let allocator = AlignedAllocator::new(alignment);
56        let buffer = Poly::<[u8], _>::new_uninit_slice(capacity.max(1), allocator)?;
57        let (ptr, alloc) = Poly::into_raw(buffer);
58
59        // SAFETY: The layout for `UnsafeCell<T>` is the same as the layout for `T`, so
60        // casting from `[u8]` to `UnsafeCell<[u8]>` is safe.
61        //
62        // It is safe to cast away `MaybeUninit` because `u8` has not padding and is valid
63        // for all bit patterns.
64        //
65        // Finally, it is safe to construct `NonNull` because `ptr` is already `NonNull`.
66        let buffer = unsafe {
67            Poly::from_raw(
68                NonNull::new_unchecked(ptr.as_ptr() as *mut UnsafeCell<[u8]>),
69                alloc,
70            )
71        };
72
73        Ok(Self {
74            inner: std::sync::Arc::new(BumpAllocatorInner {
75                buffer,
76                head: Default::default(),
77            }),
78        })
79    }
80
81    /// Return the capacity this allocator was created with.
82    pub fn capacity(&self) -> usize {
83        self.inner.buffer.get().len()
84    }
85
86    /// Return a pointer to the base of the buffer behind this allocator.
87    pub fn as_ptr(&self) -> *const u8 {
88        self.inner.buffer.get().cast::<u8>().cast_const()
89    }
90}
91
92/// Given a `base` address and a current `offset` from that base, compute a `new_offset`
93/// such that the range spanned by `[base + offset, base + new_offset)` has sufficient room
94/// to fulfill the allocation request in `layout`.
95fn next(base: usize, offset: usize, layout: std::alloc::Layout) -> Option<usize> {
96    let p = PowerOfTwo::from_align(&layout);
97    p.arg_checked_next_multiple_of(base + offset)
98        .map(|x| x - base)
99        .and_then(|x| x.checked_add(layout.size()))
100}
101
102// SAFETY: The implementation of `BumpAllocator` ensures that upon success
103//
104// 1. Allocations provided from the buffer are properly aligned, regardless of the current
105//    state of the `head` pointer.
106// 2. The allocation is always of the requested size.
107//
108// If both of these cannot be satisfied without running off the end of the page, an error
109// is returned.
110unsafe impl AllocatorCore for BumpAllocator {
111    fn allocate(&self, layout: std::alloc::Layout) -> Result<NonNull<[u8]>, AllocatorError> {
112        // Get the base pointer as an integer for alignment calculations.
113        let base = self.as_ptr() as usize;
114
115        // Return the new head that ensures we have room for the allocation defined by
116        // `layout` starting from the current head.
117        let compute_next = |head: usize| -> Result<usize, AllocatorError> {
118            let new_head = next(base, head, layout).ok_or(AllocatorError)?;
119            if new_head > self.capacity() {
120                Err(AllocatorError)
121            } else {
122                Ok(new_head)
123            }
124        };
125
126        // Spin until we successfully update the `head` pointer. Successful update indicates
127        // that we own the span of memory between `old_head` and `new_head` and can provide
128        // that for the allocation.
129        let mut old_head = self.inner.head.load(Ordering::Relaxed);
130        let mut new_head = compute_next(old_head)?;
131        loop {
132            match self.inner.head.compare_exchange(
133                old_head,
134                new_head,
135                Ordering::Relaxed,
136                Ordering::Relaxed,
137            ) {
138                Ok(_) => break,
139                Err(h) => {
140                    old_head = h;
141                    new_head = compute_next(h)?;
142                }
143            }
144        }
145
146        // SAFETY: `old_head` is guaranteed to be within the range of the buffer allocation.
147        let ptr = unsafe { self.as_ptr().add(old_head) };
148
149        // SAFETY: The computation of `new_head` ensures that we have space to do this
150        // alignment.
151        let ptr =
152            unsafe { ptr.add(PowerOfTwo::from_align(&layout).arg_align_offset(ptr as usize)) };
153
154        // SAFETY: The computation of `new_head` ensures that we have space to construct
155        // a slice of this length after alignment.
156        NonNull::new(std::ptr::slice_from_raw_parts_mut(
157            ptr.cast_mut(),
158            layout.size(),
159        ))
160        .ok_or(AllocatorError)
161    }
162
163    // No work to do in deallocation - dropping the reference count for the bump allocator
164    // is sufficient.
165    unsafe fn deallocate(&self, _ptr: NonNull<[u8]>, _layout: std::alloc::Layout) {}
166}
167
168///////////
169// Tests //
170///////////
171
172#[cfg(test)]
173mod tests {
174    use rand::{
175        SeedableRng,
176        distr::{Distribution, Uniform},
177        rngs::StdRng,
178    };
179
180    use super::*;
181    use crate::alloc::Poly;
182
183    ///////////////////
184    // BumpAllocator //
185    ///////////////////
186
187    #[test]
188    fn test_bump_allocator() {
189        let allocator = BumpAllocator::new(128, PowerOfTwo::new(1).unwrap()).unwrap();
190        let mut a = Poly::new(0usize, allocator.clone()).unwrap();
191        let mut b = Poly::new(1usize, allocator.clone()).unwrap();
192        let mut c = Poly::new(2usize, allocator.clone()).unwrap();
193
194        *b = 5;
195        *a = 10;
196        *c = 87;
197        *a = 20;
198
199        assert_eq!(*b, 5);
200    }
201
202    #[test]
203    fn poly_new_with_allocates_first() {
204        let allocator = BumpAllocator::new(128, PowerOfTwo::new(64).unwrap()).unwrap();
205
206        struct Nested {
207            inner: Poly<[usize], BumpAllocator>,
208            value: f32,
209        }
210
211        let poly = Poly::<Nested, _>::new_with(
212            |a| -> Result<_, AllocatorError> {
213                Ok(Nested {
214                    inner: Poly::from_iter(0..10, a)?,
215                    value: 10.0,
216                })
217            },
218            allocator.clone(),
219        )
220        .unwrap();
221
222        // Ensure that `poly` was initialized properly.
223        assert!(poly.inner.iter().enumerate().all(|(i, v)| i == *v));
224        assert_eq!(poly.value, 10.0);
225
226        // Ensure that `poly` was allocated before `poly.inner`.
227        let base = allocator.as_ptr();
228        assert_eq!(base, Poly::as_ptr(&poly).cast::<u8>());
229        assert_eq!(
230            base.wrapping_add(32),
231            Poly::as_ptr(&poly.inner).cast::<u8>()
232        );
233    }
234
235    fn values<T: Default>(alloc: BumpAllocator, seed: u64) {
236        let mut buf = Vec::new();
237        let mut rng = StdRng::seed_from_u64(seed);
238
239        let index_dist = Uniform::new(0, 10).unwrap();
240
241        while let Ok(poly) = Poly::new(T::default(), alloc.clone()) {
242            buf.push(poly);
243            if buf.len() == 10 {
244                buf.remove(index_dist.sample(&mut rng));
245            }
246        }
247    }
248
249    fn slices<T: Default>(alloc: BumpAllocator, seed: u64) {
250        let mut buf = Vec::new();
251        let mut rng = StdRng::seed_from_u64(seed);
252
253        let dist = Uniform::new(0, 10).unwrap();
254
255        while let Ok(poly) = Poly::from_iter(
256            (0..dist.sample(&mut rng)).map(|_| T::default()),
257            alloc.clone(),
258        ) {
259            buf.push(poly);
260            if buf.len() == 10 {
261                buf.remove(dist.sample(&mut rng));
262            }
263        }
264    }
265
266    fn stress_test_impl() {
267        let alloc = BumpAllocator::new(4096, PowerOfTwo::new(1).unwrap()).unwrap();
268
269        let c0 = alloc.clone();
270        let c1 = alloc.clone();
271        let c2 = alloc.clone();
272        let c3 = alloc.clone();
273        let handles = [
274            std::thread::spawn(move || values::<u8>(c0, 0xa7c0b68e3ece66f7)),
275            std::thread::spawn(move || values::<String>(c1, 0x72f0fbcaaefbc884)),
276            std::thread::spawn(move || slices::<u16>(c2, 0x447a846ceb3eeda9)),
277            std::thread::spawn(move || slices::<String>(c3, 0xd34c7cbedaf165ad)),
278        ];
279
280        for h in handles.into_iter() {
281            h.join().unwrap();
282        }
283    }
284
285    #[test]
286    fn stress_test() {
287        let trials = if cfg!(miri) { 3 } else { 100 };
288
289        for _ in 0..trials {
290            stress_test_impl();
291        }
292    }
293}