use super::*;
use futures::Future;
use std::{
pin::Pin,
task::{Context, Poll},
};
#[derive(Debug, Clone, PartialEq)]
pub enum TrySendError<T> {
Full(T),
Closed,
}
pub struct Sender<T, const N: usize> {
inner: Option<Arc<Inner<T, N>>>,
}
impl<T, const N: usize> Drop for Sender<T, N> {
fn drop(&mut self) {
let inner = self.inner.take().unwrap();
inner.do_close_batch();
inner.writer.waker.take();
let waker = inner.reader.waker.take();
drop(inner);
if let Some(waker) = waker {
waker.wake();
}
}
}
unsafe impl<T, const N: usize> Send for Sender<T, N> {}
impl<T, const N: usize> Sender<T, N> {
pub(crate) fn new(inner: Arc<Inner<T, N>>) -> Self {
Self { inner: Some(inner) }
}
fn inner(&self) -> &Inner<T, N> {
self.inner.as_ref().unwrap()
}
fn strong_count(&self) -> usize {
Arc::strong_count(self.inner.as_ref().unwrap())
}
pub fn try_send(&mut self, value: T) -> Result<bool, TrySendError<T>> {
if self.strong_count() == 1 {
Err(TrySendError::Closed)
} else {
match self.inner().do_send(value) {
Ok(b) => Ok(b),
Err(v) => Err(TrySendError::Full(v)),
}
}
}
pub(crate) fn try_poll(
&mut self,
value: T,
cx: &mut Context<'_>,
) -> Result<bool, TrySendError<T>> {
if self.strong_count() == 1 {
return Err(TrySendError::Closed);
}
match self.inner().do_send(value) {
Ok(b) => Ok(b),
Err(v) => {
self.inner().writer.waker.register(cx.waker());
if self.strong_count() == 1 {
Err(TrySendError::Closed)
} else {
match self.inner().do_send(v) {
Ok(b) => {
self.inner().writer.waker.take();
Ok(b)
}
Err(v) => Err(TrySendError::Full(v)),
}
}
}
}
}
pub fn close_batch(&mut self) {
self.inner().do_close_batch()
}
pub fn send(&mut self, value: T) -> SendFuture<'_, T, N> {
SendFuture {
inner: self.inner.as_ref().unwrap(),
value: Some(value),
}
}
}
pub struct SendFuture<'a, T, const N: usize> {
inner: &'a Arc<Inner<T, N>>,
value: Option<T>,
}
unsafe impl<'a, T, const N: usize> Send for SendFuture<'a, T, N> {}
impl<'a, T, const N: usize> Future for SendFuture<'a, T, N> {
type Output = Result<bool, Closed>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = unsafe { self.get_unchecked_mut() };
if Arc::strong_count(this.inner) == 1 {
return Poll::Ready(Err(Closed));
}
let value = match this.value.take() {
Some(v) => v,
None => return Poll::Ready(Ok(false)),
};
match this.inner.do_send(value) {
Ok(b) => Poll::Ready(Ok(b)),
Err(v) => {
this.inner.writer.waker.register(cx.waker());
if Arc::strong_count(this.inner) == 1 {
Poll::Ready(Err(Closed))
} else {
match this.inner.do_send(v) {
Ok(b) => {
this.inner.writer.waker.take();
Poll::Ready(Ok(b))
}
Err(v) => {
this.value = Some(v);
Poll::Pending
}
}
}
}
}
}
}