use super::shared::MpscShared;
use crate::error::{CloseError, SendError, TrySendError};
use crate::mpsc::block_queue::Block;
use core::marker::PhantomPinned;
use parking_lot::Mutex;
use std::future::Future;
use std::pin::Pin;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::task::{Context, Poll};
pub(crate) fn send_internal<T: Send>(
shared: &Arc<MpscShared<T>>,
value: T,
cache: &mut Option<Arc<Block<T>>>,
) -> Result<(), T> {
if shared.receiver_dropped.load(Ordering::Acquire) {
return Err(value);
}
shared.queue.push(value, cache);
shared.current_len.fetch_add(1, Ordering::Relaxed);
shared.wake_consumer();
Ok(())
}
#[derive(Debug)]
pub struct Sender<T: Send> {
pub(crate) shared: Arc<MpscShared<T>>,
pub(crate) closed: AtomicBool,
pub(crate) cache: Mutex<Option<Arc<Block<T>>>>,
}
#[derive(Debug)]
pub struct AsyncSender<T: Send> {
pub(crate) shared: Arc<MpscShared<T>>,
pub(crate) closed: AtomicBool,
pub(crate) cache: Mutex<Option<Arc<Block<T>>>>,
}
impl<T: Send> Sender<T> {
pub fn send(&self, value: T) -> Result<(), SendError> {
if self.closed.load(Ordering::Relaxed) {
return Err(SendError::Closed);
}
let mut cache_guard = self.cache.lock();
send_internal(&self.shared, value, &mut *cache_guard).map_err(|_| SendError::Closed)
}
pub fn try_send(&self, value: T) -> Result<(), TrySendError<T>> {
if self.closed.load(Ordering::Relaxed) {
return Err(TrySendError::Closed(value));
}
let mut cache_guard = self.cache.lock();
send_internal(&self.shared, value, &mut *cache_guard).map_err(TrySendError::Closed)
}
pub fn is_closed(&self) -> bool {
self.shared.receiver_dropped.load(Ordering::Acquire)
}
pub fn close(&self) -> Result<(), CloseError> {
if self
.closed
.compare_exchange(false, true, Ordering::AcqRel, Ordering::Relaxed)
.is_ok()
{
self.close_internal();
Ok(())
} else {
Err(CloseError)
}
}
fn close_internal(&self) {
if self.shared.sender_count.fetch_sub(1, Ordering::AcqRel) == 1 {
self.shared.wake_consumer();
}
}
pub fn sender_count(&self) -> usize {
self.shared.sender_count.load(Ordering::Relaxed)
}
pub fn len(&self) -> usize {
self.shared.current_len.load(Ordering::Relaxed)
}
pub fn is_empty(&self) -> bool {
self.shared.queue.is_empty()
}
pub fn to_async(self) -> AsyncSender<T> {
let shared = unsafe { std::ptr::read(&self.shared) };
{
let mut guard = self.cache.lock();
let _cached_block = guard.take();
}
std::mem::forget(self);
AsyncSender {
shared,
closed: AtomicBool::new(false),
cache: Mutex::new(None),
}
}
}
impl<T: Send> AsyncSender<T> {
pub fn send(&self, value: T) -> SendFuture<'_, T> {
SendFuture {
producer: self,
value: Some(value),
_phantom: PhantomPinned,
}
}
pub fn try_send(&self, value: T) -> Result<(), TrySendError<T>> {
if self.closed.load(Ordering::Relaxed) {
return Err(TrySendError::Closed(value));
}
let mut cache_guard = self.cache.lock();
send_internal(&self.shared, value, &mut *cache_guard).map_err(TrySendError::Closed)
}
pub fn is_closed(&self) -> bool {
self.shared.receiver_dropped.load(Ordering::Acquire)
}
pub fn close(&self) -> Result<(), CloseError> {
if self
.closed
.compare_exchange(false, true, Ordering::AcqRel, Ordering::Relaxed)
.is_ok()
{
self.close_internal();
Ok(())
} else {
Err(CloseError)
}
}
fn close_internal(&self) {
if self.shared.sender_count.fetch_sub(1, Ordering::AcqRel) == 1 {
self.shared.wake_consumer();
}
}
pub fn sender_count(&self) -> usize {
self.shared.sender_count.load(Ordering::Relaxed)
}
pub fn len(&self) -> usize {
self.shared.current_len.load(Ordering::Relaxed)
}
pub fn is_empty(&self) -> bool {
self.shared.queue.is_empty()
}
pub fn to_sync(self) -> Sender<T> {
let shared = unsafe { std::ptr::read(&self.shared) };
{
let mut c = self.cache.lock();
*c = None;
}
std::mem::forget(self);
Sender {
shared,
closed: AtomicBool::new(false),
cache: Mutex::new(None),
}
}
}
#[must_use = "futures do nothing unless you .await or poll them"]
pub struct SendFuture<'a, T: Send> {
producer: &'a AsyncSender<T>,
value: Option<T>,
_phantom: PhantomPinned,
}
impl<'a, T: Send> Future for SendFuture<'a, T> {
type Output = Result<(), SendError>;
fn poll(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = unsafe { self.as_mut().get_unchecked_mut() };
if this.producer.closed.load(Ordering::Relaxed) {
return Poll::Ready(Err(SendError::Closed));
}
let value = this
.value
.take()
.expect("SendFuture polled after completion");
let mut cache_guard = this.producer.cache.lock();
Poll::Ready(
send_internal(&this.producer.shared, value, &mut *cache_guard).map_err(|_| SendError::Closed),
)
}
}
impl<T: Send> Clone for Sender<T> {
fn clone(&self) -> Self {
self.shared.sender_count.fetch_add(1, Ordering::Relaxed);
Sender {
shared: Arc::clone(&self.shared),
closed: AtomicBool::new(false),
cache: Mutex::new(None),
}
}
}
impl<T: Send> Drop for Sender<T> {
fn drop(&mut self) {
if !self.closed.swap(true, Ordering::AcqRel) {
self.close_internal();
}
}
}
impl<T: Send> Clone for AsyncSender<T> {
fn clone(&self) -> Self {
self.shared.sender_count.fetch_add(1, Ordering::Relaxed);
AsyncSender {
shared: Arc::clone(&self.shared),
closed: AtomicBool::new(false),
cache: Mutex::new(None),
}
}
}
impl<T: Send> Drop for AsyncSender<T> {
fn drop(&mut self) {
if !self.closed.swap(true, Ordering::AcqRel) {
self.close_internal();
}
}
}