facet_core/impls_alloc/
arc.rs

1use core::alloc::Layout;
2
3use crate::{
4    Def, Facet, KnownSmartPointer, PtrConst, PtrMut, PtrUninit, Shape, ShapeLayout,
5    SmartPointerDef, SmartPointerFlags, SmartPointerVTable, TryBorrowInnerError, TryFromError,
6    TryIntoInnerError, Type, UserType, ValueVTable, value_vtable,
7};
8
9unsafe impl<'a, T: Facet<'a>> Facet<'a> for alloc::sync::Arc<T> {
10    const VTABLE: &'static ValueVTable = &const {
11        // Define the functions for transparent conversion between Arc<T> and T
12        unsafe fn try_from<'a, 'shape, 'src, 'dst, T: Facet<'a> + ?Sized>(
13            src_ptr: PtrConst<'src>,
14            src_shape: &'shape Shape<'shape>,
15            dst: PtrUninit<'dst>,
16        ) -> Result<PtrMut<'dst>, TryFromError<'shape>> {
17            if src_shape.id != T::SHAPE.id {
18                return Err(TryFromError::UnsupportedSourceShape {
19                    src_shape,
20                    expected: &[T::SHAPE],
21                });
22            }
23
24            if let ShapeLayout::Unsized = T::SHAPE.layout {
25                panic!("can't try_from with unsized type");
26            }
27
28            use alloc::sync::Arc;
29
30            // Get the layout for T
31            let layout = match T::SHAPE.layout {
32                ShapeLayout::Sized(layout) => layout,
33                ShapeLayout::Unsized => panic!("Unsized type not supported"),
34            };
35
36            // We'll create a new memory location, copy the value, then create an Arc from it
37            let size_of_arc_header = core::mem::size_of::<usize>() * 2;
38
39            // Use Layout::extend to combine header and value layout with correct alignment and padding
40            let header_layout =
41                Layout::from_size_align(size_of_arc_header, core::mem::align_of::<usize>())
42                    .unwrap();
43            let (arc_layout, value_offset) = header_layout.extend(layout).unwrap();
44
45            // To ensure that our allocation is correct for the Arc memory model,
46            // round up the allocation to the next multiple of 8 (Arc's alignment)
47            let adjusted_size = (arc_layout.size() + 7) & !7;
48            let final_layout =
49                unsafe { Layout::from_size_align_unchecked(adjusted_size, arc_layout.align()) };
50
51            let mem = unsafe { alloc::alloc::alloc(final_layout) };
52
53            unsafe {
54                // Copy the Arc header (refcounts, vtable pointer, etc.) from a dummy Arc<()>
55                let dummy_arc = Arc::new(());
56                let header_start = (Arc::as_ptr(&dummy_arc) as *const u8).sub(size_of_arc_header);
57                core::ptr::copy_nonoverlapping(header_start, mem, size_of_arc_header);
58
59                // Copy the source value into the memory area at the correct value offset after the Arc header
60                core::ptr::copy_nonoverlapping(
61                    src_ptr.as_byte_ptr(),
62                    mem.add(value_offset),
63                    layout.size(),
64                );
65            }
66
67            // Create an Arc from our allocated and initialized memory
68            let ptr = unsafe { mem.add(value_offset) };
69            let t_ptr: *mut T = unsafe { core::mem::transmute_copy(&ptr) };
70            let arc = unsafe { Arc::from_raw(t_ptr) };
71
72            // Move the Arc into the destination and return a PtrMut for it
73            Ok(unsafe { dst.put(arc) })
74        }
75
76        unsafe fn try_into_inner<'a, 'src, 'dst, T: Facet<'a> + ?Sized>(
77            src_ptr: PtrMut<'src>,
78            dst: PtrUninit<'dst>,
79        ) -> Result<PtrMut<'dst>, TryIntoInnerError> {
80            use alloc::sync::Arc;
81
82            // Read the Arc from the source pointer
83            let mut arc = unsafe { src_ptr.read::<Arc<T>>() };
84
85            // For unsized types, we need to know how many bytes to copy.
86            let size = match T::SHAPE.layout {
87                ShapeLayout::Sized(layout) => layout.size(),
88                _ => panic!("cannot try_into_inner with unsized type"),
89            };
90
91            // Check if we have exclusive access to the Arc (strong count = 1)
92            if let Some(inner_ref) = Arc::get_mut(&mut arc) {
93                // We have exclusive access, so we can safely copy the inner value
94                let inner_ptr = inner_ref as *const T as *const u8;
95
96                unsafe {
97                    // Copy the inner value to the destination
98                    core::ptr::copy_nonoverlapping(inner_ptr, dst.as_mut_byte_ptr(), size);
99
100                    // Prevent dropping the Arc normally which would also drop the inner value
101                    // that we've already copied
102                    let raw_ptr = Arc::into_raw(arc);
103
104                    // We need to deallocate the Arc without running destructors
105                    // Get the Arc layout
106                    let size_of_arc_header = core::mem::size_of::<usize>() * 2;
107                    let layout = match T::SHAPE.layout {
108                        ShapeLayout::Sized(layout) => layout,
109                        _ => unreachable!("We already checked that T is sized"),
110                    };
111                    let arc_layout = Layout::from_size_align_unchecked(
112                        size_of_arc_header + size,
113                        layout.align(),
114                    );
115
116                    // Get the start of the allocation (header is before the data)
117                    let allocation_start = (raw_ptr as *mut u8).sub(size_of_arc_header);
118
119                    // Deallocate the memory without running any destructors
120                    alloc::alloc::dealloc(allocation_start, arc_layout);
121
122                    // Return a PtrMut to the destination, which now owns the value
123                    Ok(PtrMut::new(dst.as_mut_byte_ptr()))
124                }
125            } else {
126                // Arc is shared, so we can't extract the inner value
127                core::mem::forget(arc);
128                Err(TryIntoInnerError::Unavailable)
129            }
130        }
131
132        unsafe fn try_borrow_inner<'a, 'src, T: Facet<'a>>(
133            src_ptr: PtrConst<'src>,
134        ) -> Result<PtrConst<'src>, TryBorrowInnerError> {
135            let arc = unsafe { src_ptr.get::<alloc::sync::Arc<T>>() };
136            Ok(PtrConst::new(&**arc))
137        }
138
139        let mut vtable = value_vtable!(alloc::sync::Arc<T>, |f, opts| {
140            write!(f, "{}", Self::SHAPE.type_identifier)?;
141            if let Some(opts) = opts.for_children() {
142                write!(f, "<")?;
143                (T::SHAPE.vtable.type_name)(f, opts)?;
144                write!(f, ">")?;
145            } else {
146                write!(f, "<…>")?;
147            }
148            Ok(())
149        });
150
151        vtable.try_from = || Some(try_from::<T>);
152        vtable.try_into_inner = || Some(try_into_inner::<T>);
153        vtable.try_borrow_inner = || Some(try_borrow_inner::<T>);
154        vtable
155    };
156
157    const SHAPE: &'static crate::Shape<'static> = &const {
158        // Function to return inner type's shape
159        fn inner_shape<'a, T: Facet<'a> + ?Sized>() -> &'static Shape<'static> {
160            T::SHAPE
161        }
162
163        crate::Shape::builder_for_sized::<Self>()
164            .type_identifier("Arc")
165            .type_params(&[crate::TypeParam {
166                name: "T",
167                shape: || T::SHAPE,
168            }])
169            .ty(Type::User(UserType::Opaque))
170            .def(Def::SmartPointer(
171                SmartPointerDef::builder()
172                    .pointee(|| T::SHAPE)
173                    .flags(SmartPointerFlags::ATOMIC)
174                    .known(KnownSmartPointer::Arc)
175                    .weak(|| <alloc::sync::Weak<T> as Facet>::SHAPE)
176                    .vtable(
177                        &const {
178                            SmartPointerVTable::builder()
179                                .borrow_fn(|this| {
180                                    let ptr = Self::as_ptr(unsafe { this.get() });
181                                    PtrConst::new(ptr)
182                                })
183                                .new_into_fn(|this, ptr| {
184                                    use alloc::sync::Arc;
185
186                                    let layout = match T::SHAPE.layout {
187                                        ShapeLayout::Sized(layout) => layout,
188                                        ShapeLayout::Unsized => panic!("nope"),
189                                    };
190
191                                    let size_of_arc_header = core::mem::size_of::<usize>() * 2;
192
193                                    // we don't know the layout of dummy_arc, but we can tell its size and we can copy it
194                                    // in front of the `PtrMut`
195                                    let arc_layout = unsafe {
196                                        Layout::from_size_align_unchecked(
197                                            size_of_arc_header + layout.size(),
198                                            layout.align(),
199                                        )
200                                    };
201                                    let mem = unsafe { alloc::alloc::alloc(arc_layout) };
202
203                                    unsafe {
204                                        // Copy the Arc header (including refcounts, vtable pointers, etc.) from a freshly-allocated Arc<()>
205                                        // so that the struct before the T value is a valid Arc header.
206                                        let dummy_arc = alloc::sync::Arc::new(());
207                                        let header_start = (Arc::as_ptr(&dummy_arc) as *const u8)
208                                            .sub(size_of_arc_header);
209                                        core::ptr::copy_nonoverlapping(
210                                            header_start,
211                                            mem,
212                                            size_of_arc_header,
213                                        );
214
215                                        // Copy the value for T, pointed to by `ptr`, into the bytes just after the Arc header
216                                        core::ptr::copy_nonoverlapping(
217                                            ptr.as_byte_ptr(),
218                                            mem.add(size_of_arc_header),
219                                            layout.size(),
220                                        );
221                                    }
222
223                                    // Safety: `mem` is valid and contains a valid Arc header and valid T.
224                                    let ptr = unsafe { mem.add(size_of_arc_header) };
225                                    let t_ptr: *mut T = unsafe { core::mem::transmute_copy(&ptr) };
226                                    // Safety: This is the pointer to the Arc header + value; from_raw assumes a pointer to T located immediately after the Arc header.
227                                    let arc = unsafe { Arc::from_raw(t_ptr) };
228                                    // Move the Arc into the destination (this) and return a PtrMut for it.
229                                    unsafe { this.put(arc) }
230                                })
231                                .downgrade_into_fn(|strong, weak| unsafe {
232                                    weak.put(alloc::sync::Arc::downgrade(strong.get::<Self>()))
233                                })
234                                .build()
235                        },
236                    )
237                    .build(),
238            ))
239            .inner(inner_shape::<T>)
240            .build()
241    };
242}
243
244unsafe impl<'a, T: Facet<'a>> Facet<'a> for alloc::sync::Weak<T> {
245    const VTABLE: &'static ValueVTable = &const {
246        value_vtable!(alloc::sync::Weak<T>, |f, opts| {
247            write!(f, "{}", Self::SHAPE.type_identifier)?;
248            if let Some(opts) = opts.for_children() {
249                write!(f, "<")?;
250                (T::SHAPE.vtable.type_name)(f, opts)?;
251                write!(f, ">")?;
252            } else {
253                write!(f, "<…>")?;
254            }
255            Ok(())
256        })
257    };
258
259    const SHAPE: &'static crate::Shape<'static> = &const {
260        // Function to return inner type's shape
261        fn inner_shape<'a, T: Facet<'a> + ?Sized>() -> &'static Shape<'static> {
262            T::SHAPE
263        }
264
265        crate::Shape::builder_for_sized::<Self>()
266            .type_identifier("Weak")
267            .type_params(&[crate::TypeParam {
268                name: "T",
269                shape: || T::SHAPE,
270            }])
271            .ty(Type::User(UserType::Opaque))
272            .def(Def::SmartPointer(
273                SmartPointerDef::builder()
274                    .pointee(|| T::SHAPE)
275                    .flags(SmartPointerFlags::ATOMIC.union(SmartPointerFlags::WEAK))
276                    .known(KnownSmartPointer::ArcWeak)
277                    .strong(|| <alloc::sync::Arc<T> as Facet>::SHAPE)
278                    .vtable(
279                        &const {
280                            SmartPointerVTable::builder()
281                                .upgrade_into_fn(|weak, strong| unsafe {
282                                    Some(strong.put(weak.get::<Self>().upgrade()?))
283                                })
284                                .build()
285                        },
286                    )
287                    .build(),
288            ))
289            .inner(inner_shape::<T>)
290            .build()
291    };
292}
293
294#[cfg(test)]
295mod tests {
296    use alloc::string::String;
297    use alloc::sync::{Arc, Weak as ArcWeak};
298
299    use super::*;
300
301    #[test]
302    fn test_arc_type_params() {
303        let [type_param_1] = <Arc<i32>>::SHAPE.type_params else {
304            panic!("Arc<T> should only have 1 type param")
305        };
306        assert_eq!(type_param_1.shape(), i32::SHAPE);
307    }
308
309    #[test]
310    fn test_arc_vtable_1_new_borrow_drop() -> eyre::Result<()> {
311        facet_testhelpers::setup();
312
313        let arc_shape = <Arc<String>>::SHAPE;
314        let arc_def = arc_shape
315            .def
316            .into_smart_pointer()
317            .expect("Arc<T> should have a smart pointer definition");
318
319        // Allocate memory for the Arc
320        let arc_uninit_ptr = arc_shape.allocate()?;
321
322        // Get the function pointer for creating a new Arc from a value
323        let new_into_fn = arc_def
324            .vtable
325            .new_into_fn
326            .expect("Arc<T> should have new_into_fn");
327
328        // Create the value and initialize the Arc
329        let mut value = String::from("example");
330        let arc_ptr = unsafe { new_into_fn(arc_uninit_ptr, PtrMut::new(&raw mut value)) };
331        // The value now belongs to the Arc, prevent its drop
332        core::mem::forget(value);
333
334        // Get the function pointer for borrowing the inner value
335        let borrow_fn = arc_def
336            .vtable
337            .borrow_fn
338            .expect("Arc<T> should have borrow_fn");
339
340        // Borrow the inner value and check it
341        let borrowed_ptr = unsafe { borrow_fn(arc_ptr.as_const()) };
342        // SAFETY: borrowed_ptr points to a valid String within the Arc
343        assert_eq!(unsafe { borrowed_ptr.get::<String>() }, "example");
344
345        // Get the function pointer for dropping the Arc
346        let drop_fn = (arc_shape.vtable.drop_in_place)().expect("Arc<T> should have drop_in_place");
347
348        // Drop the Arc in place
349        // SAFETY: arc_ptr points to a valid Arc<String>
350        unsafe { drop_fn(arc_ptr) };
351
352        // Deallocate the memory
353        // SAFETY: arc_ptr was allocated by arc_shape and is now dropped (but memory is still valid)
354        unsafe { arc_shape.deallocate_mut(arc_ptr)? };
355
356        Ok(())
357    }
358
359    #[test]
360    fn test_arc_vtable_2_downgrade_upgrade_drop() -> eyre::Result<()> {
361        facet_testhelpers::setup();
362
363        let arc_shape = <Arc<String>>::SHAPE;
364        let arc_def = arc_shape
365            .def
366            .into_smart_pointer()
367            .expect("Arc<T> should have a smart pointer definition");
368
369        let weak_shape = <ArcWeak<String>>::SHAPE;
370        let weak_def = weak_shape
371            .def
372            .into_smart_pointer()
373            .expect("ArcWeak<T> should have a smart pointer definition");
374
375        // 1. Create the first Arc (arc1)
376        let arc1_uninit_ptr = arc_shape.allocate()?;
377        let new_into_fn = arc_def.vtable.new_into_fn.unwrap();
378        let mut value = String::from("example");
379        let arc1_ptr = unsafe { new_into_fn(arc1_uninit_ptr, PtrMut::new(&raw mut value)) };
380        core::mem::forget(value); // Value now owned by arc1
381
382        // 2. Downgrade arc1 to create a weak pointer (weak1)
383        let weak1_uninit_ptr = weak_shape.allocate()?;
384        let downgrade_into_fn = arc_def.vtable.downgrade_into_fn.unwrap();
385        // SAFETY: arc1_ptr points to a valid Arc, weak1_uninit_ptr is allocated for a Weak
386        let weak1_ptr = unsafe { downgrade_into_fn(arc1_ptr, weak1_uninit_ptr) };
387
388        // 3. Upgrade weak1 to create a second Arc (arc2)
389        let arc2_uninit_ptr = arc_shape.allocate()?;
390        let upgrade_into_fn = weak_def.vtable.upgrade_into_fn.unwrap();
391        // SAFETY: weak1_ptr points to a valid Weak, arc2_uninit_ptr is allocated for an Arc.
392        // Upgrade should succeed as arc1 still exists.
393        let arc2_ptr = unsafe { upgrade_into_fn(weak1_ptr, arc2_uninit_ptr) }
394            .expect("Upgrade should succeed while original Arc exists");
395
396        // Check the content of the upgraded Arc
397        let borrow_fn = arc_def.vtable.borrow_fn.unwrap();
398        // SAFETY: arc2_ptr points to a valid Arc<String>
399        let borrowed_ptr = unsafe { borrow_fn(arc2_ptr.as_const()) };
400        // SAFETY: borrowed_ptr points to a valid String
401        assert_eq!(unsafe { borrowed_ptr.get::<String>() }, "example");
402
403        // 4. Drop everything and free memory
404        let arc_drop_fn = (arc_shape.vtable.drop_in_place)().unwrap();
405        let weak_drop_fn = (weak_shape.vtable.drop_in_place)().unwrap();
406
407        unsafe {
408            // Drop Arcs
409            arc_drop_fn(arc1_ptr);
410            arc_shape.deallocate_mut(arc1_ptr)?;
411            arc_drop_fn(arc2_ptr);
412            arc_shape.deallocate_mut(arc2_ptr)?;
413
414            // Drop Weak
415            weak_drop_fn(weak1_ptr);
416            weak_shape.deallocate_mut(weak1_ptr)?;
417        }
418
419        Ok(())
420    }
421
422    #[test]
423    fn test_arc_vtable_3_downgrade_drop_try_upgrade() -> eyre::Result<()> {
424        facet_testhelpers::setup();
425
426        let arc_shape = <Arc<String>>::SHAPE;
427        let arc_def = arc_shape
428            .def
429            .into_smart_pointer()
430            .expect("Arc<T> should have a smart pointer definition");
431
432        let weak_shape = <ArcWeak<String>>::SHAPE;
433        let weak_def = weak_shape
434            .def
435            .into_smart_pointer()
436            .expect("ArcWeak<T> should have a smart pointer definition");
437
438        // 1. Create the strong Arc (arc1)
439        let arc1_uninit_ptr = arc_shape.allocate()?;
440        let new_into_fn = arc_def.vtable.new_into_fn.unwrap();
441        let mut value = String::from("example");
442        let arc1_ptr = unsafe { new_into_fn(arc1_uninit_ptr, PtrMut::new(&raw mut value)) };
443        core::mem::forget(value);
444
445        // 2. Downgrade arc1 to create a weak pointer (weak1)
446        let weak1_uninit_ptr = weak_shape.allocate()?;
447        let downgrade_into_fn = arc_def.vtable.downgrade_into_fn.unwrap();
448        // SAFETY: arc1_ptr is valid, weak1_uninit_ptr is allocated for Weak
449        let weak1_ptr = unsafe { downgrade_into_fn(arc1_ptr, weak1_uninit_ptr) };
450
451        // 3. Drop and free the strong pointer (arc1)
452        let arc_drop_fn = (arc_shape.vtable.drop_in_place)().unwrap();
453        unsafe {
454            arc_drop_fn(arc1_ptr);
455            arc_shape.deallocate_mut(arc1_ptr)?;
456        }
457
458        // 4. Attempt to upgrade the weak pointer (weak1)
459        let upgrade_into_fn = weak_def.vtable.upgrade_into_fn.unwrap();
460        let arc2_uninit_ptr = arc_shape.allocate()?;
461        // SAFETY: weak1_ptr is valid (though points to dropped data), arc2_uninit_ptr is allocated for Arc
462        let upgrade_result = unsafe { upgrade_into_fn(weak1_ptr, arc2_uninit_ptr) };
463
464        // Assert that the upgrade failed
465        assert!(
466            upgrade_result.is_none(),
467            "Upgrade should fail after the strong Arc is dropped"
468        );
469
470        // 5. Clean up: Deallocate the memory intended for the failed upgrade and drop/deallocate the weak pointer
471        let weak_drop_fn = (weak_shape.vtable.drop_in_place)().unwrap();
472        unsafe {
473            // Deallocate the *uninitialized* memory allocated for the failed upgrade attempt
474            arc_shape.deallocate_uninit(arc2_uninit_ptr)?;
475
476            // Drop and deallocate the weak pointer
477            weak_drop_fn(weak1_ptr);
478            weak_shape.deallocate_mut(weak1_ptr)?;
479        }
480
481        Ok(())
482    }
483
484    #[test]
485    fn test_arc_vtable_4_try_from() -> eyre::Result<()> {
486        facet_testhelpers::setup();
487
488        // Get the shapes we'll be working with
489        let string_shape = <String>::SHAPE;
490        let arc_shape = <Arc<String>>::SHAPE;
491        let arc_def = arc_shape
492            .def
493            .into_smart_pointer()
494            .expect("Arc<T> should have a smart pointer definition");
495
496        // 1. Create a String value
497        let value = String::from("try_from test");
498        let value_ptr = PtrConst::new(&value as *const String as *const u8);
499
500        // 2. Allocate memory for the Arc<String>
501        let arc_uninit_ptr = arc_shape.allocate()?;
502
503        // 3. Get the try_from function from the Arc<String> shape's ValueVTable
504        let try_from_fn = (arc_shape.vtable.try_from)().expect("Arc<T> should have try_from");
505
506        // 4. Try to convert String to Arc<String>
507        let arc_ptr = unsafe { try_from_fn(value_ptr, string_shape, arc_uninit_ptr) }
508            .expect("try_from should succeed");
509        core::mem::forget(value);
510
511        // 5. Borrow the inner value and verify it's correct
512        let borrow_fn = arc_def
513            .vtable
514            .borrow_fn
515            .expect("Arc<T> should have borrow_fn");
516        let borrowed_ptr = unsafe { borrow_fn(arc_ptr.as_const()) };
517
518        // SAFETY: borrowed_ptr points to a valid String within the Arc
519        assert_eq!(unsafe { borrowed_ptr.get::<String>() }, "try_from test");
520
521        // 6. Clean up
522        let drop_fn = (arc_shape.vtable.drop_in_place)().expect("Arc<T> should have drop_in_place");
523
524        unsafe {
525            drop_fn(arc_ptr);
526            arc_shape.deallocate_mut(arc_ptr)?;
527        }
528
529        Ok(())
530    }
531
532    #[test]
533    fn test_arc_vtable_5_try_into_inner() -> eyre::Result<()> {
534        facet_testhelpers::setup();
535
536        // Get the shapes we'll be working with
537        let string_shape = <String>::SHAPE;
538        let arc_shape = <Arc<String>>::SHAPE;
539        let arc_def = arc_shape
540            .def
541            .into_smart_pointer()
542            .expect("Arc<T> should have a smart pointer definition");
543
544        // 1. Create an Arc<String>
545        let arc_uninit_ptr = arc_shape.allocate()?;
546        let new_into_fn = arc_def
547            .vtable
548            .new_into_fn
549            .expect("Arc<T> should have new_into_fn");
550
551        let mut value = String::from("try_into_inner test");
552        let arc_ptr = unsafe { new_into_fn(arc_uninit_ptr, PtrMut::new(&raw mut value)) };
553        core::mem::forget(value); // Value now owned by arc
554
555        // 2. Allocate memory for the extracted String
556        let string_uninit_ptr = string_shape.allocate()?;
557
558        // 3. Get the try_into_inner function from the Arc<String>'s ValueVTable
559        let try_into_inner_fn =
560            (arc_shape.vtable.try_into_inner)().expect("Arc<T> Shape should have try_into_inner");
561
562        // 4. Try to extract the String from the Arc<String>
563        // This should succeed because we have exclusive access to the Arc (strong count = 1)
564        let string_ptr = unsafe { try_into_inner_fn(arc_ptr, string_uninit_ptr) }
565            .expect("try_into_inner should succeed with exclusive access");
566
567        // 5. Verify the extracted String
568        assert_eq!(
569            unsafe { string_ptr.as_const().get::<String>() },
570            "try_into_inner test"
571        );
572
573        // 6. Clean up
574        let string_drop_fn =
575            (string_shape.vtable.drop_in_place)().expect("String should have drop_in_place");
576
577        unsafe {
578            // The Arc should already be dropped by try_into_inner
579            // But we still need to deallocate its memory
580            arc_shape.deallocate_mut(arc_ptr)?;
581
582            // Drop and deallocate the extracted String
583            string_drop_fn(string_ptr);
584            string_shape.deallocate_mut(string_ptr)?;
585        }
586
587        Ok(())
588    }
589}