1use core::{
11 fmt::{self, Display},
12 future::Future,
13 net::SocketAddr,
14};
15
16use embassy_time::Duration;
17use embedded_io_async::{ErrorKind, ErrorType, Read, Write};
18
19use crate::{Readable, TcpAccept, TcpConnect, TcpShutdown, TcpSplit};
20
21#[derive(Debug)]
23pub enum WithTimeoutError<E> {
24 Error(E),
26 Timeout,
28}
29
30impl<E> From<E> for WithTimeoutError<E> {
31 fn from(e: E) -> Self {
32 Self::Error(e)
33 }
34}
35
36impl<E> fmt::Display for WithTimeoutError<E>
37where
38 E: Display,
39{
40 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
41 match self {
42 Self::Error(e) => write!(f, "{}", e),
43 Self::Timeout => write!(f, "Operation timed out"),
44 }
45 }
46}
47
48impl<E> embedded_io_async::Error for WithTimeoutError<E>
49where
50 E: embedded_io_async::Error,
51{
52 fn kind(&self) -> ErrorKind {
53 match self {
54 Self::Error(e) => e.kind(),
55 Self::Timeout => ErrorKind::TimedOut,
56 }
57 }
58}
59
60pub async fn with_timeout<F, T, E>(timeout_ms: u32, fut: F) -> Result<T, WithTimeoutError<E>>
71where
72 F: Future<Output = Result<T, E>>,
73{
74 map_result(embassy_time::with_timeout(Duration::from_millis(timeout_ms as _), fut).await)
75}
76
77pub struct WithTimeout<T>(T, u32);
90
91impl<T> WithTimeout<T> {
92 pub const fn new(timeout_ms: u32, io: T) -> Self {
98 Self(io, timeout_ms)
99 }
100
101 pub fn io(&self) -> &T {
103 &self.0
104 }
105
106 pub fn io_mut(&mut self) -> &mut T {
108 &mut self.0
109 }
110
111 pub fn timeout_ms(&self) -> u32 {
113 self.1
114 }
115
116 pub fn into_io(self) -> T {
118 self.0
119 }
120}
121
122impl<T> ErrorType for WithTimeout<T>
123where
124 T: ErrorType,
125{
126 type Error = WithTimeoutError<T::Error>;
127}
128
129impl<T> Read for WithTimeout<T>
130where
131 T: Read,
132{
133 async fn read(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error> {
134 with_timeout(self.1, self.0.read(buf)).await
135 }
136}
137
138impl<T> Write for WithTimeout<T>
139where
140 T: Write,
141{
142 async fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
143 with_timeout(self.1, self.0.write(buf)).await
144 }
145
146 async fn flush(&mut self) -> Result<(), Self::Error> {
147 with_timeout(self.1, self.0.flush()).await
148 }
149}
150
151impl<T> TcpConnect for WithTimeout<T>
152where
153 T: TcpConnect,
154{
155 type Error = WithTimeoutError<T::Error>;
156
157 type Socket<'a>
158 = WithTimeout<T::Socket<'a>>
159 where
160 Self: 'a;
161
162 async fn connect(&self, remote: SocketAddr) -> Result<Self::Socket<'_>, Self::Error> {
163 with_timeout(self.1, self.0.connect(remote))
164 .await
165 .map(|s| WithTimeout::new(self.1, s))
166 }
167}
168
169impl<T> Readable for WithTimeout<T>
170where
171 T: Readable,
172{
173 async fn readable(&mut self) -> Result<(), Self::Error> {
174 with_timeout(self.1, self.0.readable()).await
175 }
176}
177
178impl<T> TcpSplit for WithTimeout<T>
179where
180 T: TcpSplit,
181{
182 type Read<'a>
183 = WithTimeout<T::Read<'a>>
184 where
185 Self: 'a;
186
187 type Write<'a>
188 = WithTimeout<T::Write<'a>>
189 where
190 Self: 'a;
191
192 fn split(&mut self) -> (Self::Read<'_>, Self::Write<'_>) {
193 let (r, w) = self.0.split();
194 (WithTimeout::new(self.1, r), WithTimeout::new(self.1, w))
195 }
196}
197
198impl<T> TcpShutdown for WithTimeout<T>
199where
200 T: TcpShutdown,
201{
202 async fn close(&mut self, what: crate::Close) -> Result<(), Self::Error> {
203 with_timeout(self.1, self.0.close(what)).await
204 }
205
206 async fn abort(&mut self) -> Result<(), Self::Error> {
207 with_timeout(self.1, self.0.abort()).await
208 }
209}
210
211impl<T> TcpAccept for WithTimeout<T>
212where
213 T: TcpAccept,
214{
215 type Error = WithTimeoutError<T::Error>;
216
217 type Socket<'a>
218 = WithTimeout<T::Socket<'a>>
219 where
220 Self: 'a;
221
222 async fn accept(&self) -> Result<(SocketAddr, Self::Socket<'_>), Self::Error> {
223 let (addr, socket) = self.0.accept().await?;
224
225 Ok((addr, WithTimeout::new(self.1, socket)))
226 }
227}
228
229fn map_result<T, E>(
230 result: Result<Result<T, E>, embassy_time::TimeoutError>,
231) -> Result<T, WithTimeoutError<E>> {
232 match result {
233 Ok(Ok(t)) => Ok(t),
234 Ok(Err(e)) => Err(WithTimeoutError::Error(e)),
235 Err(_) => Err(WithTimeoutError::Timeout),
236 }
237}