1use std::marker::PhantomData;
2use std::mem::ManuallyDrop;
3use std::ops::{Deref, DerefMut};
4use std::ptr::NonNull;
5use std::sync::atomic::{AtomicU64, Ordering};
6
7use bytemuck::Pod;
8
9struct VecVTable {
12 size: usize,
13 align: usize,
14 drop_buffer: unsafe fn(*mut (), usize),
15}
16
17impl VecVTable {
18 const fn new<T>() -> Self {
19 unsafe fn drop_buffer<T>(ptr: *mut (), cap: usize) {
20 unsafe { drop(Vec::from_raw_parts(ptr.cast::<T>(), 0, cap)) }
21 }
22
23 Self {
24 size: size_of::<T>(),
25 align: align_of::<T>(),
26 drop_buffer: drop_buffer::<T>,
27 }
28 }
29
30 fn new_static<T>() -> &'static Self {
31 const { &Self::new::<T>() }
32 }
33}
34
35use crate::ffi::InternalArrowArray;
36
37enum BackingStorage {
38 Vec {
39 original_capacity: usize, vtable: &'static VecVTable,
41 },
42 InternalArrowArray(InternalArrowArray),
43
44 External,
47
48 Leaked,
53}
54
55struct SharedStorageInner<T> {
56 ref_count: AtomicU64,
57 ptr: *mut T,
58 length_in_bytes: usize,
59 backing: BackingStorage,
60 phantom: PhantomData<T>,
62}
63
64unsafe impl<T: Sync + Send> Sync for SharedStorageInner<T> {}
65
66impl<T> SharedStorageInner<T> {
67 pub fn from_vec(mut v: Vec<T>) -> Self {
68 let length_in_bytes = v.len() * size_of::<T>();
69 let original_capacity = v.capacity();
70 let ptr = v.as_mut_ptr();
71 core::mem::forget(v);
72 Self {
73 ref_count: AtomicU64::new(1),
74 ptr,
75 length_in_bytes,
76 backing: BackingStorage::Vec {
77 original_capacity,
78 vtable: VecVTable::new_static::<T>(),
79 },
80 phantom: PhantomData,
81 }
82 }
83}
84
85impl<T> Drop for SharedStorageInner<T> {
86 fn drop(&mut self) {
87 match core::mem::replace(&mut self.backing, BackingStorage::External) {
88 BackingStorage::InternalArrowArray(a) => drop(a),
89 BackingStorage::Vec {
90 original_capacity,
91 vtable,
92 } => unsafe {
93 if std::mem::needs_drop::<T>() {
95 core::ptr::drop_in_place(core::ptr::slice_from_raw_parts_mut(
96 self.ptr,
97 self.length_in_bytes / size_of::<T>(),
98 ));
99 }
100
101 if original_capacity > 0 {
103 (vtable.drop_buffer)(self.ptr.cast(), original_capacity);
104 }
105 },
106 BackingStorage::External | BackingStorage::Leaked => {},
107 }
108 }
109}
110
111pub struct SharedStorage<T> {
112 inner: NonNull<SharedStorageInner<T>>,
113 phantom: PhantomData<SharedStorageInner<T>>,
114}
115
116unsafe impl<T: Sync + Send> Send for SharedStorage<T> {}
117unsafe impl<T: Sync + Send> Sync for SharedStorage<T> {}
118
119impl<T> Default for SharedStorage<T> {
120 fn default() -> Self {
121 Self::empty()
122 }
123}
124
125impl<T> SharedStorage<T> {
126 const fn empty() -> Self {
127 assert!(align_of::<T>() <= 1 << 30);
128 static INNER: SharedStorageInner<()> = SharedStorageInner {
129 ref_count: AtomicU64::new(1),
130 ptr: core::ptr::without_provenance_mut(1 << 30), length_in_bytes: 0,
132 backing: BackingStorage::Leaked,
133 phantom: PhantomData,
134 };
135
136 Self {
137 inner: NonNull::new(&raw const INNER as *mut SharedStorageInner<T>).unwrap(),
138 phantom: PhantomData,
139 }
140 }
141
142 pub fn from_static(slice: &'static [T]) -> Self {
143 #[expect(clippy::manual_slice_size_calculation)]
144 let length_in_bytes = slice.len() * size_of::<T>();
145 let ptr = slice.as_ptr().cast_mut();
146 let inner = SharedStorageInner {
147 ref_count: AtomicU64::new(1),
148 ptr,
149 length_in_bytes,
150 backing: BackingStorage::External,
151 phantom: PhantomData,
152 };
153 Self {
154 inner: NonNull::new(Box::into_raw(Box::new(inner))).unwrap(),
155 phantom: PhantomData,
156 }
157 }
158
159 pub fn from_vec(v: Vec<T>) -> Self {
160 Self {
161 inner: NonNull::new(Box::into_raw(Box::new(SharedStorageInner::from_vec(v)))).unwrap(),
162 phantom: PhantomData,
163 }
164 }
165
166 pub fn from_internal_arrow_array(ptr: *const T, len: usize, arr: InternalArrowArray) -> Self {
167 let inner = SharedStorageInner {
168 ref_count: AtomicU64::new(1),
169 ptr: ptr.cast_mut(),
170 length_in_bytes: len * size_of::<T>(),
171 backing: BackingStorage::InternalArrowArray(arr),
172 phantom: PhantomData,
173 };
174 Self {
175 inner: NonNull::new(Box::into_raw(Box::new(inner))).unwrap(),
176 phantom: PhantomData,
177 }
178 }
179
180 pub fn leak(&mut self) {
185 assert!(self.is_exclusive());
186 unsafe {
187 let inner = &mut *self.inner.as_ptr();
188 core::mem::forget(core::mem::replace(
189 &mut inner.backing,
190 BackingStorage::Leaked,
191 ));
192 }
193 }
194}
195
196pub struct SharedStorageAsVecMut<'a, T> {
197 ss: &'a mut SharedStorage<T>,
198 vec: ManuallyDrop<Vec<T>>,
199}
200
201impl<T> Deref for SharedStorageAsVecMut<'_, T> {
202 type Target = Vec<T>;
203
204 fn deref(&self) -> &Self::Target {
205 &self.vec
206 }
207}
208
209impl<T> DerefMut for SharedStorageAsVecMut<'_, T> {
210 fn deref_mut(&mut self) -> &mut Self::Target {
211 &mut self.vec
212 }
213}
214
215impl<T> Drop for SharedStorageAsVecMut<'_, T> {
216 fn drop(&mut self) {
217 unsafe {
218 let vec = ManuallyDrop::take(&mut self.vec);
220 let inner = self.ss.inner.as_ptr();
221 inner.write(SharedStorageInner::from_vec(vec));
222 }
223 }
224}
225
226impl<T> SharedStorage<T> {
227 #[inline(always)]
228 pub fn len(&self) -> usize {
229 self.inner().length_in_bytes / size_of::<T>()
230 }
231
232 #[inline(always)]
233 pub fn as_ptr(&self) -> *const T {
234 self.inner().ptr
235 }
236
237 #[inline(always)]
238 pub fn is_exclusive(&mut self) -> bool {
239 self.inner().ref_count.load(Ordering::Acquire) == 1
241 }
242
243 #[inline(always)]
249 pub fn refcount(&self) -> u64 {
250 self.inner().ref_count.load(Ordering::Acquire)
252 }
253
254 pub fn try_as_mut_slice(&mut self) -> Option<&mut [T]> {
255 self.is_exclusive().then(|| {
256 let inner = self.inner();
257 let len = inner.length_in_bytes / size_of::<T>();
258 unsafe { core::slice::from_raw_parts_mut(inner.ptr, len) }
259 })
260 }
261
262 pub fn try_take_vec(&mut self) -> Option<Vec<T>> {
264 if !self.is_exclusive() {
266 return None;
267 }
268
269 let ret;
270 unsafe {
271 let inner = &mut *self.inner.as_ptr();
272
273 let BackingStorage::Vec {
276 original_capacity,
277 vtable,
278 } = &mut inner.backing
279 else {
280 return None;
281 };
282
283 if vtable.size != size_of::<T>() || vtable.align != align_of::<T>() {
284 return None;
285 }
286
287 let len = inner.length_in_bytes / size_of::<T>();
289 ret = Vec::from_raw_parts(inner.ptr, len, *original_capacity);
290 *original_capacity = 0;
291 inner.length_in_bytes = 0;
292 }
293 Some(ret)
294 }
295
296 pub fn try_as_mut_vec(&mut self) -> Option<SharedStorageAsVecMut<'_, T>> {
300 Some(SharedStorageAsVecMut {
301 vec: ManuallyDrop::new(self.try_take_vec()?),
302 ss: self,
303 })
304 }
305
306 pub fn try_into_vec(mut self) -> Result<Vec<T>, Self> {
307 self.try_take_vec().ok_or(self)
308 }
309
310 #[inline(always)]
311 fn inner(&self) -> &SharedStorageInner<T> {
312 unsafe { &*self.inner.as_ptr() }
313 }
314
315 #[cold]
318 unsafe fn drop_slow(&mut self) {
319 unsafe { drop(Box::from_raw(self.inner.as_ptr())) }
320 }
321}
322
323impl<T: Pod> SharedStorage<T> {
324 fn try_transmute<U: Pod>(self) -> Result<SharedStorage<U>, Self> {
325 let inner = self.inner();
326
327 if size_of::<T>() % size_of::<U>() != 0 && inner.length_in_bytes % size_of::<U>() != 0 {
330 return Err(self);
331 }
332
333 if align_of::<T>() % align_of::<U>() != 0 && !inner.ptr.cast::<U>().is_aligned() {
336 return Err(self);
337 }
338
339 let storage = SharedStorage {
340 inner: self.inner.cast(),
341 phantom: PhantomData,
342 };
343 std::mem::forget(self);
344 Ok(storage)
345 }
346}
347
348impl SharedStorage<u8> {
349 pub fn bytes_from_pod_vec<T: Pod>(v: Vec<T>) -> Self {
351 SharedStorage::from_vec(v)
353 .try_transmute::<u8>()
354 .unwrap_or_else(|_| unreachable!())
355 }
356}
357
358impl<T> Deref for SharedStorage<T> {
359 type Target = [T];
360
361 #[inline]
362 fn deref(&self) -> &Self::Target {
363 unsafe {
364 let inner = self.inner();
365 let len = inner.length_in_bytes / size_of::<T>();
366 core::slice::from_raw_parts(inner.ptr, len)
367 }
368 }
369}
370
371impl<T> Clone for SharedStorage<T> {
372 fn clone(&self) -> Self {
373 let inner = self.inner();
374 if !matches!(inner.backing, BackingStorage::Leaked) {
375 inner.ref_count.fetch_add(1, Ordering::Relaxed);
377 }
378 Self {
379 inner: self.inner,
380 phantom: PhantomData,
381 }
382 }
383}
384
385impl<T> Drop for SharedStorage<T> {
386 fn drop(&mut self) {
387 let inner = self.inner();
388 if matches!(inner.backing, BackingStorage::Leaked) {
389 return;
390 }
391
392 if inner.ref_count.fetch_sub(1, Ordering::Release) == 1 {
394 std::sync::atomic::fence(Ordering::Acquire);
395 unsafe {
396 self.drop_slow();
397 }
398 }
399 }
400}