edge_nal/
timeout.rs

1//! This module provides utility function and a decorator struct
2//! for adding timeouts to IO types.
3//!
4//! Note that the presence of this module in the `edge-nal` crate
5//! is a bit controversial, as it is a utility, while `edge-nal` is a
6//! pure traits' crate otherwise.
7//!
8//! Therefore, the module might be moved to another location in future.
9
10use core::{
11    fmt::{self, Display},
12    future::Future,
13    net::SocketAddr,
14};
15
16use embassy_time::Duration;
17use embedded_io_async::{ErrorKind, ErrorType, Read, Write};
18
19use crate::{Readable, TcpAccept, TcpConnect, TcpShutdown, TcpSplit};
20
21/// Error type for the `with_timeout` function and `WithTimeout` struct.
22#[derive(Debug)]
23pub enum WithTimeoutError<E> {
24    /// An error occurred during the execution of the operation
25    Error(E),
26    /// The operation timed out
27    Timeout,
28}
29
30impl<E> From<E> for WithTimeoutError<E> {
31    fn from(e: E) -> Self {
32        Self::Error(e)
33    }
34}
35
36impl<E> fmt::Display for WithTimeoutError<E>
37where
38    E: Display,
39{
40    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
41        match self {
42            Self::Error(e) => write!(f, "{}", e),
43            Self::Timeout => write!(f, "Operation timed out"),
44        }
45    }
46}
47
48impl<E> core::error::Error for WithTimeoutError<E> where E: core::error::Error {}
49
50impl<E> embedded_io_async::Error for WithTimeoutError<E>
51where
52    E: embedded_io_async::Error,
53{
54    fn kind(&self) -> ErrorKind {
55        match self {
56            Self::Error(e) => e.kind(),
57            Self::Timeout => ErrorKind::TimedOut,
58        }
59    }
60}
61
62/// Run a fallible future with a timeout.
63///
64/// A future is a fallible future if it resolves to a `Result<T, E>`.
65///
66/// If the future completes before the timeout, its output is returned.
67/// Otherwise, on timeout, a timeout error is returned.
68///
69/// Parameters:
70/// - `timeout_ms`: The timeout duration in milliseconds
71/// - `fut`: The future to run
72pub async fn with_timeout<F, T, E>(timeout_ms: u32, fut: F) -> Result<T, WithTimeoutError<E>>
73where
74    F: Future<Output = Result<T, E>>,
75{
76    map_result(embassy_time::with_timeout(Duration::from_millis(timeout_ms as _), fut).await)
77}
78
79/// A type that wraps an IO stream type and adds a timeout to all operations.
80///
81/// The operations decorated with a timeout are the ones offered via the following traits:
82/// - `embedded_io_async::Read`
83/// - `embedded_io_async::Write`
84/// - `Readable`
85/// - `TcpConnect`
86/// - `TcpShutdown`
87///
88/// Additionally, wrapping with `WithTimeout` an IO type that implements `TcpAccept` will result
89/// in a `TcpAccept` implementation that - while waiting potentially indefinitely for an incoming
90/// connection - will return a connected socket readily wrapped with a timeout.
91pub struct WithTimeout<T>(T, u32);
92
93impl<T> WithTimeout<T> {
94    /// Create a new `WithTimeout` instance.
95    ///
96    /// Parameters:
97    /// - `timeout_ms`: The timeout duration in milliseconds
98    /// - `io`: The IO type to add a timeout to
99    pub const fn new(timeout_ms: u32, io: T) -> Self {
100        Self(io, timeout_ms)
101    }
102
103    /// Get a reference to the inner IO type.
104    pub fn io(&self) -> &T {
105        &self.0
106    }
107
108    /// Get a mutable reference to the inner IO type.
109    pub fn io_mut(&mut self) -> &mut T {
110        &mut self.0
111    }
112
113    /// Get the timeout duration in milliseconds.
114    pub fn timeout_ms(&self) -> u32 {
115        self.1
116    }
117
118    /// Get the IO type by destructuring the `WithTimeout` instance.
119    pub fn into_io(self) -> T {
120        self.0
121    }
122}
123
124impl<T> ErrorType for WithTimeout<T>
125where
126    T: ErrorType,
127{
128    type Error = WithTimeoutError<T::Error>;
129}
130
131impl<T> Read for WithTimeout<T>
132where
133    T: Read,
134{
135    async fn read(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error> {
136        with_timeout(self.1, self.0.read(buf)).await
137    }
138}
139
140impl<T> Write for WithTimeout<T>
141where
142    T: Write,
143{
144    async fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
145        with_timeout(self.1, self.0.write(buf)).await
146    }
147
148    async fn flush(&mut self) -> Result<(), Self::Error> {
149        with_timeout(self.1, self.0.flush()).await
150    }
151}
152
153impl<T> TcpConnect for WithTimeout<T>
154where
155    T: TcpConnect,
156{
157    type Error = WithTimeoutError<T::Error>;
158
159    type Socket<'a>
160        = WithTimeout<T::Socket<'a>>
161    where
162        Self: 'a;
163
164    async fn connect(&self, remote: SocketAddr) -> Result<Self::Socket<'_>, Self::Error> {
165        with_timeout(self.1, self.0.connect(remote))
166            .await
167            .map(|s| WithTimeout::new(self.1, s))
168    }
169}
170
171impl<T> Readable for WithTimeout<T>
172where
173    T: Readable,
174{
175    async fn readable(&mut self) -> Result<(), Self::Error> {
176        with_timeout(self.1, self.0.readable()).await
177    }
178}
179
180impl<T> TcpSplit for WithTimeout<T>
181where
182    T: TcpSplit,
183{
184    type Read<'a>
185        = WithTimeout<T::Read<'a>>
186    where
187        Self: 'a;
188
189    type Write<'a>
190        = WithTimeout<T::Write<'a>>
191    where
192        Self: 'a;
193
194    fn split(&mut self) -> (Self::Read<'_>, Self::Write<'_>) {
195        let (r, w) = self.0.split();
196        (WithTimeout::new(self.1, r), WithTimeout::new(self.1, w))
197    }
198}
199
200impl<T> TcpShutdown for WithTimeout<T>
201where
202    T: TcpShutdown,
203{
204    async fn close(&mut self, what: crate::Close) -> Result<(), Self::Error> {
205        with_timeout(self.1, self.0.close(what)).await
206    }
207
208    async fn abort(&mut self) -> Result<(), Self::Error> {
209        with_timeout(self.1, self.0.abort()).await
210    }
211}
212
213impl<T> TcpAccept for WithTimeout<T>
214where
215    T: TcpAccept,
216{
217    type Error = WithTimeoutError<T::Error>;
218
219    type Socket<'a>
220        = WithTimeout<T::Socket<'a>>
221    where
222        Self: 'a;
223
224    async fn accept(&self) -> Result<(SocketAddr, Self::Socket<'_>), Self::Error> {
225        let (addr, socket) = self.0.accept().await?;
226
227        Ok((addr, WithTimeout::new(self.1, socket)))
228    }
229}
230
231fn map_result<T, E>(
232    result: Result<Result<T, E>, embassy_time::TimeoutError>,
233) -> Result<T, WithTimeoutError<E>> {
234    match result {
235        Ok(Ok(t)) => Ok(t),
236        Ok(Err(e)) => Err(WithTimeoutError::Error(e)),
237        Err(_) => Err(WithTimeoutError::Timeout),
238    }
239}