use crate::shim::atomic::{AtomicU8, Ordering};
use crate::shim::cell::UnsafeCell;
use core::mem::MaybeUninit;
use super::common::{self, OneshotStorage, TakeResult};
pub use super::common::RecvError;
pub use super::common::TryRecvError;
pub use super::common::error;
const EMPTY: u8 = 0; const READY: u8 = 1; const SENDER_CLOSED: u8 = 2; const RECEIVER_CLOSED: u8 = 3;
pub struct GenericStorage<T> {
state: AtomicU8,
value: UnsafeCell<MaybeUninit<T>>,
}
unsafe impl<T: Send> Send for GenericStorage<T> {}
unsafe impl<T: Send> Sync for GenericStorage<T> {}
impl<T: Send> OneshotStorage for GenericStorage<T> {
type Value = T;
#[inline]
fn new() -> Self {
Self {
state: AtomicU8::new(EMPTY),
value: UnsafeCell::new(MaybeUninit::uninit()),
}
}
#[inline]
fn store(&self, value: T) {
self.value.with_mut(|v| unsafe { (*v).write(value) });
self.state.store(READY, Ordering::Release);
}
#[inline]
fn try_take(&self) -> TakeResult<T> {
let state = self.state.swap(EMPTY, Ordering::Acquire);
match state {
READY => {
self.value
.with(|v| unsafe { TakeResult::Ready((*v).assume_init_read()) })
}
SENDER_CLOSED | RECEIVER_CLOSED => TakeResult::Closed,
_ => TakeResult::Pending,
}
}
#[inline]
fn is_sender_dropped(&self) -> bool {
self.state.load(Ordering::Acquire) == SENDER_CLOSED
}
#[inline]
fn mark_sender_dropped(&self) {
self.state.store(SENDER_CLOSED, Ordering::Release);
}
#[inline]
fn is_receiver_closed(&self) -> bool {
self.state.load(Ordering::Acquire) == RECEIVER_CLOSED
}
#[inline]
fn mark_receiver_closed(&self) {
self.state.store(RECEIVER_CLOSED, Ordering::Release);
}
}
impl<T> Drop for GenericStorage<T> {
fn drop(&mut self) {
if self.state.load(Ordering::Acquire) == READY {
self.value.with_mut(|v| unsafe {
(*v).assume_init_drop();
});
}
}
}
pub type Sender<T> = common::Sender<GenericStorage<T>>;
pub type Receiver<T> = common::Receiver<GenericStorage<T>>;
#[inline]
pub fn channel<T: Send>() -> (Sender<T>, Receiver<T>) {
Sender::new()
}
impl<T: Send> Receiver<T> {
#[inline]
pub fn try_recv(&mut self) -> Result<T, TryRecvError> {
match self.inner.try_recv() {
super::common::TakeResult::Ready(v) => Ok(v),
super::common::TakeResult::Pending => Err(TryRecvError::Empty),
super::common::TakeResult::Closed => Err(TryRecvError::Closed),
}
}
}
#[cfg(all(test, not(feature = "loom")))]
mod tests {
use super::*;
#[tokio::test]
async fn test_oneshot_string() {
let (sender, receiver) = Sender::<String>::new();
tokio::spawn(async move {
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
sender.send("Hello".to_string()).unwrap();
});
let result = receiver.wait().await.unwrap();
assert_eq!(result, "Hello");
}
#[tokio::test]
async fn test_oneshot_integer() {
let (sender, receiver) = Sender::<i32>::new();
tokio::spawn(async move {
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
sender.send(42).unwrap();
});
let result = receiver.wait().await.unwrap();
assert_eq!(result, 42);
}
#[tokio::test]
async fn test_oneshot_immediate() {
let (sender, receiver) = Sender::<String>::new();
sender.send("Immediate".to_string()).unwrap();
let result = receiver.wait().await.unwrap();
assert_eq!(result, "Immediate");
}
#[tokio::test]
async fn test_oneshot_custom_struct() {
#[derive(Debug, Clone, PartialEq)]
struct CustomData {
id: u64,
name: String,
}
let (sender, receiver) = Sender::<CustomData>::new();
let data = CustomData {
id: 123,
name: "Test".to_string(),
};
tokio::spawn(async move {
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
sender.send(data).unwrap();
});
let result = receiver.wait().await.unwrap();
assert_eq!(result.id, 123);
assert_eq!(result.name, "Test");
}
#[tokio::test]
async fn test_oneshot_direct_await() {
let (sender, receiver) = Sender::<i32>::new();
tokio::spawn(async move {
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
sender.send(99).unwrap();
});
let result = receiver.await.unwrap();
assert_eq!(result, 99);
}
#[tokio::test]
async fn test_oneshot_await_mut_reference() {
let (sender, mut receiver) = Sender::<String>::new();
tokio::spawn(async move {
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
sender.send("Mutable".to_string()).unwrap();
});
let result = (&mut receiver).await.unwrap();
assert_eq!(result, "Mutable");
}
#[tokio::test]
async fn test_oneshot_immediate_await() {
let (sender, receiver) = Sender::<Vec<u8>>::new();
sender.send(vec![1, 2, 3]).unwrap();
let result = receiver.await.unwrap();
assert_eq!(result, vec![1, 2, 3]);
}
#[tokio::test]
async fn test_oneshot_try_recv() {
let (sender, mut receiver) = Sender::<i32>::new();
assert_eq!(receiver.try_recv(), Err(TryRecvError::Empty));
sender.send(42).unwrap();
assert_eq!(receiver.try_recv(), Ok(42));
}
#[tokio::test]
async fn test_oneshot_try_recv_closed() {
let (sender, mut receiver) = Sender::<i32>::new();
drop(sender);
assert_eq!(receiver.try_recv(), Err(TryRecvError::Closed));
}
#[tokio::test]
async fn test_oneshot_dropped() {
let (sender, receiver) = Sender::<i32>::new();
drop(sender);
assert_eq!(receiver.await, Err(RecvError));
}
#[tokio::test]
async fn test_oneshot_large_data() {
let (sender, receiver) = Sender::<Vec<u8>>::new();
let large_vec = vec![0u8; 1024 * 1024];
tokio::spawn(async move {
sender.send(large_vec).unwrap();
});
let result = receiver.await.unwrap();
assert_eq!(result.len(), 1024 * 1024);
}
#[test]
fn test_sender_is_closed_initially_false() {
let (sender, _receiver) = Sender::<i32>::new();
assert!(!sender.is_closed());
}
#[test]
fn test_sender_is_closed_after_receiver_drop() {
let (sender, receiver) = Sender::<i32>::new();
drop(receiver);
assert!(sender.is_closed());
}
#[test]
fn test_sender_is_closed_after_receiver_close() {
let (sender, mut receiver) = Sender::<i32>::new();
receiver.close();
assert!(sender.is_closed());
}
#[test]
fn test_receiver_close_prevents_send() {
let (sender, mut receiver) = Sender::<i32>::new();
receiver.close();
assert!(sender.send(42).is_err());
}
#[test]
fn test_blocking_recv_immediate() {
let (sender, receiver) = Sender::<i32>::new();
sender.send(42).unwrap();
let result = receiver.blocking_recv();
assert_eq!(result, Ok(42));
}
#[test]
fn test_blocking_recv_with_thread() {
let (sender, receiver) = Sender::<String>::new();
std::thread::spawn(move || {
std::thread::sleep(std::time::Duration::from_millis(10));
sender.send("hello".to_string()).unwrap();
});
let result = receiver.blocking_recv();
assert_eq!(result, Ok("hello".to_string()));
}
#[test]
fn test_blocking_recv_sender_dropped() {
let (sender, receiver) = Sender::<i32>::new();
std::thread::spawn(move || {
std::thread::sleep(std::time::Duration::from_millis(10));
drop(sender);
});
let result = receiver.blocking_recv();
assert_eq!(result, Err(RecvError));
}
}