use super::SerialStream;
use tokio_util::codec::{Decoder, Encoder};
use futures::{Sink, Stream};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use bytes::{BufMut, BytesMut};
use futures::ready;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::{io, mem::MaybeUninit};
#[must_use = "sinks do nothing unless polled"]
#[derive(Debug)]
pub struct SerialFramed<C> {
port: SerialStream,
codec: C,
rd: BytesMut,
wr: BytesMut,
flushed: bool,
is_readable: bool,
}
const INITIAL_RD_CAPACITY: usize = 64 * 1024;
const INITIAL_WR_CAPACITY: usize = 8 * 1024;
impl<C: Decoder + Unpin> Stream for SerialFramed<C> {
type Item = Result<C::Item, C::Error>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let pin = self.get_mut();
pin.rd.reserve(INITIAL_RD_CAPACITY);
loop {
if pin.is_readable {
if let Some(frame) = pin.codec.decode_eof(&mut pin.rd)? {
return Poll::Ready(Some(Ok(frame)));
}
pin.is_readable = false;
pin.rd.clear();
}
unsafe {
let buf = &mut *(pin.rd.chunk_mut() as *mut _ as *mut [MaybeUninit<u8>]);
let mut read = ReadBuf::uninit(buf);
let ptr = read.filled().as_ptr();
ready!(Pin::new(&mut pin.port).poll_read(cx, &mut read))?;
assert_eq!(ptr, read.filled().as_ptr());
pin.rd.advance_mut(read.filled().len());
};
pin.is_readable = true;
}
}
}
impl<I, C: Encoder<I> + Unpin> Sink<I> for SerialFramed<C> {
type Error = C::Error;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
if !self.flushed {
match self.poll_flush(cx)? {
Poll::Ready(()) => {}
Poll::Pending => return Poll::Pending,
}
}
Poll::Ready(Ok(()))
}
fn start_send(self: Pin<&mut Self>, item: I) -> Result<(), Self::Error> {
let pin = self.get_mut();
pin.codec.encode(item, &mut pin.wr)?;
pin.flushed = false;
Ok(())
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
if self.flushed {
return Poll::Ready(Ok(()));
}
let Self {
ref mut port,
ref mut wr,
..
} = *self;
let pinned = Pin::new(port);
let n = ready!(pinned.poll_write(cx, &wr))?;
let wrote_all = n == self.wr.len();
self.wr.clear();
self.flushed = true;
let res = if wrote_all {
Ok(())
} else {
Err(io::Error::new(
io::ErrorKind::Other,
"failed to write entire datagram to socket",
)
.into())
};
Poll::Ready(res)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
ready!(self.poll_flush(cx))?;
Poll::Ready(Ok(()))
}
}
impl<C> SerialFramed<C> {
#[allow(dead_code)]
pub fn new(port: SerialStream, codec: C) -> SerialFramed<C> {
Self {
port,
codec,
rd: BytesMut::with_capacity(INITIAL_RD_CAPACITY),
wr: BytesMut::with_capacity(INITIAL_WR_CAPACITY),
flushed: true,
is_readable: false,
}
}
#[allow(dead_code)]
pub fn get_ref(&self) -> &SerialStream {
&self.port
}
#[allow(dead_code)]
pub fn get_mut(&mut self) -> &mut SerialStream {
&mut self.port
}
#[allow(dead_code)]
pub fn into_inner(self) -> SerialStream {
self.port
}
#[allow(dead_code)]
pub fn codec(&self) -> &C {
&self.codec
}
#[allow(dead_code)]
pub fn codec_mut(&mut self) -> &mut C {
&mut self.codec
}
#[allow(dead_code)]
pub fn read_buffer(&self) -> &BytesMut {
&self.rd
}
#[allow(dead_code)]
pub fn read_buffer_mut(&mut self) -> &mut BytesMut {
&mut self.rd
}
}