1use crate::error::ModbusError;
2use crate::layers::application::{ApplicationLayer, ApplicationProtocol, ApplicationRole, Framing};
3use crate::layers::physical::{ConnectionId, PhysicalLayer, ResponseFn};
4use crate::types::{ApplicationDataUnit, CustomFcPredict, CustomFunctionCode, FramedDataUnit};
5use crate::utils::{crc, predict_rtu_frame_length, PredictResult};
6use std::collections::HashMap;
7use std::sync::atomic::{AtomicBool, Ordering};
8use std::sync::{Arc, Mutex};
9use tokio::sync::broadcast;
10use tokio::task::JoinHandle;
11
12const MAX_FRAME_LENGTH: usize = 256;
13const MIN_FRAME_LENGTH: usize = 4;
14const POOL_SIZE: usize = MAX_FRAME_LENGTH * 2;
15
16#[derive(Clone, Copy, Debug)]
19pub enum FrameInterval {
20 Bits(f64),
22 Ms(u32),
24}
25
26#[derive(Clone, Copy, Debug, Default)]
32pub struct RtuApplicationLayerOptions {
33 pub interval_between_frames: Option<FrameInterval>,
34 pub inter_char_timeout: Option<FrameInterval>,
35 pub baud_rate: Option<u32>,
36}
37
38pub struct RtuApplicationLayer {
39 role: Arc<Mutex<Option<ApplicationRole>>>,
40 framing_tx: broadcast::Sender<Framing>,
41 framing_error_tx: broadcast::Sender<ModbusError>,
42 buffers: Arc<Mutex<HashMap<ConnectionId, RtuBuffer>>>,
43 tasks: Mutex<Vec<JoinHandle<()>>>,
44 custom_function_codes: Mutex<HashMap<u8, CustomFunctionCode>>,
46 interval_ms: u32,
49 inter_char_ms: u32,
52 destroyed: AtomicBool,
53}
54
55struct RtuBuffer {
64 pool: Box<[u8]>,
65 start: usize,
66 end: usize,
67 timer: Option<JoinHandle<()>>,
68 inter_char_timer: Option<JoinHandle<()>>,
69 t15_expired: bool,
70}
71
72impl RtuBuffer {
73 fn new() -> Self {
74 Self {
75 pool: vec![0u8; POOL_SIZE].into_boxed_slice(),
76 start: 0,
77 end: 0,
78 timer: None,
79 inter_char_timer: None,
80 t15_expired: false,
81 }
82 }
83
84 fn len(&self) -> usize {
85 self.end - self.start
86 }
87
88 fn is_empty(&self) -> bool {
89 self.start == self.end
90 }
91
92 fn as_slice(&self) -> &[u8] {
93 &self.pool[self.start..self.end]
94 }
95
96 fn available(&self) -> usize {
97 self.pool.len() - self.end
98 }
99
100 fn extend_from_slice(&mut self, data: &[u8]) -> usize {
103 let n = data.len().min(self.available());
104 self.pool[self.end..self.end + n].copy_from_slice(&data[..n]);
105 self.end += n;
106 n
107 }
108
109 fn drain(&mut self, n: usize) {
111 self.start += n;
112 }
113
114 fn compact(&mut self) {
117 if self.start > 0 {
118 if self.start < self.end {
119 let len = self.end - self.start;
120 self.pool.copy_within(self.start..self.end, 0);
121 self.start = 0;
122 self.end = len;
123 } else {
124 self.start = 0;
125 self.end = 0;
126 }
127 }
128 }
129
130 fn clear(&mut self) {
131 self.start = 0;
132 self.end = 0;
133 }
134}
135
136impl RtuApplicationLayer {
137 pub fn new<P: PhysicalLayer + 'static>(
149 physical: Arc<P>,
150 options: RtuApplicationLayerOptions,
151 ) -> Arc<Self> {
152 let (interval_ms, inter_char_ms) = compute_interval_ms(physical.layer_type(), options);
153
154 let (framing_tx, _) = broadcast::channel(64);
155 let (framing_error_tx, _) = broadcast::channel(64);
156 let buffers: Arc<Mutex<HashMap<ConnectionId, RtuBuffer>>> =
157 Arc::new(Mutex::new(HashMap::new()));
158 let role: Arc<Mutex<Option<ApplicationRole>>> = Arc::new(Mutex::new(None));
159 let app = Arc::new(Self {
160 role: Arc::clone(&role),
161 framing_tx: framing_tx.clone(),
162 framing_error_tx: framing_error_tx.clone(),
163 buffers: Arc::clone(&buffers),
164 tasks: Mutex::new(Vec::new()),
165 custom_function_codes: Mutex::new(HashMap::new()),
166 interval_ms,
167 inter_char_ms,
168 destroyed: AtomicBool::new(false),
169 });
170
171 let mut data_rx = physical.subscribe_data();
172 let buffers_for_data = Arc::clone(&buffers);
173 let framing_tx_for_data = framing_tx.clone();
174 let framing_error_tx_for_data = framing_error_tx.clone();
175 let app_for_data = Arc::clone(&app);
176 let data_task = tokio::spawn(async move {
177 loop {
178 match data_rx.recv().await {
179 Ok(event) => {
180 process_data_event(
181 &app_for_data,
182 &buffers_for_data,
183 &framing_tx_for_data,
184 &framing_error_tx_for_data,
185 event.data,
186 event.response,
187 event.connection,
188 );
189 }
190 Err(broadcast::error::RecvError::Lagged(_)) => continue,
191 Err(broadcast::error::RecvError::Closed) => break,
192 }
193 }
194 });
195
196 let mut close_rx = physical.subscribe_connection_close();
197 let buffers_for_close = Arc::clone(&buffers);
198 let close_task = tokio::spawn(async move {
199 loop {
200 match close_rx.recv().await {
201 Ok(conn_id) => {
202 buffers_for_close.lock().unwrap().remove(&conn_id);
203 }
204 Err(broadcast::error::RecvError::Lagged(_)) => continue,
205 Err(broadcast::error::RecvError::Closed) => break,
206 }
207 }
208 });
209
210 app.tasks.lock().unwrap().extend([data_task, close_task]);
211 app
212 }
213
214 fn role_snapshot(&self) -> Option<ApplicationRole> {
215 *self.role.lock().unwrap()
216 }
217
218 pub fn add_custom_function_code(&self, cfc: CustomFunctionCode) {
221 self.custom_function_codes
222 .lock()
223 .unwrap()
224 .insert(cfc.fc, cfc);
225 }
226
227 pub fn remove_custom_function_code(&self, fc: u8) {
228 self.custom_function_codes.lock().unwrap().remove(&fc);
229 }
230}
231
232pub(crate) fn compute_interval_ms(
233 layer_type: crate::layers::physical::PhysicalLayerType,
234 options: RtuApplicationLayerOptions,
235) -> (u32, u32) {
236 use crate::layers::physical::PhysicalLayerType;
237 use crate::utils::bits_to_ms;
238
239 let RtuApplicationLayerOptions {
240 interval_between_frames,
241 inter_char_timeout,
242 baud_rate,
243 } = options;
244
245 match layer_type {
246 PhysicalLayerType::Net => (0, 0),
247 PhysicalLayerType::Serial => {
248 let baud = baud_rate.unwrap_or(9600);
249
250 let three_point_five_t = match interval_between_frames {
251 Some(FrameInterval::Ms(n)) => n as f64,
252 other => {
253 let bits = match other {
254 Some(FrameInterval::Bits(n)) => n,
255 _ => 38.5,
256 };
257 if baud > 19200 {
258 1.75
259 } else {
260 bits_to_ms(baud, bits).ceil()
261 }
262 }
263 };
264
265 let one_point_five_t = match inter_char_timeout {
266 Some(FrameInterval::Ms(n)) => n as f64,
267 Some(FrameInterval::Bits(n)) => {
268 if baud > 19200 {
269 0.75
270 } else {
271 bits_to_ms(baud, n).ceil()
272 }
273 }
274 None => 0.0,
275 };
276
277 (
278 three_point_five_t.max(0.0) as u32,
279 one_point_five_t.max(0.0) as u32,
280 )
281 }
282 }
283}
284
285fn process_data_event(
286 app: &Arc<RtuApplicationLayer>,
287 buffers: &Arc<Mutex<HashMap<ConnectionId, RtuBuffer>>>,
288 framing_tx: &broadcast::Sender<Framing>,
289 framing_error_tx: &broadcast::Sender<ModbusError>,
290 data: Vec<u8>,
291 response: ResponseFn,
292 connection: ConnectionId,
293) {
294 let strict = app.interval_ms > 0;
295
296 let mut guard = buffers.lock().unwrap();
297 let mut buffer = guard
298 .entry(Arc::clone(&connection))
299 .or_insert_with(RtuBuffer::new);
300
301 if buffer.t15_expired && !buffer.is_empty() {
304 buffer.start = 0;
305 buffer.end = 0;
306 buffer.t15_expired = false;
307 drop(guard);
308 let _ = framing_error_tx.send(ModbusError::T1_5Exceeded);
309 guard = buffers.lock().unwrap();
310 buffer = guard
311 .entry(Arc::clone(&connection))
312 .or_insert_with(RtuBuffer::new);
313 } else {
314 buffer.t15_expired = false;
315 }
316
317 if let Some(t) = buffer.timer.take() {
319 t.abort();
320 }
321 if let Some(t) = buffer.inter_char_timer.take() {
322 t.abort();
323 }
324
325 let mut data_offset = 0;
327 while data_offset < data.len() {
328 let copied = buffer.extend_from_slice(&data[data_offset..]);
329 if copied == 0 {
330 drop(guard);
331 flush_pool(
332 app,
333 buffers,
334 framing_tx,
335 framing_error_tx,
336 &connection,
337 &response,
338 strict,
339 );
340 guard = buffers.lock().unwrap();
341 buffer = guard
342 .entry(Arc::clone(&connection))
343 .or_insert_with(RtuBuffer::new);
344 if buffer.available() == 0 {
345 let _ = framing_error_tx.send(ModbusError::InvalidData);
346 buffer.clear();
347 data_offset = data.len();
348 }
349 continue;
350 }
351 data_offset += copied;
352 }
353
354 let len_after = buffer.len();
355 drop(guard);
356
357 if app.interval_ms == 0 || len_after >= MAX_FRAME_LENGTH {
360 flush_pool(
361 app,
362 buffers,
363 framing_tx,
364 framing_error_tx,
365 &connection,
366 &response,
367 strict,
368 );
369 }
370
371 if app.interval_ms > 0 && len_after < MAX_FRAME_LENGTH {
373 let interval = app.interval_ms;
374 let inter_char = app.inter_char_ms;
375 let buffers_t = Arc::clone(buffers);
376 let framing_tx_t = framing_tx.clone();
377 let framing_error_tx_t = framing_error_tx.clone();
378 let conn_t = Arc::clone(&connection);
379 let response_t = Arc::clone(&response);
380 let app_t = Arc::clone(app);
381
382 let timer = tokio::spawn(async move {
383 tokio::time::sleep(tokio::time::Duration::from_millis(interval as u64)).await;
384 flush_pool(
385 &app_t,
386 &buffers_t,
387 &framing_tx_t,
388 &framing_error_tx_t,
389 &conn_t,
390 &response_t,
391 interval > 0,
392 );
393 });
394
395 let mut guard = buffers.lock().unwrap();
396 if let Some(b) = guard.get_mut(&connection) {
397 b.timer = Some(timer);
398
399 if inter_char > 0 {
400 let buffers_ic = Arc::clone(buffers);
401 let conn_ic = Arc::clone(&connection);
402 let inter_char_timer = tokio::spawn(async move {
403 tokio::time::sleep(tokio::time::Duration::from_millis(inter_char as u64)).await;
404 let mut guard = buffers_ic.lock().unwrap();
405 if let Some(b) = guard.get_mut(&conn_ic) {
406 b.t15_expired = true;
407 }
408 });
409 b.inter_char_timer = Some(inter_char_timer);
410 }
411 }
412 }
413}
414
415fn flush_pool(
418 app: &Arc<RtuApplicationLayer>,
419 buffers: &Arc<Mutex<HashMap<ConnectionId, RtuBuffer>>>,
420 framing_tx: &broadcast::Sender<Framing>,
421 framing_error_tx: &broadcast::Sender<ModbusError>,
422 connection: &ConnectionId,
423 response: &ResponseFn,
424 strict: bool,
425) {
426 let mut guard = buffers.lock().unwrap();
427 let buffer = match guard.get_mut(connection) {
428 Some(b) => b,
429 None => return,
430 };
431
432 let is_response = matches!(app.role_snapshot(), Some(ApplicationRole::Master));
433 let custom_fcs = app.custom_function_codes.lock().unwrap();
434
435 while !buffer.is_empty() {
436 match try_extract(buffer.as_slice(), is_response, &custom_fcs) {
437 ExtractResult::Frame { skip, frame_len } => {
438 if skip > 0 {
439 buffer.drain(skip);
440 }
441 let frame_bytes: Vec<u8> = buffer.as_slice()[..frame_len].to_vec();
442 buffer.drain(frame_len);
443 let adu = ApplicationDataUnit {
444 transaction: None,
445 unit: frame_bytes[0],
446 fc: frame_bytes[1],
447 data: frame_bytes[2..frame_bytes.len() - 2].to_vec(),
448 };
449 let _ = framing_tx.send(Framing {
450 adu,
451 raw: frame_bytes,
452 response: Arc::clone(response),
453 connection: Arc::clone(connection),
454 });
455 }
456 ExtractResult::Skip => {
457 if strict {
458 let _ = framing_error_tx.send(ModbusError::CrcCheckFailed);
459 buffer.clear();
460 break;
461 }
462 buffer.drain(1);
463 }
464 ExtractResult::Insufficient => {
465 if buffer.len() >= MAX_FRAME_LENGTH {
466 buffer.drain(1);
467 continue;
468 }
469 if strict {
470 let err = if buffer.t15_expired {
471 ModbusError::T1_5Exceeded
472 } else {
473 ModbusError::IncompleteFrame
474 };
475 let _ = framing_error_tx.send(err);
476 buffer.clear();
477 buffer.t15_expired = false;
478 break;
479 }
480 if buffer.t15_expired {
481 let _ = framing_error_tx.send(ModbusError::T1_5Exceeded);
482 buffer.clear();
483 buffer.t15_expired = false;
484 }
485 break;
486 }
487 ExtractResult::Invalid => {
488 let _ = framing_error_tx.send(ModbusError::InvalidData);
489 buffer.clear();
490 break;
491 }
492 }
493 }
494
495 buffer.compact();
496 if buffer.is_empty() {
497 guard.remove(connection);
498 }
499}
500
501enum ExtractResult {
502 Frame { skip: usize, frame_len: usize },
503 Insufficient,
504 Skip,
505 Invalid,
506}
507
508fn try_extract(
509 buffer: &[u8],
510 is_response: bool,
511 custom_fcs: &HashMap<u8, CustomFunctionCode>,
512) -> ExtractResult {
513 if buffer.len() < MIN_FRAME_LENGTH {
514 return ExtractResult::Insufficient;
515 }
516
517 let fc = buffer[1];
518
519 if let Some(cfc) = custom_fcs.get(&fc) {
521 let predictor = if is_response {
522 &cfc.predict_response_length
523 } else {
524 &cfc.predict_request_length
525 };
526 match predictor(buffer) {
527 CustomFcPredict::NeedMore => return ExtractResult::Insufficient,
528 CustomFcPredict::Length(n) => return check_expected(buffer, n),
529 }
530 }
531
532 match predict_rtu_frame_length(buffer, is_response) {
534 PredictResult::Length(n) => check_expected(buffer, n),
535 PredictResult::NeedMore => ExtractResult::Insufficient,
536 PredictResult::Unknown => {
537 ExtractResult::Invalid
540 }
541 }
542}
543
544fn check_expected(buffer: &[u8], expected: usize) -> ExtractResult {
545 if !(MIN_FRAME_LENGTH..=MAX_FRAME_LENGTH).contains(&expected) {
546 return ExtractResult::Invalid;
547 }
548 if buffer.len() < expected {
549 return ExtractResult::Insufficient;
550 }
551 if crc_matches(buffer, expected) {
552 return ExtractResult::Frame {
553 skip: 0,
554 frame_len: expected,
555 };
556 }
557 ExtractResult::Skip
559}
560
561fn crc_matches(buffer: &[u8], length: usize) -> bool {
562 if length < 2 || length > buffer.len() {
563 return false;
564 }
565 let frame_crc = u16::from_le_bytes([buffer[length - 2], buffer[length - 1]]);
566 let computed = crc(&buffer[..length - 2]);
567 frame_crc == computed
568}
569
570fn decode_inner(data: &[u8]) -> Result<ApplicationDataUnit, ModbusError> {
571 if data.len() < 4 {
572 return Err(ModbusError::InsufficientData);
573 }
574 let frame_crc = u16::from_le_bytes([data[data.len() - 2], data[data.len() - 1]]);
575 let computed = crc(&data[..data.len() - 2]);
576 if frame_crc != computed {
577 return Err(ModbusError::CrcCheckFailed);
578 }
579 Ok(ApplicationDataUnit {
580 transaction: None,
581 unit: data[0],
582 fc: data[1],
583 data: data[2..data.len() - 2].to_vec(),
584 })
585}
586
587#[async_trait::async_trait]
588impl ApplicationLayer for RtuApplicationLayer {
589 fn set_role(&self, role: ApplicationRole) -> Result<(), ModbusError> {
590 crate::layers::application::set_role_impl(&mut self.role.lock().unwrap(), role)
591 }
592
593 fn role(&self) -> Option<ApplicationRole> {
594 self.role_snapshot()
595 }
596
597 fn protocol(&self) -> ApplicationProtocol {
598 ApplicationProtocol::Rtu
599 }
600
601 fn encode(&self, adu: &ApplicationDataUnit) -> Vec<u8> {
602 let data_len = adu.data.len();
603 let payload_len = data_len + 2;
604 let mut buf = vec![0u8; payload_len + 2];
605 buf[0] = adu.unit;
606 buf[1] = adu.fc;
607 buf[2..payload_len].copy_from_slice(&adu.data);
608 let c = crc(&buf[..payload_len]);
609 buf[payload_len..].copy_from_slice(&c.to_le_bytes());
610 buf
611 }
612
613 fn decode(&self, data: &[u8]) -> Result<FramedDataUnit, ModbusError> {
614 let adu = decode_inner(data)?;
615 Ok(FramedDataUnit {
616 adu,
617 raw: data.to_vec(),
618 })
619 }
620
621 fn flush(&self) {
622 self.buffers.lock().unwrap().clear();
623 }
624
625 fn subscribe_framing(&self) -> broadcast::Receiver<Framing> {
626 self.framing_tx.subscribe()
627 }
628
629 fn subscribe_framing_error(&self) -> broadcast::Receiver<ModbusError> {
630 self.framing_error_tx.subscribe()
631 }
632
633 async fn destroy(&self) {
634 if self.destroyed.swap(true, Ordering::SeqCst) {
635 return;
636 }
637 let mut tasks = self.tasks.lock().unwrap();
638 for task in tasks.drain(..) {
639 task.abort();
640 }
641 self.buffers.lock().unwrap().clear();
642 self.custom_function_codes.lock().unwrap().clear();
643 }
644}
645
646#[cfg(test)]
647mod tests {
648 use super::*;
649 use crate::layers::physical::PhysicalLayerType;
650
651 #[test]
652 fn test_compute_interval_ms_net_returns_zero() {
653 assert_eq!(
654 compute_interval_ms(
655 PhysicalLayerType::Net,
656 RtuApplicationLayerOptions::default()
657 ),
658 (0, 0)
659 );
660 assert_eq!(
661 compute_interval_ms(
662 PhysicalLayerType::Net,
663 RtuApplicationLayerOptions {
664 baud_rate: Some(9600),
665 interval_between_frames: Some(FrameInterval::Ms(50)),
666 ..Default::default()
667 }
668 ),
669 (0, 0),
670 "Net always ignores baud/interval inputs"
671 );
672 }
673
674 #[test]
675 fn test_compute_interval_ms_serial_default_9600() {
676 assert_eq!(
677 compute_interval_ms(
678 PhysicalLayerType::Serial,
679 RtuApplicationLayerOptions {
680 baud_rate: Some(9600),
681 ..Default::default()
682 }
683 ),
684 (5, 0)
685 );
686 }
687
688 #[test]
689 fn test_compute_interval_ms_serial_default_19200() {
690 assert_eq!(
691 compute_interval_ms(
692 PhysicalLayerType::Serial,
693 RtuApplicationLayerOptions {
694 baud_rate: Some(19200),
695 ..Default::default()
696 }
697 ),
698 (3, 0)
699 );
700 }
701
702 #[test]
703 fn test_compute_interval_ms_serial_above_19200_uses_spec_fixed() {
704 assert_eq!(
706 compute_interval_ms(
707 PhysicalLayerType::Serial,
708 RtuApplicationLayerOptions {
709 baud_rate: Some(38400),
710 ..Default::default()
711 }
712 ),
713 (1, 0)
714 );
715 assert_eq!(
716 compute_interval_ms(
717 PhysicalLayerType::Serial,
718 RtuApplicationLayerOptions {
719 baud_rate: Some(115200),
720 inter_char_timeout: Some(FrameInterval::Bits(16.5)),
721 ..Default::default()
722 }
723 ),
724 (1, 0)
725 );
726 }
727
728 #[test]
729 fn test_compute_interval_ms_serial_explicit_ms() {
730 assert_eq!(
731 compute_interval_ms(
732 PhysicalLayerType::Serial,
733 RtuApplicationLayerOptions {
734 baud_rate: Some(9600),
735 interval_between_frames: Some(FrameInterval::Ms(20)),
736 ..Default::default()
737 }
738 ),
739 (20, 0)
740 );
741 }
742
743 #[test]
744 fn test_compute_interval_ms_serial_explicit_bits() {
745 assert_eq!(
746 compute_interval_ms(
747 PhysicalLayerType::Serial,
748 RtuApplicationLayerOptions {
749 baud_rate: Some(9600),
750 interval_between_frames: Some(FrameInterval::Bits(96.0)),
751 ..Default::default()
752 }
753 ),
754 (10, 0)
755 );
756 }
757
758 #[test]
759 fn test_compute_interval_ms_serial_default_baud_when_unspecified() {
760 assert_eq!(
761 compute_interval_ms(
762 PhysicalLayerType::Serial,
763 RtuApplicationLayerOptions::default()
764 ),
765 (5, 0)
766 );
767 }
768
769 #[test]
770 fn test_compute_interval_ms_serial_with_inter_char_timeout() {
771 let (t35, t15) = compute_interval_ms(
772 PhysicalLayerType::Serial,
773 RtuApplicationLayerOptions {
774 baud_rate: Some(9600),
775 inter_char_timeout: Some(FrameInterval::Bits(21.0)),
776 ..Default::default()
777 },
778 );
779 assert_eq!(t35, 5);
780 assert_eq!(t15, 3); }
782}