Skip to main content

rs_modbus/
master_session.rs

1//! `MasterSession` — owns the in-flight "awaiting response" slots of a
2//! `ModbusMaster`. Multi-slot, keyed by [`WaiterKey`]: TCP requests key
3//! by their transaction ID (TID), FIFO/RTU/ASCII requests share the
4//! [`WaiterKey::Fifo`] slot since they have no TID to disambiguate by.
5//!
6//! Mirrors njs-modbus `MasterSession` after the FIFO + TID-validation
7//! commit. The master pushes framing events into [`MasterSession::handle_frame`];
8//! the session looks up the keyed waiter, applies the pre-check chain, and
9//! either resolves the awaiting receiver with the frame or rejects it with
10//! a [`ModbusError`].
11
12use crate::error::ModbusError;
13use crate::layers::application::Framing;
14use std::collections::HashMap;
15use std::sync::{Arc, Mutex};
16use tokio::sync::oneshot;
17
18/// Outcome of a single `PreCheck` evaluation against an incoming [`Framing`].
19///
20/// Mirrors the return values of `master-session.ts`'s `preCheck` functions
21/// (`undefined | number | boolean`).
22#[derive(Clone, Debug)]
23pub enum PreCheckOutcome {
24    /// This check accepts the frame; move on to the next check.
25    Pass,
26    /// The frame's `data` must be exactly this many bytes:
27    /// - `data.len() < n`  → rejected as `InsufficientData`
28    /// - `data.len() != n` → rejected as `InvalidResponse`
29    /// - `data.len() == n` → passes
30    NeedLength(usize),
31    /// Reject with the given error and stop pre-checking.
32    Fail(ModbusError),
33    /// Equivalent to njs `undefined`: the check can't decide yet; treated
34    /// as `InsufficientData`.
35    InsufficientData,
36}
37
38/// A single pre-check predicate applied to a [`Framing`].
39pub type PreCheck = Arc<dyn Fn(&Framing) -> PreCheckOutcome + Send + Sync>;
40
41/// Key used to route an incoming frame to the right pending waiter.
42///
43/// - [`WaiterKey::Tid`]: TCP requests get a fresh transaction ID per
44///   request; the slave echoes that TID in its response, so the master can
45///   demux pipelined responses back to their requesters even when several
46///   are in flight at once.
47/// - [`WaiterKey::Fifo`]: RTU and ASCII have no transaction ID, and FIFO
48///   serialization guarantees there is at most one outstanding waiter at
49///   any given time, so a single shared key is sufficient.
50#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)]
51pub enum WaiterKey {
52    Tid(u16),
53    Fifo,
54}
55
56struct WaitingState {
57    pre_check: Vec<PreCheck>,
58    sender: oneshot::Sender<Result<Framing, ModbusError>>,
59}
60
61/// Owns the "awaiting response" slot(s) for a master. The master calls
62/// [`MasterSession::start`] when it sends a request, and [`MasterSession::stop`]
63/// when the response arrives or times out. [`MasterSession::handle_frame`] /
64/// [`MasterSession::handle_error`] are called by the master's framing /
65/// framing_error subscription tasks.
66pub struct MasterSession {
67    waiters: Mutex<HashMap<WaiterKey, WaitingState>>,
68}
69
70impl MasterSession {
71    pub fn new() -> Self {
72        Self {
73            waiters: Mutex::new(HashMap::new()),
74        }
75    }
76
77    /// Arm a waiter under `key`. Returns a receiver that resolves with
78    /// either the first matching `Framing` or a rejection reason.
79    ///
80    /// If `key` already has a waiter (e.g. TID wrap collision), the
81    /// previous waiter's receiver is dropped — equivalent to the caller
82    /// calling [`MasterSession::stop`] first.
83    pub fn start(
84        &self,
85        key: WaiterKey,
86        pre_check: Vec<PreCheck>,
87    ) -> oneshot::Receiver<Result<Framing, ModbusError>> {
88        let (tx, rx) = oneshot::channel();
89        let mut guard = self.waiters.lock().unwrap();
90        guard.insert(
91            key,
92            WaitingState {
93                pre_check,
94                sender: tx,
95            },
96        );
97        rx
98    }
99
100    /// Drop the waiter under `key` without notifying it. Used on timeout,
101    /// where the caller has already given up on the receiver.
102    pub fn stop(&self, key: WaiterKey) {
103        self.waiters.lock().unwrap().remove(&key);
104    }
105
106    /// Reject every armed waiter with `err`. Used by `handle_error`
107    /// (framing errors lose transaction context) and on master
108    /// close/destroy.
109    pub fn stop_all(&self, err: ModbusError) {
110        let drained: Vec<WaitingState> = {
111            let mut guard = self.waiters.lock().unwrap();
112            guard.drain().map(|(_, v)| v).collect()
113        };
114        for w in drained {
115            let _ = w.sender.send(Err(err.clone()));
116        }
117    }
118
119    /// True if a waiter is currently armed under `key`.
120    pub fn has(&self, key: WaiterKey) -> bool {
121        self.waiters.lock().unwrap().contains_key(&key)
122    }
123
124    /// Push a successfully framed PDU at the session. Looks up the waiter
125    /// keyed by `frame.adu.transaction` (TCP) or `WaiterKey::Fifo`
126    /// (RTU/ASCII). If found, removes it and runs the pre-checks; on the
127    /// first failing check, rejects with the corresponding error. On all
128    /// checks passing, resolves with the frame. No-op if no waiter matches.
129    pub fn handle_frame(&self, frame: Framing) {
130        let key = match frame.adu.transaction {
131            Some(tid) => WaiterKey::Tid(tid),
132            None => WaiterKey::Fifo,
133        };
134        let state = {
135            let mut guard = self.waiters.lock().unwrap();
136            guard.remove(&key)
137        };
138        let Some(state) = state else { return };
139        match run_pre_checks(&frame, &state.pre_check) {
140            CheckResult::Pass => {
141                let _ = state.sender.send(Ok(frame));
142            }
143            CheckResult::Reject(err) => {
144                let _ = state.sender.send(Err(err));
145            }
146        }
147    }
148
149    /// Push a framing error at the session. Framing errors arrive without
150    /// transaction context (CRC/LRC failure, bogus MBAP header, etc.), so
151    /// every in-flight waiter is rejected.
152    pub fn handle_error(&self, err: ModbusError) {
153        self.stop_all(err);
154    }
155}
156
157impl Default for MasterSession {
158    fn default() -> Self {
159        Self::new()
160    }
161}
162
163enum CheckResult {
164    Pass,
165    Reject(ModbusError),
166}
167
168fn run_pre_checks(frame: &Framing, checks: &[PreCheck]) -> CheckResult {
169    for check in checks {
170        match check(frame) {
171            PreCheckOutcome::Pass => continue,
172            PreCheckOutcome::NeedLength(n) => {
173                if frame.adu.data.len() < n {
174                    return CheckResult::Reject(ModbusError::InsufficientData);
175                }
176                if frame.adu.data.len() != n {
177                    return CheckResult::Reject(ModbusError::InvalidResponse);
178                }
179                // exact length — continue
180            }
181            PreCheckOutcome::Fail(err) => return CheckResult::Reject(err),
182            PreCheckOutcome::InsufficientData => {
183                return CheckResult::Reject(ModbusError::InsufficientData);
184            }
185        }
186    }
187    CheckResult::Pass
188}
189
190#[cfg(test)]
191mod tests {
192    use super::*;
193    use crate::layers::physical::{ConnectionId, ResponseFn};
194    use crate::types::ApplicationDataUnit;
195
196    fn fake_framing(unit: u8, fc: u8, data: Vec<u8>) -> Framing {
197        let response: ResponseFn = Arc::new(|_| Box::pin(async { Ok(()) }));
198        let connection: ConnectionId = Arc::from("test");
199        Framing {
200            adu: ApplicationDataUnit::new(unit, fc, data.clone()),
201            raw: data,
202            response,
203            connection,
204        }
205    }
206
207    fn fake_framing_with_tid(unit: u8, fc: u8, data: Vec<u8>, tid: u16) -> Framing {
208        let mut f = fake_framing(unit, fc, data);
209        f.adu.transaction = Some(tid);
210        f
211    }
212
213    fn always_pass() -> PreCheck {
214        Arc::new(|_| PreCheckOutcome::Pass)
215    }
216
217    #[tokio::test]
218    async fn test_fifo_waiter_resolves_on_matching_frame() {
219        let session = MasterSession::new();
220        let rx = session.start(WaiterKey::Fifo, vec![always_pass()]);
221        session.handle_frame(fake_framing(1, 0x03, vec![0x01]));
222        let result = rx.await.unwrap();
223        assert!(result.is_ok());
224        assert_eq!(result.unwrap().adu.unit, 1);
225    }
226
227    #[tokio::test]
228    async fn test_handle_frame_with_no_waiter_is_noop() {
229        let session = MasterSession::new();
230        session.handle_frame(fake_framing(1, 0x03, vec![]));
231        let rx = session.start(WaiterKey::Fifo, vec![always_pass()]);
232        session.handle_frame(fake_framing(2, 0x04, vec![]));
233        let resolved = rx.await.unwrap().unwrap();
234        assert_eq!(resolved.adu.unit, 2);
235    }
236
237    #[tokio::test]
238    async fn test_handle_error_rejects_every_waiter() {
239        let session = MasterSession::new();
240        let rx_fifo = session.start(WaiterKey::Fifo, vec![always_pass()]);
241        let rx_tid = session.start(WaiterKey::Tid(7), vec![always_pass()]);
242        session.handle_error(ModbusError::Timeout);
243        assert!(matches!(rx_fifo.await.unwrap(), Err(ModbusError::Timeout)));
244        assert!(matches!(rx_tid.await.unwrap(), Err(ModbusError::Timeout)));
245    }
246
247    #[tokio::test]
248    async fn test_stop_drops_waiter_silently() {
249        let session = MasterSession::new();
250        let rx = session.start(WaiterKey::Fifo, vec![always_pass()]);
251        session.stop(WaiterKey::Fifo);
252        session.handle_frame(fake_framing(1, 0x03, vec![]));
253        assert!(rx.await.is_err()); // sender dropped → RecvError
254    }
255
256    #[tokio::test]
257    async fn test_stop_all_rejects_each_waiter_independently() {
258        let session = MasterSession::new();
259        let rx_a = session.start(WaiterKey::Tid(1), vec![always_pass()]);
260        let rx_b = session.start(WaiterKey::Tid(2), vec![always_pass()]);
261        session.stop_all(ModbusError::InvalidState("Master closed".into()));
262        assert!(matches!(
263            rx_a.await.unwrap(),
264            Err(ModbusError::InvalidState(ref s)) if s == "Master closed"
265        ));
266        assert!(matches!(
267            rx_b.await.unwrap(),
268            Err(ModbusError::InvalidState(ref s)) if s == "Master closed"
269        ));
270    }
271
272    #[tokio::test]
273    async fn test_has_returns_correct_state() {
274        let session = MasterSession::new();
275        assert!(!session.has(WaiterKey::Fifo));
276        let _rx = session.start(WaiterKey::Fifo, vec![always_pass()]);
277        assert!(session.has(WaiterKey::Fifo));
278        assert!(!session.has(WaiterKey::Tid(0)));
279        session.stop(WaiterKey::Fifo);
280        assert!(!session.has(WaiterKey::Fifo));
281    }
282
283    #[tokio::test]
284    async fn test_tid_routing_isolates_independent_waiters() {
285        let session = MasterSession::new();
286        let rx_tid7 = session.start(WaiterKey::Tid(7), vec![always_pass()]);
287        let rx_tid8 = session.start(WaiterKey::Tid(8), vec![always_pass()]);
288
289        // Push tid=8 first — only tid8 waiter should resolve.
290        session.handle_frame(fake_framing_with_tid(1, 0x03, vec![], 8));
291        let resolved8 = rx_tid8.await.unwrap().unwrap();
292        assert_eq!(resolved8.adu.transaction, Some(8));
293
294        // tid=7 still pending.
295        assert!(session.has(WaiterKey::Tid(7)));
296
297        session.handle_frame(fake_framing_with_tid(1, 0x03, vec![], 7));
298        let resolved7 = rx_tid7.await.unwrap().unwrap();
299        assert_eq!(resolved7.adu.transaction, Some(7));
300    }
301
302    #[tokio::test]
303    async fn test_fifo_frame_does_not_resolve_tid_waiter() {
304        let session = MasterSession::new();
305        let rx = session.start(WaiterKey::Tid(7), vec![always_pass()]);
306        // RTU-style frame without TID lands on the FIFO slot — not tid=7.
307        session.handle_frame(fake_framing(1, 0x03, vec![]));
308        assert!(session.has(WaiterKey::Tid(7)));
309        session.stop(WaiterKey::Tid(7));
310        assert!(rx.await.is_err());
311    }
312
313    #[tokio::test]
314    async fn test_pre_check_fail_returns_error() {
315        let session = MasterSession::new();
316        let fail: PreCheck = Arc::new(|_| PreCheckOutcome::Fail(ModbusError::IllegalDataAddress));
317        let rx = session.start(WaiterKey::Fifo, vec![fail]);
318        session.handle_frame(fake_framing(1, 0x03, vec![]));
319        assert!(matches!(
320            rx.await.unwrap(),
321            Err(ModbusError::IllegalDataAddress)
322        ));
323    }
324
325    #[tokio::test]
326    async fn test_pre_check_insufficient_data_returns_error() {
327        let session = MasterSession::new();
328        let insuff: PreCheck = Arc::new(|_| PreCheckOutcome::InsufficientData);
329        let rx = session.start(WaiterKey::Fifo, vec![insuff]);
330        session.handle_frame(fake_framing(1, 0x03, vec![]));
331        assert!(matches!(
332            rx.await.unwrap(),
333            Err(ModbusError::InsufficientData)
334        ));
335    }
336
337    #[tokio::test]
338    async fn test_need_length_exact_passes() {
339        let session = MasterSession::new();
340        let check: PreCheck = Arc::new(|_| PreCheckOutcome::NeedLength(3));
341        let rx = session.start(WaiterKey::Fifo, vec![check]);
342        session.handle_frame(fake_framing(1, 0x03, vec![1, 2, 3]));
343        assert!(rx.await.unwrap().is_ok());
344    }
345
346    #[tokio::test]
347    async fn test_need_length_too_short_rejects_insufficient() {
348        let session = MasterSession::new();
349        let check: PreCheck = Arc::new(|_| PreCheckOutcome::NeedLength(5));
350        let rx = session.start(WaiterKey::Fifo, vec![check]);
351        session.handle_frame(fake_framing(1, 0x03, vec![1, 2, 3]));
352        assert!(matches!(
353            rx.await.unwrap(),
354            Err(ModbusError::InsufficientData)
355        ));
356    }
357
358    #[tokio::test]
359    async fn test_need_length_too_long_rejects_invalid_response() {
360        let session = MasterSession::new();
361        let check: PreCheck = Arc::new(|_| PreCheckOutcome::NeedLength(2));
362        let rx = session.start(WaiterKey::Fifo, vec![check]);
363        session.handle_frame(fake_framing(1, 0x03, vec![1, 2, 3, 4]));
364        assert!(matches!(
365            rx.await.unwrap(),
366            Err(ModbusError::InvalidResponse)
367        ));
368    }
369}