1use crate::error::ModbusError;
13use crate::layers::application::Framing;
14use std::collections::HashMap;
15use std::sync::{Arc, Mutex};
16use tokio::sync::oneshot;
17
18#[derive(Clone, Debug)]
23pub enum PreCheckOutcome {
24 Pass,
26 NeedLength(usize),
31 Fail(ModbusError),
33 InsufficientData,
36}
37
38pub type PreCheck = Arc<dyn Fn(&Framing) -> PreCheckOutcome + Send + Sync>;
40
41#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)]
51pub enum WaiterKey {
52 Tid(u16),
53 Fifo,
54}
55
56struct WaitingState {
57 pre_check: Vec<PreCheck>,
58 sender: oneshot::Sender<Result<Framing, ModbusError>>,
59}
60
61pub struct MasterSession {
67 waiters: Mutex<HashMap<WaiterKey, WaitingState>>,
68}
69
70impl MasterSession {
71 pub fn new() -> Self {
72 Self {
73 waiters: Mutex::new(HashMap::new()),
74 }
75 }
76
77 pub fn start(
84 &self,
85 key: WaiterKey,
86 pre_check: Vec<PreCheck>,
87 ) -> oneshot::Receiver<Result<Framing, ModbusError>> {
88 let (tx, rx) = oneshot::channel();
89 let mut guard = self.waiters.lock().unwrap();
90 guard.insert(
91 key,
92 WaitingState {
93 pre_check,
94 sender: tx,
95 },
96 );
97 rx
98 }
99
100 pub fn stop(&self, key: WaiterKey) {
103 self.waiters.lock().unwrap().remove(&key);
104 }
105
106 pub fn stop_all(&self, err: ModbusError) {
110 let drained: Vec<WaitingState> = {
111 let mut guard = self.waiters.lock().unwrap();
112 guard.drain().map(|(_, v)| v).collect()
113 };
114 for w in drained {
115 let _ = w.sender.send(Err(err.clone()));
116 }
117 }
118
119 pub fn has(&self, key: WaiterKey) -> bool {
121 self.waiters.lock().unwrap().contains_key(&key)
122 }
123
124 pub fn handle_frame(&self, frame: Framing) {
130 let key = match frame.adu.transaction {
131 Some(tid) => WaiterKey::Tid(tid),
132 None => WaiterKey::Fifo,
133 };
134 let state = {
135 let mut guard = self.waiters.lock().unwrap();
136 guard.remove(&key)
137 };
138 let Some(state) = state else { return };
139 match run_pre_checks(&frame, &state.pre_check) {
140 CheckResult::Pass => {
141 let _ = state.sender.send(Ok(frame));
142 }
143 CheckResult::Reject(err) => {
144 let _ = state.sender.send(Err(err));
145 }
146 }
147 }
148
149 pub fn handle_error(&self, err: ModbusError) {
153 self.stop_all(err);
154 }
155}
156
157impl Default for MasterSession {
158 fn default() -> Self {
159 Self::new()
160 }
161}
162
163enum CheckResult {
164 Pass,
165 Reject(ModbusError),
166}
167
168fn run_pre_checks(frame: &Framing, checks: &[PreCheck]) -> CheckResult {
169 for check in checks {
170 match check(frame) {
171 PreCheckOutcome::Pass => continue,
172 PreCheckOutcome::NeedLength(n) => {
173 if frame.adu.data.len() < n {
174 return CheckResult::Reject(ModbusError::InsufficientData);
175 }
176 if frame.adu.data.len() != n {
177 return CheckResult::Reject(ModbusError::InvalidResponse);
178 }
179 }
181 PreCheckOutcome::Fail(err) => return CheckResult::Reject(err),
182 PreCheckOutcome::InsufficientData => {
183 return CheckResult::Reject(ModbusError::InsufficientData);
184 }
185 }
186 }
187 CheckResult::Pass
188}
189
190#[cfg(test)]
191mod tests {
192 use super::*;
193 use crate::layers::physical::{ConnectionId, ResponseFn};
194 use crate::types::ApplicationDataUnit;
195
196 fn fake_framing(unit: u8, fc: u8, data: Vec<u8>) -> Framing {
197 let response: ResponseFn = Arc::new(|_| Box::pin(async { Ok(()) }));
198 let connection: ConnectionId = Arc::from("test");
199 Framing {
200 adu: ApplicationDataUnit::new(unit, fc, data.clone()),
201 raw: data,
202 response,
203 connection,
204 }
205 }
206
207 fn fake_framing_with_tid(unit: u8, fc: u8, data: Vec<u8>, tid: u16) -> Framing {
208 let mut f = fake_framing(unit, fc, data);
209 f.adu.transaction = Some(tid);
210 f
211 }
212
213 fn always_pass() -> PreCheck {
214 Arc::new(|_| PreCheckOutcome::Pass)
215 }
216
217 #[tokio::test]
218 async fn test_fifo_waiter_resolves_on_matching_frame() {
219 let session = MasterSession::new();
220 let rx = session.start(WaiterKey::Fifo, vec![always_pass()]);
221 session.handle_frame(fake_framing(1, 0x03, vec![0x01]));
222 let result = rx.await.unwrap();
223 assert!(result.is_ok());
224 assert_eq!(result.unwrap().adu.unit, 1);
225 }
226
227 #[tokio::test]
228 async fn test_handle_frame_with_no_waiter_is_noop() {
229 let session = MasterSession::new();
230 session.handle_frame(fake_framing(1, 0x03, vec![]));
231 let rx = session.start(WaiterKey::Fifo, vec![always_pass()]);
232 session.handle_frame(fake_framing(2, 0x04, vec![]));
233 let resolved = rx.await.unwrap().unwrap();
234 assert_eq!(resolved.adu.unit, 2);
235 }
236
237 #[tokio::test]
238 async fn test_handle_error_rejects_every_waiter() {
239 let session = MasterSession::new();
240 let rx_fifo = session.start(WaiterKey::Fifo, vec![always_pass()]);
241 let rx_tid = session.start(WaiterKey::Tid(7), vec![always_pass()]);
242 session.handle_error(ModbusError::Timeout);
243 assert!(matches!(rx_fifo.await.unwrap(), Err(ModbusError::Timeout)));
244 assert!(matches!(rx_tid.await.unwrap(), Err(ModbusError::Timeout)));
245 }
246
247 #[tokio::test]
248 async fn test_stop_drops_waiter_silently() {
249 let session = MasterSession::new();
250 let rx = session.start(WaiterKey::Fifo, vec![always_pass()]);
251 session.stop(WaiterKey::Fifo);
252 session.handle_frame(fake_framing(1, 0x03, vec![]));
253 assert!(rx.await.is_err()); }
255
256 #[tokio::test]
257 async fn test_stop_all_rejects_each_waiter_independently() {
258 let session = MasterSession::new();
259 let rx_a = session.start(WaiterKey::Tid(1), vec![always_pass()]);
260 let rx_b = session.start(WaiterKey::Tid(2), vec![always_pass()]);
261 session.stop_all(ModbusError::InvalidState("Master closed".into()));
262 assert!(matches!(
263 rx_a.await.unwrap(),
264 Err(ModbusError::InvalidState(ref s)) if s == "Master closed"
265 ));
266 assert!(matches!(
267 rx_b.await.unwrap(),
268 Err(ModbusError::InvalidState(ref s)) if s == "Master closed"
269 ));
270 }
271
272 #[tokio::test]
273 async fn test_has_returns_correct_state() {
274 let session = MasterSession::new();
275 assert!(!session.has(WaiterKey::Fifo));
276 let _rx = session.start(WaiterKey::Fifo, vec![always_pass()]);
277 assert!(session.has(WaiterKey::Fifo));
278 assert!(!session.has(WaiterKey::Tid(0)));
279 session.stop(WaiterKey::Fifo);
280 assert!(!session.has(WaiterKey::Fifo));
281 }
282
283 #[tokio::test]
284 async fn test_tid_routing_isolates_independent_waiters() {
285 let session = MasterSession::new();
286 let rx_tid7 = session.start(WaiterKey::Tid(7), vec![always_pass()]);
287 let rx_tid8 = session.start(WaiterKey::Tid(8), vec![always_pass()]);
288
289 session.handle_frame(fake_framing_with_tid(1, 0x03, vec![], 8));
291 let resolved8 = rx_tid8.await.unwrap().unwrap();
292 assert_eq!(resolved8.adu.transaction, Some(8));
293
294 assert!(session.has(WaiterKey::Tid(7)));
296
297 session.handle_frame(fake_framing_with_tid(1, 0x03, vec![], 7));
298 let resolved7 = rx_tid7.await.unwrap().unwrap();
299 assert_eq!(resolved7.adu.transaction, Some(7));
300 }
301
302 #[tokio::test]
303 async fn test_fifo_frame_does_not_resolve_tid_waiter() {
304 let session = MasterSession::new();
305 let rx = session.start(WaiterKey::Tid(7), vec![always_pass()]);
306 session.handle_frame(fake_framing(1, 0x03, vec![]));
308 assert!(session.has(WaiterKey::Tid(7)));
309 session.stop(WaiterKey::Tid(7));
310 assert!(rx.await.is_err());
311 }
312
313 #[tokio::test]
314 async fn test_pre_check_fail_returns_error() {
315 let session = MasterSession::new();
316 let fail: PreCheck = Arc::new(|_| PreCheckOutcome::Fail(ModbusError::IllegalDataAddress));
317 let rx = session.start(WaiterKey::Fifo, vec![fail]);
318 session.handle_frame(fake_framing(1, 0x03, vec![]));
319 assert!(matches!(
320 rx.await.unwrap(),
321 Err(ModbusError::IllegalDataAddress)
322 ));
323 }
324
325 #[tokio::test]
326 async fn test_pre_check_insufficient_data_returns_error() {
327 let session = MasterSession::new();
328 let insuff: PreCheck = Arc::new(|_| PreCheckOutcome::InsufficientData);
329 let rx = session.start(WaiterKey::Fifo, vec![insuff]);
330 session.handle_frame(fake_framing(1, 0x03, vec![]));
331 assert!(matches!(
332 rx.await.unwrap(),
333 Err(ModbusError::InsufficientData)
334 ));
335 }
336
337 #[tokio::test]
338 async fn test_need_length_exact_passes() {
339 let session = MasterSession::new();
340 let check: PreCheck = Arc::new(|_| PreCheckOutcome::NeedLength(3));
341 let rx = session.start(WaiterKey::Fifo, vec![check]);
342 session.handle_frame(fake_framing(1, 0x03, vec![1, 2, 3]));
343 assert!(rx.await.unwrap().is_ok());
344 }
345
346 #[tokio::test]
347 async fn test_need_length_too_short_rejects_insufficient() {
348 let session = MasterSession::new();
349 let check: PreCheck = Arc::new(|_| PreCheckOutcome::NeedLength(5));
350 let rx = session.start(WaiterKey::Fifo, vec![check]);
351 session.handle_frame(fake_framing(1, 0x03, vec![1, 2, 3]));
352 assert!(matches!(
353 rx.await.unwrap(),
354 Err(ModbusError::InsufficientData)
355 ));
356 }
357
358 #[tokio::test]
359 async fn test_need_length_too_long_rejects_invalid_response() {
360 let session = MasterSession::new();
361 let check: PreCheck = Arc::new(|_| PreCheckOutcome::NeedLength(2));
362 let rx = session.start(WaiterKey::Fifo, vec![check]);
363 session.handle_frame(fake_framing(1, 0x03, vec![1, 2, 3, 4]));
364 assert!(matches!(
365 rx.await.unwrap(),
366 Err(ModbusError::InvalidResponse)
367 ));
368 }
369}