ruchei 0.1.2

Utilities for working with many streams
Documentation
use std::{
    pin::Pin,
    sync::Arc,
    task::{Context, Poll, Wake},
};

use futures_util::{Sink, Stream, TryStream, ready, task::AtomicWaker};
use pin_project::pin_project;

pub type Filtered<F, Item> = Option<(
    Option<<F as ReplyBufferFilter<Item>>::Reply>,
    Option<<F as ReplyBufferFilter<Item>>::Filtered>,
)>;

pub trait ReplyBufferFilter<Item> {
    type Filtered;
    type Reply;
    fn reply_and_item(&mut self, item: Item) -> Filtered<Self, Item>;
}

impl<Item, Reply, Filtered, F: ?Sized + FnMut(Item) -> Option<(Option<Reply>, Option<Filtered>)>>
    ReplyBufferFilter<Item> for F
{
    type Filtered = Filtered;
    type Reply = Reply;

    fn reply_and_item(
        &mut self,
        item: Item,
    ) -> Option<(Option<Self::Reply>, Option<Self::Filtered>)> {
        self(item)
    }
}

#[derive(Debug, Default)]
struct Wakers {
    next: AtomicWaker,
    ready: AtomicWaker,
    flush: AtomicWaker,
}

impl Wake for Wakers {
    fn wake(self: Arc<Self>) {
        self.next.wake();
        self.ready.wake();
        self.flush.wake();
    }
}

#[pin_project]
#[derive(Debug)]
pub struct WithReply<S, T, F> {
    #[pin]
    stream: S,
    reply: Option<T>,
    start: Option<T>,
    needs_flush: bool,
    wakers: Arc<Wakers>,
    filter: F,
}

impl<S: Sink<T>, T, F> WithReply<S, T, F> {
    fn poll_reply(self: Pin<&mut Self>) -> Poll<Result<(), S::Error>> {
        let mut this = self.project();
        let waker = this.wakers.clone().into();
        let mut cx = Context::from_waker(&waker);
        if this.reply.is_some() {
            ready!(this.stream.as_mut().poll_ready(&mut cx))?;
        }
        if let Some(item) = this.reply.take() {
            this.stream.start_send(item)?;
            this.wakers.wake_by_ref();
        }
        Poll::Ready(Ok(()))
    }

    fn poll_start(self: Pin<&mut Self>) -> Poll<Result<(), S::Error>> {
        let mut this = self.project();
        let waker = this.wakers.clone().into();
        let mut cx = Context::from_waker(&waker);
        if this.start.is_some() {
            ready!(this.stream.as_mut().poll_ready(&mut cx))?;
        }
        if let Some(item) = this.start.take() {
            this.stream.start_send(item)?;
            this.wakers.wake_by_ref();
        }
        Poll::Ready(Ok(()))
    }

    fn poll(mut self: Pin<&mut Self>) -> Poll<Result<(), S::Error>> {
        ready!(self.as_mut().poll_reply())?;
        self.poll_start()
    }

    fn poll_flush_if_needed(mut self: Pin<&mut Self>) -> Result<(), S::Error> {
        let this = self.as_mut().project();
        let waker = this.wakers.clone().into();
        let mut cx = Context::from_waker(&waker);
        if *this.needs_flush {
            let _ = self.poll_flush_raw(&mut cx)?;
        }
        Ok(())
    }

    fn poll_flush_raw(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), S::Error>> {
        let this = self.project();
        ready!(this.stream.poll_flush(cx))?;
        if *this.needs_flush {
            *this.needs_flush = false;
            this.wakers.wake_by_ref();
        }
        Poll::Ready(Ok(()))
    }

    fn set_reply(self: Pin<&mut Self>, reply: T) {
        assert!(self.reply.is_none());
        let this = self.project();
        *this.reply = Some(reply);
        *this.needs_flush = true;
    }

    #[must_use]
    fn new(stream: S, buffer: Option<T>, filter: F) -> Self {
        let needs_flush = buffer.is_some();
        Self {
            stream,
            reply: buffer,
            start: None,
            needs_flush,
            wakers: Default::default(),
            filter,
        }
    }
}

pub trait WithReplyExt<T>: Sized + Sink<T> + TryStream {
    #[must_use]
    fn with_reply<F: ReplyBufferFilter<Self::Ok, Reply = T>>(
        self,
        buffer: Option<T>,
        filter: F,
    ) -> WithReply<Self, T, F> {
        WithReply::new(self, buffer, filter)
    }
}

impl<S: TryStream<Ok = U, Error = E> + Sink<T, Error = E>, T, U, E> WithReplyExt<T> for S {}

impl<
    S: TryStream<Ok = U, Error = E> + Sink<T, Error = E>,
    T,
    U,
    E,
    F: ReplyBufferFilter<U, Reply = T>,
> Stream for WithReply<S, T, F>
{
    type Item = Result<F::Filtered, E>;

    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
        self.wakers.next.register(cx.waker());
        loop {
            ready!(self.as_mut().poll_reply())?;
            self.as_mut().poll_flush_if_needed()?;
            let Some(item) = ready!(self.as_mut().project().stream.try_poll_next(cx)?) else {
                break Poll::Ready(None);
            };
            let Some((reply, item)) = self.as_mut().project().filter.reply_and_item(item) else {
                break Poll::Ready(None);
            };
            if let Some(reply) = reply {
                self.as_mut().set_reply(reply);
            }
            if let Some(item) = item {
                break Poll::Ready(Some(Ok(item)));
            }
        }
    }
}

impl<S: Sink<T, Error = E>, T, E, F> Sink<T> for WithReply<S, T, F> {
    type Error = E;

    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        self.wakers.ready.register(cx.waker());
        ready!(self.as_mut().poll())?;
        self.project().stream.poll_ready(cx)
    }

    fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
        assert!(self.start.is_none());
        *self.project().start = Some(item);
        Ok(())
    }

    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        self.wakers.flush.register(cx.waker());
        ready!(self.as_mut().poll())?;
        self.poll_flush_raw(cx)
    }

    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        self.wakers.flush.register(cx.waker());
        ready!(self.as_mut().poll())?;
        self.project().stream.poll_close(cx)
    }
}