1#[cfg(target_has_atomic = "ptr")]
2mod atomic;
3
4use core::alloc::LayoutError;
5
6use ptr_meta::{from_raw_parts_mut, Pointee};
7use rancor::{Fallible, Source};
8
9use crate::{
10 alloc::{alloc::alloc, boxed::Box, rc},
11 de::{FromMetadata, Metadata, Pooling, PoolingExt as _, SharedPointer},
12 rc::{ArchivedRc, ArchivedRcWeak, RcFlavor, RcResolver, RcWeakResolver},
13 ser::{Sharing, Writer},
14 traits::{ArchivePointee, LayoutRaw},
15 Archive, ArchiveUnsized, Deserialize, DeserializeUnsized, Place, Serialize,
16 SerializeUnsized,
17};
18
19impl<T: ArchiveUnsized + ?Sized> Archive for rc::Rc<T> {
22 type Archived = ArchivedRc<T::Archived, RcFlavor>;
23 type Resolver = RcResolver;
24
25 fn resolve(&self, resolver: Self::Resolver, out: Place<Self::Archived>) {
26 ArchivedRc::resolve_from_ref(self.as_ref(), resolver, out);
27 }
28}
29
30impl<T, S> Serialize<S> for rc::Rc<T>
31where
32 T: SerializeUnsized<S> + ?Sized + 'static,
33 S: Fallible + Writer + Sharing + ?Sized,
34 S::Error: Source,
35{
36 fn serialize(
37 &self,
38 serializer: &mut S,
39 ) -> Result<Self::Resolver, S::Error> {
40 ArchivedRc::<T::Archived, RcFlavor>::serialize_from_ref(
41 self.as_ref(),
42 serializer,
43 )
44 }
45}
46
47unsafe impl<T: LayoutRaw + Pointee + ?Sized> SharedPointer<T> for rc::Rc<T> {
48 fn alloc(metadata: T::Metadata) -> Result<*mut T, LayoutError> {
49 let layout = T::layout_raw(metadata)?;
50 let data_address = if layout.size() > 0 {
51 unsafe { alloc(layout) }
52 } else {
53 crate::polyfill::dangling(&layout).as_ptr()
54 };
55 let ptr = from_raw_parts_mut(data_address.cast(), metadata);
56 Ok(ptr)
57 }
58
59 unsafe fn from_value(ptr: *mut T) -> *mut T {
60 let rc = rc::Rc::<T>::from(unsafe { Box::from_raw(ptr) });
61 rc::Rc::into_raw(rc).cast_mut()
62 }
63
64 unsafe fn drop(ptr: *mut T) {
65 drop(unsafe { rc::Rc::from_raw(ptr) });
66 }
67}
68
69impl<T, D> Deserialize<rc::Rc<T>, D> for ArchivedRc<T::Archived, RcFlavor>
70where
71 T: ArchiveUnsized + LayoutRaw + Pointee + ?Sized + 'static,
72 T::Archived: DeserializeUnsized<T, D>,
73 T::Metadata: Into<Metadata> + FromMetadata,
74 D: Fallible + Pooling + ?Sized,
75 D::Error: Source,
76{
77 fn deserialize(&self, deserializer: &mut D) -> Result<rc::Rc<T>, D::Error> {
78 let raw_shared_ptr =
79 deserializer.deserialize_shared::<_, rc::Rc<T>>(self.get())?;
80 unsafe {
81 rc::Rc::<T>::increment_strong_count(raw_shared_ptr);
82 }
83 unsafe { Ok(rc::Rc::<T>::from_raw(raw_shared_ptr)) }
84 }
85}
86
87impl<T, U> PartialEq<rc::Rc<U>> for ArchivedRc<T, RcFlavor>
88where
89 T: ArchivePointee + PartialEq<U> + ?Sized,
90 U: ?Sized,
91{
92 fn eq(&self, other: &rc::Rc<U>) -> bool {
93 self.get().eq(other.as_ref())
94 }
95}
96
97impl<T: ArchiveUnsized + ?Sized> Archive for rc::Weak<T> {
100 type Archived = ArchivedRcWeak<T::Archived, RcFlavor>;
101 type Resolver = RcWeakResolver;
102
103 fn resolve(&self, resolver: Self::Resolver, out: Place<Self::Archived>) {
104 ArchivedRcWeak::resolve_from_ref(
105 self.upgrade().as_ref().map(|v| v.as_ref()),
106 resolver,
107 out,
108 );
109 }
110}
111
112impl<T, S> Serialize<S> for rc::Weak<T>
113where
114 T: SerializeUnsized<S> + ?Sized + 'static,
115 S: Fallible + Writer + Sharing + ?Sized,
116 S::Error: Source,
117{
118 fn serialize(
119 &self,
120 serializer: &mut S,
121 ) -> Result<Self::Resolver, S::Error> {
122 ArchivedRcWeak::<T::Archived, RcFlavor>::serialize_from_ref(
123 self.upgrade().as_ref().map(|v| v.as_ref()),
124 serializer,
125 )
126 }
127}
128
129impl<T, D> Deserialize<rc::Weak<T>, D> for ArchivedRcWeak<T::Archived, RcFlavor>
130where
131 T: ArchiveUnsized
134 + LayoutRaw
135 + Pointee + 'static,
137 T::Archived: DeserializeUnsized<T, D>,
138 T::Metadata: Into<Metadata> + FromMetadata,
139 D: Fallible + Pooling + ?Sized,
140 D::Error: Source,
141{
142 fn deserialize(
143 &self,
144 deserializer: &mut D,
145 ) -> Result<rc::Weak<T>, D::Error> {
146 Ok(match self.upgrade() {
147 None => rc::Weak::new(),
148 Some(r) => rc::Rc::downgrade(&r.deserialize(deserializer)?),
149 })
150 }
151}
152
153#[cfg(test)]
154mod tests {
155 use munge::munge;
156 use rancor::{Failure, Panic};
157
158 use crate::{
159 access_unchecked, access_unchecked_mut,
160 alloc::{
161 rc::{Rc, Weak},
162 string::{String, ToString},
163 vec,
164 },
165 api::{
166 deserialize_using,
167 test::{roundtrip, to_archived},
168 },
169 de::Pool,
170 rc::{ArchivedRc, ArchivedRcWeak},
171 to_bytes, Archive, Deserialize, Serialize,
172 };
173
174 #[test]
175 fn roundtrip_rc() {
176 #[derive(Debug, Eq, PartialEq, Archive, Deserialize, Serialize)]
177 #[rkyv(crate, compare(PartialEq), derive(Debug))]
178 struct Test {
179 a: Rc<u32>,
180 b: Rc<u32>,
181 }
182
183 let shared = Rc::new(10);
184 let value = Test {
185 a: shared.clone(),
186 b: shared.clone(),
187 };
188
189 to_archived(&value, |mut archived| {
190 assert_eq!(*archived, value);
191
192 munge!(let ArchivedTest { a, .. } = archived.as_mut());
193 unsafe {
194 *ArchivedRc::get_seal_unchecked(a) = 42u32.into();
195 }
196
197 assert_eq!(*archived.a, 42);
198 assert_eq!(*archived.b, 42);
199
200 munge!(let ArchivedTest { b, .. } = archived.as_mut());
201 unsafe {
202 *ArchivedRc::get_seal_unchecked(b) = 17u32.into();
203 }
204
205 assert_eq!(*archived.a, 17);
206 assert_eq!(*archived.b, 17);
207
208 let mut deserializer = Pool::new();
209 let deserialized = deserialize_using::<Test, _, Panic>(
210 &*archived,
211 &mut deserializer,
212 )
213 .unwrap();
214
215 assert_eq!(*deserialized.a, 17);
216 assert_eq!(*deserialized.b, 17);
217 assert_eq!(
218 &*deserialized.a as *const u32,
219 &*deserialized.b as *const u32
220 );
221 assert_eq!(Rc::strong_count(&deserialized.a), 3);
222 assert_eq!(Rc::strong_count(&deserialized.b), 3);
223 assert_eq!(Rc::weak_count(&deserialized.a), 0);
224 assert_eq!(Rc::weak_count(&deserialized.b), 0);
225
226 core::mem::drop(deserializer);
227
228 assert_eq!(*deserialized.a, 17);
229 assert_eq!(*deserialized.b, 17);
230 assert_eq!(
231 &*deserialized.a as *const u32,
232 &*deserialized.b as *const u32
233 );
234 assert_eq!(Rc::strong_count(&deserialized.a), 2);
235 assert_eq!(Rc::strong_count(&deserialized.b), 2);
236 assert_eq!(Rc::weak_count(&deserialized.a), 0);
237 assert_eq!(Rc::weak_count(&deserialized.b), 0);
238 });
239 }
240
241 #[test]
242 fn roundtrip_rc_zst() {
243 #[derive(Archive, Deserialize, Serialize, Debug, PartialEq)]
244 #[rkyv(crate, compare(PartialEq), derive(Debug))]
245 struct TestRcZST {
246 a: Rc<()>,
247 b: Rc<()>,
248 }
249
250 let rc_zst = Rc::new(());
251 roundtrip(&TestRcZST {
252 a: rc_zst.clone(),
253 b: rc_zst.clone(),
254 });
255 }
256
257 #[test]
258 fn roundtrip_unsized_shared_ptr() {
259 #[derive(Archive, Serialize, Deserialize, Debug, PartialEq)]
260 #[rkyv(crate, compare(PartialEq), derive(Debug))]
261 struct Test {
262 a: Rc<[String]>,
263 b: Rc<[String]>,
264 }
265
266 let rc_slice = Rc::<[String]>::from(
267 vec!["hello".to_string(), "world".to_string()].into_boxed_slice(),
268 );
269 let value = Test {
270 a: rc_slice.clone(),
271 b: rc_slice,
272 };
273
274 roundtrip(&value);
275 }
276
277 #[test]
278 fn roundtrip_unsized_shared_ptr_empty() {
279 #[derive(Archive, Serialize, Deserialize, Debug, PartialEq)]
280 #[rkyv(crate, compare(PartialEq), derive(Debug))]
281 struct Test {
282 a: Rc<[u32]>,
283 b: Rc<[u32]>,
284 }
285
286 let a_rc_slice = Rc::<[u32]>::from(vec![].into_boxed_slice());
287 let b_rc_slice = Rc::<[u32]>::from(vec![100].into_boxed_slice());
288 let value = Test {
289 a: a_rc_slice,
290 b: b_rc_slice.clone(),
291 };
292
293 roundtrip(&value);
294 }
295
296 #[test]
297 fn roundtrip_weak_ptr() {
298 #[derive(Archive, Serialize, Deserialize)]
299 #[rkyv(crate)]
300 struct Test {
301 a: Rc<u32>,
302 b: Weak<u32>,
303 }
304
305 let shared = Rc::new(10);
306 let value = Test {
307 a: shared.clone(),
308 b: Rc::downgrade(&shared),
309 };
310
311 let mut buf = to_bytes::<Panic>(&value).unwrap();
312
313 let archived =
314 unsafe { access_unchecked::<ArchivedTest>(buf.as_ref()) };
315 assert_eq!(*archived.a, 10);
316 assert!(archived.b.upgrade().is_some());
317 assert_eq!(**archived.b.upgrade().unwrap(), 10);
318
319 let mut mutable_archived =
320 unsafe { access_unchecked_mut::<ArchivedTest>(buf.as_mut()) };
321
322 munge!(let ArchivedTest { a, .. } = mutable_archived.as_mut());
323 unsafe {
324 *ArchivedRc::get_seal_unchecked(a) = 42u32.into();
325 }
326
327 let archived =
328 unsafe { access_unchecked::<ArchivedTest>(buf.as_ref()) };
329 assert_eq!(*archived.a, 42);
330 assert!(archived.b.upgrade().is_some());
331 assert_eq!(**archived.b.upgrade().unwrap(), 42);
332
333 let mut mutable_archived =
334 unsafe { access_unchecked_mut::<ArchivedTest>(buf.as_mut()) };
335 munge!(let ArchivedTest { b, .. } = mutable_archived.as_mut());
336 unsafe {
337 *ArchivedRc::get_seal_unchecked(
338 ArchivedRcWeak::upgrade_seal(b).unwrap(),
339 ) = 17u32.into();
340 }
341
342 let archived =
343 unsafe { access_unchecked::<ArchivedTest>(buf.as_ref()) };
344 assert_eq!(*archived.a, 17);
345 assert!(archived.b.upgrade().is_some());
346 assert_eq!(**archived.b.upgrade().unwrap(), 17);
347
348 let mut deserializer = Pool::new();
349 let deserialized =
350 deserialize_using::<Test, _, Panic>(archived, &mut deserializer)
351 .unwrap();
352
353 assert_eq!(*deserialized.a, 17);
354 assert!(deserialized.b.upgrade().is_some());
355 assert_eq!(*deserialized.b.upgrade().unwrap(), 17);
356 assert_eq!(
357 &*deserialized.a as *const u32,
358 &*deserialized.b.upgrade().unwrap() as *const u32
359 );
360 assert_eq!(Rc::strong_count(&deserialized.a), 2);
361 assert_eq!(Weak::strong_count(&deserialized.b), 2);
362 assert_eq!(Rc::weak_count(&deserialized.a), 1);
363 assert_eq!(Weak::weak_count(&deserialized.b), 1);
364
365 core::mem::drop(deserializer);
366
367 assert_eq!(*deserialized.a, 17);
368 assert!(deserialized.b.upgrade().is_some());
369 assert_eq!(*deserialized.b.upgrade().unwrap(), 17);
370 assert_eq!(
371 &*deserialized.a as *const u32,
372 &*deserialized.b.upgrade().unwrap() as *const u32
373 );
374 assert_eq!(Rc::strong_count(&deserialized.a), 1);
375 assert_eq!(Weak::strong_count(&deserialized.b), 1);
376 assert_eq!(Rc::weak_count(&deserialized.a), 1);
377 assert_eq!(Weak::weak_count(&deserialized.b), 1);
378 }
379
380 #[test]
381 fn serialize_cyclic_error() {
382 use rancor::{Fallible, Source};
383
384 use crate::{
385 de::Pooling,
386 ser::{Sharing, Writer},
387 };
388
389 #[derive(Archive, Serialize, Deserialize)]
390 #[rkyv(
391 crate,
392 serialize_bounds(
393 __S: Sharing + Writer,
394 <__S as Fallible>::Error: Source,
395 ),
396 deserialize_bounds(
397 __D: Pooling,
398 <__D as Fallible>::Error: Source,
399 )
400 )]
401 #[cfg_attr(
402 feature = "bytecheck",
403 rkyv(bytecheck(bounds(
404 __C: crate::validation::ArchiveContext
405 + crate::validation::SharedContext,
406 <__C as Fallible>::Error: Source,
407 ))),
408 )]
409 struct Inner {
410 #[rkyv(omit_bounds)]
411 weak: Weak<Self>,
412 }
413
414 #[derive(Archive, Serialize, Deserialize)]
415 #[rkyv(crate)]
416 struct Outer {
417 inner: Rc<Inner>,
418 }
419
420 let value = Outer {
421 inner: Rc::new_cyclic(|weak| Inner { weak: weak.clone() }),
422 };
423
424 assert!(to_bytes::<Failure>(&value).is_err());
425 }
426
427 #[cfg(all(
428 feature = "bytecheck",
429 not(feature = "big_endian"),
430 not(any(feature = "pointer_width_16", feature = "pointer_width_64")),
431 ))]
432 #[test]
433 fn recursive_stack_overflow() {
434 use rancor::{Fallible, Source};
435
436 use crate::{
437 access,
438 de::Pooling,
439 util::Align,
440 validation::{ArchiveContext, SharedContext},
441 };
442
443 #[derive(Archive, Deserialize)]
444 #[rkyv(
445 crate,
446 bytecheck(bounds(__C: ArchiveContext + SharedContext)),
447 deserialize_bounds(
448 __D: Pooling,
449 <__D as Fallible>::Error: Source,
450 ),
451 derive(Debug),
452 )]
453 enum AllValues {
454 Rc(#[rkyv(omit_bounds)] Rc<AllValues>),
455 }
456
457 let data = Align([
458 0x00, 0x00, 0x00, 0xff, 0xfc, 0xff, 0xff, 0xff, 0x00, 0x00, 0xf6, 0xff, 0xf4, 0xff, 0xff, 0xff, ]);
463 access::<ArchivedAllValues, Failure>(&*data).unwrap_err();
464 }
465}