1use core::{
11 fmt::{self, Display},
12 future::Future,
13 net::SocketAddr,
14};
15
16use embassy_time::Duration;
17use embedded_io_async::{ErrorKind, ErrorType, Read, Write};
18
19use crate::{Readable, TcpAccept, TcpConnect, TcpShutdown, TcpSplit};
20
21#[derive(Debug)]
23pub enum WithTimeoutError<E> {
24 Error(E),
26 Timeout,
28}
29
30impl<E> From<E> for WithTimeoutError<E> {
31 fn from(e: E) -> Self {
32 Self::Error(e)
33 }
34}
35
36impl<E> fmt::Display for WithTimeoutError<E>
37where
38 E: Display,
39{
40 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
41 match self {
42 Self::Error(e) => write!(f, "{}", e),
43 Self::Timeout => write!(f, "Operation timed out"),
44 }
45 }
46}
47
48impl<E> core::error::Error for WithTimeoutError<E> where E: core::error::Error {}
49
50impl<E> embedded_io_async::Error for WithTimeoutError<E>
51where
52 E: embedded_io_async::Error,
53{
54 fn kind(&self) -> ErrorKind {
55 match self {
56 Self::Error(e) => e.kind(),
57 Self::Timeout => ErrorKind::TimedOut,
58 }
59 }
60}
61
62pub async fn with_timeout<F, T, E>(timeout_ms: u32, fut: F) -> Result<T, WithTimeoutError<E>>
73where
74 F: Future<Output = Result<T, E>>,
75{
76 map_result(embassy_time::with_timeout(Duration::from_millis(timeout_ms as _), fut).await)
77}
78
79pub struct WithTimeout<T>(T, u32);
92
93impl<T> WithTimeout<T> {
94 pub const fn new(timeout_ms: u32, io: T) -> Self {
100 Self(io, timeout_ms)
101 }
102
103 pub fn io(&self) -> &T {
105 &self.0
106 }
107
108 pub fn io_mut(&mut self) -> &mut T {
110 &mut self.0
111 }
112
113 pub fn timeout_ms(&self) -> u32 {
115 self.1
116 }
117
118 pub fn into_io(self) -> T {
120 self.0
121 }
122}
123
124impl<T> ErrorType for WithTimeout<T>
125where
126 T: ErrorType,
127{
128 type Error = WithTimeoutError<T::Error>;
129}
130
131impl<T> Read for WithTimeout<T>
132where
133 T: Read,
134{
135 async fn read(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error> {
136 with_timeout(self.1, self.0.read(buf)).await
137 }
138}
139
140impl<T> Write for WithTimeout<T>
141where
142 T: Write,
143{
144 async fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
145 with_timeout(self.1, self.0.write(buf)).await
146 }
147
148 async fn flush(&mut self) -> Result<(), Self::Error> {
149 with_timeout(self.1, self.0.flush()).await
150 }
151}
152
153impl<T> TcpConnect for WithTimeout<T>
154where
155 T: TcpConnect,
156{
157 type Error = WithTimeoutError<T::Error>;
158
159 type Socket<'a>
160 = WithTimeout<T::Socket<'a>>
161 where
162 Self: 'a;
163
164 async fn connect(&self, remote: SocketAddr) -> Result<Self::Socket<'_>, Self::Error> {
165 with_timeout(self.1, self.0.connect(remote))
166 .await
167 .map(|s| WithTimeout::new(self.1, s))
168 }
169}
170
171impl<T> Readable for WithTimeout<T>
172where
173 T: Readable,
174{
175 async fn readable(&mut self) -> Result<(), Self::Error> {
176 with_timeout(self.1, self.0.readable()).await
177 }
178}
179
180impl<T> TcpSplit for WithTimeout<T>
181where
182 T: TcpSplit,
183{
184 type Read<'a>
185 = WithTimeout<T::Read<'a>>
186 where
187 Self: 'a;
188
189 type Write<'a>
190 = WithTimeout<T::Write<'a>>
191 where
192 Self: 'a;
193
194 fn split(&mut self) -> (Self::Read<'_>, Self::Write<'_>) {
195 let (r, w) = self.0.split();
196 (WithTimeout::new(self.1, r), WithTimeout::new(self.1, w))
197 }
198}
199
200impl<T> TcpShutdown for WithTimeout<T>
201where
202 T: TcpShutdown,
203{
204 async fn close(&mut self, what: crate::Close) -> Result<(), Self::Error> {
205 with_timeout(self.1, self.0.close(what)).await
206 }
207
208 async fn abort(&mut self) -> Result<(), Self::Error> {
209 with_timeout(self.1, self.0.abort()).await
210 }
211}
212
213impl<T> TcpAccept for WithTimeout<T>
214where
215 T: TcpAccept,
216{
217 type Error = WithTimeoutError<T::Error>;
218
219 type Socket<'a>
220 = WithTimeout<T::Socket<'a>>
221 where
222 Self: 'a;
223
224 async fn accept(&self) -> Result<(SocketAddr, Self::Socket<'_>), Self::Error> {
225 let (addr, socket) = self.0.accept().await?;
226
227 Ok((addr, WithTimeout::new(self.1, socket)))
228 }
229}
230
231fn map_result<T, E>(
232 result: Result<Result<T, E>, embassy_time::TimeoutError>,
233) -> Result<T, WithTimeoutError<E>> {
234 match result {
235 Ok(Ok(t)) => Ok(t),
236 Ok(Err(e)) => Err(WithTimeoutError::Error(e)),
237 Err(_) => Err(WithTimeoutError::Timeout),
238 }
239}