Skip to main content

diskann_quantization/bits/
ptr.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use std::{marker::PhantomData, ptr::NonNull};
7
8pub use sealed::{AsMutPtr, AsPtr, Precursor};
9
10////////////////////
11// Pointer Traits //
12////////////////////
13
14/// A constant pointer with an associated lifetime.
15///
16/// The only safe way to construct this type is through implementations
17/// of the `crate::bits::slice::sealed::Precursor` trait or by reborrowing a bitslice.
18#[derive(Debug, Clone, Copy)]
19pub struct SlicePtr<'a, T> {
20    ptr: NonNull<T>,
21    lifetime: PhantomData<&'a T>,
22}
23
24impl<T> SlicePtr<'_, T> {
25    /// # Safety
26    ///
27    /// It's the callers responsibility to ensure the correct lifetime is attached to
28    /// the underlying pointer and that the constraints of any unsafe traits implemented by
29    /// this type are upheld.
30    pub(super) unsafe fn new_unchecked(ptr: NonNull<T>) -> Self {
31        Self {
32            ptr,
33            lifetime: PhantomData,
34        }
35    }
36}
37
38// SAFETY: Slices are `Send` when `T: Send`.
39unsafe impl<T: Send> Send for SlicePtr<'_, T> {}
40
41// SAFETY: Slices are `Sync` when `T: Sync`.
42unsafe impl<T: Sync> Sync for SlicePtr<'_, T> {}
43
44/// A mutable pointer with an associated lifetime.
45///
46/// The only safe way to construct this type is through implementations
47/// of the `crate::bits::slice::sealed::Precursor` trait or by reborrowing a bitslice.
48#[derive(Debug)]
49pub struct MutSlicePtr<'a, T> {
50    ptr: NonNull<T>,
51    lifetime: PhantomData<&'a mut T>,
52}
53
54impl<T> MutSlicePtr<'_, T> {
55    /// # Safety
56    ///
57    /// It's the callers responsibility to ensure the correct lifetime is attached to
58    /// the underlying pointer and that the constraints of any unsafe traits implemented by
59    /// this type are upheld.
60    pub(super) unsafe fn new_unchecked(ptr: NonNull<T>) -> Self {
61        Self {
62            ptr,
63            lifetime: PhantomData,
64        }
65    }
66}
67
68// SAFETY: Mutable slices are `Send` when `T: Send`.
69unsafe impl<T: Send> Send for MutSlicePtr<'_, T> {}
70
71// SAFETY: Mutable slices are `Sync` when `T: Sync`.
72unsafe impl<T: Sync> Sync for MutSlicePtr<'_, T> {}
73
74mod sealed {
75    use std::{marker::PhantomData, ptr::NonNull};
76
77    use super::{MutSlicePtr, SlicePtr};
78    use crate::alloc::{AllocatorCore, Poly};
79
80    /// A precursor for a pointer type that is used as the base of a `BitSlice`.
81    ///
82    /// This trait is *unsafe* because implementing it incorrectly will lead to lifetime
83    /// violations, out-of-bounds accesses, and other undefined behavior.
84    ///
85    /// # Safety
86    ///
87    /// There are two components to a safe implementation of `Precursor`.
88    ///
89    /// ## Memory Safety
90    ///
91    /// It is the implementors responsibility to ensure all the preconditions necessary so
92    /// the following expression is safe:
93    /// ```ignore
94    /// let x: impl Precursor ...
95    /// let len = x.len();
96    /// // The length and derived pointer **must** maintain the following:
97    /// unsafe { std::slice::from_raw_parts(x.precursor_into().as_ptr(), len) }
98    /// ```
99    /// Furthermore, if `Target: AsMutPtr`, then the preconditions for
100    /// `std::slice::from_raw_parts_mut` must be upheld.
101    ///
102    /// ## Lifetime Safety
103    ///
104    /// If is the implementors responsibility that `Target` correctly captures lifetimes
105    /// so the slice obtained from the precursor obeys Rust's rules for references.
106    pub unsafe trait Precursor<Target>
107    where
108        Target: AsPtr,
109    {
110        /// Consume `self` and convert into `Target`.
111        fn precursor_into(self) -> Target;
112
113        /// Return the number of elements of type `<Target as AsPtr>::Type` that are valid
114        /// from the pointer returned by `self.precursor_into().as_ptr()`.
115        fn precursor_len(&self) -> usize;
116    }
117
118    /// Safety: Slices implicitly fit the pre-conditions for `std::slice::from_raw_parts`,
119    /// meaning in particular they will not return null pointers, even with zero size.
120    ///
121    /// This implementation simply breaks the slice into its raw parts.
122    ///
123    /// The `SlicePtr` captures the reference lifetime.
124    unsafe impl<'a, T> Precursor<SlicePtr<'a, T>> for &'a [T] {
125        fn precursor_into(self) -> SlicePtr<'a, T> {
126            SlicePtr {
127                // Safety: Slices cannot yield null pointers.
128                //
129                // The `cast_mut()` is safe because we do not provide a way to retrieve a
130                // mutable pointer from `SlicePtr`.
131                ptr: unsafe { NonNull::new_unchecked(self.as_ptr().cast_mut()) },
132                lifetime: PhantomData,
133            }
134        }
135        fn precursor_len(&self) -> usize {
136            <[T]>::len(self)
137        }
138    }
139
140    /// Safety: Slices implicitly fit the pre-conditions for `std::slice::from_raw_parts`,
141    /// meaning in particular they will not return null pointers, even with zero size.
142    ///
143    /// This implementation simply breaks the slice into its raw parts.
144    ///
145    /// The `SlicePtr` captures the reference lifetime. Decaying a mutable borrow to a
146    /// normal borrow is allowed.
147    unsafe impl<'a, T> Precursor<SlicePtr<'a, T>> for &'a mut [T] {
148        fn precursor_into(self) -> SlicePtr<'a, T> {
149            SlicePtr {
150                // Safety: Slices cannot yield null pointers.
151                //
152                // The `cast_mut()` is safe because we do not provide a way to retrieve a
153                // mutable pointer from `SlicePtr`.
154                ptr: unsafe { NonNull::new_unchecked(self.as_ptr().cast_mut()) },
155                lifetime: PhantomData,
156            }
157        }
158        fn precursor_len(&self) -> usize {
159            self.len()
160        }
161    }
162
163    /// Safety: Slices implicitly fit the pre-conditions for `std::slice::from_raw_parts`
164    /// and `std::slice::from_raw_parts_mut` meaning in particular they will not return
165    /// null pointers, even with zero size.
166    ///
167    /// This implementation simply breaks the slice into its raw parts.
168    ///
169    /// The `SlicePtr` captures the mutable reference lifetime.
170    unsafe impl<'a, T> Precursor<MutSlicePtr<'a, T>> for &'a mut [T] {
171        fn precursor_into(self) -> MutSlicePtr<'a, T> {
172            MutSlicePtr {
173                // Safety: Slices cannot yield null pointers.
174                ptr: unsafe { NonNull::new_unchecked(self.as_mut_ptr()) },
175                lifetime: PhantomData,
176            }
177        }
178        fn precursor_len(&self) -> usize {
179            self.len()
180        }
181    }
182
183    /// Safety: This implementation forwards through the [`Poly`], whose implementation of
184    /// `AsPtr` and `AsPtrMut` follow the same logic as those of captured slices.
185    ///
186    /// Since a [`Poly`] owns its contents, the lifetime requirements are satisfied.
187    unsafe impl<T, A> Precursor<Poly<[T], A>> for Poly<[T], A>
188    where
189        A: AllocatorCore,
190    {
191        fn precursor_into(self) -> Poly<[T], A> {
192            self
193        }
194        fn precursor_len(&self) -> usize {
195            self.len()
196        }
197    }
198
199    /// Obtain a constant base pointer for a slice of data.
200    ///
201    /// # Safety
202    ///
203    /// The returned pointer must never be null!
204    pub unsafe trait AsPtr {
205        type Type;
206        fn as_ptr(&self) -> *const Self::Type;
207    }
208
209    /// Obtain a mutable base pointer for a slice of data.
210    ///
211    /// # Safety
212    ///
213    /// The returned pointer must never be null! Furthermore, the mutable pointer must
214    /// originally be derived from a mutable pointer.
215    pub unsafe trait AsMutPtr: AsPtr {
216        fn as_mut_ptr(&mut self) -> *mut Self::Type;
217    }
218
219    /// Safety: SlicePtr may only contain non-null pointers.
220    unsafe impl<T> AsPtr for SlicePtr<'_, T> {
221        type Type = T;
222        fn as_ptr(&self) -> *const T {
223            self.ptr.as_ptr().cast_const()
224        }
225    }
226
227    /// Safety: SlicePtr may only contain non-null pointers.
228    unsafe impl<T> AsPtr for MutSlicePtr<'_, T> {
229        type Type = T;
230        fn as_ptr(&self) -> *const T {
231            // The const-cast is allowed by variance.
232            self.ptr.as_ptr().cast_const()
233        }
234    }
235
236    /// Safety: SlicePtr may only contain non-null pointers. The only way to construct
237    /// a `MutSlicePtr` is from a mutable reference, so the underlying pointer is indeed
238    /// mutable.
239    unsafe impl<T> AsMutPtr for MutSlicePtr<'_, T> {
240        fn as_mut_ptr(&mut self) -> *mut T {
241            self.ptr.as_ptr()
242        }
243    }
244
245    /// Safety: Slices never return a null pointer.
246    unsafe impl<T, A> AsPtr for Poly<[T], A>
247    where
248        A: AllocatorCore,
249    {
250        type Type = T;
251        fn as_ptr(&self) -> *const T {
252            <[T]>::as_ptr(self)
253        }
254    }
255
256    /// Safety: Slices never return a null pointer. A mutable reference to `self` signals
257    /// an exclusive borrow - so the underlying pointer is indeed mutable.
258    unsafe impl<T, A> AsMutPtr for Poly<[T], A>
259    where
260        A: AllocatorCore,
261    {
262        fn as_mut_ptr(&mut self) -> *mut T {
263            <[T]>::as_mut_ptr(&mut *self)
264        }
265    }
266}
267
268///////////
269// Tests //
270///////////
271
272#[cfg(test)]
273mod tests {
274    use super::*;
275    use crate::alloc::{GlobalAllocator, Poly};
276
277    ///////////////////////////////
278    // SlicePtr from Const Slice //
279    ///////////////////////////////
280
281    fn slice_ptr_from_const_slice(base: &[u8]) {
282        let ptr = base.as_ptr();
283        assert!(!ptr.is_null(), "slices must not return null pointers");
284
285        let slice_ptr: SlicePtr<'_, u8> = base.precursor_into();
286        assert_eq!(slice_ptr.as_ptr(), ptr);
287        assert_eq!(base.precursor_len(), base.len());
288
289        // Safety: Check with Miri.
290        let derived = unsafe {
291            std::slice::from_raw_parts(base.precursor_into().as_ptr(), base.precursor_len())
292        };
293        assert_eq!(derived.as_ptr(), ptr);
294        assert_eq!(derived.len(), base.len());
295    }
296
297    #[test]
298    fn test_slice_ptr_from_const_slice() {
299        slice_ptr_from_const_slice(&[]);
300        slice_ptr_from_const_slice(&[1]);
301        slice_ptr_from_const_slice(&[1, 2]);
302
303        for len in 0..10 {
304            let base: Vec<u8> = vec![0; len];
305            slice_ptr_from_const_slice(&base);
306        }
307    }
308
309    /////////////////////////////
310    // SlicePtr from Mut Slice //
311    /////////////////////////////
312
313    fn slice_ptr_from_mut_slice(base: &mut [u8]) {
314        let ptr = base.as_ptr();
315        let len = base.len();
316        assert!(!ptr.is_null(), "slices must not return null pointers");
317
318        let precursor_len = <&mut [u8] as Precursor<SlicePtr<'_, u8>>>::precursor_len(&base);
319        assert_eq!(precursor_len, base.len());
320
321        let slice_ptr: SlicePtr<'_, u8> = base.precursor_into();
322        assert_eq!(slice_ptr.as_ptr(), ptr);
323
324        // Safety: Check with Miri.
325        let derived = unsafe { std::slice::from_raw_parts(slice_ptr.as_ptr(), precursor_len) };
326
327        assert_eq!(derived.as_ptr(), ptr);
328        assert_eq!(derived.len(), len);
329    }
330
331    #[test]
332    fn test_slice_ptr_from_mut_slice() {
333        slice_ptr_from_mut_slice(&mut []);
334        slice_ptr_from_mut_slice(&mut [1]);
335        slice_ptr_from_mut_slice(&mut [1, 2]);
336
337        for len in 0..10 {
338            let mut base: Vec<u8> = vec![0; len];
339            slice_ptr_from_mut_slice(&mut base);
340        }
341    }
342
343    /////////////////////////////
344    // SlicePtr from Mut Slice //
345    /////////////////////////////
346
347    fn mut_slice_ptr_from_mut_slice(base: &mut [u8]) {
348        let ptr = base.as_mut_ptr();
349        let len = base.len();
350
351        assert!(!ptr.is_null(), "slices must not return null pointers");
352
353        let precursor_len = <&mut [u8] as Precursor<SlicePtr<'_, u8>>>::precursor_len(&base);
354        assert_eq!(precursor_len, base.len());
355
356        let mut slice_ptr: MutSlicePtr<'_, u8> = base.precursor_into();
357        assert_eq!(slice_ptr.as_ptr(), ptr.cast_const());
358        assert_eq!(slice_ptr.as_mut_ptr(), ptr);
359
360        // Safety: Check with Miri.
361        let derived =
362            unsafe { std::slice::from_raw_parts_mut(slice_ptr.as_mut_ptr(), precursor_len) };
363
364        assert_eq!(derived.as_ptr(), ptr);
365        assert_eq!(derived.len(), len);
366    }
367
368    #[test]
369    fn test_mut_slice_ptr_from_mut_slice() {
370        mut_slice_ptr_from_mut_slice(&mut []);
371        mut_slice_ptr_from_mut_slice(&mut [1]);
372        mut_slice_ptr_from_mut_slice(&mut [1, 2]);
373
374        for len in 0..10 {
375            let mut base: Vec<u8> = vec![0; len];
376            mut_slice_ptr_from_mut_slice(&mut base);
377        }
378    }
379
380    /////////
381    // Box //
382    /////////
383
384    fn box_from_box(base: Poly<[u8], GlobalAllocator>) {
385        let ptr = base.as_ptr();
386        let len = base.len();
387
388        assert!(!ptr.is_null(), "slices must not return null pointers");
389
390        assert_eq!(base.precursor_len(), len);
391        let mut derived = base.precursor_into();
392
393        assert_eq!(derived.as_ptr(), ptr);
394        assert_eq!(derived.as_mut_ptr(), ptr.cast_mut());
395        assert_eq!(derived.len(), len);
396    }
397
398    #[test]
399    fn test_box() {
400        box_from_box(Poly::from_iter([].into_iter(), GlobalAllocator).unwrap());
401        box_from_box(Poly::from_iter([1].into_iter(), GlobalAllocator).unwrap());
402        box_from_box(Poly::from_iter([1, 2].into_iter(), GlobalAllocator).unwrap());
403
404        for len in 0..10 {
405            let base = Poly::broadcast(0, len, GlobalAllocator).unwrap();
406            box_from_box(base);
407        }
408    }
409}