1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
use crate::internal::io_future::{IoFuture, IoFutureState};
use futures::AsyncWrite;
use std::io;
use std::mem::replace;
use std::pin::Pin;
use std::task::{Context, Poll};

pub struct BufferWriteState {
    buffer: io::Result<Vec<u8>>,
    completion: usize,
}

#[allow(dead_code)]
const fn check_if_send<T: Send>() {}
const _: () = check_if_send::<BufferWriteState>();

impl BufferWriteState {
    pub fn new(buffer: io::Result<Vec<u8>>) -> Self {
        Self {
            buffer,
            completion: 0,
        }
    }
}

impl<IO: AsyncWrite + Unpin> IoFutureState<IO> for BufferWriteState {
    fn poll(&mut self, cx: &mut Context<'_>, io: &mut IO) -> Poll<io::Result<()>> {
        let buffer = match &self.buffer {
            Ok(buffer) => buffer,
            Err(_) => {
                let r = replace(&mut self.buffer, Ok(Vec::new()));
                return Poll::Ready(Err(r.unwrap_err()));
            }
        };
        loop {
            let remainder = &buffer[self.completion..];
            match Pin::new(&mut *io).poll_write(cx, &remainder) {
                Poll::Ready(Ok(n)) => {
                    if n == remainder.len() {
                        return Poll::Ready(Ok(()));
                    }
                    self.completion += n;
                }
                Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
                Poll::Pending => return Poll::Pending,
            }
        }
    }
}

pub type BufferWrite<IO> = IoFuture<BufferWriteState, IO>;

#[cfg(test)]
mod tests {
    use crate::internal::buffer_write::{BufferWrite, BufferWriteState};
    use crate::internal::io_future::IoFutureState;
    use futures::executor::block_on;
    use futures::io::Cursor;

    #[test]
    fn test() {
        block_on(async {
            const HELLO_WORLD: &[u8] = b"Hello World!";
            let mut io = Cursor::new(Vec::new());
            let fut: BufferWrite<_> =
                BufferWriteState::new(Ok(HELLO_WORLD.to_vec())).into_future(&mut io);
            fut.await.unwrap();

            assert_eq!(
                String::from_utf8(HELLO_WORLD.to_vec()).unwrap(),
                String::from_utf8(io.into_inner()).unwrap()
            );
        })
    }
}