diskann_quantization/bits/ptr.rs
1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use std::{marker::PhantomData, ptr::NonNull};
7
8pub use sealed::{AsMutPtr, AsPtr, Precursor};
9
10////////////////////
11// Pointer Traits //
12////////////////////
13
14/// A constant pointer with an associated lifetime.
15///
16/// The only safe way to construct this type is through implementations
17/// of the `crate::bits::slice::sealed::Precursor` trait or by reborrowing a bitslice.
18#[derive(Debug, Clone, Copy)]
19pub struct SlicePtr<'a, T> {
20 ptr: NonNull<T>,
21 lifetime: PhantomData<&'a T>,
22}
23
24impl<T> SlicePtr<'_, T> {
25 /// # Safety
26 ///
27 /// It's the callers responsibility to ensure the correct lifetime is attached to
28 /// the underlying pointer and that the constraints of any unsafe traits implemented by
29 /// this type are upheld.
30 pub(crate) unsafe fn new_unchecked(ptr: NonNull<T>) -> Self {
31 Self {
32 ptr,
33 lifetime: PhantomData,
34 }
35 }
36}
37
38// SAFETY: `SlicePtr` has shared-reference (`&'a T`) semantics.
39// `&T: Send` iff `T: Sync`.
40unsafe impl<T: Sync> Send for SlicePtr<'_, T> {}
41
42// SAFETY: Slices are `Sync` when `T: Sync`.
43unsafe impl<T: Sync> Sync for SlicePtr<'_, T> {}
44
45/// A mutable pointer with an associated lifetime.
46///
47/// The only safe way to construct this type is through implementations
48/// of the `crate::bits::slice::sealed::Precursor` trait or by reborrowing a bitslice.
49#[derive(Debug)]
50pub struct MutSlicePtr<'a, T> {
51 ptr: NonNull<T>,
52 lifetime: PhantomData<&'a mut T>,
53}
54
55impl<T> MutSlicePtr<'_, T> {
56 /// # Safety
57 ///
58 /// It's the callers responsibility to ensure the correct lifetime is attached to
59 /// the underlying pointer and that the constraints of any unsafe traits implemented by
60 /// this type are upheld.
61 pub(crate) unsafe fn new_unchecked(ptr: NonNull<T>) -> Self {
62 Self {
63 ptr,
64 lifetime: PhantomData,
65 }
66 }
67}
68
69// SAFETY: Mutable slices are `Send` when `T: Send`.
70unsafe impl<T: Send> Send for MutSlicePtr<'_, T> {}
71
72// SAFETY: Mutable slices are `Sync` when `T: Sync`.
73unsafe impl<T: Sync> Sync for MutSlicePtr<'_, T> {}
74
75mod sealed {
76 use std::{marker::PhantomData, ptr::NonNull};
77
78 use super::{MutSlicePtr, SlicePtr};
79 use crate::alloc::{AllocatorCore, Poly};
80
81 /// A precursor for a pointer type that is used as the base of a `BitSlice`.
82 ///
83 /// This trait is *unsafe* because implementing it incorrectly will lead to lifetime
84 /// violations, out-of-bounds accesses, and other undefined behavior.
85 ///
86 /// # Safety
87 ///
88 /// There are two components to a safe implementation of `Precursor`.
89 ///
90 /// ## Memory Safety
91 ///
92 /// It is the implementors responsibility to ensure all the preconditions necessary so
93 /// the following expression is safe:
94 /// ```ignore
95 /// let x: impl Precursor ...
96 /// let len = x.len();
97 /// // The length and derived pointer **must** maintain the following:
98 /// unsafe { std::slice::from_raw_parts(x.precursor_into().as_ptr(), len) }
99 /// ```
100 /// Furthermore, if `Target: AsMutPtr`, then the preconditions for
101 /// `std::slice::from_raw_parts_mut` must be upheld.
102 ///
103 /// ## Lifetime Safety
104 ///
105 /// If is the implementors responsibility that `Target` correctly captures lifetimes
106 /// so the slice obtained from the precursor obeys Rust's rules for references.
107 pub unsafe trait Precursor<Target>
108 where
109 Target: AsPtr,
110 {
111 /// Consume `self` and convert into `Target`.
112 fn precursor_into(self) -> Target;
113
114 /// Return the number of elements of type `<Target as AsPtr>::Type` that are valid
115 /// from the pointer returned by `self.precursor_into().as_ptr()`.
116 fn precursor_len(&self) -> usize;
117 }
118
119 /// Safety: Slices implicitly fit the pre-conditions for `std::slice::from_raw_parts`,
120 /// meaning in particular they will not return null pointers, even with zero size.
121 ///
122 /// This implementation simply breaks the slice into its raw parts.
123 ///
124 /// The `SlicePtr` captures the reference lifetime.
125 unsafe impl<'a, T> Precursor<SlicePtr<'a, T>> for &'a [T] {
126 fn precursor_into(self) -> SlicePtr<'a, T> {
127 SlicePtr {
128 // Safety: Slices cannot yield null pointers.
129 //
130 // The `cast_mut()` is safe because we do not provide a way to retrieve a
131 // mutable pointer from `SlicePtr`.
132 ptr: unsafe { NonNull::new_unchecked(self.as_ptr().cast_mut()) },
133 lifetime: PhantomData,
134 }
135 }
136 fn precursor_len(&self) -> usize {
137 <[T]>::len(self)
138 }
139 }
140
141 /// Safety: Slices implicitly fit the pre-conditions for `std::slice::from_raw_parts`,
142 /// meaning in particular they will not return null pointers, even with zero size.
143 ///
144 /// This implementation simply breaks the slice into its raw parts.
145 ///
146 /// The `SlicePtr` captures the reference lifetime. Decaying a mutable borrow to a
147 /// normal borrow is allowed.
148 unsafe impl<'a, T> Precursor<SlicePtr<'a, T>> for &'a mut [T] {
149 fn precursor_into(self) -> SlicePtr<'a, T> {
150 SlicePtr {
151 // Safety: Slices cannot yield null pointers.
152 //
153 // The `cast_mut()` is safe because we do not provide a way to retrieve a
154 // mutable pointer from `SlicePtr`.
155 ptr: unsafe { NonNull::new_unchecked(self.as_ptr().cast_mut()) },
156 lifetime: PhantomData,
157 }
158 }
159 fn precursor_len(&self) -> usize {
160 self.len()
161 }
162 }
163
164 /// Safety: Slices implicitly fit the pre-conditions for `std::slice::from_raw_parts`
165 /// and `std::slice::from_raw_parts_mut` meaning in particular they will not return
166 /// null pointers, even with zero size.
167 ///
168 /// This implementation simply breaks the slice into its raw parts.
169 ///
170 /// The `SlicePtr` captures the mutable reference lifetime.
171 unsafe impl<'a, T> Precursor<MutSlicePtr<'a, T>> for &'a mut [T] {
172 fn precursor_into(self) -> MutSlicePtr<'a, T> {
173 MutSlicePtr {
174 // Safety: Slices cannot yield null pointers.
175 ptr: unsafe { NonNull::new_unchecked(self.as_mut_ptr()) },
176 lifetime: PhantomData,
177 }
178 }
179 fn precursor_len(&self) -> usize {
180 self.len()
181 }
182 }
183
184 /// Safety: This implementation forwards through the [`Poly`], whose implementation of
185 /// `AsPtr` and `AsPtrMut` follow the same logic as those of captured slices.
186 ///
187 /// Since a [`Poly`] owns its contents, the lifetime requirements are satisfied.
188 unsafe impl<T, A> Precursor<Poly<[T], A>> for Poly<[T], A>
189 where
190 A: AllocatorCore,
191 {
192 fn precursor_into(self) -> Poly<[T], A> {
193 self
194 }
195 fn precursor_len(&self) -> usize {
196 self.len()
197 }
198 }
199
200 /// Obtain a constant base pointer for a slice of data.
201 ///
202 /// # Safety
203 ///
204 /// The returned pointer must never be null!
205 pub unsafe trait AsPtr {
206 type Type;
207 fn as_ptr(&self) -> *const Self::Type;
208 }
209
210 /// Obtain a mutable base pointer for a slice of data.
211 ///
212 /// # Safety
213 ///
214 /// The returned pointer must never be null! Furthermore, the mutable pointer must
215 /// originally be derived from a mutable pointer.
216 pub unsafe trait AsMutPtr: AsPtr {
217 fn as_mut_ptr(&mut self) -> *mut Self::Type;
218 }
219
220 /// Safety: SlicePtr may only contain non-null pointers.
221 unsafe impl<T> AsPtr for SlicePtr<'_, T> {
222 type Type = T;
223 fn as_ptr(&self) -> *const T {
224 self.ptr.as_ptr().cast_const()
225 }
226 }
227
228 /// Safety: SlicePtr may only contain non-null pointers.
229 unsafe impl<T> AsPtr for MutSlicePtr<'_, T> {
230 type Type = T;
231 fn as_ptr(&self) -> *const T {
232 // The const-cast is allowed by variance.
233 self.ptr.as_ptr().cast_const()
234 }
235 }
236
237 /// Safety: SlicePtr may only contain non-null pointers. The only way to construct
238 /// a `MutSlicePtr` is from a mutable reference, so the underlying pointer is indeed
239 /// mutable.
240 unsafe impl<T> AsMutPtr for MutSlicePtr<'_, T> {
241 fn as_mut_ptr(&mut self) -> *mut T {
242 self.ptr.as_ptr()
243 }
244 }
245
246 /// Safety: Slices never return a null pointer.
247 unsafe impl<T, A> AsPtr for Poly<[T], A>
248 where
249 A: AllocatorCore,
250 {
251 type Type = T;
252 fn as_ptr(&self) -> *const T {
253 <[T]>::as_ptr(self)
254 }
255 }
256
257 /// Safety: Slices never return a null pointer. A mutable reference to `self` signals
258 /// an exclusive borrow - so the underlying pointer is indeed mutable.
259 unsafe impl<T, A> AsMutPtr for Poly<[T], A>
260 where
261 A: AllocatorCore,
262 {
263 fn as_mut_ptr(&mut self) -> *mut T {
264 <[T]>::as_mut_ptr(&mut *self)
265 }
266 }
267}
268
269///////////
270// Tests //
271///////////
272
273#[cfg(test)]
274mod tests {
275 use super::*;
276 use crate::alloc::{GlobalAllocator, Poly};
277
278 ///////////////////////////////
279 // SlicePtr from Const Slice //
280 ///////////////////////////////
281
282 fn slice_ptr_from_const_slice(base: &[u8]) {
283 let ptr = base.as_ptr();
284 assert!(!ptr.is_null(), "slices must not return null pointers");
285
286 let slice_ptr: SlicePtr<'_, u8> = base.precursor_into();
287 assert_eq!(slice_ptr.as_ptr(), ptr);
288 assert_eq!(base.precursor_len(), base.len());
289
290 // Safety: Check with Miri.
291 let derived = unsafe {
292 std::slice::from_raw_parts(base.precursor_into().as_ptr(), base.precursor_len())
293 };
294 assert_eq!(derived.as_ptr(), ptr);
295 assert_eq!(derived.len(), base.len());
296 }
297
298 #[test]
299 fn test_slice_ptr_from_const_slice() {
300 slice_ptr_from_const_slice(&[]);
301 slice_ptr_from_const_slice(&[1]);
302 slice_ptr_from_const_slice(&[1, 2]);
303
304 for len in 0..10 {
305 let base: Vec<u8> = vec![0; len];
306 slice_ptr_from_const_slice(&base);
307 }
308 }
309
310 /////////////////////////////
311 // SlicePtr from Mut Slice //
312 /////////////////////////////
313
314 fn slice_ptr_from_mut_slice(base: &mut [u8]) {
315 let ptr = base.as_ptr();
316 let len = base.len();
317 assert!(!ptr.is_null(), "slices must not return null pointers");
318
319 let precursor_len = <&mut [u8] as Precursor<SlicePtr<'_, u8>>>::precursor_len(&base);
320 assert_eq!(precursor_len, base.len());
321
322 let slice_ptr: SlicePtr<'_, u8> = base.precursor_into();
323 assert_eq!(slice_ptr.as_ptr(), ptr);
324
325 // Safety: Check with Miri.
326 let derived = unsafe { std::slice::from_raw_parts(slice_ptr.as_ptr(), precursor_len) };
327
328 assert_eq!(derived.as_ptr(), ptr);
329 assert_eq!(derived.len(), len);
330 }
331
332 #[test]
333 fn test_slice_ptr_from_mut_slice() {
334 slice_ptr_from_mut_slice(&mut []);
335 slice_ptr_from_mut_slice(&mut [1]);
336 slice_ptr_from_mut_slice(&mut [1, 2]);
337
338 for len in 0..10 {
339 let mut base: Vec<u8> = vec![0; len];
340 slice_ptr_from_mut_slice(&mut base);
341 }
342 }
343
344 /////////////////////////////
345 // SlicePtr from Mut Slice //
346 /////////////////////////////
347
348 fn mut_slice_ptr_from_mut_slice(base: &mut [u8]) {
349 let ptr = base.as_mut_ptr();
350 let len = base.len();
351
352 assert!(!ptr.is_null(), "slices must not return null pointers");
353
354 let precursor_len = <&mut [u8] as Precursor<SlicePtr<'_, u8>>>::precursor_len(&base);
355 assert_eq!(precursor_len, base.len());
356
357 let mut slice_ptr: MutSlicePtr<'_, u8> = base.precursor_into();
358 assert_eq!(slice_ptr.as_ptr(), ptr.cast_const());
359 assert_eq!(slice_ptr.as_mut_ptr(), ptr);
360
361 // Safety: Check with Miri.
362 let derived =
363 unsafe { std::slice::from_raw_parts_mut(slice_ptr.as_mut_ptr(), precursor_len) };
364
365 assert_eq!(derived.as_ptr(), ptr);
366 assert_eq!(derived.len(), len);
367 }
368
369 #[test]
370 fn test_mut_slice_ptr_from_mut_slice() {
371 mut_slice_ptr_from_mut_slice(&mut []);
372 mut_slice_ptr_from_mut_slice(&mut [1]);
373 mut_slice_ptr_from_mut_slice(&mut [1, 2]);
374
375 for len in 0..10 {
376 let mut base: Vec<u8> = vec![0; len];
377 mut_slice_ptr_from_mut_slice(&mut base);
378 }
379 }
380
381 /////////
382 // Box //
383 /////////
384
385 fn box_from_box(base: Poly<[u8], GlobalAllocator>) {
386 let ptr = base.as_ptr();
387 let len = base.len();
388
389 assert!(!ptr.is_null(), "slices must not return null pointers");
390
391 assert_eq!(base.precursor_len(), len);
392 let mut derived = base.precursor_into();
393
394 assert_eq!(derived.as_ptr(), ptr);
395 assert_eq!(derived.as_mut_ptr(), ptr.cast_mut());
396 assert_eq!(derived.len(), len);
397 }
398
399 #[test]
400 fn test_box() {
401 box_from_box(Poly::from_iter([].into_iter(), GlobalAllocator).unwrap());
402 box_from_box(Poly::from_iter([1].into_iter(), GlobalAllocator).unwrap());
403 box_from_box(Poly::from_iter([1, 2].into_iter(), GlobalAllocator).unwrap());
404
405 for len in 0..10 {
406 let base = Poly::broadcast(0, len, GlobalAllocator).unwrap();
407 box_from_box(base);
408 }
409 }
410}