#![warn(clippy::undocumented_unsafe_blocks)]
use std::cell::UnsafeCell;
use std::mem::ManuallyDrop;
use std::ptr::{self, NonNull};
use std::task::{Context, RawWaker, RawWakerVTable, Waker};
use crate::lock::RwLock;
use crate::task::Header;
pub(crate) fn poll_with_ref<F, R>(header: NonNull<Header>, f: F) -> R
where
F: FnOnce(&mut Context<'_>) -> R,
{
let waker = ManuallyDrop::new(unsafe { Waker::from_raw(header_to_raw_waker(header)) });
let mut cx = Context::from_waker(&waker);
f(&mut cx)
}
fn header_to_raw_waker(header: NonNull<Header>) -> RawWaker {
const VTABLE: &RawWakerVTable = &RawWakerVTable::new(clone, wake, wake_by_ref, drop);
unsafe fn clone(ptr: *const ()) -> RawWaker {
(*ptr.cast::<Header>()).increment_ref();
RawWaker::new(ptr, VTABLE)
}
unsafe fn wake(ptr: *const ()) {
let header = NonNull::new_unchecked(ptr.cast_mut().cast::<Header>());
{
let header = header.as_ref();
let shared = header.shared();
shared.wake_set.wake(header.index());
shared.waker.wake_by_ref();
}
if header.as_ref().decrement_ref() {
Header::deallocate(header);
}
}
unsafe fn wake_by_ref(ptr: *const ()) {
let header = &*(ptr as *const Header);
let shared = header.shared();
shared.wake_set.wake(header.index());
shared.waker.wake_by_ref();
}
unsafe fn drop(ptr: *const ()) {
let header = NonNull::new_unchecked(ptr.cast_mut().cast::<Header>());
if header.as_ref().decrement_ref() {
Header::deallocate(header);
}
}
RawWaker::new(header.as_ptr().cast_const().cast(), VTABLE)
}
pub(crate) struct SharedWaker {
lock: RwLock,
waker: UnsafeCell<Waker>,
}
impl SharedWaker {
pub(crate) fn new() -> Self {
Self {
lock: RwLock::new(),
waker: UnsafeCell::new(noop_waker()),
}
}
pub(crate) fn wake_by_ref(&self) {
if let Some(_guard) = self.lock.try_lock_shared() {
let waker = unsafe { &*self.waker.get() };
waker.wake_by_ref();
}
}
pub(crate) unsafe fn swap(&self, waker: &Waker) -> bool {
let shared_waker = self.waker.get();
if (*shared_waker).will_wake(waker) {
return true;
}
if let Some(_guard) = self.lock.try_lock_exclusive_guard() {
(*self.waker.get()).clone_from(waker);
return true;
}
waker.wake_by_ref();
false
}
}
fn noop_waker() -> Waker {
unsafe { Waker::from_raw(noop_raw_waker()) }
}
fn noop_raw_waker() -> RawWaker {
return RawWaker::new(
ptr::null(),
&RawWakerVTable::new(noop_clone, noop_wake, noop_wake_by_ref, noop_drop),
);
fn noop_clone(_: *const ()) -> RawWaker {
noop_raw_waker()
}
fn noop_wake(_: *const ()) {}
fn noop_wake_by_ref(_: *const ()) {}
fn noop_drop(_: *const ()) {}
}
#[cfg(test)]
mod test {
use crate::task::Storage;
use crate::FuturesUnordered;
use super::poll_with_ref;
use futures::future::poll_fn;
use futures::Future;
use std::cell::RefCell;
use std::mem;
use std::pin::Pin;
use std::rc::Rc;
use std::task::Context;
use std::task::{Poll, Waker};
#[test]
fn basic_waker() {
let mut slab = Storage::new();
slab.insert(());
let (header, _) = slab.get_pin_mut(0).unwrap();
poll_with_ref(header, |_| ())
}
#[test]
fn clone_waker() {
struct GetWaker;
impl Future for GetWaker {
type Output = Waker;
fn poll(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Self::Output> {
Poll::Ready(cx.waker().clone())
}
}
block_on::block_on(async {
let mut futures = FuturesUnordered::new();
futures.push(GetWaker);
futures.next().await.unwrap();
});
}
#[test]
fn long_lived_waker() {
struct GetWaker;
impl Future for GetWaker {
type Output = Waker;
fn poll(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Self::Output> {
Poll::Ready(cx.waker().clone())
}
}
let waker = block_on::block_on(async {
let mut futures = FuturesUnordered::new();
futures.push(GetWaker);
futures.next().await.unwrap()
});
waker.wake();
}
#[test]
fn many_wakers() {
block_on::block_on(async {
let mut futures = FuturesUnordered::new();
let wake1 = Rc::new(RefCell::new(None));
let wake2: Rc<RefCell<Option<Waker>>> = Rc::new(RefCell::new(None));
let woken = Rc::new(RefCell::new(false));
{
let woken = woken.clone();
let wake1 = wake1.clone();
let wake2 = wake2.clone();
futures.push(FakeDynFuture::new(poll_fn(move |cx| {
if *woken.borrow() {
Poll::Ready(())
} else {
if wake1.borrow().is_none() {
*wake1.borrow_mut() = Some(cx.waker().clone());
}
if let Some(waker) = wake2.borrow().as_ref() {
waker.wake_by_ref()
}
Poll::Pending
}
})));
}
poll_fn(|cx| {
assert_eq!(
crate::PollNext::poll_next(Pin::new(&mut futures), cx),
Poll::Pending
);
Poll::Ready(())
})
.await;
for _ in 0..127 {
futures.push(FakeDynFuture::new(poll_fn(|cx| {
let _ = cx.waker().clone();
Poll::Ready(())
})));
}
futures.push(FakeDynFuture::new(poll_fn(move |cx| {
match &*wake1.borrow() {
Some(waker) => {
*woken.borrow_mut() = true;
waker.wake_by_ref();
Poll::Ready(())
}
None => {
*wake2.borrow_mut() = Some(cx.waker().clone());
Poll::Pending
}
}
})));
while futures.next().await.is_some() {}
})
}
struct FakeDynFuture<T> {
future: *const (),
poll_fn: fn(this: *const (), cx: &mut Context) -> Poll<T>,
drop_fn: fn(this: *const ()),
}
impl<T> FakeDynFuture<T> {
fn new<F: Future<Output = T>>(fut: F) -> Self {
Self {
future: Box::into_raw(Box::new(fut)) as *const _,
poll_fn: |this, cx| {
let this = unsafe { mem::transmute::<*const (), Pin<&mut F>>(this) };
this.poll(cx)
},
drop_fn: |this| {
unsafe {
mem::transmute::<*const (), Pin<Box<F>>>(this);
}
},
}
}
}
impl<T> Future for FakeDynFuture<T> {
type Output = T;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
(self.poll_fn)(self.future, cx)
}
}
impl<T> Drop for FakeDynFuture<T> {
fn drop(&mut self) {
(self.drop_fn)(self.future)
}
}
mod block_on {
use futures::Future;
use std::sync::Arc;
use std::task::{Context, Poll, Wake};
use std::thread::{self, Thread};
struct ThreadWaker(Thread);
impl Wake for ThreadWaker {
fn wake(self: Arc<Self>) {
self.0.unpark();
}
}
pub(super) fn block_on<T>(fut: impl Future<Output = T>) -> T {
let mut fut = Box::pin(fut);
let t = thread::current();
let waker = Arc::new(ThreadWaker(t)).into();
let mut cx = Context::from_waker(&waker);
loop {
match fut.as_mut().poll(&mut cx) {
Poll::Ready(res) => return res,
Poll::Pending => thread::park(),
}
}
}
}
}