edge_nal/
timeout.rs

1//! This module provides utility function and a decorator struct
2//! for adding timeouts to IO types.
3//!
4//! Note that the presence of this module in the `edge-nal` crate
5//! is a bit controversial, as it is a utility, while `edge-nal` is a
6//! pure traits' crate otherwise.
7//!
8//! Therefore, the module might be moved to another location in future.
9
10use core::{
11    fmt::{self, Display},
12    future::Future,
13    net::SocketAddr,
14};
15
16use embassy_time::Duration;
17use embedded_io_async::{ErrorKind, ErrorType, Read, Write};
18
19use crate::{Readable, TcpAccept, TcpConnect, TcpShutdown, TcpSplit};
20
21/// Error type for the `with_timeout` function and `WithTimeout` struct.
22#[derive(Debug)]
23pub enum WithTimeoutError<E> {
24    /// An error occurred during the execution of the operation
25    Error(E),
26    /// The operation timed out
27    Timeout,
28}
29
30impl<E> From<E> for WithTimeoutError<E> {
31    fn from(e: E) -> Self {
32        Self::Error(e)
33    }
34}
35
36impl<E> fmt::Display for WithTimeoutError<E>
37where
38    E: Display,
39{
40    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
41        match self {
42            Self::Error(e) => write!(f, "{}", e),
43            Self::Timeout => write!(f, "Operation timed out"),
44        }
45    }
46}
47
48impl<E> embedded_io_async::Error for WithTimeoutError<E>
49where
50    E: embedded_io_async::Error,
51{
52    fn kind(&self) -> ErrorKind {
53        match self {
54            Self::Error(e) => e.kind(),
55            Self::Timeout => ErrorKind::TimedOut,
56        }
57    }
58}
59
60/// Run a fallible future with a timeout.
61///
62/// A future is a fallible future if it resolves to a `Result<T, E>`.
63///
64/// If the future completes before the timeout, its output is returned.
65/// Otherwise, on timeout, a timeout error is returned.
66///
67/// Parameters:
68/// - `timeout_ms`: The timeout duration in milliseconds
69/// - `fut`: The future to run
70pub async fn with_timeout<F, T, E>(timeout_ms: u32, fut: F) -> Result<T, WithTimeoutError<E>>
71where
72    F: Future<Output = Result<T, E>>,
73{
74    map_result(embassy_time::with_timeout(Duration::from_millis(timeout_ms as _), fut).await)
75}
76
77/// A type that wraps an IO stream type and adds a timeout to all operations.
78///
79/// The operations decorated with a timeout are the ones offered via the following traits:
80/// - `embedded_io_async::Read`
81/// - `embedded_io_async::Write`
82/// - `Readable`
83/// - `TcpConnect`
84/// - `TcpShutdown`
85///
86/// Additionally, wrapping with `WithTimeout` an IO type that implements `TcpAccept` will result
87/// in a `TcpAccept` implementation that - while waiting potentially indefinitely for an incoming
88/// connection - will return a connected socket readily wrapped with a timeout.
89pub struct WithTimeout<T>(T, u32);
90
91impl<T> WithTimeout<T> {
92    /// Create a new `WithTimeout` instance.
93    ///
94    /// Parameters:
95    /// - `timeout_ms`: The timeout duration in milliseconds
96    /// - `io`: The IO type to add a timeout to
97    pub const fn new(timeout_ms: u32, io: T) -> Self {
98        Self(io, timeout_ms)
99    }
100
101    /// Get a reference to the inner IO type.
102    pub fn io(&self) -> &T {
103        &self.0
104    }
105
106    /// Get a mutable reference to the inner IO type.
107    pub fn io_mut(&mut self) -> &mut T {
108        &mut self.0
109    }
110
111    /// Get the timeout duration in milliseconds.
112    pub fn timeout_ms(&self) -> u32 {
113        self.1
114    }
115
116    /// Get the IO type by destructuring the `WithTimeout` instance.
117    pub fn into_io(self) -> T {
118        self.0
119    }
120}
121
122impl<T> ErrorType for WithTimeout<T>
123where
124    T: ErrorType,
125{
126    type Error = WithTimeoutError<T::Error>;
127}
128
129impl<T> Read for WithTimeout<T>
130where
131    T: Read,
132{
133    async fn read(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error> {
134        with_timeout(self.1, self.0.read(buf)).await
135    }
136}
137
138impl<T> Write for WithTimeout<T>
139where
140    T: Write,
141{
142    async fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
143        with_timeout(self.1, self.0.write(buf)).await
144    }
145
146    async fn flush(&mut self) -> Result<(), Self::Error> {
147        with_timeout(self.1, self.0.flush()).await
148    }
149}
150
151impl<T> TcpConnect for WithTimeout<T>
152where
153    T: TcpConnect,
154{
155    type Error = WithTimeoutError<T::Error>;
156
157    type Socket<'a>
158        = WithTimeout<T::Socket<'a>>
159    where
160        Self: 'a;
161
162    async fn connect(&self, remote: SocketAddr) -> Result<Self::Socket<'_>, Self::Error> {
163        with_timeout(self.1, self.0.connect(remote))
164            .await
165            .map(|s| WithTimeout::new(self.1, s))
166    }
167}
168
169impl<T> Readable for WithTimeout<T>
170where
171    T: Readable,
172{
173    async fn readable(&mut self) -> Result<(), Self::Error> {
174        with_timeout(self.1, self.0.readable()).await
175    }
176}
177
178impl<T> TcpSplit for WithTimeout<T>
179where
180    T: TcpSplit,
181{
182    type Read<'a>
183        = WithTimeout<T::Read<'a>>
184    where
185        Self: 'a;
186
187    type Write<'a>
188        = WithTimeout<T::Write<'a>>
189    where
190        Self: 'a;
191
192    fn split(&mut self) -> (Self::Read<'_>, Self::Write<'_>) {
193        let (r, w) = self.0.split();
194        (WithTimeout::new(self.1, r), WithTimeout::new(self.1, w))
195    }
196}
197
198impl<T> TcpShutdown for WithTimeout<T>
199where
200    T: TcpShutdown,
201{
202    async fn close(&mut self, what: crate::Close) -> Result<(), Self::Error> {
203        with_timeout(self.1, self.0.close(what)).await
204    }
205
206    async fn abort(&mut self) -> Result<(), Self::Error> {
207        with_timeout(self.1, self.0.abort()).await
208    }
209}
210
211impl<T> TcpAccept for WithTimeout<T>
212where
213    T: TcpAccept,
214{
215    type Error = WithTimeoutError<T::Error>;
216
217    type Socket<'a>
218        = WithTimeout<T::Socket<'a>>
219    where
220        Self: 'a;
221
222    async fn accept(&self) -> Result<(SocketAddr, Self::Socket<'_>), Self::Error> {
223        let (addr, socket) = self.0.accept().await?;
224
225        Ok((addr, WithTimeout::new(self.1, socket)))
226    }
227}
228
229fn map_result<T, E>(
230    result: Result<Result<T, E>, embassy_time::TimeoutError>,
231) -> Result<T, WithTimeoutError<E>> {
232    match result {
233        Ok(Ok(t)) => Ok(t),
234        Ok(Err(e)) => Err(WithTimeoutError::Error(e)),
235        Err(_) => Err(WithTimeoutError::Timeout),
236    }
237}