use std::any::Any;
use std::cell::Cell;
use std::collections::VecDeque;
use std::fmt::{self, Debug, Display};
use std::future::Future;
use std::io::{self, Write as _};
use std::iter;
use std::panic::{catch_unwind, panic_any, AssertUnwindSafe};
use std::pin::{pin, Pin};
use std::sync::{Arc, Mutex, MutexGuard, Weak};
use std::task::{Context, Poll, RawWaker, RawWakerVTable, Waker};
use futures::pin_mut;
use futures::task::{FutureObj, Spawn, SpawnError};
use futures::FutureExt as _;
use assert_matches::assert_matches;
use educe::Educe;
use itertools::Either::{self, *};
use itertools::{chain, izip};
use slotmap_careful::DenseSlotMap;
use std::backtrace::Backtrace;
use strum::EnumIter;
use tracing::{error, trace};
use oneshot_fused_workaround::{self as oneshot, Canceled};
use tor_error::error_report;
use tor_rtcompat::{Blocking, ToplevelBlockOn};
use Poll::*;
use TaskState::*;
type TaskFuture = FutureObj<'static, ()>;
type MainFuture<'m> = Pin<&'m mut dyn Future<Output = ()>>;
#[derive(Clone, Default, Educe)]
#[educe(Debug)]
pub struct MockExecutor {
#[educe(Debug(ignore))]
shared: Arc<Shared>,
}
#[derive(Default)]
struct Shared {
data: Mutex<Data>,
thread_condvar: std::sync::Condvar,
}
mod task_id {
slotmap_careful::new_key_type! {
pub(super) struct Ti;
}
}
use task_id::Ti as TaskId;
#[derive(Educe, derive_more::Debug)]
#[educe(Default)]
struct Data {
#[debug("{:?}", DebugTasks(self, || tasks.keys()))]
tasks: DenseSlotMap<TaskId, Task>,
#[debug("{:?}", DebugTasks(self, || awake.iter().cloned()))]
awake: VecDeque<TaskId>,
progressing_until_stalled: Option<ProgressingUntilStalled>,
scheduling: SchedulingPolicy,
#[educe(Default(expression = "ThreadDescriptor::Executor"))]
thread_to_run: ThreadDescriptor,
}
#[derive(Debug, Clone, Default, EnumIter)]
#[non_exhaustive]
pub enum SchedulingPolicy {
#[default]
Stack,
Queue,
}
struct Task {
desc: String,
state: TaskState,
fut: Option<TaskFutureInfo>,
}
#[derive(Educe)]
#[educe(Debug)]
enum TaskFutureInfo {
Normal(#[educe(Debug(ignore))] TaskFuture),
Main,
Subthread,
}
#[derive(Debug)]
enum TaskState {
Awake,
Asleep(Vec<SleepLocation>),
}
struct ActualWaker {
data: Weak<Shared>,
id: TaskId,
}
#[derive(Debug)]
struct ProgressingUntilStalled {
finished: Poll<()>,
waker: Option<Waker>,
}
#[derive(Educe)]
#[educe(Debug)]
struct ProgressUntilStalledFuture {
#[educe(Debug(ignore))]
shared: Arc<Shared>,
}
#[derive(Copy, Clone, Eq, PartialEq, derive_more::Debug)]
enum ThreadDescriptor {
#[debug("FOREIGN")]
Foreign,
#[debug("Exe")]
Executor,
#[debug("{_0:?}")]
Subthread(TaskId),
}
#[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd)]
struct IsSubthread;
#[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd)]
struct SetAwake;
thread_local! {
pub static THREAD_DESCRIPTOR: Cell<ThreadDescriptor> = const {
Cell::new(ThreadDescriptor::Foreign)
};
}
impl MockExecutor {
pub fn new() -> Self {
Self::default()
}
pub fn with_scheduling(scheduling: SchedulingPolicy) -> Self {
Data {
scheduling,
..Default::default()
}
.into()
}
}
impl From<Data> for MockExecutor {
fn from(data: Data) -> MockExecutor {
let shared = Shared {
data: Mutex::new(data),
thread_condvar: std::sync::Condvar::new(),
};
MockExecutor {
shared: Arc::new(shared),
}
}
}
impl MockExecutor {
pub fn spawn_identified(
&self,
desc: impl Display,
fut: impl Future<Output = ()> + Send + 'static,
) -> impl Debug + Clone + Send + 'static {
self.spawn_internal(desc.to_string(), FutureObj::from(Box::new(fut)))
}
pub fn spawn_join<T: Debug + Send + 'static>(
&self,
desc: impl Display,
fut: impl Future<Output = T> + Send + 'static,
) -> impl Future<Output = T> {
let (tx, rx) = oneshot::channel();
self.spawn_identified(desc, async move {
let res = fut.await;
tx.send(res)
.expect("Failed to send future's output, did future panic?");
});
rx.map(|m| m.expect("Failed to receive future's output"))
}
fn spawn_internal(&self, desc: String, fut: TaskFuture) -> TaskId {
let mut data = self.shared.lock();
data.insert_task(desc, TaskFutureInfo::Normal(fut))
}
}
impl Data {
fn insert_task(&mut self, desc: String, fut: TaskFutureInfo) -> TaskId {
let state = Awake;
let id = self.tasks.insert(Task {
state,
desc,
fut: Some(fut),
});
self.awake.push_back(id);
trace!("MockExecutor spawned {:?}={:?}", id, self.tasks[id]);
id
}
}
impl Spawn for MockExecutor {
fn spawn_obj(&self, future: TaskFuture) -> Result<(), SpawnError> {
self.spawn_internal("spawn_obj".into(), future);
Ok(())
}
}
impl Blocking for MockExecutor {
type ThreadHandle<T: Send + 'static> = Pin<Box<dyn Future<Output = T>>>;
fn spawn_blocking<F, T>(&self, f: F) -> Self::ThreadHandle<T>
where
F: FnOnce() -> T + Send + 'static,
T: Send + 'static,
{
assert_matches!(
THREAD_DESCRIPTOR.get(),
ThreadDescriptor::Executor | ThreadDescriptor::Subthread(_),
"MockExecutor::spawn_blocking_io only allowed from future or subthread, being run by this executor"
);
Box::pin(
self.subthread_spawn("spawn_blocking", f)
.map(|x| x.expect("Error in spawn_blocking subthread.")),
)
}
fn reenter_block_on<F>(&self, future: F) -> F::Output
where
F: Future,
F::Output: Send + 'static,
{
self.subthread_block_on_future(future)
}
}
impl ToplevelBlockOn for MockExecutor {
fn block_on<F>(&self, input_fut: F) -> F::Output
where
F: Future,
{
let mut value: Option<F::Output> = None;
let mut input_fut = Box::pin(input_fut);
let run_store_fut = {
let value = &mut value;
let input_fut = &mut input_fut;
async {
trace!("MockExecutor block_on future...");
let t = input_fut.await;
trace!("MockExecutor block_on future returned...");
*value = Some(t);
trace!("MockExecutor block_on future exiting.");
}
};
{
pin_mut!(run_store_fut);
let main_id = self
.shared
.lock()
.insert_task("main".into(), TaskFutureInfo::Main);
trace!("MockExecutor {main_id:?} is task for block_on");
self.execute_to_completion(run_store_fut);
}
#[allow(clippy::let_and_return)] let value = value.take().unwrap_or_else(|| {
let _: io::Result<()> = writeln!(io::stderr(), "all futures blocked, crashing...");
error!("all futures blocked, crashing...");
{
let mut data = self.shared.lock();
data.debug_dump();
}
drop(input_fut);
panic!(
r"
all futures blocked. waiting for the real world? or deadlocked (waiting for each other) ?
"
);
});
value
}
}
impl MockExecutor {
fn execute_to_completion(&self, mut main_fut: MainFuture) {
trace!("MockExecutor execute_to_completion...");
loop {
self.execute_until_first_stall(main_fut.as_mut());
let pus_waker = {
let mut data = self.shared.lock();
let pus = &mut data.progressing_until_stalled;
trace!("MockExecutor execute_to_completion PUS={:?}", &pus);
let Some(pus) = pus else {
break;
};
assert_eq!(
pus.finished, Pending,
"ProgressingUntilStalled finished twice?!"
);
pus.finished = Ready(());
let waker = pus
.waker
.take()
.expect("ProgressUntilStalledFuture not ever polled!");
drop(data);
let waker_copy = waker.clone();
let mut data = self.shared.lock();
let pus = &mut data.progressing_until_stalled;
if let Some(double) = pus
.as_mut()
.expect("progressing_until_stalled updated under our feet!")
.waker
.replace(waker)
{
panic!("double progressing_until_stalled.waker! {double:?}");
}
waker_copy
};
pus_waker.wake();
}
trace!("MockExecutor execute_to_completion done");
}
fn execute_until_first_stall(&self, main_fut: MainFuture) {
trace!("MockExecutor execute_until_first_stall ...");
assert_eq!(
THREAD_DESCRIPTOR.get(),
ThreadDescriptor::Foreign,
"MockExecutor executor re-entered"
);
THREAD_DESCRIPTOR.set(ThreadDescriptor::Executor);
let r = catch_unwind(AssertUnwindSafe(|| self.executor_main_loop(main_fut)));
THREAD_DESCRIPTOR.set(ThreadDescriptor::Foreign);
match r {
Ok(()) => trace!("MockExecutor execute_until_first_stall done."),
Err(e) => {
trace!("MockExecutor executor, or async task, panicked!");
panic_any(e)
}
}
}
#[allow(clippy::cognitive_complexity)]
fn executor_main_loop(&self, mut main_fut: MainFuture) {
'outer: loop {
let (id, mut fut) = 'inner: loop {
let mut data = self.shared.lock();
let Some(id) = data.schedule() else {
break 'outer;
};
let Some(task) = data.tasks.get_mut(id) else {
trace!("MockExecutor {id:?} vanished");
continue;
};
task.state = Asleep(vec![]);
let fut = task.fut.take().expect("future missing from task!");
break 'inner (id, fut);
};
trace!("MockExecutor {id:?} polling...");
let waker = ActualWaker::make_waker(&self.shared, id);
let mut cx = Context::from_waker(&waker);
let r: Either<Poll<()>, IsSubthread> = match &mut fut {
TaskFutureInfo::Normal(fut) => Left(fut.poll_unpin(&mut cx)),
TaskFutureInfo::Main => Left(main_fut.as_mut().poll(&mut cx)),
TaskFutureInfo::Subthread => Right(IsSubthread),
};
let _fut_drop_late;
{
let mut data = self.shared.lock();
let task = data
.tasks
.get_mut(id)
.expect("task vanished while we were polling it");
match r {
Left(Pending) => {
trace!("MockExecutor {id:?} -> Pending");
if task.fut.is_some() {
panic!("task reinserted while we polled it?!");
}
task.fut = Some(fut);
}
Left(Ready(())) => {
trace!("MockExecutor {id:?} -> Ready");
data.tasks.remove(id);
_fut_drop_late = fut;
}
Right(IsSubthread) => {
trace!("MockExecutor {id:?} -> Ready, waking Subthread");
task.fut = Some(fut);
self.shared.thread_context_switch(
data,
ThreadDescriptor::Executor,
ThreadDescriptor::Subthread(id),
);
}
}
}
}
}
}
impl Data {
fn schedule(&mut self) -> Option<TaskId> {
use SchedulingPolicy as SP;
match self.scheduling {
SP::Stack => self.awake.pop_back(),
SP::Queue => self.awake.pop_front(),
}
}
}
impl ActualWaker {
fn upgrade_data(&self) -> Option<Arc<Shared>> {
self.data.upgrade()
}
fn wake(&self) {
let Some(data) = self.upgrade_data() else {
return;
};
let mut data = data.lock();
let data = &mut *data;
trace!("MockExecutor {:?} wake", &self.id);
let Some(task) = data.tasks.get_mut(self.id) else {
return;
};
task.set_awake(self.id, &mut data.awake);
}
fn make_waker(shared: &Arc<Shared>, id: TaskId) -> Waker {
ActualWaker {
data: Arc::downgrade(shared),
id,
}
.new_waker()
}
}
impl MockExecutor {
pub fn progress_until_stalled(&self) -> impl Future<Output = ()> {
let mut data = self.shared.lock();
assert!(
data.progressing_until_stalled.is_none(),
"progress_until_stalled called more than once"
);
trace!("MockExecutor progress_until_stalled...");
data.progressing_until_stalled = Some(ProgressingUntilStalled {
finished: Pending,
waker: None,
});
ProgressUntilStalledFuture {
shared: self.shared.clone(),
}
}
}
impl Future for ProgressUntilStalledFuture {
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<()> {
let waker = cx.waker().clone();
let mut data = self.shared.lock();
let pus = data.progressing_until_stalled.as_mut();
trace!("MockExecutor progress_until_stalled polling... {:?}", &pus);
let pus = pus.expect("ProgressingUntilStalled missing");
pus.waker = Some(waker);
pus.finished
}
}
impl Drop for ProgressUntilStalledFuture {
fn drop(&mut self) {
self.shared.lock().progressing_until_stalled = None;
}
}
impl MockExecutor {
pub fn subthread_spawn<T: Send + 'static>(
&self,
desc: impl Display,
call: impl FnOnce() -> T + Send + 'static,
) -> impl Future<Output = Result<T, Box<dyn Any + Send>>> + Unpin + Send + Sync + 'static {
let desc = desc.to_string();
let (output_tx, output_rx) = oneshot::channel();
{
let mut data = self.shared.lock();
let id = data.insert_task(desc.clone(), TaskFutureInfo::Subthread);
let _: std::thread::JoinHandle<()> = std::thread::Builder::new()
.name(desc)
.spawn({
let shared = self.shared.clone();
move || shared.subthread_entrypoint(id, call, output_tx)
})
.expect("spawn failed");
}
output_rx.map(|r| {
r.unwrap_or_else(|_: Canceled| panic!("Subthread cancelled but should be impossible!"))
})
}
#[allow(clippy::cognitive_complexity)] pub fn subthread_block_on_future<T: Send + 'static>(&self, fut: impl Future<Output = T>) -> T {
let id = match THREAD_DESCRIPTOR.get() {
ThreadDescriptor::Subthread(id) => id,
ThreadDescriptor::Executor => {
panic!("subthread_block_on_future called from MockExecutor thread (async task?)")
}
ThreadDescriptor::Foreign => panic!(
"subthread_block_on_future called on foreign thread (not spawned with spawn_subthread)"
),
};
trace!("MockExecutor thread {id:?}, subthread_block_on_future...");
let mut fut = pin!(fut);
let yield_ = |set_awake| self.shared.subthread_yield(id, set_awake);
yield_(Some(SetAwake));
let ret = loop {
trace!("MockExecutor thread {id:?}, s.t._block_on_future polling...");
let waker = ActualWaker::make_waker(&self.shared, id);
let mut cx = Context::from_waker(&waker);
let r: Poll<T> = fut.as_mut().poll(&mut cx);
if let Ready(r) = r {
trace!("MockExecutor thread {id:?}, s.t._block_on_future poll -> Ready");
break r;
}
trace!("MockExecutor thread {id:?}, s.t._block_on_future poll -> Pending");
yield_(None);
};
yield_(Some(SetAwake));
trace!("MockExecutor thread {id:?}, subthread_block_on_future complete.");
ret
}
}
impl Shared {
fn subthread_entrypoint<T: Send + 'static>(
self: Arc<Self>,
id: TaskId,
call: impl FnOnce() -> T + Send + 'static,
output_tx: oneshot::Sender<Result<T, Box<dyn Any + Send>>>,
) {
THREAD_DESCRIPTOR.set(ThreadDescriptor::Subthread(id));
trace!("MockExecutor thread {id:?}, entrypoint");
{
let data = self.lock();
self.thread_context_switch_waitfor_instruction_to_run(
data,
ThreadDescriptor::Subthread(id),
);
}
trace!("MockExecutor thread {id:?}, entering user code");
let ret = catch_unwind(AssertUnwindSafe(call));
trace!("MockExecutor thread {id:?}, completed user code");
output_tx.send(ret).unwrap_or_else(
#[allow(clippy::unnecessary_lazy_evaluations)]
|_| {}, );
{
let mut data = self.lock();
let _: Task = data.tasks.remove(id).expect("Subthread task vanished!");
self.thread_context_switch_send_instruction_to_run(
&mut data,
ThreadDescriptor::Subthread(id),
ThreadDescriptor::Executor,
);
}
}
fn subthread_yield(&self, us: TaskId, set_awake: Option<SetAwake>) {
let mut data = self.lock();
{
let data = &mut *data;
let task = data.tasks.get_mut(us).expect("Subthread task vanished!");
match &task.fut {
Some(TaskFutureInfo::Subthread) => {}
other => panic!("subthread_block_on_future but TFI {other:?}"),
};
if let Some(SetAwake) = set_awake {
task.set_awake(us, &mut data.awake);
}
}
self.thread_context_switch(
data,
ThreadDescriptor::Subthread(us),
ThreadDescriptor::Executor,
);
}
fn thread_context_switch(
&self,
mut data: MutexGuard<Data>,
us: ThreadDescriptor,
them: ThreadDescriptor,
) {
trace!("MockExecutor thread {us:?}, switching to {them:?}");
self.thread_context_switch_send_instruction_to_run(&mut data, us, them);
self.thread_context_switch_waitfor_instruction_to_run(data, us);
}
fn thread_context_switch_send_instruction_to_run(
&self,
data: &mut MutexGuard<Data>,
us: ThreadDescriptor,
them: ThreadDescriptor,
) {
assert_eq!(data.thread_to_run, us);
data.thread_to_run = them;
self.thread_condvar.notify_all();
}
fn thread_context_switch_waitfor_instruction_to_run(
&self,
data: MutexGuard<Data>,
us: ThreadDescriptor,
) {
#[allow(let_underscore_lock)]
let _: MutexGuard<_> = self
.thread_condvar
.wait_while(data, |data| {
let live = data.thread_to_run;
let resume = live == us;
if resume {
trace!("MockExecutor thread {us:?}, resuming");
} else {
trace!("MockExecutor thread {us:?}, waiting for {live:?}");
}
!resume
})
.expect("data lock poisoned");
}
}
#[allow(dead_code)] trait EnsureSyncSend: Sync + Send + 'static {}
impl EnsureSyncSend for ActualWaker {}
impl EnsureSyncSend for MockExecutor {}
impl MockExecutor {
pub fn n_tasks(&self) -> usize {
self.shared.lock().tasks.len()
}
}
impl Shared {
fn lock(&self) -> MutexGuard<Data> {
self.data.lock().expect("data lock poisoned")
}
}
impl Task {
fn set_awake(&mut self, id: TaskId, data_awake: &mut VecDeque<TaskId>) {
match self.state {
Awake => {}
Asleep(_) => {
self.state = Awake;
data_awake.push_back(id);
}
}
}
}
impl ActualWaker {
fn new_waker(self) -> Waker {
unsafe { Waker::from_raw(self.raw_new()) }
}
fn raw_new(self) -> RawWaker {
let self_: Box<ActualWaker> = self.into();
let self_: *mut ActualWaker = Box::into_raw(self_);
let self_: *const () = self_ as _;
RawWaker::new(self_, &RAW_WAKER_VTABLE)
}
unsafe fn raw_clone(self_: *const ()) -> RawWaker {
let self_: *const ActualWaker = self_ as _;
let self_: &ActualWaker = self_.as_ref().unwrap_unchecked();
let copy: ActualWaker = self_.clone();
copy.raw_new()
}
unsafe fn raw_wake(self_: *const ()) {
Self::raw_wake_by_ref(self_);
Self::raw_drop(self_);
}
unsafe fn raw_wake_by_ref(self_: *const ()) {
let self_: *const ActualWaker = self_ as _;
let self_: &ActualWaker = self_.as_ref().unwrap_unchecked();
self_.wake();
}
unsafe fn raw_drop(self_: *const ()) {
let self_: *mut ActualWaker = self_ as _;
let self_: Box<ActualWaker> = Box::from_raw(self_);
drop(self_);
}
}
static RAW_WAKER_VTABLE: RawWakerVTable = RawWakerVTable::new(
ActualWaker::raw_clone,
ActualWaker::raw_wake,
ActualWaker::raw_wake_by_ref,
ActualWaker::raw_drop,
);
#[cfg(not(miri))]
type SleepLocation = Backtrace;
impl Data {
fn dump_backtraces(&self, f: &mut fmt::Formatter) -> fmt::Result {
for (id, task) in self.tasks.iter() {
let prefix = |f: &mut fmt::Formatter| write!(f, "{id:?}={task:?}: ");
match &task.state {
Awake => {
prefix(f)?;
writeln!(f, "awake")?;
}
Asleep(locs) => {
let n = locs.len();
for (i, loc) in locs.iter().enumerate() {
prefix(f)?;
writeln!(f, "asleep, backtrace {i}/{n}:\n{loc}",)?;
}
if n == 0 {
prefix(f)?;
writeln!(f, "asleep, no backtraces, Waker never cloned, stuck!",)?;
}
}
}
}
writeln!(
f,
"\nNote: there might be spurious traces, see docs for MockExecutor::debug_dump\n"
)?;
Ok(())
}
}
impl Clone for ActualWaker {
fn clone(&self) -> Self {
let id = self.id;
if let Some(data) = self.upgrade_data() {
let mut data = data.lock();
if let Some(task) = data.tasks.get_mut(self.id) {
match &mut task.state {
Awake => trace!("MockExecutor cloned waker for awake task {id:?}"),
Asleep(locs) => locs.push(SleepLocation::force_capture()),
}
} else {
trace!("MockExecutor cloned waker for dead task {id:?}");
}
}
ActualWaker {
data: self.data.clone(),
id,
}
}
}
pub struct DebugDump<'a>(Either<&'a Data, MutexGuard<'a, Data>>);
impl MockExecutor {
pub fn debug_dump(&self) {
self.as_debug_dump().to_stderr();
}
pub fn as_debug_dump(&self) -> DebugDump {
let data = self.shared.lock();
DebugDump(Right(data))
}
}
impl Data {
fn debug_dump(&mut self) {
DebugDump(Left(self)).to_stderr();
}
}
impl DebugDump<'_> {
#[allow(clippy::wrong_self_convention)] fn to_stderr(self) {
write!(io::stderr().lock(), "{:?}", self)
.unwrap_or_else(|e| error_report!(e, "failed to write debug dump to stderr"));
}
}
impl Debug for DebugDump<'_> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let self_: &Data = &self.0;
writeln!(f, "MockExecutor state:\n{self_:#?}")?;
writeln!(f, "MockExecutor task dump:")?;
self_.dump_backtraces(f)?;
Ok(())
}
}
impl Debug for Task {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let Task { desc, state, fut } = self;
write!(f, "{:?}", desc)?;
write!(f, "=")?;
match fut {
None => write!(f, "P")?,
Some(TaskFutureInfo::Normal(_)) => write!(f, "f")?,
Some(TaskFutureInfo::Main) => write!(f, "m")?,
Some(TaskFutureInfo::Subthread) => write!(f, "T")?,
}
match state {
Awake => write!(f, "W")?,
Asleep(locs) => write!(f, "s{}", locs.len())?,
};
Ok(())
}
}
#[allow(dead_code)]
struct DebugTasks<'d, F>(&'d Data, F);
impl<F, I> Debug for DebugTasks<'_, F>
where
F: Fn() -> I,
I: Iterator<Item = TaskId>,
{
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let DebugTasks(data, ids) = self;
for (id, delim) in izip!(ids(), chain!(iter::once(""), iter::repeat(" ")),) {
write!(f, "{delim}{id:?}")?;
match data.tasks.get(id) {
None => write!(f, "-")?,
Some(task) => write!(f, "={task:?}")?,
}
}
Ok(())
}
}
#[cfg(miri)]
mod miri_sleep_location {
#[derive(Debug, derive_more::Display)]
#[display("<SleepLocation>")]
pub(super) struct SleepLocation {}
impl SleepLocation {
pub(super) fn force_capture() -> Self {
SleepLocation {}
}
}
}
#[cfg(miri)]
use miri_sleep_location::SleepLocation;
#[cfg(test)]
mod test {
#![allow(clippy::bool_assert_comparison)]
#![allow(clippy::clone_on_copy)]
#![allow(clippy::dbg_macro)]
#![allow(clippy::mixed_attributes_style)]
#![allow(clippy::print_stderr)]
#![allow(clippy::print_stdout)]
#![allow(clippy::single_char_pattern)]
#![allow(clippy::unwrap_used)]
#![allow(clippy::unchecked_duration_subtraction)]
#![allow(clippy::useless_vec)]
#![allow(clippy::needless_pass_by_value)]
use super::*;
use futures::channel::mpsc;
use futures::{SinkExt as _, StreamExt as _};
use strum::IntoEnumIterator;
use tracing::info;
#[cfg(not(miri))] use tracing_test::traced_test;
fn various_mock_executors() -> impl Iterator<Item = MockExecutor> {
SchedulingPolicy::iter().map(|scheduling| {
eprintln!("===== MockExecutor::with_scheduling({scheduling:?}) =====");
MockExecutor::with_scheduling(scheduling)
})
}
#[cfg_attr(not(miri), traced_test)]
#[test]
fn simple() {
let runtime = MockExecutor::default();
let val = runtime.block_on(async { 42 });
assert_eq!(val, 42);
}
#[cfg_attr(not(miri), traced_test)]
#[test]
fn stall() {
let runtime = MockExecutor::default();
runtime.block_on({
let runtime = runtime.clone();
async move {
const N: usize = 3;
let (mut txs, mut rxs): (Vec<_>, Vec<_>) =
(0..N).map(|_| mpsc::channel::<usize>(5)).unzip();
let mut rx_n = rxs.pop().unwrap();
for (i, mut rx) in rxs.into_iter().enumerate() {
runtime.spawn_identified(i, {
let mut txs = txs.clone();
async move {
loop {
eprintln!("task {i} rx...");
let v = rx.next().await.unwrap();
let nv = v + 1;
eprintln!("task {i} rx {v}, tx {nv}");
let v = nv;
txs[v].send(v).await.unwrap();
}
}
});
}
dbg!();
let _: mpsc::TryRecvError = rx_n.try_next().unwrap_err();
dbg!();
runtime.progress_until_stalled().await;
dbg!();
let _: mpsc::TryRecvError = rx_n.try_next().unwrap_err();
dbg!();
txs[0].send(0).await.unwrap();
dbg!();
runtime.progress_until_stalled().await;
dbg!();
let r = rx_n.next().await;
assert_eq!(r, Some(N - 1));
dbg!();
let _: mpsc::TryRecvError = rx_n.try_next().unwrap_err();
runtime.spawn_identified("tx", {
let txs = txs.clone();
async {
eprintln!("sending task...");
for (i, mut tx) in txs.into_iter().enumerate() {
eprintln!("sending 0 to {i}...");
tx.send(0).await.unwrap();
}
eprintln!("sending task done");
}
});
runtime.debug_dump();
for i in 0..txs.len() {
eprintln!("main {i} wait stall...");
runtime.progress_until_stalled().await;
eprintln!("main {i} rx wait...");
let r = rx_n.next().await;
eprintln!("main {i} rx = {r:?}");
assert!(r == Some(0) || r == Some(N - 1));
}
eprintln!("finishing...");
runtime.progress_until_stalled().await;
eprintln!("finished.");
}
});
}
#[cfg_attr(not(miri), traced_test)]
#[test]
fn spawn_blocking() {
let runtime = MockExecutor::default();
runtime.block_on({
let runtime = runtime.clone();
async move {
let thr_1 = runtime.spawn_blocking(|| 42);
let thr_2 = runtime.spawn_blocking(|| 99);
assert_eq!(thr_2.await, 99);
assert_eq!(thr_1.await, 42);
}
});
}
#[cfg_attr(not(miri), traced_test)]
#[test]
fn drop_reentrancy() {
struct ReentersOnDrop {
runtime: MockExecutor,
}
impl Future for ReentersOnDrop {
type Output = ();
fn poll(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<()> {
Poll::Ready(())
}
}
impl Drop for ReentersOnDrop {
fn drop(&mut self) {
self.runtime
.spawn_identified("dummy", futures::future::ready(()));
}
}
for runtime in various_mock_executors() {
runtime.block_on(async {
runtime.spawn_identified("trapper", {
let runtime = runtime.clone();
ReentersOnDrop { runtime }
});
});
}
}
#[cfg_attr(not(miri), traced_test)]
#[test]
fn subthread_oneshot() {
for runtime in various_mock_executors() {
runtime.block_on(async {
let (tx, rx) = oneshot::channel();
info!("spawning subthread");
let thr = runtime.subthread_spawn("thr1", {
let runtime = runtime.clone();
move || {
info!("subthread_block_on_future...");
let i = runtime.subthread_block_on_future(rx).unwrap();
info!("subthread_block_on_future => {i}");
i + 1
}
});
info!("main task sending");
tx.send(12).unwrap();
info!("main task sent");
let r = thr.await.unwrap();
info!("main task thr => {r}");
assert_eq!(r, 13);
});
}
}
#[cfg_attr(not(miri), traced_test)]
#[test]
#[allow(clippy::cognitive_complexity)] fn subthread_pingpong() {
for runtime in various_mock_executors() {
runtime.block_on(async {
let (mut i_tx, mut i_rx) = mpsc::channel(1);
let (mut o_tx, mut o_rx) = mpsc::channel(1);
info!("spawning subthread");
let thr = runtime.subthread_spawn("thr", {
let runtime = runtime.clone();
move || {
while let Some(i) = {
info!("thread receiving ...");
runtime.subthread_block_on_future(i_rx.next())
} {
let o = i + 12;
info!("thread received {i}, sending {o}");
runtime.subthread_block_on_future(o_tx.send(o)).unwrap();
info!("thread sent {o}");
}
info!("thread exiting");
42
}
});
for i in 0..2 {
info!("main task sending {i}");
i_tx.send(i).await.unwrap();
info!("main task sent {i}");
let o = o_rx.next().await.unwrap();
info!("main task recv => {o}");
assert_eq!(o, i + 12);
}
info!("main task dropping sender");
drop(i_tx);
info!("main task awaiting thread");
let r = thr.await.unwrap();
info!("main task complete");
assert_eq!(r, 42);
});
}
}
}