use std::{
fmt::Debug,
future::Future,
ops::Add,
prelude::v1::*,
sync::{
Arc,
atomic::{AtomicU64, Ordering},
},
time::Duration,
};
use super::nanos::Nanos;
use crate::dst::time::Instant;
pub trait Reference:
Sized + Add<Nanos, Output = Self> + PartialEq + Eq + Ord + Copy + Clone + Send + Sync + Debug
{
fn duration_since(&self, earlier: Self) -> Nanos;
#[must_use]
fn saturating_sub(&self, duration: Nanos) -> Self;
}
pub trait Clock: Clone {
type Instant: Reference;
fn now(&self) -> Self::Instant;
fn sleep(&self, duration: Duration) -> impl Future<Output = ()> + Send + '_;
}
impl Reference for Duration {
fn duration_since(&self, earlier: Self) -> Nanos {
self.checked_sub(earlier)
.unwrap_or_else(|| Self::new(0, 0))
.into()
}
fn saturating_sub(&self, duration: Nanos) -> Self {
self.checked_sub(duration.into()).unwrap_or(*self)
}
}
impl Add<Nanos> for Duration {
type Output = Self;
fn add(self, other: Nanos) -> Self {
let other: Self = other.into();
self + other
}
}
#[derive(Debug, Clone, Default)]
pub struct FakeRelativeClock {
now: Arc<AtomicU64>,
}
impl FakeRelativeClock {
pub fn advance(&self, by: Duration) {
let by: u64 = by
.as_nanos()
.try_into()
.expect("Cannot represent durations greater than 584 years");
let mut prev = self.now.load(Ordering::Acquire);
let mut next = prev + by;
while let Err(e) =
self.now
.compare_exchange_weak(prev, next, Ordering::Release, Ordering::Relaxed)
{
prev = e;
next = prev + by;
}
}
}
impl PartialEq for FakeRelativeClock {
fn eq(&self, other: &Self) -> bool {
self.now.load(Ordering::Relaxed) == other.now.load(Ordering::Relaxed)
}
}
impl Clock for FakeRelativeClock {
type Instant = Nanos;
fn now(&self) -> Self::Instant {
self.now.load(Ordering::Relaxed).into()
}
fn sleep(&self, duration: Duration) -> impl Future<Output = ()> + Send + '_ {
self.advance(duration);
std::future::ready(())
}
}
#[derive(Clone, Debug, Default)]
pub struct MonotonicClock;
impl Add<Nanos> for Instant {
type Output = Self;
fn add(self, other: Nanos) -> Self {
let other: Duration = other.into();
self + other
}
}
impl Reference for Instant {
fn duration_since(&self, earlier: Self) -> Nanos {
if earlier < *self {
(*self - earlier).into()
} else {
Nanos::from(Duration::new(0, 0))
}
}
fn saturating_sub(&self, duration: Nanos) -> Self {
self.checked_sub(duration.into()).unwrap_or(*self)
}
}
impl Clock for MonotonicClock {
type Instant = Instant;
fn now(&self) -> Self::Instant {
Instant::now()
}
async fn sleep(&self, duration: Duration) {
#[cfg(not(all(feature = "simulation", madsim)))]
tokio::time::sleep(duration).await;
#[cfg(all(feature = "simulation", madsim))]
madsim::time::sleep(duration).await;
}
}
#[cfg(test)]
mod test {
use std::{sync::Arc, thread, time::Duration};
use rstest::rstest;
use super::*;
#[rstest]
fn fake_clock_parallel_advances() {
let clock = Arc::new(FakeRelativeClock::default());
let threads = std::iter::repeat_n((), 10)
.map(move |()| {
let clock = Arc::clone(&clock);
thread::spawn(move || {
for _ in 0..1_000_000 {
let now = clock.now();
clock.advance(Duration::from_nanos(1));
assert!(clock.now() > now);
}
})
})
.collect::<Vec<_>>();
for t in threads {
t.join().unwrap();
}
}
#[rstest]
fn duration_addition_coverage() {
let d = Duration::from_secs(1);
let one_ns = Nanos::from(1);
assert!(d + one_ns > d);
}
#[cfg(all(feature = "simulation", madsim))]
#[madsim::test]
async fn test_monotonic_clock_sleep_uses_virtual_time() {
let clock = MonotonicClock;
let start = Instant::now();
clock.sleep(Duration::from_millis(100)).await;
let elapsed = start.elapsed();
assert!(elapsed >= Duration::from_millis(100));
assert!(
elapsed < Duration::from_millis(101),
"virtual sleep showed real-tokio jitter: {elapsed:?}"
);
}
}