use futures::{AsyncBufRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, Sink, ready};
use std::{
ops::Range,
pin::Pin,
task::{Context, Poll},
};
use pin_project::pin_project;
use rkyv::{
Archive, Archived, Portable, Serialize,
api::{
high::{HighSerializer, HighValidator},
serialize_using,
},
rancor,
ser::{
Serializer,
allocator::{Arena, ArenaHandle},
sharing::Share,
},
util::AlignedVec,
};
use crate::{RkyvCodecError, length_codec::LengthCodec};
pub async fn archive_sink<'b, Inner: AsyncWrite + Unpin, L: LengthCodec>(
inner: &mut Inner,
archived: &[u8],
) -> Result<(), RkyvCodecError<L>> {
let length_buf = &mut L::Buffer::default();
let length_buf = L::encode(archived.len(), length_buf);
inner.write_all(length_buf).await?;
inner.write_all(archived).await?;
Ok(())
}
pub async unsafe fn archive_stream_unsafe<
'b,
Inner: AsyncBufRead + Unpin,
Packet: Archive + Portable + 'b,
L: LengthCodec,
>(
inner: &mut Inner,
buffer: &'b mut AlignedVec,
) -> Result<&'b Archived<Packet>, RkyvCodecError<L>> {
buffer.clear();
let archive_len = L::decode_async(inner).await?;
if buffer.capacity() < archive_len {
buffer.reserve(archive_len - buffer.capacity())
}
unsafe { buffer.set_len(archive_len) }
inner.read_exact(buffer).await?;
unsafe { Ok(rkyv::access_unchecked(buffer)) }
}
pub async fn archive_stream<'b, Inner: AsyncBufRead + Unpin, Packet, L: LengthCodec>(
inner: &mut Inner,
buffer: &'b mut AlignedVec,
) -> Result<&'b Archived<Packet>, RkyvCodecError<L>>
where
Packet: rkyv::Archive + 'b,
Packet::Archived: for<'a> rkyv::bytecheck::CheckBytes<HighValidator<'a, rancor::Error>>,
{
buffer.clear();
let archive_len = L::decode_async(inner).await?;
if buffer.capacity() < archive_len {
buffer.reserve(archive_len - buffer.capacity())
}
unsafe { buffer.set_len(archive_len) }
inner.read_exact(buffer).await?;
let archive = rkyv::access::<Packet::Archived, rancor::Error>(buffer)?;
Ok(archive)
}
#[pin_project]
pub struct RkyvWriter<Writer: AsyncWrite, L: LengthCodec> {
#[pin]
writer: Writer,
length_buffer: L::Buffer,
len_state: Range<usize>, buf_state: usize, buffer: Option<AlignedVec>,
arena: Arena,
share: Option<Share>,
}
unsafe impl<Writer: AsyncWrite + Send, L: LengthCodec> Send for RkyvWriter<Writer, L> {}
impl<Writer: AsyncWrite, L: LengthCodec> RkyvWriter<Writer, L> {
pub fn new(writer: Writer) -> Self {
Self {
writer,
length_buffer: L::Buffer::default(),
len_state: Default::default(),
buf_state: 0,
buffer: Some(AlignedVec::new()),
arena: Arena::new(),
share: Some(Share::new()),
}
}
pub fn inner(self) -> Writer {
self.writer
}
}
impl<Writer: AsyncWrite, Packet: std::fmt::Debug, L: LengthCodec> Sink<&Packet>
for RkyvWriter<Writer, L>
where
Packet: Archive + for<'b> Serialize<HighSerializer<AlignedVec, ArenaHandle<'b>, rancor::Error>>,
{
type Error = RkyvCodecError<L>;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.project()
.writer
.poll_flush(cx)
.map_err(RkyvCodecError::IoError)
}
fn start_send(self: Pin<&mut Self>, item: &Packet) -> Result<(), Self::Error> {
let this = self.project();
let buffer_len = {
let mut buffer = this.buffer.take().unwrap();
buffer.clear();
let share = this.share.take().unwrap();
let mut serializer = Serializer::new(buffer, this.arena.acquire(), share);
let _ = serialize_using(item, &mut serializer)?;
let (buffer, _, share) = serializer.into_raw_parts();
let buffer_len = buffer.len();
*this.buffer = Some(buffer);
*this.share = Some(share);
buffer_len
};
*this.len_state = 0..L::encode(buffer_len, this.length_buffer).len();
*this.buf_state = 0;
Ok(())
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
let mut this = self.project();
let len_state = this.len_state;
if len_state.start <= len_state.end {
let length_buffer = L::as_slice(this.length_buffer);
let length_buffer = &mut length_buffer[len_state.clone()];
let written = ready!(Pin::new(&mut this.writer).poll_write(cx, length_buffer)?);
len_state.start += written;
}
let buffer = this.buffer.take().unwrap();
while *this.buf_state < buffer.len() {
let buffer_left = &buffer[*this.buf_state..buffer.len()];
let bytes_written = ready!(Pin::new(&mut this.writer).poll_write(cx, buffer_left))?;
if bytes_written == 0 {
return Poll::Ready(Err(RkyvCodecError::LengthTooLong {
requested: buffer.capacity(),
available: buffer.len(),
}));
}
*this.buf_state += bytes_written;
}
*this.buffer = Some(buffer);
ready!(this.writer.poll_flush(cx)?);
Poll::Ready(Ok(()))
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.project()
.writer
.poll_close(cx)
.map_err(RkyvCodecError::IoError)
}
}