Skip to main content

rustmod_client/
lib.rs

1//! High-level Modbus client crate.
2
3#![forbid(unsafe_code)]
4
5pub mod points;
6pub mod sync;
7
8pub use points::{CoilPoints, RegisterPoints};
9pub use sync::{SyncClientError, SyncModbusTcpClient};
10
11use rustmod_core::encoding::{Reader, Writer};
12use rustmod_core::pdu::{
13    CustomRequest, ExceptionResponse, ReadCoilsRequest, ReadDiscreteInputsRequest,
14    ReadHoldingRegistersRequest, ReadInputRegistersRequest, ReadWriteMultipleRegistersRequest,
15    Request, Response, MaskWriteRegisterRequest, WriteMultipleCoilsRequest,
16    WriteMultipleRegistersRequest, WriteSingleCoilRequest, WriteSingleRegisterRequest,
17};
18use rustmod_core::{DecodeError, EncodeError};
19use rustmod_datalink::{DataLink, DataLinkError};
20use std::sync::atomic::{AtomicU64, Ordering};
21use std::time::Duration;
22use thiserror::Error;
23use tokio::sync::Mutex;
24use tokio::time::{Instant, sleep, timeout};
25use tracing::{debug, warn};
26
27#[cfg(feature = "metrics")]
28use std::sync::Arc;
29
30#[derive(Debug, Clone, Copy, PartialEq, Eq)]
31pub enum RetryPolicy {
32    Never,
33    ReadOnly,
34    All,
35}
36
37#[derive(Debug, Clone, Copy)]
38pub struct ClientConfig {
39    pub response_timeout: Duration,
40    pub retry_count: u8,
41    pub throttle_delay: Option<Duration>,
42    pub retry_policy: RetryPolicy,
43}
44
45impl Default for ClientConfig {
46    fn default() -> Self {
47        Self {
48            response_timeout: Duration::from_secs(5),
49            retry_count: 3,
50            throttle_delay: None,
51            retry_policy: RetryPolicy::ReadOnly,
52        }
53    }
54}
55
56impl ClientConfig {
57    pub fn with_response_timeout(mut self, timeout: Duration) -> Self {
58        self.response_timeout = timeout;
59        self
60    }
61
62    pub fn with_retry_count(mut self, retry_count: u8) -> Self {
63        self.retry_count = retry_count;
64        self
65    }
66
67    pub fn with_throttle_delay(mut self, throttle_delay: Option<Duration>) -> Self {
68        self.throttle_delay = throttle_delay;
69        self
70    }
71
72    pub fn with_retry_policy(mut self, retry_policy: RetryPolicy) -> Self {
73        self.retry_policy = retry_policy;
74        self
75    }
76}
77
78#[derive(Debug, Error)]
79pub enum ClientError {
80    #[error("datalink error: {0}")]
81    DataLink(#[from] DataLinkError),
82    #[error("encode error: {0}")]
83    Encode(#[from] EncodeError),
84    #[error("decode error: {0}")]
85    Decode(#[from] DecodeError),
86    #[error("request timed out")]
87    Timeout,
88    #[error("modbus exception: {0:?}")]
89    Exception(ExceptionResponse),
90    #[error("invalid response: {0}")]
91    InvalidResponse(&'static str),
92}
93
94#[derive(Debug, Clone, PartialEq, Eq)]
95pub struct ReportServerIdResponse {
96    pub server_id: u8,
97    pub run_indicator_status: bool,
98    pub additional_data: Vec<u8>,
99}
100
101#[derive(Debug, Clone, PartialEq, Eq)]
102pub struct DeviceIdentificationObject {
103    pub object_id: u8,
104    pub value: Vec<u8>,
105}
106
107#[derive(Debug, Clone, PartialEq, Eq)]
108pub struct ReadDeviceIdentificationResponse {
109    pub read_device_id_code: u8,
110    pub conformity_level: u8,
111    pub more_follows: bool,
112    pub next_object_id: u8,
113    pub objects: Vec<DeviceIdentificationObject>,
114}
115
116#[cfg(feature = "metrics")]
117#[derive(Debug, Default)]
118pub struct ClientMetrics {
119    requests_total: AtomicU64,
120    successful_responses: AtomicU64,
121    retries_total: AtomicU64,
122    timeouts_total: AtomicU64,
123    transport_errors_total: AtomicU64,
124    exceptions_total: AtomicU64,
125    decode_errors_total: AtomicU64,
126}
127
128#[cfg(feature = "metrics")]
129#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
130pub struct ClientMetricsSnapshot {
131    pub requests_total: u64,
132    pub successful_responses: u64,
133    pub retries_total: u64,
134    pub timeouts_total: u64,
135    pub transport_errors_total: u64,
136    pub exceptions_total: u64,
137    pub decode_errors_total: u64,
138}
139
140#[cfg(feature = "metrics")]
141impl ClientMetrics {
142    fn snapshot(&self) -> ClientMetricsSnapshot {
143        ClientMetricsSnapshot {
144            requests_total: self.requests_total.load(Ordering::Relaxed),
145            successful_responses: self.successful_responses.load(Ordering::Relaxed),
146            retries_total: self.retries_total.load(Ordering::Relaxed),
147            timeouts_total: self.timeouts_total.load(Ordering::Relaxed),
148            transport_errors_total: self.transport_errors_total.load(Ordering::Relaxed),
149            exceptions_total: self.exceptions_total.load(Ordering::Relaxed),
150            decode_errors_total: self.decode_errors_total.load(Ordering::Relaxed),
151        }
152    }
153}
154
155pub struct ModbusClient<D: DataLink> {
156    datalink: D,
157    config: ClientConfig,
158    last_request_at: Mutex<Option<Instant>>,
159    request_counter: AtomicU64,
160    #[cfg(feature = "metrics")]
161    metrics: Arc<ClientMetrics>,
162}
163
164impl<D: DataLink> ModbusClient<D> {
165    pub fn new(datalink: D) -> Self {
166        Self::with_config(datalink, ClientConfig::default())
167    }
168
169    pub fn with_config(datalink: D, config: ClientConfig) -> Self {
170        Self {
171            datalink,
172            config,
173            last_request_at: Mutex::new(None),
174            request_counter: AtomicU64::new(1),
175            #[cfg(feature = "metrics")]
176            metrics: Arc::new(ClientMetrics::default()),
177        }
178    }
179
180    pub fn config(&self) -> ClientConfig {
181        self.config
182    }
183
184    #[cfg(feature = "metrics")]
185    pub fn metrics_snapshot(&self) -> ClientMetricsSnapshot {
186        self.metrics.snapshot()
187    }
188
189    fn next_correlation_id(&self) -> u64 {
190        self.request_counter.fetch_add(1, Ordering::Relaxed)
191    }
192
193    async fn apply_throttle(&self) {
194        let Some(delay) = self.config.throttle_delay else {
195            return;
196        };
197
198        let mut last = self.last_request_at.lock().await;
199        if let Some(previous) = *last {
200            let elapsed = previous.elapsed();
201            if elapsed < delay {
202                sleep(delay - elapsed).await;
203            }
204        }
205        *last = Some(Instant::now());
206    }
207
208    fn is_retryable(err: &DataLinkError) -> bool {
209        matches!(
210            err,
211            DataLinkError::Io(_)
212                | DataLinkError::Timeout
213                | DataLinkError::ConnectionClosed
214        )
215    }
216
217    fn request_is_retry_eligible(&self, request: &Request<'_>) -> bool {
218        match self.config.retry_policy {
219            RetryPolicy::Never => false,
220            RetryPolicy::All => true,
221            RetryPolicy::ReadOnly => matches!(
222                request,
223                Request::ReadCoils(_)
224                    | Request::ReadDiscreteInputs(_)
225                    | Request::ReadHoldingRegisters(_)
226                    | Request::ReadInputRegisters(_)
227            ),
228        }
229    }
230
231    async fn exchange_raw(
232        &self,
233        correlation_id: u64,
234        unit_id: u8,
235        request_pdu: &[u8],
236        response_buf: &mut [u8],
237        retry_eligible: bool,
238    ) -> Result<usize, ClientError> {
239        self.apply_throttle().await;
240
241        #[cfg(feature = "metrics")]
242        self.metrics.requests_total.fetch_add(1, Ordering::Relaxed);
243
244        let attempts = usize::from(self.config.retry_count) + 1;
245        let mut last_err: Option<ClientError> = None;
246
247        for attempt in 1..=attempts {
248            let result = timeout(
249                self.config.response_timeout,
250                self.datalink.exchange(unit_id, request_pdu, response_buf),
251            )
252            .await;
253
254            match result {
255                Ok(Ok(len)) => {
256                    debug!(
257                        correlation_id,
258                        unit_id,
259                        attempt,
260                        len,
261                        "modbus request succeeded"
262                    );
263                    #[cfg(feature = "metrics")]
264                    self.metrics
265                        .successful_responses
266                        .fetch_add(1, Ordering::Relaxed);
267                    return Ok(len);
268                }
269                Ok(Err(err)) => {
270                    #[cfg(feature = "metrics")]
271                    self.metrics
272                        .transport_errors_total
273                        .fetch_add(1, Ordering::Relaxed);
274                    if attempt < attempts && retry_eligible && Self::is_retryable(&err) {
275                        warn!(
276                            correlation_id,
277                            unit_id,
278                            attempt,
279                            error = %err,
280                            "retrying modbus request after transport error"
281                        );
282                        #[cfg(feature = "metrics")]
283                        self.metrics.retries_total.fetch_add(1, Ordering::Relaxed);
284                        last_err = Some(ClientError::DataLink(err));
285                        continue;
286                    }
287                    return Err(ClientError::DataLink(err));
288                }
289                Err(_) => {
290                    #[cfg(feature = "metrics")]
291                    self.metrics.timeouts_total.fetch_add(1, Ordering::Relaxed);
292                    if attempt < attempts && retry_eligible {
293                        warn!(
294                            correlation_id,
295                            unit_id,
296                            attempt,
297                            "retrying modbus request after timeout"
298                        );
299                        #[cfg(feature = "metrics")]
300                        self.metrics.retries_total.fetch_add(1, Ordering::Relaxed);
301                        last_err = Some(ClientError::Timeout);
302                        continue;
303                    }
304                    return Err(ClientError::Timeout);
305                }
306            }
307        }
308
309        Err(last_err.unwrap_or(ClientError::InvalidResponse(
310            "retry loop exhausted",
311        )))
312    }
313
314    async fn send_request<'a>(
315        &self,
316        unit_id: u8,
317        request: &Request<'_>,
318        response_storage: &'a mut [u8],
319    ) -> Result<Response<'a>, ClientError> {
320        let correlation_id = self.next_correlation_id();
321        let mut req_buf = [0u8; 260];
322        let mut writer = Writer::new(&mut req_buf);
323        request.encode(&mut writer)?;
324
325        debug!(
326            correlation_id,
327            unit_id,
328            function = request.function_code().as_u8(),
329            pdu_len = writer.as_written().len(),
330            "dispatching modbus request"
331        );
332        let retry_eligible = self.request_is_retry_eligible(request);
333
334        let response_len = self
335            .exchange_raw(
336                correlation_id,
337                unit_id,
338                writer.as_written(),
339                response_storage,
340                retry_eligible,
341            )
342            .await?;
343
344        let mut reader = Reader::new(&response_storage[..response_len]);
345        let response = match Response::decode(&mut reader) {
346            Ok(resp) => resp,
347            Err(err) => {
348                #[cfg(feature = "metrics")]
349                self.metrics
350                    .decode_errors_total
351                    .fetch_add(1, Ordering::Relaxed);
352                return Err(ClientError::Decode(err));
353            }
354        };
355
356        if !reader.is_empty() {
357            #[cfg(feature = "metrics")]
358            self.metrics
359                .decode_errors_total
360                .fetch_add(1, Ordering::Relaxed);
361            return Err(ClientError::InvalidResponse("trailing bytes in response"));
362        }
363
364        if let Response::Exception(ex) = response {
365            #[cfg(feature = "metrics")]
366            self.metrics.exceptions_total.fetch_add(1, Ordering::Relaxed);
367            return Err(ClientError::Exception(ex));
368        }
369
370        Ok(response)
371    }
372
373    pub async fn read_coils(
374        &self,
375        unit_id: u8,
376        start: u16,
377        quantity: u16,
378    ) -> Result<Vec<bool>, ClientError> {
379        let request = Request::ReadCoils(ReadCoilsRequest {
380            start_address: start,
381            quantity,
382        });
383
384        let mut response_buf = [0u8; 260];
385        let response = self
386            .send_request(unit_id, &request, &mut response_buf)
387            .await?;
388
389        match response {
390            Response::ReadCoils(data) => {
391                let count = usize::from(quantity);
392                if data.coil_status.len() * 8 < count {
393                    return Err(ClientError::InvalidResponse(
394                        "coil payload shorter than requested",
395                    ));
396                }
397                Ok((0..count).filter_map(|idx| data.coil(idx)).collect())
398            }
399            _ => Err(ClientError::InvalidResponse("unexpected function response")),
400        }
401    }
402
403    pub async fn custom_request(
404        &self,
405        unit_id: u8,
406        function_code: u8,
407        payload: &[u8],
408    ) -> Result<Vec<u8>, ClientError> {
409        let request = Request::Custom(CustomRequest {
410            function_code,
411            data: payload,
412        });
413
414        let mut response_buf = [0u8; 260];
415        let response = self
416            .send_request(unit_id, &request, &mut response_buf)
417            .await?;
418
419        match response {
420            Response::Custom(custom) if custom.function_code == function_code => {
421                Ok(custom.data.to_vec())
422            }
423            Response::Custom(_) => {
424                Err(ClientError::InvalidResponse("custom response function mismatch"))
425            }
426            _ => Err(ClientError::InvalidResponse("unexpected function response")),
427        }
428    }
429
430    pub async fn report_server_id(&self, unit_id: u8) -> Result<ReportServerIdResponse, ClientError> {
431        let payload = self.custom_request(unit_id, 0x11, &[]).await?;
432        let Some((&byte_count, data)) = payload.split_first() else {
433            return Err(ClientError::InvalidResponse(
434                "report server id payload missing byte count",
435            ));
436        };
437        let byte_count = usize::from(byte_count);
438        if data.len() != byte_count || byte_count < 2 {
439            return Err(ClientError::InvalidResponse(
440                "report server id payload length mismatch",
441            ));
442        }
443
444        Ok(ReportServerIdResponse {
445            server_id: data[0],
446            run_indicator_status: data[1] != 0,
447            additional_data: data[2..].to_vec(),
448        })
449    }
450
451    pub async fn read_device_identification(
452        &self,
453        unit_id: u8,
454        read_device_id_code: u8,
455        object_id: u8,
456    ) -> Result<ReadDeviceIdentificationResponse, ClientError> {
457        let payload = self
458            .custom_request(unit_id, 0x2B, &[0x0E, read_device_id_code, object_id])
459            .await?;
460
461        if payload.len() < 6 {
462            return Err(ClientError::InvalidResponse(
463                "read device identification payload too short",
464            ));
465        }
466        if payload[0] != 0x0E {
467            return Err(ClientError::InvalidResponse(
468                "read device identification MEI type mismatch",
469            ));
470        }
471
472        let object_count = usize::from(payload[5]);
473        let mut cursor = 6usize;
474        let mut objects = Vec::with_capacity(object_count);
475        for _ in 0..object_count {
476            if payload.len().saturating_sub(cursor) < 2 {
477                return Err(ClientError::InvalidResponse(
478                    "read device identification object header truncated",
479                ));
480            }
481            let id = payload[cursor];
482            let len = usize::from(payload[cursor + 1]);
483            cursor += 2;
484            let end = cursor
485                .checked_add(len)
486                .ok_or(ClientError::InvalidResponse(
487                    "read device identification object length overflow",
488                ))?;
489            if end > payload.len() {
490                return Err(ClientError::InvalidResponse(
491                    "read device identification object data truncated",
492                ));
493            }
494            objects.push(DeviceIdentificationObject {
495                object_id: id,
496                value: payload[cursor..end].to_vec(),
497            });
498            cursor = end;
499        }
500        if cursor != payload.len() {
501            return Err(ClientError::InvalidResponse(
502                "read device identification trailing data",
503            ));
504        }
505
506        Ok(ReadDeviceIdentificationResponse {
507            read_device_id_code: payload[1],
508            conformity_level: payload[2],
509            more_follows: payload[3] != 0,
510            next_object_id: payload[4],
511            objects,
512        })
513    }
514
515    pub async fn read_discrete_inputs(
516        &self,
517        unit_id: u8,
518        start: u16,
519        quantity: u16,
520    ) -> Result<Vec<bool>, ClientError> {
521        let request = Request::ReadDiscreteInputs(ReadDiscreteInputsRequest {
522            start_address: start,
523            quantity,
524        });
525
526        let mut response_buf = [0u8; 260];
527        let response = self
528            .send_request(unit_id, &request, &mut response_buf)
529            .await?;
530
531        match response {
532            Response::ReadDiscreteInputs(data) => {
533                let count = usize::from(quantity);
534                if data.input_status.len() * 8 < count {
535                    return Err(ClientError::InvalidResponse(
536                        "discrete input payload shorter than requested",
537                    ));
538                }
539                Ok((0..count).filter_map(|idx| data.coil(idx)).collect())
540            }
541            _ => Err(ClientError::InvalidResponse("unexpected function response")),
542        }
543    }
544
545    pub async fn read_holding_registers(
546        &self,
547        unit_id: u8,
548        start: u16,
549        quantity: u16,
550    ) -> Result<Vec<u16>, ClientError> {
551        let request = Request::ReadHoldingRegisters(ReadHoldingRegistersRequest {
552            start_address: start,
553            quantity,
554        });
555
556        let mut response_buf = [0u8; 260];
557        let response = self
558            .send_request(unit_id, &request, &mut response_buf)
559            .await?;
560
561        match response {
562            Response::ReadHoldingRegisters(data) => {
563                let count = usize::from(quantity);
564                if data.register_count() < count {
565                    return Err(ClientError::InvalidResponse(
566                        "register payload shorter than requested",
567                    ));
568                }
569                Ok((0..count).filter_map(|idx| data.register(idx)).collect())
570            }
571            _ => Err(ClientError::InvalidResponse("unexpected function response")),
572        }
573    }
574
575    pub async fn read_input_registers(
576        &self,
577        unit_id: u8,
578        start: u16,
579        quantity: u16,
580    ) -> Result<Vec<u16>, ClientError> {
581        let request = Request::ReadInputRegisters(ReadInputRegistersRequest {
582            start_address: start,
583            quantity,
584        });
585
586        let mut response_buf = [0u8; 260];
587        let response = self
588            .send_request(unit_id, &request, &mut response_buf)
589            .await?;
590
591        match response {
592            Response::ReadInputRegisters(data) => {
593                let count = usize::from(quantity);
594                if data.register_count() < count {
595                    return Err(ClientError::InvalidResponse(
596                        "register payload shorter than requested",
597                    ));
598                }
599                Ok((0..count).filter_map(|idx| data.register(idx)).collect())
600            }
601            _ => Err(ClientError::InvalidResponse("unexpected function response")),
602        }
603    }
604
605    pub async fn write_single_coil(
606        &self,
607        unit_id: u8,
608        address: u16,
609        value: bool,
610    ) -> Result<(), ClientError> {
611        let request = Request::WriteSingleCoil(WriteSingleCoilRequest { address, value });
612
613        let mut response_buf = [0u8; 260];
614        let response = self
615            .send_request(unit_id, &request, &mut response_buf)
616            .await?;
617
618        match response {
619            Response::WriteSingleCoil(resp) if resp.address == address && resp.value == value => Ok(()),
620            Response::WriteSingleCoil(_) => {
621                Err(ClientError::InvalidResponse("write single coil echo mismatch"))
622            }
623            _ => Err(ClientError::InvalidResponse("unexpected function response")),
624        }
625    }
626
627    pub async fn write_single_register(
628        &self,
629        unit_id: u8,
630        address: u16,
631        value: u16,
632    ) -> Result<(), ClientError> {
633        let request = Request::WriteSingleRegister(WriteSingleRegisterRequest { address, value });
634
635        let mut response_buf = [0u8; 260];
636        let response = self
637            .send_request(unit_id, &request, &mut response_buf)
638            .await?;
639
640        match response {
641            Response::WriteSingleRegister(resp) if resp.address == address && resp.value == value => {
642                Ok(())
643            }
644            Response::WriteSingleRegister(_) => {
645                Err(ClientError::InvalidResponse("write single register echo mismatch"))
646            }
647            _ => Err(ClientError::InvalidResponse("unexpected function response")),
648        }
649    }
650
651    pub async fn mask_write_register(
652        &self,
653        unit_id: u8,
654        address: u16,
655        and_mask: u16,
656        or_mask: u16,
657    ) -> Result<(), ClientError> {
658        let request = Request::MaskWriteRegister(MaskWriteRegisterRequest {
659            address,
660            and_mask,
661            or_mask,
662        });
663
664        let mut response_buf = [0u8; 260];
665        let response = self
666            .send_request(unit_id, &request, &mut response_buf)
667            .await?;
668
669        match response {
670            Response::MaskWriteRegister(resp)
671                if resp.address == address && resp.and_mask == and_mask && resp.or_mask == or_mask =>
672            {
673                Ok(())
674            }
675            Response::MaskWriteRegister(_) => {
676                Err(ClientError::InvalidResponse("mask write register echo mismatch"))
677            }
678            _ => Err(ClientError::InvalidResponse("unexpected function response")),
679        }
680    }
681
682    pub async fn write_multiple_coils(
683        &self,
684        unit_id: u8,
685        start: u16,
686        values: &[bool],
687    ) -> Result<(), ClientError> {
688        let request_variant = WriteMultipleCoilsRequest {
689            start_address: start,
690            values,
691        };
692        let expected_qty = request_variant.quantity()?;
693
694        let request = Request::WriteMultipleCoils(request_variant);
695        let mut response_buf = [0u8; 260];
696        let response = self
697            .send_request(unit_id, &request, &mut response_buf)
698            .await?;
699
700        match response {
701            Response::WriteMultipleCoils(resp)
702                if resp.start_address == start && resp.quantity == expected_qty =>
703            {
704                Ok(())
705            }
706            Response::WriteMultipleCoils(_) => {
707                Err(ClientError::InvalidResponse("write multiple coils echo mismatch"))
708            }
709            _ => Err(ClientError::InvalidResponse("unexpected function response")),
710        }
711    }
712
713    pub async fn write_multiple_registers(
714        &self,
715        unit_id: u8,
716        start: u16,
717        values: &[u16],
718    ) -> Result<(), ClientError> {
719        let request_variant = WriteMultipleRegistersRequest {
720            start_address: start,
721            values,
722        };
723        let expected_qty = request_variant.quantity()?;
724
725        let request = Request::WriteMultipleRegisters(request_variant);
726        let mut response_buf = [0u8; 260];
727        let response = self
728            .send_request(unit_id, &request, &mut response_buf)
729            .await?;
730
731        match response {
732            Response::WriteMultipleRegisters(resp)
733                if resp.start_address == start && resp.quantity == expected_qty =>
734            {
735                Ok(())
736            }
737            Response::WriteMultipleRegisters(_) => {
738                Err(ClientError::InvalidResponse(
739                    "write multiple registers echo mismatch",
740                ))
741            }
742            _ => Err(ClientError::InvalidResponse("unexpected function response")),
743        }
744    }
745
746    pub async fn read_write_multiple_registers(
747        &self,
748        unit_id: u8,
749        read_start: u16,
750        read_quantity: u16,
751        write_start: u16,
752        write_values: &[u16],
753    ) -> Result<Vec<u16>, ClientError> {
754        let request = Request::ReadWriteMultipleRegisters(ReadWriteMultipleRegistersRequest {
755            read_start_address: read_start,
756            read_quantity,
757            write_start_address: write_start,
758            values: write_values,
759        });
760
761        let mut response_buf = [0u8; 260];
762        let response = self
763            .send_request(unit_id, &request, &mut response_buf)
764            .await?;
765
766        match response {
767            Response::ReadWriteMultipleRegisters(data) => {
768                let count = usize::from(read_quantity);
769                if data.register_count() < count {
770                    return Err(ClientError::InvalidResponse(
771                        "read-write register payload shorter than requested",
772                    ));
773                }
774                Ok((0..count).filter_map(|idx| data.register(idx)).collect())
775            }
776            _ => Err(ClientError::InvalidResponse("unexpected function response")),
777        }
778    }
779}
780
781#[cfg(test)]
782mod tests {
783    use super::{ClientConfig, ClientError, ModbusClient, RetryPolicy};
784    use async_trait::async_trait;
785    use rustmod_datalink::{DataLink, DataLinkError};
786    use std::collections::VecDeque;
787    use std::sync::Arc;
788    use std::sync::atomic::{AtomicUsize, Ordering};
789    use std::time::Duration;
790    use tokio::sync::Mutex;
791    use tokio::time::sleep;
792
793    type MockQueue = VecDeque<Result<Vec<u8>, DataLinkError>>;
794
795    #[derive(Clone, Default)]
796    struct MockLink {
797        responses: Arc<Mutex<MockQueue>>,
798        calls: Arc<AtomicUsize>,
799    }
800
801    impl MockLink {
802        fn with_responses(responses: Vec<Result<Vec<u8>, DataLinkError>>) -> Self {
803            Self {
804                responses: Arc::new(Mutex::new(responses.into())),
805                calls: Arc::new(AtomicUsize::new(0)),
806            }
807        }
808
809        fn call_count(&self) -> usize {
810            self.calls.load(Ordering::Relaxed)
811        }
812    }
813
814    #[async_trait]
815    impl DataLink for MockLink {
816        async fn exchange(
817            &self,
818            _unit_id: u8,
819            _request_pdu: &[u8],
820            response_pdu: &mut [u8],
821        ) -> Result<usize, DataLinkError> {
822            self.calls.fetch_add(1, Ordering::Relaxed);
823            let mut guard = self.responses.lock().await;
824            let next = guard
825                .pop_front()
826                .ok_or(DataLinkError::InvalidResponse("no mock response"))?;
827            let bytes = next?;
828            if bytes.len() > response_pdu.len() {
829                return Err(DataLinkError::ResponseBufferTooSmall {
830                    needed: bytes.len(),
831                    available: response_pdu.len(),
832                });
833            }
834            response_pdu[..bytes.len()].copy_from_slice(&bytes);
835            Ok(bytes.len())
836        }
837    }
838
839    #[derive(Clone, Default)]
840    struct ConnectionClosedThenSlowLink {
841        calls: Arc<AtomicUsize>,
842    }
843
844    impl ConnectionClosedThenSlowLink {
845        fn call_count(&self) -> usize {
846            self.calls.load(Ordering::Relaxed)
847        }
848    }
849
850    #[async_trait]
851    impl DataLink for ConnectionClosedThenSlowLink {
852        async fn exchange(
853            &self,
854            _unit_id: u8,
855            _request_pdu: &[u8],
856            response_pdu: &mut [u8],
857        ) -> Result<usize, DataLinkError> {
858            let call = self.calls.fetch_add(1, Ordering::Relaxed);
859            if call == 0 {
860                return Err(DataLinkError::ConnectionClosed);
861            }
862
863            sleep(Duration::from_millis(50)).await;
864            response_pdu[..4].copy_from_slice(&[0x03, 0x02, 0x00, 0x2A]);
865            Ok(4)
866        }
867    }
868
869    #[tokio::test]
870    async fn read_holding_registers_success() {
871        let link = MockLink::with_responses(vec![Ok(vec![
872            0x03, 0x04, 0x12, 0x34, 0xAB, 0xCD,
873        ])]);
874        let client = ModbusClient::new(link);
875
876        let values = client.read_holding_registers(1, 0, 2).await.unwrap();
877        assert_eq!(values, vec![0x1234, 0xABCD]);
878    }
879
880    #[tokio::test]
881    async fn exception_is_mapped() {
882        let link = MockLink::with_responses(vec![Ok(vec![0x83, 0x02])]);
883        let client = ModbusClient::new(link);
884
885        let err = client.read_holding_registers(1, 0, 1).await.unwrap_err();
886        assert!(matches!(err, ClientError::Exception(_)));
887    }
888
889    #[tokio::test]
890    async fn custom_request_roundtrip() {
891        let link = MockLink::with_responses(vec![Ok(vec![0x41, 0x12, 0x34])]);
892        let client = ModbusClient::new(link);
893
894        let payload = client.custom_request(1, 0x41, &[0xAA]).await.unwrap();
895        assert_eq!(payload, vec![0x12, 0x34]);
896    }
897
898    #[tokio::test]
899    async fn report_server_id_parses_payload() {
900        let link = MockLink::with_responses(vec![Ok(vec![0x11, 0x03, 0x2A, 0xFF, 0x10])]);
901        let client = ModbusClient::new(link);
902
903        let report = client.report_server_id(1).await.unwrap();
904        assert_eq!(report.server_id, 0x2A);
905        assert!(report.run_indicator_status);
906        assert_eq!(report.additional_data, vec![0x10]);
907    }
908
909    #[tokio::test]
910    async fn read_device_identification_parses_objects() {
911        let link = MockLink::with_responses(vec![Ok(vec![
912            0x2B, 0x0E, 0x01, 0x01, 0x00, 0x00, 0x02, 0x00, 0x07, b'r', b'u', b's', b't', b'-',
913            b'm', b'o', 0x01, 0x03, b'0', b'.', b'1',
914        ])]);
915        let client = ModbusClient::new(link);
916
917        let response = client.read_device_identification(1, 0x01, 0x00).await.unwrap();
918        assert_eq!(response.read_device_id_code, 0x01);
919        assert_eq!(response.conformity_level, 0x01);
920        assert!(!response.more_follows);
921        assert_eq!(response.next_object_id, 0x00);
922        assert_eq!(response.objects.len(), 2);
923        assert_eq!(response.objects[0].object_id, 0x00);
924        assert_eq!(response.objects[0].value, b"rust-mo".to_vec());
925        assert_eq!(response.objects[1].object_id, 0x01);
926        assert_eq!(response.objects[1].value, b"0.1".to_vec());
927    }
928
929    #[tokio::test]
930    async fn read_device_identification_rejects_wrong_mei_type() {
931        let link = MockLink::with_responses(vec![Ok(vec![
932            0x2B, 0x0D, 0x01, 0x01, 0x00, 0x00, 0x00,
933        ])]);
934        let client = ModbusClient::new(link);
935
936        let err = client
937            .read_device_identification(1, 0x01, 0x00)
938            .await
939            .unwrap_err();
940        assert!(matches!(
941            err,
942            ClientError::InvalidResponse("read device identification MEI type mismatch")
943        ));
944    }
945
946    #[tokio::test]
947    async fn retries_after_connection_closed() {
948        let link = MockLink::with_responses(vec![
949            Err(DataLinkError::ConnectionClosed),
950            Ok(vec![0x03, 0x02, 0x00, 0x2A]),
951        ]);
952        let link_for_assert = link.clone();
953
954        let client = ModbusClient::with_config(link, ClientConfig::default().with_retry_count(1));
955
956        let values = client.read_holding_registers(1, 0, 1).await.unwrap();
957        assert_eq!(values, vec![42]);
958        assert_eq!(link_for_assert.call_count(), 2);
959    }
960
961    #[tokio::test]
962    async fn write_is_not_retried_by_default() {
963        let link = MockLink::with_responses(vec![
964            Err(DataLinkError::ConnectionClosed),
965            Ok(vec![0x06, 0x00, 0x01, 0x00, 0x2A]),
966        ]);
967        let link_for_assert = link.clone();
968
969        let client = ModbusClient::with_config(link, ClientConfig::default().with_retry_count(1));
970        let err = client.write_single_register(1, 1, 42).await.unwrap_err();
971
972        assert!(matches!(
973            err,
974            ClientError::DataLink(DataLinkError::ConnectionClosed)
975        ));
976        assert_eq!(link_for_assert.call_count(), 1);
977    }
978
979    #[tokio::test]
980    async fn response_buffer_too_small_is_not_retried() {
981        let link = MockLink::with_responses(vec![
982            Err(DataLinkError::ResponseBufferTooSmall {
983                needed: 300,
984                available: 260,
985            }),
986            Ok(vec![0x03, 0x02, 0x00, 0x2A]),
987        ]);
988        let link_for_assert = link.clone();
989
990        let client = ModbusClient::with_config(link, ClientConfig::default().with_retry_count(1));
991        let err = client.read_holding_registers(1, 0, 1).await.unwrap_err();
992
993        assert!(matches!(
994            err,
995            ClientError::DataLink(DataLinkError::ResponseBufferTooSmall { .. })
996        ));
997        assert_eq!(link_for_assert.call_count(), 1);
998    }
999
1000    #[tokio::test]
1001    async fn write_can_retry_when_policy_is_all() {
1002        let link = MockLink::with_responses(vec![
1003            Err(DataLinkError::ConnectionClosed),
1004            Ok(vec![0x06, 0x00, 0x01, 0x00, 0x2A]),
1005        ]);
1006        let link_for_assert = link.clone();
1007
1008        let config = ClientConfig::default()
1009            .with_retry_count(1)
1010            .with_retry_policy(RetryPolicy::All);
1011        let client = ModbusClient::with_config(link, config);
1012        client.write_single_register(1, 1, 42).await.unwrap();
1013
1014        assert_eq!(link_for_assert.call_count(), 2);
1015    }
1016
1017    #[tokio::test]
1018    async fn final_timeout_is_reported_over_previous_transport_error() {
1019        let link = ConnectionClosedThenSlowLink::default();
1020        let link_for_assert = link.clone();
1021
1022        let config = ClientConfig::default()
1023            .with_retry_count(1)
1024            .with_response_timeout(Duration::from_millis(10));
1025        let client = ModbusClient::with_config(link, config);
1026
1027        let err = client.read_holding_registers(1, 0, 1).await.unwrap_err();
1028        assert!(matches!(err, ClientError::Timeout));
1029        assert_eq!(link_for_assert.call_count(), 2);
1030    }
1031
1032    #[tokio::test]
1033    async fn mask_write_register_success() {
1034        let link = MockLink::with_responses(vec![Ok(vec![0x16, 0x00, 0x04, 0xFF, 0x00, 0x00, 0x12])]);
1035        let client = ModbusClient::new(link);
1036        client
1037            .mask_write_register(1, 0x0004, 0xFF00, 0x0012)
1038            .await
1039            .unwrap();
1040    }
1041
1042    #[tokio::test]
1043    async fn read_write_multiple_registers_success() {
1044        let link = MockLink::with_responses(vec![Ok(vec![0x17, 0x04, 0x12, 0x34, 0xAB, 0xCD])]);
1045        let client = ModbusClient::new(link);
1046
1047        let values = client
1048            .read_write_multiple_registers(1, 0x0010, 2, 0x0020, &[0x0102, 0x0304])
1049            .await
1050            .unwrap();
1051        assert_eq!(values, vec![0x1234, 0xABCD]);
1052    }
1053
1054    #[tokio::test]
1055    async fn read_coils_rejects_truncated_payload() {
1056        let link = MockLink::with_responses(vec![Ok(vec![0x01, 0x01, 0b0000_1111])]);
1057        let client = ModbusClient::new(link);
1058        let err = client.read_coils(1, 0, 9).await.unwrap_err();
1059        assert!(matches!(
1060            err,
1061            ClientError::InvalidResponse("coil payload shorter than requested")
1062        ));
1063    }
1064
1065    #[tokio::test]
1066    async fn read_discrete_inputs_rejects_truncated_payload() {
1067        let link = MockLink::with_responses(vec![Ok(vec![0x02, 0x01, 0b0000_1111])]);
1068        let client = ModbusClient::new(link);
1069        let err = client.read_discrete_inputs(1, 0, 9).await.unwrap_err();
1070        assert!(matches!(
1071            err,
1072            ClientError::InvalidResponse("discrete input payload shorter than requested")
1073        ));
1074    }
1075
1076    #[cfg(feature = "metrics")]
1077    #[tokio::test]
1078    async fn metrics_count_success() {
1079        let link = MockLink::with_responses(vec![Ok(vec![0x03, 0x02, 0x00, 0x2A])]);
1080        let client = ModbusClient::new(link);
1081
1082        let _ = client.read_holding_registers(1, 0, 1).await.unwrap();
1083        let metrics = client.metrics_snapshot();
1084
1085        assert_eq!(metrics.requests_total, 1);
1086        assert_eq!(metrics.successful_responses, 1);
1087        assert_eq!(metrics.exceptions_total, 0);
1088    }
1089}