use std::cell::Cell;
use std::mem;
use std::ops::Deref;
use std::rc::Rc;
use std::task::{RawWaker, RawWakerVTable, Waker};
pub trait RcWake {
fn wake(self: Rc<Self>);
fn wake_by_ref(self: &Rc<Self>) {
self.clone().wake();
}
}
pub fn into_std<W: RcWake>(waker: Rc<W>) -> Waker {
unsafe { Waker::from_raw(raw_waker(waker)) }
}
#[inline(always)]
fn raw_waker<W: RcWake>(waker: Rc<W>) -> RawWaker {
unsafe fn clone_waker<W: RcWake>(waker: *const ()) -> RawWaker {
let waker = mem::ManuallyDrop::new(Rc::from_raw(waker as *const W));
let waker = Rc::clone(&waker);
RawWaker::new(
Rc::into_raw(waker) as *const (),
&RawWakerVTable::new(
clone_waker::<W>,
wake::<W>,
wake_by_ref::<W>,
drop_waker::<W>,
),
)
}
unsafe fn wake<W: RcWake>(waker: *const ()) {
let waker = Rc::from_raw(waker as *const W);
<W as RcWake>::wake(waker);
}
unsafe fn wake_by_ref<W: RcWake>(waker: *const ()) {
let waker = mem::ManuallyDrop::new(Rc::from_raw(waker as *const W));
<W as RcWake>::wake_by_ref(&waker);
}
unsafe fn drop_waker<W: RcWake>(waker: *const ()) {
Rc::from_raw(waker as *const W);
}
RawWaker::new(
Rc::into_raw(waker) as *const (),
&RawWakerVTable::new(
clone_waker::<W>,
wake::<W>,
wake_by_ref::<W>,
drop_waker::<W>,
),
)
}
pub trait RefWake {
fn wake(&self);
}
pub struct RefWakerData<W> {
refs: Cell<usize>,
w: W,
}
impl<W: RefWake> RefWakerData<W> {
pub fn new(w: W) -> Self {
Self {
refs: Cell::new(1),
w,
}
}
}
pub struct RefWaker<'a, W: RefWake + 'a> {
data: &'a RefWakerData<W>,
}
impl<'a, W: RefWake + 'a> RefWaker<'a, W> {
pub fn new(data: &'a RefWakerData<W>) -> Self {
Self { data }
}
pub fn ref_count(&self) -> usize {
self.data.refs.get()
}
pub fn as_std<'b>(&self, scratch: &'b mut mem::MaybeUninit<Waker>) -> &'b Waker {
let rw = RawWaker::new(
self.data as *const RefWakerData<W> as *const (),
&RawWakerVTable::new(Self::clone, Self::wake, Self::wake_by_ref, Self::drop),
);
let waker = unsafe { Waker::from_raw(rw) };
scratch.write(waker);
unsafe { scratch.assume_init_ref() }
}
unsafe fn clone(data: *const ()) -> RawWaker {
let data_ref = (data as *const RefWakerData<W>).as_ref().unwrap();
data_ref.refs.set(data_ref.refs.get() + 1);
RawWaker::new(
data,
&RawWakerVTable::new(Self::clone, Self::wake, Self::wake_by_ref, Self::drop),
)
}
unsafe fn wake(data: *const ()) {
Self::wake_by_ref(data);
Self::drop(data);
}
unsafe fn wake_by_ref(data: *const ()) {
let data_ref = (data as *const RefWakerData<W>).as_ref().unwrap();
data_ref.w.wake();
}
unsafe fn drop(data: *const ()) {
let data_ref = (data as *const RefWakerData<W>).as_ref().unwrap();
let refs = data_ref.refs.get();
assert!(refs > 1);
data_ref.refs.set(refs - 1);
}
}
impl<'a, W: RefWake + 'a> Drop for RefWaker<'a, W> {
fn drop(&mut self) {
assert_eq!(self.data.refs.get(), 1);
}
}
impl<'a, W: RefWake + 'a> Deref for RefWaker<'a, W> {
type Target = W;
fn deref(&self) -> &Self::Target {
&self.data.w
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::cell::Cell;
struct TestWaker {
waked: Cell<u32>,
}
impl TestWaker {
fn new() -> Self {
TestWaker {
waked: Cell::new(0),
}
}
fn waked(&self) -> u32 {
self.waked.get()
}
}
impl RcWake for TestWaker {
fn wake(self: Rc<Self>) {
self.waked.set(self.waked.get() + 1);
}
}
impl RefWake for TestWaker {
fn wake(&self) {
self.waked.set(self.waked.get() + 1);
}
}
#[test]
fn test_rc_waker() {
let data = Rc::new(TestWaker::new());
assert_eq!(Rc::strong_count(&data), 1);
let waker = into_std(data.clone());
assert_eq!(Rc::strong_count(&data), 2);
let waker2 = waker.clone();
assert_eq!(Rc::strong_count(&data), 3);
assert_eq!(data.waked(), 0);
waker2.wake();
assert_eq!(Rc::strong_count(&data), 2);
assert_eq!(data.waked(), 1);
waker.wake_by_ref();
assert_eq!(Rc::strong_count(&data), 2);
assert_eq!(data.waked(), 2);
mem::drop(waker);
assert_eq!(Rc::strong_count(&data), 1);
}
#[test]
fn test_ref_waker() {
let data = RefWakerData::new(TestWaker::new());
let data = RefWaker::new(&data);
assert_eq!(RefWaker::ref_count(&data), 1);
let waker = data.as_std(&mut mem::MaybeUninit::uninit()).clone();
assert_eq!(RefWaker::ref_count(&data), 2);
let waker2 = waker.clone();
assert_eq!(RefWaker::ref_count(&data), 3);
assert_eq!(data.waked(), 0);
waker2.wake();
assert_eq!(RefWaker::ref_count(&data), 2);
assert_eq!(data.waked(), 1);
waker.wake_by_ref();
assert_eq!(RefWaker::ref_count(&data), 2);
assert_eq!(data.waked(), 2);
mem::drop(waker);
assert_eq!(RefWaker::ref_count(&data), 1);
}
}