facet_core/impls_alloc/
smartptr.rs

1use core::alloc::Layout;
2
3use crate::{
4    ConstTypeId, Def, Facet, KnownSmartPointer, Opaque, PtrConst, SmartPointerDef,
5    SmartPointerFlags, SmartPointerVTable, value_vtable,
6};
7
8unsafe impl<'a, T: Facet<'a>> Facet<'a> for alloc::sync::Arc<T> {
9    const SHAPE: &'static crate::Shape = &const {
10        crate::Shape::builder()
11            .id(ConstTypeId::of::<Self>())
12            .layout(Layout::new::<Self>())
13            .type_params(&[crate::TypeParam {
14                name: "T",
15                shape: || T::SHAPE,
16            }])
17            .def(Def::SmartPointer(
18                SmartPointerDef::builder()
19                    .pointee(T::SHAPE)
20                    .flags(SmartPointerFlags::ATOMIC)
21                    .known(KnownSmartPointer::Arc)
22                    .weak(|| <alloc::sync::Weak<T> as Facet>::SHAPE)
23                    .vtable(
24                        &const {
25                            SmartPointerVTable::builder()
26                                .borrow_fn(|this| {
27                                    let ptr = Self::as_ptr(unsafe { this.get() });
28                                    PtrConst::new(ptr)
29                                })
30                                .new_into_fn(|this, ptr| {
31                                    let t = unsafe { ptr.read::<T>() };
32                                    let arc = alloc::sync::Arc::new(t);
33                                    unsafe { this.put(arc) }
34                                })
35                                .downgrade_into_fn(|strong, weak| unsafe {
36                                    weak.put(alloc::sync::Arc::downgrade(strong.get::<Self>()))
37                                })
38                                .build()
39                        },
40                    )
41                    .build(),
42            ))
43            .vtable(value_vtable!(alloc::sync::Arc<T>, |f, _opts| write!(
44                f,
45                "Arc"
46            )))
47            .build()
48    };
49}
50
51unsafe impl<'a, T: Facet<'a>> Facet<'a> for alloc::sync::Weak<T> {
52    const SHAPE: &'static crate::Shape = &const {
53        crate::Shape::builder()
54            .id(ConstTypeId::of::<Self>())
55            .layout(Layout::new::<Self>())
56            .type_params(&[crate::TypeParam {
57                name: "T",
58                shape: || T::SHAPE,
59            }])
60            .def(Def::SmartPointer(
61                SmartPointerDef::builder()
62                    .pointee(T::SHAPE)
63                    .flags(SmartPointerFlags::ATOMIC.union(SmartPointerFlags::WEAK))
64                    .known(KnownSmartPointer::ArcWeak)
65                    .strong(|| <alloc::sync::Arc<T> as Facet>::SHAPE)
66                    .vtable(
67                        &const {
68                            SmartPointerVTable::builder()
69                                .upgrade_into_fn(|weak, strong| unsafe {
70                                    Some(strong.put(weak.get::<Self>().upgrade()?))
71                                })
72                                .build()
73                        },
74                    )
75                    .build(),
76            ))
77            .vtable(value_vtable!(alloc::sync::Arc<T>, |f, _opts| write!(
78                f,
79                "Arc"
80            )))
81            .build()
82    };
83}
84
85unsafe impl<'a, T: 'a> Facet<'a> for Opaque<alloc::sync::Arc<T>> {
86    const SHAPE: &'static crate::Shape = &const {
87        crate::Shape::builder()
88            .id(ConstTypeId::of::<Self>())
89            .layout(Layout::new::<Self>())
90            .def(Def::SmartPointer(
91                SmartPointerDef::builder()
92                    .flags(SmartPointerFlags::ATOMIC)
93                    .known(KnownSmartPointer::Arc)
94                    .vtable(
95                        &const {
96                            SmartPointerVTable::builder()
97                                .borrow_fn(|this| {
98                                    let ptr = alloc::sync::Arc::<T>::as_ptr(unsafe { this.get() });
99                                    PtrConst::new(ptr)
100                                })
101                                .new_into_fn(|this, ptr| {
102                                    let t = unsafe { ptr.read::<T>() };
103                                    let arc = alloc::sync::Arc::new(t);
104                                    unsafe { this.put(arc) }
105                                })
106                                .build()
107                        },
108                    )
109                    .build(),
110            ))
111            .vtable(value_vtable!(alloc::sync::Arc<T>, |f, _opts| write!(
112                f,
113                "Arc"
114            )))
115            .build()
116    };
117}
118
119unsafe impl<'a, T: Facet<'a>> Facet<'a> for alloc::rc::Rc<T> {
120    const SHAPE: &'static crate::Shape = &const {
121        crate::Shape::builder()
122            .id(ConstTypeId::of::<Self>())
123            .layout(Layout::new::<Self>())
124            .type_params(&[crate::TypeParam {
125                name: "T",
126                shape: || T::SHAPE,
127            }])
128            .def(Def::SmartPointer(
129                SmartPointerDef::builder()
130                    .pointee(T::SHAPE)
131                    .flags(SmartPointerFlags::EMPTY)
132                    .known(KnownSmartPointer::Rc)
133                    .weak(|| <alloc::rc::Weak<T> as Facet>::SHAPE)
134                    .vtable(
135                        &const {
136                            SmartPointerVTable::builder()
137                                .borrow_fn(|this| {
138                                    let ptr = Self::as_ptr(unsafe { this.get() });
139                                    PtrConst::new(ptr)
140                                })
141                                .new_into_fn(|this, ptr| {
142                                    let t = unsafe { ptr.read::<T>() };
143                                    let rc = alloc::rc::Rc::new(t);
144                                    unsafe { this.put(rc) }
145                                })
146                                .downgrade_into_fn(|strong, weak| unsafe {
147                                    weak.put(alloc::rc::Rc::downgrade(strong.get::<Self>()))
148                                })
149                                .build()
150                        },
151                    )
152                    .build(),
153            ))
154            .vtable(value_vtable!(alloc::rc::Rc<T>, |f, _opts| write!(f, "Rc")))
155            .build()
156    };
157}
158
159unsafe impl<'a, T: Facet<'a>> Facet<'a> for alloc::rc::Weak<T> {
160    const SHAPE: &'static crate::Shape = &const {
161        crate::Shape::builder()
162            .id(ConstTypeId::of::<Self>())
163            .layout(Layout::new::<Self>())
164            .type_params(&[crate::TypeParam {
165                name: "T",
166                shape: || T::SHAPE,
167            }])
168            .def(Def::SmartPointer(
169                SmartPointerDef::builder()
170                    .pointee(T::SHAPE)
171                    .flags(SmartPointerFlags::WEAK)
172                    .known(KnownSmartPointer::RcWeak)
173                    .strong(|| <alloc::rc::Rc<T> as Facet>::SHAPE)
174                    .vtable(
175                        &const {
176                            SmartPointerVTable::builder()
177                                .upgrade_into_fn(|weak, strong| unsafe {
178                                    Some(strong.put(weak.get::<Self>().upgrade()?))
179                                })
180                                .build()
181                        },
182                    )
183                    .build(),
184            ))
185            .vtable(value_vtable!(alloc::rc::Rc<T>, |f, _opts| write!(f, "Rc")))
186            .build()
187    };
188}
189
190unsafe impl<'a, T: 'a> Facet<'a> for Opaque<alloc::rc::Rc<T>> {
191    const SHAPE: &'static crate::Shape = &const {
192        crate::Shape::builder()
193            .id(ConstTypeId::of::<Self>())
194            .layout(Layout::new::<Self>())
195            .def(Def::SmartPointer(
196                SmartPointerDef::builder()
197                    .known(KnownSmartPointer::Rc)
198                    .vtable(
199                        &const {
200                            SmartPointerVTable::builder()
201                                .borrow_fn(|this| {
202                                    let ptr = alloc::rc::Rc::<T>::as_ptr(unsafe { this.get() });
203                                    PtrConst::new(ptr)
204                                })
205                                .new_into_fn(|this, ptr| {
206                                    let t = unsafe { ptr.read::<T>() };
207                                    let rc = alloc::rc::Rc::new(t);
208                                    unsafe { this.put(rc) }
209                                })
210                                .build()
211                        },
212                    )
213                    .build(),
214            ))
215            .vtable(value_vtable!(alloc::rc::Rc<T>, |f, _opts| write!(f, "Rc")))
216            .build()
217    };
218}
219
220#[cfg(test)]
221mod tests {
222    use alloc::rc::{Rc, Weak as RcWeak};
223    use alloc::string::String;
224    use alloc::sync::{Arc, Weak as ArcWeak};
225    use core::mem::MaybeUninit;
226
227    use super::*;
228
229    use crate::PtrUninit;
230
231    #[test]
232    fn test_arc_type_params() {
233        let [type_param_1] = <Arc<i32>>::SHAPE.type_params else {
234            panic!("Arc<T> should only have 1 type param")
235        };
236        assert_eq!(type_param_1.shape(), i32::SHAPE);
237    }
238
239    #[test]
240    fn test_arc_vtable() {
241        facet_testhelpers::setup();
242
243        let arc_shape = <Arc<String>>::SHAPE;
244        let arc_def = arc_shape
245            .def
246            .into_smart_pointer()
247            .expect("Arc<T> should have a smart pointer definition");
248
249        let weak_shape = <ArcWeak<String>>::SHAPE;
250        let weak_def = weak_shape
251            .def
252            .into_smart_pointer()
253            .expect("ArcWeak<T> should have a smart pointer definition");
254
255        // Keep this alive as long as the Arc inside it is used
256        let mut arc_storage = MaybeUninit::<Arc<String>>::zeroed();
257        let arc_ptr = unsafe {
258            let arc_uninit_ptr = PtrUninit::from_maybe_uninit(&mut arc_storage);
259
260            let value = String::from("example");
261            let value_ptr = PtrConst::new(&raw const value);
262
263            // SAFETY:
264            // - `arc_uninit_ptr` has layout Arc<String>
265            // - `value_ptr` is String
266            // - `value_ptr` is deallocated
267            let returned_ptr = arc_def
268                .vtable
269                .new_into_fn
270                .expect("Arc<T> should have new_into_fn vtable function")(
271                arc_uninit_ptr,
272                value_ptr,
273            );
274
275            // Don't run the destructor
276            core::mem::forget(value);
277
278            // Test correctness of the return value of new_into_fn
279            // SAFETY: Using correct type Arc<String>
280            assert_eq!(
281                returned_ptr.as_ptr(),
282                arc_uninit_ptr.as_byte_ptr() as *const Arc<String>
283            );
284
285            returned_ptr
286        };
287
288        unsafe {
289            // SAFETY: `arc_ptr` is valid
290            let borrowed = arc_def
291                .vtable
292                .borrow_fn
293                .expect("Arc<T> should have borrow_fn vtable function")(
294                arc_ptr.as_const()
295            );
296            assert_eq!(borrowed.get::<String>(), "example");
297        }
298
299        // Keep this alive as long as the RcWeak inside it is used
300        let mut new_arc_storage = MaybeUninit::<ArcWeak<String>>::zeroed();
301        let weak_ptr = unsafe {
302            let weak_uninit_ptr = PtrUninit::from_maybe_uninit(&mut new_arc_storage);
303
304            let returned_ptr = arc_def
305                .vtable
306                .downgrade_into_fn
307                .expect("Arc<T> should have downgrade_into_fn vtable function")(
308                arc_ptr,
309                weak_uninit_ptr,
310            );
311
312            // Test correctness of the return value of downgrade_into_fn
313            // SAFETY: Using correct type ArcWeak<String>
314            assert_eq!(
315                returned_ptr.as_ptr(),
316                weak_uninit_ptr.as_byte_ptr() as *const ArcWeak<String>
317            );
318
319            returned_ptr
320        };
321
322        {
323            let mut new_arc_storage = MaybeUninit::<Arc<String>>::zeroed();
324            let new_arc_ptr = unsafe {
325                let new_arc_uninit_ptr = PtrUninit::from_maybe_uninit(&mut new_arc_storage);
326
327                // SAFETY: `weak_ptr` is valid and `new_arc_uninit_ptr` has layout Weak<String>
328                let returned_ptr = weak_def
329                    .vtable
330                    .upgrade_into_fn
331                    .expect("ArcWeak<T> should have upgrade_into_fn vtable function")(
332                    weak_ptr,
333                    new_arc_uninit_ptr,
334                )
335                .expect("Upgrade should be successful");
336
337                // Test correctness of the return value of upgrade_into_fn
338                // SAFETY: Using correct type Arc<String>
339                assert_eq!(
340                    returned_ptr.as_ptr(),
341                    new_arc_uninit_ptr.as_byte_ptr() as *const Arc<String>
342                );
343
344                returned_ptr
345            };
346
347            unsafe {
348                // SAFETY: `new_arc_ptr` is valid
349                let borrowed = arc_def
350                    .vtable
351                    .borrow_fn
352                    .expect("Arc<T> should have borrow_fn vtable function")(
353                    new_arc_ptr.as_const()
354                );
355                assert_eq!(borrowed.get::<String>(), "example");
356            }
357
358            unsafe {
359                // SAFETY: Proper value at `arc_ptr`, which is not accessed after this
360                arc_shape
361                    .vtable
362                    .drop_in_place
363                    .expect("Arc<T> should have drop_in_place vtable function")(
364                    new_arc_ptr
365                );
366            }
367        }
368
369        unsafe {
370            // SAFETY: Proper value at `arc_ptr`, which is not accessed after this
371            arc_shape
372                .vtable
373                .drop_in_place
374                .expect("Arc<T> should have drop_in_place vtable function")(arc_ptr);
375        }
376
377        unsafe {
378            let mut new_arc_storage = MaybeUninit::<Arc<String>>::zeroed();
379            let new_arc_uninit_ptr = PtrUninit::from_maybe_uninit(&mut new_arc_storage);
380
381            // SAFETY: `weak_ptr` is valid and `new_arc_uninit_ptr` has layout Weak<String>
382            if weak_def
383                .vtable
384                .upgrade_into_fn
385                .expect("ArcWeak<T> should have upgrade_into_fn vtable function")(
386                weak_ptr,
387                new_arc_uninit_ptr,
388            )
389            .is_some()
390            {
391                panic!("Upgrade should be unsuccessful")
392            }
393        };
394
395        unsafe {
396            // SAFETY: Proper value at `weak_ptr`, which is not accessed after this
397            weak_shape
398                .vtable
399                .drop_in_place
400                .expect("ArcWeak<T> should have drop_in_place vtable function")(
401                weak_ptr
402            );
403        }
404    }
405
406    #[test]
407    fn test_rc_type_params() {
408        let [type_param_1] = <Rc<i32>>::SHAPE.type_params else {
409            panic!("Rc<T> should only have 1 type param")
410        };
411        assert_eq!(type_param_1.shape(), i32::SHAPE);
412    }
413
414    #[test]
415    fn test_rc_vtable() {
416        facet_testhelpers::setup();
417
418        let rc_shape = <Rc<String>>::SHAPE;
419        let rc_def = rc_shape
420            .def
421            .into_smart_pointer()
422            .expect("Rc<T> should have a smart pointer definition");
423
424        let weak_shape = <RcWeak<String>>::SHAPE;
425        let weak_def = weak_shape
426            .def
427            .into_smart_pointer()
428            .expect("RcWeak<T> should have a smart pointer definition");
429
430        // Keep this alive as long as the Rc inside it is used
431        let mut rc_storage = MaybeUninit::<Rc<String>>::zeroed();
432        let rc_ptr = unsafe {
433            let rc_uninit_ptr = PtrUninit::from_maybe_uninit(&mut rc_storage);
434
435            let value = String::from("example");
436            let value_ptr = PtrConst::new(&raw const value);
437
438            // SAFETY:
439            // - `rc_uninit_ptr` has layout Rc<String>
440            // - `value_ptr` is String
441            // - `value_ptr` is deallocated after this without running the destructor
442            let returned_ptr = rc_def
443                .vtable
444                .new_into_fn
445                .expect("Rc<T> should have new_into_fn vtable function")(
446                rc_uninit_ptr, value_ptr
447            );
448
449            // Don't run the destructor
450            core::mem::forget(value);
451
452            // Test correctness of the return value of new_into_fn
453            // SAFETY: Using correct type Rc<String>
454            assert_eq!(
455                returned_ptr.as_ptr(),
456                rc_uninit_ptr.as_byte_ptr() as *const Rc<String>
457            );
458
459            returned_ptr
460        };
461
462        unsafe {
463            // SAFETY: `rc_ptr` is valid
464            let borrowed = rc_def
465                .vtable
466                .borrow_fn
467                .expect("Rc<T> should have borrow_fn vtable function")(
468                rc_ptr.as_const()
469            );
470            assert_eq!(borrowed.get::<String>(), "example");
471        }
472
473        // Keep this alive as long as the RcWeak inside it is used
474        let mut new_rc_storage = MaybeUninit::<RcWeak<String>>::zeroed();
475        let weak_ptr = unsafe {
476            let weak_uninit_ptr = PtrUninit::from_maybe_uninit(&mut new_rc_storage);
477
478            let returned_ptr = rc_def
479                .vtable
480                .downgrade_into_fn
481                .expect("Rc<T> should have downgrade_into_fn vtable function")(
482                rc_ptr,
483                weak_uninit_ptr,
484            );
485
486            // Test correctness of the return value of downgrade_into_fn
487            // SAFETY: Using correct type RcWeak<String>
488            assert_eq!(
489                returned_ptr.as_ptr(),
490                weak_uninit_ptr.as_byte_ptr() as *const RcWeak<String>
491            );
492
493            returned_ptr
494        };
495
496        {
497            let mut new_rc_storage = MaybeUninit::<Rc<String>>::zeroed();
498            let new_rc_ptr = unsafe {
499                let new_rc_uninit_ptr = PtrUninit::from_maybe_uninit(&mut new_rc_storage);
500
501                // SAFETY: `weak_ptr` is valid and `new_rc_uninit_ptr` has layout Weak<String>
502                let returned_ptr = weak_def
503                    .vtable
504                    .upgrade_into_fn
505                    .expect("RcWeak<T> should have upgrade_into_fn vtable function")(
506                    weak_ptr,
507                    new_rc_uninit_ptr,
508                )
509                .expect("Upgrade should be successful");
510
511                // Test correctness of the return value of upgrade_into_fn
512                // SAFETY: Using correct type Rc<String>
513                assert_eq!(
514                    returned_ptr.as_ptr(),
515                    new_rc_uninit_ptr.as_byte_ptr() as *const Rc<String>
516                );
517
518                returned_ptr
519            };
520
521            unsafe {
522                // SAFETY: `new_rc_ptr` is valid
523                let borrowed = rc_def
524                    .vtable
525                    .borrow_fn
526                    .expect("Rc<T> should have borrow_fn vtable function")(
527                    new_rc_ptr.as_const()
528                );
529                assert_eq!(borrowed.get::<String>(), "example");
530            }
531
532            unsafe {
533                // SAFETY: Proper value at `rc_ptr`, which is not accessed after this
534                rc_shape
535                    .vtable
536                    .drop_in_place
537                    .expect("Rc<T> should have drop_in_place vtable function")(
538                    new_rc_ptr
539                );
540            }
541        }
542
543        unsafe {
544            // SAFETY: Proper value at `rc_ptr`, which is not accessed after this
545            rc_shape
546                .vtable
547                .drop_in_place
548                .expect("Rc<T> should have drop_in_place vtable function")(rc_ptr);
549        }
550
551        unsafe {
552            let mut new_rc_storage = MaybeUninit::<Rc<String>>::zeroed();
553            let new_rc_uninit_ptr = PtrUninit::from_maybe_uninit(&mut new_rc_storage);
554
555            // SAFETY: `weak_ptr` is valid and `new_rc_uninit_ptr` has layout Weak<String>
556            if weak_def
557                .vtable
558                .upgrade_into_fn
559                .expect("RcWeak<T> should have upgrade_into_fn vtable function")(
560                weak_ptr,
561                new_rc_uninit_ptr,
562            )
563            .is_some()
564            {
565                panic!("Upgrade should be unsuccessful")
566            }
567        };
568
569        unsafe {
570            // SAFETY: Proper value at `weak_ptr`, which is not accessed after this
571            weak_shape
572                .vtable
573                .drop_in_place
574                .expect("RcWeak<T> should have drop_in_place vtable function")(weak_ptr);
575        }
576    }
577}