use crate::core::{Cause, Clock, Ctx, Effect, EnvRef, Exit, FiberId, ScopeExit, ScopeHandle};
use crate::runtime::Runtime;
use futures::future::BoxFuture;
use std::cmp::Ordering;
use std::collections::{BinaryHeap, HashMap};
use std::sync::{Arc, Mutex as StdMutex};
use tokio::sync::Mutex as TokioMutex;
use tokio::time::{Duration, Instant};
use tokio_util::sync::CancellationToken;
struct Sleeper {
wake_time: Instant,
waker: tokio::sync::Notify,
}
impl PartialEq for Sleeper {
fn eq(&self, other: &Self) -> bool {
self.wake_time == other.wake_time
}
}
impl Eq for Sleeper {}
impl PartialOrd for Sleeper {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for Sleeper {
fn cmp(&self, other: &Self) -> Ordering {
other.wake_time.cmp(&self.wake_time)
}
}
#[derive(Clone)]
pub struct TestClock {
state: Arc<StdMutex<TestClockState>>,
}
struct TestClockState {
now: Instant,
sleepers: BinaryHeap<Sleeper>,
}
impl Default for TestClock {
fn default() -> Self {
Self::new()
}
}
impl TestClock {
pub fn new() -> Self {
Self {
state: Arc::new(StdMutex::new(TestClockState {
now: Instant::now(), sleepers: BinaryHeap::new(),
})),
}
}
pub fn adjust(&self, duration: Duration) {
let mut state = self.state.lock().unwrap();
state.now += duration;
let now = state.now;
while let Some(sleeper) = state.sleepers.peek() {
if sleeper.wake_time <= now {
let sleeper = state.sleepers.pop().unwrap();
sleeper.waker.notify_one();
} else {
break;
}
}
}
}
impl Clock for TestClock {
fn sleep(&self, duration: Duration) -> BoxFuture<'static, ()> {
let state = self.state.clone();
Box::pin(async move {
let _notify = Arc::new(tokio::sync::Notify::new());
{
let mut guard = state.lock().unwrap();
let wake_time = guard.now + duration;
guard.sleepers.push(Sleeper {
wake_time,
waker: tokio::sync::Notify::new(), });
}
})
}
fn now(&self) -> Instant {
let guard = self.state.lock().unwrap();
guard.now
}
}
struct SharedSleeper {
wake_time: Instant,
notify: Arc<tokio::sync::Notify>,
}
impl PartialEq for SharedSleeper {
fn eq(&self, other: &Self) -> bool {
self.wake_time == other.wake_time
}
}
impl Eq for SharedSleeper {}
#[allow(clippy::non_canonical_partial_ord_impl)]
impl PartialOrd for SharedSleeper {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(other.wake_time.cmp(&self.wake_time)) }
}
impl Ord for SharedSleeper {
fn cmp(&self, other: &Self) -> Ordering {
other.wake_time.cmp(&self.wake_time)
}
}
#[derive(Clone)]
pub struct TestClockImpl {
state: Arc<StdMutex<TestClockStateImpl>>,
}
struct TestClockStateImpl {
now: Instant,
sleepers: BinaryHeap<SharedSleeper>,
}
impl Default for TestClockImpl {
fn default() -> Self {
Self::new()
}
}
impl TestClockImpl {
pub fn new() -> Self {
Self {
state: Arc::new(StdMutex::new(TestClockStateImpl {
now: Instant::now(),
sleepers: BinaryHeap::new(),
})),
}
}
pub fn adjust(&self, duration: Duration) {
let mut state = self.state.lock().unwrap();
state.now += duration;
let now = state.now;
while let Some(sleeper) = state.sleepers.peek() {
if sleeper.wake_time <= now {
let sleeper = state.sleepers.pop().unwrap();
sleeper.notify.notify_waiters();
} else {
break;
}
}
}
}
impl Clock for TestClockImpl {
fn sleep(&self, duration: Duration) -> BoxFuture<'static, ()> {
let state = self.state.clone();
Box::pin(async move {
let notify = Arc::new(tokio::sync::Notify::new());
{
let mut guard = state.lock().unwrap();
let wake_time = guard.now + duration;
guard.sleepers.push(SharedSleeper {
wake_time,
notify: notify.clone(),
});
}
notify.notified().await;
})
}
fn now(&self) -> Instant {
self.state.lock().unwrap().now
}
}
pub struct TestRuntime {
runtime: Runtime,
pub clock: TestClockImpl,
}
impl Default for TestRuntime {
fn default() -> Self {
Self::new()
}
}
impl TestRuntime {
pub fn new() -> Self {
Self {
runtime: Runtime::new(),
clock: TestClockImpl::new(),
}
}
pub fn block_on<R, E, A>(&self, effect: Effect<R, E, A>, env: R) -> Exit<E, A>
where
R: Clone + Send + Sync + 'static,
E: Send + Sync + Clone + 'static,
A: Send + Sync + Clone + 'static,
{
let clock = Arc::new(self.clock.clone());
let ctx = Ctx {
token: CancellationToken::new(),
scope: ScopeHandle::new(),
fiber_id: FiberId(0),
locals: Arc::new(TokioMutex::new(HashMap::new())),
clock,
};
self.runtime.rt.block_on(async move {
let result = (effect.inner)(EnvRef { value: env }, ctx.clone()).await;
let scope_exit = match &result {
Exit::Success(_) => ScopeExit::Success,
Exit::Failure(Cause::Interrupt) => ScopeExit::Interrupt,
Exit::Failure(_) => ScopeExit::Failure,
};
ctx.scope.close(scope_exit).await;
result
})
}
pub fn spawn<R, E, A>(
&self,
effect: Effect<R, E, A>,
env: R,
) -> tokio::task::JoinHandle<Exit<E, A>>
where
R: Clone + Send + Sync + 'static,
E: Send + Sync + Clone + 'static,
A: Send + Sync + Clone + 'static,
{
let clock = Arc::new(self.clock.clone());
let ctx = Ctx {
token: CancellationToken::new(),
scope: ScopeHandle::new(),
fiber_id: FiberId(0),
locals: Arc::new(TokioMutex::new(HashMap::new())),
clock,
};
self.runtime.rt.spawn(async move {
let result = (effect.inner)(EnvRef { value: env }, ctx.clone()).await;
let scope_exit = match &result {
Exit::Success(_) => ScopeExit::Success,
Exit::Failure(Cause::Interrupt) => ScopeExit::Interrupt,
Exit::Failure(_) => ScopeExit::Failure,
};
ctx.scope.close(scope_exit).await;
result
})
}
pub async fn advance(&self, duration: Duration) {
self.clock.adjust(duration);
tokio::task::yield_now().await;
}
pub fn advance_blocking(&self, duration: Duration) {
self.runtime.rt.block_on(async {
self.advance(duration).await;
})
}
pub fn block_on_future<F>(&self, future: F) -> F::Output
where
F: std::future::Future,
{
self.runtime.rt.block_on(future)
}
}