agnostic_lite/tokio/
interval.rs

1use core::{
2  pin::Pin,
3  task::{Context, Poll},
4  time::Duration,
5};
6use std::time::Instant;
7
8use futures_util::stream::Stream;
9
10use crate::time::{AsyncLocalInterval, AsyncLocalIntervalExt};
11
12pin_project_lite::pin_project! {
13  /// The [`AsyncInterval`] implementation for tokio runtime
14  #[repr(transparent)]
15  pub struct TokioInterval {
16    #[pin]
17    inner: ::tokio::time::Interval,
18  }
19}
20
21impl From<::tokio::time::Interval> for TokioInterval {
22  fn from(interval: ::tokio::time::Interval) -> Self {
23    Self { inner: interval }
24  }
25}
26
27impl From<TokioInterval> for ::tokio::time::Interval {
28  fn from(interval: TokioInterval) -> Self {
29    interval.inner
30  }
31}
32
33impl Stream for TokioInterval {
34  type Item = tokio::time::Instant;
35
36  fn poll_next(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Option<Self::Item>> {
37    self
38      .project()
39      .inner
40      .poll_tick(cx)
41      .map(Some)
42  }
43}
44
45impl AsyncLocalInterval for TokioInterval {
46  type Instant = ::tokio::time::Instant;
47
48  fn reset(&mut self, interval: Duration) {
49    self.inner.reset_after(interval);
50  }
51
52  fn reset_at(&mut self, instant: Self::Instant) {
53    self.inner.reset_at(instant);
54  }
55  
56  fn poll_tick(&mut self, cx: &mut Context<'_>) -> Poll<Self::Instant> {
57    self.inner.poll_tick(cx)
58  }
59}
60
61impl AsyncLocalIntervalExt for TokioInterval {
62  fn interval_local(period: Duration) -> Self
63  where
64    Self: Sized,
65  {
66    Self {
67      inner: tokio::time::interval_at((Instant::now() + period).into(), period),
68    }
69  }
70
71  fn interval_local_at(start: Self::Instant, period: Duration) -> Self
72  where
73    Self: Sized,
74  {
75    Self {
76      inner: tokio::time::interval_at(start, period),
77    }
78  }
79}
80
81#[cfg(test)]
82mod tests {
83  use futures::StreamExt;
84
85  use super::TokioInterval;
86  use crate::time::{AsyncInterval, AsyncIntervalExt};
87  use tokio::time::{Duration, Instant};
88
89  const INTERVAL: Duration = Duration::from_millis(100);
90  const BOUND: Duration = Duration::from_millis(50);
91  const IMMEDIATE: Duration = Duration::from_millis(1);
92
93  #[tokio::test]
94  async fn test_interval() {
95    let start = Instant::now();
96    let interval = TokioInterval::interval(INTERVAL);
97    let mut interval = interval.take(3);
98
99    let ins = interval.next().await.unwrap();
100    let elapsed = start.elapsed();
101    assert!(ins >= start + INTERVAL - BOUND);
102    assert!(elapsed >= INTERVAL - BOUND && elapsed <= INTERVAL + BOUND);
103
104    let ins = interval.next().await.unwrap();
105    let elapsed = start.elapsed();
106    assert!(ins >= start + INTERVAL * 2 - BOUND);
107    assert!(elapsed >= INTERVAL * 2 - BOUND && elapsed <= INTERVAL * 2 + BOUND);
108
109    let ins = interval.next().await.unwrap();
110    let elapsed = start.elapsed();
111    assert!(ins >= start + INTERVAL * 3 - BOUND);
112    assert!(elapsed >= INTERVAL * 3 - BOUND && elapsed <= INTERVAL * 3 + BOUND);
113
114    assert!(interval.next().await.is_none());
115  }
116
117  #[tokio::test(flavor = "multi_thread")]
118  async fn test_interval_at() {
119    let start = Instant::now();
120    let interval = TokioInterval::interval_at(Instant::now(), INTERVAL);
121    let mut interval = interval.take(4);
122
123    // The first tick is immediate
124    let ins = interval.next().await.unwrap();
125    let elapsed = start.elapsed();
126    assert!(ins <= start + IMMEDIATE);
127    assert!(elapsed <= IMMEDIATE + BOUND);
128
129    let ins = interval.next().await.unwrap();
130    let elapsed = start.elapsed();
131    assert!(ins >= start + INTERVAL - BOUND);
132    assert!(elapsed >= INTERVAL - BOUND && elapsed <= INTERVAL + BOUND);
133
134    let ins = interval.next().await.unwrap();
135    let elapsed = start.elapsed();
136    assert!(ins >= start + INTERVAL * 2 - BOUND);
137    assert!(elapsed >= INTERVAL * 2 - BOUND && elapsed <= INTERVAL * 2 + BOUND);
138
139    let ins = interval.next().await.unwrap();
140    let elapsed = start.elapsed();
141    assert!(ins >= start + INTERVAL * 3 - BOUND);
142    assert!(elapsed >= INTERVAL * 3 - BOUND && elapsed <= INTERVAL * 3 + BOUND);
143
144    assert!(interval.next().await.is_none());
145  }
146
147  #[tokio::test(flavor = "multi_thread")]
148  async fn test_interval_reset() {
149    let start = Instant::now();
150    let mut interval = TokioInterval::interval(INTERVAL);
151
152    let ins = interval.next().await.unwrap();
153    let elapsed = start.elapsed();
154    assert!(ins >= start + INTERVAL - BOUND);
155    assert!(elapsed >= INTERVAL - BOUND && elapsed <= INTERVAL + BOUND);
156
157    // Reset the next tick to 2x
158    interval.reset(INTERVAL * 2);
159    let ins = interval.next().await.unwrap();
160    let elapsed = start.elapsed();
161    // interval + 2x interval, so 3 here
162    assert!(ins >= start + INTERVAL * 3 - BOUND);
163    assert!(elapsed >= INTERVAL * 3 - BOUND && elapsed <= INTERVAL * 3 + BOUND);
164
165    let ins = interval.next().await.unwrap();
166    let elapsed = start.elapsed();
167    // interval + 2x interval + interval, so 4 here
168    assert!(ins >= start + INTERVAL * 4 - BOUND);
169    assert!(elapsed >= INTERVAL * 4 - BOUND && elapsed <= INTERVAL * 4 + BOUND);
170  }
171
172  #[tokio::test(flavor = "multi_thread")]
173  async fn test_interval_reset_at() {
174    let start = Instant::now();
175    let mut interval = TokioInterval::interval(INTERVAL);
176
177    let ins = interval.next().await.unwrap();
178    let elapsed = start.elapsed();
179    assert!(ins >= start + INTERVAL);
180    assert!(elapsed >= INTERVAL && elapsed <= INTERVAL + BOUND);
181
182    // Reset the next tick to 2x
183    interval.reset_at(start + INTERVAL * 3);
184    let ins = interval.next().await.unwrap();
185    let elapsed = start.elapsed();
186    // interval + 2x interval, so 3 here
187    assert!(ins >= start + INTERVAL * 3);
188    assert!(elapsed >= INTERVAL * 3 - BOUND && elapsed <= INTERVAL * 3 + BOUND);
189
190    let ins = interval.next().await.unwrap();
191    let elapsed = start.elapsed();
192    // interval + 2x interval + interval, so 4 here
193    assert!(ins >= start + INTERVAL * 4 - BOUND);
194    assert!(elapsed >= INTERVAL * 4 - BOUND && elapsed <= INTERVAL * 4 + BOUND);
195  }
196}