use std::cell::RefCell;
use std::collections::VecDeque;
use std::fmt;
use std::future::Future;
use std::pin::Pin;
use std::rc::Rc;
use std::task::{Context, Poll};
use super::CURRENT_TASK_ID;
use super::io::try_with_state;
fn wake_waiter(waiter: Option<u32>) {
if let Some(id) = waiter {
try_with_state(|_driver, executor| {
executor.wake_task(id);
});
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct RecvError;
impl fmt::Display for RecvError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("channel closed")
}
}
impl std::error::Error for RecvError {}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct SendError<T>(pub T);
impl<T> fmt::Display for SendError<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("channel closed")
}
}
impl<T: fmt::Debug> std::error::Error for SendError<T> {}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum TryRecvError {
Empty,
Disconnected,
}
impl fmt::Display for TryRecvError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
TryRecvError::Empty => f.write_str("channel empty"),
TryRecvError::Disconnected => f.write_str("channel disconnected"),
}
}
}
impl std::error::Error for TryRecvError {}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum TrySendError<T> {
Full(T),
Disconnected(T),
}
impl<T> fmt::Display for TrySendError<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
TrySendError::Full(_) => f.write_str("channel full"),
TrySendError::Disconnected(_) => f.write_str("channel disconnected"),
}
}
}
impl<T: fmt::Debug> std::error::Error for TrySendError<T> {}
pub mod oneshot {
use super::*;
struct State<T> {
value: Option<T>,
recv_waiter: Option<u32>,
closed: bool,
}
pub fn channel<T>() -> (Sender<T>, Receiver<T>) {
let state = Rc::new(RefCell::new(State {
value: None,
recv_waiter: None,
closed: false,
}));
(
Sender {
state: Rc::clone(&state),
sent: false,
},
Receiver { state },
)
}
pub struct Sender<T> {
state: Rc<RefCell<State<T>>>,
sent: bool,
}
impl<T> Sender<T> {
pub fn send(mut self, value: T) -> Result<(), T> {
if Rc::strong_count(&self.state) == 1 {
return Err(value);
}
let mut s = self.state.borrow_mut();
s.value = Some(value);
self.sent = true;
let waiter = s.recv_waiter.take();
drop(s);
wake_waiter(waiter);
Ok(())
}
}
impl<T> Drop for Sender<T> {
fn drop(&mut self) {
if !self.sent {
let mut s = self.state.borrow_mut();
s.closed = true;
let waiter = s.recv_waiter.take();
drop(s);
wake_waiter(waiter);
}
}
}
pub struct Receiver<T> {
state: Rc<RefCell<State<T>>>,
}
impl<T> Receiver<T> {
pub fn try_recv(&self) -> Result<T, TryRecvError> {
let mut s = self.state.borrow_mut();
if let Some(value) = s.value.take() {
return Ok(value);
}
if s.closed {
return Err(TryRecvError::Disconnected);
}
Err(TryRecvError::Empty)
}
}
impl<T> Future for Receiver<T> {
type Output = Result<T, RecvError>;
fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut s = self.state.borrow_mut();
if let Some(value) = s.value.take() {
return Poll::Ready(Ok(value));
}
if s.closed {
return Poll::Ready(Err(RecvError));
}
s.recv_waiter = Some(CURRENT_TASK_ID.with(|c| c.get()));
Poll::Pending
}
}
}
pub mod mpsc {
use super::*;
struct State<T> {
queue: VecDeque<T>,
capacity: usize,
recv_waiter: Option<u32>,
send_waiter: Option<u32>,
sender_count: usize,
}
pub fn channel<T>(capacity: usize) -> (Sender<T>, Receiver<T>) {
assert!(capacity > 0, "mpsc channel capacity must be > 0");
let state = Rc::new(RefCell::new(State {
queue: VecDeque::with_capacity(capacity),
capacity,
recv_waiter: None,
send_waiter: None,
sender_count: 1,
}));
(
Sender {
state: Rc::clone(&state),
},
Receiver { state },
)
}
pub struct Sender<T> {
state: Rc<RefCell<State<T>>>,
}
impl<T> Sender<T> {
pub fn try_send(&self, value: T) -> Result<(), TrySendError<T>> {
let mut s = self.state.borrow_mut();
if Rc::strong_count(&self.state) <= s.sender_count {
return Err(TrySendError::Disconnected(value));
}
if s.queue.len() >= s.capacity {
return Err(TrySendError::Full(value));
}
s.queue.push_back(value);
let waiter = s.recv_waiter.take();
drop(s);
wake_waiter(waiter);
Ok(())
}
pub fn send(&self, value: T) -> SendFuture<'_, T> {
SendFuture {
state: &self.state,
value: Some(value),
}
}
}
impl<T> Clone for Sender<T> {
fn clone(&self) -> Self {
self.state.borrow_mut().sender_count += 1;
Sender {
state: Rc::clone(&self.state),
}
}
}
impl<T> Drop for Sender<T> {
fn drop(&mut self) {
let mut s = self.state.borrow_mut();
s.sender_count -= 1;
if s.sender_count == 0 {
let waiter = s.recv_waiter.take();
drop(s);
wake_waiter(waiter);
}
}
}
pub struct SendFuture<'a, T> {
state: &'a Rc<RefCell<State<T>>>,
value: Option<T>,
}
impl<T> Unpin for SendFuture<'_, T> {}
impl<T> Future for SendFuture<'_, T> {
type Output = Result<(), SendError<T>>;
fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut();
let mut s = this.state.borrow_mut();
if Rc::strong_count(this.state) <= s.sender_count {
let value = this.value.take().unwrap();
return Poll::Ready(Err(SendError(value)));
}
if s.queue.len() < s.capacity {
let value = this.value.take().unwrap();
s.queue.push_back(value);
let waiter = s.recv_waiter.take();
drop(s);
wake_waiter(waiter);
return Poll::Ready(Ok(()));
}
s.send_waiter = Some(CURRENT_TASK_ID.with(|c| c.get()));
Poll::Pending
}
}
pub struct Receiver<T> {
state: Rc<RefCell<State<T>>>,
}
impl<T> Receiver<T> {
pub fn try_recv(&self) -> Result<T, TryRecvError> {
let mut s = self.state.borrow_mut();
if let Some(value) = s.queue.pop_front() {
let waiter = s.send_waiter.take();
drop(s);
wake_waiter(waiter);
return Ok(value);
}
if s.sender_count == 0 {
return Err(TryRecvError::Disconnected);
}
Err(TryRecvError::Empty)
}
pub fn recv(&self) -> RecvFuture<'_, T> {
RecvFuture { state: &self.state }
}
}
pub struct RecvFuture<'a, T> {
state: &'a Rc<RefCell<State<T>>>,
}
impl<T> Future for RecvFuture<'_, T> {
type Output = Option<T>;
fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut s = self.state.borrow_mut();
if let Some(value) = s.queue.pop_front() {
let waiter = s.send_waiter.take();
drop(s);
wake_waiter(waiter);
return Poll::Ready(Some(value));
}
if s.sender_count == 0 {
return Poll::Ready(None);
}
s.recv_waiter = Some(CURRENT_TASK_ID.with(|c| c.get()));
Poll::Pending
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn oneshot_send_before_recv() {
let (tx, rx) = oneshot::channel::<i32>();
assert!(tx.send(42).is_ok());
assert_eq!(rx.try_recv(), Ok(42));
}
#[test]
fn oneshot_sender_dropped() {
let (tx, rx) = oneshot::channel::<i32>();
drop(tx);
assert_eq!(rx.try_recv(), Err(TryRecvError::Disconnected));
}
#[test]
fn oneshot_receiver_dropped() {
let (tx, _) = oneshot::channel::<i32>();
assert_eq!(tx.send(42), Err(42));
}
#[test]
fn mpsc_send_and_recv() {
let (tx, rx) = mpsc::channel::<i32>(4);
tx.try_send(1).unwrap();
tx.try_send(2).unwrap();
tx.try_send(3).unwrap();
assert_eq!(rx.try_recv(), Ok(1));
assert_eq!(rx.try_recv(), Ok(2));
assert_eq!(rx.try_recv(), Ok(3));
assert_eq!(rx.try_recv(), Err(TryRecvError::Empty));
}
#[test]
fn mpsc_full_channel() {
let (tx, _rx) = mpsc::channel::<i32>(2);
tx.try_send(1).unwrap();
tx.try_send(2).unwrap();
match tx.try_send(3) {
Err(TrySendError::Full(3)) => {}
other => panic!("expected Full(3), got {other:?}"),
}
}
#[test]
fn mpsc_sender_clone_and_drop() {
let (tx1, rx) = mpsc::channel::<i32>(4);
let tx2 = tx1.clone();
tx1.try_send(1).unwrap();
tx2.try_send(2).unwrap();
drop(tx1);
assert_eq!(rx.try_recv(), Ok(1));
assert_eq!(rx.try_recv(), Ok(2));
assert_eq!(rx.try_recv(), Err(TryRecvError::Empty));
drop(tx2);
assert_eq!(rx.try_recv(), Err(TryRecvError::Disconnected));
}
#[test]
fn mpsc_receiver_dropped() {
let (tx, rx) = mpsc::channel::<i32>(4);
drop(rx);
match tx.try_send(1) {
Err(TrySendError::Disconnected(1)) => {}
other => panic!("expected Disconnected(1), got {other:?}"),
}
}
#[test]
#[should_panic(expected = "capacity must be > 0")]
fn mpsc_zero_capacity_panics() {
let _ = mpsc::channel::<i32>(0);
}
}