1use futures::{AsyncBufRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, Sink, ready};
2use std::{
3 ops::Range,
4 pin::Pin,
5 task::{Context, Poll},
6};
7
8use pin_project::pin_project;
9
10use rkyv::{
11 Archive, Archived, Portable, Serialize,
12 api::{
13 high::{HighSerializer, HighValidator},
14 serialize_using,
15 },
16 rancor,
17 ser::{
18 Serializer,
19 allocator::{Arena, ArenaHandle},
20 sharing::Share,
21 },
22 util::AlignedVec,
23};
24
25use crate::{RkyvCodecError, length_codec::LengthCodec};
26
27pub async fn archive_sink<'b, Inner: AsyncWrite + Unpin, L: LengthCodec>(
29 inner: &mut Inner,
30 archived: &[u8],
31) -> Result<(), RkyvCodecError<L>> {
32 let length_buf = &mut L::Buffer::default();
33 let length_buf = L::encode(archived.len(), length_buf);
34 inner.write_all(length_buf).await?;
35 inner.write_all(archived).await?;
36 Ok(())
37}
38pub async unsafe fn archive_stream_unsafe<
48 'b,
49 Inner: AsyncBufRead + Unpin,
50 Packet: Archive + Portable + 'b,
51 L: LengthCodec,
52>(
53 inner: &mut Inner,
54 buffer: &'b mut AlignedVec,
55) -> Result<&'b Archived<Packet>, RkyvCodecError<L>> {
56 buffer.clear();
57
58 let archive_len = L::decode_async(inner).await?;
60
61 if buffer.capacity() < archive_len {
63 buffer.reserve(archive_len - buffer.capacity())
64 }
65
66 unsafe { buffer.set_len(archive_len) }
68
69 inner.read_exact(buffer).await?;
71
72 unsafe { Ok(rkyv::access_unchecked(buffer)) }
74}
75
76pub async fn archive_stream<'b, Inner: AsyncBufRead + Unpin, Packet, L: LengthCodec>(
88 inner: &mut Inner,
89 buffer: &'b mut AlignedVec,
90) -> Result<&'b Archived<Packet>, RkyvCodecError<L>>
91where
92 Packet: rkyv::Archive + 'b,
93 Packet::Archived: for<'a> rkyv::bytecheck::CheckBytes<HighValidator<'a, rancor::Error>>,
94{
95 buffer.clear();
96
97 let archive_len = L::decode_async(inner).await?;
98
99 if buffer.capacity() < archive_len {
101 buffer.reserve(archive_len - buffer.capacity())
102 }
103
104 unsafe { buffer.set_len(archive_len) }
106
107 inner.read_exact(buffer).await?;
108
109 let archive = rkyv::access::<Packet::Archived, rancor::Error>(buffer)?;
110
111 Ok(archive)
112}
113
114#[pin_project]
116pub struct RkyvWriter<Writer: AsyncWrite, L: LengthCodec> {
117 #[pin]
118 writer: Writer,
119 length_buffer: L::Buffer,
120 len_state: Range<usize>, buf_state: usize, buffer: Option<AlignedVec>,
123 arena: Arena,
124 share: Option<Share>,
125}
126
127unsafe impl<Writer: AsyncWrite + Send, L: LengthCodec> Send for RkyvWriter<Writer, L> {}
129
130impl<Writer: AsyncWrite, L: LengthCodec> RkyvWriter<Writer, L> {
131 pub fn new(writer: Writer) -> Self {
132 Self {
133 writer,
134 length_buffer: L::Buffer::default(),
135 len_state: Default::default(),
136 buf_state: 0,
137 buffer: Some(AlignedVec::new()),
138 arena: Arena::new(),
139 share: Some(Share::new()),
140 }
141 }
142 pub fn inner(self) -> Writer {
143 self.writer
144 }
145}
146
147impl<Writer: AsyncWrite, Packet: std::fmt::Debug, L: LengthCodec> Sink<&Packet>
148 for RkyvWriter<Writer, L>
149where
150 Packet: Archive + for<'b> Serialize<HighSerializer<AlignedVec, ArenaHandle<'b>, rancor::Error>>,
151{
152 type Error = RkyvCodecError<L>;
153
154 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
155 self.project()
156 .writer
157 .poll_flush(cx)
158 .map_err(RkyvCodecError::IoError)
159 }
160
161 fn start_send(self: Pin<&mut Self>, item: &Packet) -> Result<(), Self::Error> {
162 let this = self.project();
163 let buffer_len = {
164 let mut buffer = this.buffer.take().unwrap();
166 buffer.clear();
167 let share = this.share.take().unwrap();
168 let mut serializer = Serializer::new(buffer, this.arena.acquire(), share);
169 let _ = serialize_using(item, &mut serializer)?;
171
172 let (buffer, _, share) = serializer.into_raw_parts();
173 let buffer_len = buffer.len();
174 *this.buffer = Some(buffer);
175 *this.share = Some(share);
176 buffer_len
177 };
178
179 *this.len_state = 0..L::encode(buffer_len, this.length_buffer).len();
180 *this.buf_state = 0;
181
182 Ok(())
183 }
184
185 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
186 let mut this = self.project();
187
188 let len_state = this.len_state;
190 if len_state.start <= len_state.end {
191 let length_buffer = L::as_slice(this.length_buffer);
192 let length_buffer = &mut length_buffer[len_state.clone()];
193
194 let written = ready!(Pin::new(&mut this.writer).poll_write(cx, length_buffer)?);
195 len_state.start += written;
196 }
197 let buffer = this.buffer.take().unwrap();
198
199 while *this.buf_state < buffer.len() {
200 let buffer_left = &buffer[*this.buf_state..buffer.len()];
201 let bytes_written = ready!(Pin::new(&mut this.writer).poll_write(cx, buffer_left))?;
202 if bytes_written == 0 {
203 return Poll::Ready(Err(RkyvCodecError::LengthTooLong {
204 requested: buffer.capacity(),
205 available: buffer.len(),
206 }));
207 }
208 *this.buf_state += bytes_written;
209 }
210
211 *this.buffer = Some(buffer);
212
213 ready!(this.writer.poll_flush(cx)?);
214 Poll::Ready(Ok(()))
215 }
216
217 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
218 self.project()
219 .writer
220 .poll_close(cx)
221 .map_err(RkyvCodecError::IoError)
222 }
223}