use std::io;
use std::net::SocketAddr;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Duration;
use async_trait::async_trait;
use bytes::{Bytes, BytesMut};
use serde::{Deserialize, Serialize};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf};
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::mpsc;
use super::serial::{SerialConfig, VirtualSerial, VirtualSerialConfig};
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum TransportConfig {
VirtualSerial(VirtualSerialConfig),
TcpBridge(TcpBridgeConfig),
Channel(ChannelConfig),
}
impl Default for TransportConfig {
fn default() -> Self {
Self::VirtualSerial(VirtualSerialConfig::default())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TcpBridgeConfig {
pub bind_address: SocketAddr,
pub serial: SerialConfig,
pub max_connections: usize,
pub connection_timeout: Duration,
pub raw_mode: bool,
}
impl Default for TcpBridgeConfig {
fn default() -> Self {
Self {
bind_address: "0.0.0.0:5020".parse().unwrap(),
serial: SerialConfig::default(),
max_connections: 10,
connection_timeout: Duration::from_secs(60),
raw_mode: true,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChannelConfig {
pub buffer_size: usize,
pub serial: SerialConfig,
pub simulate_delays: bool,
}
impl Default for ChannelConfig {
fn default() -> Self {
Self {
buffer_size: 256,
serial: SerialConfig::default(),
simulate_delays: false,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum TransportType {
VirtualSerial,
TcpBridge,
Channel,
}
#[async_trait]
pub trait RtuTransport: Send + Sync {
fn transport_type(&self) -> TransportType;
fn is_ready(&self) -> bool;
async fn read(&mut self, buf: &mut [u8]) -> io::Result<usize>;
async fn write(&mut self, data: &[u8]) -> io::Result<usize>;
async fn flush(&mut self) -> io::Result<()>;
fn serial_config(&self) -> &SerialConfig;
fn transmission_delay(&self, bytes: usize) -> Duration {
self.serial_config().transmission_time(bytes)
}
fn inter_frame_timeout(&self) -> Duration {
self.serial_config().inter_frame_timeout()
}
async fn close(&mut self) -> io::Result<()>;
}
pub struct ChannelTransport {
rx: mpsc::Receiver<Bytes>,
tx: mpsc::Sender<Bytes>,
config: ChannelConfig,
read_buffer: BytesMut,
}
impl ChannelTransport {
pub fn pair(config: ChannelConfig) -> (Self, Self) {
let (tx1, rx1) = mpsc::channel(config.buffer_size);
let (tx2, rx2) = mpsc::channel(config.buffer_size);
let transport1 = Self {
rx: rx1,
tx: tx2,
config: config.clone(),
read_buffer: BytesMut::new(),
};
let transport2 = Self {
rx: rx2,
tx: tx1,
config,
read_buffer: BytesMut::new(),
};
(transport1, transport2)
}
pub fn new(tx: mpsc::Sender<Bytes>, rx: mpsc::Receiver<Bytes>, config: ChannelConfig) -> Self {
Self {
rx,
tx,
config,
read_buffer: BytesMut::new(),
}
}
}
#[async_trait]
impl RtuTransport for ChannelTransport {
fn transport_type(&self) -> TransportType {
TransportType::Channel
}
fn is_ready(&self) -> bool {
!self.tx.is_closed()
}
async fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
if !self.read_buffer.is_empty() {
let len = std::cmp::min(buf.len(), self.read_buffer.len());
buf[..len].copy_from_slice(&self.read_buffer.split_to(len));
return Ok(len);
}
match self.rx.recv().await {
Some(data) => {
let len = std::cmp::min(buf.len(), data.len());
buf[..len].copy_from_slice(&data[..len]);
if data.len() > len {
self.read_buffer.extend_from_slice(&data[len..]);
}
Ok(len)
}
None => Ok(0), }
}
async fn write(&mut self, data: &[u8]) -> io::Result<usize> {
if self.config.simulate_delays {
let delay = self.config.serial.transmission_time(data.len());
tokio::time::sleep(delay).await;
}
self.tx
.send(Bytes::copy_from_slice(data))
.await
.map_err(|_| io::Error::new(io::ErrorKind::BrokenPipe, "Channel closed"))?;
Ok(data.len())
}
async fn flush(&mut self) -> io::Result<()> {
Ok(())
}
fn serial_config(&self) -> &SerialConfig {
&self.config.serial
}
async fn close(&mut self) -> io::Result<()> {
Ok(())
}
}
#[cfg(unix)]
pub struct VirtualSerialTransport {
io: tokio::fs::File,
config: VirtualSerialConfig,
}
#[cfg(unix)]
impl VirtualSerialTransport {
fn new(io: tokio::fs::File, config: VirtualSerialConfig) -> Self {
Self { io, config }
}
}
#[cfg(unix)]
#[async_trait]
impl RtuTransport for VirtualSerialTransport {
fn transport_type(&self) -> TransportType {
TransportType::VirtualSerial
}
fn is_ready(&self) -> bool {
true
}
async fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.io.read(buf).await
}
async fn write(&mut self, data: &[u8]) -> io::Result<usize> {
if self.config.simulate_delays {
tokio::time::sleep(self.config.serial.transmission_time(data.len())).await;
}
self.io.write(data).await
}
async fn flush(&mut self) -> io::Result<()> {
self.io.flush().await
}
fn serial_config(&self) -> &SerialConfig {
&self.config.serial
}
async fn close(&mut self) -> io::Result<()> {
self.io.flush().await
}
}
pub struct TcpBridgeTransport {
listener: TcpListener,
stream: Option<TcpStream>,
config: TcpBridgeConfig,
}
impl TcpBridgeTransport {
pub async fn bind(config: TcpBridgeConfig) -> io::Result<Self> {
let listener = TcpListener::bind(config.bind_address).await?;
Ok(Self {
listener,
stream: None,
config,
})
}
async fn ensure_stream(&mut self) -> io::Result<&mut TcpStream> {
if self.stream.is_none() {
let (stream, _) = self.listener.accept().await?;
self.stream = Some(stream);
}
Ok(self.stream.as_mut().expect("stream initialized"))
}
}
#[async_trait]
impl RtuTransport for TcpBridgeTransport {
fn transport_type(&self) -> TransportType {
TransportType::TcpBridge
}
fn is_ready(&self) -> bool {
true
}
async fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
let result = {
let stream = self.ensure_stream().await?;
stream.read(buf).await
};
match result {
Ok(0) => {
self.stream = None;
Ok(0)
}
other => other,
}
}
async fn write(&mut self, data: &[u8]) -> io::Result<usize> {
let stream = self.ensure_stream().await?;
stream.write(data).await
}
async fn flush(&mut self) -> io::Result<()> {
if let Some(stream) = &mut self.stream {
stream.flush().await
} else {
Ok(())
}
}
fn serial_config(&self) -> &SerialConfig {
&self.config.serial
}
async fn close(&mut self) -> io::Result<()> {
self.stream.take();
Ok(())
}
}
pub struct TransportIo<T: RtuTransport> {
transport: T,
read_buffer: BytesMut,
}
impl<T: RtuTransport> TransportIo<T> {
pub fn new(transport: T) -> Self {
Self {
transport,
read_buffer: BytesMut::with_capacity(256),
}
}
pub fn transport(&self) -> &T {
&self.transport
}
pub fn transport_mut(&mut self) -> &mut T {
&mut self.transport
}
pub fn into_inner(self) -> T {
self.transport
}
}
impl<T: RtuTransport + Unpin> AsyncRead for TransportIo<T> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let this = self.get_mut();
if !this.read_buffer.is_empty() {
let len = std::cmp::min(buf.remaining(), this.read_buffer.len());
buf.put_slice(&this.read_buffer.split_to(len));
return Poll::Ready(Ok(()));
}
cx.waker().wake_by_ref();
Poll::Pending
}
}
impl<T: RtuTransport + Unpin> AsyncWrite for TransportIo<T> {
fn poll_write(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
_buf: &[u8],
) -> Poll<io::Result<usize>> {
Poll::Pending
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
}
pub struct TransportFactory;
impl TransportFactory {
pub async fn create(config: TransportConfig) -> io::Result<Box<dyn RtuTransport>> {
match config {
TransportConfig::Channel(cfg) => {
let (transport, _peer) = ChannelTransport::pair(cfg);
Ok(Box::new(transport))
}
TransportConfig::VirtualSerial(cfg) => {
#[cfg(unix)]
{
let serial = VirtualSerial::create(cfg.clone())
.map_err(|error| io::Error::new(io::ErrorKind::Other, error.to_string()))?;
let io = serial.into_async_io()?;
Ok(Box::new(VirtualSerialTransport::new(io, cfg)))
}
#[cfg(not(unix))]
{
let _ = cfg;
Err(io::Error::new(
io::ErrorKind::Unsupported,
"virtual serial transport is not available on this platform",
))
}
}
TransportConfig::TcpBridge(cfg) => {
let transport = TcpBridgeTransport::bind(cfg).await?;
Ok(Box::new(transport))
}
}
}
}
#[derive(Debug, Clone, Default)]
pub struct TransportMetrics {
pub bytes_received: u64,
pub bytes_sent: u64,
pub frames_received: u64,
pub frames_sent: u64,
pub crc_errors: u64,
pub framing_errors: u64,
pub timeouts: u64,
}
impl TransportMetrics {
pub fn new() -> Self {
Self::default()
}
pub fn record_bytes_received(&mut self, bytes: usize) {
self.bytes_received += bytes as u64;
}
pub fn record_bytes_sent(&mut self, bytes: usize) {
self.bytes_sent += bytes as u64;
}
pub fn record_frame_received(&mut self) {
self.frames_received += 1;
}
pub fn record_frame_sent(&mut self) {
self.frames_sent += 1;
}
pub fn record_crc_error(&mut self) {
self.crc_errors += 1;
}
pub fn record_framing_error(&mut self) {
self.framing_errors += 1;
}
pub fn record_timeout(&mut self) {
self.timeouts += 1;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_channel_transport_pair() {
let config = ChannelConfig::default();
let (mut transport1, mut transport2) = ChannelTransport::pair(config);
let data = b"Hello, RTU!";
transport1.write(data).await.unwrap();
let mut buf = [0u8; 32];
let n = transport2.read(&mut buf).await.unwrap();
assert_eq!(&buf[..n], data);
}
#[tokio::test]
async fn test_channel_transport_bidirectional() {
let config = ChannelConfig::default();
let (mut server, mut client) = ChannelTransport::pair(config);
let request = [0x01, 0x03, 0x00, 0x00, 0x00, 0x0A];
client.write(&request).await.unwrap();
let mut buf = [0u8; 32];
let n = server.read(&mut buf).await.unwrap();
assert_eq!(&buf[..n], &request);
let response = [0x01, 0x03, 0x14];
server.write(&response).await.unwrap();
let n = client.read(&mut buf).await.unwrap();
assert_eq!(&buf[..n], &response);
}
#[test]
fn test_transport_metrics() {
let mut metrics = TransportMetrics::new();
metrics.record_bytes_received(100);
metrics.record_bytes_sent(50);
metrics.record_frame_received();
metrics.record_crc_error();
assert_eq!(metrics.bytes_received, 100);
assert_eq!(metrics.bytes_sent, 50);
assert_eq!(metrics.frames_received, 1);
assert_eq!(metrics.crc_errors, 1);
}
#[test]
fn test_tcp_bridge_config_default() {
let config = TcpBridgeConfig::default();
assert_eq!(config.bind_address.port(), 5020);
assert_eq!(config.max_connections, 10);
assert!(config.raw_mode);
}
}