polars_arrow/
storage.rs

1use std::marker::PhantomData;
2use std::mem::ManuallyDrop;
3use std::ops::{Deref, DerefMut};
4use std::ptr::NonNull;
5use std::sync::atomic::{AtomicU64, Ordering};
6
7use bytemuck::Pod;
8
9// Allows us to transmute between types while also keeping the original
10// stats and drop method of the Vec around.
11struct VecVTable {
12    size: usize,
13    align: usize,
14    drop_buffer: unsafe fn(*mut (), usize),
15}
16
17impl VecVTable {
18    const fn new<T>() -> Self {
19        unsafe fn drop_buffer<T>(ptr: *mut (), cap: usize) {
20            unsafe { drop(Vec::from_raw_parts(ptr.cast::<T>(), 0, cap)) }
21        }
22
23        Self {
24            size: size_of::<T>(),
25            align: align_of::<T>(),
26            drop_buffer: drop_buffer::<T>,
27        }
28    }
29
30    fn new_static<T>() -> &'static Self {
31        const { &Self::new::<T>() }
32    }
33}
34
35use crate::ffi::InternalArrowArray;
36
37enum BackingStorage {
38    Vec {
39        original_capacity: usize, // Elements, not bytes.
40        vtable: &'static VecVTable,
41    },
42    InternalArrowArray(InternalArrowArray),
43
44    /// Backed by some external method which we do not need to take care of,
45    /// but we still should refcount and drop the SharedStorageInner.
46    External,
47
48    /// Both the backing storage and the SharedStorageInner are leaked, no
49    /// refcounting is done. This technically should be a flag on
50    /// SharedStorageInner instead of being here, but that would add 8 more
51    /// bytes to SharedStorageInner, so here it is.
52    Leaked,
53}
54
55struct SharedStorageInner<T> {
56    ref_count: AtomicU64,
57    ptr: *mut T,
58    length_in_bytes: usize,
59    backing: BackingStorage,
60    // https://github.com/rust-lang/rfcs/blob/master/text/0769-sound-generic-drop.md#phantom-data
61    phantom: PhantomData<T>,
62}
63
64unsafe impl<T: Sync + Send> Sync for SharedStorageInner<T> {}
65
66impl<T> SharedStorageInner<T> {
67    pub fn from_vec(mut v: Vec<T>) -> Self {
68        let length_in_bytes = v.len() * size_of::<T>();
69        let original_capacity = v.capacity();
70        let ptr = v.as_mut_ptr();
71        core::mem::forget(v);
72        Self {
73            ref_count: AtomicU64::new(1),
74            ptr,
75            length_in_bytes,
76            backing: BackingStorage::Vec {
77                original_capacity,
78                vtable: VecVTable::new_static::<T>(),
79            },
80            phantom: PhantomData,
81        }
82    }
83}
84
85impl<T> Drop for SharedStorageInner<T> {
86    fn drop(&mut self) {
87        match core::mem::replace(&mut self.backing, BackingStorage::External) {
88            BackingStorage::InternalArrowArray(a) => drop(a),
89            BackingStorage::Vec {
90                original_capacity,
91                vtable,
92            } => unsafe {
93                // Drop the elements in our slice.
94                if std::mem::needs_drop::<T>() {
95                    core::ptr::drop_in_place(core::ptr::slice_from_raw_parts_mut(
96                        self.ptr,
97                        self.length_in_bytes / size_of::<T>(),
98                    ));
99                }
100
101                // Free the buffer.
102                if original_capacity > 0 {
103                    (vtable.drop_buffer)(self.ptr.cast(), original_capacity);
104                }
105            },
106            BackingStorage::External | BackingStorage::Leaked => {},
107        }
108    }
109}
110
111pub struct SharedStorage<T> {
112    inner: NonNull<SharedStorageInner<T>>,
113    phantom: PhantomData<SharedStorageInner<T>>,
114}
115
116unsafe impl<T: Sync + Send> Send for SharedStorage<T> {}
117unsafe impl<T: Sync + Send> Sync for SharedStorage<T> {}
118
119impl<T> Default for SharedStorage<T> {
120    fn default() -> Self {
121        Self::empty()
122    }
123}
124
125impl<T> SharedStorage<T> {
126    const fn empty() -> Self {
127        assert!(align_of::<T>() <= 1 << 30);
128        static INNER: SharedStorageInner<()> = SharedStorageInner {
129            ref_count: AtomicU64::new(1),
130            ptr: core::ptr::without_provenance_mut(1 << 30), // Very overaligned for any T.
131            length_in_bytes: 0,
132            backing: BackingStorage::Leaked,
133            phantom: PhantomData,
134        };
135
136        Self {
137            inner: NonNull::new(&raw const INNER as *mut SharedStorageInner<T>).unwrap(),
138            phantom: PhantomData,
139        }
140    }
141
142    pub fn from_static(slice: &'static [T]) -> Self {
143        #[expect(clippy::manual_slice_size_calculation)]
144        let length_in_bytes = slice.len() * size_of::<T>();
145        let ptr = slice.as_ptr().cast_mut();
146        let inner = SharedStorageInner {
147            ref_count: AtomicU64::new(1),
148            ptr,
149            length_in_bytes,
150            backing: BackingStorage::External,
151            phantom: PhantomData,
152        };
153        Self {
154            inner: NonNull::new(Box::into_raw(Box::new(inner))).unwrap(),
155            phantom: PhantomData,
156        }
157    }
158
159    pub fn from_vec(v: Vec<T>) -> Self {
160        Self {
161            inner: NonNull::new(Box::into_raw(Box::new(SharedStorageInner::from_vec(v)))).unwrap(),
162            phantom: PhantomData,
163        }
164    }
165
166    /// # Safety
167    /// The range [ptr, ptr+len) needs to be valid and aligned for T.
168    /// ptr may not be null.
169    pub unsafe fn from_internal_arrow_array(
170        ptr: *const T,
171        len: usize,
172        arr: InternalArrowArray,
173    ) -> Self {
174        assert!(!ptr.is_null() && ptr.is_aligned());
175        let inner = SharedStorageInner {
176            ref_count: AtomicU64::new(1),
177            ptr: ptr.cast_mut(),
178            length_in_bytes: len * size_of::<T>(),
179            backing: BackingStorage::InternalArrowArray(arr),
180            phantom: PhantomData,
181        };
182        Self {
183            inner: NonNull::new(Box::into_raw(Box::new(inner))).unwrap(),
184            phantom: PhantomData,
185        }
186    }
187
188    /// Leaks this SharedStorage such that it and its inner value is never
189    /// dropped. In return no refcounting needs to be performed.
190    ///
191    /// The SharedStorage must be exclusive.
192    pub fn leak(&mut self) {
193        assert!(self.is_exclusive());
194        unsafe {
195            let inner = &mut *self.inner.as_ptr();
196            core::mem::forget(core::mem::replace(
197                &mut inner.backing,
198                BackingStorage::Leaked,
199            ));
200        }
201    }
202}
203
204pub struct SharedStorageAsVecMut<'a, T> {
205    ss: &'a mut SharedStorage<T>,
206    vec: ManuallyDrop<Vec<T>>,
207}
208
209impl<T> Deref for SharedStorageAsVecMut<'_, T> {
210    type Target = Vec<T>;
211
212    fn deref(&self) -> &Self::Target {
213        &self.vec
214    }
215}
216
217impl<T> DerefMut for SharedStorageAsVecMut<'_, T> {
218    fn deref_mut(&mut self) -> &mut Self::Target {
219        &mut self.vec
220    }
221}
222
223impl<T> Drop for SharedStorageAsVecMut<'_, T> {
224    fn drop(&mut self) {
225        unsafe {
226            // Restore the SharedStorage.
227            let vec = ManuallyDrop::take(&mut self.vec);
228            let inner = self.ss.inner.as_ptr();
229            inner.write(SharedStorageInner::from_vec(vec));
230        }
231    }
232}
233
234impl<T> SharedStorage<T> {
235    #[inline(always)]
236    pub fn len(&self) -> usize {
237        self.inner().length_in_bytes / size_of::<T>()
238    }
239
240    #[inline(always)]
241    pub fn as_ptr(&self) -> *const T {
242        self.inner().ptr
243    }
244
245    #[inline(always)]
246    pub fn is_exclusive(&mut self) -> bool {
247        // Ordering semantics copied from Arc<T>.
248        self.inner().ref_count.load(Ordering::Acquire) == 1
249    }
250
251    /// Gets the reference count of this storage.
252    ///
253    /// Because this function takes a shared reference this should not be used
254    /// in cases where we are checking if the refcount is one for safety,
255    /// someone else could increment it in the meantime.
256    #[inline(always)]
257    pub fn refcount(&self) -> u64 {
258        // Ordering semantics copied from Arc<T>.
259        self.inner().ref_count.load(Ordering::Acquire)
260    }
261
262    pub fn try_as_mut_slice(&mut self) -> Option<&mut [T]> {
263        self.is_exclusive().then(|| {
264            let inner = self.inner();
265            let len = inner.length_in_bytes / size_of::<T>();
266            unsafe { core::slice::from_raw_parts_mut(inner.ptr, len) }
267        })
268    }
269
270    /// Try to take the vec backing this SharedStorage, leaving this as an empty slice.
271    pub fn try_take_vec(&mut self) -> Option<Vec<T>> {
272        // If there are other references we can't get an exclusive reference.
273        if !self.is_exclusive() {
274            return None;
275        }
276
277        let ret;
278        unsafe {
279            let inner = &mut *self.inner.as_ptr();
280
281            // We may only go back to a Vec if we originally came from a Vec
282            // where the desired size/align matches the original.
283            let BackingStorage::Vec {
284                original_capacity,
285                vtable,
286            } = &mut inner.backing
287            else {
288                return None;
289            };
290
291            if vtable.size != size_of::<T>() || vtable.align != align_of::<T>() {
292                return None;
293            }
294
295            // Steal vec from inner.
296            let len = inner.length_in_bytes / size_of::<T>();
297            ret = Vec::from_raw_parts(inner.ptr, len, *original_capacity);
298            *original_capacity = 0;
299            inner.length_in_bytes = 0;
300        }
301        Some(ret)
302    }
303
304    /// Attempts to call the given function with this SharedStorage as a
305    /// reference to a mutable Vec. If this SharedStorage can't be converted to
306    /// a Vec the function is not called and instead returned as an error.
307    pub fn try_as_mut_vec(&mut self) -> Option<SharedStorageAsVecMut<'_, T>> {
308        Some(SharedStorageAsVecMut {
309            vec: ManuallyDrop::new(self.try_take_vec()?),
310            ss: self,
311        })
312    }
313
314    pub fn try_into_vec(mut self) -> Result<Vec<T>, Self> {
315        self.try_take_vec().ok_or(self)
316    }
317
318    #[inline(always)]
319    fn inner(&self) -> &SharedStorageInner<T> {
320        unsafe { &*self.inner.as_ptr() }
321    }
322
323    /// # Safety
324    /// May only be called once.
325    #[cold]
326    unsafe fn drop_slow(&mut self) {
327        unsafe { drop(Box::from_raw(self.inner.as_ptr())) }
328    }
329}
330
331impl<T: Pod> SharedStorage<T> {
332    pub fn try_transmute<U: Pod>(self) -> Result<SharedStorage<U>, Self> {
333        let inner = self.inner();
334
335        // The length of the array in bytes must be a multiple of the target size.
336        // We can skip this check if the size of U divides the size of T.
337        if !size_of::<T>().is_multiple_of(size_of::<U>())
338            && !inner.length_in_bytes.is_multiple_of(size_of::<U>())
339        {
340            return Err(self);
341        }
342
343        // The pointer must be properly aligned for U.
344        // We can skip this check if the alignment of U divides the alignment of T.
345        if !align_of::<T>().is_multiple_of(align_of::<U>()) && !inner.ptr.cast::<U>().is_aligned() {
346            return Err(self);
347        }
348
349        let storage = SharedStorage {
350            inner: self.inner.cast(),
351            phantom: PhantomData,
352        };
353        std::mem::forget(self);
354        Ok(storage)
355    }
356}
357
358impl SharedStorage<u8> {
359    /// Create a [`SharedStorage<u8>`][SharedStorage] from a [`Vec`] of [`Pod`].
360    pub fn bytes_from_pod_vec<T: Pod>(v: Vec<T>) -> Self {
361        // This can't fail, bytes is compatible with everything.
362        SharedStorage::from_vec(v)
363            .try_transmute::<u8>()
364            .unwrap_or_else(|_| unreachable!())
365    }
366}
367
368impl<T> Deref for SharedStorage<T> {
369    type Target = [T];
370
371    #[inline]
372    fn deref(&self) -> &Self::Target {
373        unsafe {
374            let inner = self.inner();
375            let len = inner.length_in_bytes / size_of::<T>();
376            core::slice::from_raw_parts(inner.ptr, len)
377        }
378    }
379}
380
381impl<T> Clone for SharedStorage<T> {
382    fn clone(&self) -> Self {
383        let inner = self.inner();
384        if !matches!(inner.backing, BackingStorage::Leaked) {
385            // Ordering semantics copied from Arc<T>.
386            inner.ref_count.fetch_add(1, Ordering::Relaxed);
387        }
388        Self {
389            inner: self.inner,
390            phantom: PhantomData,
391        }
392    }
393}
394
395impl<T> Drop for SharedStorage<T> {
396    fn drop(&mut self) {
397        let inner = self.inner();
398        if matches!(inner.backing, BackingStorage::Leaked) {
399            return;
400        }
401
402        // Ordering semantics copied from Arc<T>.
403        if inner.ref_count.fetch_sub(1, Ordering::Release) == 1 {
404            std::sync::atomic::fence(Ordering::Acquire);
405            unsafe {
406                self.drop_slow();
407            }
408        }
409    }
410}