#![forbid(unsafe_code)] #![allow(clippy::missing_docs_in_private_items)]
use std::{
cmp::{Eq, Ordering, PartialEq, PartialOrd},
collections::BinaryHeap,
fmt,
pin::Pin,
sync::{Arc, Mutex, Weak},
task::{Context, Poll, Waker},
time::{Duration, Instant, SystemTime},
};
use futures::Future;
use tracing::trace;
use std::collections::HashSet;
use std::fmt::Formatter;
use tor_rtcompat::{CoarseInstant, CoarseTimeProvider, SleepProvider};
use crate::time_core::MockTimeCore;
#[derive(Clone)]
#[cfg_attr(not(test), deprecated(since = "0.29.0"))]
pub struct MockSleepProvider {
state: Arc<Mutex<SleepSchedule>>,
}
impl fmt::Debug for MockSleepProvider {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.debug_struct("MockSleepProvider").finish_non_exhaustive()
}
}
struct SleepSchedule {
core: MockTimeCore,
sleepers: BinaryHeap<SleepEntry>,
waitfor_waker: Option<Waker>,
sleepers_made: usize,
sleepers_polled: usize,
should_advance: bool,
blocked_advance: HashSet<String>,
allowed_advance: Duration,
}
struct SleepEntry {
when: Instant,
waker: Waker,
}
pub struct Sleeping {
when: Instant,
inserted: bool,
provider: Weak<Mutex<SleepSchedule>>,
}
impl Default for MockSleepProvider {
fn default() -> Self {
let wallclock = humantime::parse_rfc3339("2023-07-05T11:25:56Z").expect("parse");
MockSleepProvider::new(wallclock)
}
}
impl MockSleepProvider {
pub fn new(wallclock: SystemTime) -> Self {
let instant = Instant::now();
let sleepers = BinaryHeap::new();
let core = MockTimeCore::new(instant, wallclock);
let state = SleepSchedule {
core,
sleepers,
waitfor_waker: None,
sleepers_made: 0,
sleepers_polled: 0,
should_advance: false,
blocked_advance: HashSet::new(),
allowed_advance: Duration::from_nanos(0),
};
MockSleepProvider {
state: Arc::new(Mutex::new(state)),
}
}
pub async fn advance(&self, dur: Duration) {
self.advance_noyield(dur);
tor_rtcompat::task::yield_now().await;
}
pub(crate) fn advance_noyield(&self, dur: Duration) {
let mut state = self.state.lock().expect("Poisoned lock for state");
state.core.advance(dur);
state.fire();
}
pub fn jump_to(&self, new_wallclock: SystemTime) {
let mut state = self.state.lock().expect("Poisoned lock for state");
state.core.jump_wallclock(new_wallclock);
}
pub(crate) fn time_until_next_timeout(&self) -> Option<Duration> {
let state = self.state.lock().expect("Poisoned lock for state");
let now = state.core.instant();
state
.sleepers
.peek()
.map(|sleepent| sleepent.when.saturating_duration_since(now))
}
#[allow(clippy::cognitive_complexity)]
pub(crate) fn should_advance(&mut self) -> bool {
let mut state = self.state.lock().expect("Poisoned lock for state");
if !state.blocked_advance.is_empty() && state.allowed_advance == Duration::from_nanos(0) {
trace!(
"should_advance = false: blocked by {:?}",
state.blocked_advance
);
return false;
}
if !state.should_advance {
trace!("should_advance = false; bit not previously set");
return false;
}
state.should_advance = false;
if state.sleepers_polled < state.sleepers_made {
trace!("should_advance = false; advancing no longer valid");
return false;
}
if !state.blocked_advance.is_empty() && state.allowed_advance > Duration::from_nanos(0) {
let next_timeout = {
let now = state.core.instant();
state
.sleepers
.peek()
.map(|sleepent| sleepent.when.saturating_duration_since(now))
};
let next_timeout = match next_timeout {
Some(x) => x,
None => {
trace!("should_advance = false; allow_one set but no timeout yet");
return false;
}
};
if next_timeout <= state.allowed_advance {
state.allowed_advance -= next_timeout;
trace!(
"WARNING: allowing advance due to allow_one; new allowed is {:?}",
state.allowed_advance
);
} else {
trace!(
"should_advance = false; allow_one set but only up to {:?}, next is {:?}",
state.allowed_advance,
next_timeout
);
return false;
}
}
true
}
pub(crate) fn register_waitfor_waker(&mut self, waker: Waker) {
let mut state = self.state.lock().expect("Poisoned lock for state");
state.waitfor_waker = Some(waker);
}
pub(crate) fn clear_waitfor_waker(&mut self) {
let mut state = self.state.lock().expect("Poisoned lock for state");
state.waitfor_waker = None;
}
pub(crate) fn has_waitfor_waker(&self) -> bool {
let state = self.state.lock().expect("Poisoned lock for state");
state.waitfor_waker.is_some()
}
}
impl SleepSchedule {
fn fire(&mut self) {
use std::collections::binary_heap::PeekMut;
let now = self.core.instant();
while let Some(top) = self.sleepers.peek_mut() {
if now < top.when {
return;
}
PeekMut::pop(top).waker.wake();
}
}
fn push(&mut self, ent: SleepEntry) {
self.sleepers.push(ent);
}
fn maybe_advance(&mut self) {
if self.sleepers_polled >= self.sleepers_made {
if let Some(ref waker) = self.waitfor_waker {
trace!("setting advance flag");
self.should_advance = true;
waker.wake_by_ref();
} else {
trace!("would advance, but no waker");
}
}
}
fn increment_poll_count(&mut self) {
self.sleepers_polled += 1;
trace!(
"sleeper polled, {}/{}",
self.sleepers_polled,
self.sleepers_made
);
self.maybe_advance();
}
}
impl SleepProvider for MockSleepProvider {
type SleepFuture = Sleeping;
fn sleep(&self, duration: Duration) -> Self::SleepFuture {
let mut provider = self.state.lock().expect("Poisoned lock for state");
let when = provider.core.instant() + duration;
provider.sleepers_made += 1;
trace!(
"sleeper made for {:?}, {}/{}",
duration,
provider.sleepers_polled,
provider.sleepers_made
);
Sleeping {
when,
inserted: false,
provider: Arc::downgrade(&self.state),
}
}
fn block_advance<T: Into<String>>(&self, reason: T) {
let mut provider = self.state.lock().expect("Poisoned lock for state");
let reason = reason.into();
trace!("advancing blocked: {}", reason);
provider.blocked_advance.insert(reason);
}
fn release_advance<T: Into<String>>(&self, reason: T) {
let mut provider = self.state.lock().expect("Poisoned lock for state");
let reason = reason.into();
trace!("advancing released: {}", reason);
provider.blocked_advance.remove(&reason);
if provider.blocked_advance.is_empty() {
provider.maybe_advance();
}
}
fn allow_one_advance(&self, dur: Duration) {
let mut provider = self.state.lock().expect("Poisoned lock for state");
provider.allowed_advance = Duration::max(provider.allowed_advance, dur);
trace!(
"** allow_one_advance fired; may advance up to {:?} **",
provider.allowed_advance
);
provider.maybe_advance();
}
fn now(&self) -> Instant {
self.state
.lock()
.expect("Poisoned lock for state")
.core
.instant()
}
fn wallclock(&self) -> SystemTime {
self.state
.lock()
.expect("Poisoned lock for state")
.core
.wallclock()
}
}
impl CoarseTimeProvider for MockSleepProvider {
fn now_coarse(&self) -> CoarseInstant {
self.state
.lock()
.expect("poisoned")
.core
.coarse()
.now_coarse()
}
}
impl PartialEq for SleepEntry {
fn eq(&self, other: &Self) -> bool {
self.when == other.when
}
}
impl Eq for SleepEntry {}
impl PartialOrd for SleepEntry {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for SleepEntry {
fn cmp(&self, other: &Self) -> Ordering {
self.when.cmp(&other.when).reverse()
}
}
impl Drop for Sleeping {
fn drop(&mut self) {
if let Some(provider) = Weak::upgrade(&self.provider) {
let mut provider = provider.lock().expect("Poisoned lock for provider");
if !self.inserted {
trace!("sleeper dropped, incrementing count");
provider.increment_poll_count();
self.inserted = true;
}
}
}
}
impl Future for Sleeping {
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
if let Some(provider) = Weak::upgrade(&self.provider) {
let mut provider = provider.lock().expect("Poisoned lock for provider");
let now = provider.core.instant();
if now >= self.when {
if !self.inserted {
provider.increment_poll_count();
self.inserted = true;
}
if !provider.should_advance {
provider.maybe_advance();
}
return Poll::Ready(());
}
if !self.inserted {
let entry = SleepEntry {
when: self.when,
waker: cx.waker().clone(),
};
provider.push(entry);
self.inserted = true;
provider.increment_poll_count();
}
}
Poll::Pending
}
}
#[cfg(all(test, not(miri)))] mod test {
#![allow(clippy::bool_assert_comparison)]
#![allow(clippy::clone_on_copy)]
#![allow(clippy::dbg_macro)]
#![allow(clippy::mixed_attributes_style)]
#![allow(clippy::print_stderr)]
#![allow(clippy::print_stdout)]
#![allow(clippy::single_char_pattern)]
#![allow(clippy::unwrap_used)]
#![allow(clippy::unchecked_duration_subtraction)]
#![allow(clippy::useless_vec)]
#![allow(clippy::needless_pass_by_value)]
use super::*;
use tor_rtcompat::test_with_all_runtimes;
#[test]
fn basics_of_time_travel() {
let w1 = SystemTime::now();
let sp = MockSleepProvider::new(w1);
let i1 = sp.now();
assert_eq!(sp.wallclock(), w1);
let interval = Duration::new(4 * 3600 + 13 * 60, 0);
sp.advance_noyield(interval);
assert_eq!(sp.now(), i1 + interval);
assert_eq!(sp.wallclock(), w1 + interval);
sp.jump_to(w1 + interval * 3);
assert_eq!(sp.now(), i1 + interval);
assert_eq!(sp.wallclock(), w1 + interval * 3);
}
#[test]
fn time_moves_on() {
test_with_all_runtimes!(|_| async {
use oneshot_fused_workaround as oneshot;
use std::sync::atomic::AtomicBool;
use std::sync::atomic::Ordering;
let sp = MockSleepProvider::new(SystemTime::now());
let one_hour = Duration::new(3600, 0);
let (s1, r1) = oneshot::channel();
let (s2, r2) = oneshot::channel();
let (s3, r3) = oneshot::channel();
let b1 = AtomicBool::new(false);
let b2 = AtomicBool::new(false);
let b3 = AtomicBool::new(false);
let real_start = Instant::now();
futures::join!(
async {
sp.sleep(one_hour).await;
b1.store(true, Ordering::SeqCst);
s1.send(()).unwrap();
},
async {
sp.sleep(one_hour * 3).await;
b2.store(true, Ordering::SeqCst);
s2.send(()).unwrap();
},
async {
sp.sleep(one_hour * 5).await;
b3.store(true, Ordering::SeqCst);
s3.send(()).unwrap();
},
async {
sp.advance(one_hour * 2).await;
r1.await.unwrap();
assert!(b1.load(Ordering::SeqCst));
assert!(!b2.load(Ordering::SeqCst));
assert!(!b3.load(Ordering::SeqCst));
sp.advance(one_hour * 2).await;
r2.await.unwrap();
assert!(b1.load(Ordering::SeqCst));
assert!(b2.load(Ordering::SeqCst));
assert!(!b3.load(Ordering::SeqCst));
sp.advance(one_hour * 2).await;
r3.await.unwrap();
assert!(b1.load(Ordering::SeqCst));
assert!(b2.load(Ordering::SeqCst));
assert!(b3.load(Ordering::SeqCst));
let real_end = Instant::now();
assert!(real_end - real_start < one_hour);
}
);
std::io::Result::Ok(())
});
}
}