facet_core/impls_alloc/
rc.rs

1use crate::{
2    Def, Facet, KnownSmartPointer, PtrConst, PtrMut, PtrUninit, Shape, SmartPointerDef,
3    SmartPointerFlags, SmartPointerVTable, TryBorrowInnerError, TryFromError, TryIntoInnerError,
4    Type, UserType, ValueVTable, value_vtable,
5};
6
7unsafe impl<'a, T: Facet<'a>> Facet<'a> for alloc::rc::Rc<T> {
8    const VTABLE: &'static ValueVTable = &const {
9        // Define the functions for transparent conversion between Rc<T> and T
10        unsafe fn try_from<'a, 'shape, 'src, 'dst, T: Facet<'a>>(
11            src_ptr: PtrConst<'src>,
12            src_shape: &'shape Shape<'shape>,
13            dst: PtrUninit<'dst>,
14        ) -> Result<PtrMut<'dst>, TryFromError<'shape>> {
15            if src_shape.id != T::SHAPE.id {
16                return Err(TryFromError::UnsupportedSourceShape {
17                    src_shape,
18                    expected: &[T::SHAPE],
19                });
20            }
21            let t = unsafe { src_ptr.read::<T>() };
22            let rc = alloc::rc::Rc::new(t);
23            Ok(unsafe { dst.put(rc) })
24        }
25
26        unsafe fn try_into_inner<'a, 'src, 'dst, T: Facet<'a>>(
27            src_ptr: PtrMut<'src>,
28            dst: PtrUninit<'dst>,
29        ) -> Result<PtrMut<'dst>, TryIntoInnerError> {
30            let rc = unsafe { src_ptr.get::<alloc::rc::Rc<T>>() };
31            match alloc::rc::Rc::try_unwrap(rc.clone()) {
32                Ok(t) => Ok(unsafe { dst.put(t) }),
33                Err(_) => Err(TryIntoInnerError::Unavailable),
34            }
35        }
36
37        unsafe fn try_borrow_inner<'a, 'src, T: Facet<'a>>(
38            src_ptr: PtrConst<'src>,
39        ) -> Result<PtrConst<'src>, TryBorrowInnerError> {
40            let rc = unsafe { src_ptr.get::<alloc::rc::Rc<T>>() };
41            Ok(PtrConst::new(&**rc))
42        }
43
44        let mut vtable = value_vtable!(alloc::rc::Rc<T>, |f, opts| {
45            write!(f, "{}", Self::SHAPE.type_identifier)?;
46            if let Some(opts) = opts.for_children() {
47                write!(f, "<")?;
48                (T::SHAPE.vtable.type_name)(f, opts)?;
49                write!(f, ">")?;
50            } else {
51                write!(f, "<…>")?;
52            }
53            Ok(())
54        });
55        vtable.try_from = || Some(try_from::<T>);
56        vtable.try_into_inner = || Some(try_into_inner::<T>);
57        vtable.try_borrow_inner = || Some(try_borrow_inner::<T>);
58        vtable
59    };
60
61    const SHAPE: &'static crate::Shape<'static> = &const {
62        // Function to return inner type's shape
63        fn inner_shape<'a, T: Facet<'a>>() -> &'static Shape<'static> {
64            T::SHAPE
65        }
66
67        crate::Shape::builder_for_sized::<Self>()
68            .type_identifier("Rc")
69            .type_params(&[crate::TypeParam {
70                name: "T",
71                shape: || T::SHAPE,
72            }])
73            .ty(Type::User(UserType::Opaque))
74            .def(Def::SmartPointer(
75                SmartPointerDef::builder()
76                    .pointee(|| T::SHAPE)
77                    .flags(SmartPointerFlags::EMPTY)
78                    .known(KnownSmartPointer::Rc)
79                    .weak(|| <alloc::rc::Weak<T> as Facet>::SHAPE)
80                    .vtable(
81                        &const {
82                            SmartPointerVTable::builder()
83                                .borrow_fn(|this| {
84                                    let ptr = Self::as_ptr(unsafe { this.get() });
85                                    PtrConst::new(ptr)
86                                })
87                                .new_into_fn(|this, ptr| {
88                                    let t = unsafe { ptr.read::<T>() };
89                                    let rc = alloc::rc::Rc::new(t);
90                                    unsafe { this.put(rc) }
91                                })
92                                .downgrade_into_fn(|strong, weak| unsafe {
93                                    weak.put(alloc::rc::Rc::downgrade(strong.get::<Self>()))
94                                })
95                                .build()
96                        },
97                    )
98                    .build(),
99            ))
100            .inner(inner_shape::<T>)
101            .build()
102    };
103}
104
105unsafe impl<'a, T: Facet<'a>> Facet<'a> for alloc::rc::Weak<T> {
106    const VTABLE: &'static ValueVTable = &const {
107        value_vtable!(alloc::rc::Weak<T>, |f, opts| {
108            write!(f, "{}", Self::SHAPE.type_identifier)?;
109            if let Some(opts) = opts.for_children() {
110                write!(f, "<")?;
111                (T::SHAPE.vtable.type_name)(f, opts)?;
112                write!(f, ">")?;
113            } else {
114                write!(f, "<…>")?;
115            }
116            Ok(())
117        })
118    };
119
120    const SHAPE: &'static crate::Shape<'static> = &const {
121        // Function to return inner type's shape
122        fn inner_shape<'a, T: Facet<'a>>() -> &'static Shape<'static> {
123            T::SHAPE
124        }
125
126        crate::Shape::builder_for_sized::<Self>()
127            .type_identifier("Weak")
128            .type_params(&[crate::TypeParam {
129                name: "T",
130                shape: || T::SHAPE,
131            }])
132            .ty(Type::User(UserType::Opaque))
133            .def(Def::SmartPointer(
134                SmartPointerDef::builder()
135                    .pointee(|| T::SHAPE)
136                    .flags(SmartPointerFlags::WEAK)
137                    .known(KnownSmartPointer::RcWeak)
138                    .strong(|| <alloc::rc::Rc<T> as Facet>::SHAPE)
139                    .vtable(
140                        &const {
141                            SmartPointerVTable::builder()
142                                .upgrade_into_fn(|weak, strong| unsafe {
143                                    Some(strong.put(weak.get::<Self>().upgrade()?))
144                                })
145                                .build()
146                        },
147                    )
148                    .build(),
149            ))
150            .inner(inner_shape::<T>)
151            .build()
152    };
153}
154
155#[cfg(test)]
156mod tests {
157    use alloc::rc::{Rc, Weak as RcWeak};
158    use alloc::string::String;
159
160    use super::*;
161
162    #[test]
163    fn test_rc_type_params() {
164        let [type_param_1] = <Rc<i32>>::SHAPE.type_params else {
165            panic!("Rc<T> should only have 1 type param")
166        };
167        assert_eq!(type_param_1.shape(), i32::SHAPE);
168    }
169
170    #[test]
171    fn test_rc_vtable_1_new_borrow_drop() -> eyre::Result<()> {
172        facet_testhelpers::setup();
173
174        let rc_shape = <Rc<String>>::SHAPE;
175        let rc_def = rc_shape
176            .def
177            .into_smart_pointer()
178            .expect("Rc<T> should have a smart pointer definition");
179
180        // Allocate memory for the Rc
181        let rc_uninit_ptr = rc_shape.allocate()?;
182
183        // Get the function pointer for creating a new Rc from a value
184        let new_into_fn = rc_def
185            .vtable
186            .new_into_fn
187            .expect("Rc<T> should have new_into_fn");
188
189        // Create the value and initialize the Rc
190        let mut value = String::from("example");
191        let rc_ptr = unsafe { new_into_fn(rc_uninit_ptr, PtrMut::new(&raw mut value)) };
192        // The value now belongs to the Rc, prevent its drop
193        core::mem::forget(value);
194
195        // Get the function pointer for borrowing the inner value
196        let borrow_fn = rc_def
197            .vtable
198            .borrow_fn
199            .expect("Rc<T> should have borrow_fn");
200
201        // Borrow the inner value and check it
202        let borrowed_ptr = unsafe { borrow_fn(rc_ptr.as_const()) };
203        // SAFETY: borrowed_ptr points to a valid String within the Rc
204        assert_eq!(unsafe { borrowed_ptr.get::<String>() }, "example");
205
206        // Get the function pointer for dropping the Rc
207        let drop_fn = (rc_shape.vtable.drop_in_place)().expect("Rc<T> should have drop_in_place");
208
209        // Drop the Rc in place
210        // SAFETY: rc_ptr points to a valid Rc<String>
211        unsafe { drop_fn(rc_ptr) };
212
213        // Deallocate the memory
214        // SAFETY: rc_ptr was allocated by rc_shape and is now dropped (but memory is still valid)
215        unsafe { rc_shape.deallocate_mut(rc_ptr)? };
216
217        Ok(())
218    }
219
220    #[test]
221    fn test_rc_vtable_2_downgrade_upgrade_drop() -> eyre::Result<()> {
222        facet_testhelpers::setup();
223
224        let rc_shape = <Rc<String>>::SHAPE;
225        let rc_def = rc_shape
226            .def
227            .into_smart_pointer()
228            .expect("Rc<T> should have a smart pointer definition");
229
230        let weak_shape = <RcWeak<String>>::SHAPE;
231        let weak_def = weak_shape
232            .def
233            .into_smart_pointer()
234            .expect("RcWeak<T> should have a smart pointer definition");
235
236        // 1. Create the first Rc (rc1)
237        let rc1_uninit_ptr = rc_shape.allocate()?;
238        let new_into_fn = rc_def.vtable.new_into_fn.unwrap();
239        let mut value = String::from("example");
240        let rc1_ptr = unsafe { new_into_fn(rc1_uninit_ptr, PtrMut::new(&raw mut value)) };
241        core::mem::forget(value); // Value now owned by rc1
242
243        // 2. Downgrade rc1 to create a weak pointer (weak1)
244        let weak1_uninit_ptr = weak_shape.allocate()?;
245        let downgrade_into_fn = rc_def.vtable.downgrade_into_fn.unwrap();
246        // SAFETY: rc1_ptr points to a valid Rc, weak1_uninit_ptr is allocated for a Weak
247        let weak1_ptr = unsafe { downgrade_into_fn(rc1_ptr, weak1_uninit_ptr) };
248
249        // 3. Upgrade weak1 to create a second Rc (rc2)
250        let rc2_uninit_ptr = rc_shape.allocate()?;
251        let upgrade_into_fn = weak_def.vtable.upgrade_into_fn.unwrap();
252        // SAFETY: weak1_ptr points to a valid Weak, rc2_uninit_ptr is allocated for an Rc.
253        // Upgrade should succeed as rc1 still exists.
254        let rc2_ptr = unsafe { upgrade_into_fn(weak1_ptr, rc2_uninit_ptr) }
255            .expect("Upgrade should succeed while original Rc exists");
256
257        // Check the content of the upgraded Rc
258        let borrow_fn = rc_def.vtable.borrow_fn.unwrap();
259        // SAFETY: rc2_ptr points to a valid Rc<String>
260        let borrowed_ptr = unsafe { borrow_fn(rc2_ptr.as_const()) };
261        // SAFETY: borrowed_ptr points to a valid String
262        assert_eq!(unsafe { borrowed_ptr.get::<String>() }, "example");
263
264        // 4. Drop everything and free memory
265        let rc_drop_fn = (rc_shape.vtable.drop_in_place)().unwrap();
266        let weak_drop_fn = (weak_shape.vtable.drop_in_place)().unwrap();
267
268        unsafe {
269            // Drop Rcs
270            rc_drop_fn(rc1_ptr);
271            rc_shape.deallocate_mut(rc1_ptr)?;
272            rc_drop_fn(rc2_ptr);
273            rc_shape.deallocate_mut(rc2_ptr)?;
274
275            // Drop Weak
276            weak_drop_fn(weak1_ptr);
277            weak_shape.deallocate_mut(weak1_ptr)?;
278        }
279
280        Ok(())
281    }
282
283    #[test]
284    fn test_rc_vtable_3_downgrade_drop_try_upgrade() -> eyre::Result<()> {
285        facet_testhelpers::setup();
286
287        let rc_shape = <Rc<String>>::SHAPE;
288        let rc_def = rc_shape
289            .def
290            .into_smart_pointer()
291            .expect("Rc<T> should have a smart pointer definition");
292
293        let weak_shape = <RcWeak<String>>::SHAPE;
294        let weak_def = weak_shape
295            .def
296            .into_smart_pointer()
297            .expect("RcWeak<T> should have a smart pointer definition");
298
299        // 1. Create the strong Rc (rc1)
300        let rc1_uninit_ptr = rc_shape.allocate()?;
301        let new_into_fn = rc_def.vtable.new_into_fn.unwrap();
302        let mut value = String::from("example");
303        let rc1_ptr = unsafe { new_into_fn(rc1_uninit_ptr, PtrMut::new(&raw mut value)) };
304        core::mem::forget(value);
305
306        // 2. Downgrade rc1 to create a weak pointer (weak1)
307        let weak1_uninit_ptr = weak_shape.allocate()?;
308        let downgrade_into_fn = rc_def.vtable.downgrade_into_fn.unwrap();
309        // SAFETY: rc1_ptr is valid, weak1_uninit_ptr is allocated for Weak
310        let weak1_ptr = unsafe { downgrade_into_fn(rc1_ptr, weak1_uninit_ptr) };
311
312        // 3. Drop and free the strong pointer (rc1)
313        let rc_drop_fn = (rc_shape.vtable.drop_in_place)().unwrap();
314        unsafe {
315            rc_drop_fn(rc1_ptr);
316            rc_shape.deallocate_mut(rc1_ptr)?;
317        }
318
319        // 4. Attempt to upgrade the weak pointer (weak1)
320        let upgrade_into_fn = weak_def.vtable.upgrade_into_fn.unwrap();
321        let rc2_uninit_ptr = rc_shape.allocate()?;
322        // SAFETY: weak1_ptr is valid (though points to dropped data), rc2_uninit_ptr is allocated for Rc
323        let upgrade_result = unsafe { upgrade_into_fn(weak1_ptr, rc2_uninit_ptr) };
324
325        // Assert that the upgrade failed
326        assert!(
327            upgrade_result.is_none(),
328            "Upgrade should fail after the strong Rc is dropped"
329        );
330
331        // 5. Clean up: Deallocate the memory intended for the failed upgrade and drop/deallocate the weak pointer
332        let weak_drop_fn = (weak_shape.vtable.drop_in_place)().unwrap();
333        unsafe {
334            // Deallocate the *uninitialized* memory allocated for the failed upgrade attempt
335            rc_shape.deallocate_uninit(rc2_uninit_ptr)?;
336
337            // Drop and deallocate the weak pointer
338            weak_drop_fn(weak1_ptr);
339            weak_shape.deallocate_mut(weak1_ptr)?;
340        }
341
342        Ok(())
343    }
344}