modbus_robust/
lib.rs

1/* Copyright 2022-2024 Bruce Merry
2 *
3 * This program is free software: you can redistribute it and/or modify it
4 * under the terms of the GNU General Public License as published by the Free
5 * Software Foundation, either version 3 of the License, or (at your option)
6 * any later version.
7 *
8 * This program is distributed in the hope that it will be useful, but WITHOUT
9 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
10 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
11 * more details.
12 *
13 * You should have received a copy of the GNU General Public License along
14 * with this program. If not, see <https://www.gnu.org/licenses/>.
15 */
16
17#![doc = include_str!("../README.md")]
18
19use async_trait::async_trait;
20use std::fmt::Debug;
21use std::io::Error;
22use std::net::SocketAddr;
23use tokio_modbus::client::{Client, Context};
24use tokio_modbus::slave::{Slave, SlaveContext};
25use tokio_modbus::{Request, Response};
26
27/// Establish a connection. The implementation must support calling
28/// [`Connector::connect`] multiple times.
29#[async_trait]
30pub trait Connector: Send + Debug {
31    type Output: Client;
32
33    /// Establish a connection.
34    async fn connect(&mut self, slave: Slave) -> Result<Self::Output, Error>;
35}
36
37/// Establish a connection using a factory function.
38///
39/// In practice, the function needs to be `'static` to be able to use this to
40/// obtain a [`tokio_modbus::client::Context`].
41pub struct SyncConnector<T: Client, F: FnMut(Slave) -> Result<T, Error> + Send + Sync> {
42    factory: F,
43}
44
45impl<T: Client, F: FnMut(Slave) -> Result<T, Error> + Send + Sync> SyncConnector<T, F> {
46    /// Create from a factory function.
47    pub fn new(factory: F) -> Self {
48        Self { factory }
49    }
50}
51
52impl<T: Client, F: FnMut(Slave) -> Result<T, Error> + Send + Sync> Debug for SyncConnector<T, F> {
53    fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
54        write!(fmt, "SyncConnector()")
55    }
56}
57
58#[async_trait]
59impl<T: Client, F: FnMut(Slave) -> Result<T, Error> + Send + Sync> Connector
60    for SyncConnector<T, F>
61{
62    type Output = T;
63
64    async fn connect(&mut self, slave: Slave) -> Result<T, Error> {
65        (self.factory)(slave)
66    }
67}
68
69/// Implementation of [`Connector`] for TCP connections.
70#[derive(Debug)]
71pub struct TcpSlaveConnector {
72    socket_addr: SocketAddr,
73}
74
75impl TcpSlaveConnector {
76    /// Construct.
77    pub fn new(socket_addr: SocketAddr) -> Self {
78        Self { socket_addr }
79    }
80}
81
82#[async_trait]
83impl Connector for TcpSlaveConnector {
84    type Output = Context;
85
86    async fn connect(&mut self, slave: Slave) -> Result<Context, Error> {
87        tokio_modbus::client::tcp::connect_slave(self.socket_addr, slave).await
88    }
89}
90
91/// Client that automatically reconnects and retries on failure.
92#[derive(Debug)]
93pub struct RobustClient<C: Connector> {
94    connector: C,
95    client: Option<C::Output>,
96    slave: Slave,
97}
98
99impl<C: Connector> RobustClient<C> {
100    /// Construct a robust client.
101    ///
102    /// When constructed, there is no established connection. An attempt will
103    /// be made to establish one on the first call.
104    pub fn new(connector: C, slave: Slave) -> Self {
105        Self {
106            connector,
107            client: None,
108            slave,
109        }
110    }
111}
112
113impl<C: Connector + 'static> RobustClient<C> {
114    /// Construct a robust client wrapped in a
115    /// [`tokio_modbus::client::Context`].
116    ///
117    /// This is the constructor you will most likely want to use, because
118    /// [Context][`tokio_modbus::client::Context`] provides all the
119    /// convenience functions.
120    pub fn new_context(connector: C, slave: Slave) -> Context {
121        (Box::new(Self::new(connector, slave)) as Box<dyn Client>).into()
122    }
123}
124
125/// Construct a [`tokio_modbus::client::Context`] from a connection factory
126/// function.
127///
128/// The connection is not immediately established. It will be attempted on
129/// the first call.
130pub fn new_sync<
131    T: Client + 'static,
132    F: FnMut(Slave) -> Result<T, Error> + Send + Sync + 'static,
133>(
134    factory: F,
135    slave: Slave,
136) -> Context {
137    RobustClient::new_context(SyncConnector::new(factory), slave)
138}
139
140/// Construct a [`tokio_modbus::client::Context`] for a TCP connection.
141///
142/// The connection is not immediately established. It will be attempted on
143/// the first call.
144pub fn new_tcp_slave(socket_addr: SocketAddr, slave: Slave) -> Context {
145    RobustClient::new_context(TcpSlaveConnector::new(socket_addr), slave)
146}
147
148/// Construct a [`tokio_modbus::client::Context`] for an RTU connection.
149///
150/// The connection is not immediately established. It will be attempted on
151/// the first call.
152///
153/// This implementation only allows the baud rate to be set, and not other
154/// options such as the parity bits. If more control is needed, use
155/// [`new_sync`].
156pub fn new_rtu_slave(device: impl Into<String>, baud_rate: u32, slave: Slave) -> Context {
157    let device = device.into();
158    new_sync(
159        move |slave| -> Result<Context, Error> {
160            let serial_builder = tokio_serial::new(&device, baud_rate);
161            let serial_stream = tokio_serial::SerialStream::open(&serial_builder)?;
162            Ok(tokio_modbus::client::rtu::attach_slave(
163                serial_stream,
164                slave,
165            ))
166        },
167        slave,
168    )
169}
170
171impl<C: Connector> SlaveContext for RobustClient<C> {
172    fn set_slave(&mut self, slave: Slave) {
173        self.slave = slave;
174        if let Some(ref mut client) = self.client {
175            client.set_slave(slave)
176        }
177    }
178}
179
180#[async_trait]
181impl<C: Connector> Client for RobustClient<C> {
182    async fn call(&mut self, req: Request<'_>) -> tokio_modbus::Result<Response> {
183        let (client, fresh) = match self.client {
184            None => {
185                let c = self
186                    .connector
187                    .connect(self.slave)
188                    .await
189                    .map_err(tokio_modbus::Error::Transport)?;
190                (self.client.insert(c), true)
191            }
192            Some(ref mut c) => (c, false),
193        };
194        match client.call(req.clone()).await {
195            result if fresh => result, // Don't retry if this is a brand new connection
196            Ok(response) => Ok(response),
197            Err(_) => {
198                /* Note: an inner ExceptionCode takes the Ok path above. It
199                 * indicates that the server rejected the request, so retrying
200                 * is unlikely to help.
201                 */
202                let c = self.connector.connect(self.slave).await?;
203                self.client.insert(c).call(req).await
204            }
205        }
206    }
207
208    async fn disconnect(&mut self) -> Result<(), Error> {
209        if let Some(c) = &mut self.client {
210            c.disconnect().await?;
211        }
212        self.client.take(); // Clears out the client
213        Ok(())
214    }
215}
216
217#[cfg(test)]
218mod test {
219    use super::*;
220    use std::sync::{Arc, Mutex};
221    use tokio_modbus::prelude::*;
222
223    trait DummyState: Send + Debug {
224        fn connect(&mut self, slave: Slave) -> Result<(), Error>;
225        fn call(&mut self, req: Request) -> tokio_modbus::Result<Response>;
226    }
227
228    #[derive(Debug)]
229    struct IterDummyState<
230        I: Iterator<Item = tokio_modbus::Result<Response>> + Send + Debug,
231        J: Iterator<Item = Result<(), Error>> + Send + Debug,
232    > {
233        responses: I,
234        connects: J,
235    }
236
237    impl<
238            I: Iterator<Item = tokio_modbus::Result<Response>> + Send + Debug,
239            J: Iterator<Item = Result<(), Error>> + Send + Debug,
240        > IterDummyState<I, J>
241    {
242        fn new(responses: I, connects: J) -> Self {
243            Self {
244                responses,
245                connects,
246            }
247        }
248    }
249
250    impl<
251            I: Iterator<Item = tokio_modbus::Result<Response>> + Send + Debug,
252            J: Iterator<Item = Result<(), Error>> + Send + Debug,
253        > DummyState for IterDummyState<I, J>
254    {
255        fn connect(&mut self, _slave: Slave) -> Result<(), Error> {
256            self.connects.next().unwrap()
257        }
258
259        fn call(&mut self, _req: Request) -> tokio_modbus::Result<Response> {
260            self.responses.next().unwrap()
261        }
262    }
263
264    #[derive(Debug)]
265    struct DummyConnector<S: DummyState> {
266        state: Arc<Mutex<S>>,
267    }
268
269    #[derive(Debug)]
270    struct DummyClient<S: DummyState> {
271        state: Arc<Mutex<S>>,
272    }
273
274    impl<S: DummyState> DummyConnector<S> {
275        fn new(state: S) -> Self {
276            Self {
277                state: Arc::new(Mutex::new(state)),
278            }
279        }
280    }
281
282    #[async_trait]
283    impl<S: DummyState> Connector for DummyConnector<S> {
284        type Output = DummyClient<S>;
285
286        async fn connect(&mut self, slave: Slave) -> Result<DummyClient<S>, Error> {
287            let mut state = self.state.lock().unwrap();
288            state.connect(slave).map(|_| DummyClient {
289                state: self.state.clone(),
290            })
291        }
292    }
293
294    impl<S: DummyState> SlaveContext for DummyClient<S> {
295        fn set_slave(&mut self, _slave: Slave) {}
296    }
297
298    #[async_trait]
299    impl<S: DummyState> Client for DummyClient<S> {
300        async fn call(&mut self, req: Request<'_>) -> tokio_modbus::Result<Response> {
301            let mut state = self.state.lock().unwrap();
302            state.call(req)
303        }
304        async fn disconnect(&mut self) -> Result<(), Error> {
305            Ok(())
306        }
307    }
308
309    fn make_client_always_connect(responses: Vec<tokio_modbus::Result<Response>>) -> Context {
310        let state = IterDummyState::new(responses.into_iter(), std::iter::repeat_with(|| Ok(())));
311        RobustClient::new_context(DummyConnector::new(state), Slave(1))
312    }
313
314    fn make_client(
315        responses: Vec<tokio_modbus::Result<Response>>,
316        connects: Vec<Result<(), Error>>,
317    ) -> Context {
318        let state = IterDummyState::new(responses.into_iter(), connects.into_iter());
319        RobustClient::new_context(DummyConnector::new(state), Slave(1))
320    }
321
322    #[tokio::test]
323    async fn test_success() {
324        let responses = vec![Ok(Ok(Response::ReadHoldingRegisters(vec![123])))];
325        let mut client = make_client_always_connect(responses);
326        let result = client
327            .read_holding_registers(321, 1)
328            .await
329            .unwrap()
330            .unwrap();
331        assert_eq!(result, vec![123]);
332    }
333
334    #[tokio::test]
335    async fn test_call_failure() {
336        let responses = vec![
337            Ok(Ok(Response::ReadHoldingRegisters(vec![123]))),
338            Err(tokio_modbus::Error::Transport(Error::from(
339                std::io::ErrorKind::ConnectionReset,
340            ))),
341            Ok(Ok(Response::ReadHoldingRegisters(vec![123]))),
342        ];
343        let mut client = make_client_always_connect(responses);
344        let _ = client.read_holding_registers(321, 1).await; // Establish connection
345        let result = client
346            .read_holding_registers(321, 1)
347            .await
348            .unwrap()
349            .unwrap();
350        assert_eq!(result, vec![123]);
351    }
352
353    #[tokio::test]
354    async fn test_call_double_failure() {
355        let responses = vec![
356            Ok(Ok(Response::ReadHoldingRegisters(vec![123]))),
357            Err(tokio_modbus::Error::Transport(Error::from(
358                std::io::ErrorKind::ConnectionReset,
359            ))),
360            Err(tokio_modbus::Error::Transport(Error::from(
361                std::io::ErrorKind::PermissionDenied,
362            ))),
363        ];
364        let mut client = make_client_always_connect(responses);
365        let _ = client.read_holding_registers(321, 1).await; // Establish connection
366        match client.read_holding_registers(321, 1).await.unwrap_err() {
367            tokio_modbus::Error::Transport(err) => {
368                assert_eq!(err.kind(), std::io::ErrorKind::PermissionDenied);
369            }
370            _ => {
371                panic!("Wrong error type");
372            }
373        }
374    }
375
376    #[tokio::test]
377    async fn test_connect_failure() {
378        let responses = vec![];
379        let connects = vec![Err(Error::from(std::io::ErrorKind::ConnectionRefused))];
380        let mut client = make_client(responses, connects);
381        match client.read_holding_registers(321, 1).await.unwrap_err() {
382            tokio_modbus::Error::Transport(err) => {
383                assert_eq!(err.kind(), std::io::ErrorKind::ConnectionRefused);
384            }
385            _ => {
386                panic!("Wrong error type");
387            }
388        }
389    }
390
391    #[tokio::test]
392    async fn test_connect_failure2() {
393        let responses = vec![
394            Ok(Ok(Response::ReadHoldingRegisters(vec![123]))),
395            Err(tokio_modbus::Error::Transport(Error::from(
396                std::io::ErrorKind::ConnectionReset,
397            ))),
398        ];
399        let connects = vec![
400            Ok(()),
401            Err(Error::from(std::io::ErrorKind::ConnectionRefused)),
402        ];
403        let mut client = make_client(responses, connects);
404        let _ = client.read_holding_registers(321, 1).await; // Establish connection
405        match client.read_holding_registers(321, 1).await.unwrap_err() {
406            tokio_modbus::Error::Transport(err) => {
407                assert_eq!(err.kind(), std::io::ErrorKind::ConnectionRefused);
408            }
409            _ => {
410                panic!("Wrong error type");
411            }
412        }
413    }
414}