use crate::{
codec::RxPacket,
core::{
base_types::VarSizeInt,
error::{CodecError, ConversionError},
utils::TryDecode,
},
};
use bytes::BytesMut;
use core::{
ops::Range,
pin::Pin,
task::{Context, Poll},
};
use futures::{AsyncRead, AsyncWrite, AsyncWriteExt, Stream};
use std::{io, mem};
enum PacketStreamState {
Idle,
ReadPacketLen,
ReadPacketData,
}
pub(crate) struct RxPacketStream<StreamT> {
stream: StreamT,
buf: BytesMut,
size: usize,
packet: Range<usize>,
state: PacketStreamState,
}
impl<StreamT> From<StreamT> for RxPacketStream<StreamT> {
fn from(stream: StreamT) -> Self {
Self {
stream,
buf: BytesMut::with_capacity(1024),
size: 0,
packet: 0..0,
state: PacketStreamState::Idle,
}
}
}
impl<StreamT> RxPacketStream<StreamT> {
fn split_borrows_mut(
&mut self,
) -> (
&mut StreamT,
&mut BytesMut,
&mut usize,
&mut Range<usize>,
&mut PacketStreamState,
) {
(
&mut self.stream,
&mut self.buf,
&mut self.size,
&mut self.packet,
&mut self.state,
)
}
}
impl<StreamT> Stream for RxPacketStream<StreamT>
where
StreamT: AsyncRead + Unpin,
{
type Item = Result<RxPacket, CodecError>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
const DEFAULT_CHUNK_SIZE: usize = 512;
let (mut stream, buf, size, packet, state) = self.split_borrows_mut();
match *state {
PacketStreamState::Idle => {
let chunk_size = if packet.end - *size < DEFAULT_CHUNK_SIZE {
DEFAULT_CHUNK_SIZE
} else {
packet.end
};
buf.resize(*size + chunk_size, 0);
if let Poll::Ready(result) = Pin::new(&mut stream)
.poll_read(cx, &mut buf[*size..*size + chunk_size])
.map(|res| res.ok().filter(|&size| size != 0 ))
{
if result.is_none() {
return Poll::Ready(None);
}
*size += result.unwrap();
if *size >= 2 {
*state = PacketStreamState::ReadPacketLen;
return self.poll_next(cx);
}
}
Poll::Pending
}
PacketStreamState::ReadPacketLen => {
let maybe_remaining_len =
VarSizeInt::try_from(&buf[1..]).map(Some).or_else(|err| {
if let ConversionError::InsufficientBufferSize(_) = err {
return Ok(None); }
Err(err)
});
if maybe_remaining_len.is_err() {
return Poll::Ready(None);
}
if let Some(remaining_len) = maybe_remaining_len.unwrap() {
packet.start = 0;
packet.end = 1 + remaining_len.len() + remaining_len.value() as usize;
*state = PacketStreamState::ReadPacketData;
return self.poll_next(cx);
}
*state = PacketStreamState::Idle;
self.poll_next(cx)
}
PacketStreamState::ReadPacketData => {
if *size < packet.end {
*state = PacketStreamState::Idle;
return self.poll_next(cx);
}
*size -= packet.len();
if *size != 0 {
*state = PacketStreamState::ReadPacketLen;
} else {
*state = PacketStreamState::Idle;
}
Poll::Ready(Some(RxPacket::try_decode(
buf.split_to(mem::replace(&mut packet.end, 0)).freeze(),
)))
}
}
}
}
pub(crate) struct TxPacketStream<TxStreamT> {
stream: TxStreamT,
}
impl<TxStreamT> From<TxStreamT> for TxPacketStream<TxStreamT> {
fn from(inner: TxStreamT) -> Self {
Self { stream: inner }
}
}
impl<TxStreamT> TxPacketStream<TxStreamT> {
pub(crate) async fn write(&mut self, packet: &[u8]) -> Result<(), io::Error>
where
TxStreamT: AsyncWrite + Unpin,
{
self.stream.write_all(&packet[0..packet.len()]).await
}
}