1#![doc = include_str!("../README.md")]
18
19use async_trait::async_trait;
20use std::fmt::Debug;
21use std::io::Error;
22use std::net::SocketAddr;
23use tokio_modbus::client::{Client, Context};
24use tokio_modbus::slave::{Slave, SlaveContext};
25use tokio_modbus::{Request, Response};
26
27#[async_trait]
30pub trait Connector: Send + Debug {
31 type Output: Client;
32
33 async fn connect(&mut self, slave: Slave) -> Result<Self::Output, Error>;
35}
36
37pub struct SyncConnector<T: Client, F: FnMut(Slave) -> Result<T, Error> + Send + Sync> {
42 factory: F,
43}
44
45impl<T: Client, F: FnMut(Slave) -> Result<T, Error> + Send + Sync> SyncConnector<T, F> {
46 pub fn new(factory: F) -> Self {
48 Self { factory }
49 }
50}
51
52impl<T: Client, F: FnMut(Slave) -> Result<T, Error> + Send + Sync> Debug for SyncConnector<T, F> {
53 fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
54 write!(fmt, "SyncConnector()")
55 }
56}
57
58#[async_trait]
59impl<T: Client, F: FnMut(Slave) -> Result<T, Error> + Send + Sync> Connector
60 for SyncConnector<T, F>
61{
62 type Output = T;
63
64 async fn connect(&mut self, slave: Slave) -> Result<T, Error> {
65 (self.factory)(slave)
66 }
67}
68
69#[derive(Debug)]
71pub struct TcpSlaveConnector {
72 socket_addr: SocketAddr,
73}
74
75impl TcpSlaveConnector {
76 pub fn new(socket_addr: SocketAddr) -> Self {
78 Self { socket_addr }
79 }
80}
81
82#[async_trait]
83impl Connector for TcpSlaveConnector {
84 type Output = Context;
85
86 async fn connect(&mut self, slave: Slave) -> Result<Context, Error> {
87 tokio_modbus::client::tcp::connect_slave(self.socket_addr, slave).await
88 }
89}
90
91#[derive(Debug)]
93pub struct RobustClient<C: Connector> {
94 connector: C,
95 client: Option<C::Output>,
96 slave: Slave,
97}
98
99impl<C: Connector> RobustClient<C> {
100 pub fn new(connector: C, slave: Slave) -> Self {
105 Self {
106 connector,
107 client: None,
108 slave,
109 }
110 }
111}
112
113impl<C: Connector + 'static> RobustClient<C> {
114 pub fn new_context(connector: C, slave: Slave) -> Context {
121 (Box::new(Self::new(connector, slave)) as Box<dyn Client>).into()
122 }
123}
124
125pub fn new_sync<
131 T: Client + 'static,
132 F: FnMut(Slave) -> Result<T, Error> + Send + Sync + 'static,
133>(
134 factory: F,
135 slave: Slave,
136) -> Context {
137 RobustClient::new_context(SyncConnector::new(factory), slave)
138}
139
140pub fn new_tcp_slave(socket_addr: SocketAddr, slave: Slave) -> Context {
145 RobustClient::new_context(TcpSlaveConnector::new(socket_addr), slave)
146}
147
148pub fn new_rtu_slave(device: impl Into<String>, baud_rate: u32, slave: Slave) -> Context {
157 let device = device.into();
158 new_sync(
159 move |slave| -> Result<Context, Error> {
160 let serial_builder = tokio_serial::new(&device, baud_rate);
161 let serial_stream = tokio_serial::SerialStream::open(&serial_builder)?;
162 Ok(tokio_modbus::client::rtu::attach_slave(
163 serial_stream,
164 slave,
165 ))
166 },
167 slave,
168 )
169}
170
171impl<C: Connector> SlaveContext for RobustClient<C> {
172 fn set_slave(&mut self, slave: Slave) {
173 self.slave = slave;
174 if let Some(ref mut client) = self.client {
175 client.set_slave(slave)
176 }
177 }
178}
179
180#[async_trait]
181impl<C: Connector> Client for RobustClient<C> {
182 async fn call(&mut self, req: Request<'_>) -> tokio_modbus::Result<Response> {
183 let (client, fresh) = match self.client {
184 None => {
185 let c = self
186 .connector
187 .connect(self.slave)
188 .await
189 .map_err(tokio_modbus::Error::Transport)?;
190 (self.client.insert(c), true)
191 }
192 Some(ref mut c) => (c, false),
193 };
194 match client.call(req.clone()).await {
195 result if fresh => result, Ok(response) => Ok(response),
197 Err(_) => {
198 let c = self.connector.connect(self.slave).await?;
203 self.client.insert(c).call(req).await
204 }
205 }
206 }
207
208 async fn disconnect(&mut self) -> Result<(), Error> {
209 if let Some(c) = &mut self.client {
210 c.disconnect().await?;
211 }
212 self.client.take(); Ok(())
214 }
215}
216
217#[cfg(test)]
218mod test {
219 use super::*;
220 use std::sync::{Arc, Mutex};
221 use tokio_modbus::prelude::*;
222
223 trait DummyState: Send + Debug {
224 fn connect(&mut self, slave: Slave) -> Result<(), Error>;
225 fn call(&mut self, req: Request) -> tokio_modbus::Result<Response>;
226 }
227
228 #[derive(Debug)]
229 struct IterDummyState<
230 I: Iterator<Item = tokio_modbus::Result<Response>> + Send + Debug,
231 J: Iterator<Item = Result<(), Error>> + Send + Debug,
232 > {
233 responses: I,
234 connects: J,
235 }
236
237 impl<
238 I: Iterator<Item = tokio_modbus::Result<Response>> + Send + Debug,
239 J: Iterator<Item = Result<(), Error>> + Send + Debug,
240 > IterDummyState<I, J>
241 {
242 fn new(responses: I, connects: J) -> Self {
243 Self {
244 responses,
245 connects,
246 }
247 }
248 }
249
250 impl<
251 I: Iterator<Item = tokio_modbus::Result<Response>> + Send + Debug,
252 J: Iterator<Item = Result<(), Error>> + Send + Debug,
253 > DummyState for IterDummyState<I, J>
254 {
255 fn connect(&mut self, _slave: Slave) -> Result<(), Error> {
256 self.connects.next().unwrap()
257 }
258
259 fn call(&mut self, _req: Request) -> tokio_modbus::Result<Response> {
260 self.responses.next().unwrap()
261 }
262 }
263
264 #[derive(Debug)]
265 struct DummyConnector<S: DummyState> {
266 state: Arc<Mutex<S>>,
267 }
268
269 #[derive(Debug)]
270 struct DummyClient<S: DummyState> {
271 state: Arc<Mutex<S>>,
272 }
273
274 impl<S: DummyState> DummyConnector<S> {
275 fn new(state: S) -> Self {
276 Self {
277 state: Arc::new(Mutex::new(state)),
278 }
279 }
280 }
281
282 #[async_trait]
283 impl<S: DummyState> Connector for DummyConnector<S> {
284 type Output = DummyClient<S>;
285
286 async fn connect(&mut self, slave: Slave) -> Result<DummyClient<S>, Error> {
287 let mut state = self.state.lock().unwrap();
288 state.connect(slave).map(|_| DummyClient {
289 state: self.state.clone(),
290 })
291 }
292 }
293
294 impl<S: DummyState> SlaveContext for DummyClient<S> {
295 fn set_slave(&mut self, _slave: Slave) {}
296 }
297
298 #[async_trait]
299 impl<S: DummyState> Client for DummyClient<S> {
300 async fn call(&mut self, req: Request<'_>) -> tokio_modbus::Result<Response> {
301 let mut state = self.state.lock().unwrap();
302 state.call(req)
303 }
304 async fn disconnect(&mut self) -> Result<(), Error> {
305 Ok(())
306 }
307 }
308
309 fn make_client_always_connect(responses: Vec<tokio_modbus::Result<Response>>) -> Context {
310 let state = IterDummyState::new(responses.into_iter(), std::iter::repeat_with(|| Ok(())));
311 RobustClient::new_context(DummyConnector::new(state), Slave(1))
312 }
313
314 fn make_client(
315 responses: Vec<tokio_modbus::Result<Response>>,
316 connects: Vec<Result<(), Error>>,
317 ) -> Context {
318 let state = IterDummyState::new(responses.into_iter(), connects.into_iter());
319 RobustClient::new_context(DummyConnector::new(state), Slave(1))
320 }
321
322 #[tokio::test]
323 async fn test_success() {
324 let responses = vec![Ok(Ok(Response::ReadHoldingRegisters(vec![123])))];
325 let mut client = make_client_always_connect(responses);
326 let result = client
327 .read_holding_registers(321, 1)
328 .await
329 .unwrap()
330 .unwrap();
331 assert_eq!(result, vec![123]);
332 }
333
334 #[tokio::test]
335 async fn test_call_failure() {
336 let responses = vec![
337 Ok(Ok(Response::ReadHoldingRegisters(vec![123]))),
338 Err(tokio_modbus::Error::Transport(Error::from(
339 std::io::ErrorKind::ConnectionReset,
340 ))),
341 Ok(Ok(Response::ReadHoldingRegisters(vec![123]))),
342 ];
343 let mut client = make_client_always_connect(responses);
344 let _ = client.read_holding_registers(321, 1).await; let result = client
346 .read_holding_registers(321, 1)
347 .await
348 .unwrap()
349 .unwrap();
350 assert_eq!(result, vec![123]);
351 }
352
353 #[tokio::test]
354 async fn test_call_double_failure() {
355 let responses = vec![
356 Ok(Ok(Response::ReadHoldingRegisters(vec![123]))),
357 Err(tokio_modbus::Error::Transport(Error::from(
358 std::io::ErrorKind::ConnectionReset,
359 ))),
360 Err(tokio_modbus::Error::Transport(Error::from(
361 std::io::ErrorKind::PermissionDenied,
362 ))),
363 ];
364 let mut client = make_client_always_connect(responses);
365 let _ = client.read_holding_registers(321, 1).await; match client.read_holding_registers(321, 1).await.unwrap_err() {
367 tokio_modbus::Error::Transport(err) => {
368 assert_eq!(err.kind(), std::io::ErrorKind::PermissionDenied);
369 }
370 _ => {
371 panic!("Wrong error type");
372 }
373 }
374 }
375
376 #[tokio::test]
377 async fn test_connect_failure() {
378 let responses = vec![];
379 let connects = vec![Err(Error::from(std::io::ErrorKind::ConnectionRefused))];
380 let mut client = make_client(responses, connects);
381 match client.read_holding_registers(321, 1).await.unwrap_err() {
382 tokio_modbus::Error::Transport(err) => {
383 assert_eq!(err.kind(), std::io::ErrorKind::ConnectionRefused);
384 }
385 _ => {
386 panic!("Wrong error type");
387 }
388 }
389 }
390
391 #[tokio::test]
392 async fn test_connect_failure2() {
393 let responses = vec![
394 Ok(Ok(Response::ReadHoldingRegisters(vec![123]))),
395 Err(tokio_modbus::Error::Transport(Error::from(
396 std::io::ErrorKind::ConnectionReset,
397 ))),
398 ];
399 let connects = vec![
400 Ok(()),
401 Err(Error::from(std::io::ErrorKind::ConnectionRefused)),
402 ];
403 let mut client = make_client(responses, connects);
404 let _ = client.read_holding_registers(321, 1).await; match client.read_holding_registers(321, 1).await.unwrap_err() {
406 tokio_modbus::Error::Transport(err) => {
407 assert_eq!(err.kind(), std::io::ErrorKind::ConnectionRefused);
408 }
409 _ => {
410 panic!("Wrong error type");
411 }
412 }
413 }
414}