rkyv_owned_archive/
lib.rs

1#[cfg(feature = "pool")]
2mod pool;
3
4use aligned_buffer::{
5	alloc::{BufferAllocator, Global},
6	SharedAlignedBuffer, DEFAULT_BUFFER_ALIGNMENT,
7};
8use rkyv::Portable;
9use std::{fmt, marker::PhantomData, mem, ops};
10
11#[cfg(feature = "bytecheck")]
12use rkyv::{
13	api::{access_pos_with_context, high::HighValidator},
14	bytecheck::CheckBytes,
15	ptr_meta::Pointee,
16	rancor::{self, Strategy},
17	validation::{archive::ArchiveValidator, shared::SharedValidator, ArchiveContext, Validator},
18};
19
20#[cfg(feature = "pool")]
21pub use pool::PooledArchive;
22
23#[cfg(feature = "pool")]
24pub use aligned_buffer_pool;
25
26pub use aligned_buffer;
27
28pub struct OwnedArchive<T: Portable, const ALIGNMENT: usize = DEFAULT_BUFFER_ALIGNMENT, A = Global>
29where
30	A: BufferAllocator<ALIGNMENT>,
31{
32	buffer: SharedAlignedBuffer<ALIGNMENT, A>,
33	pos: usize,
34	_phantom: PhantomData<T>,
35}
36
37impl<T: Portable, const ALIGNMENT: usize, A> OwnedArchive<T, ALIGNMENT, A>
38where
39	A: BufferAllocator<ALIGNMENT>,
40{
41	#[allow(dead_code)]
42	const ALIGNMENT_OK: () = assert!(mem::size_of::<T>() <= ALIGNMENT);
43}
44
45// It's safe to clone an `OwnedArchive` because the inner buffer is shared and guarantees
46// that clones returns the a stable reference to the same data.
47impl<T: Portable, const ALIGNMENT: usize, A> Clone for OwnedArchive<T, ALIGNMENT, A>
48where
49	A: BufferAllocator<ALIGNMENT> + Clone,
50{
51	#[inline]
52	fn clone(&self) -> Self {
53		Self {
54			buffer: self.buffer.clone(),
55			pos: self.pos,
56			_phantom: PhantomData,
57		}
58	}
59}
60
61#[cfg(feature = "bytecheck")]
62impl<T: Portable, const ALIGNMENT: usize, A> OwnedArchive<T, ALIGNMENT, A>
63where
64	A: BufferAllocator<ALIGNMENT>,
65{
66	pub fn new<E>(buffer: SharedAlignedBuffer<ALIGNMENT, A>) -> Result<Self, E>
67	where
68		E: rancor::Source,
69		T: for<'a> CheckBytes<HighValidator<'a, E>>,
70	{
71		let pos = buffer.len().saturating_sub(mem::size_of::<T>());
72		Self::new_with_pos(buffer, pos)
73	}
74
75	pub fn new_with_pos<E>(buffer: SharedAlignedBuffer<ALIGNMENT, A>, pos: usize) -> Result<Self, E>
76	where
77		E: rancor::Source,
78		T: for<'a> CheckBytes<HighValidator<'a, E>>,
79	{
80		let mut validator = Validator::new(ArchiveValidator::new(&buffer), SharedValidator::new());
81		match access_pos_with_context::<T, _, E>(&buffer, pos, &mut validator) {
82			Err(e) => Err(e),
83			Ok(_) => Ok(Self {
84				buffer,
85				pos,
86				_phantom: PhantomData,
87			}),
88		}
89	}
90
91	pub fn new_with_context<C, E>(
92		buffer: SharedAlignedBuffer<ALIGNMENT, A>,
93		context: &mut C,
94	) -> Result<Self, E>
95	where
96		T: CheckBytes<Strategy<C, E>> + Pointee<Metadata = ()>,
97		C: ArchiveContext<E> + ?Sized,
98		E: rancor::Source,
99	{
100		let pos = buffer.len().saturating_sub(mem::size_of::<T>());
101		Self::new_with_pos_and_context(buffer, pos, context)
102	}
103
104	pub fn new_with_pos_and_context<C, E>(
105		buffer: SharedAlignedBuffer<ALIGNMENT, A>,
106		pos: usize,
107		context: &mut C,
108	) -> Result<Self, E>
109	where
110		T: CheckBytes<Strategy<C, E>> + Pointee<Metadata = ()>,
111		C: ArchiveContext<E> + ?Sized,
112		E: rancor::Source,
113	{
114		match access_pos_with_context::<T, C, E>(&buffer, pos, context) {
115			Err(e) => Err(e),
116			Ok(_) => Ok(Self {
117				buffer,
118				pos,
119				_phantom: PhantomData,
120			}),
121		}
122	}
123
124	pub fn map<U: Portable, F>(self, f: F) -> OwnedArchive<U, ALIGNMENT, A>
125	where
126		F: for<'a> FnOnce(&'a T) -> &'a U,
127	{
128		self.map_with_buffer(|a, _| f(a))
129	}
130
131	pub fn map_with_buffer<U: Portable, F>(self, f: F) -> OwnedArchive<U, ALIGNMENT, A>
132	where
133		F: for<'a> FnOnce(&'a T, &'a [u8]) -> &'a U,
134	{
135		let ptr_start = f(&*self, &self.buffer) as *const U as usize;
136		let ptr_end = ptr_start + mem::size_of::<U>();
137		let buf_start = self.buffer.as_ptr() as usize;
138		let buf_end = buf_start + self.buffer.len();
139
140		// check that U is within the bounds of the buffer
141		assert!((buf_start..=buf_end).contains(&ptr_start));
142		assert!((buf_start..=buf_end).contains(&ptr_end));
143		let pos = ptr_start - buf_start;
144
145		// SAFETY: U is within the bounds of the buffer
146		unsafe { OwnedArchive::new_unchecked_with_pos(self.buffer, pos) }
147	}
148
149	pub fn try_map<U: Portable, E, F>(self, f: F) -> Result<OwnedArchive<U, ALIGNMENT, A>, E>
150	where
151		F: for<'a> FnOnce(&'a T) -> Result<&'a U, E>,
152	{
153		self.try_map_with_buffer(|a, _| f(a))
154	}
155
156	pub fn try_map_with_buffer<U: Portable, E, F>(
157		self,
158		f: F,
159	) -> Result<OwnedArchive<U, ALIGNMENT, A>, E>
160	where
161		F: for<'a> FnOnce(&'a T, &'a [u8]) -> Result<&'a U, E>,
162	{
163		let ptr_start = f(&*self, &self.buffer)? as *const U as usize;
164		let ptr_end = ptr_start + mem::size_of::<U>();
165		let buf_start = self.buffer.as_ptr() as usize;
166		let buf_end = buf_start + self.buffer.len();
167
168		// check that U is within the bounds of the buffer
169		assert!((buf_start..=buf_end).contains(&ptr_start));
170		assert!((buf_start..=buf_end).contains(&ptr_end));
171		let pos = ptr_start - buf_start;
172
173		// SAFETY: U is within the bounds of the buffer
174		Ok(unsafe { OwnedArchive::new_unchecked_with_pos(self.buffer, pos) })
175	}
176}
177
178impl<T: Portable, const ALIGNMENT: usize, A> OwnedArchive<T, ALIGNMENT, A>
179where
180	A: BufferAllocator<ALIGNMENT>,
181{
182	/// # Safety
183	///
184	/// - The byte slice must represent an archived object.
185	/// - The root of the object must be stored at the end of the slice (this is the
186	///   default behavior).
187	pub unsafe fn new_unchecked(buffer: SharedAlignedBuffer<ALIGNMENT, A>) -> Self {
188		let pos = buffer.len().saturating_sub(mem::size_of::<T>());
189		Self::new_unchecked_with_pos(buffer, pos)
190	}
191
192	/// # Safety
193	///
194	/// A `T::Archived` must be located at the given position in the byte slice.
195	pub unsafe fn new_unchecked_with_pos(
196		buffer: SharedAlignedBuffer<ALIGNMENT, A>,
197		pos: usize,
198	) -> Self {
199		Self {
200			buffer,
201			pos,
202			_phantom: PhantomData,
203		}
204	}
205}
206
207impl<T: Portable, const ALIGNMENT: usize, A> ops::Deref for OwnedArchive<T, ALIGNMENT, A>
208where
209	A: BufferAllocator<ALIGNMENT>,
210{
211	type Target = T;
212
213	#[inline]
214	fn deref(&self) -> &Self::Target {
215		// SAFETY: `buffer` is required to contain a representation of T::Archived at `pos`.
216		// This is checked by the safe constructors, and required by the unsafe constructors.
217		unsafe { rkyv::api::access_pos_unchecked::<T>(&self.buffer, self.pos) }
218	}
219}
220
221impl<T: Portable, const ALIGNMENT: usize, A> AsRef<T> for OwnedArchive<T, ALIGNMENT, A>
222where
223	A: BufferAllocator<ALIGNMENT>,
224{
225	#[inline]
226	fn as_ref(&self) -> &T {
227		self
228	}
229}
230
231impl<T: Portable, const ALIGNMENT: usize, A> fmt::Debug for OwnedArchive<T, ALIGNMENT, A>
232where
233	A: BufferAllocator<ALIGNMENT>,
234	T: fmt::Debug,
235{
236	fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
237		fmt::Debug::fmt(&**self, f)
238	}
239}
240
241impl<T: Portable, const ALIGNMENT: usize, A> fmt::Display for OwnedArchive<T, ALIGNMENT, A>
242where
243	A: BufferAllocator<ALIGNMENT>,
244	T: fmt::Display,
245{
246	fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
247		fmt::Display::fmt(&**self, f)
248	}
249}
250
251#[cfg(all(test, feature = "bytecheck"))]
252mod tests {
253	use super::*;
254	use aligned_buffer_pool::{RetainAllRetentionPolicy, SerializerPool};
255	use rkyv::{Archive, Deserialize, Serialize};
256
257	#[derive(Archive, Serialize, Deserialize)]
258	struct TestStruct1 {
259		name: String,
260		boxed_name: Box<str>,
261		age: u16,
262	}
263
264	impl TestStruct1 {
265		fn new(name: &str, age: u16) -> Self {
266			Self {
267				name: name.to_string(),
268				boxed_name: name.into(),
269				age,
270			}
271		}
272	}
273
274	#[test]
275	fn owned_archive() {
276		let pool = SerializerPool::<RetainAllRetentionPolicy, 64>::with_capacity(10);
277		let original = TestStruct1::new("test1", 10);
278		let buffer = pool.serialize(&original).expect("failed to serialize");
279
280		let owned_archive =
281			OwnedArchive::<ArchivedTestStruct1, 64, _>::new::<rancor::BoxedError>(buffer)
282				.expect("failed to create");
283
284		assert_eq!(owned_archive.name, original.name);
285		assert_eq!(owned_archive.boxed_name, original.boxed_name);
286		assert_eq!(owned_archive.age, original.age);
287	}
288
289	#[test]
290	fn clone_owned_archive() {
291		let pool = SerializerPool::<RetainAllRetentionPolicy, 64>::with_capacity(10);
292		let original = TestStruct1::new("test1", 10);
293		let buffer = pool.serialize(&original).expect("failed to serialize");
294
295		let owned_archive =
296			OwnedArchive::<ArchivedTestStruct1, 64, _>::new::<rancor::BoxedError>(buffer)
297				.expect("failed to create");
298
299		let clone = owned_archive.clone();
300		assert_eq!(owned_archive.name, clone.name);
301		assert_eq!(owned_archive.boxed_name, clone.boxed_name);
302		assert_eq!(owned_archive.age, clone.age);
303	}
304
305	#[test]
306	fn mapped_owned_archive() {
307		let pool = SerializerPool::<RetainAllRetentionPolicy, 64>::with_capacity(10);
308		let original = TestStruct1::new("test1", 10);
309		let buffer = pool.serialize(&original).expect("failed to serialize");
310
311		let owned_archive =
312			OwnedArchive::<ArchivedTestStruct1, 64, _>::new::<rancor::BoxedError>(buffer)
313				.expect("failed to create");
314
315		let mapped = owned_archive.map(|a| &a.name);
316		assert_eq!(*mapped, original.name);
317	}
318}