Skip to main content

diskann_quantization/bits/
ptr.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use std::{marker::PhantomData, ptr::NonNull};
7
8pub use sealed::{AsMutPtr, AsPtr, Precursor};
9
10////////////////////
11// Pointer Traits //
12////////////////////
13
14/// A constant pointer with an associated lifetime.
15///
16/// The only safe way to construct this type is through implementations
17/// of the `crate::bits::slice::sealed::Precursor` trait or by reborrowing a bitslice.
18#[derive(Debug, Clone, Copy)]
19pub struct SlicePtr<'a, T> {
20    ptr: NonNull<T>,
21    lifetime: PhantomData<&'a T>,
22}
23
24impl<T> SlicePtr<'_, T> {
25    /// # Safety
26    ///
27    /// It's the callers responsibility to ensure the correct lifetime is attached to
28    /// the underlying pointer and that the constraints of any unsafe traits implemented by
29    /// this type are upheld.
30    pub(crate) unsafe fn new_unchecked(ptr: NonNull<T>) -> Self {
31        Self {
32            ptr,
33            lifetime: PhantomData,
34        }
35    }
36}
37
38// SAFETY: `SlicePtr` has shared-reference (`&'a T`) semantics.
39// `&T: Send` iff `T: Sync`.
40unsafe impl<T: Sync> Send for SlicePtr<'_, T> {}
41
42// SAFETY: Slices are `Sync` when `T: Sync`.
43unsafe impl<T: Sync> Sync for SlicePtr<'_, T> {}
44
45/// A mutable pointer with an associated lifetime.
46///
47/// The only safe way to construct this type is through implementations
48/// of the `crate::bits::slice::sealed::Precursor` trait or by reborrowing a bitslice.
49#[derive(Debug)]
50pub struct MutSlicePtr<'a, T> {
51    ptr: NonNull<T>,
52    lifetime: PhantomData<&'a mut T>,
53}
54
55impl<T> MutSlicePtr<'_, T> {
56    /// # Safety
57    ///
58    /// It's the callers responsibility to ensure the correct lifetime is attached to
59    /// the underlying pointer and that the constraints of any unsafe traits implemented by
60    /// this type are upheld.
61    pub(crate) unsafe fn new_unchecked(ptr: NonNull<T>) -> Self {
62        Self {
63            ptr,
64            lifetime: PhantomData,
65        }
66    }
67}
68
69// SAFETY: Mutable slices are `Send` when `T: Send`.
70unsafe impl<T: Send> Send for MutSlicePtr<'_, T> {}
71
72// SAFETY: Mutable slices are `Sync` when `T: Sync`.
73unsafe impl<T: Sync> Sync for MutSlicePtr<'_, T> {}
74
75mod sealed {
76    use std::{marker::PhantomData, ptr::NonNull};
77
78    use super::{MutSlicePtr, SlicePtr};
79    use crate::alloc::{AllocatorCore, Poly};
80
81    /// A precursor for a pointer type that is used as the base of a `BitSlice`.
82    ///
83    /// This trait is *unsafe* because implementing it incorrectly will lead to lifetime
84    /// violations, out-of-bounds accesses, and other undefined behavior.
85    ///
86    /// # Safety
87    ///
88    /// There are two components to a safe implementation of `Precursor`.
89    ///
90    /// ## Memory Safety
91    ///
92    /// It is the implementors responsibility to ensure all the preconditions necessary so
93    /// the following expression is safe:
94    /// ```ignore
95    /// let x: impl Precursor ...
96    /// let len = x.len();
97    /// // The length and derived pointer **must** maintain the following:
98    /// unsafe { std::slice::from_raw_parts(x.precursor_into().as_ptr(), len) }
99    /// ```
100    /// Furthermore, if `Target: AsMutPtr`, then the preconditions for
101    /// `std::slice::from_raw_parts_mut` must be upheld.
102    ///
103    /// ## Lifetime Safety
104    ///
105    /// If is the implementors responsibility that `Target` correctly captures lifetimes
106    /// so the slice obtained from the precursor obeys Rust's rules for references.
107    pub unsafe trait Precursor<Target>
108    where
109        Target: AsPtr,
110    {
111        /// Consume `self` and convert into `Target`.
112        fn precursor_into(self) -> Target;
113
114        /// Return the number of elements of type `<Target as AsPtr>::Type` that are valid
115        /// from the pointer returned by `self.precursor_into().as_ptr()`.
116        fn precursor_len(&self) -> usize;
117    }
118
119    /// Safety: Slices implicitly fit the pre-conditions for `std::slice::from_raw_parts`,
120    /// meaning in particular they will not return null pointers, even with zero size.
121    ///
122    /// This implementation simply breaks the slice into its raw parts.
123    ///
124    /// The `SlicePtr` captures the reference lifetime.
125    unsafe impl<'a, T> Precursor<SlicePtr<'a, T>> for &'a [T] {
126        fn precursor_into(self) -> SlicePtr<'a, T> {
127            SlicePtr {
128                // Safety: Slices cannot yield null pointers.
129                //
130                // The `cast_mut()` is safe because we do not provide a way to retrieve a
131                // mutable pointer from `SlicePtr`.
132                ptr: unsafe { NonNull::new_unchecked(self.as_ptr().cast_mut()) },
133                lifetime: PhantomData,
134            }
135        }
136        fn precursor_len(&self) -> usize {
137            <[T]>::len(self)
138        }
139    }
140
141    /// Safety: Slices implicitly fit the pre-conditions for `std::slice::from_raw_parts`,
142    /// meaning in particular they will not return null pointers, even with zero size.
143    ///
144    /// This implementation simply breaks the slice into its raw parts.
145    ///
146    /// The `SlicePtr` captures the reference lifetime. Decaying a mutable borrow to a
147    /// normal borrow is allowed.
148    unsafe impl<'a, T> Precursor<SlicePtr<'a, T>> for &'a mut [T] {
149        fn precursor_into(self) -> SlicePtr<'a, T> {
150            SlicePtr {
151                // Safety: Slices cannot yield null pointers.
152                //
153                // The `cast_mut()` is safe because we do not provide a way to retrieve a
154                // mutable pointer from `SlicePtr`.
155                ptr: unsafe { NonNull::new_unchecked(self.as_ptr().cast_mut()) },
156                lifetime: PhantomData,
157            }
158        }
159        fn precursor_len(&self) -> usize {
160            self.len()
161        }
162    }
163
164    /// Safety: Slices implicitly fit the pre-conditions for `std::slice::from_raw_parts`
165    /// and `std::slice::from_raw_parts_mut` meaning in particular they will not return
166    /// null pointers, even with zero size.
167    ///
168    /// This implementation simply breaks the slice into its raw parts.
169    ///
170    /// The `SlicePtr` captures the mutable reference lifetime.
171    unsafe impl<'a, T> Precursor<MutSlicePtr<'a, T>> for &'a mut [T] {
172        fn precursor_into(self) -> MutSlicePtr<'a, T> {
173            MutSlicePtr {
174                // Safety: Slices cannot yield null pointers.
175                ptr: unsafe { NonNull::new_unchecked(self.as_mut_ptr()) },
176                lifetime: PhantomData,
177            }
178        }
179        fn precursor_len(&self) -> usize {
180            self.len()
181        }
182    }
183
184    /// Safety: This implementation forwards through the [`Poly`], whose implementation of
185    /// `AsPtr` and `AsPtrMut` follow the same logic as those of captured slices.
186    ///
187    /// Since a [`Poly`] owns its contents, the lifetime requirements are satisfied.
188    unsafe impl<T, A> Precursor<Poly<[T], A>> for Poly<[T], A>
189    where
190        A: AllocatorCore,
191    {
192        fn precursor_into(self) -> Poly<[T], A> {
193            self
194        }
195        fn precursor_len(&self) -> usize {
196            self.len()
197        }
198    }
199
200    /// Obtain a constant base pointer for a slice of data.
201    ///
202    /// # Safety
203    ///
204    /// The returned pointer must never be null!
205    pub unsafe trait AsPtr {
206        type Type;
207        fn as_ptr(&self) -> *const Self::Type;
208    }
209
210    /// Obtain a mutable base pointer for a slice of data.
211    ///
212    /// # Safety
213    ///
214    /// The returned pointer must never be null! Furthermore, the mutable pointer must
215    /// originally be derived from a mutable pointer.
216    pub unsafe trait AsMutPtr: AsPtr {
217        fn as_mut_ptr(&mut self) -> *mut Self::Type;
218    }
219
220    /// Safety: SlicePtr may only contain non-null pointers.
221    unsafe impl<T> AsPtr for SlicePtr<'_, T> {
222        type Type = T;
223        fn as_ptr(&self) -> *const T {
224            self.ptr.as_ptr().cast_const()
225        }
226    }
227
228    /// Safety: SlicePtr may only contain non-null pointers.
229    unsafe impl<T> AsPtr for MutSlicePtr<'_, T> {
230        type Type = T;
231        fn as_ptr(&self) -> *const T {
232            // The const-cast is allowed by variance.
233            self.ptr.as_ptr().cast_const()
234        }
235    }
236
237    /// Safety: SlicePtr may only contain non-null pointers. The only way to construct
238    /// a `MutSlicePtr` is from a mutable reference, so the underlying pointer is indeed
239    /// mutable.
240    unsafe impl<T> AsMutPtr for MutSlicePtr<'_, T> {
241        fn as_mut_ptr(&mut self) -> *mut T {
242            self.ptr.as_ptr()
243        }
244    }
245
246    /// Safety: Slices never return a null pointer.
247    unsafe impl<T, A> AsPtr for Poly<[T], A>
248    where
249        A: AllocatorCore,
250    {
251        type Type = T;
252        fn as_ptr(&self) -> *const T {
253            <[T]>::as_ptr(self)
254        }
255    }
256
257    /// Safety: Slices never return a null pointer. A mutable reference to `self` signals
258    /// an exclusive borrow - so the underlying pointer is indeed mutable.
259    unsafe impl<T, A> AsMutPtr for Poly<[T], A>
260    where
261        A: AllocatorCore,
262    {
263        fn as_mut_ptr(&mut self) -> *mut T {
264            <[T]>::as_mut_ptr(&mut *self)
265        }
266    }
267}
268
269///////////
270// Tests //
271///////////
272
273#[cfg(test)]
274mod tests {
275    use super::*;
276    use crate::alloc::{GlobalAllocator, Poly};
277
278    ///////////////////////////////
279    // SlicePtr from Const Slice //
280    ///////////////////////////////
281
282    fn slice_ptr_from_const_slice(base: &[u8]) {
283        let ptr = base.as_ptr();
284        assert!(!ptr.is_null(), "slices must not return null pointers");
285
286        let slice_ptr: SlicePtr<'_, u8> = base.precursor_into();
287        assert_eq!(slice_ptr.as_ptr(), ptr);
288        assert_eq!(base.precursor_len(), base.len());
289
290        // Safety: Check with Miri.
291        let derived = unsafe {
292            std::slice::from_raw_parts(base.precursor_into().as_ptr(), base.precursor_len())
293        };
294        assert_eq!(derived.as_ptr(), ptr);
295        assert_eq!(derived.len(), base.len());
296    }
297
298    #[test]
299    fn test_slice_ptr_from_const_slice() {
300        slice_ptr_from_const_slice(&[]);
301        slice_ptr_from_const_slice(&[1]);
302        slice_ptr_from_const_slice(&[1, 2]);
303
304        for len in 0..10 {
305            let base: Vec<u8> = vec![0; len];
306            slice_ptr_from_const_slice(&base);
307        }
308    }
309
310    /////////////////////////////
311    // SlicePtr from Mut Slice //
312    /////////////////////////////
313
314    fn slice_ptr_from_mut_slice(base: &mut [u8]) {
315        let ptr = base.as_ptr();
316        let len = base.len();
317        assert!(!ptr.is_null(), "slices must not return null pointers");
318
319        let precursor_len = <&mut [u8] as Precursor<SlicePtr<'_, u8>>>::precursor_len(&base);
320        assert_eq!(precursor_len, base.len());
321
322        let slice_ptr: SlicePtr<'_, u8> = base.precursor_into();
323        assert_eq!(slice_ptr.as_ptr(), ptr);
324
325        // Safety: Check with Miri.
326        let derived = unsafe { std::slice::from_raw_parts(slice_ptr.as_ptr(), precursor_len) };
327
328        assert_eq!(derived.as_ptr(), ptr);
329        assert_eq!(derived.len(), len);
330    }
331
332    #[test]
333    fn test_slice_ptr_from_mut_slice() {
334        slice_ptr_from_mut_slice(&mut []);
335        slice_ptr_from_mut_slice(&mut [1]);
336        slice_ptr_from_mut_slice(&mut [1, 2]);
337
338        for len in 0..10 {
339            let mut base: Vec<u8> = vec![0; len];
340            slice_ptr_from_mut_slice(&mut base);
341        }
342    }
343
344    /////////////////////////////
345    // SlicePtr from Mut Slice //
346    /////////////////////////////
347
348    fn mut_slice_ptr_from_mut_slice(base: &mut [u8]) {
349        let ptr = base.as_mut_ptr();
350        let len = base.len();
351
352        assert!(!ptr.is_null(), "slices must not return null pointers");
353
354        let precursor_len = <&mut [u8] as Precursor<SlicePtr<'_, u8>>>::precursor_len(&base);
355        assert_eq!(precursor_len, base.len());
356
357        let mut slice_ptr: MutSlicePtr<'_, u8> = base.precursor_into();
358        assert_eq!(slice_ptr.as_ptr(), ptr.cast_const());
359        assert_eq!(slice_ptr.as_mut_ptr(), ptr);
360
361        // Safety: Check with Miri.
362        let derived =
363            unsafe { std::slice::from_raw_parts_mut(slice_ptr.as_mut_ptr(), precursor_len) };
364
365        assert_eq!(derived.as_ptr(), ptr);
366        assert_eq!(derived.len(), len);
367    }
368
369    #[test]
370    fn test_mut_slice_ptr_from_mut_slice() {
371        mut_slice_ptr_from_mut_slice(&mut []);
372        mut_slice_ptr_from_mut_slice(&mut [1]);
373        mut_slice_ptr_from_mut_slice(&mut [1, 2]);
374
375        for len in 0..10 {
376            let mut base: Vec<u8> = vec![0; len];
377            mut_slice_ptr_from_mut_slice(&mut base);
378        }
379    }
380
381    /////////
382    // Box //
383    /////////
384
385    fn box_from_box(base: Poly<[u8], GlobalAllocator>) {
386        let ptr = base.as_ptr();
387        let len = base.len();
388
389        assert!(!ptr.is_null(), "slices must not return null pointers");
390
391        assert_eq!(base.precursor_len(), len);
392        let mut derived = base.precursor_into();
393
394        assert_eq!(derived.as_ptr(), ptr);
395        assert_eq!(derived.as_mut_ptr(), ptr.cast_mut());
396        assert_eq!(derived.len(), len);
397    }
398
399    #[test]
400    fn test_box() {
401        box_from_box(Poly::from_iter([].into_iter(), GlobalAllocator).unwrap());
402        box_from_box(Poly::from_iter([1].into_iter(), GlobalAllocator).unwrap());
403        box_from_box(Poly::from_iter([1, 2].into_iter(), GlobalAllocator).unwrap());
404
405        for len in 0..10 {
406            let base = Poly::broadcast(0, len, GlobalAllocator).unwrap();
407            box_from_box(base);
408        }
409    }
410}