use std::{
future::Future,
pin::Pin,
task::{Context, Poll},
};
use futures::{ready, sink::Sink};
use pin_project_lite::pin_project;
pub(crate) fn make_sink<S, F, T, A, E>(init: S, f: F) -> SinkImpl<S, F, T, A, E>
where
F: FnMut(S, Action<A>) -> T,
T: Future<Output = Result<S, E>>,
{
SinkImpl {
lambda: f,
future: None,
param: Some(init),
state: State::Empty,
_mark: std::marker::PhantomData,
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub(crate) enum Action<A> {
Send(A),
Flush,
Close,
}
#[derive(Debug, PartialEq, Eq)]
enum State {
Empty,
Sending,
Flushing,
Closing,
Closed,
Failed,
}
#[derive(Debug, thiserror::Error)]
pub(crate) enum Error<E> {
#[error("Error while sending over the sink, {0}")]
Send(E),
#[error("The Sink has closed")]
Closed,
}
pin_project! {
#[derive(Debug)]
pub(crate) struct SinkImpl<S, F, T, A, E> {
lambda: F,
#[pin] future: Option<T>,
param: Option<S>,
state: State,
_mark: std::marker::PhantomData<(A, E)>
}
}
impl<S, F, T, A, E> Sink<A> for SinkImpl<S, F, T, A, E>
where
F: FnMut(S, Action<A>) -> T,
T: Future<Output = Result<S, E>>,
{
type Error = Error<E>;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
let mut this = self.project();
match this.state {
State::Sending | State::Flushing => {
match ready!(this.future.as_mut().as_pin_mut().unwrap().poll(cx)) {
Ok(p) => {
this.future.set(None);
*this.param = Some(p);
*this.state = State::Empty;
Poll::Ready(Ok(()))
}
Err(e) => {
this.future.set(None);
*this.state = State::Failed;
Poll::Ready(Err(Error::Send(e)))
}
}
}
State::Closing => match ready!(this.future.as_mut().as_pin_mut().unwrap().poll(cx)) {
Ok(_) => {
this.future.set(None);
*this.state = State::Closed;
Poll::Ready(Err(Error::Closed))
}
Err(e) => {
this.future.set(None);
*this.state = State::Failed;
Poll::Ready(Err(Error::Send(e)))
}
},
State::Empty => {
assert!(this.param.is_some());
Poll::Ready(Ok(()))
}
State::Closed | State::Failed => Poll::Ready(Err(Error::Closed)),
}
}
fn start_send(self: Pin<&mut Self>, item: A) -> Result<(), Self::Error> {
assert_eq!(State::Empty, self.state);
let mut this = self.project();
let param = this.param.take().unwrap();
let future = (this.lambda)(param, Action::Send(item));
this.future.set(Some(future));
*this.state = State::Sending;
Ok(())
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
loop {
let mut this = self.as_mut().project();
match this.state {
State::Empty => {
if let Some(p) = this.param.take() {
let future = (this.lambda)(p, Action::Flush);
this.future.set(Some(future));
*this.state = State::Flushing
} else {
return Poll::Ready(Ok(()));
}
}
State::Sending => match ready!(this.future.as_mut().as_pin_mut().unwrap().poll(cx))
{
Ok(p) => {
this.future.set(None);
*this.param = Some(p);
*this.state = State::Empty
}
Err(e) => {
this.future.set(None);
*this.state = State::Failed;
return Poll::Ready(Err(Error::Send(e)));
}
},
State::Flushing => {
match ready!(this.future.as_mut().as_pin_mut().unwrap().poll(cx)) {
Ok(p) => {
this.future.set(None);
*this.param = Some(p);
*this.state = State::Empty;
return Poll::Ready(Ok(()));
}
Err(e) => {
this.future.set(None);
*this.state = State::Failed;
return Poll::Ready(Err(Error::Send(e)));
}
}
}
State::Closing => match ready!(this.future.as_mut().as_pin_mut().unwrap().poll(cx))
{
Ok(_) => {
this.future.set(None);
*this.state = State::Closed;
return Poll::Ready(Ok(()));
}
Err(e) => {
this.future.set(None);
*this.state = State::Failed;
return Poll::Ready(Err(Error::Send(e)));
}
},
State::Closed | State::Failed => return Poll::Ready(Err(Error::Closed)),
}
}
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
loop {
let mut this = self.as_mut().project();
match this.state {
State::Empty => {
if let Some(p) = this.param.take() {
let future = (this.lambda)(p, Action::Close);
this.future.set(Some(future));
*this.state = State::Closing;
} else {
return Poll::Ready(Ok(()));
}
}
State::Sending => match ready!(this.future.as_mut().as_pin_mut().unwrap().poll(cx))
{
Ok(p) => {
this.future.set(None);
*this.param = Some(p);
*this.state = State::Empty
}
Err(e) => {
this.future.set(None);
*this.state = State::Failed;
return Poll::Ready(Err(Error::Send(e)));
}
},
State::Flushing => {
match ready!(this.future.as_mut().as_pin_mut().unwrap().poll(cx)) {
Ok(p) => {
this.future.set(None);
*this.param = Some(p);
*this.state = State::Empty
}
Err(e) => {
this.future.set(None);
*this.state = State::Failed;
return Poll::Ready(Err(Error::Send(e)));
}
}
}
State::Closing => match ready!(this.future.as_mut().as_pin_mut().unwrap().poll(cx))
{
Ok(_) => {
this.future.set(None);
*this.state = State::Closed;
return Poll::Ready(Ok(()));
}
Err(e) => {
this.future.set(None);
*this.state = State::Failed;
return Poll::Ready(Err(Error::Send(e)));
}
},
State::Closed => return Poll::Ready(Ok(())),
State::Failed => return Poll::Ready(Err(Error::Closed)),
}
}
}
}
#[cfg(test)]
mod tests {
use futures::{channel::mpsc, prelude::*};
use tokio::io::{self, AsyncWriteExt};
use crate::quicksink::{make_sink, Action};
#[tokio::test]
async fn smoke_test() {
let sink = make_sink(io::stdout(), |mut stdout, action| async move {
match action {
Action::Send(x) => stdout.write_all(x).await?,
Action::Flush => stdout.flush().await?,
Action::Close => stdout.shutdown().await?,
}
Ok::<_, io::Error>(stdout)
});
let values = vec![Ok(&b"hello\n"[..]), Ok(&b"world\n"[..])];
assert!(stream::iter(values).forward(sink).await.is_ok())
}
#[tokio::test]
async fn replay() {
let (tx, rx) = mpsc::channel(5);
let sink = make_sink(tx, |mut tx, action| async move {
tx.send(action.clone()).await?;
if action == Action::Close {
tx.close().await?
}
Ok::<_, mpsc::SendError>(tx)
});
futures::pin_mut!(sink);
let expected = [
Action::Send("hello\n"),
Action::Flush,
Action::Send("world\n"),
Action::Flush,
Action::Close,
];
for &item in &["hello\n", "world\n"] {
sink.send(item).await.unwrap()
}
sink.close().await.unwrap();
let actual = rx.collect::<Vec<_>>().await;
assert_eq!(&expected[..], &actual[..])
}
#[tokio::test]
async fn error_does_not_panic() {
let sink = make_sink(io::stdout(), |mut _stdout, _action| async move {
Err(io::Error::other("oh no"))
});
futures::pin_mut!(sink);
let result = sink.send("hello").await;
match result {
Err(crate::quicksink::Error::Send(e)) => {
assert_eq!(e.kind(), io::ErrorKind::Other);
assert_eq!(e.to_string(), "oh no")
}
_ => panic!("unexpected result: {result:?}"),
};
let result = sink.send("hello").await;
match result {
Err(crate::quicksink::Error::Closed) => {}
_ => panic!("unexpected result: {result:?}"),
};
}
}