1use std::{
2 fmt::Debug,
3 future::Future,
4 io,
5 pin::Pin,
6 task::{Context, Poll},
7 time::Duration,
8};
9
10use hala_io_driver::{
11 current_poller, get_driver, Cmd, Description, Driver, Handle, Interest, OpenFlags,
12};
13
14pub struct Sleep {
15 fd: Option<Handle>,
16 driver: Driver,
17 expired: Duration,
18 poller: Handle,
19}
20
21impl Sleep {
22 pub fn new(expired: Duration) -> io::Result<Self> {
23 let driver = get_driver()?;
24
25 let poller = current_poller()?;
26
27 Ok(Self {
28 fd: None,
29 driver,
30 expired,
31 poller,
32 })
33 }
34}
35
36impl Future for Sleep {
37 type Output = io::Result<()>;
38
39 fn poll(mut self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
40 if self.fd.is_none() {
42 let fd = match self
43 .driver
44 .fd_open(Description::Timeout, OpenFlags::Duration(self.expired))
45 {
46 Err(err) => return Poll::Ready(Err(err)),
47 Ok(fd) => fd,
48 };
49
50 self.fd = Some(fd);
51
52 match self.driver.fd_cntl(
53 self.poller,
54 Cmd::Register {
55 source: fd,
56 interests: Interest::Readable,
57 },
58 ) {
59 Err(err) => return Poll::Ready(Err(err)),
60 _ => {}
61 }
62
63 log::trace!("create timeout {:?}", fd);
64 }
65
66 match self
68 .driver
69 .fd_cntl(self.fd.unwrap(), Cmd::Timeout(cx.waker().clone()))
70 {
71 Ok(resp) => match resp.try_into_timeout() {
72 Ok(status) => {
73 if status {
74 return Poll::Ready(Ok(()));
75 }
76 }
77
78 Err(err) => {
79 return Poll::Ready(Err(err));
80 }
81 },
82 Err(err) => return Poll::Ready(Err(err)),
83 }
84
85 return Poll::Pending;
86 }
87}
88
89impl Drop for Sleep {
90 fn drop(&mut self) {
91 if let Some(fd) = self.fd.take() {
92 log::trace!("dropping sleep, fd={:?}", fd);
93
94 self.driver
95 .fd_cntl(self.poller, Cmd::Deregister(fd))
96 .unwrap();
97
98 self.driver.fd_close(fd).unwrap();
99 }
100 }
101}
102
103pub struct Timeout<Fut> {
104 fut: Pin<Box<Fut>>,
105 fd: Option<Handle>,
106 driver: Driver,
107 expired: Duration,
108 poller: Handle,
109}
110
111impl<Fut> Timeout<Fut> {
112 pub fn new(fut: Fut, expired: Duration) -> io::Result<Self> {
113 let driver = get_driver()?;
114
115 let poller = current_poller()?;
116
117 Ok(Self {
118 fut: Box::pin(fut),
119 fd: None,
120 driver,
121 expired,
122 poller,
123 })
124 }
125}
126
127impl<'a, Fut, R> Future for Timeout<Fut>
128where
129 Fut: Future<Output = io::Result<R>> + 'a,
130 R: Debug,
131{
132 type Output = io::Result<R>;
133
134 fn poll(mut self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
135 if self.fd.is_none() {
137 let fd = match self
138 .driver
139 .fd_open(Description::Timeout, OpenFlags::Duration(self.expired))
140 {
141 Err(err) => return Poll::Ready(Err(err)),
142 Ok(fd) => fd,
143 };
144
145 self.fd = Some(fd);
146
147 match self.driver.fd_cntl(
148 self.poller,
149 Cmd::Register {
150 source: fd,
151 interests: Interest::Readable,
152 },
153 ) {
154 Err(err) => return Poll::Ready(Err(err)),
155 _ => {}
156 }
157
158 log::trace!("create timeout {:?}", fd);
159 }
160
161 match self.fut.as_mut().poll(cx) {
163 Poll::Ready(r) => {
164 log::trace!("timeout poll future ready {:?}", r);
165 return Poll::Ready(r);
166 }
167 _ => {
168 log::trace!("timeout poll future pending");
169 }
170 }
171
172 match self
174 .driver
175 .fd_cntl(self.fd.unwrap(), Cmd::Timeout(cx.waker().clone()))
176 {
177 Ok(resp) => match resp.try_into_timeout() {
178 Ok(status) => {
179 if status {
180 return Poll::Ready(Err(io::Error::new(
181 io::ErrorKind::TimedOut,
182 format!("async io timeout={:?}", self.expired),
183 )));
184 }
185 }
186
187 Err(err) => {
188 return Poll::Ready(Err(err));
189 }
190 },
191 Err(err) => return Poll::Ready(Err(err)),
192 }
193
194 return Poll::Pending;
195 }
196}
197
198impl<Fut> Drop for Timeout<Fut> {
199 fn drop(&mut self) {
200 if let Some(fd) = self.fd.take() {
201 self.driver
202 .fd_cntl(self.poller, Cmd::Deregister(fd))
203 .unwrap();
204
205 self.driver.fd_close(fd).unwrap();
206 }
207 }
208}
209
210pub async fn timeout<'a, Fut, R>(fut: Fut, expired: Option<Duration>) -> io::Result<R>
212where
213 Fut: Future<Output = io::Result<R>> + 'a,
214 R: Debug,
215{
216 if let Some(expired) = expired {
217 Timeout::new(fut, expired)?.await
218 } else {
219 fut.await
220 }
221}
222
223pub async fn sleep(duration: Duration) -> io::Result<()> {
225 Sleep::new(duration)?.await
226}
227
228#[cfg(test)]
229mod tests {
230 use futures::future::poll_fn;
231
232 use super::*;
233
234 #[hala_io_test::test]
235 async fn test_timeout() {
236 let result = timeout(
237 poll_fn(|_| -> Poll<io::Result<()>> { Poll::Pending }),
238 Some(Duration::from_millis(20)),
239 )
240 .await;
241
242 assert_eq!(result.unwrap_err().kind(), io::ErrorKind::TimedOut);
243 }
244
245 #[hala_io_test::test]
246 async fn test_sleep() {
247 sleep(Duration::from_secs(1)).await.unwrap();
248 }
249}