use crate::shim::atomic::{AtomicU8, Ordering};
use super::common::{OneshotStorage, TakeResult};
pub use super::common::RecvError;
pub use super::common::TryRecvError;
pub use super::common::error;
pub trait State: Sized + Send + Sync + 'static {
fn to_u8(&self) -> u8;
fn from_u8(value: u8) -> Option<Self>;
fn pending_value() -> u8;
fn closed_value() -> u8;
fn receiver_closed_value() -> u8;
}
impl State for () {
#[inline]
fn to_u8(&self) -> u8 {
1 }
#[inline]
fn from_u8(value: u8) -> Option<Self> {
match value {
1 => Some(()),
_ => None,
}
}
#[inline]
fn pending_value() -> u8 {
0 }
#[inline]
fn closed_value() -> u8 {
255 }
#[inline]
fn receiver_closed_value() -> u8 {
254 }
}
pub struct LiteStorage<S: State> {
state: AtomicU8,
_marker: core::marker::PhantomData<S>,
}
unsafe impl<S: State> Send for LiteStorage<S> {}
unsafe impl<S: State> Sync for LiteStorage<S> {}
impl<S: State> OneshotStorage for LiteStorage<S> {
type Value = S;
#[inline]
fn new() -> Self {
Self {
state: AtomicU8::new(S::pending_value()),
_marker: core::marker::PhantomData,
}
}
#[inline]
fn store(&self, value: S) {
self.state.store(value.to_u8(), Ordering::Release);
}
#[inline]
fn try_take(&self) -> TakeResult<S> {
let current = self.state.load(Ordering::Acquire);
if current == S::closed_value() || current == S::receiver_closed_value() {
return TakeResult::Closed;
}
if current == S::pending_value() {
return TakeResult::Pending;
}
if let Some(state) = S::from_u8(current) {
TakeResult::Ready(state)
} else {
TakeResult::Pending
}
}
#[inline]
fn is_sender_dropped(&self) -> bool {
self.state.load(Ordering::Acquire) == S::closed_value()
}
#[inline]
fn mark_sender_dropped(&self) {
self.state.store(S::closed_value(), Ordering::Release);
}
#[inline]
fn is_receiver_closed(&self) -> bool {
self.state.load(Ordering::Acquire) == S::receiver_closed_value()
}
#[inline]
fn mark_receiver_closed(&self) {
self.state
.store(S::receiver_closed_value(), Ordering::Release);
}
}
pub type Sender<S> = super::common::Sender<LiteStorage<S>>;
pub type Receiver<S> = super::common::Receiver<LiteStorage<S>>;
#[inline]
pub fn channel<S: State>() -> (Sender<S>, Receiver<S>) {
Sender::new()
}
impl<S: State> Receiver<S> {
#[inline]
pub async fn recv(self) -> Result<S, RecvError> {
self.await
}
#[inline]
pub fn try_recv(&mut self) -> Result<S, TryRecvError> {
match self.inner.try_recv() {
TakeResult::Ready(v) => Ok(v),
TakeResult::Pending => Err(TryRecvError::Empty),
TakeResult::Closed => Err(TryRecvError::Closed),
}
}
}
#[cfg(all(test, not(feature = "loom")))]
mod tests {
use super::*;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum TestCompletion {
Called,
Cancelled,
}
impl State for TestCompletion {
fn to_u8(&self) -> u8 {
match self {
TestCompletion::Called => 1,
TestCompletion::Cancelled => 2,
}
}
fn from_u8(value: u8) -> Option<Self> {
match value {
1 => Some(TestCompletion::Called),
2 => Some(TestCompletion::Cancelled),
_ => None,
}
}
fn pending_value() -> u8 {
0
}
fn closed_value() -> u8 {
255
}
fn receiver_closed_value() -> u8 {
254
}
}
#[tokio::test]
async fn test_oneshot_called() {
let (notifier, receiver) = Sender::<TestCompletion>::new();
tokio::spawn(async move {
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
notifier.send(TestCompletion::Called).unwrap();
});
let result = receiver.recv().await;
assert_eq!(result, Ok(TestCompletion::Called));
}
#[tokio::test]
async fn test_oneshot_cancelled() {
let (notifier, receiver) = Sender::<TestCompletion>::new();
tokio::spawn(async move {
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
notifier.send(TestCompletion::Cancelled).unwrap();
});
let result = receiver.recv().await;
assert_eq!(result, Ok(TestCompletion::Cancelled));
}
#[tokio::test]
async fn test_oneshot_immediate_called() {
let (notifier, receiver) = Sender::<TestCompletion>::new();
notifier.send(TestCompletion::Called).unwrap();
let result = receiver.recv().await;
assert_eq!(result, Ok(TestCompletion::Called));
}
#[tokio::test]
async fn test_oneshot_immediate_cancelled() {
let (notifier, receiver) = Sender::<TestCompletion>::new();
notifier.send(TestCompletion::Cancelled).unwrap();
let result = receiver.recv().await;
assert_eq!(result, Ok(TestCompletion::Cancelled));
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum CustomState {
Success,
Failure,
Timeout,
}
impl State for CustomState {
fn to_u8(&self) -> u8 {
match self {
CustomState::Success => 1,
CustomState::Failure => 2,
CustomState::Timeout => 3,
}
}
fn from_u8(value: u8) -> Option<Self> {
match value {
1 => Some(CustomState::Success),
2 => Some(CustomState::Failure),
3 => Some(CustomState::Timeout),
_ => None,
}
}
fn pending_value() -> u8 {
0
}
fn closed_value() -> u8 {
255
}
fn receiver_closed_value() -> u8 {
254
}
}
#[tokio::test]
async fn test_oneshot_custom_state() {
let (notifier, receiver) = Sender::<CustomState>::new();
tokio::spawn(async move {
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
notifier.send(CustomState::Success).unwrap();
});
let result = receiver.recv().await;
assert_eq!(result, Ok(CustomState::Success));
}
#[tokio::test]
async fn test_oneshot_custom_state_timeout() {
let (notifier, receiver) = Sender::<CustomState>::new();
notifier.send(CustomState::Timeout).unwrap();
let result = receiver.recv().await;
assert_eq!(result, Ok(CustomState::Timeout));
}
#[tokio::test]
async fn test_oneshot_unit_type() {
let (notifier, receiver) = Sender::<()>::new();
tokio::spawn(async move {
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
notifier.send(()).unwrap();
});
let result = receiver.recv().await;
assert_eq!(result, Ok(()));
}
#[tokio::test]
async fn test_oneshot_unit_type_immediate() {
let (notifier, receiver) = Sender::<()>::new();
notifier.send(()).unwrap();
let result = receiver.recv().await;
assert_eq!(result, Ok(()));
}
#[tokio::test]
async fn test_oneshot_into_future_called() {
let (notifier, receiver) = Sender::<TestCompletion>::new();
tokio::spawn(async move {
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
notifier.send(TestCompletion::Called).unwrap();
});
let result = receiver.await;
assert_eq!(result, Ok(TestCompletion::Called));
}
#[tokio::test]
async fn test_oneshot_into_future_immediate() {
let (notifier, receiver) = Sender::<TestCompletion>::new();
notifier.send(TestCompletion::Cancelled).unwrap();
let result = receiver.await;
assert_eq!(result, Ok(TestCompletion::Cancelled));
}
#[tokio::test]
async fn test_oneshot_into_future_unit_type() {
let (notifier, receiver) = Sender::<()>::new();
tokio::spawn(async move {
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
notifier.send(()).unwrap();
});
let result = receiver.await;
assert_eq!(result, Ok(()));
}
#[tokio::test]
async fn test_oneshot_into_future_custom_state() {
let (notifier, receiver) = Sender::<CustomState>::new();
tokio::spawn(async move {
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
notifier.send(CustomState::Failure).unwrap();
});
let result = receiver.await;
assert_eq!(result, Ok(CustomState::Failure));
}
#[tokio::test]
async fn test_oneshot_await_mut_reference() {
let (notifier, mut receiver) = Sender::<TestCompletion>::new();
tokio::spawn(async move {
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
notifier.send(TestCompletion::Called).unwrap();
});
let result = (&mut receiver).await;
assert_eq!(result, Ok(TestCompletion::Called));
}
#[tokio::test]
async fn test_oneshot_await_mut_reference_unit_type() {
let (notifier, mut receiver) = Sender::<()>::new();
notifier.send(()).unwrap();
let result = (&mut receiver).await;
assert_eq!(result, Ok(()));
}
#[tokio::test]
async fn test_oneshot_try_recv_pending() {
let (_notifier, mut receiver) = Sender::<TestCompletion>::new();
let result = receiver.try_recv();
assert_eq!(result, Err(TryRecvError::Empty));
}
#[tokio::test]
async fn test_oneshot_try_recv_ready() {
let (notifier, mut receiver) = Sender::<TestCompletion>::new();
notifier.send(TestCompletion::Called).unwrap();
let result = receiver.try_recv();
assert_eq!(result, Ok(TestCompletion::Called));
}
#[tokio::test]
async fn test_oneshot_try_recv_sender_dropped() {
let (notifier, mut receiver) = Sender::<TestCompletion>::new();
drop(notifier);
let result = receiver.try_recv();
assert_eq!(result, Err(TryRecvError::Closed));
}
#[tokio::test]
async fn test_oneshot_sender_dropped_before_recv() {
let (notifier, receiver) = Sender::<TestCompletion>::new();
drop(notifier);
let result = receiver.recv().await;
assert_eq!(result, Err(RecvError));
}
#[tokio::test]
async fn test_oneshot_sender_dropped_unit_type() {
let (notifier, receiver) = Sender::<()>::new();
drop(notifier);
let result = receiver.recv().await;
assert_eq!(result, Err(RecvError));
}
#[tokio::test]
async fn test_oneshot_sender_dropped_custom_state() {
let (notifier, receiver) = Sender::<CustomState>::new();
drop(notifier);
let result = receiver.recv().await;
assert_eq!(result, Err(RecvError));
}
#[test]
fn test_sender_is_closed_initially_false() {
let (sender, _receiver) = Sender::<()>::new();
assert!(!sender.is_closed());
}
#[test]
fn test_sender_is_closed_after_receiver_drop() {
let (sender, receiver) = Sender::<()>::new();
drop(receiver);
assert!(sender.is_closed());
}
#[test]
fn test_sender_is_closed_after_receiver_close() {
let (sender, mut receiver) = Sender::<()>::new();
receiver.close();
assert!(sender.is_closed());
}
#[test]
fn test_receiver_close_prevents_send() {
let (sender, mut receiver) = Sender::<TestCompletion>::new();
receiver.close();
assert!(sender.send(TestCompletion::Called).is_err());
}
#[test]
fn test_blocking_recv_immediate() {
let (sender, receiver) = Sender::<TestCompletion>::new();
sender.send(TestCompletion::Called).unwrap();
let result = receiver.blocking_recv();
assert_eq!(result, Ok(TestCompletion::Called));
}
#[test]
fn test_blocking_recv_with_thread() {
let (sender, receiver) = Sender::<()>::new();
std::thread::spawn(move || {
std::thread::sleep(std::time::Duration::from_millis(10));
sender.send(()).unwrap();
});
let result = receiver.blocking_recv();
assert_eq!(result, Ok(()));
}
#[test]
fn test_blocking_recv_sender_dropped() {
let (sender, receiver) = Sender::<()>::new();
std::thread::spawn(move || {
std::thread::sleep(std::time::Duration::from_millis(10));
drop(sender);
});
let result = receiver.blocking_recv();
assert_eq!(result, Err(RecvError));
}
}