rkyv_codec/
rkyv_codec.rs

1use futures::{AsyncBufRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, Sink, ready};
2use std::{
3	ops::Range,
4	pin::Pin,
5	task::{Context, Poll},
6};
7
8use pin_project::pin_project;
9
10use rkyv::{
11	Archive, Archived, Portable, Serialize,
12	api::{
13		high::{HighSerializer, HighValidator},
14		serialize_using,
15	},
16	rancor,
17	ser::{
18		Serializer,
19		allocator::{Arena, ArenaHandle},
20		sharing::Share,
21	},
22	util::AlignedVec,
23};
24
25use crate::{RkyvCodecError, length_codec::LengthCodec};
26
27/// Rewrites a single buffer representing an Archive to an `AsyncWrite`
28pub async fn archive_sink<'b, Inner: AsyncWrite + Unpin, L: LengthCodec>(
29	inner: &mut Inner,
30	archived: &[u8],
31) -> Result<(), RkyvCodecError<L>> {
32	let length_buf = &mut L::Buffer::default();
33	let length_buf = L::encode(archived.len(), length_buf);
34	inner.write_all(length_buf).await?;
35	inner.write_all(archived).await?;
36	Ok(())
37}
38/// Reads a single `&Archived<Object>` from an `AsyncRead` without checking for correct byte formatting
39/// # Safety
40/// This may cause undefined behavior if the bytestream is not a valid archive (i.e. not generated through `archive_sink[_bytes]`, or `RkyvWriter`)
41///
42/// As an optimisation, this function may pass uninitialized bytes to the reader for the reader to read into. Make sure the particular reader in question is implemented correctly and does not read from its passed buffer in the poll_read() function without first writing to it.
43/// # Warning
44/// Passed buffer is reallocated so it may fit the size of the packet being written. This may allow for DOS attacks if remote sends too large a length encoding
45/// # Errors
46/// Will return an error if there are not enough bytes to read to read the length of the packet, or read the packet itself. Will also return an error if the length encoding format is invalid.
47pub async unsafe fn archive_stream_unsafe<
48	'b,
49	Inner: AsyncBufRead + Unpin,
50	Packet: Archive + Portable + 'b,
51	L: LengthCodec,
52>(
53	inner: &mut Inner,
54	buffer: &'b mut AlignedVec,
55) -> Result<&'b Archived<Packet>, RkyvCodecError<L>> {
56	buffer.clear();
57
58	// parse archive length
59	let archive_len = L::decode_async(inner).await?;
60
61	// If not enough capacity in buffer to fit `archive_len`, reserve more.
62	if buffer.capacity() < archive_len {
63		buffer.reserve(archive_len - buffer.capacity())
64	}
65
66	// Safety: Caller should make sure that reader does not read from this potentially uninitialized buffer passed to poll_read()
67	unsafe { buffer.set_len(archive_len) }
68
69	// read exactly amount specified by archive_len into buffer
70	inner.read_exact(buffer).await?;
71
72	// Safety: Caller should make sure that reader does not produce invalid packets.
73	unsafe { Ok(rkyv::access_unchecked(buffer)) }
74}
75
76/// Reads a single `&Archived<Object>` from an `AsyncRead` using the passed buffer.
77///
78/// Until streaming iterators (and streaming futures) are implemented in rust, this currently the fastest method I could come up with that requires no recurring heap allocations.
79///
80/// Requires rkyv "validation" feature
81/// # Safety
82/// As an optimisation, this function may pass uninitialized bytes to the reader for the reader to read into. Make sure the particular reader in question is implemented correctly and does not read from its passed buffer in the poll_read() function without first writing to it.
83/// # Warning
84/// Passed buffer is reallocated so it may fit the size of the packet being written. This may allow for DOS attacks if remote sends too large a length encoding
85/// # Errors
86/// Will return an error if there are not enough bytes to read to read the length of the packet, or read the packet itself. Will also return an error if the length encoding format is invalid or the packet archive itself is invalid.
87pub async fn archive_stream<'b, Inner: AsyncBufRead + Unpin, Packet, L: LengthCodec>(
88	inner: &mut Inner,
89	buffer: &'b mut AlignedVec,
90) -> Result<&'b Archived<Packet>, RkyvCodecError<L>>
91where
92	Packet: rkyv::Archive + 'b,
93	Packet::Archived: for<'a> rkyv::bytecheck::CheckBytes<HighValidator<'a, rancor::Error>>,
94{
95	buffer.clear();
96
97	let archive_len = L::decode_async(inner).await?;
98
99	// If not enough capacity in buffer to fit `archive_len`, reserve more.
100	if buffer.capacity() < archive_len {
101		buffer.reserve(archive_len - buffer.capacity())
102	}
103
104	// Safety: Caller should make sure that reader does not read from this potentially uninitialized buffer passed to poll_read()
105	unsafe { buffer.set_len(archive_len) }
106
107	inner.read_exact(buffer).await?;
108
109	let archive = rkyv::access::<Packet::Archived, rancor::Error>(buffer)?;
110
111	Ok(archive)
112}
113
114/// Wraps an `AsyncWrite` and implements `Sink` to serialize `Archive` objects.
115#[pin_project]
116pub struct RkyvWriter<Writer: AsyncWrite, L: LengthCodec> {
117	#[pin]
118	writer: Writer,
119	length_buffer: L::Buffer,
120	len_state: Range<usize>, // How much of the length buffer has been written
121	buf_state: usize, // Whether or not the aligned buf is being written and if so, how much so far
122	buffer: Option<AlignedVec>,
123	arena: Arena,
124	share: Option<Share>,
125}
126
127// Safety: Arena is Send and Share is Send, if Writer is Send RkyvWriter should be Send.
128unsafe impl<Writer: AsyncWrite + Send, L: LengthCodec> Send for RkyvWriter<Writer, L> {}
129
130impl<Writer: AsyncWrite, L: LengthCodec> RkyvWriter<Writer, L> {
131	pub fn new(writer: Writer) -> Self {
132		Self {
133			writer,
134			length_buffer: L::Buffer::default(),
135			len_state: Default::default(),
136			buf_state: 0,
137			buffer: Some(AlignedVec::new()),
138			arena: Arena::new(),
139			share: Some(Share::new()),
140		}
141	}
142	pub fn inner(self) -> Writer {
143		self.writer
144	}
145}
146
147impl<Writer: AsyncWrite, Packet: std::fmt::Debug, L: LengthCodec> Sink<&Packet>
148	for RkyvWriter<Writer, L>
149where
150	Packet: Archive + for<'b> Serialize<HighSerializer<AlignedVec, ArenaHandle<'b>, rancor::Error>>,
151{
152	type Error = RkyvCodecError<L>;
153
154	fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
155		self.project()
156			.writer
157			.poll_flush(cx)
158			.map_err(RkyvCodecError::IoError)
159	}
160
161	fn start_send(self: Pin<&mut Self>, item: &Packet) -> Result<(), Self::Error> {
162		let this = self.project();
163		let buffer_len = {
164			// Serializer
165			let mut buffer = this.buffer.take().unwrap();
166			buffer.clear();
167			let share = this.share.take().unwrap();
168			let mut serializer = Serializer::new(buffer, this.arena.acquire(), share);
169			// serialize
170			let _ = serialize_using(item, &mut serializer)?;
171
172			let (buffer, _, share) = serializer.into_raw_parts();
173			let buffer_len = buffer.len();
174			*this.buffer = Some(buffer);
175			*this.share = Some(share);
176			buffer_len
177		};
178
179		*this.len_state = 0..L::encode(buffer_len, this.length_buffer).len();
180		*this.buf_state = 0;
181
182		Ok(())
183	}
184
185	fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
186		let mut this = self.project();
187
188		// keep writing length buffer for as long as is required
189		let len_state = this.len_state;
190		if len_state.start <= len_state.end {
191			let length_buffer = L::as_slice(this.length_buffer);
192			let length_buffer = &mut length_buffer[len_state.clone()];
193
194			let written = ready!(Pin::new(&mut this.writer).poll_write(cx, length_buffer)?);
195			len_state.start += written;
196		}
197		let buffer = this.buffer.take().unwrap();
198
199		while *this.buf_state < buffer.len() {
200			let buffer_left = &buffer[*this.buf_state..buffer.len()];
201			let bytes_written = ready!(Pin::new(&mut this.writer).poll_write(cx, buffer_left))?;
202			if bytes_written == 0 {
203				return Poll::Ready(Err(RkyvCodecError::LengthTooLong {
204					requested: buffer.capacity(),
205					available: buffer.len(),
206				}));
207			}
208			*this.buf_state += bytes_written;
209		}
210
211		*this.buffer = Some(buffer);
212
213		ready!(this.writer.poll_flush(cx)?);
214		Poll::Ready(Ok(()))
215	}
216
217	fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
218		self.project()
219			.writer
220			.poll_close(cx)
221			.map_err(RkyvCodecError::IoError)
222	}
223}