facet_core/impls_alloc/
arc.rs

1use crate::{
2    Def, Facet, KnownSmartPointer, PtrConst, PtrMut, PtrUninit, Shape, SmartPointerDef,
3    SmartPointerFlags, SmartPointerVTable, TryBorrowInnerError, TryFromError, TryIntoInnerError,
4    Type, UserType, ValueVTable, value_vtable,
5};
6
7unsafe impl<'a, T: Facet<'a>> Facet<'a> for alloc::sync::Arc<T> {
8    const VTABLE: &'static ValueVTable = &const {
9        // Define the functions for transparent conversion between Arc<T> and T
10        unsafe fn try_from<'a, 'shape, 'src, 'dst, T: Facet<'a>>(
11            src_ptr: PtrConst<'src>,
12            src_shape: &'shape Shape<'shape>,
13            dst: PtrUninit<'dst>,
14        ) -> Result<PtrMut<'dst>, TryFromError<'shape>> {
15            if src_shape.id != T::SHAPE.id {
16                return Err(TryFromError::UnsupportedSourceShape {
17                    src_shape,
18                    expected: &[T::SHAPE],
19                });
20            }
21            let t = unsafe { src_ptr.read::<T>() };
22            let arc = alloc::sync::Arc::new(t);
23            Ok(unsafe { dst.put(arc) })
24        }
25
26        unsafe fn try_into_inner<'a, 'src, 'dst, T: Facet<'a>>(
27            src_ptr: PtrMut<'src>,
28            dst: PtrUninit<'dst>,
29        ) -> Result<PtrMut<'dst>, TryIntoInnerError> {
30            use alloc::sync::Arc;
31
32            // Read the Arc from the source pointer
33            let arc = unsafe { src_ptr.read::<Arc<T>>() };
34
35            // Try to unwrap the Arc to get exclusive ownership
36            match Arc::try_unwrap(arc) {
37                Ok(inner) => Ok(unsafe { dst.put(inner) }),
38                Err(arc) => {
39                    // Arc is shared, so we can't extract the inner value
40                    core::mem::forget(arc);
41                    Err(TryIntoInnerError::Unavailable)
42                }
43            }
44        }
45
46        unsafe fn try_borrow_inner<'a, 'src, T: Facet<'a>>(
47            src_ptr: PtrConst<'src>,
48        ) -> Result<PtrConst<'src>, TryBorrowInnerError> {
49            let arc = unsafe { src_ptr.get::<alloc::sync::Arc<T>>() };
50            Ok(PtrConst::new(&**arc))
51        }
52
53        let mut vtable = value_vtable!(alloc::sync::Arc<T>, |f, opts| {
54            write!(f, "{}", Self::SHAPE.type_identifier)?;
55            if let Some(opts) = opts.for_children() {
56                write!(f, "<")?;
57                (T::SHAPE.vtable.type_name)(f, opts)?;
58                write!(f, ">")?;
59            } else {
60                write!(f, "<…>")?;
61            }
62            Ok(())
63        });
64
65        vtable.try_from = || Some(try_from::<T>);
66        vtable.try_into_inner = || Some(try_into_inner::<T>);
67        vtable.try_borrow_inner = || Some(try_borrow_inner::<T>);
68        vtable
69    };
70
71    const SHAPE: &'static crate::Shape<'static> = &const {
72        // Function to return inner type's shape
73        fn inner_shape<'a, T: Facet<'a>>() -> &'static Shape<'static> {
74            T::SHAPE
75        }
76
77        crate::Shape::builder_for_sized::<Self>()
78            .type_identifier("Arc")
79            .type_params(&[crate::TypeParam {
80                name: "T",
81                shape: || T::SHAPE,
82            }])
83            .ty(Type::User(UserType::Opaque))
84            .def(Def::SmartPointer(
85                SmartPointerDef::builder()
86                    .pointee(|| T::SHAPE)
87                    .flags(SmartPointerFlags::ATOMIC)
88                    .known(KnownSmartPointer::Arc)
89                    .weak(|| <alloc::sync::Weak<T> as Facet>::SHAPE)
90                    .vtable(
91                        &const {
92                            SmartPointerVTable::builder()
93                                .borrow_fn(|this| {
94                                    let arc_ptr = unsafe { this.as_ptr::<alloc::sync::Arc<T>>() };
95                                    let ptr = unsafe { alloc::sync::Arc::as_ptr(&*arc_ptr) };
96                                    PtrConst::new(ptr)
97                                })
98                                .new_into_fn(|this, ptr| {
99                                    let t = unsafe { ptr.read::<T>() };
100                                    let arc = alloc::sync::Arc::new(t);
101                                    unsafe { this.put(arc) }
102                                })
103                                .downgrade_into_fn(|strong, weak| unsafe {
104                                    weak.put(alloc::sync::Arc::downgrade(strong.get::<Self>()))
105                                })
106                                .build()
107                        },
108                    )
109                    .build(),
110            ))
111            .inner(inner_shape::<T>)
112            .build()
113    };
114}
115
116unsafe impl<'a, T: Facet<'a>> Facet<'a> for alloc::sync::Weak<T> {
117    const VTABLE: &'static ValueVTable = &const {
118        value_vtable!(alloc::sync::Weak<T>, |f, opts| {
119            write!(f, "{}", Self::SHAPE.type_identifier)?;
120            if let Some(opts) = opts.for_children() {
121                write!(f, "<")?;
122                (T::SHAPE.vtable.type_name)(f, opts)?;
123                write!(f, ">")?;
124            } else {
125                write!(f, "<…>")?;
126            }
127            Ok(())
128        })
129    };
130
131    const SHAPE: &'static crate::Shape<'static> = &const {
132        // Function to return inner type's shape
133        fn inner_shape<'a, T: Facet<'a>>() -> &'static Shape<'static> {
134            T::SHAPE
135        }
136
137        crate::Shape::builder_for_sized::<Self>()
138            .type_identifier("Weak")
139            .type_params(&[crate::TypeParam {
140                name: "T",
141                shape: || T::SHAPE,
142            }])
143            .ty(Type::User(UserType::Opaque))
144            .def(Def::SmartPointer(
145                SmartPointerDef::builder()
146                    .pointee(|| T::SHAPE)
147                    .flags(SmartPointerFlags::ATOMIC.union(SmartPointerFlags::WEAK))
148                    .known(KnownSmartPointer::ArcWeak)
149                    .strong(|| <alloc::sync::Arc<T> as Facet>::SHAPE)
150                    .vtable(
151                        &const {
152                            SmartPointerVTable::builder()
153                                .upgrade_into_fn(|weak, strong| unsafe {
154                                    Some(strong.put(weak.get::<Self>().upgrade()?))
155                                })
156                                .build()
157                        },
158                    )
159                    .build(),
160            ))
161            .inner(inner_shape::<T>)
162            .build()
163    };
164}
165
166#[cfg(test)]
167mod tests {
168    use alloc::string::String;
169    use alloc::sync::{Arc, Weak as ArcWeak};
170
171    use super::*;
172
173    #[test]
174    fn test_arc_type_params() {
175        let [type_param_1] = <Arc<i32>>::SHAPE.type_params else {
176            panic!("Arc<T> should only have 1 type param")
177        };
178        assert_eq!(type_param_1.shape(), i32::SHAPE);
179    }
180
181    #[test]
182    fn test_arc_vtable_1_new_borrow_drop() -> eyre::Result<()> {
183        facet_testhelpers::setup();
184
185        let arc_shape = <Arc<String>>::SHAPE;
186        let arc_def = arc_shape
187            .def
188            .into_smart_pointer()
189            .expect("Arc<T> should have a smart pointer definition");
190
191        // Allocate memory for the Arc
192        let arc_uninit_ptr = arc_shape.allocate()?;
193
194        // Get the function pointer for creating a new Arc from a value
195        let new_into_fn = arc_def
196            .vtable
197            .new_into_fn
198            .expect("Arc<T> should have new_into_fn");
199
200        // Create the value and initialize the Arc
201        let mut value = String::from("example");
202        let arc_ptr = unsafe { new_into_fn(arc_uninit_ptr, PtrMut::new(&raw mut value)) };
203        // The value now belongs to the Arc, prevent its drop
204        core::mem::forget(value);
205
206        // Get the function pointer for borrowing the inner value
207        let borrow_fn = arc_def
208            .vtable
209            .borrow_fn
210            .expect("Arc<T> should have borrow_fn");
211
212        // Borrow the inner value and check it
213        let borrowed_ptr = unsafe { borrow_fn(arc_ptr.as_const()) };
214        // SAFETY: borrowed_ptr points to a valid String within the Arc
215        assert_eq!(unsafe { borrowed_ptr.get::<String>() }, "example");
216
217        // Get the function pointer for dropping the Arc
218        let drop_fn = (arc_shape.vtable.drop_in_place)().expect("Arc<T> should have drop_in_place");
219
220        // Drop the Arc in place
221        // SAFETY: arc_ptr points to a valid Arc<String>
222        unsafe { drop_fn(arc_ptr) };
223
224        // Deallocate the memory
225        // SAFETY: arc_ptr was allocated by arc_shape and is now dropped (but memory is still valid)
226        unsafe { arc_shape.deallocate_mut(arc_ptr)? };
227
228        Ok(())
229    }
230
231    #[test]
232    fn test_arc_vtable_2_downgrade_upgrade_drop() -> eyre::Result<()> {
233        facet_testhelpers::setup();
234
235        let arc_shape = <Arc<String>>::SHAPE;
236        let arc_def = arc_shape
237            .def
238            .into_smart_pointer()
239            .expect("Arc<T> should have a smart pointer definition");
240
241        let weak_shape = <ArcWeak<String>>::SHAPE;
242        let weak_def = weak_shape
243            .def
244            .into_smart_pointer()
245            .expect("ArcWeak<T> should have a smart pointer definition");
246
247        // 1. Create the first Arc (arc1)
248        let arc1_uninit_ptr = arc_shape.allocate()?;
249        let new_into_fn = arc_def.vtable.new_into_fn.unwrap();
250        let mut value = String::from("example");
251        let arc1_ptr = unsafe { new_into_fn(arc1_uninit_ptr, PtrMut::new(&raw mut value)) };
252        core::mem::forget(value); // Value now owned by arc1
253
254        // 2. Downgrade arc1 to create a weak pointer (weak1)
255        let weak1_uninit_ptr = weak_shape.allocate()?;
256        let downgrade_into_fn = arc_def.vtable.downgrade_into_fn.unwrap();
257        // SAFETY: arc1_ptr points to a valid Arc, weak1_uninit_ptr is allocated for a Weak
258        let weak1_ptr = unsafe { downgrade_into_fn(arc1_ptr, weak1_uninit_ptr) };
259
260        // 3. Upgrade weak1 to create a second Arc (arc2)
261        let arc2_uninit_ptr = arc_shape.allocate()?;
262        let upgrade_into_fn = weak_def.vtable.upgrade_into_fn.unwrap();
263        // SAFETY: weak1_ptr points to a valid Weak, arc2_uninit_ptr is allocated for an Arc.
264        // Upgrade should succeed as arc1 still exists.
265        let arc2_ptr = unsafe { upgrade_into_fn(weak1_ptr, arc2_uninit_ptr) }
266            .expect("Upgrade should succeed while original Arc exists");
267
268        // Check the content of the upgraded Arc
269        let borrow_fn = arc_def.vtable.borrow_fn.unwrap();
270        // SAFETY: arc2_ptr points to a valid Arc<String>
271        let borrowed_ptr = unsafe { borrow_fn(arc2_ptr.as_const()) };
272        // SAFETY: borrowed_ptr points to a valid String
273        assert_eq!(unsafe { borrowed_ptr.get::<String>() }, "example");
274
275        // 4. Drop everything and free memory
276        let arc_drop_fn = (arc_shape.vtable.drop_in_place)().unwrap();
277        let weak_drop_fn = (weak_shape.vtable.drop_in_place)().unwrap();
278
279        unsafe {
280            // Drop Arcs
281            arc_drop_fn(arc1_ptr);
282            arc_shape.deallocate_mut(arc1_ptr)?;
283            arc_drop_fn(arc2_ptr);
284            arc_shape.deallocate_mut(arc2_ptr)?;
285
286            // Drop Weak
287            weak_drop_fn(weak1_ptr);
288            weak_shape.deallocate_mut(weak1_ptr)?;
289        }
290
291        Ok(())
292    }
293
294    #[test]
295    fn test_arc_vtable_3_downgrade_drop_try_upgrade() -> eyre::Result<()> {
296        facet_testhelpers::setup();
297
298        let arc_shape = <Arc<String>>::SHAPE;
299        let arc_def = arc_shape
300            .def
301            .into_smart_pointer()
302            .expect("Arc<T> should have a smart pointer definition");
303
304        let weak_shape = <ArcWeak<String>>::SHAPE;
305        let weak_def = weak_shape
306            .def
307            .into_smart_pointer()
308            .expect("ArcWeak<T> should have a smart pointer definition");
309
310        // 1. Create the strong Arc (arc1)
311        let arc1_uninit_ptr = arc_shape.allocate()?;
312        let new_into_fn = arc_def.vtable.new_into_fn.unwrap();
313        let mut value = String::from("example");
314        let arc1_ptr = unsafe { new_into_fn(arc1_uninit_ptr, PtrMut::new(&raw mut value)) };
315        core::mem::forget(value);
316
317        // 2. Downgrade arc1 to create a weak pointer (weak1)
318        let weak1_uninit_ptr = weak_shape.allocate()?;
319        let downgrade_into_fn = arc_def.vtable.downgrade_into_fn.unwrap();
320        // SAFETY: arc1_ptr is valid, weak1_uninit_ptr is allocated for Weak
321        let weak1_ptr = unsafe { downgrade_into_fn(arc1_ptr, weak1_uninit_ptr) };
322
323        // 3. Drop and free the strong pointer (arc1)
324        let arc_drop_fn = (arc_shape.vtable.drop_in_place)().unwrap();
325        unsafe {
326            arc_drop_fn(arc1_ptr);
327            arc_shape.deallocate_mut(arc1_ptr)?;
328        }
329
330        // 4. Attempt to upgrade the weak pointer (weak1)
331        let upgrade_into_fn = weak_def.vtable.upgrade_into_fn.unwrap();
332        let arc2_uninit_ptr = arc_shape.allocate()?;
333        // SAFETY: weak1_ptr is valid (though points to dropped data), arc2_uninit_ptr is allocated for Arc
334        let upgrade_result = unsafe { upgrade_into_fn(weak1_ptr, arc2_uninit_ptr) };
335
336        // Assert that the upgrade failed
337        assert!(
338            upgrade_result.is_none(),
339            "Upgrade should fail after the strong Arc is dropped"
340        );
341
342        // 5. Clean up: Deallocate the memory intended for the failed upgrade and drop/deallocate the weak pointer
343        let weak_drop_fn = (weak_shape.vtable.drop_in_place)().unwrap();
344        unsafe {
345            // Deallocate the *uninitialized* memory allocated for the failed upgrade attempt
346            arc_shape.deallocate_uninit(arc2_uninit_ptr)?;
347
348            // Drop and deallocate the weak pointer
349            weak_drop_fn(weak1_ptr);
350            weak_shape.deallocate_mut(weak1_ptr)?;
351        }
352
353        Ok(())
354    }
355
356    #[test]
357    fn test_arc_vtable_4_try_from() -> eyre::Result<()> {
358        facet_testhelpers::setup();
359
360        // Get the shapes we'll be working with
361        let string_shape = <String>::SHAPE;
362        let arc_shape = <Arc<String>>::SHAPE;
363        let arc_def = arc_shape
364            .def
365            .into_smart_pointer()
366            .expect("Arc<T> should have a smart pointer definition");
367
368        // 1. Create a String value
369        let value = String::from("try_from test");
370        let value_ptr = PtrConst::new(&value as *const String as *const u8);
371
372        // 2. Allocate memory for the Arc<String>
373        let arc_uninit_ptr = arc_shape.allocate()?;
374
375        // 3. Get the try_from function from the Arc<String> shape's ValueVTable
376        let try_from_fn = (arc_shape.vtable.try_from)().expect("Arc<T> should have try_from");
377
378        // 4. Try to convert String to Arc<String>
379        let arc_ptr = unsafe { try_from_fn(value_ptr, string_shape, arc_uninit_ptr) }
380            .expect("try_from should succeed");
381        core::mem::forget(value);
382
383        // 5. Borrow the inner value and verify it's correct
384        let borrow_fn = arc_def
385            .vtable
386            .borrow_fn
387            .expect("Arc<T> should have borrow_fn");
388        let borrowed_ptr = unsafe { borrow_fn(arc_ptr.as_const()) };
389
390        // SAFETY: borrowed_ptr points to a valid String within the Arc
391        assert_eq!(unsafe { borrowed_ptr.get::<String>() }, "try_from test");
392
393        // 6. Clean up
394        let drop_fn = (arc_shape.vtable.drop_in_place)().expect("Arc<T> should have drop_in_place");
395
396        unsafe {
397            drop_fn(arc_ptr);
398            arc_shape.deallocate_mut(arc_ptr)?;
399        }
400
401        Ok(())
402    }
403
404    #[test]
405    fn test_arc_vtable_5_try_into_inner() -> eyre::Result<()> {
406        facet_testhelpers::setup();
407
408        // Get the shapes we'll be working with
409        let string_shape = <String>::SHAPE;
410        let arc_shape = <Arc<String>>::SHAPE;
411        let arc_def = arc_shape
412            .def
413            .into_smart_pointer()
414            .expect("Arc<T> should have a smart pointer definition");
415
416        // 1. Create an Arc<String>
417        let arc_uninit_ptr = arc_shape.allocate()?;
418        let new_into_fn = arc_def
419            .vtable
420            .new_into_fn
421            .expect("Arc<T> should have new_into_fn");
422
423        let mut value = String::from("try_into_inner test");
424        let arc_ptr = unsafe { new_into_fn(arc_uninit_ptr, PtrMut::new(&raw mut value)) };
425        core::mem::forget(value); // Value now owned by arc
426
427        // 2. Allocate memory for the extracted String
428        let string_uninit_ptr = string_shape.allocate()?;
429
430        // 3. Get the try_into_inner function from the Arc<String>'s ValueVTable
431        let try_into_inner_fn =
432            (arc_shape.vtable.try_into_inner)().expect("Arc<T> Shape should have try_into_inner");
433
434        // 4. Try to extract the String from the Arc<String>
435        // This should succeed because we have exclusive access to the Arc (strong count = 1)
436        let string_ptr = unsafe { try_into_inner_fn(arc_ptr, string_uninit_ptr) }
437            .expect("try_into_inner should succeed with exclusive access");
438
439        // 5. Verify the extracted String
440        assert_eq!(
441            unsafe { string_ptr.as_const().get::<String>() },
442            "try_into_inner test"
443        );
444
445        // 6. Clean up
446        let string_drop_fn =
447            (string_shape.vtable.drop_in_place)().expect("String should have drop_in_place");
448
449        unsafe {
450            // The Arc should already be dropped by try_into_inner
451            // But we still need to deallocate its memory
452            arc_shape.deallocate_mut(arc_ptr)?;
453
454            // Drop and deallocate the extracted String
455            string_drop_fn(string_ptr);
456            string_shape.deallocate_mut(string_ptr)?;
457        }
458
459        Ok(())
460    }
461}