use std::pin::Pin;
use std::task::{Context, Poll, Waker};
use std::time::Duration;
use futures_core::Stream;
use super::Clock;
use super::timers::TimerKey;
use crate::timers::TIMER_RESOLUTION;
#[derive(Debug)]
pub struct PeriodicTimer {
period: Duration,
clock: Clock,
current_timer: Option<TimerKey>,
}
impl PeriodicTimer {
#[must_use]
pub fn new(clock: &Clock, period: Duration) -> Self {
let period = period.max(TIMER_RESOLUTION);
Self {
current_timer: None,
period,
clock: clock.clone(),
}
}
fn register_timer(&mut self, waker: Waker) {
match self.clock.instant().checked_add(self.period) {
Some(when) => {
self.current_timer = Some(self.clock.register_timer(when, waker));
}
None => {
self.period = Duration::MAX;
}
}
}
}
impl Stream for PeriodicTimer {
type Item = ();
#[cfg_attr(test, mutants::skip)] fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
if this.period == Duration::MAX {
return Poll::Pending;
}
match this.current_timer {
Some(key) if key.tick() <= this.clock.instant() => {
this.current_timer = None;
this.clock.unregister_timer(key);
Poll::Ready(Some(()))
}
Some(_) => Poll::Pending,
None => {
this.register_timer(cx.waker().clone());
Poll::Pending
}
}
}
}
impl Drop for PeriodicTimer {
fn drop(&mut self) {
if let Some(key) = self.current_timer {
self.clock.unregister_timer(key);
}
}
}
#[cfg_attr(coverage_nightly, coverage(off))]
#[cfg(test)]
mod tests {
use std::thread;
use super::*;
use crate::ClockControl;
#[test]
fn assert_types() {
static_assertions::assert_impl_all!(PeriodicTimer: Send, Sync);
}
#[cfg_attr(miri, ignore)]
#[tokio::test]
async fn next_ensure_awaited() {
use futures::StreamExt;
use crate::FutureExt;
let clock = Clock::new_tokio();
let mut timer = PeriodicTimer::new(&clock, Duration::from_millis(1));
async move {
assert_eq!(timer.next().await, Some(()));
assert_eq!(timer.next().await, Some(()));
}
.timeout(&clock, Duration::from_secs(5))
.await
.unwrap();
}
#[test]
fn next_with_control() {
let control = ClockControl::new();
let clock = control.to_clock();
let mut timer = PeriodicTimer::new(&clock, Duration::from_millis(1));
assert_eq!(poll_timer(&mut timer), Poll::Pending);
thread::sleep(Duration::from_millis(1));
assert_eq!(poll_timer(&mut timer), Poll::Pending);
let len = control.timers_len();
control.advance(Duration::from_millis(2));
assert_eq!(control.timers_len(), len - 1);
assert_eq!(poll_timer(&mut timer), Poll::Ready(Some(())));
}
#[test]
fn first_poll_next_should_be_pending() {
let clock = Clock::new_frozen();
let mut timer = PeriodicTimer::new(&clock, Duration::from_millis(1));
assert_eq!(poll_timer(&mut timer), Poll::Pending);
}
#[test]
fn new_zero_duration_period_adjusted() {
let clock = Clock::new_frozen();
let timer = PeriodicTimer::new(&clock, Duration::ZERO);
assert_eq!(timer.period, Duration::from_millis(1));
}
#[test]
fn new_duration_near_max_never_fires() {
let clock = Clock::new_frozen();
let mut timer = PeriodicTimer::new(&clock, Duration::MAX.saturating_sub(Duration::from_millis(1)));
assert_eq!(poll_timer(&mut timer), Poll::Pending);
assert_eq!(poll_timer(&mut timer), Poll::Pending);
assert_eq!(timer.period, Duration::MAX);
assert_eq!(timer.current_timer, None);
}
#[cfg_attr(miri, ignore)]
#[tokio::test]
async fn ready_without_advancing_timers_ensure_timer_unregistered() {
let clock = Clock::new_tokio();
let period = Duration::from_millis(1);
let mut timer = PeriodicTimer::new(&clock, period);
assert_eq!(poll_timer(&mut timer), Poll::Pending);
thread::sleep(period);
assert_eq!(poll_timer(&mut timer), Poll::Ready(Some(())));
assert_eq!(timer.current_timer, None);
assert_eq!(clock.clock_state().timers_len(), 0);
}
#[test]
fn drop_periodic_timer_unregisters_timer() {
let clock = Clock::new_frozen();
let period = Duration::from_millis(1);
{
let mut timer = PeriodicTimer::new(&clock, period);
assert_eq!(poll_timer(&mut timer), Poll::Pending);
assert_eq!(clock.clock_state().timers_len(), 1);
}
assert_eq!(clock.clock_state().timers_len(), 0);
}
fn poll_timer(delay: &mut PeriodicTimer) -> Poll<Option<()>> {
let waker = Waker::noop().clone();
let mut cx = Context::from_waker(&waker);
let delay = std::pin::pin!(delay);
delay.poll_next(&mut cx)
}
}