use std::cmp::Reverse;
use std::future::Future;
use std::pin::Pin;
use std::sync::{Arc, Mutex, MutexGuard};
use std::task::{Context, Poll, Waker};
use web_time_compat::{Duration, Instant, InstantExt, SystemTime, SystemTimeExt};
use derive_more::AsMut;
use priority_queue::priority_queue::PriorityQueue;
use slotmap_careful::DenseSlotMap;
use tor_rtcompat::CoarseInstant;
use tor_rtcompat::CoarseTimeProvider;
use tor_rtcompat::SleepProvider;
use crate::time_core::MockTimeCore;
#[derive(Clone, Debug)]
pub struct SimpleMockTimeProvider {
state: Arc<Mutex<State>>,
}
pub(crate) use SimpleMockTimeProvider as Provider;
type Id = slotmap_careful::DefaultKey;
pub struct SleepFuture {
prov: Provider,
id: Id,
}
#[derive(Debug, AsMut)]
struct State {
core: MockTimeCore,
futures: DenseSlotMap<Id, Option<Waker>>,
unready: PriorityQueue<Id, Reverse<Instant>>,
}
impl Default for Provider {
fn default() -> Self {
Self::from_real()
}
}
impl Provider {
pub fn new(now: Instant, wallclock: SystemTime) -> Self {
let state = State {
core: MockTimeCore::new(now, wallclock),
futures: Default::default(),
unready: Default::default(),
};
Provider {
state: Arc::new(Mutex::new(state)),
}
}
pub fn from_real() -> Self {
Provider::from_wallclock(SystemTime::get())
}
pub fn from_wallclock(wallclock: SystemTime) -> Self {
Provider::new(Instant::get(), wallclock)
}
pub fn advance(&self, d: Duration) {
let mut state = self.lock();
state.core.advance(d);
state.wake_any();
}
pub fn jump_wallclock(&self, new_wallclock: SystemTime) {
self.lock().core.jump_wallclock(new_wallclock);
}
pub fn time_until_next_timeout(&self) -> Option<Duration> {
let state = self.lock();
let Reverse(until) = state.unready.peek()?.1;
let d = until.duration_since(state.core.instant());
Some(d)
}
fn lock(&self) -> MutexGuard<'_, State> {
self.state.lock().expect("simple time state poisoned")
}
}
impl SleepProvider for Provider {
type SleepFuture = SleepFuture;
fn sleep(&self, d: Duration) -> SleepFuture {
let mut state = self.lock();
let until = state.core.instant() + d;
let id = state.futures.insert(None);
state.unready.push(id, Reverse(until));
let fut = SleepFuture {
id,
prov: self.clone(),
};
state.wake_any();
fut
}
fn now(&self) -> Instant {
self.lock().core.instant()
}
fn wallclock(&self) -> SystemTime {
self.lock().core.wallclock()
}
}
impl CoarseTimeProvider for Provider {
fn now_coarse(&self) -> CoarseInstant {
self.lock().core.coarse().now_coarse()
}
}
impl Future for SleepFuture {
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
let mut state = self.prov.lock();
if let Some((_, Reverse(scheduled))) = state.unready.get(&self.id) {
assert!(*scheduled > state.core.instant());
let waker = Some(cx.waker().clone());
*state
.futures
.get_mut(self.id)
.expect("polling futures entry") = waker;
Poll::Pending
} else {
Poll::Ready(())
}
}
}
impl State {
fn wake_any(&mut self) {
loop {
match self.unready.peek() {
Some((_, Reverse(scheduled))) if *scheduled <= self.core.instant() => {
let (id, _) = self.unready.pop().expect("vanished");
let futures_entry = self.futures.get_mut(id).expect("stale unready entry");
if let Some(waker) = futures_entry.take() {
waker.wake();
}
}
_ => break,
}
}
}
}
impl Drop for SleepFuture {
fn drop(&mut self) {
let mut state = self.prov.lock();
let _: Option<Waker> = state.futures.remove(self.id).expect("entry vanished");
let _: Option<(Id, Reverse<Instant>)> = state.unready.remove(&self.id);
}
}
#[cfg(test)]
mod test {
#![allow(clippy::bool_assert_comparison)]
#![allow(clippy::clone_on_copy)]
#![allow(clippy::dbg_macro)]
#![allow(clippy::mixed_attributes_style)]
#![allow(clippy::print_stderr)]
#![allow(clippy::print_stdout)]
#![allow(clippy::single_char_pattern)]
#![allow(clippy::unwrap_used)]
#![allow(clippy::unchecked_time_subtraction)]
#![allow(clippy::useless_vec)]
#![allow(clippy::needless_pass_by_value)]
use super::*;
use crate::task::MockExecutor;
use Poll::*;
use futures::poll;
use humantime::parse_rfc3339;
use tor_rtcompat::ToplevelBlockOn as _;
fn ms(ms: u64) -> Duration {
Duration::from_millis(ms)
}
fn run_test<FUT>(f: impl FnOnce(Provider, MockExecutor) -> FUT)
where
FUT: Future<Output = ()>,
{
let sp = Provider::new(
Instant::get(), parse_rfc3339("2000-01-01T00:00:00Z").unwrap(),
);
let exec = MockExecutor::new();
exec.block_on(f(sp, exec.clone()));
}
#[test]
fn simple() {
run_test(|sp, _exec| async move {
let n1 = sp.now();
let w1 = sp.wallclock();
let mut f1 = sp.sleep(ms(500));
let mut f2 = sp.sleep(ms(1500));
assert_eq!(poll!(&mut f1), Pending);
sp.advance(ms(200));
assert_eq!(n1 + ms(200), sp.now());
assert_eq!(w1 + ms(200), sp.wallclock());
assert_eq!(poll!(&mut f1), Pending);
assert_eq!(poll!(&mut f2), Pending);
drop(f2);
sp.jump_wallclock(w1 + ms(10_000));
sp.advance(ms(300));
assert_eq!(n1 + ms(500), sp.now());
assert_eq!(w1 + ms(10_300), sp.wallclock());
assert_eq!(poll!(&mut f1), Ready(()));
let mut f0 = sp.sleep(ms(0));
assert_eq!(poll!(&mut f0), Ready(()));
});
}
#[test]
fn task() {
run_test(|sp, exec| async move {
let st = Arc::new(Mutex::new(0_i8));
exec.spawn_identified("test task", {
let st = st.clone();
let sp = sp.clone();
async move {
*st.lock().unwrap() = 1;
sp.sleep(ms(500)).await;
*st.lock().unwrap() = 2;
sp.sleep(ms(300)).await;
*st.lock().unwrap() = 3;
}
});
let st = move || *st.lock().unwrap();
assert_eq!(st(), 0);
exec.progress_until_stalled().await;
assert_eq!(st(), 1);
assert_eq!(sp.time_until_next_timeout(), Some(ms(500)));
sp.advance(ms(500));
assert_eq!(st(), 1);
assert_eq!(sp.time_until_next_timeout(), None);
exec.progress_until_stalled().await;
assert_eq!(st(), 2);
assert_eq!(sp.time_until_next_timeout(), Some(ms(300)));
sp.advance(ms(500));
assert_eq!(st(), 2);
assert_eq!(sp.time_until_next_timeout(), None);
exec.progress_until_stalled().await;
assert_eq!(sp.time_until_next_timeout(), None);
assert_eq!(st(), 3);
});
}
}