use std::{fmt, io};
use futures_util::{SinkExt as _, StreamExt as _};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_util::codec::Framed;
use crate::{
codec,
frame::{rtu::*, *},
slave::*,
Result,
};
use super::verify_response_header;
#[derive(Debug)]
pub(crate) struct Client<T> {
framed: Framed<T, codec::rtu::ClientCodec>,
slave_id: SlaveId,
}
impl<T> Client<T>
where
T: AsyncRead + AsyncWrite + Unpin,
{
pub(crate) fn new(transport: T, slave: Slave) -> Self {
let framed = Framed::new(transport, codec::rtu::ClientCodec::default());
let slave_id = slave.into();
Self { framed, slave_id }
}
fn next_request_adu<'a, R>(&self, req: R, disconnect: bool) -> RequestAdu<'a>
where
R: Into<RequestPdu<'a>>,
{
let slave_id = self.slave_id;
let hdr = Header { slave_id };
let pdu = req.into();
RequestAdu {
hdr,
pdu,
disconnect,
}
}
async fn call(&mut self, req: Request<'_>) -> Result<Response> {
let disconnect = req == Request::Disconnect;
let req_adu = self.next_request_adu(req, disconnect);
let req_hdr = req_adu.hdr;
self.framed.read_buffer_mut().clear();
self.framed.send(req_adu).await?;
let res_adu = self
.framed
.next()
.await
.unwrap_or_else(|| Err(io::Error::from(io::ErrorKind::BrokenPipe)))?;
match res_adu.pdu {
ResponsePdu(Ok(res)) => verify_response_header(&req_hdr, &res_adu.hdr).and(Ok(Ok(res))),
ResponsePdu(Err(err)) => Ok(Err(err.exception)),
}
}
}
impl<T> SlaveContext for Client<T> {
fn set_slave(&mut self, slave: Slave) {
self.slave_id = slave.into();
}
}
#[async_trait::async_trait]
impl<T> crate::client::Client for Client<T>
where
T: fmt::Debug + AsyncRead + AsyncWrite + Send + Unpin,
{
async fn call(&mut self, req: Request<'_>) -> Result<Response> {
self.call(req).await
}
}
#[cfg(test)]
mod tests {
use core::{
pin::Pin,
task::{Context, Poll},
};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf, Result};
use crate::service::{rtu::Header, verify_response_header};
#[test]
fn validate_same_headers() {
let req_hdr = Header { slave_id: 0 };
let rsp_hdr = Header { slave_id: 0 };
let result = verify_response_header(&req_hdr, &rsp_hdr);
assert!(result.is_ok());
}
#[test]
fn invalid_validate_not_same_slave_id() {
let req_hdr = Header { slave_id: 0 };
let rsp_hdr = Header { slave_id: 5 };
let result = verify_response_header(&req_hdr, &rsp_hdr);
assert!(matches!(
result,
Err(err) if err.kind() == std::io::ErrorKind::InvalidData));
}
#[derive(Debug)]
struct MockTransport;
impl Unpin for MockTransport {}
impl AsyncRead for MockTransport {
fn poll_read(
self: Pin<&mut Self>,
_: &mut Context<'_>,
_: &mut ReadBuf<'_>,
) -> Poll<Result<()>> {
Poll::Ready(Ok(()))
}
}
impl AsyncWrite for MockTransport {
fn poll_write(self: Pin<&mut Self>, _: &mut Context<'_>, _: &[u8]) -> Poll<Result<usize>> {
Poll::Ready(Ok(2))
}
fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<()>> {
unimplemented!()
}
}
#[tokio::test]
async fn handle_broken_pipe() {
let transport = MockTransport;
let mut client =
crate::service::rtu::Client::new(transport, crate::service::rtu::Slave::broadcast());
let res = client
.call(crate::service::rtu::Request::ReadCoils(0x00, 5))
.await;
assert!(res.is_err());
let err = res.err().unwrap();
assert_eq!(err.kind(), std::io::ErrorKind::BrokenPipe);
}
}