use crate::Error;
use std::task::{Context, Poll};
use bytes::Bytes;
use futures_channel::{mpsc, oneshot};
use http::HeaderMap;
use super::watch;
type BodySender = mpsc::Sender<Result<Bytes, Error>>;
type TrailersSender = oneshot::Sender<HeaderMap>;
pub(crate) const WANT_PENDING: usize = 1;
pub(crate) const WANT_READY: usize = 2;
#[must_use = "Sender does nothing unless sent on"]
pub struct Sender {
pub(crate) want_rx: watch::Receiver,
pub(crate) data_tx: BodySender,
pub(crate) trailers_tx: Option<TrailersSender>,
}
impl Sender {
pub fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
ready!(self.poll_want(cx)?);
self.data_tx
.poll_ready(cx)
.map_err(|_| Error::new(SenderError::ChannelClosed))
}
fn poll_want(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
match self.want_rx.load(cx) {
WANT_READY => Poll::Ready(Ok(())),
WANT_PENDING => Poll::Pending,
watch::CLOSED => Poll::Ready(Err(Error::new(SenderError::ChannelClosed))),
unexpected => unreachable!("want_rx value: {}", unexpected),
}
}
async fn ready(&mut self) -> Result<(), Error> {
futures_util::future::poll_fn(|cx| self.poll_ready(cx)).await
}
#[allow(unused)]
pub async fn send_data(&mut self, chunk: Bytes) -> Result<(), Error> {
self.ready().await?;
self.data_tx
.try_send(Ok(chunk))
.map_err(|_| Error::new(SenderError::ChannelClosed))
}
#[allow(unused)]
pub async fn send_trailers(&mut self, trailers: HeaderMap) -> Result<(), Error> {
let tx = match self.trailers_tx.take() {
Some(tx) => tx,
None => return Err(Error::new(SenderError::ChannelClosed)),
};
tx.send(trailers).map_err(|_| Error::new(SenderError::ChannelClosed))
}
pub fn try_send_data(&mut self, chunk: Bytes) -> Result<(), Bytes> {
self.data_tx
.try_send(Ok(chunk))
.map_err(|err| err.into_inner().expect("just sent Ok"))
}
#[allow(unused)]
pub fn abort(mut self) {
self.send_error(Error::new(SenderError::BodyWriteAborted));
}
pub fn send_error(&mut self, err: Error) {
let _ = self
.data_tx
.clone()
.try_send(Err(err));
}
}
#[derive(Debug)]
enum SenderError {
ChannelClosed,
BodyWriteAborted,
}
impl SenderError {
fn description(&self) -> &str {
match self {
SenderError::BodyWriteAborted => "user body write aborted",
SenderError::ChannelClosed => "channel closed",
}
}
}
impl std::fmt::Display for SenderError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(self.description())
}
}
impl std::error::Error for SenderError {}