use std::{
io::Result,
mem::MaybeUninit,
pin::Pin,
task::{ready, Context, Poll},
};
use pin_project_lite::pin_project;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
pin_project! {
pub(crate) struct Reader<R> {
#[pin]
inner: R,
storage: [MaybeUninit<u8>; 128],
trailing_fives: usize,
}
}
impl<R: AsyncRead + Unpin> Reader<R> {
pub(crate) fn new(inner: R) -> Self {
Self {
inner,
storage: [const { MaybeUninit::uninit() }; 128],
trailing_fives: 0,
}
}
#[cfg(test)]
pub(crate) fn into_inner(self) -> R {
self.inner
}
}
impl<R: AsyncRead + Unpin> AsyncRead for Reader<R> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<Result<()>> {
let mut me = self.project();
let mut storage = ReadBuf::uninit(me.storage);
let mut b = storage.take(std::cmp::min(storage.capacity(), buf.remaining()));
ready!(me.inner.as_mut().poll_read(cx, &mut b))?;
let mut s = b.filled();
const SENTINEL: &[u8; 4] = &[0x55, 0x55, 0x55, 0x00];
if *me.trailing_fives > 0 {
if *me.trailing_fives + s.len() < SENTINEL.len() {
buf.put_slice(s);
*me.trailing_fives =
if s == &SENTINEL[*me.trailing_fives..*me.trailing_fives + s.len()] {
*me.trailing_fives + s.len()
} else {
0
};
return Poll::Ready(Ok(()));
} else if *me.trailing_fives == SENTINEL.len() - 1
&& s.len() == 1
&& &s[0] == SENTINEL.last().unwrap()
{
b.clear();
ready!(me.inner.as_mut().poll_read(cx, &mut b))?;
s = b.filled();
} else if s[..SENTINEL.len() - *me.trailing_fives] == SENTINEL[*me.trailing_fives..] {
buf.put_slice(&s[..SENTINEL.len() - *me.trailing_fives - 1]);
s = &s[SENTINEL.len() - *me.trailing_fives..];
}
}
while let Some(p) = s.windows(SENTINEL.len()).position(|w| w == SENTINEL) {
let (l, r) = s.split_at(p + SENTINEL.len() - 1);
buf.put_slice(l);
s = &r[1..];
}
*me.trailing_fives = 0;
for i in (1..SENTINEL.len()).rev() {
if s.ends_with(&SENTINEL[0..i]) {
*me.trailing_fives = i;
break;
}
}
buf.put_slice(s);
Poll::Ready(Ok(()))
}
}
pin_project! {
pub(crate) struct Writer<W> {
#[pin]
inner: W,
trailing_fives: usize,
}
}
impl<W: AsyncWrite + Unpin> Writer<W> {
pub(crate) fn new(inner: W) -> Self {
Self {
inner,
trailing_fives: 0,
}
}
#[cfg(test)]
pub(crate) fn into_inner(self) -> W {
self.inner
}
#[cfg(test)]
pub(crate) async fn write_magic(&mut self, src: &[u8]) -> Result<usize> {
use tokio::io::AsyncWriteExt;
self.inner.write(src).await
}
pub(crate) fn poll_write_magic(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize>> {
self.project().inner.poll_write(cx, buf)
}
}
macro_rules! partial_write {
( $me:expr, $cx:expr, $s:expr, $n:ident, $x:ident, $short:block, $full:block ) => {
match ($me).inner.as_mut().poll_write(($cx), ($s)) {
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
Poll::Pending if $n == 0 => return Poll::Pending,
Poll::Pending => return Poll::Ready(Ok($n)),
Poll::Ready(Ok($x)) if $x < ($s).len() => $short,
Poll::Ready(Ok($x)) => $full,
}
};
( $me:expr, $cx:expr, $s:expr, $n:ident, $x:ident, $short:block ) => {
partial_write!($me, $cx, $s, $n, $x, $short, { $n += $x })
};
( $me:expr, $cx:expr, $s:expr, $n:ident, $x:ident ) => {
partial_write!($me, $cx, $s, $n, $x, { return Poll::Ready(Ok($n + $x)) })
};
( $me:expr, $cx:expr, $s:expr, $n:ident ) => {
partial_write!($me, $cx, $s, $n, x_)
};
}
impl<W: AsyncWrite + Unpin> AsyncWrite for Writer<W> {
fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize>> {
let mut me = self.project();
if buf.is_empty() {
return Poll::Ready(Ok(0));
}
const SENTINEL_ENCODED: &[u8; 4] = &[0x55, 0x55, 0x55, 0x00];
const SENTINEL: &[u8; 3] = &[0x55, 0x55, 0x55];
let mut n: usize = 0;
let mut s = buf;
if *me.trailing_fives > 0 {
if *me.trailing_fives + s.len() < SENTINEL.len() {
partial_write!(
me,
cx,
s,
n,
x,
{
*me.trailing_fives =
if s[..x] == SENTINEL[*me.trailing_fives..*me.trailing_fives + x] {
*me.trailing_fives + x
} else {
0
};
return Poll::Ready(Ok(n + x));
},
{
*me.trailing_fives =
if s[..x] == SENTINEL[*me.trailing_fives..*me.trailing_fives + x] {
*me.trailing_fives + x
} else {
0
};
return Poll::Ready(Ok(n + x));
}
);
} else if s[..SENTINEL.len() - *me.trailing_fives] == SENTINEL[*me.trailing_fives..] {
partial_write!(
me,
cx,
&SENTINEL_ENCODED[*me.trailing_fives..],
n,
_x,
{
*me.trailing_fives += _x;
return Poll::Ready(Ok(n + _x));
},
{
n += SENTINEL.len() - *me.trailing_fives;
}
);
s = &s[SENTINEL.len() - *me.trailing_fives..];
}
}
while let Some(p) = s.windows(SENTINEL.len()).position(|w| w == SENTINEL) {
let (l, m) = s.split_at(p);
let (m, r) = m.split_at(SENTINEL.len());
assert_eq!(m, SENTINEL);
partial_write!(me, cx, l, n);
partial_write!(
me,
cx,
SENTINEL_ENCODED,
n,
_x,
{
*me.trailing_fives = _x;
return Poll::Ready(Ok(n + _x));
},
{
n += SENTINEL.len();
}
);
s = r;
}
*me.trailing_fives = 0;
for i in (1..SENTINEL.len()).rev() {
if s.ends_with(&SENTINEL[0..i]) {
*me.trailing_fives = i;
break;
}
}
partial_write!(me, cx, s, n);
Poll::Ready(Ok(n))
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
self.project().inner.poll_flush(cx)
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
self.project().inner.poll_shutdown(cx)
}
}
#[cfg(test)]
mod tests {
use super::super::tests::data::*;
use super::*;
use crc16::{State as Crc16, MODBUS};
use rstest::rstest;
#[tokio::test]
async fn test_reader() {
use tokio::io::AsyncReadExt;
let mut buf = [0u8; 1024];
let mut r = Reader::new(bytes_reader!(MSG_RESP_AC_CAP));
let len = r.read(&mut buf).await.expect("Couldn't read");
println!("{}: {:?}", len, &buf[..len]);
assert_eq!(buf[..4], [0x55, 0x55, 0x55, 0xAA]);
assert_eq!(len, MSG_RESP_AC_CAP.len() - 2);
let mut crc = Crc16::<MODBUS>::new();
crc.update(&buf[4..len - 2]);
assert_eq!(
crc.get(),
u16::from_be_bytes(buf[len - 2..len].try_into().unwrap())
);
assert_eq!(&buf[14..27], "UUUNIT 01 UUU".as_bytes());
}
#[rstest]
#[tokio::test]
async fn test_reader_split(#[values(14, 15, 16, 17, 18, 19)] at: usize) {
use tokio::io::AsyncReadExt;
let mut buf = [0u8; 1024];
let (first, second) = bytes_reader!(MSG_RESP_AC_CAP).split_at(at);
let mut r = Reader::new(first.chain(second));
let mut len = 0usize;
while len < MSG_RESP_AC_CAP.len() - 2 {
let l = r.read(&mut buf[len..]).await.expect("Couldn't read");
assert!(l > 0, "unexpected eof");
len += l;
}
println!("{}: {:?}", len, &buf[..len]);
assert_eq!(buf[..4], [0x55, 0x55, 0x55, 0xAA]);
assert_eq!(len, MSG_RESP_AC_CAP.len() - 2);
let mut crc = Crc16::<MODBUS>::new();
crc.update(&buf[4..len - 2]);
assert_eq!(
crc.get(),
u16::from_be_bytes(buf[len - 2..len].try_into().unwrap())
);
assert_eq!(&buf[14..27], "UUUNIT 01 UUU".as_bytes());
}
#[rstest]
#[tokio::test]
async fn test_reader_split_len(
#[values(14, 15, 16, 17, 18, 19)] at: usize,
#[values(1, 2, 3, 4)] n: usize,
) {
use tokio::io::AsyncReadExt;
let mut buf = [0u8; 1024];
let (first, second) = bytes_reader!(MSG_RESP_AC_CAP).split_at(at);
let (second, third) = second.split_at(n);
let mut r = Reader::new(first.chain(second).chain(third));
let mut len = 0usize;
while len < MSG_RESP_AC_CAP.len() - 2 {
let l = r.read(&mut buf[len..]).await.expect("Couldn't read");
assert!(l > 0, "unexpected eof");
len += l;
}
println!("{}: {:?}", len, &buf[..len]);
assert_eq!(buf[..4], [0x55, 0x55, 0x55, 0xAA]);
assert_eq!(len, MSG_RESP_AC_CAP.len() - 2);
let mut crc = Crc16::<MODBUS>::new();
crc.update(&buf[4..len - 2]);
assert_eq!(
crc.get(),
u16::from_be_bytes(buf[len - 2..len].try_into().unwrap())
);
assert_eq!(&buf[14..27], "UUUNIT 01 UUU".as_bytes());
}
#[tokio::test]
async fn test_reader_into_inner() {
use tokio::io::AsyncReadExt;
let mut buf = [0u8; 1024];
let (first, second) = bytes_reader!(MSG_RESP_AC_CAP).split_at(16);
let mut r = Reader::new(first.chain(second));
let l = r.read(&mut buf).await.expect("couldn't read");
let mut rr = r.into_inner();
let k = rr.read(&mut buf).await.expect("couldn't read inner");
assert_eq!(buf[..k], MSG_RESP_AC_CAP[l..]);
}
#[tokio::test]
async fn test_writer() {
use tokio::io::AsyncWriteExt;
let cursor = std::io::Cursor::new(vec![0; 1024]);
let mut w = Writer::new(cursor);
let l = w
.write(&decode(&MSG_RESP_AC_CAP[4..MSG_RESP_AC_CAP.len() - 2]))
.await
.expect("couldn't write");
assert_eq!(l, MSG_RESP_AC_CAP.len() - 4 - 2 - 2);
assert_eq!(w.into_inner().get_ref()[..l], MSG_RESP_AC_CAP[4..4 + l]);
}
#[tokio::test]
async fn test_writer_magic() {
use tokio::io::AsyncWriteExt;
let cursor = std::io::Cursor::new(vec![0; 1024]);
let mut w = Writer::new(cursor);
let mut l = w
.write_magic(&MSG_RESP_AC_CAP[..4])
.await
.expect("couldn't write magic");
l += w
.write(&decode(&MSG_RESP_AC_CAP[4..MSG_RESP_AC_CAP.len() - 2]))
.await
.expect("couldn't write");
assert_eq!(l, MSG_RESP_AC_CAP.len() - 2 - 2);
assert_eq!(w.into_inner().get_ref()[..l], MSG_RESP_AC_CAP[..l]);
}
#[rstest]
#[tokio::test]
async fn test_writer_split(#[values(14, 15, 16, 17, 18, 19)] at: usize) {
use tokio::io::AsyncWriteExt;
let cursor = std::io::Cursor::new(vec![0; 1024]);
let mut w = Writer::new(cursor);
let src = decode(&MSG_RESP_AC_CAP[..MSG_RESP_AC_CAP.len() - 2]);
let (first, second) = src.split_at(at);
let mut l = w
.write_magic(&first[..4])
.await
.expect("couldn't write magic");
l += w.write(&first[4..]).await.expect("couldn't write");
l += w.write(second).await.expect("couldn't write");
assert_eq!(l, MSG_RESP_AC_CAP.len() - 2 - 2);
assert_eq!(w.into_inner().get_ref()[..l], MSG_RESP_AC_CAP[..l]);
}
#[rstest]
#[tokio::test]
async fn test_writer_split_len(
#[values(14, 15, 16, 17, 18, 19)] at: usize,
#[values(1, 2, 3, 4)] n: usize,
) {
use tokio::io::AsyncWriteExt;
let cursor = std::io::Cursor::new(vec![0; 1024]);
let mut w = Writer::new(cursor);
let src = decode(&MSG_RESP_AC_CAP[..MSG_RESP_AC_CAP.len() - 2]);
let (first, second) = src.split_at(at);
let (second, third) = second.split_at(n);
let mut l = w
.write_magic(&first[..4])
.await
.expect("couldn't write magic");
l += w.write(&first[4..]).await.expect("couldn't write");
l += w.write(second).await.expect("couldn't write");
l += w.write(third).await.expect("couldn't write");
assert_eq!(l, MSG_RESP_AC_CAP.len() - 2 - 2);
assert_eq!(w.into_inner().get_ref()[..l], MSG_RESP_AC_CAP[..l]);
}
#[tokio::test]
async fn test_writer_into_inner() {
use tokio::io::AsyncWriteExt;
let cursor = std::io::Cursor::new(vec![0; 1024]);
let mut w = Writer::new(cursor);
let mut l = w
.write(&MSG_RESP_AC_CAP[4..17])
.await
.expect("couldn't write");
l += w
.write(&MSG_RESP_AC_CAP[18..28])
.await
.expect("couldn't write");
let mut ww = w.into_inner();
l += ww
.write(&MSG_RESP_AC_CAP[29..MSG_RESP_AC_CAP.len() - 2])
.await
.expect("couldn't write inner");
assert_eq!(l, MSG_RESP_AC_CAP.len() - 4 - 2 - 2);
assert_eq!(ww.get_ref()[..l], MSG_RESP_AC_CAP[4..4 + l]);
}
}