use std::{
future::Future,
pin::Pin,
sync::{
Arc, LazyLock, Mutex, PoisonError,
atomic::{AtomicBool, AtomicU64, Ordering},
},
task::{Context, Poll, Wake, Waker},
};
use scoped_tls::scoped_thread_local;
use switchy_random::{rand::rand::seq::IteratorRandom, rng};
pub use crate::Builder;
use crate::{Error, GenericRuntime, task};
use std::cell::RefCell;
use std::collections::BTreeMap;
type LocalFutureMap = RefCell<BTreeMap<u64, Pin<Box<dyn Future<Output = ()> + 'static>>>>;
thread_local! {
static LOCAL_FUTURES: LocalFutureMap = RefCell::new(BTreeMap::new());
}
struct LocalFutureProxy {
id: u64,
completed: bool,
}
impl LocalFutureProxy {
fn new<T: 'static>(
future: impl Future<Output = T> + 'static,
sender: futures::channel::oneshot::Sender<T>,
) -> Self {
let id = TASK_ID.fetch_add(1, Ordering::SeqCst);
let wrapped_future = async move {
let result = future.await;
let _ = sender.send(result);
};
LOCAL_FUTURES.with(|futures| {
futures.borrow_mut().insert(id, Box::pin(wrapped_future));
});
Self {
id,
completed: false,
}
}
}
impl Future for LocalFutureProxy {
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
if self.completed {
return Poll::Ready(());
}
LOCAL_FUTURES.with(|futures| {
let mut futures = futures.borrow_mut();
if let Some(future) = futures.get_mut(&self.id) {
match future.as_mut().poll(cx) {
Poll::Ready(()) => {
futures.remove(&self.id);
self.completed = true;
Poll::Ready(())
}
Poll::Pending => Poll::Pending,
}
} else {
self.completed = true;
Poll::Ready(())
}
})
}
}
impl Drop for LocalFutureProxy {
fn drop(&mut self) {
if !self.completed {
LOCAL_FUTURES.with(|futures| {
futures.borrow_mut().remove(&self.id);
});
}
}
}
type Queue = Arc<Mutex<Vec<Arc<Task>>>>;
static RUNTIME_ID: LazyLock<AtomicU64> = LazyLock::new(|| AtomicU64::new(1));
static TASK_ID: LazyLock<AtomicU64> = LazyLock::new(|| AtomicU64::new(1));
#[derive(Debug, Clone)]
pub struct Handle {
runtime: Arc<Runtime>,
}
impl Handle {
pub fn block_on<F: Future>(&self, f: F) -> F::Output {
self.runtime.block_on(f)
}
pub fn spawn<T: Send + 'static>(
&self,
future: impl Future<Output = T> + Send + 'static,
) -> JoinHandle<T> {
self.runtime.spawn(future)
}
pub fn spawn_with_name<T: Send + 'static>(
&self,
name: &str,
future: impl Future<Output = T> + Send + 'static,
) -> JoinHandle<T> {
if log::log_enabled!(log::Level::Trace) {
log::trace!("spawn start: {name}");
let name = name.to_owned();
let future = async move {
let response = future.await;
log::trace!("spawn finished: {name}");
response
};
self.runtime.spawn(future)
} else {
self.runtime.spawn(future)
}
}
pub fn spawn_blocking<F, R>(&self, func: F) -> JoinHandle<R>
where
F: FnOnce() -> R + Send + 'static,
R: Send + 'static,
{
self.runtime.spawn_blocking(func)
}
pub fn spawn_blocking_with_name<F, R>(&self, name: &str, func: F) -> JoinHandle<R>
where
F: FnOnce() -> R + Send + 'static,
R: Send + 'static,
{
if log::log_enabled!(log::Level::Trace) {
log::trace!("spawn_blocking start: {name}");
let name = name.to_owned();
let func = move || {
let response = func();
log::trace!("spawn_blocking finished: {name}");
response
};
self.runtime.spawn_blocking(func)
} else {
self.runtime.spawn_blocking(func)
}
}
pub fn spawn_local<T: 'static>(
&self,
future: impl Future<Output = T> + 'static,
) -> JoinHandle<T> {
self.runtime.spawn_local(future)
}
pub fn spawn_local_with_name<T: 'static>(
&self,
name: &str,
future: impl Future<Output = T> + 'static,
) -> JoinHandle<T> {
if log::log_enabled!(log::Level::Trace) {
log::trace!("spawn_local start: {name}");
let name = name.to_owned();
let future = async move {
let response = future.await;
log::trace!("spawn_local finished: {name}");
response
};
self.runtime.spawn_local(future)
} else {
self.runtime.spawn_local(future)
}
}
#[must_use]
pub fn current() -> Self {
Runtime::current().map(|x| x.handle()).unwrap()
}
}
scoped_thread_local! {
static RUNTIME: Runtime
}
#[derive(Debug, Clone)]
pub struct Runtime {
id: u64,
queue: Queue,
spawner: Spawner,
tasks: Arc<AtomicU64>,
active: Arc<AtomicBool>,
handle: Option<Handle>,
}
impl Default for Runtime {
fn default() -> Self {
Self::new()
}
}
impl PartialEq for Runtime {
fn eq(&self, other: &Self) -> bool {
self.id == other.id
}
}
impl GenericRuntime for Runtime {
fn block_on<F: Future>(&self, future: F) -> F::Output {
assert!(
Self::current().is_none(),
"Cannot run block_on within a runtime"
);
log::trace!("block_on");
self.start();
RUNTIME.set(self, || {
let mut future = Box::pin(future);
let waker = futures::task::noop_waker();
let mut ctx = Context::from_waker(&waker);
loop {
#[allow(clippy::significant_drop_in_scrutinee)]
match future.as_mut().poll(&mut ctx) {
Poll::Ready(x) => {
return x;
}
Poll::Pending => {
if !self.process_next_task() {
std::thread::yield_now();
}
}
}
}
})
}
fn wait(self) -> Result<(), Error> {
log::debug!("wait: entering, outstanding tasks={}", self.tasks());
while self.tasks() > 0 {
log::debug!("wait: processing task={}", self.tasks());
if !self.process_next_task() {
std::thread::yield_now();
}
}
self.active.store(false, Ordering::SeqCst);
log::debug!("wait: completed, all tasks finished");
Ok(())
}
}
impl Runtime {
#[must_use]
pub fn new() -> Self {
let queue = Arc::new(Mutex::new(vec![]));
let mut this = Self {
id: RUNTIME_ID.fetch_add(1, Ordering::SeqCst),
spawner: Spawner {
queue: queue.clone(),
},
queue,
tasks: Arc::new(AtomicU64::new(0)),
active: Arc::new(AtomicBool::new(false)),
handle: None,
};
this.handle = Some(Handle {
runtime: Arc::new(this.clone()),
});
this
}
#[must_use]
pub fn handle(&self) -> Handle {
self.handle.clone().unwrap()
}
fn start(&self) {
if self.active.fetch_or(true, Ordering::SeqCst) {
return;
}
assert!(!RUNTIME.is_set(), "Cannot start a Runtime within a Runtime");
}
fn next_task(&self) -> Option<Arc<Task>> {
let mut queue = self.queue.lock().unwrap_or_else(PoisonError::into_inner);
let task_count = queue.len();
if task_count == 0 {
log::debug!("No tasks");
return None;
}
let index = queue
.iter()
.enumerate()
.filter(|(_, x)| x.block)
.map(|(i, _)| i)
.choose(&mut rng())
.unwrap_or_else(|| rng().gen_range(0..task_count));
log::debug!("next task index={index} task_count={task_count}");
Some(queue.remove(index))
}
pub(crate) fn process_next_task(&self) -> bool {
let Some(task) = self.next_task() else {
return false;
};
RUNTIME.set(self, || {
task.process();
});
true
}
pub fn tick(&self) {
self.process_next_task();
}
pub fn spawn<T: Send + 'static>(
&self,
future: impl Future<Output = T> + Send + 'static,
) -> JoinHandle<T> {
self.start();
RUNTIME.set(self, || self.spawner.spawn(self.clone(), future))
}
pub fn spawn_with_name<T: Send + 'static>(
&self,
name: &str,
future: impl Future<Output = T> + Send + 'static,
) -> JoinHandle<T> {
self.handle().spawn_with_name(name, future)
}
pub fn spawn_blocking<F, R>(&self, func: F) -> JoinHandle<R>
where
F: FnOnce() -> R + Send + 'static,
R: Send + 'static,
{
self.start();
RUNTIME.set(self, || self.spawner.spawn_blocking(self.clone(), func))
}
pub fn spawn_blocking_with_name<F, R>(&self, name: &str, func: F) -> JoinHandle<R>
where
F: FnOnce() -> R + Send + 'static,
R: Send + 'static,
{
self.handle().spawn_blocking_with_name(name, func)
}
pub fn spawn_local<T: 'static>(
&self,
future: impl Future<Output = T> + 'static,
) -> JoinHandle<T> {
self.start();
RUNTIME.set(self, || self.spawner.spawn_local(self.clone(), future))
}
fn active(&self) -> bool {
self.active.load(Ordering::SeqCst)
}
fn tasks(&self) -> u64 {
self.tasks.load(Ordering::SeqCst)
}
#[must_use]
pub fn current() -> Option<Self> {
if RUNTIME.is_set() {
Some(RUNTIME.with(Clone::clone))
} else {
None
}
}
}
pub struct JoinHandle<T> {
rx: futures::channel::oneshot::Receiver<T>,
#[allow(clippy::option_option)]
result: Option<Result<T, task::JoinError>>,
finished: bool,
#[allow(unused)]
aborted: bool,
}
impl<T: Send + Unpin> JoinHandle<T> {
#[must_use]
pub fn is_finished(&mut self) -> bool {
if self.finished {
return true;
}
let waker = futures::task::noop_waker();
let mut cx = Context::from_waker(&waker);
let receiver = Pin::new(&mut self.rx);
match receiver.poll(&mut cx) {
Poll::Ready(x) => {
self.finished = true;
self.result = Some(x.map_err(|_| task::JoinError::new()));
true
}
Poll::Pending => false,
}
}
pub fn abort(&self) {
log::debug!("JoinHandle::abort() called (no-op in simulator)");
}
}
impl<T: Send + Unpin> Future for JoinHandle<T> {
type Output = Result<T, task::JoinError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
if let Some(result) = self.as_mut().result.take() {
return Poll::Ready(result);
}
let receiver = Pin::new(&mut self.get_mut().rx);
match receiver.poll(cx) {
Poll::Ready(x) => Poll::Ready(x.map_err(|_| task::JoinError::new())),
Poll::Pending => Poll::Pending,
}
}
}
#[derive(Debug, Clone)]
pub(crate) struct Spawner {
queue: Queue,
}
impl Spawner {
fn spawn<T: Send + 'static>(
&self,
runtime: Runtime,
future: impl Future<Output = T> + Send + 'static,
) -> JoinHandle<T> {
let (tx, rx) = futures::channel::oneshot::channel();
let wrapped = async move {
let _ = tx.send(future.await);
};
self.inner_spawn(&Task::new(runtime, false, wrapped));
JoinHandle {
rx,
result: None,
finished: false,
aborted: false,
}
}
fn spawn_blocking<F, R>(&self, runtime: Runtime, func: F) -> JoinHandle<R>
where
F: FnOnce() -> R + Send + 'static,
R: Send + 'static,
{
log::trace!("spawn_blocking");
let (tx, rx) = futures::channel::oneshot::channel();
let wrapped = async move {
let _ = tx.send(func());
};
self.inner_spawn_blocking(&Task::new(runtime, true, wrapped));
JoinHandle {
rx,
result: None,
finished: false,
aborted: false,
}
}
fn spawn_local<T: 'static>(
&self,
runtime: Runtime,
future: impl Future<Output = T> + 'static,
) -> JoinHandle<T> {
log::trace!("spawn_local");
let (tx, rx) = futures::channel::oneshot::channel();
let wrapped = LocalFutureProxy::new(future, tx);
self.inner_spawn(&Task::new(runtime, false, wrapped));
JoinHandle {
rx,
result: None,
finished: false,
aborted: false,
}
}
fn inner_spawn(&self, task: &Arc<Task>) {
log::trace!("inner_spawn");
self.add_task(task);
}
fn inner_spawn_blocking(&self, task: &Arc<Task>) {
log::trace!("inner_spawn_blocking");
self.add_task(task);
}
fn add_task(&self, task: &Arc<Task>) {
log::trace!("add_task");
if !self.queue.lock().unwrap().iter().all(|x| x.id != task.id) {
return;
}
self.queue
.lock()
.unwrap_or_else(PoisonError::into_inner)
.push(task.clone());
}
}
pub fn spawn<T: Send + 'static>(future: impl Future<Output = T> + Send + 'static) -> JoinHandle<T> {
RUNTIME.with(|runtime| runtime.spawn(future))
}
pub fn spawn_local<T: 'static>(future: impl Future<Output = T> + 'static) -> JoinHandle<T> {
RUNTIME.with(|runtime| runtime.spawn_local(future))
}
pub fn spawn_blocking<F, R>(func: F) -> JoinHandle<R>
where
F: FnOnce() -> R + Send + 'static,
R: Send + 'static,
{
RUNTIME.with(|runtime| runtime.spawn_blocking(func))
}
pub fn block_on<F: Future>(future: F) -> F::Output {
RUNTIME.with(|runtime| runtime.block_on(future))
}
pub fn wait() -> Result<(), Error> {
RUNTIME.with(|runtime| runtime.clone().wait())
}
struct Task {
id: u64,
runtime: Runtime,
future: Mutex<Pin<Box<dyn Future<Output = ()> + Send + 'static>>>,
finished: AtomicBool,
block: bool,
}
impl std::fmt::Debug for Task {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Task")
.field("id", &self.id)
.field("finished", &self.finished)
.field("block", &self.block)
.finish_non_exhaustive()
}
}
impl Task {
fn new(
runtime: Runtime,
block: bool,
future: impl Future<Output = ()> + Send + 'static,
) -> Arc<Self> {
runtime.tasks.fetch_add(1, Ordering::SeqCst);
Arc::new(Self {
id: TASK_ID.fetch_add(1, Ordering::SeqCst),
runtime,
future: Mutex::new(Box::pin(future)),
finished: AtomicBool::new(false),
block,
})
}
fn waker(self: &Arc<Self>) -> Waker {
self.clone().into()
}
fn poll(self: &Arc<Self>) -> Poll<()> {
if self.finished() {
return Poll::Ready(());
}
let waker = self.waker();
let mut ctx = Context::from_waker(&waker);
#[allow(clippy::significant_drop_in_scrutinee)]
match self
.future
.lock()
.unwrap_or_else(PoisonError::into_inner)
.as_mut()
.poll(&mut ctx)
{
Poll::Ready(x) => {
self.finished.store(true, Ordering::SeqCst);
Poll::Ready(x)
}
Poll::Pending => Poll::Pending,
}
}
fn process(self: Arc<Self>) {
if !self.runtime.active() {
return;
}
if self.finished() {
return;
}
if self.block {
while self.poll().is_pending() {
if !self.runtime.process_next_task() {
std::thread::yield_now();
}
}
} else {
let _ = self.poll();
}
}
fn finished(&self) -> bool {
self.finished.load(Ordering::SeqCst)
}
}
impl Drop for Task {
fn drop(&mut self) {
RUNTIME.with(|runtime| runtime.tasks.fetch_sub(1, Ordering::SeqCst));
}
}
impl Wake for Task {
fn wake(self: Arc<Self>) {
log::trace!("wake");
assert!(
self.runtime.active(),
"Attempted to wake on an inactive Runtime"
);
if self.block {
self.runtime.spawner.inner_spawn_blocking(&self);
} else {
self.runtime.spawner.inner_spawn(&self);
}
}
}
#[allow(clippy::unnecessary_wraps)]
pub(crate) fn build_runtime(_builder: &Builder) -> Result<Runtime, Error> {
Ok(Runtime::new())
}
#[cfg(test)]
mod test {
#[allow(unused)]
use pretty_assertions::{assert_eq, assert_ne};
use std::sync::{Arc, Mutex};
use crate::{
runtime::Builder,
simulator::runtime::{Handle, Runtime, build_runtime},
task,
};
#[test_log::test]
fn rt_current_thread_runtime_spawns_on_same_thread() {
let runtime = build_runtime(&Builder::new()).unwrap();
let thread_id = std::thread::current().id();
runtime.block_on(async move {
task::spawn(async move { assert_eq!(std::thread::current().id(), thread_id) });
});
runtime.wait().unwrap();
}
#[test_log::test]
fn rt_spawn_local_works_with_non_send() {
let runtime = build_runtime(&Builder::new()).unwrap();
runtime.block_on(async move {
use std::cell::RefCell;
use std::rc::Rc;
let data = Rc::new(RefCell::new(42));
let data_clone = data.clone();
let handle = task::spawn_local(async move {
*data_clone.borrow_mut() += 1;
*data_clone.borrow()
});
let result = handle.await.unwrap();
assert_eq!(result, 43);
assert_eq!(*data.borrow(), 43);
});
runtime.wait().unwrap();
}
#[test_log::test]
fn rt_current_thread_runtime_block_on_same_thread() {
let runtime = build_runtime(&Builder::new()).unwrap();
let thread_id = std::thread::current().id();
runtime.block_on(async move {
assert_eq!(std::thread::current().id(), thread_id);
});
runtime.wait().unwrap();
}
#[cfg(feature = "rt-multi-thread")]
#[test_log::test]
fn rt_multi_thread_runtime_spawns_on_same_thread() {
let runtime = build_runtime(Builder::new().max_blocking_threads(1)).unwrap();
let thread_id = std::thread::current().id();
runtime.block_on(async move {
task::spawn(async move { assert_eq!(std::thread::current().id(), thread_id) });
});
runtime.wait().unwrap();
}
#[cfg(feature = "rt-multi-thread")]
#[test_log::test]
fn rt_multi_thread_runtime_block_on_same_thread() {
let runtime = build_runtime(Builder::new().max_blocking_threads(1)).unwrap();
let thread_id = std::thread::current().id();
runtime.block_on(async move {
assert_eq!(std::thread::current().id(), thread_id);
});
runtime.wait().unwrap();
}
#[test_log::test]
fn runtime_tick_processes_single_task() {
let runtime = build_runtime(&Builder::new()).unwrap();
let completed = Arc::new(Mutex::new(false));
let completed_clone = Arc::clone(&completed);
runtime.spawn(async move {
*completed_clone.lock().unwrap() = true;
});
runtime.tick();
std::thread::sleep(std::time::Duration::from_millis(10));
assert!(*completed.lock().unwrap());
runtime.wait().unwrap();
}
#[test_log::test]
fn runtime_current_returns_none_outside_runtime() {
let current = Runtime::current();
assert!(current.is_none());
}
#[test_log::test]
fn runtime_current_returns_some_inside_runtime() {
let runtime = build_runtime(&Builder::new()).unwrap();
runtime.block_on(async {
let current = Runtime::current();
assert!(current.is_some());
});
runtime.wait().unwrap();
}
#[test_log::test]
fn runtime_equality_based_on_id() {
let runtime1 = build_runtime(&Builder::new()).unwrap();
let runtime2 = build_runtime(&Builder::new()).unwrap();
assert_eq!(runtime1, runtime1.clone());
assert_ne!(runtime1, runtime2);
}
#[test_log::test]
fn runtime_default_is_same_as_new() {
let runtime1 = Runtime::default();
let runtime2 = Runtime::new();
let result1 = runtime1.block_on(async { 42 });
let result2 = runtime2.block_on(async { 42 });
assert_eq!(result1, result2);
runtime1.wait().unwrap();
runtime2.wait().unwrap();
}
#[test_log::test]
fn handle_spawn_executes_task() {
let runtime = build_runtime(&Builder::new()).unwrap();
let handle = runtime.handle();
let join_handle = handle.spawn(async { 42 });
let result = runtime.block_on(async { join_handle.await.unwrap() });
assert_eq!(result, 42);
runtime.wait().unwrap();
}
#[test_log::test]
fn handle_spawn_blocking_executes_blocking_code() {
let runtime = build_runtime(&Builder::new()).unwrap();
let handle = runtime.handle();
let join_handle = handle.spawn_blocking(|| {
42
});
let result = runtime.block_on(async { join_handle.await.unwrap() });
assert_eq!(result, 42);
runtime.wait().unwrap();
}
#[test_log::test]
fn runtime_spawn_with_name_executes_task() {
let runtime = build_runtime(&Builder::new()).unwrap();
let join_handle = runtime.spawn_with_name("test_task", async { 123 });
let result = runtime.block_on(async { join_handle.await.unwrap() });
assert_eq!(result, 123);
runtime.wait().unwrap();
}
#[test_log::test]
fn runtime_spawn_blocking_with_name_executes_task() {
let runtime = build_runtime(&Builder::new()).unwrap();
let join_handle = runtime.spawn_blocking_with_name("blocking_task", || 456);
let result = runtime.block_on(async { join_handle.await.unwrap() });
assert_eq!(result, 456);
runtime.wait().unwrap();
}
#[test_log::test]
fn join_handle_is_finished_detects_completion() {
let runtime = build_runtime(&Builder::new()).unwrap();
let mut join_handle = runtime.spawn(async { 42 });
assert!(!join_handle.is_finished());
let _result = runtime.block_on(async { join_handle.await.unwrap() });
let join_handle2 = runtime.spawn(async { 10 });
runtime.block_on(async {
let _ = join_handle2.await;
});
runtime.wait().unwrap();
}
#[test_log::test]
fn join_handle_abort_is_noop() {
let runtime = build_runtime(&Builder::new()).unwrap();
let join_handle = runtime.spawn(async {
42
});
join_handle.abort();
let result = runtime.block_on(async { join_handle.await.unwrap() });
assert_eq!(result, 42);
runtime.wait().unwrap();
}
#[test_log::test]
fn handle_spawn_local_with_name_executes_task() {
let runtime = build_runtime(&Builder::new()).unwrap();
runtime.block_on(async {
use std::cell::RefCell;
use std::rc::Rc;
let data = Rc::new(RefCell::new(10));
let data_clone = data.clone();
let handle = Handle::current();
let join_handle = handle.spawn_local_with_name("local_task", async move {
*data_clone.borrow_mut() += 5;
*data_clone.borrow()
});
let result = join_handle.await.unwrap();
assert_eq!(result, 15);
assert_eq!(*data.borrow(), 15);
});
runtime.wait().unwrap();
}
#[test_log::test]
fn runtime_spawn_local_executes_non_send_future() {
let runtime = build_runtime(&Builder::new()).unwrap();
runtime.block_on(async {
use std::cell::RefCell;
use std::rc::Rc;
let data = Rc::new(RefCell::new(vec![1, 2, 3]));
let data_clone = data.clone();
let handle = runtime.spawn_local(async move {
data_clone.borrow_mut().push(4);
data_clone.borrow().len()
});
let len = handle.await.unwrap();
assert_eq!(len, 4);
assert_eq!(data.borrow().len(), 4);
});
runtime.wait().unwrap();
}
#[test_log::test]
fn local_future_proxy_handles_drop_correctly() {
let runtime = build_runtime(&Builder::new()).unwrap();
runtime.block_on(async {
use std::cell::RefCell;
use std::rc::Rc;
let data = Rc::new(RefCell::new(false));
let _handle = runtime.spawn_local(async move {
*data.borrow_mut() = true;
});
});
runtime.wait().unwrap();
}
#[test_log::test]
fn join_error_display_formatting() {
let err = task::JoinError::new();
assert_eq!(err.to_string(), "JoinError");
}
#[test_log::test]
fn join_error_is_clonable() {
let err1 = task::JoinError::new();
let err2 = err1.clone();
assert_eq!(err1.to_string(), err2.to_string());
}
}