rkyv/impls/alloc/rc/
mod.rs

1#[cfg(target_has_atomic = "ptr")]
2mod atomic;
3
4use core::alloc::LayoutError;
5
6use ptr_meta::{from_raw_parts_mut, Pointee};
7use rancor::{Fallible, Source};
8
9use crate::{
10    alloc::{alloc::alloc, boxed::Box, rc},
11    de::{FromMetadata, Metadata, Pooling, PoolingExt as _, SharedPointer},
12    rc::{ArchivedRc, ArchivedRcWeak, RcFlavor, RcResolver, RcWeakResolver},
13    ser::{Sharing, Writer},
14    traits::{ArchivePointee, LayoutRaw},
15    Archive, ArchiveUnsized, Deserialize, DeserializeUnsized, Place, Serialize,
16    SerializeUnsized,
17};
18
19// Rc
20
21impl<T: ArchiveUnsized + ?Sized> Archive for rc::Rc<T> {
22    type Archived = ArchivedRc<T::Archived, RcFlavor>;
23    type Resolver = RcResolver;
24
25    fn resolve(&self, resolver: Self::Resolver, out: Place<Self::Archived>) {
26        ArchivedRc::resolve_from_ref(self.as_ref(), resolver, out);
27    }
28}
29
30impl<T, S> Serialize<S> for rc::Rc<T>
31where
32    T: SerializeUnsized<S> + ?Sized + 'static,
33    S: Fallible + Writer + Sharing + ?Sized,
34    S::Error: Source,
35{
36    fn serialize(
37        &self,
38        serializer: &mut S,
39    ) -> Result<Self::Resolver, S::Error> {
40        ArchivedRc::<T::Archived, RcFlavor>::serialize_from_ref(
41            self.as_ref(),
42            serializer,
43        )
44    }
45}
46
47unsafe impl<T: LayoutRaw + Pointee + ?Sized> SharedPointer<T> for rc::Rc<T> {
48    fn alloc(metadata: T::Metadata) -> Result<*mut T, LayoutError> {
49        let layout = T::layout_raw(metadata)?;
50        let data_address = if layout.size() > 0 {
51            unsafe { alloc(layout) }
52        } else {
53            crate::polyfill::dangling(&layout).as_ptr()
54        };
55        let ptr = from_raw_parts_mut(data_address.cast(), metadata);
56        Ok(ptr)
57    }
58
59    unsafe fn from_value(ptr: *mut T) -> *mut T {
60        let rc = rc::Rc::<T>::from(unsafe { Box::from_raw(ptr) });
61        rc::Rc::into_raw(rc).cast_mut()
62    }
63
64    unsafe fn drop(ptr: *mut T) {
65        drop(unsafe { rc::Rc::from_raw(ptr) });
66    }
67}
68
69impl<T, D> Deserialize<rc::Rc<T>, D> for ArchivedRc<T::Archived, RcFlavor>
70where
71    T: ArchiveUnsized + LayoutRaw + Pointee + ?Sized + 'static,
72    T::Archived: DeserializeUnsized<T, D>,
73    T::Metadata: Into<Metadata> + FromMetadata,
74    D: Fallible + Pooling + ?Sized,
75    D::Error: Source,
76{
77    fn deserialize(&self, deserializer: &mut D) -> Result<rc::Rc<T>, D::Error> {
78        let raw_shared_ptr =
79            deserializer.deserialize_shared::<_, rc::Rc<T>>(self.get())?;
80        unsafe {
81            rc::Rc::<T>::increment_strong_count(raw_shared_ptr);
82        }
83        unsafe { Ok(rc::Rc::<T>::from_raw(raw_shared_ptr)) }
84    }
85}
86
87impl<T, U> PartialEq<rc::Rc<U>> for ArchivedRc<T, RcFlavor>
88where
89    T: ArchivePointee + PartialEq<U> + ?Sized,
90    U: ?Sized,
91{
92    fn eq(&self, other: &rc::Rc<U>) -> bool {
93        self.get().eq(other.as_ref())
94    }
95}
96
97// rc::Weak
98
99impl<T: ArchiveUnsized + ?Sized> Archive for rc::Weak<T> {
100    type Archived = ArchivedRcWeak<T::Archived, RcFlavor>;
101    type Resolver = RcWeakResolver;
102
103    fn resolve(&self, resolver: Self::Resolver, out: Place<Self::Archived>) {
104        ArchivedRcWeak::resolve_from_ref(
105            self.upgrade().as_ref().map(|v| v.as_ref()),
106            resolver,
107            out,
108        );
109    }
110}
111
112impl<T, S> Serialize<S> for rc::Weak<T>
113where
114    T: SerializeUnsized<S> + ?Sized + 'static,
115    S: Fallible + Writer + Sharing + ?Sized,
116    S::Error: Source,
117{
118    fn serialize(
119        &self,
120        serializer: &mut S,
121    ) -> Result<Self::Resolver, S::Error> {
122        ArchivedRcWeak::<T::Archived, RcFlavor>::serialize_from_ref(
123            self.upgrade().as_ref().map(|v| v.as_ref()),
124            serializer,
125        )
126    }
127}
128
129impl<T, D> Deserialize<rc::Weak<T>, D> for ArchivedRcWeak<T::Archived, RcFlavor>
130where
131    // Deserialize can only be implemented for sized types because weak pointers
132    // to unsized types don't have `new` functions.
133    T: ArchiveUnsized
134        + LayoutRaw
135        + Pointee // + ?Sized
136        + 'static,
137    T::Archived: DeserializeUnsized<T, D>,
138    T::Metadata: Into<Metadata> + FromMetadata,
139    D: Fallible + Pooling + ?Sized,
140    D::Error: Source,
141{
142    fn deserialize(
143        &self,
144        deserializer: &mut D,
145    ) -> Result<rc::Weak<T>, D::Error> {
146        Ok(match self.upgrade() {
147            None => rc::Weak::new(),
148            Some(r) => rc::Rc::downgrade(&r.deserialize(deserializer)?),
149        })
150    }
151}
152
153#[cfg(test)]
154mod tests {
155    use munge::munge;
156    use rancor::{Failure, Panic};
157
158    use crate::{
159        access_unchecked, access_unchecked_mut,
160        alloc::{
161            rc::{Rc, Weak},
162            string::{String, ToString},
163            vec,
164        },
165        api::{
166            deserialize_using,
167            test::{roundtrip, to_archived},
168        },
169        de::Pool,
170        rc::{ArchivedRc, ArchivedRcWeak},
171        to_bytes, Archive, Deserialize, Serialize,
172    };
173
174    #[test]
175    fn roundtrip_rc() {
176        #[derive(Debug, Eq, PartialEq, Archive, Deserialize, Serialize)]
177        #[rkyv(crate, compare(PartialEq), derive(Debug))]
178        struct Test {
179            a: Rc<u32>,
180            b: Rc<u32>,
181        }
182
183        let shared = Rc::new(10);
184        let value = Test {
185            a: shared.clone(),
186            b: shared.clone(),
187        };
188
189        to_archived(&value, |mut archived| {
190            assert_eq!(*archived, value);
191
192            munge!(let ArchivedTest { a, .. } = archived.as_mut());
193            unsafe {
194                *ArchivedRc::get_seal_unchecked(a) = 42u32.into();
195            }
196
197            assert_eq!(*archived.a, 42);
198            assert_eq!(*archived.b, 42);
199
200            munge!(let ArchivedTest { b, .. } = archived.as_mut());
201            unsafe {
202                *ArchivedRc::get_seal_unchecked(b) = 17u32.into();
203            }
204
205            assert_eq!(*archived.a, 17);
206            assert_eq!(*archived.b, 17);
207
208            let mut deserializer = Pool::new();
209            let deserialized = deserialize_using::<Test, _, Panic>(
210                &*archived,
211                &mut deserializer,
212            )
213            .unwrap();
214
215            assert_eq!(*deserialized.a, 17);
216            assert_eq!(*deserialized.b, 17);
217            assert_eq!(
218                &*deserialized.a as *const u32,
219                &*deserialized.b as *const u32
220            );
221            assert_eq!(Rc::strong_count(&deserialized.a), 3);
222            assert_eq!(Rc::strong_count(&deserialized.b), 3);
223            assert_eq!(Rc::weak_count(&deserialized.a), 0);
224            assert_eq!(Rc::weak_count(&deserialized.b), 0);
225
226            core::mem::drop(deserializer);
227
228            assert_eq!(*deserialized.a, 17);
229            assert_eq!(*deserialized.b, 17);
230            assert_eq!(
231                &*deserialized.a as *const u32,
232                &*deserialized.b as *const u32
233            );
234            assert_eq!(Rc::strong_count(&deserialized.a), 2);
235            assert_eq!(Rc::strong_count(&deserialized.b), 2);
236            assert_eq!(Rc::weak_count(&deserialized.a), 0);
237            assert_eq!(Rc::weak_count(&deserialized.b), 0);
238        });
239    }
240
241    #[test]
242    fn roundtrip_rc_zst() {
243        #[derive(Archive, Deserialize, Serialize, Debug, PartialEq)]
244        #[rkyv(crate, compare(PartialEq), derive(Debug))]
245        struct TestRcZST {
246            a: Rc<()>,
247            b: Rc<()>,
248        }
249
250        let rc_zst = Rc::new(());
251        roundtrip(&TestRcZST {
252            a: rc_zst.clone(),
253            b: rc_zst.clone(),
254        });
255    }
256
257    #[test]
258    fn roundtrip_unsized_shared_ptr() {
259        #[derive(Archive, Serialize, Deserialize, Debug, PartialEq)]
260        #[rkyv(crate, compare(PartialEq), derive(Debug))]
261        struct Test {
262            a: Rc<[String]>,
263            b: Rc<[String]>,
264        }
265
266        let rc_slice = Rc::<[String]>::from(
267            vec!["hello".to_string(), "world".to_string()].into_boxed_slice(),
268        );
269        let value = Test {
270            a: rc_slice.clone(),
271            b: rc_slice,
272        };
273
274        roundtrip(&value);
275    }
276
277    #[test]
278    fn roundtrip_unsized_shared_ptr_empty() {
279        #[derive(Archive, Serialize, Deserialize, Debug, PartialEq)]
280        #[rkyv(crate, compare(PartialEq), derive(Debug))]
281        struct Test {
282            a: Rc<[u32]>,
283            b: Rc<[u32]>,
284        }
285
286        let a_rc_slice = Rc::<[u32]>::from(vec![].into_boxed_slice());
287        let b_rc_slice = Rc::<[u32]>::from(vec![100].into_boxed_slice());
288        let value = Test {
289            a: a_rc_slice,
290            b: b_rc_slice.clone(),
291        };
292
293        roundtrip(&value);
294    }
295
296    #[test]
297    fn roundtrip_weak_ptr() {
298        #[derive(Archive, Serialize, Deserialize)]
299        #[rkyv(crate)]
300        struct Test {
301            a: Rc<u32>,
302            b: Weak<u32>,
303        }
304
305        let shared = Rc::new(10);
306        let value = Test {
307            a: shared.clone(),
308            b: Rc::downgrade(&shared),
309        };
310
311        let mut buf = to_bytes::<Panic>(&value).unwrap();
312
313        let archived =
314            unsafe { access_unchecked::<ArchivedTest>(buf.as_ref()) };
315        assert_eq!(*archived.a, 10);
316        assert!(archived.b.upgrade().is_some());
317        assert_eq!(**archived.b.upgrade().unwrap(), 10);
318
319        let mut mutable_archived =
320            unsafe { access_unchecked_mut::<ArchivedTest>(buf.as_mut()) };
321
322        munge!(let ArchivedTest { a, .. } = mutable_archived.as_mut());
323        unsafe {
324            *ArchivedRc::get_seal_unchecked(a) = 42u32.into();
325        }
326
327        let archived =
328            unsafe { access_unchecked::<ArchivedTest>(buf.as_ref()) };
329        assert_eq!(*archived.a, 42);
330        assert!(archived.b.upgrade().is_some());
331        assert_eq!(**archived.b.upgrade().unwrap(), 42);
332
333        let mut mutable_archived =
334            unsafe { access_unchecked_mut::<ArchivedTest>(buf.as_mut()) };
335        munge!(let ArchivedTest { b, .. } = mutable_archived.as_mut());
336        unsafe {
337            *ArchivedRc::get_seal_unchecked(
338                ArchivedRcWeak::upgrade_seal(b).unwrap(),
339            ) = 17u32.into();
340        }
341
342        let archived =
343            unsafe { access_unchecked::<ArchivedTest>(buf.as_ref()) };
344        assert_eq!(*archived.a, 17);
345        assert!(archived.b.upgrade().is_some());
346        assert_eq!(**archived.b.upgrade().unwrap(), 17);
347
348        let mut deserializer = Pool::new();
349        let deserialized =
350            deserialize_using::<Test, _, Panic>(archived, &mut deserializer)
351                .unwrap();
352
353        assert_eq!(*deserialized.a, 17);
354        assert!(deserialized.b.upgrade().is_some());
355        assert_eq!(*deserialized.b.upgrade().unwrap(), 17);
356        assert_eq!(
357            &*deserialized.a as *const u32,
358            &*deserialized.b.upgrade().unwrap() as *const u32
359        );
360        assert_eq!(Rc::strong_count(&deserialized.a), 2);
361        assert_eq!(Weak::strong_count(&deserialized.b), 2);
362        assert_eq!(Rc::weak_count(&deserialized.a), 1);
363        assert_eq!(Weak::weak_count(&deserialized.b), 1);
364
365        core::mem::drop(deserializer);
366
367        assert_eq!(*deserialized.a, 17);
368        assert!(deserialized.b.upgrade().is_some());
369        assert_eq!(*deserialized.b.upgrade().unwrap(), 17);
370        assert_eq!(
371            &*deserialized.a as *const u32,
372            &*deserialized.b.upgrade().unwrap() as *const u32
373        );
374        assert_eq!(Rc::strong_count(&deserialized.a), 1);
375        assert_eq!(Weak::strong_count(&deserialized.b), 1);
376        assert_eq!(Rc::weak_count(&deserialized.a), 1);
377        assert_eq!(Weak::weak_count(&deserialized.b), 1);
378    }
379
380    #[test]
381    fn serialize_cyclic_error() {
382        use rancor::{Fallible, Source};
383
384        use crate::{
385            de::Pooling,
386            ser::{Sharing, Writer},
387        };
388
389        #[derive(Archive, Serialize, Deserialize)]
390        #[rkyv(
391            crate,
392            serialize_bounds(
393                __S: Sharing + Writer,
394                <__S as Fallible>::Error: Source,
395            ),
396            deserialize_bounds(
397                __D: Pooling,
398                <__D as Fallible>::Error: Source,
399            )
400        )]
401        #[cfg_attr(
402            feature = "bytecheck",
403            rkyv(bytecheck(bounds(
404                __C: crate::validation::ArchiveContext
405                    + crate::validation::SharedContext,
406                <__C as Fallible>::Error: Source,
407            ))),
408        )]
409        struct Inner {
410            #[rkyv(omit_bounds)]
411            weak: Weak<Self>,
412        }
413
414        #[derive(Archive, Serialize, Deserialize)]
415        #[rkyv(crate)]
416        struct Outer {
417            inner: Rc<Inner>,
418        }
419
420        let value = Outer {
421            inner: Rc::new_cyclic(|weak| Inner { weak: weak.clone() }),
422        };
423
424        assert!(to_bytes::<Failure>(&value).is_err());
425    }
426
427    #[cfg(all(
428        feature = "bytecheck",
429        not(feature = "big_endian"),
430        not(any(feature = "pointer_width_16", feature = "pointer_width_64")),
431    ))]
432    #[test]
433    fn recursive_stack_overflow() {
434        use rancor::{Fallible, Source};
435
436        use crate::{
437            access,
438            de::Pooling,
439            util::Align,
440            validation::{ArchiveContext, SharedContext},
441        };
442
443        #[derive(Archive, Deserialize)]
444        #[rkyv(
445            crate,
446            bytecheck(bounds(__C: ArchiveContext + SharedContext)),
447            deserialize_bounds(
448                __D: Pooling,
449                <__D as Fallible>::Error: Source,
450            ),
451            derive(Debug),
452        )]
453        enum AllValues {
454            Rc(#[rkyv(omit_bounds)] Rc<AllValues>),
455        }
456
457        let data = Align([
458            0x00, 0x00, 0x00, 0xff, // B: AllValues::Rc
459            0xfc, 0xff, 0xff, 0xff, // RelPtr with offset -4 (B)
460            0x00, 0x00, 0xf6, 0xff, // A: AllValues::Rc
461            0xf4, 0xff, 0xff, 0xff, // RelPtr with offset -12 (B)
462        ]);
463        access::<ArchivedAllValues, Failure>(&*data).unwrap_err();
464    }
465}