hala_io_util/
timeout.rs

1use std::{
2    fmt::Debug,
3    future::Future,
4    io,
5    pin::Pin,
6    task::{Context, Poll},
7    time::Duration,
8};
9
10use hala_io_driver::{
11    current_poller, get_driver, Cmd, Description, Driver, Handle, Interest, OpenFlags,
12};
13
14pub struct Sleep {
15    fd: Option<Handle>,
16    driver: Driver,
17    expired: Duration,
18    poller: Handle,
19}
20
21impl Sleep {
22    pub fn new(expired: Duration) -> io::Result<Self> {
23        let driver = get_driver()?;
24
25        let poller = current_poller()?;
26
27        Ok(Self {
28            fd: None,
29            driver,
30            expired,
31            poller,
32        })
33    }
34}
35
36impl Future for Sleep {
37    type Output = io::Result<()>;
38
39    fn poll(mut self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
40        // first time, create timeout fd
41        if self.fd.is_none() {
42            let fd = match self
43                .driver
44                .fd_open(Description::Timeout, OpenFlags::Duration(self.expired))
45            {
46                Err(err) => return Poll::Ready(Err(err)),
47                Ok(fd) => fd,
48            };
49
50            self.fd = Some(fd);
51
52            match self.driver.fd_cntl(
53                self.poller,
54                Cmd::Register {
55                    source: fd,
56                    interests: Interest::Readable,
57                },
58            ) {
59                Err(err) => return Poll::Ready(Err(err)),
60                _ => {}
61            }
62
63            log::trace!("create timeout {:?}", fd);
64        }
65
66        // try check status of timeout fd
67        match self
68            .driver
69            .fd_cntl(self.fd.unwrap(), Cmd::Timeout(cx.waker().clone()))
70        {
71            Ok(resp) => match resp.try_into_timeout() {
72                Ok(status) => {
73                    if status {
74                        return Poll::Ready(Ok(()));
75                    }
76                }
77
78                Err(err) => {
79                    return Poll::Ready(Err(err));
80                }
81            },
82            Err(err) => return Poll::Ready(Err(err)),
83        }
84
85        return Poll::Pending;
86    }
87}
88
89impl Drop for Sleep {
90    fn drop(&mut self) {
91        if let Some(fd) = self.fd.take() {
92            log::trace!("dropping sleep, fd={:?}", fd);
93
94            self.driver
95                .fd_cntl(self.poller, Cmd::Deregister(fd))
96                .unwrap();
97
98            self.driver.fd_close(fd).unwrap();
99        }
100    }
101}
102
103pub struct Timeout<Fut> {
104    fut: Pin<Box<Fut>>,
105    fd: Option<Handle>,
106    driver: Driver,
107    expired: Duration,
108    poller: Handle,
109}
110
111impl<Fut> Timeout<Fut> {
112    pub fn new(fut: Fut, expired: Duration) -> io::Result<Self> {
113        let driver = get_driver()?;
114
115        let poller = current_poller()?;
116
117        Ok(Self {
118            fut: Box::pin(fut),
119            fd: None,
120            driver,
121            expired,
122            poller,
123        })
124    }
125}
126
127impl<'a, Fut, R> Future for Timeout<Fut>
128where
129    Fut: Future<Output = io::Result<R>> + 'a,
130    R: Debug,
131{
132    type Output = io::Result<R>;
133
134    fn poll(mut self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
135        // first time, create timeout fd
136        if self.fd.is_none() {
137            let fd = match self
138                .driver
139                .fd_open(Description::Timeout, OpenFlags::Duration(self.expired))
140            {
141                Err(err) => return Poll::Ready(Err(err)),
142                Ok(fd) => fd,
143            };
144
145            self.fd = Some(fd);
146
147            match self.driver.fd_cntl(
148                self.poller,
149                Cmd::Register {
150                    source: fd,
151                    interests: Interest::Readable,
152                },
153            ) {
154                Err(err) => return Poll::Ready(Err(err)),
155                _ => {}
156            }
157
158            log::trace!("create timeout {:?}", fd);
159        }
160
161        // try poll fut once
162        match self.fut.as_mut().poll(cx) {
163            Poll::Ready(r) => {
164                log::trace!("timeout poll future ready {:?}", r);
165                return Poll::Ready(r);
166            }
167            _ => {
168                log::trace!("timeout poll future pending");
169            }
170        }
171
172        // try check status of timeout fd
173        match self
174            .driver
175            .fd_cntl(self.fd.unwrap(), Cmd::Timeout(cx.waker().clone()))
176        {
177            Ok(resp) => match resp.try_into_timeout() {
178                Ok(status) => {
179                    if status {
180                        return Poll::Ready(Err(io::Error::new(
181                            io::ErrorKind::TimedOut,
182                            format!("async io timeout={:?}", self.expired),
183                        )));
184                    }
185                }
186
187                Err(err) => {
188                    return Poll::Ready(Err(err));
189                }
190            },
191            Err(err) => return Poll::Ready(Err(err)),
192        }
193
194        return Poll::Pending;
195    }
196}
197
198impl<Fut> Drop for Timeout<Fut> {
199    fn drop(&mut self) {
200        if let Some(fd) = self.fd.take() {
201            self.driver
202                .fd_cntl(self.poller, Cmd::Deregister(fd))
203                .unwrap();
204
205            self.driver.fd_close(fd).unwrap();
206        }
207    }
208}
209
210/// Add timeout feature for exists `Fut`
211pub async fn timeout<'a, Fut, R>(fut: Fut, expired: Option<Duration>) -> io::Result<R>
212where
213    Fut: Future<Output = io::Result<R>> + 'a,
214    R: Debug,
215{
216    if let Some(expired) = expired {
217        Timeout::new(fut, expired)?.await
218    } else {
219        fut.await
220    }
221}
222
223/// Sleep for a while
224pub async fn sleep(duration: Duration) -> io::Result<()> {
225    Sleep::new(duration)?.await
226}
227
228#[cfg(test)]
229mod tests {
230    use futures::future::poll_fn;
231
232    use super::*;
233
234    #[hala_io_test::test]
235    async fn test_timeout() {
236        let result = timeout(
237            poll_fn(|_| -> Poll<io::Result<()>> { Poll::Pending }),
238            Some(Duration::from_millis(20)),
239        )
240        .await;
241
242        assert_eq!(result.unwrap_err().kind(), io::ErrorKind::TimedOut);
243    }
244
245    #[hala_io_test::test]
246    async fn test_sleep() {
247        sleep(Duration::from_secs(1)).await.unwrap();
248    }
249}