use std::marker::PhantomData;
use asynchronous_codec::{BytesMut, Decoder, Encoder};
use bytes::Buf;
use rkyv::{
Archive, Archived, Deserialize, Portable, Serialize,
api::{
high::{HighDeserializer, HighSerializer, HighValidator},
serialize_using,
},
de::Pool,
rancor,
ser::{
Serializer,
allocator::{Arena, ArenaHandle},
sharing::Share,
},
util::AlignedVec,
};
use crate::{RkyvCodecError, length_codec::LengthCodec};
pub struct RkyvCodec<Packet: Archive, L: LengthCodec> {
_data: PhantomData<Packet>,
_length: PhantomData<L>,
encode_buffer: Option<AlignedVec>,
ser_arena: Arena,
ser_share: Option<Share>,
decode_buffer: AlignedVec,
des_pool: Pool,
}
impl<Packet: Archive, L: LengthCodec> Default for RkyvCodec<Packet, L> {
fn default() -> Self {
Self {
_data: PhantomData,
_length: PhantomData::default(),
encode_buffer: Some(AlignedVec::new()),
ser_arena: Arena::new(),
ser_share: Some(Share::new()),
decode_buffer: AlignedVec::new(),
des_pool: Pool::new(),
}
}
}
impl<Packet, L: LengthCodec> Encoder for RkyvCodec<Packet, L>
where
Packet: Archive + for<'b> Serialize<HighSerializer<AlignedVec, ArenaHandle<'b>, rancor::Error>>,
{
type Item<'a> = Packet;
type Error = RkyvCodecError<L>;
fn encode<'a>(&mut self, data: Self::Item<'a>, buf: &mut BytesMut) -> Result<(), Self::Error> {
let mut encode_buffer = self.encode_buffer.take().unwrap();
let share = self.ser_share.take().unwrap();
encode_buffer.clear();
let mut serializer = Serializer::new(encode_buffer, self.ser_arena.acquire(), share);
let _ = serialize_using(&data, &mut serializer)?;
let (encode_buffer, _, share) = serializer.into_raw_parts();
let mut length_buffer = L::Buffer::default();
let length_buffer = L::encode(encode_buffer.len(), &mut length_buffer);
buf.extend_from_slice(length_buffer);
buf.extend_from_slice(&encode_buffer[..]);
self.encode_buffer = Some(encode_buffer);
self.ser_share = Some(share);
Ok(())
}
}
impl<Packet, L: LengthCodec> Decoder for RkyvCodec<Packet, L>
where
Packet: Archive + 'static,
Packet::Archived: Portable
+ for<'b> rkyv::bytecheck::CheckBytes<HighValidator<'b, rancor::Error>>
+ Deserialize<Packet, HighDeserializer<rancor::Error>>,
{
type Item = Packet;
type Error = RkyvCodecError<L>;
fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
if buf.is_empty() {
return Ok(None);
}
self.decode_buffer.clear();
let (length, remaining) = L::decode(buf).map_err(RkyvCodecError::ReadLengthError)?;
self.decode_buffer.extend_from_slice(&remaining[0..length]);
let archive: &Archived<Packet> = rkyv::access::<_, rancor::Error>(&self.decode_buffer)?;
let packet: Packet =
archive.deserialize(&mut HighDeserializer::wrap(&mut self.des_pool))?;
let amount_read = length + buf.len() - remaining.len();
buf.advance(amount_read);
Ok(Some(packet))
}
}