diskann_quantization/bits/ptr.rs
1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use std::{marker::PhantomData, ptr::NonNull};
7
8pub use sealed::{AsMutPtr, AsPtr, Precursor};
9
10////////////////////
11// Pointer Traits //
12////////////////////
13
14/// A constant pointer with an associated lifetime.
15///
16/// The only safe way to construct this type is through implementations
17/// of the `crate::bits::slice::sealed::Precursor` trait or by reborrowing a bitslice.
18#[derive(Debug, Clone, Copy)]
19pub struct SlicePtr<'a, T> {
20 ptr: NonNull<T>,
21 lifetime: PhantomData<&'a T>,
22}
23
24impl<T> SlicePtr<'_, T> {
25 /// # Safety
26 ///
27 /// It's the callers responsibility to ensure the correct lifetime is attached to
28 /// the underlying pointer and that the constraints of any unsafe traits implemented by
29 /// this type are upheld.
30 pub(super) unsafe fn new_unchecked(ptr: NonNull<T>) -> Self {
31 Self {
32 ptr,
33 lifetime: PhantomData,
34 }
35 }
36}
37
38// SAFETY: Slices are `Send` when `T: Send`.
39unsafe impl<T: Send> Send for SlicePtr<'_, T> {}
40
41// SAFETY: Slices are `Sync` when `T: Sync`.
42unsafe impl<T: Sync> Sync for SlicePtr<'_, T> {}
43
44/// A mutable pointer with an associated lifetime.
45///
46/// The only safe way to construct this type is through implementations
47/// of the `crate::bits::slice::sealed::Precursor` trait or by reborrowing a bitslice.
48#[derive(Debug)]
49pub struct MutSlicePtr<'a, T> {
50 ptr: NonNull<T>,
51 lifetime: PhantomData<&'a mut T>,
52}
53
54impl<T> MutSlicePtr<'_, T> {
55 /// # Safety
56 ///
57 /// It's the callers responsibility to ensure the correct lifetime is attached to
58 /// the underlying pointer and that the constraints of any unsafe traits implemented by
59 /// this type are upheld.
60 pub(super) unsafe fn new_unchecked(ptr: NonNull<T>) -> Self {
61 Self {
62 ptr,
63 lifetime: PhantomData,
64 }
65 }
66}
67
68// SAFETY: Mutable slices are `Send` when `T: Send`.
69unsafe impl<T: Send> Send for MutSlicePtr<'_, T> {}
70
71// SAFETY: Mutable slices are `Sync` when `T: Sync`.
72unsafe impl<T: Sync> Sync for MutSlicePtr<'_, T> {}
73
74mod sealed {
75 use std::{marker::PhantomData, ptr::NonNull};
76
77 use super::{MutSlicePtr, SlicePtr};
78 use crate::alloc::{AllocatorCore, Poly};
79
80 /// A precursor for a pointer type that is used as the base of a `BitSlice`.
81 ///
82 /// This trait is *unsafe* because implementing it incorrectly will lead to lifetime
83 /// violations, out-of-bounds accesses, and other undefined behavior.
84 ///
85 /// # Safety
86 ///
87 /// There are two components to a safe implementation of `Precursor`.
88 ///
89 /// ## Memory Safety
90 ///
91 /// It is the implementors responsibility to ensure all the preconditions necessary so
92 /// the following expression is safe:
93 /// ```ignore
94 /// let x: impl Precursor ...
95 /// let len = x.len();
96 /// // The length and derived pointer **must** maintain the following:
97 /// unsafe { std::slice::from_raw_parts(x.precursor_into().as_ptr(), len) }
98 /// ```
99 /// Furthermore, if `Target: AsMutPtr`, then the preconditions for
100 /// `std::slice::from_raw_parts_mut` must be upheld.
101 ///
102 /// ## Lifetime Safety
103 ///
104 /// If is the implementors responsibility that `Target` correctly captures lifetimes
105 /// so the slice obtained from the precursor obeys Rust's rules for references.
106 pub unsafe trait Precursor<Target>
107 where
108 Target: AsPtr,
109 {
110 /// Consume `self` and convert into `Target`.
111 fn precursor_into(self) -> Target;
112
113 /// Return the number of elements of type `<Target as AsPtr>::Type` that are valid
114 /// from the pointer returned by `self.precursor_into().as_ptr()`.
115 fn precursor_len(&self) -> usize;
116 }
117
118 /// Safety: Slices implicitly fit the pre-conditions for `std::slice::from_raw_parts`,
119 /// meaning in particular they will not return null pointers, even with zero size.
120 ///
121 /// This implementation simply breaks the slice into its raw parts.
122 ///
123 /// The `SlicePtr` captures the reference lifetime.
124 unsafe impl<'a, T> Precursor<SlicePtr<'a, T>> for &'a [T] {
125 fn precursor_into(self) -> SlicePtr<'a, T> {
126 SlicePtr {
127 // Safety: Slices cannot yield null pointers.
128 //
129 // The `cast_mut()` is safe because we do not provide a way to retrieve a
130 // mutable pointer from `SlicePtr`.
131 ptr: unsafe { NonNull::new_unchecked(self.as_ptr().cast_mut()) },
132 lifetime: PhantomData,
133 }
134 }
135 fn precursor_len(&self) -> usize {
136 <[T]>::len(self)
137 }
138 }
139
140 /// Safety: Slices implicitly fit the pre-conditions for `std::slice::from_raw_parts`,
141 /// meaning in particular they will not return null pointers, even with zero size.
142 ///
143 /// This implementation simply breaks the slice into its raw parts.
144 ///
145 /// The `SlicePtr` captures the reference lifetime. Decaying a mutable borrow to a
146 /// normal borrow is allowed.
147 unsafe impl<'a, T> Precursor<SlicePtr<'a, T>> for &'a mut [T] {
148 fn precursor_into(self) -> SlicePtr<'a, T> {
149 SlicePtr {
150 // Safety: Slices cannot yield null pointers.
151 //
152 // The `cast_mut()` is safe because we do not provide a way to retrieve a
153 // mutable pointer from `SlicePtr`.
154 ptr: unsafe { NonNull::new_unchecked(self.as_ptr().cast_mut()) },
155 lifetime: PhantomData,
156 }
157 }
158 fn precursor_len(&self) -> usize {
159 self.len()
160 }
161 }
162
163 /// Safety: Slices implicitly fit the pre-conditions for `std::slice::from_raw_parts`
164 /// and `std::slice::from_raw_parts_mut` meaning in particular they will not return
165 /// null pointers, even with zero size.
166 ///
167 /// This implementation simply breaks the slice into its raw parts.
168 ///
169 /// The `SlicePtr` captures the mutable reference lifetime.
170 unsafe impl<'a, T> Precursor<MutSlicePtr<'a, T>> for &'a mut [T] {
171 fn precursor_into(self) -> MutSlicePtr<'a, T> {
172 MutSlicePtr {
173 // Safety: Slices cannot yield null pointers.
174 ptr: unsafe { NonNull::new_unchecked(self.as_mut_ptr()) },
175 lifetime: PhantomData,
176 }
177 }
178 fn precursor_len(&self) -> usize {
179 self.len()
180 }
181 }
182
183 /// Safety: This implementation forwards through the [`Poly`], whose implementation of
184 /// `AsPtr` and `AsPtrMut` follow the same logic as those of captured slices.
185 ///
186 /// Since a [`Poly`] owns its contents, the lifetime requirements are satisfied.
187 unsafe impl<T, A> Precursor<Poly<[T], A>> for Poly<[T], A>
188 where
189 A: AllocatorCore,
190 {
191 fn precursor_into(self) -> Poly<[T], A> {
192 self
193 }
194 fn precursor_len(&self) -> usize {
195 self.len()
196 }
197 }
198
199 /// Obtain a constant base pointer for a slice of data.
200 ///
201 /// # Safety
202 ///
203 /// The returned pointer must never be null!
204 pub unsafe trait AsPtr {
205 type Type;
206 fn as_ptr(&self) -> *const Self::Type;
207 }
208
209 /// Obtain a mutable base pointer for a slice of data.
210 ///
211 /// # Safety
212 ///
213 /// The returned pointer must never be null! Furthermore, the mutable pointer must
214 /// originally be derived from a mutable pointer.
215 pub unsafe trait AsMutPtr: AsPtr {
216 fn as_mut_ptr(&mut self) -> *mut Self::Type;
217 }
218
219 /// Safety: SlicePtr may only contain non-null pointers.
220 unsafe impl<T> AsPtr for SlicePtr<'_, T> {
221 type Type = T;
222 fn as_ptr(&self) -> *const T {
223 self.ptr.as_ptr().cast_const()
224 }
225 }
226
227 /// Safety: SlicePtr may only contain non-null pointers.
228 unsafe impl<T> AsPtr for MutSlicePtr<'_, T> {
229 type Type = T;
230 fn as_ptr(&self) -> *const T {
231 // The const-cast is allowed by variance.
232 self.ptr.as_ptr().cast_const()
233 }
234 }
235
236 /// Safety: SlicePtr may only contain non-null pointers. The only way to construct
237 /// a `MutSlicePtr` is from a mutable reference, so the underlying pointer is indeed
238 /// mutable.
239 unsafe impl<T> AsMutPtr for MutSlicePtr<'_, T> {
240 fn as_mut_ptr(&mut self) -> *mut T {
241 self.ptr.as_ptr()
242 }
243 }
244
245 /// Safety: Slices never return a null pointer.
246 unsafe impl<T, A> AsPtr for Poly<[T], A>
247 where
248 A: AllocatorCore,
249 {
250 type Type = T;
251 fn as_ptr(&self) -> *const T {
252 <[T]>::as_ptr(self)
253 }
254 }
255
256 /// Safety: Slices never return a null pointer. A mutable reference to `self` signals
257 /// an exclusive borrow - so the underlying pointer is indeed mutable.
258 unsafe impl<T, A> AsMutPtr for Poly<[T], A>
259 where
260 A: AllocatorCore,
261 {
262 fn as_mut_ptr(&mut self) -> *mut T {
263 <[T]>::as_mut_ptr(&mut *self)
264 }
265 }
266}
267
268///////////
269// Tests //
270///////////
271
272#[cfg(test)]
273mod tests {
274 use super::*;
275 use crate::alloc::{GlobalAllocator, Poly};
276
277 ///////////////////////////////
278 // SlicePtr from Const Slice //
279 ///////////////////////////////
280
281 fn slice_ptr_from_const_slice(base: &[u8]) {
282 let ptr = base.as_ptr();
283 assert!(!ptr.is_null(), "slices must not return null pointers");
284
285 let slice_ptr: SlicePtr<'_, u8> = base.precursor_into();
286 assert_eq!(slice_ptr.as_ptr(), ptr);
287 assert_eq!(base.precursor_len(), base.len());
288
289 // Safety: Check with Miri.
290 let derived = unsafe {
291 std::slice::from_raw_parts(base.precursor_into().as_ptr(), base.precursor_len())
292 };
293 assert_eq!(derived.as_ptr(), ptr);
294 assert_eq!(derived.len(), base.len());
295 }
296
297 #[test]
298 fn test_slice_ptr_from_const_slice() {
299 slice_ptr_from_const_slice(&[]);
300 slice_ptr_from_const_slice(&[1]);
301 slice_ptr_from_const_slice(&[1, 2]);
302
303 for len in 0..10 {
304 let base: Vec<u8> = vec![0; len];
305 slice_ptr_from_const_slice(&base);
306 }
307 }
308
309 /////////////////////////////
310 // SlicePtr from Mut Slice //
311 /////////////////////////////
312
313 fn slice_ptr_from_mut_slice(base: &mut [u8]) {
314 let ptr = base.as_ptr();
315 let len = base.len();
316 assert!(!ptr.is_null(), "slices must not return null pointers");
317
318 let precursor_len = <&mut [u8] as Precursor<SlicePtr<'_, u8>>>::precursor_len(&base);
319 assert_eq!(precursor_len, base.len());
320
321 let slice_ptr: SlicePtr<'_, u8> = base.precursor_into();
322 assert_eq!(slice_ptr.as_ptr(), ptr);
323
324 // Safety: Check with Miri.
325 let derived = unsafe { std::slice::from_raw_parts(slice_ptr.as_ptr(), precursor_len) };
326
327 assert_eq!(derived.as_ptr(), ptr);
328 assert_eq!(derived.len(), len);
329 }
330
331 #[test]
332 fn test_slice_ptr_from_mut_slice() {
333 slice_ptr_from_mut_slice(&mut []);
334 slice_ptr_from_mut_slice(&mut [1]);
335 slice_ptr_from_mut_slice(&mut [1, 2]);
336
337 for len in 0..10 {
338 let mut base: Vec<u8> = vec![0; len];
339 slice_ptr_from_mut_slice(&mut base);
340 }
341 }
342
343 /////////////////////////////
344 // SlicePtr from Mut Slice //
345 /////////////////////////////
346
347 fn mut_slice_ptr_from_mut_slice(base: &mut [u8]) {
348 let ptr = base.as_mut_ptr();
349 let len = base.len();
350
351 assert!(!ptr.is_null(), "slices must not return null pointers");
352
353 let precursor_len = <&mut [u8] as Precursor<SlicePtr<'_, u8>>>::precursor_len(&base);
354 assert_eq!(precursor_len, base.len());
355
356 let mut slice_ptr: MutSlicePtr<'_, u8> = base.precursor_into();
357 assert_eq!(slice_ptr.as_ptr(), ptr.cast_const());
358 assert_eq!(slice_ptr.as_mut_ptr(), ptr);
359
360 // Safety: Check with Miri.
361 let derived =
362 unsafe { std::slice::from_raw_parts_mut(slice_ptr.as_mut_ptr(), precursor_len) };
363
364 assert_eq!(derived.as_ptr(), ptr);
365 assert_eq!(derived.len(), len);
366 }
367
368 #[test]
369 fn test_mut_slice_ptr_from_mut_slice() {
370 mut_slice_ptr_from_mut_slice(&mut []);
371 mut_slice_ptr_from_mut_slice(&mut [1]);
372 mut_slice_ptr_from_mut_slice(&mut [1, 2]);
373
374 for len in 0..10 {
375 let mut base: Vec<u8> = vec![0; len];
376 mut_slice_ptr_from_mut_slice(&mut base);
377 }
378 }
379
380 /////////
381 // Box //
382 /////////
383
384 fn box_from_box(base: Poly<[u8], GlobalAllocator>) {
385 let ptr = base.as_ptr();
386 let len = base.len();
387
388 assert!(!ptr.is_null(), "slices must not return null pointers");
389
390 assert_eq!(base.precursor_len(), len);
391 let mut derived = base.precursor_into();
392
393 assert_eq!(derived.as_ptr(), ptr);
394 assert_eq!(derived.as_mut_ptr(), ptr.cast_mut());
395 assert_eq!(derived.len(), len);
396 }
397
398 #[test]
399 fn test_box() {
400 box_from_box(Poly::from_iter([].into_iter(), GlobalAllocator).unwrap());
401 box_from_box(Poly::from_iter([1].into_iter(), GlobalAllocator).unwrap());
402 box_from_box(Poly::from_iter([1, 2].into_iter(), GlobalAllocator).unwrap());
403
404 for len in 0..10 {
405 let base = Poly::broadcast(0, len, GlobalAllocator).unwrap();
406 box_from_box(base);
407 }
408 }
409}