1use std::{
9 ffi::c_void,
10 future::poll_fn,
11 io,
12 mem::{self, MaybeUninit},
13 ops::Neg,
14 os::fd::{AsRawFd, FromRawFd, IntoRawFd, OwnedFd, RawFd},
15 pin::Pin,
16 ptr,
17 task::{Context, Poll, ready},
18 time::{Duration, SystemTime},
19};
20
21use futures_core::Stream;
22use tokio::io::unix::AsyncFd;
23
24#[derive(Debug)]
26pub struct Deadline {
27 inner: AsyncFd<RawSystemTimer>,
28}
29
30impl Deadline {
31 pub fn new(at: SystemTime) -> Self {
32 Self {
33 inner: {
34 let mut raw = RawSystemTimer::new().unwrap();
35 raw.set_timing(RawTiming {
36 initial_expiration: Some(at),
37 interval: None,
38 });
39 AsyncFd::new(raw).unwrap()
40 },
41 }
42 }
43 pub fn reset(&mut self, to: SystemTime) {
46 self.inner.get_mut().set_timing(RawTiming {
48 initial_expiration: Some(to),
49 interval: None,
50 });
51 }
52}
53
54impl Future for Deadline {
55 type Output = ();
56 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
57 let Self { ref inner } = *self;
58 ready!(inner.poll_read_ready(cx)).unwrap().clear_ready();
59 Poll::Ready(())
60 }
61}
62
63#[derive(Debug)]
67pub struct Interval {
68 missed_ticks: u64,
69 inner: AsyncFd<RawSystemTimer>,
70}
71
72impl Interval {
73 pub fn new(every: Duration) -> Self {
74 Self {
75 missed_ticks: 0,
76 inner: {
77 let mut raw = RawSystemTimer::new().unwrap();
78 raw.set_timing(RawTiming {
79 initial_expiration: None,
80 interval: Some(every),
81 });
82 AsyncFd::new(raw).unwrap()
83 },
84 }
85 }
86 pub fn poll_tick(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
87 let Self {
88 missed_ticks,
89 inner,
90 } = &mut *self;
91 loop {
92 if let Some(new) = missed_ticks.checked_sub(1) {
93 *missed_ticks = new;
94 return Poll::Ready(());
95 }
96 match ready!(inner.poll_read_ready_mut(cx)) {
97 Ok(mut gd) => {
98 let mut missed = MaybeUninit::<u64>::uninit();
99 let count = mem::size_of_val(&missed);
100 let n = unsafe {
101 libc::read(
102 gd.get_mut().as_raw_fd(),
103 missed.as_mut_ptr().cast::<c_void>(),
104 count,
105 )
106 };
107 assert_eq!(n, count as isize, "{}", io::Error::last_os_error());
108 *missed_ticks = unsafe { missed.assume_init() };
109 gd.clear_ready();
110 }
111 Err(e) => panic!("io error on TimerFd {inner:?}: {e}"),
112 }
113 }
114 }
115 pub async fn tick(&mut self) {
116 poll_fn(|cx| Pin::new(&mut *self).poll_tick(cx)).await
117 }
118}
119
120impl Stream for Interval {
121 type Item = ();
122 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
123 self.poll_tick(cx).map(Some)
124 }
125}
126
127#[derive(Debug)]
131pub struct IntervalAfter {
132 wrapped: Interval,
133}
134
135impl IntervalAfter {
136 pub fn new(after: SystemTime, every: Duration) -> Self {
137 Self {
138 wrapped: Interval {
139 missed_ticks: 0,
140 inner: {
141 let mut raw = RawSystemTimer::new().unwrap();
142 raw.set_timing(RawTiming {
143 initial_expiration: Some(after),
144 interval: Some(every),
145 });
146 AsyncFd::new(raw).unwrap()
147 },
148 },
149 }
150 }
151
152 pub fn poll_tick(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
153 Pin::new(&mut self.wrapped).poll_tick(cx)
154 }
155 pub async fn tick(&mut self) {
156 poll_fn(|cx| Pin::new(&mut *self).poll_tick(cx)).await
157 }
158}
159
160impl Stream for IntervalAfter {
161 type Item = ();
162 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
163 self.poll_tick(cx).map(Some)
164 }
165}
166
167#[derive(Debug)]
168pub struct RawSystemTimer {
169 fd: OwnedFd,
170}
171
172#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Default, Hash)]
173pub struct RawTiming {
174 pub initial_expiration: Option<SystemTime>,
175 pub interval: Option<Duration>,
176}
177
178impl RawTiming {
179 pub const DISABLED: Self = Self {
180 initial_expiration: None,
181 interval: None,
182 };
183 pub fn is_disabled(&self) -> bool {
184 matches!(*self, Self::DISABLED)
185 }
186 pub fn is_enabled(&self) -> bool {
187 !self.is_disabled()
188 }
189}
190
191impl FromRawFd for RawSystemTimer {
192 unsafe fn from_raw_fd(fd: RawFd) -> Self {
193 Self {
194 fd: unsafe { OwnedFd::from_raw_fd(fd) },
195 }
196 }
197}
198
199impl IntoRawFd for RawSystemTimer {
200 fn into_raw_fd(self) -> RawFd {
201 let Self { fd } = self;
202 fd.into_raw_fd()
203 }
204}
205
206const TIMESPEC_ZERO: libc::timespec = libc::timespec {
207 tv_sec: 0,
208 tv_nsec: 0,
209};
210
211impl RawSystemTimer {
212 pub fn new() -> io::Result<Self> {
213 let res = unsafe { libc::timerfd_create(libc::CLOCK_REALTIME, libc::TFD_NONBLOCK) };
214 match res == -1 {
215 false => Ok(Self {
216 fd: unsafe { OwnedFd::from_raw_fd(res) },
217 }),
218 true => Err(io::Error::last_os_error()),
219 }
220 }
221 pub fn timing(&self) -> RawTiming {
222 let Self { fd } = self;
223 let mut out = MaybeUninit::uninit();
224 let rc = unsafe { libc::timerfd_gettime(fd.as_raw_fd(), out.as_mut_ptr()) };
225 assert_eq!(
226 rc,
227 0,
228 "timerfd_gettime({fd:?}) threw {}",
229 io::Error::last_os_error()
230 );
231 let libc::itimerspec {
232 it_value,
233 it_interval,
234 } = unsafe { out.assume_init() };
235
236 let initial_expiration = match it_value {
237 libc::timespec {
238 tv_sec: 0,
239 tv_nsec: 0,
240 } => None,
241 libc::timespec { tv_sec, tv_nsec } => {
242 let mut instant = SystemTime::UNIX_EPOCH;
243 match tv_sec.is_positive() {
244 true => instant += Duration::from_secs(tv_sec as u64),
245 false => instant -= Duration::from_secs(tv_sec.unsigned_abs()),
246 }
247 match tv_nsec.is_positive() {
248 true => instant += Duration::from_nanos(tv_sec as u64),
249 false => instant -= Duration::from_nanos(tv_sec.unsigned_abs()),
250 }
251 Some(instant)
252 }
253 };
254
255 let interval = match it_interval {
256 libc::timespec {
257 tv_sec: 0,
258 tv_nsec: 0,
259 } => None,
260 libc::timespec { tv_sec, tv_nsec } => {
261 let mut duration = Duration::ZERO;
262 match tv_sec.is_positive() {
263 true => duration += Duration::from_secs(tv_sec as u64),
264 false => duration -= Duration::from_secs(tv_sec.unsigned_abs()),
265 }
266 match tv_nsec.is_positive() {
267 true => duration += Duration::from_nanos(tv_sec as u64),
268 false => duration -= Duration::from_nanos(tv_sec.unsigned_abs()),
269 }
270 Some(duration)
271 }
272 };
273
274 RawTiming {
275 initial_expiration,
276 interval,
277 }
278 }
279 pub fn set_timing(&mut self, to: RawTiming) {
280 let Self { fd } = self;
281 let RawTiming {
282 initial_expiration,
283 interval,
284 } = to;
285 let it_value = match initial_expiration {
286 Some(SystemTime::UNIX_EPOCH) => panic!(
287 "Due to platform limitations, SystemTime::UNIX_EPOCH is not supported as an `initial_expiration`"
288 ),
289 Some(time) => {
290 let (neg, dur) = match time.duration_since(SystemTime::UNIX_EPOCH) {
291 Ok(pos) => (false, pos),
292 Err(e) => (true, e.duration()),
293 };
294 libc::timespec {
295 tv_sec: neg_if(
296 i64::try_from(dur.as_secs()).unwrap_or_else(|e| {
297 panic!("can't fit duration {dur:?} in a timespec: {e}")
298 }),
299 neg,
300 ),
301 tv_nsec: neg_if(i64::from(dur.subsec_nanos()), neg),
302 }
303 }
304 None => TIMESPEC_ZERO,
305 };
306 let it_interval = match interval {
307 Some(dur) => libc::timespec {
308 tv_sec: dur
309 .as_secs()
310 .try_into()
311 .unwrap_or_else(|e| panic!("can't fit duration {dur:?} in a timespec: {e}")),
312 tv_nsec: i64::from(dur.subsec_nanos()),
313 },
314 None => TIMESPEC_ZERO,
315 };
316 let rc = unsafe {
317 libc::timerfd_settime(
318 fd.as_raw_fd(),
319 libc::TFD_TIMER_ABSTIME,
320 &libc::itimerspec {
321 it_interval,
322 it_value,
323 },
324 ptr::null_mut(),
325 )
326 };
327 assert_eq!(
328 rc,
329 0,
330 "timerfd_settime({fd:?}, {to:?}) threw {}",
331 io::Error::last_os_error()
332 );
333 }
334}
335
336fn neg_if<T: Neg<Output = T>>(t: T, neg: bool) -> T {
337 match neg {
338 true => -t,
339 false => t,
340 }
341}
342
343impl AsRawFd for RawSystemTimer {
344 fn as_raw_fd(&self) -> RawFd {
345 self.fd.as_raw_fd()
346 }
347}
348
349#[cfg(test)]
350mod tests {
351 use assert_matches::assert_matches;
352
353 use super::*;
354
355 #[test]
356 fn raw() {
357 let mut timer = RawSystemTimer::new().unwrap();
358 let timing = timer.timing();
359 assert_eq!(timing, RawTiming::DISABLED);
360 timer.set_timing(RawTiming {
361 initial_expiration: Some(SystemTime::now()),
362 interval: Some(Duration::from_mins(60)),
363 });
364 assert_matches!(
365 timer.timing(),
366 RawTiming {
367 initial_expiration: Some(_),
368 interval: Some(_)
369 }
370 );
371 }
372
373 #[tokio::test]
374 async fn deadline() {
375 let t0 = SystemTime::now();
376 let duration = Duration::from_secs(5);
377 Deadline::new(t0 + duration).await;
378 assert!(t0.elapsed().unwrap() >= duration);
379 }
380
381 #[tokio::test]
382 async fn interval_since() {
383 let mut interval = IntervalAfter::new(
384 SystemTime::UNIX_EPOCH + Duration::from_nanos(1),
385 Duration::from_hours(1),
386 );
387 interval.tick().await;
388 assert!(interval.wrapped.missed_ticks > 100);
389 }
390}