use std::cell::RefCell;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use futures_executor::{LocalPool, LocalSpawner};
use futures_util::task::{waker_ref, ArcWake, LocalSpawnExt};
thread_local! {
static POOL: RefCell<LocalPool> = RefCell::new(LocalPool::new());
static SPAWNER: RefCell<LocalSpawner> = RefCell::new({
POOL.with(|p| p.borrow().spawner())
});
}
type DriveHook = fn();
static DRIVE_HOOK: std::sync::Mutex<DriveHook> = std::sync::Mutex::new(request_main_loop_drive);
fn request_main_loop_drive() {
crate::main_thread::run_on_main_thread(|| {});
}
#[doc(hidden)]
pub fn set_drive_hook(hook: DriveHook) -> DriveHook {
let mut guard = DRIVE_HOOK.lock().unwrap_or_else(|e| e.into_inner());
std::mem::replace(&mut *guard, hook)
}
fn invoke_drive_hook() {
let hook = *DRIVE_HOOK.lock().unwrap_or_else(|e| e.into_inner());
hook();
}
struct DriveWaker {
inner: std::task::Waker,
}
impl ArcWake for DriveWaker {
fn wake_by_ref(arc_self: &Arc<Self>) {
arc_self.inner.wake_by_ref();
invoke_drive_hook();
}
}
struct DriveBridged<F> {
future: F,
}
impl<F: Future<Output = ()>> Future for DriveBridged<F> {
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
let future = unsafe { self.map_unchecked_mut(|s| &mut s.future) };
let drive_waker = Arc::new(DriveWaker {
inner: cx.waker().clone(),
});
let waker = waker_ref(&drive_waker);
let mut cx = Context::from_waker(&waker);
future.poll(&mut cx)
}
}
pub fn spawn_local<F>(future: F)
where
F: Future<Output = ()> + 'static,
{
SPAWNER.with(|s| {
s.borrow()
.spawn_local(DriveBridged { future })
.expect("whisker tasks: local pool is shut down");
});
crate::host_wake::wake_runtime();
}
pub fn run_until_stalled() {
POOL.with(|p| {
p.borrow_mut().run_until_stalled();
});
}
pub fn run_blocking<F, T>(closure: F) -> impl Future<Output = T>
where
F: FnOnce() -> T + Send + 'static,
T: Send + 'static,
{
let (tx, rx) = futures_channel::oneshot::channel::<T>();
std::thread::spawn(move || {
let value = closure();
crate::main_thread::run_on_main_thread(move || {
let _ = tx.send(value);
});
});
BlockingResult { rx }
}
struct BlockingResult<T> {
rx: futures_channel::oneshot::Receiver<T>,
}
impl<T> Future for BlockingResult<T> {
type Output = T;
fn poll(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<T> {
use std::task::Poll;
match std::pin::Pin::new(&mut self.rx).poll(cx) {
Poll::Ready(Ok(v)) => Poll::Ready(v),
Poll::Ready(Err(_)) => Poll::Pending,
Poll::Pending => Poll::Pending,
}
}
}
#[doc(hidden)]
pub fn __reset_for_tests() {
POOL.with(|p| *p.borrow_mut() = LocalPool::new());
SPAWNER.with(|s| {
POOL.with(|p| {
*s.borrow_mut() = p.borrow().spawner();
});
});
}
#[cfg(test)]
mod tests {
use super::*;
use crate::main_thread::{set_main_thread_dispatcher, DispatchFn};
use std::cell::Cell;
use std::ffi::c_void;
use std::rc::Rc;
use std::sync::MutexGuard;
fn lock<'a>() -> MutexGuard<'a, ()> {
crate::main_thread::host_test_lock()
}
fn reset_all() {
__reset_for_tests();
crate::main_thread::__reset_for_tests();
}
extern "C" fn sync_invoke(
_engine: *mut c_void,
callback: extern "C" fn(*mut c_void),
user_data: *mut c_void,
) -> bool {
callback(user_data);
true
}
fn install_sync_dispatcher() {
set_main_thread_dispatcher(Some(sync_invoke as DispatchFn), std::ptr::null_mut());
}
#[test]
fn spawn_local_does_not_block_poll_at_call_time() {
let _g = lock();
reset_all();
let flag = Rc::new(Cell::new(false));
let f = flag.clone();
spawn_local(async move {
f.set(true);
});
assert!(!flag.get(), "spawn should not poll synchronously");
run_until_stalled();
assert!(flag.get(), "tick should drive the task to completion");
}
#[test]
fn run_until_stalled_drains_multiple_independent_tasks() {
let _g = lock();
reset_all();
let counter = Rc::new(Cell::new(0));
for _ in 0..5 {
let c = counter.clone();
spawn_local(async move {
c.set(c.get() + 1);
});
}
run_until_stalled();
assert_eq!(counter.get(), 5);
}
struct Yielder {
phase: Rc<Cell<i32>>,
polled_once: bool,
}
impl Future for Yielder {
type Output = ();
fn poll(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<()> {
if !self.polled_once {
self.polled_once = true;
self.phase.set(1);
cx.waker().wake_by_ref();
std::task::Poll::Pending
} else {
self.phase.set(2);
std::task::Poll::Ready(())
}
}
}
#[test]
fn run_until_stalled_resumes_self_woken_tasks_within_one_call() {
let _g = lock();
reset_all();
let phase = Rc::new(Cell::new(0));
let phase_for_task = phase.clone();
spawn_local(async move {
Yielder {
phase: phase_for_task,
polled_once: false,
}
.await;
});
run_until_stalled();
assert_eq!(phase.get(), 2);
}
#[test]
fn run_blocking_returns_value_from_worker_thread() {
let _g = lock();
reset_all();
install_sync_dispatcher();
let got: Rc<RefCell<Option<i32>>> = Rc::new(RefCell::new(None));
let got_for_task = got.clone();
spawn_local(async move {
let v = run_blocking(|| 42_i32).await;
*got_for_task.borrow_mut() = Some(v);
});
let deadline = std::time::Instant::now() + std::time::Duration::from_secs(2);
while got.borrow().is_none() && std::time::Instant::now() < deadline {
run_until_stalled();
std::thread::sleep(std::time::Duration::from_millis(5));
}
assert_eq!(*got.borrow(), Some(42));
crate::main_thread::__reset_for_tests();
}
#[test]
fn run_on_main_thread_trampoline_wakes_runtime() {
use std::sync::atomic::{AtomicBool, Ordering};
let _g = lock();
reset_all();
install_sync_dispatcher();
static WOKE: AtomicBool = AtomicBool::new(false);
WOKE.store(false, Ordering::SeqCst);
extern "C" fn wake_cb(_: *mut c_void) {
WOKE.store(true, Ordering::SeqCst);
}
crate::host_wake::set_request_frame_callback(Some(wake_cb), std::ptr::null_mut());
crate::main_thread::run_on_main_thread(|| {});
assert!(
WOKE.load(Ordering::SeqCst),
"run_on_main_thread's trampoline must wake the runtime \
after the closure runs — otherwise the awaiting future \
never gets re-polled (hn-reader Loading-stuck bug)"
);
crate::host_wake::__reset_for_tests();
crate::main_thread::__reset_for_tests();
}
#[test]
fn run_blocking_future_parks_when_no_dispatcher_registered() {
let _g = lock();
reset_all();
let polled = Rc::new(Cell::new(false));
let polled_for_task = polled.clone();
spawn_local(async move {
let _v: () = run_blocking(|| {}).await;
polled_for_task.set(true);
});
for _ in 0..20 {
run_until_stalled();
std::thread::sleep(std::time::Duration::from_millis(5));
}
assert!(
!polled.get(),
"task body should NOT have completed without a dispatcher \
(cancel-on-dispose semantics)"
);
}
struct ForeignSignal {
done: Arc<std::sync::atomic::AtomicBool>,
waker_slot: Arc<std::sync::Mutex<Option<std::task::Waker>>>,
}
impl Future for ForeignSignal {
type Output = ();
fn poll(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<()> {
if self.done.load(std::sync::atomic::Ordering::SeqCst) {
std::task::Poll::Ready(())
} else {
*self.waker_slot.lock().unwrap() = Some(cx.waker().clone());
std::task::Poll::Pending
}
}
}
#[test]
fn foreign_thread_wake_invokes_drive_hook_and_resumes_task() {
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
let _g = lock();
reset_all();
static DRIVE_POKES: AtomicUsize = AtomicUsize::new(0);
DRIVE_POKES.store(0, Ordering::SeqCst);
fn test_hook() {
DRIVE_POKES.fetch_add(1, Ordering::SeqCst);
}
let prev = set_drive_hook(test_hook);
let done = Arc::new(AtomicBool::new(false));
let waker_slot: Arc<std::sync::Mutex<Option<std::task::Waker>>> =
Arc::new(std::sync::Mutex::new(None));
let completed = Rc::new(Cell::new(false));
let done_for_task = done.clone();
let slot_for_task = waker_slot.clone();
let completed_for_task = completed.clone();
spawn_local(async move {
ForeignSignal {
done: done_for_task,
waker_slot: slot_for_task,
}
.await;
completed_for_task.set(true);
});
run_until_stalled();
assert!(!completed.get(), "task should be parked awaiting signal");
assert!(
waker_slot.lock().unwrap().is_some(),
"task must have stashed its waker on the first poll"
);
let done_for_thread = done.clone();
let slot_for_thread = waker_slot.clone();
let handle = std::thread::spawn(move || {
done_for_thread.store(true, Ordering::SeqCst);
let waker = slot_for_thread.lock().unwrap().take().unwrap();
waker.wake();
});
handle.join().unwrap();
assert!(
DRIVE_POKES.load(Ordering::SeqCst) >= 1,
"foreign-thread wake must invoke the drive hook so the main \
loop re-polls the pool (issue #7)"
);
run_until_stalled();
assert!(
completed.get(),
"task must resume and complete after the foreign-thread wake \
drove a re-poll"
);
set_drive_hook(prev);
}
#[test]
fn reset_clears_pending_tasks() {
let _g = lock();
reset_all();
let counter = Rc::new(Cell::new(0));
let c = counter.clone();
spawn_local(async move {
c.set(c.get() + 1);
});
__reset_for_tests();
run_until_stalled();
assert_eq!(counter.get(), 0, "reset should drop pending tasks");
}
#[test]
fn spawn_local_after_reset_uses_fresh_spawner() {
let _g = lock();
reset_all();
__reset_for_tests();
let flag = Rc::new(Cell::new(false));
let f = flag.clone();
spawn_local(async move {
f.set(true);
});
run_until_stalled();
assert!(
flag.get(),
"spawner must be re-bound to the new pool after reset"
);
}
}