rs-modbus 2.0.0

A pure Rust implementation of MODBUS protocol.
Documentation
use crate::error::ModbusError;
use crate::layers::application::{ApplicationLayer, ApplicationProtocol, ApplicationRole, Framing};
use crate::layers::physical::{ConnectionId, PhysicalLayer, ResponseFn};
use crate::types::{ApplicationDataUnit, FramedDataUnit};
use std::collections::HashMap;
use std::sync::atomic::{AtomicBool, AtomicU16, Ordering};
use std::sync::{Arc, Mutex};
use tokio::sync::broadcast;
use tokio::task::JoinHandle;

const MAX_TCP_FRAME: usize = 260;

pub struct TcpApplicationLayer {
    role: Mutex<Option<ApplicationRole>>,
    transaction_id: AtomicU16,
    framing_tx: broadcast::Sender<Framing>,
    framing_error_tx: broadcast::Sender<ModbusError>,
    buffers: Arc<Mutex<HashMap<ConnectionId, Vec<u8>>>>,
    tasks: Mutex<Vec<JoinHandle<()>>>,
    destroyed: AtomicBool,
}

impl TcpApplicationLayer {
    pub fn new<P: PhysicalLayer + 'static>(physical: Arc<P>) -> Arc<Self> {
        let (framing_tx, _) = broadcast::channel(64);
        let (framing_error_tx, _) = broadcast::channel(64);
        let buffers: Arc<Mutex<HashMap<ConnectionId, Vec<u8>>>> =
            Arc::new(Mutex::new(HashMap::new()));
        let app = Arc::new(Self {
            role: Mutex::new(None),
            transaction_id: AtomicU16::new(0),
            framing_tx: framing_tx.clone(),
            framing_error_tx: framing_error_tx.clone(),
            buffers: Arc::clone(&buffers),
            tasks: Mutex::new(Vec::new()),
            destroyed: AtomicBool::new(false),
        });

        let mut data_rx = physical.subscribe_data();
        let buffers_for_data = Arc::clone(&buffers);
        let framing_tx_for_data = framing_tx.clone();
        let framing_error_tx_for_data = framing_error_tx.clone();
        let data_task = tokio::spawn(async move {
            loop {
                match data_rx.recv().await {
                    Ok(event) => process_data_event(
                        &buffers_for_data,
                        &framing_tx_for_data,
                        &framing_error_tx_for_data,
                        event.data,
                        event.response,
                        event.connection,
                    ),
                    Err(broadcast::error::RecvError::Lagged(_)) => continue,
                    Err(broadcast::error::RecvError::Closed) => break,
                }
            }
        });

        let mut close_rx = physical.subscribe_connection_close();
        let buffers_for_close = Arc::clone(&buffers);
        let close_task = tokio::spawn(async move {
            loop {
                match close_rx.recv().await {
                    Ok(conn_id) => {
                        buffers_for_close.lock().unwrap().remove(&conn_id);
                    }
                    Err(broadcast::error::RecvError::Lagged(_)) => continue,
                    Err(broadcast::error::RecvError::Closed) => break,
                }
            }
        });

        app.tasks.lock().unwrap().extend([data_task, close_task]);
        app
    }
}

fn process_data_event(
    buffers: &Arc<Mutex<HashMap<ConnectionId, Vec<u8>>>>,
    framing_tx: &broadcast::Sender<Framing>,
    framing_error_tx: &broadcast::Sender<ModbusError>,
    data: Vec<u8>,
    response: ResponseFn,
    connection: ConnectionId,
) {
    let mut guard = buffers.lock().unwrap();
    let buffer = guard.entry(Arc::clone(&connection)).or_default();
    buffer.extend_from_slice(&data);

    loop {
        match try_extract_frame(buffer) {
            ExtractResult::Frame(total) => {
                let frame_bytes: Vec<u8> = buffer.drain(..total).collect();
                let transaction = u16::from_be_bytes([frame_bytes[0], frame_bytes[1]]);
                let adu = ApplicationDataUnit {
                    transaction: Some(transaction),
                    unit: frame_bytes[6],
                    fc: frame_bytes[7],
                    data: frame_bytes[8..].to_vec(),
                };
                let _ = framing_tx.send(Framing {
                    adu,
                    raw: frame_bytes,
                    response: Arc::clone(&response),
                    connection: Arc::clone(&connection),
                });
            }
            ExtractResult::Insufficient => break,
            ExtractResult::Invalid => {
                let _ = framing_error_tx.send(ModbusError::InvalidData);
                buffer.clear();
                break;
            }
        }
    }

    if buffer.is_empty() {
        guard.remove(&connection);
    }
}

