multipart-write 0.1.0

Sink-like interface for writing an object in parts
Documentation
use std::pin::Pin;
use std::task::{Context, Poll};

use futures::future;
use futures::stream::{StreamExt as _, iter};
use multipart_write::stream::MultipartStreamExt as _;
use multipart_write::{
    FusedMultipartWrite, MultipartWrite, MultipartWriteExt as _,
};

#[derive(Clone, Debug)]
struct TestWriter {
    inner: Vec<usize>,
    multiplier: usize,
    completed: usize,
    max_completed: Option<usize>,
}

impl Default for TestWriter {
    fn default() -> Self {
        Self {
            inner: Vec::new(),
            multiplier: 1,
            completed: 0,
            max_completed: None,
        }
    }
}

impl TestWriter {
    fn new(multiplier: usize) -> Self {
        Self {
            multiplier,
            inner: Vec::new(),
            completed: 0,
            max_completed: None,
        }
    }

    fn max_completed(mut self, max: usize) -> Self {
        self.max_completed = Some(max);
        self
    }
}

impl MultipartWrite<usize> for TestWriter {
    type Error = String;
    type Output = Vec<usize>;
    type Recv = usize;

    fn poll_ready(
        self: Pin<&mut Self>,
        _cx: &mut Context<'_>,
    ) -> Poll<Result<(), String>> {
        Poll::Ready(Ok(()))
    }

    fn start_send(
        mut self: Pin<&mut Self>,
        part: usize,
    ) -> Result<usize, String> {
        let m = self.multiplier;
        self.as_mut().inner.push(m * part);
        Ok(part)
    }

    fn poll_flush(
        self: Pin<&mut Self>,
        _cx: &mut Context<'_>,
    ) -> Poll<Result<(), Self::Error>> {
        Poll::Ready(Ok(()))
    }

    fn poll_complete(
        mut self: Pin<&mut Self>,
        _cx: &mut Context<'_>,
    ) -> Poll<Result<Self::Output, Self::Error>> {
        self.completed += 1;
        Poll::Ready(Ok(std::mem::take(&mut self.inner)))
    }
}

impl FusedMultipartWrite<usize> for TestWriter {
    fn is_terminated(&self) -> bool {
        self.max_completed.is_some_and(|n| n <= self.completed)
    }
}

#[derive(Debug, Clone, Default)]
struct OtherTestWriter(Vec<String>);

impl MultipartWrite<Vec<usize>> for OtherTestWriter {
    type Error = String;
    type Output = String;
    type Recv = usize;

    fn poll_ready(
        self: Pin<&mut Self>,
        _cx: &mut Context<'_>,
    ) -> Poll<Result<(), String>> {
        Poll::Ready(Ok(()))
    }

    fn start_send(
        mut self: Pin<&mut Self>,
        part: Vec<usize>,
    ) -> Result<usize, String> {
        let n: usize = part.iter().sum();
        self.as_mut().0.push(n.to_string());
        Ok(self.0.len())
    }

    fn poll_flush(
        self: Pin<&mut Self>,
        _cx: &mut Context<'_>,
    ) -> Poll<Result<(), Self::Error>> {
        Poll::Ready(Ok(()))
    }

    fn poll_complete(
        self: Pin<&mut Self>,
        _cx: &mut Context<'_>,
    ) -> Poll<Result<Self::Output, Self::Error>> {
        let out = self.0.join(",");
        Poll::Ready(Ok(out))
    }
}

#[tokio::test]
async fn trait_futures() {
    let mut writer = TestWriter::default();
    writer.send_flush(1).await.unwrap();
    writer.send_flush(2).await.unwrap();
    writer.send_flush(3).await.unwrap();
    let out1 = writer.complete().await.unwrap();

    writer.send_flush(10).await.unwrap();
    writer.send_flush(20).await.unwrap();
    writer.flush().await.unwrap();
    let out2 = writer.complete().await.unwrap();

    assert_eq!(out1, vec![1, 2, 3]);
    assert_eq!(out2, vec![10, 20]);
}

#[tokio::test]
async fn map_writer() {
    let mut writer =
        TestWriter::default().map_ok(|ns| ns.iter().sum::<usize>());
    for n in 1..=5 {
        writer.feed(n).await.unwrap();
    }
    writer.flush().await.unwrap();
    let out = writer.complete().await.unwrap();
    assert_eq!(out, 15);
}

#[tokio::test]
async fn ready_part_writer() {
    let mut writer =
        TestWriter::new(2).ready_part(|x: &str| future::ready(Ok(x.len())));
    writer.send_flush("abc").await.unwrap();
    writer.send_flush("d").await.unwrap();
    writer.send_flush("ef").await.unwrap();
    let out = writer.complete().await.unwrap();

    assert_eq!(out, vec![6, 2, 4]);
}

#[tokio::test]
async fn lift_writer() {
    let writer = OtherTestWriter::default().lift(TestWriter::default());
    let out = iter(1..=5).complete_with(writer).await.unwrap();
    assert_eq!(&out, "15");
}

#[tokio::test]
async fn try_complete_when_stream() {
    let writer = TestWriter::default();
    let mut outputs = iter(1..=12)
        .try_complete_when(writer, |ret| ret % 5 == 0)
        .filter_map(|res| future::ready(res.ok()))
        .collect::<Vec<_>>()
        .await;

    assert_eq!(outputs.len(), 3);

    let out3 = outputs.pop().unwrap();
    let out2 = outputs.pop().unwrap();
    let out1 = outputs.pop().unwrap();
    assert_eq!(out1, vec![1, 2, 3, 4, 5]);
    assert_eq!(out2, vec![6, 7, 8, 9, 10]);
    assert_eq!(out3, vec![11, 12]);
}

#[tokio::test]
async fn complete_with_future() {
    let writer =
        TestWriter::default().fold_sent(String::new(), |mut acc, n| {
            acc += &n.to_string();
            acc
        });
    let (acc, out) = iter(1..=5).complete_with(writer).await.unwrap();
    assert_eq!(acc, "12345".to_string());
    assert_eq!(out, vec![1, 2, 3, 4, 5]);
}

#[tokio::test]
async fn skip_last_complete_if_empty() {
    let writer = TestWriter::default();
    let outputs = iter(1..=10)
        .try_complete_when(writer, |ret| ret % 5 == 0)
        .filter_map(|res| future::ready(res.ok()))
        .collect::<Vec<_>>()
        .await;
    assert_eq!(outputs.len(), 2);
}

#[tokio::test]
async fn stream_ends_on_fused_writer() {
    let writer = TestWriter::default().max_completed(4);
    let mut outputs = iter(1..)
        .try_complete_when(writer, |ret| ret % 2 == 0)
        .filter_map(|res| future::ready(res.ok()))
        .collect::<Vec<_>>()
        .await;
    assert_eq!(outputs.pop(), Some(vec![7, 8]));
    assert_eq!(outputs.pop(), Some(vec![5, 6]));
    assert_eq!(outputs.pop(), Some(vec![3, 4]));
    assert_eq!(outputs.pop(), Some(vec![1, 2]));
    assert!(outputs.pop().is_none());
}