use std::cell::UnsafeCell;
use std::future::Future;
use std::pin::Pin;
use std::sync::atomic::{AtomicU8, Ordering};
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll, Waker};
const EMPTY: u8 = 0;
const SENT: u8 = 1;
const CLOSED: u8 = 2;
struct Inner<T> {
state: AtomicU8,
value: UnsafeCell<Option<T>>,
waker: Mutex<Option<Waker>>,
}
unsafe impl<T: Send> Send for Inner<T> {}
unsafe impl<T: Send> Sync for Inner<T> {}
impl<T> Inner<T> {
fn new() -> Self {
Self {
state: AtomicU8::new(EMPTY),
value: UnsafeCell::new(None),
waker: Mutex::new(None),
}
}
}
#[derive(Debug, PartialEq, Eq)]
pub enum RecvError {
Closed,
}
impl std::fmt::Display for RecvError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
RecvError::Closed => f.write_str("oneshot channel closed without a value"),
}
}
}
impl std::error::Error for RecvError {}
pub fn channel<T>() -> (Sender<T>, Receiver<T>) {
let inner = Arc::new(Inner::new());
(
Sender {
inner: inner.clone(),
sent: false,
},
Receiver { inner },
)
}
pub struct Sender<T> {
inner: Arc<Inner<T>>,
sent: bool,
}
impl<T> Sender<T> {
pub fn send(mut self, value: T) -> Result<(), T> {
unsafe { *self.inner.value.get() = Some(value) };
match self.inner.state.compare_exchange(
EMPTY,
SENT,
Ordering::Release, Ordering::Relaxed,
) {
Ok(_) => {
self.sent = true;
if let Some(w) = self.inner.waker.lock().unwrap().take() {
w.wake();
}
Ok(())
}
Err(_) => {
let val = unsafe { (*self.inner.value.get()).take() }.unwrap();
Err(val)
}
}
}
}
impl<T> Drop for Sender<T> {
fn drop(&mut self) {
if self.sent {
return; }
let prev = self.inner.state.swap(CLOSED, Ordering::Release);
if prev == EMPTY {
if let Some(w) = self.inner.waker.lock().unwrap().take() {
w.wake();
}
}
}
}
pub struct Receiver<T> {
inner: Arc<Inner<T>>,
}
impl<T> Future for Receiver<T> {
type Output = Result<T, RecvError>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let state = self.inner.state.load(Ordering::Acquire);
match state {
SENT => {
let val = unsafe { (*self.inner.value.get()).take() }
.expect("oneshot: SENT state but value is None (logic error)");
Poll::Ready(Ok(val))
}
CLOSED => Poll::Ready(Err(RecvError::Closed)),
_ => {
*self.inner.waker.lock().unwrap() = Some(cx.waker().clone());
let state2 = self.inner.state.load(Ordering::Acquire);
if state2 == SENT {
let val = unsafe { (*self.inner.value.get()).take() }
.expect("oneshot: SENT but value None after re-check");
Poll::Ready(Ok(val))
} else if state2 == CLOSED {
Poll::Ready(Err(RecvError::Closed))
} else {
Poll::Pending
}
}
}
}
}
impl<T> Drop for Receiver<T> {
fn drop(&mut self) {
let _ =
self.inner
.state
.compare_exchange(EMPTY, CLOSED, Ordering::Relaxed, Ordering::Relaxed);
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::executor::{block_on, block_on_with_spawn, spawn};
#[test]
fn send_then_recv() {
let result = block_on(async {
let (tx, rx) = channel::<u32>();
tx.send(42).unwrap();
rx.await
});
assert_eq!(result, Ok(42));
}
#[test]
fn recv_then_send_via_spawn() {
let result = block_on_with_spawn(async {
let (tx, rx) = channel::<String>();
let jh = spawn(async move {
tx.send("hello".to_string()).unwrap();
});
let val = rx.await.unwrap();
jh.await.unwrap();
val
});
assert_eq!(result, "hello");
}
#[test]
fn sender_drop_closes_channel() {
let result = block_on(async {
let (tx, rx) = channel::<u32>();
drop(tx);
rx.await
});
assert_eq!(result, Err(RecvError::Closed));
}
#[test]
fn send_after_receiver_drop_returns_err() {
let (tx, rx) = channel::<u32>();
drop(rx);
assert!(tx.send(1).is_err());
}
#[test]
fn value_types_roundtrip() {
block_on(async {
let (tx, rx) = channel::<Vec<u8>>();
tx.send(vec![1, 2, 3]).unwrap();
assert_eq!(rx.await.unwrap(), vec![1, 2, 3]);
});
}
#[test]
fn oneshot_send_string() {
let result = block_on(async {
let (tx, rx) = channel::<String>();
tx.send("world".to_string()).unwrap();
rx.await
});
assert_eq!(result.unwrap(), "world");
}
#[test]
fn oneshot_send_struct() {
#[derive(Debug, PartialEq)]
struct Point {
x: i32,
y: i32,
}
let result = block_on(async {
let (tx, rx) = channel::<Point>();
tx.send(Point { x: 1, y: 2 }).unwrap();
rx.await
});
assert_eq!(result.unwrap(), Point { x: 1, y: 2 });
}
#[test]
fn oneshot_send_vec() {
let result = block_on(async {
let (tx, rx) = channel::<Vec<u8>>();
tx.send(vec![1, 2, 3, 4, 5]).unwrap();
rx.await
});
assert_eq!(result.unwrap(), vec![1, 2, 3, 4, 5]);
}
#[test]
fn oneshot_multiple_pairs_concurrent() {
block_on_with_spawn(async {
let mut rxs = Vec::new();
for i in 0u32..10 {
let (tx, rx) = channel::<u32>();
spawn(async move {
tx.send(i).unwrap();
});
rxs.push(rx);
}
let mut results: Vec<u32> = Vec::new();
for rx in rxs {
results.push(rx.await.unwrap());
}
results.sort();
assert_eq!(results, (0..10).collect::<Vec<_>>());
});
}
#[test]
fn oneshot_recv_error_display() {
let err = RecvError::Closed;
let s = format!("{err}");
assert!(s.contains("closed") || s.contains("Closed"));
}
#[test]
fn oneshot_send_returns_err_when_rx_dropped() {
let (tx, rx) = channel::<i32>();
drop(rx);
let result = tx.send(42);
assert_eq!(result, Err(42));
}
#[test]
fn oneshot_send_value_then_recv_in_separate_block_on() {
let (tx, rx) = channel::<u64>();
tx.send(12345).unwrap();
let val = block_on(async { rx.await.unwrap() });
assert_eq!(val, 12345);
}
#[test]
fn oneshot_sender_drop_closes_from_spawn() {
let result = block_on_with_spawn(async {
let (tx, rx) = channel::<u32>();
let jh = spawn(async move {
drop(tx);
});
jh.await.unwrap();
rx.await
});
assert_eq!(result, Err(RecvError::Closed));
}
#[test]
fn oneshot_recv_error_is_error_trait() {
let err = RecvError::Closed;
let _e: &dyn std::error::Error = &err;
}
#[test]
fn oneshot_u8_roundtrip() {
let result = block_on(async {
let (tx, rx) = channel::<u8>();
tx.send(255).unwrap();
rx.await.unwrap()
});
assert_eq!(result, 255);
}
#[test]
fn oneshot_bool_roundtrip() {
let result = block_on(async {
let (tx, rx) = channel::<bool>();
tx.send(true).unwrap();
rx.await.unwrap()
});
assert!(result);
}
#[test]
fn oneshot_unit_roundtrip() {
let result = block_on(async {
let (tx, rx) = channel::<()>();
tx.send(()).unwrap();
rx.await.unwrap()
});
assert_eq!(result, ());
}
#[test]
fn oneshot_10_pairs_in_parallel() {
block_on_with_spawn(async {
let mut rxs = Vec::new();
for i in 0..10u32 {
let (tx, rx) = channel::<u32>();
let val = i * 3;
spawn(async move { tx.send(val).unwrap() });
rxs.push((i, rx));
}
for (i, rx) in rxs {
let v = rx.await.unwrap();
assert_eq!(v, i * 3);
}
});
}
#[test]
fn oneshot_send_before_poll_synchronous() {
let (tx, rx) = channel::<u32>();
tx.send(777).unwrap();
let v = block_on(async { rx.await.unwrap() });
assert_eq!(v, 777);
}
#[test]
fn oneshot_send_i64_max() {
let result = block_on(async {
let (tx, rx) = channel::<i64>();
tx.send(i64::MAX).unwrap();
rx.await.unwrap()
});
assert_eq!(result, i64::MAX);
}
#[test]
fn oneshot_send_i64_min() {
let result = block_on(async {
let (tx, rx) = channel::<i64>();
tx.send(i64::MIN).unwrap();
rx.await.unwrap()
});
assert_eq!(result, i64::MIN);
}
#[test]
fn oneshot_send_empty_vec() {
let result = block_on(async {
let (tx, rx) = channel::<Vec<u8>>();
tx.send(Vec::new()).unwrap();
rx.await.unwrap()
});
assert!(result.is_empty());
}
#[test]
fn oneshot_two_separate_channels_independent() {
block_on(async {
let (tx1, rx1) = channel::<u32>();
let (tx2, rx2) = channel::<u32>();
tx1.send(1).unwrap();
tx2.send(2).unwrap();
assert_eq!(rx1.await.unwrap(), 1);
assert_eq!(rx2.await.unwrap(), 2);
});
}
}