use crate::async_util::AtomicWaker;
use crate::error::{RecvError, TryRecvError, TrySendError};
use core::task::{Context, Poll};
use std::fmt;
use std::mem::MaybeUninit;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; use std::sync::Mutex;
pub(super) const STATE_EMPTY: usize = 0; pub(super) const STATE_WRITING: usize = 1; pub(super) const STATE_SENT: usize = 2; pub(super) const STATE_TAKEN: usize = 3; pub(super) const STATE_CLOSED: usize = 4;
pub(super) struct OneShotShared<T> {
pub(crate) value_slot: Mutex<Option<MaybeUninit<T>>>,
pub(crate) state: AtomicUsize, receiver_waker: AtomicWaker,
pub(crate) receiver_dropped: AtomicBool, pub(crate) sender_count: AtomicUsize, }
impl<T> fmt::Debug for OneShotShared<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let state_val = self.state.load(Ordering::Relaxed);
let state_str = match state_val {
STATE_EMPTY => "Empty",
STATE_WRITING => "Writing",
STATE_SENT => "Sent",
STATE_TAKEN => "Taken",
STATE_CLOSED => "Closed",
_ => "Unknown",
};
f.debug_struct("OneShotShared")
.field("state", &state_str)
.field(
"receiver_dropped",
&self.receiver_dropped.load(Ordering::Relaxed),
)
.field("sender_count", &self.sender_count.load(Ordering::Relaxed))
.finish_non_exhaustive()
}
}
unsafe impl<T: Send> Send for OneShotShared<T> {}
unsafe impl<T: Send> Sync for OneShotShared<T> {}
impl<T> OneShotShared<T> {
pub(super) fn new() -> Self {
OneShotShared {
value_slot: Mutex::new(None),
state: AtomicUsize::new(STATE_EMPTY),
receiver_waker: AtomicWaker::new(),
receiver_dropped: AtomicBool::new(false),
sender_count: AtomicUsize::new(1),
}
}
pub(super) fn increment_senders(&self) {
self.sender_count.fetch_add(1, Ordering::Relaxed);
}
pub(super) fn decrement_senders(&self) {
if self.sender_count.fetch_sub(1, Ordering::AcqRel) == 1 {
if self
.state
.compare_exchange(
STATE_EMPTY,
STATE_CLOSED,
Ordering::AcqRel, Ordering::Relaxed,
)
.is_ok()
{
self.receiver_waker.wake(); }
else if self.state.load(Ordering::Acquire) == STATE_SENT
&& self.receiver_dropped.load(Ordering::Acquire)
{
if self
.state
.compare_exchange(STATE_SENT, STATE_TAKEN, Ordering::AcqRel, Ordering::Relaxed)
.is_ok()
{
let mut guard = self.value_slot.lock().unwrap_or_else(|e| e.into_inner());
if let Some(mut mu_value) = guard.take() {
unsafe {
mu_value.assume_init_drop();
}
}
}
}
else if self.state.load(Ordering::Relaxed) != STATE_TAKEN
&& self.state.load(Ordering::Relaxed) != STATE_SENT
{
self.receiver_waker.wake();
}
}
}
pub(super) fn mark_receiver_dropped(&self) {
self.receiver_dropped.store(true, Ordering::Release);
if self
.state
.compare_exchange(
STATE_EMPTY,
STATE_CLOSED,
Ordering::AcqRel,
Ordering::Relaxed,
)
.is_ok()
{
}
}
pub(super) fn send(&self, value: T) -> Result<(), TrySendError<T>> {
if self.receiver_dropped.load(Ordering::Acquire) {
return Err(TrySendError::Closed(value));
}
let current_state = self.state.load(Ordering::Acquire);
if current_state >= STATE_SENT {
return Err(TrySendError::Sent(value)); }
match self.state.compare_exchange(
STATE_EMPTY,
STATE_WRITING,
Ordering::AcqRel, Ordering::Acquire, ) {
Ok(_) => {
if self.receiver_dropped.load(Ordering::Acquire) {
self.state.store(STATE_EMPTY, Ordering::Release); return Err(TrySendError::Closed(value));
}
let mut guard = self
.value_slot
.lock()
.expect("Oneshot value_slot mutex poisoned");
*guard = Some(MaybeUninit::new(value));
let prev_state = self.state.swap(STATE_SENT, Ordering::AcqRel);
debug_assert_eq!(
prev_state, STATE_WRITING,
"Oneshot: State inconsistency during send, expected WRITING"
);
self.receiver_waker.wake();
Ok(())
}
Err(observed_state_on_failure) => {
if observed_state_on_failure >= STATE_SENT {
Err(TrySendError::Sent(value))
} else if observed_state_on_failure == STATE_WRITING {
Err(TrySendError::Sent(value))
} else {
Err(TrySendError::Sent(value))
}
}
}
}
pub(super) fn try_recv(&self) -> Result<T, TryRecvError> {
let current_state = self.state.load(Ordering::Acquire);
if current_state == STATE_SENT {
if self
.state
.compare_exchange(
STATE_SENT,
STATE_TAKEN,
Ordering::AcqRel, Ordering::Acquire, )
.is_ok()
{
let mut guard = self
.value_slot
.lock()
.expect("Oneshot value_slot mutex poisoned");
match guard.take() {
Some(mu_value) => unsafe { Ok(mu_value.assume_init()) },
None => {
self.state.store(STATE_CLOSED, Ordering::Relaxed); Err(TryRecvError::Disconnected) }
}
} else {
let new_state_after_cas_fail = self.state.load(Ordering::Acquire);
if new_state_after_cas_fail == STATE_TAKEN {
Err(TryRecvError::Empty) } else if new_state_after_cas_fail == STATE_CLOSED
|| self.sender_count.load(Ordering::Relaxed) == 0
{
Err(TryRecvError::Disconnected)
} else {
Err(TryRecvError::Empty) }
}
} else if current_state == STATE_TAKEN {
Err(TryRecvError::Empty) } else if current_state == STATE_CLOSED {
Err(TryRecvError::Disconnected)
} else {
if current_state == STATE_EMPTY && self.sender_count.load(Ordering::Acquire) == 0 {
self
.state
.compare_exchange(
STATE_EMPTY,
STATE_CLOSED,
Ordering::Relaxed,
Ordering::Relaxed,
)
.ok();
Err(TryRecvError::Disconnected)
} else {
Err(TryRecvError::Empty) }
}
}
pub(super) fn poll_recv(&self, cx: &mut Context<'_>) -> Poll<Result<T, RecvError>> {
loop {
match self.try_recv() {
Ok(value) => return Poll::Ready(Ok(value)),
Err(TryRecvError::Disconnected) => return Poll::Ready(Err(RecvError::Disconnected)),
Err(TryRecvError::Empty) => {
let current_state = self.state.load(Ordering::Acquire);
if current_state == STATE_TAKEN || current_state == STATE_CLOSED {
if self.sender_count.load(Ordering::Acquire) == 0 && current_state != STATE_SENT {
return Poll::Ready(Err(RecvError::Disconnected));
}
}
if current_state == STATE_EMPTY && self.sender_count.load(Ordering::Acquire) == 0 {
self
.state
.compare_exchange(
STATE_EMPTY,
STATE_CLOSED,
Ordering::Relaxed,
Ordering::Relaxed,
)
.ok();
return Poll::Ready(Err(RecvError::Disconnected));
}
self.receiver_waker.register(cx.waker());
match self.try_recv() {
Ok(value) => {
return Poll::Ready(Ok(value));
}
Err(TryRecvError::Disconnected) => {
return Poll::Ready(Err(RecvError::Disconnected));
}
Err(TryRecvError::Empty) => {
return Poll::Pending;
}
}
}
}
}
}
}
impl<T> Drop for OneShotShared<T> {
fn drop(&mut self) {
if self.state.load(Ordering::Relaxed) == STATE_SENT {
let guard = self.value_slot.get_mut().unwrap_or_else(|e| e.into_inner());
if let Some(mut mu_value) = guard.take() {
unsafe {
mu_value.assume_init_drop();
}
}
}
}
}