use crate::backoff::Backoff;
use crate::shared::*;
#[allow(unused_imports)]
use crate::{tokio_task_id, trace_log};
use core::cell::UnsafeCell;
use std::future::Future;
use std::pin::Pin;
use std::ptr::NonNull;
use std::sync::atomic::{
fence, AtomicU8,
Ordering::{self, AcqRel, Acquire, SeqCst},
};
use std::task::{Context, Poll};
use std::thread;
use std::time::{Duration, Instant};
const LOCK_FLAG: u8 = 0x1;
const WAKER_SET_FLAG: u8 = 0x2;
const CLOSE_FLAG: u8 = 0x4;
const EXIST_FLAG: u8 = 0x8;
struct OneShotInner<T> {
state: AtomicU8,
value: UnsafeCell<Option<T>>,
o_waker: UnsafeCell<Option<ThinWaker>>,
}
unsafe impl<T: Send> Send for OneShotInner<T> {}
unsafe impl<T: Send> Sync for OneShotInner<T> {}
impl<T> OneShotInner<T> {
#[inline]
fn new() -> Box<Self> {
Box::new(Self {
value: UnsafeCell::new(None),
state: AtomicU8::new(0),
o_waker: UnsafeCell::new(None),
})
}
#[inline]
fn get_waker(&self) -> &mut Option<ThinWaker> {
unsafe { &mut *self.o_waker.get() }
}
#[inline(always)]
fn value_mut(&self) -> &mut Option<T> {
unsafe { &mut *self.value.get() }
}
#[inline(always)]
fn set_state(&self, flag: u8) -> u8 {
self.state.fetch_or(flag, Ordering::AcqRel)
}
#[inline(always)]
fn _try_recv(&self, order: Ordering) -> Result<u8, u8> {
let state = self.state.load(order);
if state & LOCK_FLAG > 0 {
Ok(state)
} else {
Err(state)
}
}
#[inline(always)]
fn _consume_value(p: NonNull<Self>, mut state: u8) -> Option<T> {
debug_assert!(
state & LOCK_FLAG > 0,
"oneshot:({:?}) consume value unexpected {state}",
tokio_task_id!()
);
let this = unsafe { p.as_ref() };
let item = if state & EXIST_FLAG > 0 { this.value_mut().take() } else { None };
loop {
if state & CLOSE_FLAG > 0 {
trace_log!(
"oneshot:({:?}) recv value={} & destroy",
tokio_task_id!(),
item.is_some()
);
fence(Acquire);
let _ = unsafe { Box::from_raw(p.as_ptr()) };
return item;
}
if let Err(s) = this.state.compare_exchange(state, CLOSE_FLAG | state, AcqRel, Acquire)
{
trace_log!(
"oneshot:({:?}) recv value={} {state} close retry",
tokio_task_id!(),
item.is_some()
);
state = s;
} else {
trace_log!(
"oneshot:({:?}) recv value={} {state}",
tokio_task_id!(),
item.is_some()
);
return item;
}
}
}
#[inline(always)]
fn _notify_rx(p: NonNull<Self>, exist: bool) -> bool {
let this = unsafe { p.as_ref() };
let mut old_state = 0;
let exist_flag: u8 = if exist { EXIST_FLAG } else { 0 };
loop {
let new_state = if old_state == 0 {
LOCK_FLAG | CLOSE_FLAG | exist_flag
} else if old_state == WAKER_SET_FLAG {
LOCK_FLAG | WAKER_SET_FLAG | exist_flag
} else if old_state & CLOSE_FLAG > 0 {
trace_log!("oneshot:({:?}) rx closed", tokio_task_id!());
return true;
} else {
panic!("unexpected state {}", old_state);
};
match this.state.compare_exchange_weak(old_state, new_state, AcqRel, Acquire) {
Ok(_) => {
if old_state == 0 {
trace_log!("oneshot:({:?}) send value", tokio_task_id!());
return false;
} else {
if let Some(waker) = this.get_waker().as_ref() {
trace_log!("oneshot:({:?}) wake rx", tokio_task_id!());
waker.wake_by_ref();
} else {
unreachable!();
}
if let Err(state) = this.state.compare_exchange(
new_state,
CLOSE_FLAG | LOCK_FLAG | exist_flag,
AcqRel,
Acquire,
) {
debug_assert!(state & CLOSE_FLAG > 0, "unexpected state {state}");
trace_log!("oneshot:({:?}) rx closed {state}", tokio_task_id!());
return true;
} else {
return false;
}
}
}
Err(s) => {
old_state = s;
}
}
}
}
#[inline(always)]
fn set_waker(&self, waker: ThinWaker) -> Result<(), u8> {
self.get_waker().replace(waker);
self.state.compare_exchange(0, WAKER_SET_FLAG, AcqRel, Acquire)?;
Ok(())
}
#[inline(always)]
fn cancel_waker(&self, abandon: bool) -> Result<(), u8> {
let new_state = if abandon { CLOSE_FLAG } else { 0 };
if let Err(state) = self.state.compare_exchange(WAKER_SET_FLAG, new_state, AcqRel, Acquire)
{
return Err(state);
} else {
Ok(())
}
}
#[inline(always)]
fn is_empty(&self) -> bool {
let state = self.state.load(Ordering::SeqCst);
state & EXIST_FLAG == 0
}
}
pub struct TxOneshot<T>(NonNull<OneShotInner<T>>);
unsafe impl<T> Send for TxOneshot<T> {}
unsafe impl<T> Sync for TxOneshot<T> {}
impl<T> TxOneshot<T> {
#[inline]
pub fn send(self, item: T) {
unsafe { self.0.as_ref() }.value_mut().replace(item);
if OneShotInner::_notify_rx(self.0, true) {
let _ = unsafe { Box::from_raw(self.0.as_ptr()) };
}
std::mem::forget(self);
}
#[inline]
pub fn is_disconnected(&self) -> bool {
unsafe { self.0.as_ref() }.state.load(Acquire) & CLOSE_FLAG > 0
}
}
impl<T> Drop for TxOneshot<T> {
#[inline]
fn drop(&mut self) {
if OneShotInner::_notify_rx(self.0, false) {
let _ = unsafe { Box::from_raw(self.0.as_ptr()) };
}
}
}
#[must_use]
pub struct RxOneshot<T>(Option<NonNull<OneShotInner<T>>>);
unsafe impl<T> Send for RxOneshot<T> {}
impl<T> Drop for RxOneshot<T> {
#[inline]
fn drop(&mut self) {
if let Some(p) = self.0.as_ref() {
let inner = unsafe { p.as_ref() };
let old_state = inner.set_state(CLOSE_FLAG);
if old_state & CLOSE_FLAG > 0 {
trace_log!("oneshot:({:?}) rx drop destroy, state={}", tokio_task_id!(), old_state);
debug_assert_eq!(
old_state & (!EXIST_FLAG),
CLOSE_FLAG | LOCK_FLAG,
"unexpected state {old_state}"
); let _ = unsafe { Box::from_raw(p.as_ptr()) };
} else {
trace_log!("oneshot:({:?}) rx drop, state={}", tokio_task_id!(), old_state);
debug_assert!(
old_state == 0 || old_state == WAKER_SET_FLAG || old_state | EXIST_FLAG== (EXIST_FLAG | LOCK_FLAG | WAKER_SET_FLAG), "oneshot:({:?}) rx drop, unexpected state={}",
tokio_task_id!(),
old_state
);
}
}
}
}
impl<T> RxOneshot<T> {
#[inline]
pub fn recv(self) -> Result<T, RecvError> {
if let Ok(item) = self._recv_blocking(None) {
return Ok(item);
}
Err(RecvError)
}
#[inline]
pub fn recv_timeout(self, timeout: Duration) -> Result<T, RecvTimeoutError> {
let deadline = Instant::now() + timeout;
match self._recv_blocking(Some(deadline)) {
Ok(item) => Ok(item),
Err(true) => Err(RecvTimeoutError::Timeout),
Err(false) => Err(RecvTimeoutError::Disconnected),
}
}
#[inline(always)]
pub fn is_empty(&self) -> bool {
if let Some(p) = self.0.as_ref() {
let inner = unsafe { p.as_ref() };
inner.is_empty()
} else {
true
}
}
#[inline]
pub fn try_recv(&mut self) -> Result<T, TryRecvError> {
if let Some(p) = self.0.as_ref() {
let p = *p;
if let Ok(state) = unsafe { p.as_ref() }._try_recv(Acquire) {
self.0 = None;
if let Some(item) = OneShotInner::_consume_value(p, state) {
return Ok(item);
} else {
return Err(TryRecvError::Disconnected);
}
} else {
Err(TryRecvError::Empty)
}
} else {
Err(TryRecvError::Disconnected)
}
}
#[inline]
pub async fn recv_async(self) -> Result<T, RecvError> {
self.await
}
#[inline]
fn poll(&mut self, ctx: &mut Context<'_>) -> Poll<Result<T, ()>> {
let p: NonNull<OneShotInner<T>> = if let Some(p) = self.0.as_ref() {
*p
} else {
return Poll::Ready(Err(()));
};
let inner = unsafe { p.as_ref() };
macro_rules! process {
($state: expr) => {
self.0 = None;
if let Some(item) = OneShotInner::_consume_value(p, $state) {
return Poll::Ready(Ok(item));
} else {
return Poll::Ready(Err(()));
}
};
}
macro_rules! check_exist {
($order: expr) => {{
match inner._try_recv($order) {
Ok(state) => {
process!(state);
}
Err(s) => s,
}
}};
}
let state = check_exist!(SeqCst);
if state & WAKER_SET_FLAG > 0 {
let waker = inner.get_waker().as_ref().unwrap();
if waker.will_wake(ctx) {
trace_log!("oneshot:({:?}) spurious waked state {}", tokio_task_id!(), state,);
return Poll::Pending;
}
if let Err(state) = inner.cancel_waker(false) {
process!(state);
}
}
if let Err(state) = inner.set_waker(ThinWaker::Async(ctx.waker().clone())) {
process!(state);
}
Poll::Pending
}
#[inline(always)]
pub(crate) fn _recv_blocking(self, deadline: Option<Instant>) -> Result<T, bool> {
let p: NonNull<OneShotInner<T>> = if let Some(p) = self.0.as_ref() {
*p
} else {
return Err(false);
};
let inner = unsafe { p.as_ref() };
macro_rules! process {
($state: expr) => {
let _ = inner;
std::mem::forget(self);
if let Some(item) = OneShotInner::_consume_value(p, $state) {
return Ok(item);
} else {
return Err(false);
}
};
}
macro_rules! try_recv {
($order: expr) => {
if let Ok(state) = inner._try_recv($order) {
trace_log!("try_recv got {state}");
process!(state);
}
};
}
try_recv!(Acquire);
let mut backoff = Backoff::new();
while !backoff.snooze() {
try_recv!(Acquire);
}
if let Err(state) = inner.set_waker(ThinWaker::Blocking(thread::current())) {
process!(state);
}
trace_log!("oneshot: waker set");
loop {
try_recv!(SeqCst);
match check_timeout(deadline) {
Ok(None) => {
std::thread::park();
}
Ok(Some(dur)) => {
std::thread::park_timeout(dur);
}
Err(_) => {
trace_log!("oneshot: to cancel_waker on timeout");
if let Err(state) = inner.cancel_waker(true) {
process!(state);
} else {
let _ = inner;
std::mem::forget(self);
return Err(true);
}
}
}
}
}
#[cfg(any(feature = "tokio", feature = "async_std"))]
#[cfg_attr(docsrs, doc(cfg(any(feature = "tokio", feature = "async_std"))))]
#[inline]
pub async fn recv_async_timeout(
self, timeout: std::time::Duration,
) -> Result<T, RecvTimeoutError> {
#[cfg(feature = "tokio")]
{
let sleep = tokio::time::sleep(timeout);
self.recv_async_with_timer(sleep).await
}
#[cfg(feature = "async_std")]
{
let sleep = async_std::task::sleep(timeout);
self.recv_async_with_timer(sleep).await
}
}
#[inline]
pub fn recv_async_with_timer<F, R>(self, sleep: F) -> OneshotTimeoutFuture<T, F, R>
where
F: Future<Output = R>,
{
OneshotTimeoutFuture { rx: self, sleep }
}
}
impl<T> Future for RxOneshot<T> {
type Output = Result<T, RecvError>;
#[inline]
fn poll(self: Pin<&mut Self>, ctx: &mut Context) -> Poll<Self::Output> {
let this = self.get_mut();
match this.poll(ctx) {
Poll::Ready(Ok(item)) => Poll::Ready(Ok(item)),
Poll::Ready(Err(())) => Poll::Ready(Err(RecvError)),
Poll::Pending => Poll::Pending,
}
}
}
pub struct OneshotTimeoutFuture<T, F, R>
where
F: Future<Output = R>,
{
rx: RxOneshot<T>,
sleep: F,
}
impl<T, F, R> Future for OneshotTimeoutFuture<T, F, R>
where
F: Future<Output = R>,
{
type Output = Result<T, RecvTimeoutError>;
#[inline]
fn poll(self: Pin<&mut Self>, ctx: &mut Context) -> Poll<Self::Output> {
let this = unsafe { self.get_unchecked_mut() };
match this.rx.poll(ctx) {
Poll::Ready(Ok(item)) => return Poll::Ready(Ok(item)),
Poll::Ready(Err(())) => return Poll::Ready(Err(RecvTimeoutError::Disconnected)),
_ => {}
}
let sleep = unsafe { Pin::new_unchecked(&mut this.sleep) };
if sleep.poll(ctx).is_ready() {
Poll::Ready(Err(RecvTimeoutError::Timeout))
} else {
Poll::Pending
}
}
}
#[inline]
pub fn oneshot<T>() -> (TxOneshot<T>, RxOneshot<T>) {
let p = unsafe { NonNull::new_unchecked(Box::into_raw(OneShotInner::new())) };
let tx = TxOneshot(p);
let rx = RxOneshot(Some(p));
(tx, rx)
}