#![no_std]
extern crate alloc;
use alloc::boxed::Box;
use alloc::rc;
use alloc::sync as arc;
use core::{
mem::ManuallyDrop,
ptr,
task::{RawWaker, RawWakerVTable, Waker},
};
pub unsafe trait ViaRawPointer {
type Target: ?Sized;
fn into_raw(self) -> *mut Self::Target;
unsafe fn from_raw(ptr: *mut Self::Target) -> Self;
}
pub trait WakeRef {
fn wake_by_ref(&self);
}
pub trait Wake: WakeRef + Sized {
#[inline]
fn wake(self) {
self.wake_by_ref()
}
}
pub trait IntoWaker {
#[doc(hidden)]
const VTABLE: &'static RawWakerVTable;
#[must_use]
fn into_waker(self) -> Waker;
}
impl<T> IntoWaker for T
where
T: Wake + Clone + Send + Sync + 'static + ViaRawPointer,
T::Target: Sized,
{
const VTABLE: &'static RawWakerVTable = &RawWakerVTable::new(
|raw| {
let raw = raw as *mut T::Target;
let waker = ManuallyDrop::<T>::new(unsafe { ViaRawPointer::from_raw(raw) });
let cloned: T = (*waker).clone();
debug_assert_eq!(ManuallyDrop::into_inner(waker).into_raw(), raw);
let cloned_raw = cloned.into_raw();
let cloned_raw = cloned_raw as *const ();
RawWaker::new(cloned_raw, T::VTABLE)
},
|raw| {
let raw = raw as *mut T::Target;
let waker: T = unsafe { ViaRawPointer::from_raw(raw) };
waker.wake();
},
|raw| {
let raw = raw as *mut T::Target;
let waker = ManuallyDrop::<T>::new(unsafe { ViaRawPointer::from_raw(raw) });
waker.wake_by_ref();
debug_assert_eq!(ManuallyDrop::into_inner(waker).into_raw(), raw);
},
|raw| {
let raw = raw as *mut T::Target;
let _waker: T = unsafe { ViaRawPointer::from_raw(raw) };
},
);
fn into_waker(self) -> Waker {
let raw = self.into_raw();
let raw = raw as *const ();
let raw_waker = RawWaker::new(raw, T::VTABLE);
unsafe { Waker::from_raw(raw_waker) }
}
}
impl<T: WakeRef + ?Sized> WakeRef for &T {
#[inline]
fn wake_by_ref(&self) {
T::wake_by_ref(*self)
}
}
impl<T: WakeRef + ?Sized> Wake for &T {}
unsafe impl<T: ?Sized> ViaRawPointer for Box<T> {
type Target = T;
fn into_raw(self) -> *mut T {
Box::into_raw(self)
}
unsafe fn from_raw(ptr: *mut T) -> Self {
Box::from_raw(ptr)
}
}
impl<T: WakeRef + ?Sized> WakeRef for Box<T> {
#[inline]
fn wake_by_ref(&self) {
T::wake_by_ref(self.as_ref())
}
}
impl<T: Wake> Wake for Box<T> {
#[inline]
fn wake(self) {
T::wake(*self)
}
}
unsafe impl<T: ?Sized> ViaRawPointer for arc::Arc<T> {
type Target = T;
fn into_raw(self) -> *mut T {
arc::Arc::into_raw(self) as *mut T
}
unsafe fn from_raw(ptr: *mut T) -> Self {
arc::Arc::from_raw(ptr as *const T)
}
}
impl<T: WakeRef + ?Sized> WakeRef for arc::Arc<T> {
#[inline]
fn wake_by_ref(&self) {
T::wake_by_ref(self.as_ref())
}
}
impl<T: WakeRef + ?Sized> Wake for arc::Arc<T> {}
unsafe impl<T> ViaRawPointer for arc::Weak<T> {
type Target = T;
fn into_raw(self) -> *mut T {
arc::Weak::into_raw(self) as *mut T
}
unsafe fn from_raw(ptr: *mut T) -> Self {
arc::Weak::from_raw(ptr as *const T)
}
}
impl<T: WakeRef + ?Sized> WakeRef for arc::Weak<T> {
#[inline]
fn wake_by_ref(&self) {
self.upgrade().wake()
}
}
impl<T: WakeRef + ?Sized> Wake for arc::Weak<T> {}
impl<T: WakeRef + ?Sized> WakeRef for rc::Rc<T> {
#[inline]
fn wake_by_ref(&self) {
T::wake_by_ref(self.as_ref())
}
}
unsafe impl<T: ?Sized> ViaRawPointer for rc::Rc<T> {
type Target = T;
fn into_raw(self) -> *mut T {
rc::Rc::into_raw(self) as *mut T
}
unsafe fn from_raw(ptr: *mut T) -> Self {
rc::Rc::from_raw(ptr as *const T)
}
}
impl<T: WakeRef + ?Sized> Wake for rc::Rc<T> {
#[inline]
fn wake(self) {
T::wake_by_ref(self.as_ref())
}
}
unsafe impl<T> ViaRawPointer for rc::Weak<T> {
type Target = T;
fn into_raw(self) -> *mut T {
rc::Weak::into_raw(self) as *mut T
}
unsafe fn from_raw(ptr: *mut T) -> Self {
rc::Weak::from_raw(ptr as *const T)
}
}
impl<T: WakeRef + ?Sized> WakeRef for rc::Weak<T> {
#[inline]
fn wake_by_ref(&self) {
self.upgrade().wake()
}
}
impl<T: WakeRef + ?Sized> Wake for rc::Weak<T> {}
unsafe impl<T: ViaRawPointer> ViaRawPointer for Option<T>
where
T::Target: Sized,
{
type Target = T::Target;
fn into_raw(self) -> *mut Self::Target {
match self {
Some(value) => match value.into_raw() {
ptr if ptr.is_null() => {
let _ = unsafe { T::from_raw(ptr) };
ptr::null_mut()
}
ptr => ptr,
},
None => ptr::null_mut(),
}
}
unsafe fn from_raw(ptr: *mut Self::Target) -> Self {
match ptr.is_null() {
false => Some(T::from_raw(ptr)),
true => None,
}
}
}
impl<T: WakeRef> WakeRef for Option<T> {
#[inline]
fn wake_by_ref(&self) {
if let Some(waker) = self {
waker.wake_by_ref()
}
}
}
impl<T: Wake> Wake for Option<T> {
#[inline]
fn wake(self) {
if let Some(waker) = self {
waker.wake()
}
}
}
impl WakeRef for Waker {
#[inline]
fn wake_by_ref(&self) {
Waker::wake_by_ref(self)
}
}
impl Wake for Waker {
#[inline]
fn wake(self) {
Waker::wake(self)
}
}
#[cfg(test)]
mod test {
extern crate std;
use super::*;
use std::panic;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::task::Waker;
static PANIC_WAKE_REF_COUNT: AtomicUsize = AtomicUsize::new(0);
static PANIC_WAKE_VALUE_COUNT: AtomicUsize = AtomicUsize::new(0);
static PANIC_DROP_COUNT: AtomicUsize = AtomicUsize::new(0);
#[derive(Debug, Clone)]
struct PanicWaker;
impl WakeRef for PanicWaker {
fn wake_by_ref(&self) {
PANIC_WAKE_REF_COUNT.fetch_add(1, Ordering::SeqCst);
panic!();
}
}
impl Wake for PanicWaker {
fn wake(self) {
PANIC_WAKE_VALUE_COUNT.fetch_add(1, Ordering::SeqCst);
}
}
impl Drop for PanicWaker {
fn drop(&mut self) {
PANIC_DROP_COUNT.fetch_add(1, Ordering::SeqCst);
}
}
unsafe impl ViaRawPointer for PanicWaker {
type Target = ();
fn into_raw(self) -> *mut () {
std::mem::forget(self);
std::ptr::null_mut()
}
unsafe fn from_raw(_ptr: *mut ()) -> Self {
PanicWaker
}
}
#[test]
fn panic_wake() {
assert_eq!(PANIC_DROP_COUNT.load(Ordering::SeqCst), 0);
let waker = PanicWaker;
{
let waker1: Waker = waker.into_waker();
let waker2: Waker = waker1.clone();
let result = panic::catch_unwind(|| {
waker2.wake_by_ref();
});
assert!(result.is_err());
assert_eq!(PANIC_WAKE_REF_COUNT.load(Ordering::SeqCst), 1);
assert_eq!(PANIC_DROP_COUNT.load(Ordering::SeqCst), 0);
let result = panic::catch_unwind(|| {
waker1.wake_by_ref();
});
assert!(result.is_err());
assert_eq!(PANIC_WAKE_REF_COUNT.load(Ordering::SeqCst), 2);
assert_eq!(PANIC_DROP_COUNT.load(Ordering::SeqCst), 0);
let result = panic::catch_unwind(|| {
waker1.wake();
});
assert!(result.is_ok());
assert_eq!(PANIC_WAKE_VALUE_COUNT.load(Ordering::SeqCst), 1);
assert_eq!(PANIC_DROP_COUNT.load(Ordering::SeqCst), 1);
}
assert_eq!(PANIC_DROP_COUNT.load(Ordering::SeqCst), 2);
}
}