pub(crate) mod frame;
mod io;
use std::{
collections::VecDeque,
io::Result,
net::{IpAddr, SocketAddr},
pin::Pin,
task::{ready, Context, Poll},
};
use futures_sink::Sink;
use log::{info, warn};
use pin_project_lite::pin_project;
use tokio::{
io::{AsyncRead, AsyncWrite, BufReader, BufWriter},
net::{TcpStream, ToSocketAddrs},
};
use tokio_stream::Stream;
use tokio_util::codec::{Encoder, FramedRead};
use frame::{Frame, FrameCodec, MaybeFrame};
const DEFAULT_PORT: u16 = 9005;
pin_project! {
pub(crate) struct Connection {
#[pin]
reader: Box<dyn Stream<Item = Result<Frame>> + Send + Sync + Unpin>,
#[pin]
writer: Box<dyn Sink<Frame, Error = std::io::Error> + Send + Sync + Unpin>,
}
}
impl Connection {
pub async fn with_ipaddr(addr: IpAddr) -> Result<Self> {
Self::new((addr, DEFAULT_PORT)).await
}
pub async fn with_ipaddrs(addrs: &[IpAddr]) -> Result<Self> {
let sockaddrs: Vec<SocketAddr> = addrs
.iter()
.map(|a| SocketAddr::new(*a, DEFAULT_PORT))
.collect();
Self::new(&sockaddrs[..]).await
}
pub async fn with_str(addr: &str) -> Result<Self> {
Self::new((addr, DEFAULT_PORT)).await
}
pub async fn new<A: ToSocketAddrs>(addr: A) -> Result<Self> {
let socket = TcpStream::connect(addr).await?;
Ok(Self::from_socket(socket))
}
pub fn from_socket(socket: TcpStream) -> Self {
let (read, write) = socket.into_split();
Self::from_io(read, write)
}
fn from_io<
R: AsyncRead + Send + Sync + Unpin + 'static,
W: AsyncWrite + Send + Sync + Unpin + 'static,
>(
read: R,
write: W,
) -> Self {
use tokio_stream::StreamExt;
let codec = FrameCodec::new();
let reader = Box::new(
FramedRead::new(io::Reader::new(BufReader::new(read)), codec)
.filter_map(filter_maybe_frame),
);
let writer = Box::new(FramedWrite::new(
io::Writer::new(BufWriter::new(write)),
codec,
));
Self { reader, writer }
}
}
fn filter_maybe_frame(f: Result<MaybeFrame>) -> Option<Result<Frame>> {
match f {
Err(e) => Some(Err(e)),
Ok(MaybeFrame::CrcError(calculated, expected)) => {
warn!(
"Received frame with bad CRC: calculated {:#06x}, expected {:#06x}",
calculated, expected
);
None
}
Ok(MaybeFrame::Frame(frame)) if frame.kind == MessageKind::Unknown => {
info!(
"Ignoring unknown message, type {:#04x} address {:#06x}",
frame.msg_type, frame.address
);
None
}
Ok(MaybeFrame::Frame(frame)) => Some(Ok(frame)),
}
}
impl Stream for Connection {
type Item = Result<Frame>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.project().reader.poll_next(cx)
}
}
impl Sink<Frame> for Connection {
type Error = std::io::Error;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
self.project().writer.poll_ready(cx)
}
fn start_send(self: Pin<&mut Self>, item: Frame) -> Result<()> {
self.project().writer.start_send(item)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
self.project().writer.poll_flush(cx)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
self.project().writer.poll_close(cx)
}
}
pin_project! {
struct FramedWrite<W, C> {
#[pin]
inner: io::Writer<W>,
codec: C,
bufs: VecDeque<tokio_util::bytes::BytesMut>,
idx: usize,
}
}
impl<W: AsyncWrite + Unpin, C: Encoder<Frame>> FramedWrite<W, C> {
fn new(writer: io::Writer<W>, codec: C) -> Self {
Self {
inner: writer,
codec,
bufs: VecDeque::with_capacity(4),
idx: 0,
}
}
fn total_buffered(&self) -> usize {
self.bufs.iter().fold(0, |acc, e| acc + e.len())
}
fn poll_flush_some(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
let mut me = self.project();
const LEN_MAGIC: usize = std::mem::size_of::<u32>();
while let Some(first) = me.bufs.front() {
while *me.idx < LEN_MAGIC {
*me.idx += ready!(me
.inner
.as_mut()
.poll_write_magic(cx, &first[*me.idx..LEN_MAGIC]))?;
}
while *me.idx < first.len() {
*me.idx += ready!(me.inner.as_mut().poll_write(cx, &first[*me.idx..]))?;
}
*me.idx = 0;
me.bufs.pop_front();
}
if let Poll::Ready(Err(e)) = me.inner.poll_flush(cx) {
Poll::Ready(Err(e))
} else {
Poll::Ready(Ok(()))
}
}
}
impl<W: AsyncWrite + Unpin, C: Encoder<Frame, Error = std::io::Error>> Sink<Frame>
for FramedWrite<W, C>
{
type Error = std::io::Error;
fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
if self.bufs.len() == self.bufs.capacity() || self.total_buffered() > 512 {
if let Poll::Ready(Err(e)) = self.as_mut().poll_flush_some(cx) {
return Poll::Ready(Err(e));
}
}
if self.bufs.len() == self.bufs.capacity() || self.total_buffered() > 512 {
Poll::Pending
} else {
Poll::Ready(Ok(()))
}
}
fn start_send(self: Pin<&mut Self>, item: Frame) -> Result<()> {
let mut b = tokio_util::bytes::BytesMut::with_capacity(64);
let me = self.project();
me.codec.encode(item, &mut b)?;
me.bufs.push_back(b);
Ok(())
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
while !self.bufs.is_empty() {
ready!(self.as_mut().poll_flush_some(cx))?;
}
ready!(self.project().inner.poll_flush(cx))?;
Poll::Ready(Ok(()))
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
ready!(self.as_mut().poll_flush(cx))?;
ready!(self.project().inner.poll_shutdown(cx))?;
Poll::Ready(Ok(()))
}
}
#[derive(Clone, Copy, Debug, PartialEq)]
pub enum MessageKind {
ControlRequest,
StatusResponse,
ExtendedRequest,
ExtendedResponse,
Unknown,
}
impl MessageKind {
fn is_valid(&self, t: u8, a: u16) -> bool {
match self {
Self::ControlRequest if t == 0xc0 && a == 0x80b0 => true,
Self::StatusResponse if t == 0xc0 && a & 0x00ff == 0x0080 => true,
Self::ExtendedRequest if t == 0x1f && a == 0x90b0 => true,
Self::ExtendedResponse if t == 0x1f && a & 0xfffe == 0xb090 => true,
_ => false,
}
}
}
impl From<(u8, u16)> for MessageKind {
fn from(value: (u8, u16)) -> Self {
match value {
(0xc0, 0x80b0) => Self::ControlRequest,
(0xc0, a) if a & 0x00ff == 0x0080 => Self::StatusResponse,
(0x1f, 0x90b0) => Self::ExtendedRequest,
(0x1f, 0xb090) => Self::ExtendedResponse,
(0x1f, 0xb091) => Self::ExtendedResponse,
_ => Self::Unknown,
}
}
}
impl From<MessageKind> for (u8, u16) {
fn from(value: MessageKind) -> (u8, u16) {
match value {
MessageKind::ControlRequest => (0xc0, 0x80b0),
MessageKind::StatusResponse => (0xc0, 0xb080),
MessageKind::ExtendedRequest => (0x1f, 0x90b0),
MessageKind::ExtendedResponse => (0x1f, 0xb090),
_ => (0x00, 0x0000),
}
}
}
#[cfg(test)]
pub(crate) mod tests {
use super::*;
use rstest::rstest;
pub(crate) mod data {
#[rustfmt::skip]
pub(crate) const MSG_REQ_STATUS_ZONES: &[u8] = &[
0x55, 0x55, 0x55, 0xAA, 0x80, 0xB0, 0x01, 0xC0, 0x00, 0x08, 0x21, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xA4, 0x31, ];
#[rustfmt::skip]
pub(crate) const MSG_RESP_STATUS_ZONES: &[u8] = &[
0x55, 0x55, 0x55, 0xAA, 0xB0, 0x80, 0x01, 0xC0, 0x00, 0x18, 0x21, 0x00, 0x00, 0x00, 0x00, 0x08, 0x00, 0x02, 0x40, 0x80, 0x96, 0x80, 0x02, 0xE7, 0x00, 0x00, 0x01, 0x64, 0xFF, 0x00, 0x07, 0xFF, 0x00, 0x00, 0xB9, 0xEF ];
#[rustfmt::skip]
pub(crate) const MSG_REQ_STATUS_ACS: &[u8] = &[
0x55, 0x55, 0x55, 0xAA, 0x80, 0xB0, 0x01, 0xC0, 0x00, 0x08, 0x23, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x7D, 0xB0 ];
#[rustfmt::skip]
pub(crate) const MSG_RESP_STATUS_ACS: &[u8] = &[
0x55, 0x55, 0x55, 0xAA, 0xB0, 0x80, 0x01, 0xC0, 0x00, 0x1C, 0x23, 0x00, 0x00, 0x00, 0x00, 0x0A, 0x00, 0x02, 0x10, 0x12, 0x78, 0xC0, 0x02, 0xDA, 0x00, 0x00, 0x80, 0x00, 0x01, 0x42, 0x64, 0xC0, 0x02, 0xE4, 0x00, 0x00, 0x80, 0x00, 0x3D, 0x79 ];
#[rustfmt::skip]
pub(crate) const MSG_REQ_AC_CAP_ONE: &[u8] = &[
0x55, 0x55, 0x55, 0xAA, 0x90, 0xB0, 0x01, 0x1F, 0x00, 0x03, 0xFF, 0x11, 0x00, 0x09, 0x83 ];
#[rustfmt::skip]
pub(crate) const MSG_REQ_AC_CAP_ALL: &[u8] = &[
0x55, 0x55, 0x55, 0xAA, 0x90, 0xB0, 0x01, 0x1F, 0x00, 0x02, 0xFF, 0x11, 0x83, 0x4C ];
#[rustfmt::skip]
pub(crate) const MSG_RESP_AC_CAP: &[u8] = &[
0x55, 0x55, 0x55, 0xAA, 0xB0, 0x90, 0x01, 0x1F, 0x00, 0x1C, 0xFF, 0x11, 0x00, 0x18, 0x55, 0x55, 0x55, 0x00, 0x4E, 0x49, 0x54, 0x20, 0x30, 0x31, 0x20, 0x55, 0x55, 0x55, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x17, 0x1D, 0x10, 0x1f, 0x12, 0x1f, 0x70, 0xF0, ];
#[rustfmt::skip]
pub(crate) const MSG_REQ_ZONE_NAME_ONE: &[u8] = &[
0x55, 0x55, 0x55, 0xAA, 0x90, 0xB0, 0x01, 0x1F, 0x00, 0x03, 0xFF, 0x13, 0x00, 0x69, 0x82, ];
#[rustfmt::skip]
pub(crate) const MSG_RESP_ZONE_NAME_ONE: &[u8] = &[
0x55, 0x55, 0x55, 0xAA, 0xB0, 0x90, 0x01, 0x1F, 0x00, 0x0A, 0xFF, 0x13, 0x00, 0x06, 0x4C, 0x69, 0x76, 0x69, 0x6E, 0x67, 0xB6, 0x2F, ];
#[rustfmt::skip]
pub(crate) const MSG_REQ_ZONE_NAME_ALL: &[u8] = &[
0x55, 0x55, 0x55, 0xAA, 0x90, 0xB0, 0x01, 0x1F, 0x00, 0x02, 0xFF, 0x13, 0x42, 0xCD, ];
#[rustfmt::skip]
pub(crate) const MSG_RESP_ZONE_NAME_ALL: &[u8] = &[
0x55, 0x55, 0x55, 0xAA, 0xB0, 0x90, 0x01, 0x1F, 0x00, 0x1C, 0xFF, 0x13, 0x00, 0x06, 0x4C, 0x69, 0x76, 0x69, 0x6E, 0x67, 0x01, 0x07, 0x4B, 0x69, 0x74, 0x63, 0x68, 0x65, 0x6E, 0x02, 0x07, 0x42, 0x65, 0x64, 0x72, 0x6F, 0x6F, 0x6D, 0xAE, 0x8B, ];
#[rustfmt::skip]
pub(crate) const MSG_REQ_CON_VERS: &[u8] = &[
0x55, 0x55, 0x55, 0xAA, 0x90, 0xB0, 0x01, 0x1F, 0x00, 0x02, 0xFF, 0x30, 0x9B, 0x8C, ];
#[rustfmt::skip]
pub(crate) const MSG_RESP_CON_VERS: &[u8] = &[
0x55, 0x55, 0x55, 0xAA, 0xB0, 0x90, 0x01, 0x1F, 0x00, 0x0F, 0xFF, 0x30, 0x00, 0x0B, 0x31, 0x2E, 0x30, 0x2E, 0x33, 0x2C, 0x31, 0x2E, 0x30, 0x2E, 0x33, 0x13, 0x28, ];
macro_rules! bytes_reader {
( $x:expr) => {
&($x)[..]
};
}
pub(crate) use bytes_reader;
pub(crate) struct ErroringReader {
pub kind: std::io::ErrorKind,
pub payload: Option<String>,
}
impl tokio::io::AsyncRead for ErroringReader {
fn poll_read(
self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
_buf: &mut tokio::io::ReadBuf<'_>,
) -> std::task::Poll<std::io::Result<()>> {
std::task::Poll::Ready(Err(if let Some(payload) = &self.payload {
std::io::Error::new(self.kind, payload.clone())
} else {
std::io::Error::from(self.kind)
}))
}
}
macro_rules! erroring_reader {
() => {
ErroringReader {
kind: std::io::ErrorKind::Other,
payload: Some("injected error".to_string()),
}
};
( $k:path ) => {
ErroringReader {
kind: $k,
payload: None,
}
};
( $k:path, $p:expr ) => {
ErroringReader {
kind: $k,
payload: Some($p.to_string()),
}
};
( $p:expr ) => {
ErroringReader {
kind: std::io::ErrorKind::Other,
payload: Some($p.to_string()),
}
};
}
pub(crate) use erroring_reader;
pub(crate) fn decode(src: &[u8]) -> Vec<u8> {
let mut dst: Vec<u8> = Vec::with_capacity(src.len());
let mut fives = 0;
for b in src.iter() {
if fives == 3 {
fives = 0;
if *b == 0 {
continue;
}
}
dst.push(*b);
if *b == 0x55 {
fives += 1;
}
}
dst
}
pub(crate) fn frame(src: &[u8]) -> super::frame::Frame {
use tokio_util::codec::Decoder;
let mut src = tokio_util::bytes::BytesMut::from(src);
let mut codec = super::frame::FrameCodec::new();
assert_matches!(codec.decode(&mut src),
Ok(Some(super::frame::MaybeFrame::Frame(frame))) => frame
)
}
#[test]
fn test_bytes_reader() {
use std::io::prelude::*;
let mut head = [0u8; 4];
let mut buf = [0u8; MSG_REQ_STATUS_ZONES.len() - 4];
let mut r = bytes_reader!(MSG_REQ_STATUS_ZONES);
let len = r.read(&mut head).expect("couldn't head");
assert_eq!(len, head.len());
assert_eq!(head[..len], MSG_REQ_STATUS_ZONES[..len]);
let len = r.read(&mut buf).expect("couldn't read");
assert_eq!(len, MSG_REQ_STATUS_ZONES.len() - 4);
assert_eq!(buf[..len], MSG_REQ_STATUS_ZONES[4..4 + len]);
}
#[tokio::test]
async fn test_bytes_reader_async() {
use tokio::io::AsyncReadExt;
let mut head = [0u8; 4];
let mut buf = [0u8; MSG_REQ_STATUS_ZONES.len() - 4];
let mut r = bytes_reader!(MSG_REQ_STATUS_ZONES);
let len = r.read(&mut head).await.expect("couldn't head");
assert_eq!(len, head.len());
assert_eq!(head[..len], MSG_REQ_STATUS_ZONES[..len]);
let len = r.read(&mut buf).await.expect("couldn't read");
assert_eq!(len, MSG_REQ_STATUS_ZONES.len() - 4);
assert_eq!(buf[..len], MSG_REQ_STATUS_ZONES[4..4 + len]);
}
#[tokio::test]
#[super::rstest]
#[case(erroring_reader!(), std::io::ErrorKind::Other, "injected error")]
#[case(erroring_reader!(std::io::ErrorKind::InvalidData),
std::io::ErrorKind::InvalidData, "invalid data")]
#[case(erroring_reader!(std::io::ErrorKind::NotFound, "nope.txt"),
std::io::ErrorKind::NotFound, "nope.txt")]
#[case(erroring_reader!("you have made a fatal mistake"),
std::io::ErrorKind::Other, "you have made a fatal mistake")]
async fn test_erroring_reader(
#[case] mut reader: impl tokio::io::AsyncRead + Unpin,
#[case] expected_kind: std::io::ErrorKind,
#[case] expected_string: &str,
) {
use tokio::io::AsyncReadExt;
assert_matches!(reader.read(&mut [0u8]).await, Err(e) => {
assert_eq!(e.kind(), expected_kind);
assert_eq!(e.to_string(), expected_string);
});
}
#[test]
fn test_decode() {
use crc16::{State as Crc16, MODBUS};
let buf = decode(MSG_RESP_AC_CAP);
assert_eq!(buf.len(), MSG_RESP_AC_CAP.len() - 2);
let mut crc = Crc16::<MODBUS>::new();
crc.update(&buf[4..buf.len() - 2]);
assert_eq!(
crc.get(),
u16::from_be_bytes(buf[buf.len() - 2..buf.len()].try_into().unwrap())
);
assert_eq!(&buf[14..27], "UUUNIT 01 UUU".as_bytes());
}
#[test]
fn test_frame() {
let f = frame(MSG_RESP_ZONE_NAME_ALL);
assert_eq!(f.kind, super::MessageKind::ExtendedResponse);
}
#[test]
fn expr_calc_crc() {
use crc16::{State as Crc16, MODBUS};
let mut crc = Crc16::<MODBUS>::new();
crc.update(&decode(&MSG_RESP_AC_CAP[4..MSG_RESP_AC_CAP.len() - 2])[..]);
println!("{:#06x}", crc.get());
}
}
use data::*;
#[rstest]
#[case(0xc0, 0x80b0, MessageKind::ControlRequest)]
#[case(0xc0, 0xb080, MessageKind::StatusResponse)]
#[case(0xc0, 0xfd80, MessageKind::StatusResponse)]
#[case(0xc1, 0x8080, MessageKind::Unknown)]
#[case(0x1f, 0x8080, MessageKind::Unknown)]
#[case(0x1f, 0x90b0, MessageKind::ExtendedRequest)]
#[case(0x1f, 0xb090, MessageKind::ExtendedResponse)]
#[case(0x1f, 0xb091, MessageKind::ExtendedResponse)]
#[case(0x1f, 0x90b1, MessageKind::Unknown)]
fn test_message_kind_from(
#[case] msg_type: u8,
#[case] address: u16,
#[case] expected: MessageKind,
) {
let kind: MessageKind = (msg_type, address).into();
assert_eq!(kind, expected);
}
#[rstest]
#[case(0xc0, 0x80b0, MessageKind::ControlRequest, true)]
#[case(0xc0, 0xb080, MessageKind::StatusResponse, true)]
#[case(0xc0, 0xfd80, MessageKind::StatusResponse, true)]
#[case(0xc1, 0xb080, MessageKind::StatusResponse, false)]
#[case(0x1f, 0x8080, MessageKind::StatusResponse, false)]
#[case(0x1f, 0x90b0, MessageKind::ExtendedRequest, true)]
#[case(0x1f, 0xb090, MessageKind::ExtendedResponse, true)]
#[case(0x1f, 0xb091, MessageKind::ExtendedResponse, true)]
#[case(0x1f, 0x90b1, MessageKind::ExtendedRequest, false)]
#[case(0xc0, 0xb080, MessageKind::Unknown, false)]
#[case(0xc0, 0xfd80, MessageKind::Unknown, false)]
#[case(0x1f, 0x90b0, MessageKind::Unknown, false)]
#[case(0x00, 0xffff, MessageKind::Unknown, false)]
fn test_message_kind_is_valid(
#[case] msg_type: u8,
#[case] address: u16,
#[case] kind: MessageKind,
#[case] expected: bool,
) {
assert_eq!(kind.is_valid(msg_type, address), expected);
}
#[tokio::test]
async fn test_conn_stream_ok() {
use tokio::io::AsyncReadExt;
use tokio_stream::StreamExt;
let write: Vec<u8> = vec![];
let mut conn = Connection::from_io(
bytes_reader!(MSG_REQ_STATUS_ZONES).chain(bytes_reader!(MSG_RESP_AC_CAP)),
write,
);
assert_matches!(conn.next().await, Some(Ok(frame)) => {
assert_eq!(frame.kind, MessageKind::ControlRequest);
});
assert_matches!(conn.next().await, Some(Ok(frame)) => {
assert_eq!(frame.kind, MessageKind::ExtendedResponse);
});
assert_matches!(conn.next().await, None);
}
#[tokio::test]
async fn test_conn_stream_badcrc() {
use tokio::io::AsyncReadExt;
use tokio_stream::StreamExt;
testing_logger::setup();
let write: Vec<u8> = vec![];
let mut conn = Connection::from_io(
bytes_reader!(&MSG_REQ_STATUS_ZONES[..MSG_REQ_STATUS_ZONES.len() - 2])
.chain(bytes_reader!(&[0xac, 0xab]))
.chain(bytes_reader!(MSG_RESP_AC_CAP)),
write,
);
assert_matches!(conn.next().await, Some(Ok(frame)) => {
assert_eq!(frame.kind, MessageKind::ExtendedResponse);
});
assert_matches!(conn.next().await, None);
testing_logger::validate(|logs| {
assert_eq!(logs.len(), 1, "expected exactly one log");
assert_eq!(logs[0].level, log::Level::Warn);
let s = logs[0].body.to_lowercase();
for p in ["bad crc", "expected 0xacab", "calculated 0xa431"] {
assert!(s.contains(p), "incorrect log: {}", logs[0].body);
}
});
}
#[tokio::test]
async fn test_conn_stream_badtype() {
use tokio::io::AsyncReadExt;
use tokio_stream::StreamExt;
testing_logger::setup();
let write: Vec<u8> = vec![];
let mut conn = Connection::from_io(
bytes_reader!(&MSG_REQ_STATUS_ZONES[..7])
.chain(bytes_reader!(&[0xff]))
.chain(bytes_reader!(
&MSG_REQ_STATUS_ZONES[8..MSG_REQ_STATUS_ZONES.len() - 2]
))
.chain(bytes_reader!(&[0xB0, 0xFE]))
.chain(bytes_reader!(MSG_RESP_AC_CAP)),
write,
);
assert_matches!(conn.next().await, Some(Ok(frame)) => {
assert_eq!(frame.kind, MessageKind::ExtendedResponse);
});
assert_matches!(conn.next().await, None);
testing_logger::validate(|logs| {
assert_eq!(logs.len(), 1, "expected exactly one log");
assert_eq!(logs[0].level, log::Level::Info);
let s = logs[0].body.to_lowercase();
for p in ["unknown message", "type 0xff", "address 0x80b0"] {
assert!(s.contains(p), "incorrect log: {}", logs[0].body);
}
});
}
#[tokio::test]
async fn test_conn_stream_eio() {
use tokio::io::AsyncReadExt;
use tokio_stream::StreamExt;
let write: Vec<u8> = vec![];
let mut conn = Connection::from_io(
bytes_reader!(MSG_REQ_STATUS_ZONES)
.chain(erroring_reader!(std::io::ErrorKind::InvalidData))
.chain(bytes_reader!(MSG_RESP_AC_CAP)),
write,
);
assert_matches!(conn.next().await, Some(Ok(frame)) => {
assert_eq!(frame.kind, MessageKind::ControlRequest);
});
assert_matches!(conn.next().await, Some(Err(eio)) => {
assert_eq!(eio.kind(), std::io::ErrorKind::InvalidData);
});
assert_matches!(conn.next().await, None);
}
#[tokio::test]
async fn test_conn_sink_ok() {
use futures_util::sink::SinkExt;
use tokio::io::AsyncReadExt;
let (mut read, write) = tokio::io::simplex(1024);
let mut conn = Connection::from_io(&[0u8; 0][..], write);
let frame = Frame {
msg_id: 1,
msg_type: 0xc0,
address: 0x80b0,
kind: MessageKind::ControlRequest,
data: vec![0x21, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00],
};
conn.send(frame).await.expect("failed to send");
let mut buf = [0u8; 128];
let l = read.read(&mut buf).await.expect("failed to read");
assert_eq!(&buf[..l], MSG_REQ_STATUS_ZONES);
}
#[tokio::test]
async fn test_conn_sink_shortpipe_spawn() {
use futures_util::sink::SinkExt;
use tokio::io::AsyncReadExt;
let (mut read, write) = tokio::io::simplex(2);
let mut conn = Connection::from_io(&[0u8; 0][..], write);
let frame = Frame {
msg_id: 1,
msg_type: 0xc0,
address: 0x80b0,
kind: MessageKind::ControlRequest,
data: vec![0x21, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00],
};
let jh: tokio::task::JoinHandle<Result<Vec<u8>>> = tokio::spawn(async move {
let mut l = 0;
let mut buf = [0u8; 128];
while l < MSG_REQ_STATUS_ZONES.len() {
l += read.read(&mut buf[l..]).await?;
}
Ok(buf[..l].to_owned())
});
conn.send(frame).await.expect("failed to send");
let buf = jh.await.expect("subthread panic").expect("couldn't read");
assert_eq!(&buf[..], MSG_REQ_STATUS_ZONES);
}
#[tokio::test]
async fn test_conn_sink_shortpipe_select() {
use futures_util::sink::SinkExt;
use tokio::io::AsyncReadExt;
let (mut read, write) = tokio::io::simplex(2);
let mut conn = Connection::from_io(&[0u8; 0][..], write);
let frame = Frame {
msg_id: 1,
msg_type: 0xc0,
address: 0x80b0,
kind: MessageKind::ControlRequest,
data: vec![0x21, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00],
};
let mut buf = [0u8; 128];
let mut l = 0;
let mut send = conn.send(frame);
while l < MSG_REQ_STATUS_ZONES.len() {
tokio::select! {
res = &mut send => {
assert!(res.is_ok());
},
res = read.read(&mut buf[l..]) => {
assert!(res.is_ok());
l += res.unwrap();
},
}
}
assert_eq!(&buf[..l], MSG_REQ_STATUS_ZONES);
}
#[tokio::test]
async fn test_conn_loopback() {
use futures_util::sink::SinkExt;
use tokio_stream::StreamExt;
let (read, write) = tokio::io::simplex(1024);
let mut conn = Connection::from_io(read, write);
let frame = Frame {
msg_id: 1,
msg_type: 0xc0,
address: 0x80b0,
kind: MessageKind::ControlRequest,
data: vec![0x21, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00],
};
conn.send(frame.clone()).await.expect("coudln't send");
assert_matches!(conn.next().await,
Some(Ok(received)) => {
assert_eq!(received, frame);
}
);
let frame = Frame {
msg_id: 17,
msg_type: 0x1f,
address: 0xb090,
kind: MessageKind::ExtendedResponse,
data: MSG_RESP_AC_CAP[10..MSG_RESP_AC_CAP.len() - 2].to_owned(),
};
conn.send(frame.clone()).await.expect("coudln't send");
assert_matches!(conn.next().await,
Some(Ok(received)) => {
assert_eq!(received, frame);
}
);
conn.close().await.expect("couldn't close");
assert_matches!(conn.next().await, None);
}
}