use crate::error::ModbusError;
use crate::layers::application::Framing;
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use tokio::sync::oneshot;
#[derive(Clone, Debug)]
pub enum PreCheckOutcome {
Pass,
NeedLength(usize),
Fail(ModbusError),
InsufficientData,
}
pub type PreCheck = Arc<dyn Fn(&Framing) -> PreCheckOutcome + Send + Sync>;
#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)]
pub enum WaiterKey {
Tid(u16),
Fifo,
}
struct WaitingState {
pre_check: Vec<PreCheck>,
sender: oneshot::Sender<Result<Framing, ModbusError>>,
}
pub struct MasterSession {
waiters: Mutex<HashMap<WaiterKey, WaitingState>>,
}
impl MasterSession {
pub fn new() -> Self {
Self {
waiters: Mutex::new(HashMap::new()),
}
}
pub fn start(
&self,
key: WaiterKey,
pre_check: Vec<PreCheck>,
) -> oneshot::Receiver<Result<Framing, ModbusError>> {
let (tx, rx) = oneshot::channel();
let mut guard = self.waiters.lock().unwrap();
guard.insert(
key,
WaitingState {
pre_check,
sender: tx,
},
);
rx
}
pub fn stop(&self, key: WaiterKey) {
self.waiters.lock().unwrap().remove(&key);
}
pub fn stop_all(&self, err: ModbusError) {
let drained: Vec<WaitingState> = {
let mut guard = self.waiters.lock().unwrap();
guard.drain().map(|(_, v)| v).collect()
};
for w in drained {
let _ = w.sender.send(Err(err.clone()));
}
}
pub fn has(&self, key: WaiterKey) -> bool {
self.waiters.lock().unwrap().contains_key(&key)
}
pub fn handle_frame(&self, frame: Framing) {
let key = match frame.adu.transaction {
Some(tid) => WaiterKey::Tid(tid),
None => WaiterKey::Fifo,
};
let state = {
let mut guard = self.waiters.lock().unwrap();
guard.remove(&key)
};
let Some(state) = state else { return };
match run_pre_checks(&frame, &state.pre_check) {
CheckResult::Pass => {
let _ = state.sender.send(Ok(frame));
}
CheckResult::Reject(err) => {
let _ = state.sender.send(Err(err));
}
}
}
pub fn handle_error(&self, err: ModbusError) {
self.stop_all(err);
}
}
impl Default for MasterSession {
fn default() -> Self {
Self::new()
}
}
enum CheckResult {
Pass,
Reject(ModbusError),
}
fn run_pre_checks(frame: &Framing, checks: &[PreCheck]) -> CheckResult {
for check in checks {
match check(frame) {
PreCheckOutcome::Pass => continue,
PreCheckOutcome::NeedLength(n) => {
if frame.adu.data.len() < n {
return CheckResult::Reject(ModbusError::InsufficientData);
}
if frame.adu.data.len() != n {
return CheckResult::Reject(ModbusError::InvalidResponse);
}
}
PreCheckOutcome::Fail(err) => return CheckResult::Reject(err),
PreCheckOutcome::InsufficientData => {
return CheckResult::Reject(ModbusError::InsufficientData);
}
}
}
CheckResult::Pass
}
#[cfg(test)]
mod tests {
use super::*;
use crate::layers::physical::{ConnectionId, ResponseFn};
use crate::types::ApplicationDataUnit;
fn fake_framing(unit: u8, fc: u8, data: Vec<u8>) -> Framing {
let response: ResponseFn = Arc::new(|_| Box::pin(async { Ok(()) }));
let connection: ConnectionId = Arc::from("test");
Framing {
adu: ApplicationDataUnit::new(unit, fc, data.clone()),
raw: data,
response,
connection,
}
}
fn fake_framing_with_tid(unit: u8, fc: u8, data: Vec<u8>, tid: u16) -> Framing {
let mut f = fake_framing(unit, fc, data);
f.adu.transaction = Some(tid);
f
}
fn always_pass() -> PreCheck {
Arc::new(|_| PreCheckOutcome::Pass)
}
#[tokio::test]
async fn test_fifo_waiter_resolves_on_matching_frame() {
let session = MasterSession::new();
let rx = session.start(WaiterKey::Fifo, vec![always_pass()]);
session.handle_frame(fake_framing(1, 0x03, vec![0x01]));
let result = rx.await.unwrap();
assert!(result.is_ok());
assert_eq!(result.unwrap().adu.unit, 1);
}
#[tokio::test]
async fn test_handle_frame_with_no_waiter_is_noop() {
let session = MasterSession::new();
session.handle_frame(fake_framing(1, 0x03, vec![]));
let rx = session.start(WaiterKey::Fifo, vec![always_pass()]);
session.handle_frame(fake_framing(2, 0x04, vec![]));
let resolved = rx.await.unwrap().unwrap();
assert_eq!(resolved.adu.unit, 2);
}
#[tokio::test]
async fn test_handle_error_rejects_every_waiter() {
let session = MasterSession::new();
let rx_fifo = session.start(WaiterKey::Fifo, vec![always_pass()]);
let rx_tid = session.start(WaiterKey::Tid(7), vec![always_pass()]);
session.handle_error(ModbusError::Timeout);
assert!(matches!(rx_fifo.await.unwrap(), Err(ModbusError::Timeout)));
assert!(matches!(rx_tid.await.unwrap(), Err(ModbusError::Timeout)));
}
#[tokio::test]
async fn test_stop_drops_waiter_silently() {
let session = MasterSession::new();
let rx = session.start(WaiterKey::Fifo, vec![always_pass()]);
session.stop(WaiterKey::Fifo);
session.handle_frame(fake_framing(1, 0x03, vec![]));
assert!(rx.await.is_err()); }
#[tokio::test]
async fn test_stop_all_rejects_each_waiter_independently() {
let session = MasterSession::new();
let rx_a = session.start(WaiterKey::Tid(1), vec![always_pass()]);
let rx_b = session.start(WaiterKey::Tid(2), vec![always_pass()]);
session.stop_all(ModbusError::InvalidState("Master closed".into()));
assert!(matches!(
rx_a.await.unwrap(),
Err(ModbusError::InvalidState(ref s)) if s == "Master closed"
));
assert!(matches!(
rx_b.await.unwrap(),
Err(ModbusError::InvalidState(ref s)) if s == "Master closed"
));
}
#[tokio::test]
async fn test_has_returns_correct_state() {
let session = MasterSession::new();
assert!(!session.has(WaiterKey::Fifo));
let _rx = session.start(WaiterKey::Fifo, vec![always_pass()]);
assert!(session.has(WaiterKey::Fifo));
assert!(!session.has(WaiterKey::Tid(0)));
session.stop(WaiterKey::Fifo);
assert!(!session.has(WaiterKey::Fifo));
}
#[tokio::test]
async fn test_tid_routing_isolates_independent_waiters() {
let session = MasterSession::new();
let rx_tid7 = session.start(WaiterKey::Tid(7), vec![always_pass()]);
let rx_tid8 = session.start(WaiterKey::Tid(8), vec![always_pass()]);
session.handle_frame(fake_framing_with_tid(1, 0x03, vec![], 8));
let resolved8 = rx_tid8.await.unwrap().unwrap();
assert_eq!(resolved8.adu.transaction, Some(8));
assert!(session.has(WaiterKey::Tid(7)));
session.handle_frame(fake_framing_with_tid(1, 0x03, vec![], 7));
let resolved7 = rx_tid7.await.unwrap().unwrap();
assert_eq!(resolved7.adu.transaction, Some(7));
}
#[tokio::test]
async fn test_fifo_frame_does_not_resolve_tid_waiter() {
let session = MasterSession::new();
let rx = session.start(WaiterKey::Tid(7), vec![always_pass()]);
session.handle_frame(fake_framing(1, 0x03, vec![]));
assert!(session.has(WaiterKey::Tid(7)));
session.stop(WaiterKey::Tid(7));
assert!(rx.await.is_err());
}
#[tokio::test]
async fn test_pre_check_fail_returns_error() {
let session = MasterSession::new();
let fail: PreCheck = Arc::new(|_| PreCheckOutcome::Fail(ModbusError::IllegalDataAddress));
let rx = session.start(WaiterKey::Fifo, vec![fail]);
session.handle_frame(fake_framing(1, 0x03, vec![]));
assert!(matches!(
rx.await.unwrap(),
Err(ModbusError::IllegalDataAddress)
));
}
#[tokio::test]
async fn test_pre_check_insufficient_data_returns_error() {
let session = MasterSession::new();
let insuff: PreCheck = Arc::new(|_| PreCheckOutcome::InsufficientData);
let rx = session.start(WaiterKey::Fifo, vec![insuff]);
session.handle_frame(fake_framing(1, 0x03, vec![]));
assert!(matches!(
rx.await.unwrap(),
Err(ModbusError::InsufficientData)
));
}
#[tokio::test]
async fn test_need_length_exact_passes() {
let session = MasterSession::new();
let check: PreCheck = Arc::new(|_| PreCheckOutcome::NeedLength(3));
let rx = session.start(WaiterKey::Fifo, vec![check]);
session.handle_frame(fake_framing(1, 0x03, vec![1, 2, 3]));
assert!(rx.await.unwrap().is_ok());
}
#[tokio::test]
async fn test_need_length_too_short_rejects_insufficient() {
let session = MasterSession::new();
let check: PreCheck = Arc::new(|_| PreCheckOutcome::NeedLength(5));
let rx = session.start(WaiterKey::Fifo, vec![check]);
session.handle_frame(fake_framing(1, 0x03, vec![1, 2, 3]));
assert!(matches!(
rx.await.unwrap(),
Err(ModbusError::InsufficientData)
));
}
#[tokio::test]
async fn test_need_length_too_long_rejects_invalid_response() {
let session = MasterSession::new();
let check: PreCheck = Arc::new(|_| PreCheckOutcome::NeedLength(2));
let rx = session.start(WaiterKey::Fifo, vec![check]);
session.handle_frame(fake_framing(1, 0x03, vec![1, 2, 3, 4]));
assert!(matches!(
rx.await.unwrap(),
Err(ModbusError::InvalidResponse)
));
}
}