Skip to main content

fennec_modbus/tcp/
tokio.rs

1//! Modbus-over-TCP implementation for [`tokio`].
2
3#![cfg(feature = "tokio")]
4
5use alloc::{vec, vec::Vec};
6use core::{fmt::Debug, time::Duration};
7
8use thiserror::Error;
9use tokio::{
10    io::{AsyncReadExt, AsyncWriteExt},
11    net::{TcpStream, ToSocketAddrs},
12    sync::{Mutex, MutexGuard},
13    time::timeout,
14};
15
16use crate::{
17    protocol::{Function, Request, Response, codec::Decode, function::IntoValue},
18    tcp,
19    tcp::{Header, transaction},
20};
21
22#[must_use]
23struct Connection<E> {
24    endpoint: E,
25    connect_timeout: Duration,
26    stream: Mutex<Option<TcpStream>>,
27}
28
29impl<E> Connection<E> {
30    /// Lazily establish a connection when needed and return the TCP stream.
31    #[cfg_attr(feature = "tracing", tracing::instrument(skip_all, level = "debug"))]
32    async fn get(&self) -> Result<ConnectionGuard<'_>, Error>
33    where
34        E: Clone + ToSocketAddrs,
35    {
36        let mut guard = self.stream.lock().await;
37
38        if guard.is_none() {
39            #[cfg(feature = "tracing")]
40            tracing::debug!("connecting…");
41
42            let stream =
43                timeout(self.connect_timeout, TcpStream::connect(Clone::clone(&self.endpoint)))
44                    .await
45                    .map_err(Error::ConnectionTimeout)??;
46            stream.set_nodelay(true)?;
47            socket2::SockRef::from(&stream).set_keepalive(true)?;
48            *guard = Some(stream);
49        }
50
51        Ok(ConnectionGuard(guard))
52    }
53}
54
55struct ConnectionGuard<'a>(MutexGuard<'a, Option<TcpStream>>);
56
57impl ConnectionGuard<'_> {
58    fn get_mut(&mut self) -> &mut TcpStream {
59        self.0.as_mut().unwrap()
60    }
61
62    fn invalidate(mut self) {
63        *self.0 = None;
64    }
65}
66
67/// Modbus TCP client for [`tokio`].
68///
69/// # Example
70///
71/// ```rust,no_run
72/// use anyhow::Result;
73/// use fennec_modbus::{
74///     protocol::{address, function::ReadHoldingRegisters},
75///     tcp::{UnitId, tokio::Client},
76/// };
77///
78/// #[tokio::main]
79/// async fn main() -> Result<()> {
80///     let unit_id = UnitId::Significant(1);
81///     let client = Client::new("battery.iot.home.arpa:502");
82///     let decivolts = client.call::<ReadHoldingRegisters<_, u16>>(unit_id, 39201).await?;
83///     Ok(())
84/// }
85/// ```
86///
87/// # Connection management
88///
89/// The underlying connection is managed automatically:
90///
91/// - An initial connection is established on first use.
92/// - The connection is dropped on any error, except for response decoding errors – in that case, the connection itself stays healthy.
93/// - Connection is re-established upon next use, so it is safe to retry operations via, for example, `backon`.
94/// - It is safe to wrap the client in [`alloc::sync::Arc`] and clone it.
95///
96/// # Pipelining
97///
98/// - The pipelining is currently *not supported*. The underlying connection stays locked for the entire transaction.
99/// - Mismatching transaction responses are *dropped*.
100#[must_use]
101pub struct Client<E> {
102    encoder: transaction::Encoder,
103    connection: Connection<E>,
104    round_trip_timeout: Duration,
105}
106
107impl<E> Client<E> {
108    pub fn new(endpoint: E) -> Self {
109        Self {
110            encoder: transaction::Encoder::default(),
111            connection: Connection {
112                endpoint,
113                connect_timeout: Duration::from_secs(5),
114                stream: Mutex::new(None),
115            },
116            round_trip_timeout: Duration::from_secs(1),
117        }
118    }
119
120    pub const fn with_connect_timeout(mut self, duration: Duration) -> Self {
121        self.connection.connect_timeout = duration;
122        self
123    }
124
125    pub const fn with_round_trip_timeout(mut self, duration: Duration) -> Self {
126        self.round_trip_timeout = duration;
127        self
128    }
129}
130
131impl<E> Client<E>
132where
133    E: Clone + ToSocketAddrs,
134{
135    #[cfg_attr(feature = "tracing", tracing::instrument(skip_all, level = "trace"))]
136    pub async fn call<F: Function>(
137        &self,
138        unit_id: tcp::UnitId,
139        args: impl Into<F::Args>,
140    ) -> Result<<F::Output as IntoValue>::Value, Error> {
141        #[cfg(feature = "tracing")]
142        tracing::debug!(?unit_id, code = ?F::CODE, "calling function…");
143
144        let mut frame = Vec::new();
145        let transaction_id =
146            self.encoder.encode(unit_id, &Request::wrap::<F>(args.into()), &mut frame)?;
147
148        let mut connection = self.connection.get().await?;
149
150        let future = async {
151            #[cfg(feature = "tracing")]
152            tracing::trace!(transaction_id, len = frame.len(), "writing frame…");
153            connection.get_mut().write_all(&frame).await?;
154
155            let header = loop {
156                #[cfg(feature = "tracing")]
157                tracing::trace!(transaction_id, "awaiting header…");
158
159                let header = {
160                    let mut header_bytes = [0; tcp::Header::N_BYTES];
161                    connection.get_mut().read_exact(&mut header_bytes).await?;
162                    Header::decode(&mut header_bytes.as_slice())?
163                };
164
165                #[cfg(feature = "tracing")]
166                tracing::trace!(transaction_id = header.transaction_id, "received header");
167
168                if header.transaction_id == transaction_id {
169                    break header;
170                }
171
172                #[cfg(feature = "tracing")]
173                tracing::warn!(header.transaction_id, "discarding response");
174
175                let mut discarded_bytes = vec![0; header.payload_length().into()];
176                connection.get_mut().read_exact(&mut discarded_bytes).await?;
177            };
178
179            let mut payload_bytes = vec![0; header.payload_length().into()];
180
181            #[cfg(feature = "tracing")]
182            tracing::trace!(len = header.payload_length(), "reading payload…");
183
184            connection.get_mut().read_exact(&mut payload_bytes).await?;
185
186            Ok::<_, Error>(payload_bytes)
187        };
188
189        let payload_bytes = timeout(self.round_trip_timeout, future)
190            .await
191            .map_err(Error::TransactionTimeout)
192            .flatten()
193            .inspect_err(|error| {
194                #[cfg(feature = "tracing")]
195                tracing::debug!("invalidating connection because of error: {error:#}");
196
197                connection.invalidate();
198            })?;
199        Ok(Response::<F>::decode(&mut payload_bytes.as_slice())?.into_result()?.into_value())
200    }
201}
202
203impl<E> Client<E> {
204    /// Disconnect the client.
205    ///
206    /// Subsequent call will re-establish a connection.
207    /// Note that the client normally disconnects automatically on error.
208    ///
209    /// This operation is idempotent, closing a closed connection is a no-op.
210    pub async fn disconnect(&self) {
211        *self.connection.stream.lock().await = None;
212    }
213}
214
215#[derive(Debug, Error)]
216pub enum Error {
217    #[error("protocol error")]
218    Protocol(#[from] crate::Error),
219
220    #[error("I/O error")]
221    Io(#[from] tokio::io::Error),
222
223    #[error("timed out connecting")]
224    ConnectionTimeout(tokio::time::error::Elapsed),
225
226    #[error("transaction timeout")]
227    TransactionTimeout(tokio::time::error::Elapsed),
228}