enum ExtractResult {
    Frame(usize),
    Insufficient,
    Invalid,
}

fn try_extract_frame(buffer: &[u8]) -> ExtractResult {
    if buffer.len() < 8 {
        return ExtractResult::Insufficient;
    }
    if buffer[2] != 0 || buffer[3] != 0 {
        return ExtractResult::Invalid;
    }
    let length = u16::from_be_bytes([buffer[4], buffer[5]]) as usize;
    let total = 6 + length;
    if total > MAX_TCP_FRAME || length < 2 {
        return ExtractResult::Invalid;
    }
    if buffer.len() < total {
        return ExtractResult::Insufficient;
    }
    ExtractResult::Frame(total)
}

fn decode_inner(data: &[u8]) -> Result<ApplicationDataUnit, ModbusError> {
    if data.len() < 8 {
        return Err(ModbusError::InsufficientData);
    }
    if data[2] != 0 || data[3] != 0 {
        return Err(ModbusError::InvalidData);
    }
    let len = u16::from_be_bytes([data[4], data[5]]) as usize;
    if len + 6 != data.len() {
        return Err(ModbusError::InvalidData);
    }
    let transaction = u16::from_be_bytes([data[0], data[1]]);
    Ok(ApplicationDataUnit {
        transaction: Some(transaction),
        unit: data[6],
        fc: data[7],
        data: data[8..].to_vec(),
    })
}

#[async_trait::async_trait]
impl ApplicationLayer for TcpApplicationLayer {
    fn set_role(&self, role: ApplicationRole) -> Result<(), ModbusError> {
        crate::layers::application::set_role_impl(&mut self.role.lock().unwrap(), role)
    }

    fn role(&self) -> Option<ApplicationRole> {
        *self.role.lock().unwrap()
    }

    fn protocol(&self) -> ApplicationProtocol {
        ApplicationProtocol::Tcp
    }

    fn encode(&self, adu: &ApplicationDataUnit) -> Vec<u8> {
        let data_len = adu.data.len();
        let mut buf = vec![0u8; data_len + 8];
        let tx = adu.transaction.unwrap_or_else(|| {
            let current = self.transaction_id.fetch_add(1, Ordering::Relaxed);
            if current == 0 {
                self.transaction_id.store(1, Ordering::Relaxed);
            }
            current.wrapping_add(1)
        });
        buf[0..2].copy_from_slice(&tx.to_be_bytes());
        buf[2..4].copy_from_slice(&[0x00, 0x00]);
        buf[4..6].copy_from_slice(&((data_len + 2) as u16).to_be_bytes());
        buf[6] = adu.unit;
        buf[7] = adu.fc;
        buf[8..].copy_from_slice(&adu.data);
        buf
    }

    fn decode(&self, data: &[u8]) -> Result<FramedDataUnit, ModbusError> {
        let adu = decode_inner(data)?;
        Ok(FramedDataUnit {
            adu,
            raw: data.to_vec(),
        })
    }

    fn flush(&self) {
        self.buffers.lock().unwrap().clear();
    }

    fn subscribe_framing(&self) -> broadcast::Receiver<Framing> {
        self.framing_tx.subscribe()
    }

    fn subscribe_framing_error(&self) -> broadcast::Receiver<ModbusError> {
        self.framing_error_tx.subscribe()
    }

    async fn destroy(&self) {
        if self.destroyed.swap(true, Ordering::SeqCst) {
            return;
        }
        let mut tasks = self.tasks.lock().unwrap();
        for task in tasks.drain(..) {
            task.abort();
        }
        self.buffers.lock().unwrap().clear();
    }
}