use futures::future::poll_fn;
use shuttle::sync::atomic::{AtomicBool, Ordering};
use shuttle::sync::Mutex;
use shuttle::{check_dfs, future, thread};
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll, Waker};
use test_log::test;
#[test]
fn wake_after_finish() {
#[derive(Clone)]
struct Future1 {
waker: std::sync::Arc<std::sync::Mutex<Option<Waker>>>,
}
impl Future1 {
fn new() -> Self {
Self {
waker: std::sync::Arc::new(std::sync::Mutex::new(None)),
}
}
}
impl Future for Future1 {
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
*self.waker.lock().unwrap() = Some(cx.waker().clone());
Poll::Ready(())
}
}
check_dfs(
|| {
let future1 = Future1::new();
let future1_clone = future1.clone();
future::block_on(async move {
future1_clone.await;
});
let waker = future1.waker.lock().unwrap().take();
if let Some(waker) = waker {
waker.wake();
}
},
None,
)
}
#[test]
fn wake_during_poll() {
check_dfs(
|| {
let waker: Arc<Mutex<Option<Waker>>> = Arc::new(Mutex::new(None));
let waker_clone = Arc::clone(&waker);
let signal = Arc::new(AtomicBool::new(false));
let signal_clone = Arc::clone(&signal);
thread::spawn(move || {
signal_clone.store(true, Ordering::SeqCst);
if let Some(waker) = waker_clone.lock().unwrap().take() {
waker.wake();
}
});
future::block_on(poll_fn(move |cx| {
*waker.lock().unwrap() = Some(cx.waker().clone());
if signal.load(Ordering::SeqCst) {
Poll::Ready(())
} else {
Poll::Pending
}
}));
},
None,
);
}
#[test]
fn wake_during_blocked_poll() {
static RAN_WAKER: std::sync::atomic::AtomicBool = std::sync::atomic::AtomicBool::new(false);
check_dfs(
|| {
let waker: Arc<Mutex<Option<Waker>>> = Arc::new(Mutex::new(None));
let waker_clone = Arc::clone(&waker);
let counter = Arc::new(Mutex::new(0));
let counter_clone = Arc::clone(&counter);
thread::spawn(move || {
let mut counter = counter_clone.lock().unwrap();
thread::yield_now();
*counter += 1;
});
thread::spawn(move || {
if let Some(waker) = waker_clone.lock().unwrap().take() {
RAN_WAKER.store(true, Ordering::SeqCst);
waker.wake();
}
});
future::block_on(poll_fn(move |cx| {
*waker.lock().unwrap() = Some(cx.waker().clone());
let mut counter = counter.lock().unwrap();
*counter += 1;
Poll::Ready(())
}));
},
None,
);
assert!(RAN_WAKER.load(Ordering::SeqCst), "waker was not invoked by any test");
}