Skip to main content

diskann_quantization/alloc/
mod.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use std::{alloc::Layout, ptr::NonNull};
7
8mod aligned;
9mod bump;
10mod poly;
11mod traits;
12
13pub use aligned::{AlignedAllocator, NotPowerOfTwo};
14pub use bump::BumpAllocator;
15pub use poly::{CompoundError, Poly, TrustedIter, poly};
16pub use traits::{Allocator, AllocatorCore, AllocatorError};
17
18/// A handle to Rust's global allocator. This type does not support allocations of size 0.
19#[derive(Debug, Clone, Copy, PartialEq, Eq)]
20pub struct GlobalAllocator;
21
22// SAFETY: This is a simple wrapper around Rust's built-in allocation and deallocation
23// methods, augmented slightly to handle zero sized layouts by returning a dangling pointer.
24//
25// The returned slice from `allocate` always has the exact size and alignment as `layout`.
26unsafe impl AllocatorCore for GlobalAllocator {
27    fn allocate(&self, layout: Layout) -> Result<NonNull<[u8]>, AllocatorError> {
28        if layout.size() == 0 {
29            return Err(AllocatorError);
30        }
31
32        // SAFETY: `layout` has a non-zero size.
33        let ptr = unsafe { std::alloc::alloc(layout) };
34        let ptr = std::ptr::slice_from_raw_parts_mut(ptr, layout.size());
35        NonNull::new(ptr).ok_or(AllocatorError)
36    }
37
38    unsafe fn deallocate(&self, ptr: NonNull<[u8]>, layout: Layout) {
39        // SAFETY: The caller has the responsibility to ensure that `ptr` and `layout`
40        // came from a previous allocation.
41        unsafe { std::alloc::dealloc(ptr.as_ptr().cast::<u8>(), layout) }
42    }
43}
44
45////////////
46// Scoped //
47////////////
48
49trait DebugAllocator: AllocatorCore + std::fmt::Debug {}
50impl<T> DebugAllocator for T where T: AllocatorCore + std::fmt::Debug {}
51
52/// A dynamic wrapper around an `AllocatorCore` that provides the guarantee that all
53/// allocated object are tied to a given scope.
54///
55/// Additionally, this can allow the use of an allocator that is not `Clone` in contexts
56/// where a clonable allocator is needed (provided the scoping limitations are acceptable).
57#[derive(Debug, Clone, Copy)]
58pub struct ScopedAllocator<'a> {
59    allocator: &'a dyn DebugAllocator,
60}
61
62impl<'a> ScopedAllocator<'a> {
63    /// Construct a new `ScopedAllocator` around the provided `allocator`.
64    pub const fn new<T>(allocator: &'a T) -> Self
65    where
66        T: AllocatorCore + std::fmt::Debug,
67    {
68        Self { allocator }
69    }
70}
71
72impl ScopedAllocator<'static> {
73    /// A convenience method for construcing a `ScopedAllocator` around the [`GlobalAllocator`]
74    /// for cases where a more specialized allocator is not needed.
75    pub const fn global() -> Self {
76        Self {
77            allocator: &GlobalAllocator,
78        }
79    }
80}
81
82// SAFETY: This allocator simply delegates to the underlying allocator.
83unsafe impl AllocatorCore for ScopedAllocator<'_> {
84    fn allocate(&self, layout: std::alloc::Layout) -> Result<NonNull<[u8]>, AllocatorError> {
85        self.allocator.allocate(layout)
86    }
87
88    unsafe fn deallocate(&self, ptr: NonNull<[u8]>, layout: std::alloc::Layout) {
89        // SAFETY: Inherited from caller.
90        unsafe { self.allocator.deallocate(ptr, layout) }
91    }
92}
93
94///////////////
95// Try Clone //
96///////////////
97
98/// A trait like [`Clone`] that allows graceful allocation failure.
99///
100/// # NOTE
101///
102/// Keep this `pub(crate)` for now because we do not want general users of the crate
103/// relying on the current implementations for [`Poly`]. In particular, the base case should
104/// be `Poly<T> where T: TryClone` instead of `Poly<T> where T: Clone`.
105pub(crate) trait TryClone: Sized {
106    /// Returns a duplicate of the value.
107    fn try_clone(&self) -> Result<Self, AllocatorError>;
108}
109
110///////////
111// Tests //
112///////////
113
114#[cfg(test)]
115mod tests {
116    use super::*;
117
118    fn test_alloc<T>() {
119        let alloc = GlobalAllocator;
120
121        let layout = Layout::new::<T>();
122        let ptr = alloc.allocate(layout).unwrap();
123
124        assert_eq!(ptr.len(), layout.size());
125        assert_eq!(ptr.len(), std::mem::size_of::<T>());
126        assert_eq!((ptr.as_ptr().cast::<u8>() as usize) % layout.align(), 0);
127
128        // SAFETY: `ptr` was obtained from this allocator with the specified `layout`.
129        unsafe { alloc.deallocate(ptr, layout) };
130    }
131
132    #[test]
133    fn test_global_allocator() {
134        assert!(GlobalAllocator.allocate(Layout::new::<()>()).is_err());
135
136        test_alloc::<(u8,)>();
137        test_alloc::<(u8, u8)>();
138        test_alloc::<(u8, u8, u8)>();
139        test_alloc::<(u8, u8, u8, u8)>();
140        test_alloc::<(u8, u8, u8, u8, u8)>();
141        test_alloc::<(u8, u8, u8, u8, u8, u8)>();
142        test_alloc::<(u8, u8, u8, u8, u8, u8, u8)>();
143        test_alloc::<(u8, u8, u8, u8, u8, u8, u8, u8)>();
144        test_alloc::<(u8, u8, u8, u8, u8, u8, u8, u8, u8)>();
145
146        test_alloc::<(u16,)>();
147        test_alloc::<(u16, u16)>();
148        test_alloc::<(u16, u16, u16)>();
149        test_alloc::<(u16, u16, u16, u16)>();
150        test_alloc::<(u16, u16, u16, u16, u16)>();
151
152        test_alloc::<(u32,)>();
153        test_alloc::<(u32, u32)>();
154        test_alloc::<(u32, u32, u32)>();
155
156        test_alloc::<String>();
157    }
158}