use futures_sink::Sink;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::{fmt, mem};
use tokio::sync::mpsc::OwnedPermit;
use tokio::sync::mpsc::Sender;
use super::ReusableBoxFuture;
#[derive(Debug)]
pub struct PollSendError<T>(Option<T>);
impl<T> PollSendError<T> {
pub fn into_inner(self) -> Option<T> {
self.0
}
}
impl<T> fmt::Display for PollSendError<T> {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(fmt, "channel closed")
}
}
impl<T: fmt::Debug> std::error::Error for PollSendError<T> {}
#[derive(Debug)]
enum State<T> {
Idle(Sender<T>),
Acquiring,
ReadyToSend(OwnedPermit<T>),
Closed,
}
#[derive(Debug)]
pub struct PollSender<T> {
sender: Option<Sender<T>>,
state: State<T>,
acquire: PollSenderFuture<T>,
}
async fn make_acquire_future<T>(
data: Option<Sender<T>>,
) -> Result<OwnedPermit<T>, PollSendError<T>> {
match data {
Some(sender) => sender
.reserve_owned()
.await
.map_err(|_| PollSendError(None)),
None => unreachable!("this future should not be pollable in this state"),
}
}
type InnerFuture<'a, T> = ReusableBoxFuture<'a, Result<OwnedPermit<T>, PollSendError<T>>>;
#[derive(Debug)]
struct PollSenderFuture<T>(InnerFuture<'static, T>);
impl<T> PollSenderFuture<T> {
fn empty() -> Self {
Self(ReusableBoxFuture::new(async { unreachable!() }))
}
}
impl<T: Send> PollSenderFuture<T> {
fn new() -> Self {
let v = InnerFuture::new(make_acquire_future(None));
Self(unsafe { mem::transmute::<InnerFuture<'_, T>, InnerFuture<'static, T>>(v) })
}
fn poll(&mut self, cx: &mut Context<'_>) -> Poll<Result<OwnedPermit<T>, PollSendError<T>>> {
self.0.poll(cx)
}
fn set(&mut self, sender: Option<Sender<T>>) {
let inner: *mut InnerFuture<'static, T> = &mut self.0;
let inner: *mut InnerFuture<'_, T> = inner.cast();
let inner = unsafe { &mut *inner };
inner.set(make_acquire_future(sender));
}
}
impl<T: Send> PollSender<T> {
pub fn new(sender: Sender<T>) -> Self {
Self {
sender: Some(sender.clone()),
state: State::Idle(sender),
acquire: PollSenderFuture::new(),
}
}
fn take_state(&mut self) -> State<T> {
mem::replace(&mut self.state, State::Closed)
}
pub fn poll_reserve(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), PollSendError<T>>> {
loop {
let (result, next_state) = match self.take_state() {
State::Idle(sender) => {
self.acquire.set(Some(sender));
(None, State::Acquiring)
}
State::Acquiring => match self.acquire.poll(cx) {
Poll::Ready(Ok(permit)) => {
(Some(Poll::Ready(Ok(()))), State::ReadyToSend(permit))
}
Poll::Ready(Err(e)) => (Some(Poll::Ready(Err(e))), State::Closed),
Poll::Pending => (Some(Poll::Pending), State::Acquiring),
},
s @ State::Closed => (Some(Poll::Ready(Err(PollSendError(None)))), s),
s @ State::ReadyToSend(_) => (Some(Poll::Ready(Ok(()))), s),
};
self.state = next_state;
if let Some(result) = result {
return result;
}
}
}
#[track_caller]
pub fn send_item(&mut self, value: T) -> Result<(), PollSendError<T>> {
let (result, next_state) = match self.take_state() {
State::Idle(_) | State::Acquiring => {
panic!("`send_item` called without first calling `poll_reserve`")
}
State::ReadyToSend(permit) => (Ok(()), State::Idle(permit.send(value))),
State::Closed => (Err(PollSendError(Some(value))), State::Closed),
};
self.state = if self.sender.is_some() {
next_state
} else {
State::Closed
};
result
}
pub fn is_closed(&self) -> bool {
matches!(self.state, State::Closed) || self.sender.is_none()
}
pub fn get_ref(&self) -> Option<&Sender<T>> {
self.sender.as_ref()
}
pub fn close(&mut self) {
self.sender = None;
match self.state {
State::Idle(_) => self.state = State::Closed,
State::Acquiring => {
self.acquire.set(None);
self.state = State::Closed;
}
_ => {}
}
}
pub fn abort_send(&mut self) -> bool {
let (result, next_state) = match self.take_state() {
State::Acquiring => {
self.acquire.set(None);
let state = match self.sender.clone() {
Some(sender) => State::Idle(sender),
None => State::Closed,
};
(true, state)
}
State::ReadyToSend(permit) => {
let state = if self.sender.is_some() {
State::Idle(permit.release())
} else {
State::Closed
};
(true, state)
}
s => (false, s),
};
self.state = next_state;
result
}
}
impl<T> Clone for PollSender<T> {
fn clone(&self) -> PollSender<T> {
let (sender, state) = match self.sender.clone() {
Some(sender) => (Some(sender.clone()), State::Idle(sender)),
None => (None, State::Closed),
};
Self {
sender,
state,
acquire: PollSenderFuture::empty(),
}
}
}
impl<T: Send> Sink<T> for PollSender<T> {
type Error = PollSendError<T>;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Pin::into_inner(self).poll_reserve(cx)
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
Pin::into_inner(self).send_item(item)
}
fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Pin::into_inner(self).close();
Poll::Ready(Ok(()))
}
}