use alloc::boxed::Box;
use pin_project::pin_project;
use super::ring::Ring;
use crate::sync::{Closure, Discriminant, Flags, Signal, Waiter, Waiters, WaitersExt};
use core::{
pin::{Pin, pin},
sync::atomic::Ordering::*,
task::{Context, Poll},
};
pub struct Duplex<A, B>(Half<A>, Half<B>);
pub struct Half<T> {
ring: Ring<T>,
senders: Waiters,
receivers: Waiters,
is_closed: Flags<Closure>,
}
pub struct Endpoint<S, R> {
send_half: *const Half<S>,
recv_half: *const Half<R>,
}
unsafe impl<S: Send, R: Send> Send for Endpoint<S, R> {}
unsafe impl<S: Send, R: Send> Sync for Endpoint<S, R> {}
impl<T> Half<T> {
fn new(capacity: usize) -> Self {
Self {
ring: Ring::new(capacity),
senders: Waiters::new(),
receivers: Waiters::new(),
is_closed: Flags::default(),
}
}
unsafe fn try_send(&self, value: T) -> Result<(), TrySendError<T>> {
if self.is_closed.is_set(&Closure::Closed) {
return Err(TrySendError::Disconnected(value));
}
match self.ring.push(value) {
Ok(()) => {
if let Some(waiter) = self.receivers.dequeue() {
waiter.wake_by_ref();
}
Ok(())
}
Err(value) => Err(TrySendError::Full(value)),
}
}
unsafe fn try_recv(&self) -> Result<T, TryRecvError> {
if let Some(value) = self.ring.pop() {
if let Some(waiter) = self.senders.dequeue() {
waiter.wake_by_ref();
}
return Ok(value);
}
if self.is_closed.is_set(&Closure::Closed) {
Err(TryRecvError::Disconnected)
} else {
Err(TryRecvError::Empty)
}
}
unsafe fn close(&self) {
self.is_closed.set(&Closure::Closed);
self.receivers.notify_all();
self.senders.notify_all();
}
}
impl<S: Send + Unpin, R: Send> Endpoint<S, R> {
pub fn try_send(&self, value: S) -> Result<(), TrySendError<S>> {
unsafe { (*self.send_half).try_send(value) }
}
pub fn try_recv(&self) -> Result<R, TryRecvError> {
unsafe { (*self.recv_half).try_recv() }
}
pub async fn send(&self, value: S) -> Result<(), S> {
SendFuture {
half: self.send_half,
value: Some(value),
waiter: None,
}
.await
}
pub async fn recv(&self) -> Option<R> {
RecvFuture {
half: self.recv_half,
}
.await
}
pub fn close(&self) {
unsafe {
(*self.send_half).close();
(*self.recv_half).close();
}
}
}
#[pin_project]
struct SendFuture<T> {
half: *const Half<T>,
#[pin]
value: Option<T>,
#[pin]
waiter: Option<Waiter>,
}
impl<T: Send> Half<T> {
pub fn send(&self, value: T) -> SendFuture<T> {
SendFuture {
half: self as *const _,
value: Some(value),
waiter: None, }
}
}
impl<T: Send + Unpin> Future for SendFuture<T> {
type Output = Result<(), T>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut this = self.project();
if let Some(value) = this.value.as_mut().get_mut().take() {
unsafe {
let half = (*this.half).as_ref().unwrap();
match half.try_send(value) {
Ok(()) => Poll::Ready(Ok(())),
Err(TrySendError::Disconnected(value)) => Poll::Ready(Err(value)),
Err(TrySendError::Full(value)) => {
if this.waiter.is_none() {
this.waiter
.set(Some(Waiter::from_waker(cx.waker().clone())));
} else {
this.waiter
.as_mut()
.as_pin_mut()
.unwrap()
.assign_waker(cx.waker().clone());
}
let waiter_ref = this.waiter.take().unwrap();
half.senders.enqueue(waiter_ref);
this.value = pin!(Some(value));
Poll::Pending
}
}
}
} else {
panic!("SendFuture polled after completion");
}
}
}
struct RecvFuture<T> {
half: *const Half<T>,
}
impl<T: Send> Future for RecvFuture<T> {
type Output = Option<T>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
unsafe {
match (*self.half).try_recv() {
Ok(value) => Poll::Ready(Some(value)),
Err(TryRecvError::Disconnected) => Poll::Ready(None),
Err(TryRecvError::Empty) => {
(*self.half)
.receivers
.enqueue(Waiter::from_waker(cx.waker().clone()));
Poll::Pending
}
}
}
}
}
#[derive(Debug)]
pub enum TrySendError<T> {
Full(T),
Disconnected(T),
}
#[derive(Debug)]
pub enum TryRecvError {
Empty,
Disconnected,
}
pub fn duplex<A: Send, B: Send>(capacity: usize) -> (Endpoint<A, B>, Endpoint<B, A>) {
let duplex = Box::new(Duplex(Half::new(capacity), Half::new(capacity)));
let ptr = Box::into_raw(duplex);
let a = Endpoint {
send_half: unsafe { &(*ptr).0 },
recv_half: unsafe { &(*ptr).1 },
};
let b = Endpoint {
send_half: unsafe { &(*ptr).1 },
recv_half: unsafe { &(*ptr).0 },
};
(a, b)
}