facet_core/impls_alloc/
arc.rs

1use crate::{
2    Def, Facet, KnownSmartPointer, PtrConst, PtrMut, PtrUninit, Shape, SmartPointerDef,
3    SmartPointerFlags, SmartPointerVTable, TryBorrowInnerError, TryFromError, TryIntoInnerError,
4    Type, UserType, ValueVTable, value_vtable,
5};
6
7unsafe impl<'a, T: Facet<'a>> Facet<'a> for alloc::sync::Arc<T> {
8    const VTABLE: &'static ValueVTable = &const {
9        // Define the functions for transparent conversion between Arc<T> and T
10        unsafe fn try_from<'a, 'shape, 'src, 'dst, T: Facet<'a>>(
11            src_ptr: PtrConst<'src>,
12            src_shape: &'shape Shape<'shape>,
13            dst: PtrUninit<'dst>,
14        ) -> Result<PtrMut<'dst>, TryFromError<'shape>> {
15            if src_shape.id != T::SHAPE.id {
16                return Err(TryFromError::UnsupportedSourceShape {
17                    src_shape,
18                    expected: &[T::SHAPE],
19                });
20            }
21            let t = unsafe { src_ptr.read::<T>() };
22            let arc = alloc::sync::Arc::new(t);
23            Ok(unsafe { dst.put(arc) })
24        }
25
26        unsafe fn try_into_inner<'a, 'src, 'dst, T: Facet<'a>>(
27            src_ptr: PtrMut<'src>,
28            dst: PtrUninit<'dst>,
29        ) -> Result<PtrMut<'dst>, TryIntoInnerError> {
30            use alloc::sync::Arc;
31
32            // Read the Arc from the source pointer
33            let arc = unsafe { src_ptr.read::<Arc<T>>() };
34
35            // Try to unwrap the Arc to get exclusive ownership
36            match Arc::try_unwrap(arc) {
37                Ok(inner) => Ok(unsafe { dst.put(inner) }),
38                Err(arc) => {
39                    // Arc is shared, so we can't extract the inner value
40                    core::mem::forget(arc);
41                    Err(TryIntoInnerError::Unavailable)
42                }
43            }
44        }
45
46        unsafe fn try_borrow_inner<'a, 'src, T: Facet<'a>>(
47            src_ptr: PtrConst<'src>,
48        ) -> Result<PtrConst<'src>, TryBorrowInnerError> {
49            let arc = unsafe { src_ptr.get::<alloc::sync::Arc<T>>() };
50            Ok(PtrConst::new(&**arc))
51        }
52
53        let mut vtable = value_vtable!(alloc::sync::Arc<T>, |f, opts| {
54            write!(f, "{}", Self::SHAPE.type_identifier)?;
55            if let Some(opts) = opts.for_children() {
56                write!(f, "<")?;
57                (T::SHAPE.vtable.type_name())(f, opts)?;
58                write!(f, ">")?;
59            } else {
60                write!(f, "<…>")?;
61            }
62            Ok(())
63        });
64
65        {
66            let vtable = vtable.sized_mut().unwrap();
67            vtable.try_from = || Some(try_from::<T>);
68            vtable.try_into_inner = || Some(try_into_inner::<T>);
69            vtable.try_borrow_inner = || Some(try_borrow_inner::<T>);
70        }
71        vtable
72    };
73
74    const SHAPE: &'static crate::Shape<'static> = &const {
75        // Function to return inner type's shape
76        fn inner_shape<'a, T: Facet<'a>>() -> &'static Shape<'static> {
77            T::SHAPE
78        }
79
80        crate::Shape::builder_for_sized::<Self>()
81            .type_identifier("Arc")
82            .type_params(&[crate::TypeParam {
83                name: "T",
84                shape: || T::SHAPE,
85            }])
86            .ty(Type::User(UserType::Opaque))
87            .def(Def::SmartPointer(
88                SmartPointerDef::builder()
89                    .pointee(|| T::SHAPE)
90                    .flags(SmartPointerFlags::ATOMIC)
91                    .known(KnownSmartPointer::Arc)
92                    .weak(|| <alloc::sync::Weak<T> as Facet>::SHAPE)
93                    .vtable(
94                        &const {
95                            SmartPointerVTable::builder()
96                                .borrow_fn(|this| {
97                                    let arc_ptr = unsafe { this.as_ptr::<alloc::sync::Arc<T>>() };
98                                    let ptr = unsafe { alloc::sync::Arc::as_ptr(&*arc_ptr) };
99                                    PtrConst::new(ptr)
100                                })
101                                .new_into_fn(|this, ptr| {
102                                    let t = unsafe { ptr.read::<T>() };
103                                    let arc = alloc::sync::Arc::new(t);
104                                    unsafe { this.put(arc) }
105                                })
106                                .downgrade_into_fn(|strong, weak| unsafe {
107                                    weak.put(alloc::sync::Arc::downgrade(strong.get::<Self>()))
108                                })
109                                .build()
110                        },
111                    )
112                    .build(),
113            ))
114            .inner(inner_shape::<T>)
115            .build()
116    };
117}
118
119unsafe impl<'a, T: Facet<'a>> Facet<'a> for alloc::sync::Weak<T> {
120    const VTABLE: &'static ValueVTable = &const {
121        value_vtable!(alloc::sync::Weak<T>, |f, opts| {
122            write!(f, "{}", Self::SHAPE.type_identifier)?;
123            if let Some(opts) = opts.for_children() {
124                write!(f, "<")?;
125                (T::SHAPE.vtable.type_name())(f, opts)?;
126                write!(f, ">")?;
127            } else {
128                write!(f, "<…>")?;
129            }
130            Ok(())
131        })
132    };
133
134    const SHAPE: &'static crate::Shape<'static> = &const {
135        // Function to return inner type's shape
136        fn inner_shape<'a, T: Facet<'a>>() -> &'static Shape<'static> {
137            T::SHAPE
138        }
139
140        crate::Shape::builder_for_sized::<Self>()
141            .type_identifier("Weak")
142            .type_params(&[crate::TypeParam {
143                name: "T",
144                shape: || T::SHAPE,
145            }])
146            .ty(Type::User(UserType::Opaque))
147            .def(Def::SmartPointer(
148                SmartPointerDef::builder()
149                    .pointee(|| T::SHAPE)
150                    .flags(SmartPointerFlags::ATOMIC.union(SmartPointerFlags::WEAK))
151                    .known(KnownSmartPointer::ArcWeak)
152                    .strong(|| <alloc::sync::Arc<T> as Facet>::SHAPE)
153                    .vtable(
154                        &const {
155                            SmartPointerVTable::builder()
156                                .upgrade_into_fn(|weak, strong| unsafe {
157                                    Some(strong.put(weak.get::<Self>().upgrade()?))
158                                })
159                                .build()
160                        },
161                    )
162                    .build(),
163            ))
164            .inner(inner_shape::<T>)
165            .build()
166    };
167}
168
169#[cfg(test)]
170mod tests {
171    use alloc::string::String;
172    use alloc::sync::{Arc, Weak as ArcWeak};
173
174    use super::*;
175
176    #[test]
177    fn test_arc_type_params() {
178        let [type_param_1] = <Arc<i32>>::SHAPE.type_params else {
179            panic!("Arc<T> should only have 1 type param")
180        };
181        assert_eq!(type_param_1.shape(), i32::SHAPE);
182    }
183
184    #[test]
185    fn test_arc_vtable_1_new_borrow_drop() -> eyre::Result<()> {
186        facet_testhelpers::setup();
187
188        let arc_shape = <Arc<String>>::SHAPE;
189        let arc_def = arc_shape
190            .def
191            .into_smart_pointer()
192            .expect("Arc<T> should have a smart pointer definition");
193
194        // Allocate memory for the Arc
195        let arc_uninit_ptr = arc_shape.allocate()?;
196
197        // Get the function pointer for creating a new Arc from a value
198        let new_into_fn = arc_def
199            .vtable
200            .new_into_fn
201            .expect("Arc<T> should have new_into_fn");
202
203        // Create the value and initialize the Arc
204        let mut value = String::from("example");
205        let arc_ptr = unsafe { new_into_fn(arc_uninit_ptr, PtrMut::new(&raw mut value)) };
206        // The value now belongs to the Arc, prevent its drop
207        core::mem::forget(value);
208
209        // Get the function pointer for borrowing the inner value
210        let borrow_fn = arc_def
211            .vtable
212            .borrow_fn
213            .expect("Arc<T> should have borrow_fn");
214
215        // Borrow the inner value and check it
216        let borrowed_ptr = unsafe { borrow_fn(arc_ptr.as_const()) };
217        // SAFETY: borrowed_ptr points to a valid String within the Arc
218        assert_eq!(unsafe { borrowed_ptr.get::<String>() }, "example");
219
220        // Get the function pointer for dropping the Arc
221        let drop_fn = (arc_shape.vtable.sized().unwrap().drop_in_place)()
222            .expect("Arc<T> should have drop_in_place");
223
224        // Drop the Arc in place
225        // SAFETY: arc_ptr points to a valid Arc<String>
226        unsafe { drop_fn(arc_ptr) };
227
228        // Deallocate the memory
229        // SAFETY: arc_ptr was allocated by arc_shape and is now dropped (but memory is still valid)
230        unsafe { arc_shape.deallocate_mut(arc_ptr)? };
231
232        Ok(())
233    }
234
235    #[test]
236    fn test_arc_vtable_2_downgrade_upgrade_drop() -> eyre::Result<()> {
237        facet_testhelpers::setup();
238
239        let arc_shape = <Arc<String>>::SHAPE;
240        let arc_def = arc_shape
241            .def
242            .into_smart_pointer()
243            .expect("Arc<T> should have a smart pointer definition");
244
245        let weak_shape = <ArcWeak<String>>::SHAPE;
246        let weak_def = weak_shape
247            .def
248            .into_smart_pointer()
249            .expect("ArcWeak<T> should have a smart pointer definition");
250
251        // 1. Create the first Arc (arc1)
252        let arc1_uninit_ptr = arc_shape.allocate()?;
253        let new_into_fn = arc_def.vtable.new_into_fn.unwrap();
254        let mut value = String::from("example");
255        let arc1_ptr = unsafe { new_into_fn(arc1_uninit_ptr, PtrMut::new(&raw mut value)) };
256        core::mem::forget(value); // Value now owned by arc1
257
258        // 2. Downgrade arc1 to create a weak pointer (weak1)
259        let weak1_uninit_ptr = weak_shape.allocate()?;
260        let downgrade_into_fn = arc_def.vtable.downgrade_into_fn.unwrap();
261        // SAFETY: arc1_ptr points to a valid Arc, weak1_uninit_ptr is allocated for a Weak
262        let weak1_ptr = unsafe { downgrade_into_fn(arc1_ptr, weak1_uninit_ptr) };
263
264        // 3. Upgrade weak1 to create a second Arc (arc2)
265        let arc2_uninit_ptr = arc_shape.allocate()?;
266        let upgrade_into_fn = weak_def.vtable.upgrade_into_fn.unwrap();
267        // SAFETY: weak1_ptr points to a valid Weak, arc2_uninit_ptr is allocated for an Arc.
268        // Upgrade should succeed as arc1 still exists.
269        let arc2_ptr = unsafe { upgrade_into_fn(weak1_ptr, arc2_uninit_ptr) }
270            .expect("Upgrade should succeed while original Arc exists");
271
272        // Check the content of the upgraded Arc
273        let borrow_fn = arc_def.vtable.borrow_fn.unwrap();
274        // SAFETY: arc2_ptr points to a valid Arc<String>
275        let borrowed_ptr = unsafe { borrow_fn(arc2_ptr.as_const()) };
276        // SAFETY: borrowed_ptr points to a valid String
277        assert_eq!(unsafe { borrowed_ptr.get::<String>() }, "example");
278
279        // 4. Drop everything and free memory
280        let arc_drop_fn = (arc_shape.vtable.sized().unwrap().drop_in_place)().unwrap();
281        let weak_drop_fn = (weak_shape.vtable.sized().unwrap().drop_in_place)().unwrap();
282
283        unsafe {
284            // Drop Arcs
285            arc_drop_fn(arc1_ptr);
286            arc_shape.deallocate_mut(arc1_ptr)?;
287            arc_drop_fn(arc2_ptr);
288            arc_shape.deallocate_mut(arc2_ptr)?;
289
290            // Drop Weak
291            weak_drop_fn(weak1_ptr);
292            weak_shape.deallocate_mut(weak1_ptr)?;
293        }
294
295        Ok(())
296    }
297
298    #[test]
299    fn test_arc_vtable_3_downgrade_drop_try_upgrade() -> eyre::Result<()> {
300        facet_testhelpers::setup();
301
302        let arc_shape = <Arc<String>>::SHAPE;
303        let arc_def = arc_shape
304            .def
305            .into_smart_pointer()
306            .expect("Arc<T> should have a smart pointer definition");
307
308        let weak_shape = <ArcWeak<String>>::SHAPE;
309        let weak_def = weak_shape
310            .def
311            .into_smart_pointer()
312            .expect("ArcWeak<T> should have a smart pointer definition");
313
314        // 1. Create the strong Arc (arc1)
315        let arc1_uninit_ptr = arc_shape.allocate()?;
316        let new_into_fn = arc_def.vtable.new_into_fn.unwrap();
317        let mut value = String::from("example");
318        let arc1_ptr = unsafe { new_into_fn(arc1_uninit_ptr, PtrMut::new(&raw mut value)) };
319        core::mem::forget(value);
320
321        // 2. Downgrade arc1 to create a weak pointer (weak1)
322        let weak1_uninit_ptr = weak_shape.allocate()?;
323        let downgrade_into_fn = arc_def.vtable.downgrade_into_fn.unwrap();
324        // SAFETY: arc1_ptr is valid, weak1_uninit_ptr is allocated for Weak
325        let weak1_ptr = unsafe { downgrade_into_fn(arc1_ptr, weak1_uninit_ptr) };
326
327        // 3. Drop and free the strong pointer (arc1)
328        let arc_drop_fn = (arc_shape.vtable.sized().unwrap().drop_in_place)().unwrap();
329        unsafe {
330            arc_drop_fn(arc1_ptr);
331            arc_shape.deallocate_mut(arc1_ptr)?;
332        }
333
334        // 4. Attempt to upgrade the weak pointer (weak1)
335        let upgrade_into_fn = weak_def.vtable.upgrade_into_fn.unwrap();
336        let arc2_uninit_ptr = arc_shape.allocate()?;
337        // SAFETY: weak1_ptr is valid (though points to dropped data), arc2_uninit_ptr is allocated for Arc
338        let upgrade_result = unsafe { upgrade_into_fn(weak1_ptr, arc2_uninit_ptr) };
339
340        // Assert that the upgrade failed
341        assert!(
342            upgrade_result.is_none(),
343            "Upgrade should fail after the strong Arc is dropped"
344        );
345
346        // 5. Clean up: Deallocate the memory intended for the failed upgrade and drop/deallocate the weak pointer
347        let weak_drop_fn = (weak_shape.vtable.sized().unwrap().drop_in_place)().unwrap();
348        unsafe {
349            // Deallocate the *uninitialized* memory allocated for the failed upgrade attempt
350            arc_shape.deallocate_uninit(arc2_uninit_ptr)?;
351
352            // Drop and deallocate the weak pointer
353            weak_drop_fn(weak1_ptr);
354            weak_shape.deallocate_mut(weak1_ptr)?;
355        }
356
357        Ok(())
358    }
359
360    #[test]
361    fn test_arc_vtable_4_try_from() -> eyre::Result<()> {
362        facet_testhelpers::setup();
363
364        // Get the shapes we'll be working with
365        let string_shape = <String>::SHAPE;
366        let arc_shape = <Arc<String>>::SHAPE;
367        let arc_def = arc_shape
368            .def
369            .into_smart_pointer()
370            .expect("Arc<T> should have a smart pointer definition");
371
372        // 1. Create a String value
373        let value = String::from("try_from test");
374        let value_ptr = PtrConst::new(&value as *const String as *const u8);
375
376        // 2. Allocate memory for the Arc<String>
377        let arc_uninit_ptr = arc_shape.allocate()?;
378
379        // 3. Get the try_from function from the Arc<String> shape's ValueVTable
380        let try_from_fn =
381            (arc_shape.vtable.sized().unwrap().try_from)().expect("Arc<T> should have try_from");
382
383        // 4. Try to convert String to Arc<String>
384        let arc_ptr = unsafe { try_from_fn(value_ptr, string_shape, arc_uninit_ptr) }
385            .expect("try_from should succeed");
386        core::mem::forget(value);
387
388        // 5. Borrow the inner value and verify it's correct
389        let borrow_fn = arc_def
390            .vtable
391            .borrow_fn
392            .expect("Arc<T> should have borrow_fn");
393        let borrowed_ptr = unsafe { borrow_fn(arc_ptr.as_const()) };
394
395        // SAFETY: borrowed_ptr points to a valid String within the Arc
396        assert_eq!(unsafe { borrowed_ptr.get::<String>() }, "try_from test");
397
398        // 6. Clean up
399        let drop_fn = (arc_shape.vtable.sized().unwrap().drop_in_place)()
400            .expect("Arc<T> should have drop_in_place");
401
402        unsafe {
403            drop_fn(arc_ptr);
404            arc_shape.deallocate_mut(arc_ptr)?;
405        }
406
407        Ok(())
408    }
409
410    #[test]
411    fn test_arc_vtable_5_try_into_inner() -> eyre::Result<()> {
412        facet_testhelpers::setup();
413
414        // Get the shapes we'll be working with
415        let string_shape = <String>::SHAPE;
416        let arc_shape = <Arc<String>>::SHAPE;
417        let arc_def = arc_shape
418            .def
419            .into_smart_pointer()
420            .expect("Arc<T> should have a smart pointer definition");
421
422        // 1. Create an Arc<String>
423        let arc_uninit_ptr = arc_shape.allocate()?;
424        let new_into_fn = arc_def
425            .vtable
426            .new_into_fn
427            .expect("Arc<T> should have new_into_fn");
428
429        let mut value = String::from("try_into_inner test");
430        let arc_ptr = unsafe { new_into_fn(arc_uninit_ptr, PtrMut::new(&raw mut value)) };
431        core::mem::forget(value); // Value now owned by arc
432
433        // 2. Allocate memory for the extracted String
434        let string_uninit_ptr = string_shape.allocate()?;
435
436        // 3. Get the try_into_inner function from the Arc<String>'s ValueVTable
437        let try_into_inner_fn = (arc_shape.vtable.sized().unwrap().try_into_inner)()
438            .expect("Arc<T> Shape should have try_into_inner");
439
440        // 4. Try to extract the String from the Arc<String>
441        // This should succeed because we have exclusive access to the Arc (strong count = 1)
442        let string_ptr = unsafe { try_into_inner_fn(arc_ptr, string_uninit_ptr) }
443            .expect("try_into_inner should succeed with exclusive access");
444
445        // 5. Verify the extracted String
446        assert_eq!(
447            unsafe { string_ptr.as_const().get::<String>() },
448            "try_into_inner test"
449        );
450
451        // 6. Clean up
452        let string_drop_fn = (string_shape.vtable.sized().unwrap().drop_in_place)()
453            .expect("String should have drop_in_place");
454
455        unsafe {
456            // The Arc should already be dropped by try_into_inner
457            // But we still need to deallocate its memory
458            arc_shape.deallocate_mut(arc_ptr)?;
459
460            // Drop and deallocate the extracted String
461            string_drop_fn(string_ptr);
462            string_shape.deallocate_mut(string_ptr)?;
463        }
464
465        Ok(())
466    }
467}