fennec_modbus/tcp/
tokio.rs1#![cfg(feature = "tokio")]
4
5use alloc::{vec, vec::Vec};
6use core::{fmt::Debug, time::Duration};
7
8use thiserror::Error;
9use tokio::{
10 io::{AsyncReadExt, AsyncWriteExt},
11 net::{TcpStream, ToSocketAddrs},
12 sync::{Mutex, MutexGuard},
13 time::timeout,
14};
15
16use crate::{
17 protocol::{Function, Request, Response, codec::Decode, function::IntoValue},
18 tcp,
19 tcp::{Header, transaction},
20};
21
22#[must_use]
23struct Connection<E> {
24 endpoint: E,
25 connect_timeout: Duration,
26 stream: Mutex<Option<TcpStream>>,
27}
28
29impl<E> Connection<E> {
30 #[cfg_attr(feature = "tracing", tracing::instrument(skip_all, level = "debug"))]
32 async fn get(&self) -> Result<ConnectionGuard<'_>, Error>
33 where
34 E: Clone + ToSocketAddrs,
35 {
36 let mut guard = self.stream.lock().await;
37
38 if guard.is_none() {
39 #[cfg(feature = "tracing")]
40 tracing::debug!("connecting…");
41
42 let stream =
43 timeout(self.connect_timeout, TcpStream::connect(Clone::clone(&self.endpoint)))
44 .await
45 .map_err(Error::ConnectionTimeout)??;
46 stream.set_nodelay(true)?;
47 socket2::SockRef::from(&stream).set_keepalive(true)?;
48 *guard = Some(stream);
49 }
50
51 Ok(ConnectionGuard(guard))
52 }
53}
54
55struct ConnectionGuard<'a>(MutexGuard<'a, Option<TcpStream>>);
56
57impl ConnectionGuard<'_> {
58 fn get_mut(&mut self) -> &mut TcpStream {
59 self.0.as_mut().unwrap()
60 }
61
62 fn invalidate(mut self) {
63 *self.0 = None;
64 }
65}
66
67#[must_use]
101pub struct Client<E> {
102 encoder: transaction::Encoder,
103 connection: Connection<E>,
104 round_trip_timeout: Duration,
105}
106
107impl<E> Client<E> {
108 pub fn new(endpoint: E) -> Self {
109 Self {
110 encoder: transaction::Encoder::default(),
111 connection: Connection {
112 endpoint,
113 connect_timeout: Duration::from_secs(5),
114 stream: Mutex::new(None),
115 },
116 round_trip_timeout: Duration::from_secs(1),
117 }
118 }
119
120 pub const fn with_connect_timeout(mut self, duration: Duration) -> Self {
121 self.connection.connect_timeout = duration;
122 self
123 }
124
125 pub const fn with_round_trip_timeout(mut self, duration: Duration) -> Self {
126 self.round_trip_timeout = duration;
127 self
128 }
129}
130
131impl<E> Client<E>
132where
133 E: Clone + ToSocketAddrs,
134{
135 #[cfg_attr(feature = "tracing", tracing::instrument(skip_all, level = "trace"))]
136 pub async fn call<F: Function>(
137 &self,
138 unit_id: tcp::UnitId,
139 args: impl Into<F::Args>,
140 ) -> Result<<F::Output as IntoValue>::Value, Error> {
141 #[cfg(feature = "tracing")]
142 tracing::debug!(?unit_id, code = ?F::CODE, "calling function…");
143
144 let mut frame = Vec::new();
145 let transaction_id =
146 self.encoder.encode(unit_id, &Request::wrap::<F>(args.into()), &mut frame)?;
147
148 let mut connection = self.connection.get().await?;
149
150 let future = async {
151 #[cfg(feature = "tracing")]
152 tracing::trace!(transaction_id, len = frame.len(), "writing frame…");
153 connection.get_mut().write_all(&frame).await?;
154
155 let header = loop {
156 #[cfg(feature = "tracing")]
157 tracing::trace!(transaction_id, "awaiting header…");
158
159 let header = {
160 let mut header_bytes = [0; tcp::Header::N_BYTES];
161 connection.get_mut().read_exact(&mut header_bytes).await?;
162 Header::decode(&mut header_bytes.as_slice())?
163 };
164
165 #[cfg(feature = "tracing")]
166 tracing::trace!(transaction_id = header.transaction_id, "received header");
167
168 if header.transaction_id == transaction_id {
169 break header;
170 }
171
172 #[cfg(feature = "tracing")]
173 tracing::warn!(header.transaction_id, "discarding response");
174
175 let mut discarded_bytes = vec![0; header.payload_length().into()];
176 connection.get_mut().read_exact(&mut discarded_bytes).await?;
177 };
178
179 let mut payload_bytes = vec![0; header.payload_length().into()];
180
181 #[cfg(feature = "tracing")]
182 tracing::trace!(len = header.payload_length(), "reading payload…");
183
184 connection.get_mut().read_exact(&mut payload_bytes).await?;
185
186 Ok::<_, Error>(payload_bytes)
187 };
188
189 let payload_bytes = timeout(self.round_trip_timeout, future)
190 .await
191 .map_err(Error::TransactionTimeout)
192 .flatten()
193 .inspect_err(|error| {
194 #[cfg(feature = "tracing")]
195 tracing::debug!("invalidating connection because of error: {error:#}");
196
197 connection.invalidate();
198 })?;
199 Ok(Response::<F>::decode(&mut payload_bytes.as_slice())?.into_result()?.into_value())
200 }
201}
202
203impl<E> Client<E> {
204 pub async fn disconnect(&self) {
211 *self.connection.stream.lock().await = None;
212 }
213}
214
215#[derive(Debug, Error)]
216pub enum Error {
217 #[error("protocol error")]
218 Protocol(#[from] crate::Error),
219
220 #[error("I/O error")]
221 Io(#[from] tokio::io::Error),
222
223 #[error("timed out connecting")]
224 ConnectionTimeout(tokio::time::error::Elapsed),
225
226 #[error("transaction timeout")]
227 TransactionTimeout(tokio::time::error::Elapsed),
228}