use crate::{Sleep, SleepTrait, TimeServiceTrait, ZERO_DURATION};
use aptos_infallible::Mutex;
use futures::future::Future;
use std::{
cmp::max,
collections::btree_map::BTreeMap,
fmt::Debug,
pin::Pin,
sync::{Arc, MutexGuard},
task::{Context, Poll, Waker},
time::{Duration, Instant},
};
#[inline]
#[allow(clippy::integer_arithmetic)]
fn duration_max() -> Duration {
Duration::new(std::u64::MAX, 1_000_000_000 - 1)
}
type SleepIndex = usize;
#[derive(Clone, Debug)]
pub struct MockTimeService {
inner: Arc<Mutex<Inner>>,
}
#[derive(Debug)]
struct Inner {
base_time: Instant,
now: Duration,
auto_advance_deadline: Option<Duration>,
next_sleep_index: SleepIndex,
pending: BTreeMap<(Duration, SleepIndex), Option<Waker>>,
}
#[derive(Debug)]
pub struct MockSleep {
time_service: MockTimeService,
deadline: Duration,
index: SleepIndex,
}
impl MockTimeService {
pub fn new() -> Self {
Self {
inner: Arc::new(Mutex::new(Inner {
base_time: Instant::now(),
now: ZERO_DURATION,
auto_advance_deadline: None,
next_sleep_index: 0,
pending: BTreeMap::new(),
})),
}
}
pub fn new_auto_advance() -> Self {
Self::new_auto_advance_for(duration_max())
}
pub fn new_auto_advance_for(deadline: Duration) -> Self {
Self {
inner: Arc::new(Mutex::new(Inner {
base_time: Instant::now(),
now: ZERO_DURATION,
auto_advance_deadline: Some(deadline),
next_sleep_index: 0,
pending: BTreeMap::new(),
})),
}
}
fn lock(&self) -> MutexGuard<'_, Inner> {
self.inner.lock()
}
pub fn num_waiters(&self) -> usize {
self.lock().pending.len()
}
pub fn advance_next(&self) -> Option<Duration> {
self.lock().advance_next()
}
pub fn advance(&self, duration: Duration) -> usize {
self.lock().advance(duration)
}
pub fn advance_secs(&self, duration: u64) -> usize {
self.lock().advance(Duration::from_secs(duration))
}
pub fn advance_ms(&self, duration: u64) -> usize {
self.lock().advance(Duration::from_millis(duration))
}
pub async fn advance_next_async(&self) -> Option<Duration> {
let wake_time = self.lock().advance_next();
tokio::task::yield_now().await;
wake_time
}
pub async fn advance_async(&self, duration: Duration) -> usize {
let num_woken = self.lock().advance(duration);
tokio::task::yield_now().await;
num_woken
}
pub async fn advance_secs_async(&self, duration: u64) -> usize {
self.advance_async(Duration::from_secs(duration)).await
}
pub async fn advance_ms_async(&self, duration: u64) -> usize {
self.advance_async(Duration::from_millis(duration)).await
}
}
impl TimeServiceTrait for MockTimeService {
fn now(&self) -> Instant {
let this = self.lock();
this.base_time + this.now
}
fn now_unix_time(&self) -> Duration {
self.lock().now
}
fn sleep(&self, duration: Duration) -> Sleep {
MockSleep::new(self.clone(), duration).into()
}
fn sleep_blocking(&self, duration: Duration) {
let delay = self.sleep(duration);
futures::executor::block_on(delay);
}
}
impl Inner {
fn advance_next(&mut self) -> Option<Duration> {
let deadline = self.trigger_min_sleep()?;
self.now = max(self.now, deadline);
Some(self.now)
}
fn advance(&mut self, duration: Duration) -> usize {
self.now += duration;
let num_waiters = self.pending.len();
let num_expired = self
.pending
.keys()
.position(|&(deadline, _index)| deadline > self.now)
.unwrap_or(num_waiters);
for _ in 0..num_expired {
self.trigger_min_sleep()
.expect("must be at least num_expired waiters");
}
num_expired
}
fn next_sleep_index(&mut self) -> SleepIndex {
let index = self.next_sleep_index;
self.next_sleep_index = self
.next_sleep_index
.checked_add(1)
.expect("too many sleep entries");
index
}
fn get_mut_sleep(
&mut self,
deadline: Duration,
index: SleepIndex,
) -> Option<&mut Option<Waker>> {
self.pending.get_mut(&(deadline, index))
}
fn is_sleep_registered(&self, deadline: Duration, index: SleepIndex) -> bool {
self.pending.contains_key(&(deadline, index))
}
fn register_sleep(
&mut self,
duration: Duration,
maybe_waker: Option<Waker>,
) -> (Duration, SleepIndex) {
let deadline = self.now + duration;
let index = self.next_sleep_index();
if let Some(auto_advance_deadline) = self.auto_advance_deadline {
if deadline <= auto_advance_deadline {
self.now += duration;
return (deadline, index);
} else {
self.now = max(self.now, auto_advance_deadline);
self.auto_advance_deadline = None;
}
}
let prev_entry = self.pending.insert((deadline, index), maybe_waker);
assert!(
prev_entry.is_none(),
"there can never be an entry at an unused SleepIndex"
);
(deadline, index)
}
fn unregister_sleep(&mut self, deadline: Duration, index: SleepIndex) -> Option<Option<Waker>> {
self.pending.remove(&(deadline, index))
}
fn unregister_min_sleep(&mut self) -> Option<((Duration, SleepIndex), Option<Waker>)> {
let (deadline, index) = self.pending.keys().next()?;
let deadline = *deadline;
let index = *index;
self.pending.remove_entry(&(deadline, index))
}
fn trigger_min_sleep(&mut self) -> Option<Duration> {
self.unregister_min_sleep()
.map(|((deadline, _index), maybe_waker)| {
if let Some(waker) = maybe_waker {
waker.wake();
}
deadline
})
}
}
impl MockSleep {
fn new(time_service: MockTimeService, duration: Duration) -> Self {
let (deadline, index) = time_service.lock().register_sleep(duration, None);
Self {
time_service,
deadline,
index,
}
}
}
impl SleepTrait for MockSleep {
fn is_elapsed(&self) -> bool {
let inner = self.time_service.lock();
!inner.is_sleep_registered(self.deadline, self.index)
}
fn reset(self: Pin<&mut Self>, duration: Duration) {
let this = self.get_mut();
let mut inner = this.time_service.lock();
let maybe_waker = inner.unregister_sleep(this.deadline, this.index).flatten();
let (deadline, index) = inner.register_sleep(duration, maybe_waker);
this.deadline = deadline;
this.index = index;
}
fn reset_until(self: Pin<&mut Self>, deadline: Instant) {
let this = self.get_mut();
let duration = deadline.saturating_duration_since(this.time_service.now());
Pin::new(this).reset(duration);
}
}
impl Future for MockSleep {
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut inner = self.time_service.lock();
let maybe_entry = inner.get_mut_sleep(self.deadline, self.index);
match maybe_entry {
Some(maybe_waker) => {
maybe_waker.replace(cx.waker().clone());
Poll::Pending
}
None => Poll::Ready(()),
}
}
}
impl Drop for MockSleep {
fn drop(&mut self) {
self.time_service
.lock()
.unregister_sleep(self.deadline, self.index);
}
}
#[cfg(test)]
mod test {
use super::*;
use futures::channel::oneshot;
use tokio_test::{
assert_pending, assert_ready, assert_ready_eq, assert_ready_err, assert_ready_ok, task,
};
#[tokio::test]
async fn test_sleep() {
let time = MockTimeService::new();
assert_eq!(time.num_waiters(), 0);
assert_eq!(time.advance_next_async().await, None);
let start = time.now();
let start_unix = time.now_unix_time();
assert_eq!(time.advance_async(ms(10)).await, 0);
assert_eq!(time.now() - start, ms(10));
assert_eq!(time.now_unix_time() - start_unix, ms(10));
let mut sleep = task::spawn(time.sleep(ms(10)));
assert!(!sleep.is_woken());
assert_eq!(time.num_waiters(), 1);
assert_pending!(sleep.poll());
assert_eq!(time.advance_async(ms(5)).await, 0);
assert!(!sleep.is_woken());
assert_pending!(sleep.poll());
assert_eq!(time.advance_async(ms(5)).await, 1);
assert_eq!(time.num_waiters(), 0);
assert!(sleep.is_woken());
assert_ready!(sleep.poll());
sleep.enter(|_c, sleep| sleep.reset(ms(5)));
assert!(!sleep.is_woken());
assert_eq!(time.num_waiters(), 1);
assert_pending!(sleep.poll());
assert_eq!(time.advance_async(ms(5)).await, 1);
assert_eq!(time.num_waiters(), 0);
assert!(sleep.is_woken());
assert_ready!(sleep.poll());
sleep.enter(|_c, sleep| sleep.reset(ms(5)));
assert!(!sleep.is_woken());
assert_eq!(time.advance_async(ms(5)).await, 1);
assert_eq!(time.num_waiters(), 0);
assert!(!sleep.is_woken());
assert_ready!(sleep.poll());
}
#[tokio::test]
async fn test_sleep_until() {
let time = MockTimeService::new();
assert_eq!(time.num_waiters(), 0);
let start = time.now();
let mut sleep = task::spawn(time.sleep_until(start + ms(10)));
assert!(!sleep.is_woken());
assert_eq!(time.num_waiters(), 1);
assert_pending!(sleep.poll());
assert_eq!(time.advance_async(ms(5)).await, 0);
assert!(!sleep.is_woken());
assert_pending!(sleep.poll());
assert_eq!(time.advance_async(ms(5)).await, 1);
assert_eq!(time.num_waiters(), 0);
assert!(sleep.is_woken());
assert_ready!(sleep.poll());
sleep.enter(|_c, sleep| sleep.reset_until(time.now() + ms(5)));
assert!(!sleep.is_woken());
assert_eq!(time.num_waiters(), 1);
assert_pending!(sleep.poll());
assert_eq!(time.advance_async(ms(5)).await, 1);
assert_eq!(time.num_waiters(), 0);
assert!(sleep.is_woken());
assert_ready!(sleep.poll());
}
#[tokio::test]
async fn test_many_sleep() {
let time = MockTimeService::new();
assert_eq!(time.num_waiters(), 0);
let mut sleep_5ms = task::spawn(time.sleep(ms(5)));
let mut sleep_10ms_1 = task::spawn(time.sleep(ms(10)));
let mut sleep_10ms_2 = task::spawn(time.sleep(ms(10)));
let mut sleep_10ms_3 = task::spawn(time.sleep(ms(10)));
let mut sleep_15ms = task::spawn(time.sleep(ms(15)));
let mut sleep_20ms = task::spawn(time.sleep(ms(20)));
assert_eq!(time.num_waiters(), 6);
assert_pending!(sleep_5ms.poll());
assert_pending!(sleep_10ms_1.poll());
assert_pending!(sleep_10ms_2.poll());
assert_pending!(sleep_10ms_3.poll());
assert_pending!(sleep_15ms.poll());
assert_pending!(sleep_20ms.poll());
assert_eq!(time.advance_async(ms(10)).await, 4);
assert_eq!(time.num_waiters(), 2);
assert_ready!(sleep_5ms.poll());
assert_ready!(sleep_10ms_1.poll());
assert_ready!(sleep_10ms_2.poll());
assert_ready!(sleep_10ms_3.poll());
assert_pending!(sleep_15ms.poll());
assert_pending!(sleep_20ms.poll());
assert_eq!(time.advance_next_async().await, Some(ms(15)));
assert_eq!(time.num_waiters(), 1);
assert_ready!(sleep_15ms.poll());
assert_pending!(sleep_20ms.poll());
assert_eq!(time.advance_next_async().await, Some(ms(20)));
assert_eq!(time.num_waiters(), 0);
assert_ready!(sleep_20ms.poll());
}
#[tokio::test]
async fn test_interval() {
let time = MockTimeService::new();
let mut interval = task::spawn(time.interval(ms(10)));
assert_pending!(interval.poll_next());
assert!(!interval.is_woken());
assert_eq!(time.advance_next_async().await, Some(ms(0)));
assert!(interval.is_woken());
assert_ready_eq!(interval.poll_next(), Some(()));
assert_pending!(interval.poll_next());
assert_eq!(time.advance_async(ms(5)).await, 0);
assert!(!interval.is_woken());
assert_pending!(interval.poll_next());
assert_eq!(time.advance_async(ms(5)).await, 1);
assert!(interval.is_woken());
assert_ready_eq!(interval.poll_next(), Some(()));
assert_pending!(interval.poll_next());
}
#[tokio::test]
async fn test_timeout() {
let time = MockTimeService::new();
let mut timeout = task::spawn(time.timeout(ms(10), async {}));
assert_ready_ok!(timeout.poll());
let time = MockTimeService::new();
let (tx, rx) = oneshot::channel();
let mut timeout = task::spawn(time.timeout(ms(10), rx));
assert_pending!(timeout.poll());
assert_eq!(time.advance_async(ms(5)).await, 0);
assert_pending!(timeout.poll());
tx.send(()).unwrap();
assert!(timeout.is_woken());
assert_ready_ok!(timeout.poll()).unwrap();
let time = MockTimeService::new();
let (_tx, rx) = oneshot::channel::<()>();
let mut timeout = task::spawn(time.timeout(ms(10), rx));
assert_pending!(timeout.poll());
assert_eq!(time.advance_async(ms(15)).await, 1);
assert!(timeout.is_woken());
assert_ready_err!(timeout.poll());
}
#[tokio::test]
async fn test_auto_advance() {
let time = MockTimeService::new_auto_advance_for(ms(100));
assert_eq!(time.now_unix_time(), ms(0));
let mut sleep = task::spawn(time.sleep(ms(20)));
assert_eq!(time.now_unix_time(), ms(20));
assert_ready!(sleep.poll());
assert_eq!(time.now_unix_time(), ms(20));
let mut sleep = task::spawn(time.sleep(ms(30)));
assert_eq!(time.now_unix_time(), ms(50));
assert_ready!(sleep.poll());
assert_eq!(time.now_unix_time(), ms(50));
assert_eq!(time.advance_async(ms(30)).await, 0);
assert_eq!(time.now_unix_time(), ms(80));
let mut sleep = task::spawn(time.sleep(ms(90)));
assert_pending!(sleep.poll());
assert_eq!(time.now_unix_time(), ms(100));
assert_eq!(time.advance_async(ms(30)).await, 0);
assert_pending!(sleep.poll());
assert_eq!(time.now_unix_time(), ms(130));
assert_eq!(time.advance_async(ms(100)).await, 1);
assert_ready!(sleep.poll());
assert_eq!(time.now_unix_time(), ms(230));
}
#[test]
fn test_auto_advance_blocking() {
let time = MockTimeService::new_auto_advance();
time.sleep_blocking(secs(10_000));
assert_eq!(time.now_unix_time(), secs(10_000));
}
#[test]
fn test_auto_advance_interval() {
let time = MockTimeService::new_auto_advance_for(ms(10));
let mut interval = task::spawn(dbg!(time.interval(ms(4))));
assert_eq!(time.now_unix_time(), ms(0));
assert_ready!(interval.poll_next());
assert_eq!(time.now_unix_time(), ms(4));
assert_ready!(interval.poll_next());
assert_eq!(time.now_unix_time(), ms(8));
assert_ready!(interval.poll_next());
assert_eq!(time.now_unix_time(), ms(10));
assert_pending!(interval.poll_next());
}
fn ms(duration: u64) -> Duration {
Duration::from_millis(duration)
}
fn secs(duration: u64) -> Duration {
Duration::from_secs(duration)
}
}