facet_core/impls_alloc/
rc.rs

1use crate::{
2    Def, Facet, KnownSmartPointer, PtrConst, PtrMut, PtrUninit, Shape, SmartPointerDef,
3    SmartPointerFlags, SmartPointerVTable, TryBorrowInnerError, TryFromError, TryIntoInnerError,
4    Type, UserType, ValueVTable, value_vtable,
5};
6
7unsafe impl<'a, T: Facet<'a>> Facet<'a> for alloc::rc::Rc<T> {
8    const VTABLE: &'static ValueVTable = &const {
9        // Define the functions for transparent conversion between Rc<T> and T
10        unsafe fn try_from<'a, 'shape, 'src, 'dst, T: Facet<'a>>(
11            src_ptr: PtrConst<'src>,
12            src_shape: &'shape Shape<'shape>,
13            dst: PtrUninit<'dst>,
14        ) -> Result<PtrMut<'dst>, TryFromError<'shape>> {
15            if src_shape.id != T::SHAPE.id {
16                return Err(TryFromError::UnsupportedSourceShape {
17                    src_shape,
18                    expected: &[T::SHAPE],
19                });
20            }
21            let t = unsafe { src_ptr.read::<T>() };
22            let rc = alloc::rc::Rc::new(t);
23            Ok(unsafe { dst.put(rc) })
24        }
25
26        unsafe fn try_into_inner<'a, 'src, 'dst, T: Facet<'a>>(
27            src_ptr: PtrMut<'src>,
28            dst: PtrUninit<'dst>,
29        ) -> Result<PtrMut<'dst>, TryIntoInnerError> {
30            let rc = unsafe { src_ptr.get::<alloc::rc::Rc<T>>() };
31            match alloc::rc::Rc::try_unwrap(rc.clone()) {
32                Ok(t) => Ok(unsafe { dst.put(t) }),
33                Err(_) => Err(TryIntoInnerError::Unavailable),
34            }
35        }
36
37        unsafe fn try_borrow_inner<'a, 'src, T: Facet<'a>>(
38            src_ptr: PtrConst<'src>,
39        ) -> Result<PtrConst<'src>, TryBorrowInnerError> {
40            let rc = unsafe { src_ptr.get::<alloc::rc::Rc<T>>() };
41            Ok(PtrConst::new(&**rc))
42        }
43
44        let mut vtable = value_vtable!(alloc::rc::Rc<T>, |f, opts| {
45            write!(f, "Rc")?;
46            if let Some(opts) = opts.for_children() {
47                write!(f, "<")?;
48                (T::SHAPE.vtable.type_name)(f, opts)?;
49                write!(f, ">")?;
50            } else {
51                write!(f, "<…>")?;
52            }
53            Ok(())
54        });
55        vtable.try_from = Some(try_from::<T>);
56        vtable.try_into_inner = Some(try_into_inner::<T>);
57        vtable.try_borrow_inner = Some(try_borrow_inner::<T>);
58        vtable
59    };
60
61    const SHAPE: &'static crate::Shape<'static> = &const {
62        // Function to return inner type's shape
63        fn inner_shape<'a, T: Facet<'a>>() -> &'static Shape<'static> {
64            T::SHAPE
65        }
66
67        crate::Shape::builder_for_sized::<Self>()
68            .type_params(&[crate::TypeParam {
69                name: "T",
70                shape: || T::SHAPE,
71            }])
72            .ty(Type::User(UserType::Opaque))
73            .def(Def::SmartPointer(
74                SmartPointerDef::builder()
75                    .pointee(|| T::SHAPE)
76                    .flags(SmartPointerFlags::EMPTY)
77                    .known(KnownSmartPointer::Rc)
78                    .weak(|| <alloc::rc::Weak<T> as Facet>::SHAPE)
79                    .vtable(
80                        &const {
81                            SmartPointerVTable::builder()
82                                .borrow_fn(|this| {
83                                    let ptr = Self::as_ptr(unsafe { this.get() });
84                                    PtrConst::new(ptr)
85                                })
86                                .new_into_fn(|this, ptr| {
87                                    let t = unsafe { ptr.read::<T>() };
88                                    let rc = alloc::rc::Rc::new(t);
89                                    unsafe { this.put(rc) }
90                                })
91                                .downgrade_into_fn(|strong, weak| unsafe {
92                                    weak.put(alloc::rc::Rc::downgrade(strong.get::<Self>()))
93                                })
94                                .build()
95                        },
96                    )
97                    .build(),
98            ))
99            .inner(inner_shape::<T>)
100            .build()
101    };
102}
103
104unsafe impl<'a, T: Facet<'a>> Facet<'a> for alloc::rc::Weak<T> {
105    const VTABLE: &'static ValueVTable = &const {
106        value_vtable!(alloc::rc::Weak<T>, |f, opts| {
107            write!(f, "Weak")?;
108            if let Some(opts) = opts.for_children() {
109                write!(f, "<")?;
110                (T::SHAPE.vtable.type_name)(f, opts)?;
111                write!(f, ">")?;
112            } else {
113                write!(f, "<…>")?;
114            }
115            Ok(())
116        })
117    };
118
119    const SHAPE: &'static crate::Shape<'static> = &const {
120        // Function to return inner type's shape
121        fn inner_shape<'a, T: Facet<'a>>() -> &'static Shape<'static> {
122            T::SHAPE
123        }
124
125        crate::Shape::builder_for_sized::<Self>()
126            .type_params(&[crate::TypeParam {
127                name: "T",
128                shape: || T::SHAPE,
129            }])
130            .ty(Type::User(UserType::Opaque))
131            .def(Def::SmartPointer(
132                SmartPointerDef::builder()
133                    .pointee(|| T::SHAPE)
134                    .flags(SmartPointerFlags::WEAK)
135                    .known(KnownSmartPointer::RcWeak)
136                    .strong(|| <alloc::rc::Rc<T> as Facet>::SHAPE)
137                    .vtable(
138                        &const {
139                            SmartPointerVTable::builder()
140                                .upgrade_into_fn(|weak, strong| unsafe {
141                                    Some(strong.put(weak.get::<Self>().upgrade()?))
142                                })
143                                .build()
144                        },
145                    )
146                    .build(),
147            ))
148            .inner(inner_shape::<T>)
149            .build()
150    };
151}
152
153#[cfg(test)]
154mod tests {
155    use alloc::rc::{Rc, Weak as RcWeak};
156    use alloc::string::String;
157
158    use super::*;
159
160    #[test]
161    fn test_rc_type_params() {
162        let [type_param_1] = <Rc<i32>>::SHAPE.type_params else {
163            panic!("Rc<T> should only have 1 type param")
164        };
165        assert_eq!(type_param_1.shape(), i32::SHAPE);
166    }
167
168    #[test]
169    fn test_rc_vtable_1_new_borrow_drop() -> eyre::Result<()> {
170        facet_testhelpers::setup();
171
172        let rc_shape = <Rc<String>>::SHAPE;
173        let rc_def = rc_shape
174            .def
175            .into_smart_pointer()
176            .expect("Rc<T> should have a smart pointer definition");
177
178        // Allocate memory for the Rc
179        let rc_uninit_ptr = rc_shape.allocate()?;
180
181        // Get the function pointer for creating a new Rc from a value
182        let new_into_fn = rc_def
183            .vtable
184            .new_into_fn
185            .expect("Rc<T> should have new_into_fn");
186
187        // Create the value and initialize the Rc
188        let mut value = String::from("example");
189        let rc_ptr = unsafe { new_into_fn(rc_uninit_ptr, PtrMut::new(&raw mut value)) };
190        // The value now belongs to the Rc, prevent its drop
191        core::mem::forget(value);
192
193        // Get the function pointer for borrowing the inner value
194        let borrow_fn = rc_def
195            .vtable
196            .borrow_fn
197            .expect("Rc<T> should have borrow_fn");
198
199        // Borrow the inner value and check it
200        let borrowed_ptr = unsafe { borrow_fn(rc_ptr.as_const()) };
201        // SAFETY: borrowed_ptr points to a valid String within the Rc
202        assert_eq!(unsafe { borrowed_ptr.get::<String>() }, "example");
203
204        // Get the function pointer for dropping the Rc
205        let drop_fn = rc_shape
206            .vtable
207            .drop_in_place
208            .expect("Rc<T> should have drop_in_place");
209
210        // Drop the Rc in place
211        // SAFETY: rc_ptr points to a valid Rc<String>
212        unsafe { drop_fn(rc_ptr) };
213
214        // Deallocate the memory
215        // SAFETY: rc_ptr was allocated by rc_shape and is now dropped (but memory is still valid)
216        unsafe { rc_shape.deallocate_mut(rc_ptr)? };
217
218        Ok(())
219    }
220
221    #[test]
222    fn test_rc_vtable_2_downgrade_upgrade_drop() -> eyre::Result<()> {
223        facet_testhelpers::setup();
224
225        let rc_shape = <Rc<String>>::SHAPE;
226        let rc_def = rc_shape
227            .def
228            .into_smart_pointer()
229            .expect("Rc<T> should have a smart pointer definition");
230
231        let weak_shape = <RcWeak<String>>::SHAPE;
232        let weak_def = weak_shape
233            .def
234            .into_smart_pointer()
235            .expect("RcWeak<T> should have a smart pointer definition");
236
237        // 1. Create the first Rc (rc1)
238        let rc1_uninit_ptr = rc_shape.allocate()?;
239        let new_into_fn = rc_def.vtable.new_into_fn.unwrap();
240        let mut value = String::from("example");
241        let rc1_ptr = unsafe { new_into_fn(rc1_uninit_ptr, PtrMut::new(&raw mut value)) };
242        core::mem::forget(value); // Value now owned by rc1
243
244        // 2. Downgrade rc1 to create a weak pointer (weak1)
245        let weak1_uninit_ptr = weak_shape.allocate()?;
246        let downgrade_into_fn = rc_def.vtable.downgrade_into_fn.unwrap();
247        // SAFETY: rc1_ptr points to a valid Rc, weak1_uninit_ptr is allocated for a Weak
248        let weak1_ptr = unsafe { downgrade_into_fn(rc1_ptr, weak1_uninit_ptr) };
249
250        // 3. Upgrade weak1 to create a second Rc (rc2)
251        let rc2_uninit_ptr = rc_shape.allocate()?;
252        let upgrade_into_fn = weak_def.vtable.upgrade_into_fn.unwrap();
253        // SAFETY: weak1_ptr points to a valid Weak, rc2_uninit_ptr is allocated for an Rc.
254        // Upgrade should succeed as rc1 still exists.
255        let rc2_ptr = unsafe { upgrade_into_fn(weak1_ptr, rc2_uninit_ptr) }
256            .expect("Upgrade should succeed while original Rc exists");
257
258        // Check the content of the upgraded Rc
259        let borrow_fn = rc_def.vtable.borrow_fn.unwrap();
260        // SAFETY: rc2_ptr points to a valid Rc<String>
261        let borrowed_ptr = unsafe { borrow_fn(rc2_ptr.as_const()) };
262        // SAFETY: borrowed_ptr points to a valid String
263        assert_eq!(unsafe { borrowed_ptr.get::<String>() }, "example");
264
265        // 4. Drop everything and free memory
266        let rc_drop_fn = rc_shape.vtable.drop_in_place.unwrap();
267        let weak_drop_fn = weak_shape.vtable.drop_in_place.unwrap();
268
269        unsafe {
270            // Drop Rcs
271            rc_drop_fn(rc1_ptr);
272            rc_shape.deallocate_mut(rc1_ptr)?;
273            rc_drop_fn(rc2_ptr);
274            rc_shape.deallocate_mut(rc2_ptr)?;
275
276            // Drop Weak
277            weak_drop_fn(weak1_ptr);
278            weak_shape.deallocate_mut(weak1_ptr)?;
279        }
280
281        Ok(())
282    }
283
284    #[test]
285    fn test_rc_vtable_3_downgrade_drop_try_upgrade() -> eyre::Result<()> {
286        facet_testhelpers::setup();
287
288        let rc_shape = <Rc<String>>::SHAPE;
289        let rc_def = rc_shape
290            .def
291            .into_smart_pointer()
292            .expect("Rc<T> should have a smart pointer definition");
293
294        let weak_shape = <RcWeak<String>>::SHAPE;
295        let weak_def = weak_shape
296            .def
297            .into_smart_pointer()
298            .expect("RcWeak<T> should have a smart pointer definition");
299
300        // 1. Create the strong Rc (rc1)
301        let rc1_uninit_ptr = rc_shape.allocate()?;
302        let new_into_fn = rc_def.vtable.new_into_fn.unwrap();
303        let mut value = String::from("example");
304        let rc1_ptr = unsafe { new_into_fn(rc1_uninit_ptr, PtrMut::new(&raw mut value)) };
305        core::mem::forget(value);
306
307        // 2. Downgrade rc1 to create a weak pointer (weak1)
308        let weak1_uninit_ptr = weak_shape.allocate()?;
309        let downgrade_into_fn = rc_def.vtable.downgrade_into_fn.unwrap();
310        // SAFETY: rc1_ptr is valid, weak1_uninit_ptr is allocated for Weak
311        let weak1_ptr = unsafe { downgrade_into_fn(rc1_ptr, weak1_uninit_ptr) };
312
313        // 3. Drop and free the strong pointer (rc1)
314        let rc_drop_fn = rc_shape.vtable.drop_in_place.unwrap();
315        unsafe {
316            rc_drop_fn(rc1_ptr);
317            rc_shape.deallocate_mut(rc1_ptr)?;
318        }
319
320        // 4. Attempt to upgrade the weak pointer (weak1)
321        let upgrade_into_fn = weak_def.vtable.upgrade_into_fn.unwrap();
322        let rc2_uninit_ptr = rc_shape.allocate()?;
323        // SAFETY: weak1_ptr is valid (though points to dropped data), rc2_uninit_ptr is allocated for Rc
324        let upgrade_result = unsafe { upgrade_into_fn(weak1_ptr, rc2_uninit_ptr) };
325
326        // Assert that the upgrade failed
327        assert!(
328            upgrade_result.is_none(),
329            "Upgrade should fail after the strong Rc is dropped"
330        );
331
332        // 5. Clean up: Deallocate the memory intended for the failed upgrade and drop/deallocate the weak pointer
333        let weak_drop_fn = weak_shape.vtable.drop_in_place.unwrap();
334        unsafe {
335            // Deallocate the *uninitialized* memory allocated for the failed upgrade attempt
336            rc_shape.deallocate_uninit(rc2_uninit_ptr)?;
337
338            // Drop and deallocate the weak pointer
339            weak_drop_fn(weak1_ptr);
340            weak_shape.deallocate_mut(weak1_ptr)?;
341        }
342
343        Ok(())
344    }
345}