facet_core/impls_alloc/
arc.rs

1use core::alloc::Layout;
2
3use crate::{
4    Def, Facet, KnownSmartPointer, PtrConst, PtrMut, PtrUninit, Shape, ShapeLayout,
5    SmartPointerDef, SmartPointerFlags, SmartPointerVTable, TryBorrowInnerError, TryFromError,
6    TryIntoInnerError, Type, UserType, ValueVTable, value_vtable,
7};
8
9unsafe impl<'a, T: Facet<'a> + ?Sized> Facet<'a> for alloc::sync::Arc<T> {
10    const VTABLE: &'static ValueVTable = &const {
11        // Define the functions for transparent conversion between Arc<T> and T
12        unsafe fn try_from<'a, 'src, 'dst, T: Facet<'a> + ?Sized>(
13            src_ptr: PtrConst<'src>,
14            src_shape: &'static Shape,
15            dst: PtrUninit<'dst>,
16        ) -> Result<PtrMut<'dst>, TryFromError> {
17            if src_shape.id != T::SHAPE.id {
18                return Err(TryFromError::UnsupportedSourceShape {
19                    src_shape,
20                    expected: &[T::SHAPE],
21                });
22            }
23
24            if let ShapeLayout::Unsized = T::SHAPE.layout {
25                panic!("can't try_from with unsized type");
26            }
27
28            use alloc::sync::Arc;
29
30            // Get the layout for T
31            let layout = match T::SHAPE.layout {
32                ShapeLayout::Sized(layout) => layout,
33                ShapeLayout::Unsized => panic!("Unsized type not supported"),
34            };
35
36            // We'll create a new memory location, copy the value, then create an Arc from it
37            let size_of_arc_header = core::mem::size_of::<usize>() * 2;
38
39            // Allocate memory for the Arc header plus the value
40            let arc_layout = unsafe {
41                Layout::from_size_align_unchecked(
42                    size_of_arc_header + layout.size(),
43                    layout.align(),
44                )
45            };
46            let mem = unsafe { alloc::alloc::alloc(arc_layout) };
47
48            unsafe {
49                // Copy the Arc header (refcounts, vtable pointer, etc.) from a dummy Arc<()>
50                let dummy_arc = Arc::new(());
51                let header_start = (Arc::as_ptr(&dummy_arc) as *const u8).sub(size_of_arc_header);
52                core::ptr::copy_nonoverlapping(header_start, mem, size_of_arc_header);
53
54                // Copy the source value into the memory area just after the Arc header
55                core::ptr::copy_nonoverlapping(
56                    src_ptr.as_byte_ptr(),
57                    mem.add(size_of_arc_header),
58                    layout.size(),
59                );
60            }
61
62            // Create an Arc from our allocated and initialized memory
63            let ptr = unsafe { mem.add(size_of_arc_header) };
64            let t_ptr: *mut T = unsafe { core::mem::transmute_copy(&ptr) };
65            let arc = unsafe { Arc::from_raw(t_ptr) };
66
67            // Move the Arc into the destination and return a PtrMut for it
68            Ok(unsafe { dst.put(arc) })
69        }
70
71        unsafe fn try_into_inner<'a, 'src, 'dst, T: Facet<'a> + ?Sized>(
72            src_ptr: PtrMut<'src>,
73            dst: PtrUninit<'dst>,
74        ) -> Result<PtrMut<'dst>, TryIntoInnerError> {
75            use alloc::sync::Arc;
76
77            // Read the Arc from the source pointer
78            let mut arc = unsafe { src_ptr.read::<Arc<T>>() };
79
80            // For unsized types, we need to know how many bytes to copy.
81            let size = match T::SHAPE.layout {
82                ShapeLayout::Sized(layout) => layout.size(),
83                _ => panic!("cannot try_into_inner with unsized type"),
84            };
85
86            // Check if we have exclusive access to the Arc (strong count = 1)
87            if let Some(inner_ref) = Arc::get_mut(&mut arc) {
88                // We have exclusive access, so we can safely copy the inner value
89                let inner_ptr = inner_ref as *const T as *const u8;
90
91                unsafe {
92                    // Copy the inner value to the destination
93                    core::ptr::copy_nonoverlapping(inner_ptr, dst.as_mut_byte_ptr(), size);
94
95                    // Prevent dropping the Arc normally which would also drop the inner value
96                    // that we've already copied
97                    let raw_ptr = Arc::into_raw(arc);
98
99                    // We need to deallocate the Arc without running destructors
100                    // Get the Arc layout
101                    let size_of_arc_header = core::mem::size_of::<usize>() * 2;
102                    let layout = match T::SHAPE.layout {
103                        ShapeLayout::Sized(layout) => layout,
104                        _ => unreachable!("We already checked that T is sized"),
105                    };
106                    let arc_layout = Layout::from_size_align_unchecked(
107                        size_of_arc_header + size,
108                        layout.align(),
109                    );
110
111                    // Get the start of the allocation (header is before the data)
112                    let allocation_start = (raw_ptr as *mut u8).sub(size_of_arc_header);
113
114                    // Deallocate the memory without running any destructors
115                    alloc::alloc::dealloc(allocation_start, arc_layout);
116
117                    // Return a PtrMut to the destination, which now owns the value
118                    Ok(PtrMut::new(dst.as_mut_byte_ptr()))
119                }
120            } else {
121                // Arc is shared, so we can't extract the inner value
122                core::mem::forget(arc);
123                Err(TryIntoInnerError::Unavailable)
124            }
125        }
126
127        unsafe fn try_borrow_inner<'a, 'src, T: Facet<'a> + ?Sized>(
128            src_ptr: PtrConst<'src>,
129        ) -> Result<PtrConst<'src>, TryBorrowInnerError> {
130            let arc = unsafe { src_ptr.get::<alloc::sync::Arc<T>>() };
131            Ok(PtrConst::new(&**arc))
132        }
133
134        let mut vtable = value_vtable!(alloc::sync::Arc<T>, |f, opts| {
135            write!(f, "Arc")?;
136            if let Some(opts) = opts.for_children() {
137                write!(f, "<")?;
138                (T::SHAPE.vtable.type_name)(f, opts)?;
139                write!(f, ">")?;
140            } else {
141                write!(f, "<…>")?;
142            }
143            Ok(())
144        });
145
146        vtable.try_from = Some(try_from::<T>);
147        vtable.try_into_inner = Some(try_into_inner::<T>);
148        vtable.try_borrow_inner = Some(try_borrow_inner::<T>);
149        vtable
150    };
151
152    const SHAPE: &'static crate::Shape = &const {
153        // Function to return inner type's shape
154        fn inner_shape<'a, T: Facet<'a> + ?Sized>() -> &'static Shape {
155            T::SHAPE
156        }
157
158        crate::Shape::builder_for_sized::<Self>()
159            .type_params(&[crate::TypeParam {
160                name: "T",
161                shape: || T::SHAPE,
162            }])
163            .ty(Type::User(UserType::Opaque))
164            .def(Def::SmartPointer(
165                SmartPointerDef::builder()
166                    .pointee(|| T::SHAPE)
167                    .flags(SmartPointerFlags::ATOMIC)
168                    .known(KnownSmartPointer::Arc)
169                    .weak(|| <alloc::sync::Weak<T> as Facet>::SHAPE)
170                    .vtable(
171                        &const {
172                            SmartPointerVTable::builder()
173                                .borrow_fn(|this| {
174                                    let ptr = Self::as_ptr(unsafe { this.get() });
175                                    PtrConst::new(ptr)
176                                })
177                                .new_into_fn(|this, ptr| {
178                                    use alloc::sync::Arc;
179
180                                    let layout = match T::SHAPE.layout {
181                                        ShapeLayout::Sized(layout) => layout,
182                                        ShapeLayout::Unsized => panic!("nope"),
183                                    };
184
185                                    let size_of_arc_header = core::mem::size_of::<usize>() * 2;
186
187                                    // we don't know the layout of dummy_arc, but we can tell its size and we can copy it
188                                    // in front of the `PtrMut`
189                                    let arc_layout = unsafe {
190                                        Layout::from_size_align_unchecked(
191                                            size_of_arc_header + layout.size(),
192                                            layout.align(),
193                                        )
194                                    };
195                                    let mem = unsafe { alloc::alloc::alloc(arc_layout) };
196
197                                    unsafe {
198                                        // Copy the Arc header (including refcounts, vtable pointers, etc.) from a freshly-allocated Arc<()>
199                                        // so that the struct before the T value is a valid Arc header.
200                                        let dummy_arc = alloc::sync::Arc::new(());
201                                        let header_start = (Arc::as_ptr(&dummy_arc) as *const u8)
202                                            .sub(size_of_arc_header);
203                                        core::ptr::copy_nonoverlapping(
204                                            header_start,
205                                            mem,
206                                            size_of_arc_header,
207                                        );
208
209                                        // Copy the value for T, pointed to by `ptr`, into the bytes just after the Arc header
210                                        core::ptr::copy_nonoverlapping(
211                                            ptr.as_byte_ptr(),
212                                            mem.add(size_of_arc_header),
213                                            layout.size(),
214                                        );
215                                    }
216
217                                    // Safety: `mem` is valid and contains a valid Arc header and valid T.
218                                    let ptr = unsafe { mem.add(size_of_arc_header) };
219                                    let t_ptr: *mut T = unsafe { core::mem::transmute_copy(&ptr) };
220                                    // Safety: This is the pointer to the Arc header + value; from_raw assumes a pointer to T located immediately after the Arc header.
221                                    let arc = unsafe { Arc::from_raw(t_ptr) };
222                                    // Move the Arc into the destination (this) and return a PtrMut for it.
223                                    unsafe { this.put(arc) }
224                                })
225                                .downgrade_into_fn(|strong, weak| unsafe {
226                                    weak.put(alloc::sync::Arc::downgrade(strong.get::<Self>()))
227                                })
228                                .build()
229                        },
230                    )
231                    .build(),
232            ))
233            .inner(inner_shape::<T>)
234            .build()
235    };
236}
237
238unsafe impl<'a, T: Facet<'a> + ?Sized> Facet<'a> for alloc::sync::Weak<T> {
239    const VTABLE: &'static ValueVTable = &const {
240        value_vtable!(alloc::sync::Weak<T>, |f, opts| {
241            write!(f, "Weak")?;
242            if let Some(opts) = opts.for_children() {
243                write!(f, "<")?;
244                (T::SHAPE.vtable.type_name)(f, opts)?;
245                write!(f, ">")?;
246            } else {
247                write!(f, "<…>")?;
248            }
249            Ok(())
250        })
251    };
252
253    const SHAPE: &'static crate::Shape = &const {
254        // Function to return inner type's shape
255        fn inner_shape<'a, T: Facet<'a> + ?Sized>() -> &'static Shape {
256            T::SHAPE
257        }
258
259        crate::Shape::builder_for_sized::<Self>()
260            .type_params(&[crate::TypeParam {
261                name: "T",
262                shape: || T::SHAPE,
263            }])
264            .ty(Type::User(UserType::Opaque))
265            .def(Def::SmartPointer(
266                SmartPointerDef::builder()
267                    .pointee(|| T::SHAPE)
268                    .flags(SmartPointerFlags::ATOMIC.union(SmartPointerFlags::WEAK))
269                    .known(KnownSmartPointer::ArcWeak)
270                    .strong(|| <alloc::sync::Arc<T> as Facet>::SHAPE)
271                    .vtable(
272                        &const {
273                            SmartPointerVTable::builder()
274                                .upgrade_into_fn(|weak, strong| unsafe {
275                                    Some(strong.put(weak.get::<Self>().upgrade()?))
276                                })
277                                .build()
278                        },
279                    )
280                    .build(),
281            ))
282            .inner(inner_shape::<T>)
283            .build()
284    };
285}
286
287#[cfg(test)]
288mod tests {
289    use alloc::string::String;
290    use alloc::sync::{Arc, Weak as ArcWeak};
291
292    use super::*;
293
294    #[test]
295    fn test_arc_type_params() {
296        let [type_param_1] = <Arc<i32>>::SHAPE.type_params else {
297            panic!("Arc<T> should only have 1 type param")
298        };
299        assert_eq!(type_param_1.shape(), i32::SHAPE);
300    }
301
302    #[test]
303    fn test_arc_vtable_1_new_borrow_drop() -> eyre::Result<()> {
304        facet_testhelpers::setup();
305
306        let arc_shape = <Arc<String>>::SHAPE;
307        let arc_def = arc_shape
308            .def
309            .into_smart_pointer()
310            .expect("Arc<T> should have a smart pointer definition");
311
312        // Allocate memory for the Arc
313        let arc_uninit_ptr = arc_shape.allocate()?;
314
315        // Get the function pointer for creating a new Arc from a value
316        let new_into_fn = arc_def
317            .vtable
318            .new_into_fn
319            .expect("Arc<T> should have new_into_fn");
320
321        // Create the value and initialize the Arc
322        let mut value = String::from("example");
323        let arc_ptr = unsafe { new_into_fn(arc_uninit_ptr, PtrMut::new(&raw mut value)) };
324        // The value now belongs to the Arc, prevent its drop
325        core::mem::forget(value);
326
327        // Get the function pointer for borrowing the inner value
328        let borrow_fn = arc_def
329            .vtable
330            .borrow_fn
331            .expect("Arc<T> should have borrow_fn");
332
333        // Borrow the inner value and check it
334        let borrowed_ptr = unsafe { borrow_fn(arc_ptr.as_const()) };
335        // SAFETY: borrowed_ptr points to a valid String within the Arc
336        assert_eq!(unsafe { borrowed_ptr.get::<String>() }, "example");
337
338        // Get the function pointer for dropping the Arc
339        let drop_fn = arc_shape
340            .vtable
341            .drop_in_place
342            .expect("Arc<T> should have drop_in_place");
343
344        // Drop the Arc in place
345        // SAFETY: arc_ptr points to a valid Arc<String>
346        unsafe { drop_fn(arc_ptr) };
347
348        // Deallocate the memory
349        // SAFETY: arc_ptr was allocated by arc_shape and is now dropped (but memory is still valid)
350        unsafe { arc_shape.deallocate_mut(arc_ptr)? };
351
352        Ok(())
353    }
354
355    #[test]
356    fn test_arc_vtable_2_downgrade_upgrade_drop() -> eyre::Result<()> {
357        facet_testhelpers::setup();
358
359        let arc_shape = <Arc<String>>::SHAPE;
360        let arc_def = arc_shape
361            .def
362            .into_smart_pointer()
363            .expect("Arc<T> should have a smart pointer definition");
364
365        let weak_shape = <ArcWeak<String>>::SHAPE;
366        let weak_def = weak_shape
367            .def
368            .into_smart_pointer()
369            .expect("ArcWeak<T> should have a smart pointer definition");
370
371        // 1. Create the first Arc (arc1)
372        let arc1_uninit_ptr = arc_shape.allocate()?;
373        let new_into_fn = arc_def.vtable.new_into_fn.unwrap();
374        let mut value = String::from("example");
375        let arc1_ptr = unsafe { new_into_fn(arc1_uninit_ptr, PtrMut::new(&raw mut value)) };
376        core::mem::forget(value); // Value now owned by arc1
377
378        // 2. Downgrade arc1 to create a weak pointer (weak1)
379        let weak1_uninit_ptr = weak_shape.allocate()?;
380        let downgrade_into_fn = arc_def.vtable.downgrade_into_fn.unwrap();
381        // SAFETY: arc1_ptr points to a valid Arc, weak1_uninit_ptr is allocated for a Weak
382        let weak1_ptr = unsafe { downgrade_into_fn(arc1_ptr, weak1_uninit_ptr) };
383
384        // 3. Upgrade weak1 to create a second Arc (arc2)
385        let arc2_uninit_ptr = arc_shape.allocate()?;
386        let upgrade_into_fn = weak_def.vtable.upgrade_into_fn.unwrap();
387        // SAFETY: weak1_ptr points to a valid Weak, arc2_uninit_ptr is allocated for an Arc.
388        // Upgrade should succeed as arc1 still exists.
389        let arc2_ptr = unsafe { upgrade_into_fn(weak1_ptr, arc2_uninit_ptr) }
390            .expect("Upgrade should succeed while original Arc exists");
391
392        // Check the content of the upgraded Arc
393        let borrow_fn = arc_def.vtable.borrow_fn.unwrap();
394        // SAFETY: arc2_ptr points to a valid Arc<String>
395        let borrowed_ptr = unsafe { borrow_fn(arc2_ptr.as_const()) };
396        // SAFETY: borrowed_ptr points to a valid String
397        assert_eq!(unsafe { borrowed_ptr.get::<String>() }, "example");
398
399        // 4. Drop everything and free memory
400        let arc_drop_fn = arc_shape.vtable.drop_in_place.unwrap();
401        let weak_drop_fn = weak_shape.vtable.drop_in_place.unwrap();
402
403        unsafe {
404            // Drop Arcs
405            arc_drop_fn(arc1_ptr);
406            arc_shape.deallocate_mut(arc1_ptr)?;
407            arc_drop_fn(arc2_ptr);
408            arc_shape.deallocate_mut(arc2_ptr)?;
409
410            // Drop Weak
411            weak_drop_fn(weak1_ptr);
412            weak_shape.deallocate_mut(weak1_ptr)?;
413        }
414
415        Ok(())
416    }
417
418    #[test]
419    fn test_arc_vtable_3_downgrade_drop_try_upgrade() -> eyre::Result<()> {
420        facet_testhelpers::setup();
421
422        let arc_shape = <Arc<String>>::SHAPE;
423        let arc_def = arc_shape
424            .def
425            .into_smart_pointer()
426            .expect("Arc<T> should have a smart pointer definition");
427
428        let weak_shape = <ArcWeak<String>>::SHAPE;
429        let weak_def = weak_shape
430            .def
431            .into_smart_pointer()
432            .expect("ArcWeak<T> should have a smart pointer definition");
433
434        // 1. Create the strong Arc (arc1)
435        let arc1_uninit_ptr = arc_shape.allocate()?;
436        let new_into_fn = arc_def.vtable.new_into_fn.unwrap();
437        let mut value = String::from("example");
438        let arc1_ptr = unsafe { new_into_fn(arc1_uninit_ptr, PtrMut::new(&raw mut value)) };
439        core::mem::forget(value);
440
441        // 2. Downgrade arc1 to create a weak pointer (weak1)
442        let weak1_uninit_ptr = weak_shape.allocate()?;
443        let downgrade_into_fn = arc_def.vtable.downgrade_into_fn.unwrap();
444        // SAFETY: arc1_ptr is valid, weak1_uninit_ptr is allocated for Weak
445        let weak1_ptr = unsafe { downgrade_into_fn(arc1_ptr, weak1_uninit_ptr) };
446
447        // 3. Drop and free the strong pointer (arc1)
448        let arc_drop_fn = arc_shape.vtable.drop_in_place.unwrap();
449        unsafe {
450            arc_drop_fn(arc1_ptr);
451            arc_shape.deallocate_mut(arc1_ptr)?;
452        }
453
454        // 4. Attempt to upgrade the weak pointer (weak1)
455        let upgrade_into_fn = weak_def.vtable.upgrade_into_fn.unwrap();
456        let arc2_uninit_ptr = arc_shape.allocate()?;
457        // SAFETY: weak1_ptr is valid (though points to dropped data), arc2_uninit_ptr is allocated for Arc
458        let upgrade_result = unsafe { upgrade_into_fn(weak1_ptr, arc2_uninit_ptr) };
459
460        // Assert that the upgrade failed
461        assert!(
462            upgrade_result.is_none(),
463            "Upgrade should fail after the strong Arc is dropped"
464        );
465
466        // 5. Clean up: Deallocate the memory intended for the failed upgrade and drop/deallocate the weak pointer
467        let weak_drop_fn = weak_shape.vtable.drop_in_place.unwrap();
468        unsafe {
469            // Deallocate the *uninitialized* memory allocated for the failed upgrade attempt
470            arc_shape.deallocate_uninit(arc2_uninit_ptr)?;
471
472            // Drop and deallocate the weak pointer
473            weak_drop_fn(weak1_ptr);
474            weak_shape.deallocate_mut(weak1_ptr)?;
475        }
476
477        Ok(())
478    }
479
480    #[test]
481    fn test_arc_vtable_4_try_from() -> eyre::Result<()> {
482        facet_testhelpers::setup();
483
484        // Get the shapes we'll be working with
485        let string_shape = <String>::SHAPE;
486        let arc_shape = <Arc<String>>::SHAPE;
487        let arc_def = arc_shape
488            .def
489            .into_smart_pointer()
490            .expect("Arc<T> should have a smart pointer definition");
491
492        // 1. Create a String value
493        let value = String::from("try_from test");
494        let value_ptr = PtrConst::new(&value as *const String as *const u8);
495
496        // 2. Allocate memory for the Arc<String>
497        let arc_uninit_ptr = arc_shape.allocate()?;
498
499        // 3. Get the try_from function from the Arc<String> shape's ValueVTable
500        let try_from_fn = arc_shape
501            .vtable
502            .try_from
503            .expect("Arc<T> should have try_from");
504
505        // 4. Try to convert String to Arc<String>
506        let arc_ptr = unsafe { try_from_fn(value_ptr, string_shape, arc_uninit_ptr) }
507            .expect("try_from should succeed");
508        core::mem::forget(value);
509
510        // 5. Borrow the inner value and verify it's correct
511        let borrow_fn = arc_def
512            .vtable
513            .borrow_fn
514            .expect("Arc<T> should have borrow_fn");
515        let borrowed_ptr = unsafe { borrow_fn(arc_ptr.as_const()) };
516
517        // SAFETY: borrowed_ptr points to a valid String within the Arc
518        assert_eq!(unsafe { borrowed_ptr.get::<String>() }, "try_from test");
519
520        // 6. Clean up
521        let drop_fn = arc_shape
522            .vtable
523            .drop_in_place
524            .expect("Arc<T> should have drop_in_place");
525
526        unsafe {
527            drop_fn(arc_ptr);
528            arc_shape.deallocate_mut(arc_ptr)?;
529        }
530
531        Ok(())
532    }
533
534    #[test]
535    fn test_arc_vtable_5_try_into_inner() -> eyre::Result<()> {
536        facet_testhelpers::setup();
537
538        // Get the shapes we'll be working with
539        let string_shape = <String>::SHAPE;
540        let arc_shape = <Arc<String>>::SHAPE;
541        let arc_def = arc_shape
542            .def
543            .into_smart_pointer()
544            .expect("Arc<T> should have a smart pointer definition");
545
546        // 1. Create an Arc<String>
547        let arc_uninit_ptr = arc_shape.allocate()?;
548        let new_into_fn = arc_def
549            .vtable
550            .new_into_fn
551            .expect("Arc<T> should have new_into_fn");
552
553        let mut value = String::from("try_into_inner test");
554        let arc_ptr = unsafe { new_into_fn(arc_uninit_ptr, PtrMut::new(&raw mut value)) };
555        core::mem::forget(value); // Value now owned by arc
556
557        // 2. Allocate memory for the extracted String
558        let string_uninit_ptr = string_shape.allocate()?;
559
560        // 3. Get the try_into_inner function from the Arc<String>'s ValueVTable
561        let try_into_inner_fn = arc_shape
562            .vtable
563            .try_into_inner
564            .expect("Arc<T> Shape should have try_into_inner");
565
566        // 4. Try to extract the String from the Arc<String>
567        // This should succeed because we have exclusive access to the Arc (strong count = 1)
568        let string_ptr = unsafe { try_into_inner_fn(arc_ptr, string_uninit_ptr) }
569            .expect("try_into_inner should succeed with exclusive access");
570
571        // 5. Verify the extracted String
572        assert_eq!(
573            unsafe { string_ptr.as_const().get::<String>() },
574            "try_into_inner test"
575        );
576
577        // 6. Clean up
578        let string_drop_fn = string_shape
579            .vtable
580            .drop_in_place
581            .expect("String should have drop_in_place");
582
583        unsafe {
584            // The Arc should already be dropped by try_into_inner
585            // But we still need to deallocate its memory
586            arc_shape.deallocate_mut(arc_ptr)?;
587
588            // Drop and deallocate the extracted String
589            string_drop_fn(string_ptr);
590            string_shape.deallocate_mut(string_ptr)?;
591        }
592
593        Ok(())
594    }
595}