1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
use crate::state::Data;
use crate::state::State;
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll};
use tokio::io::{self, AsyncWrite};

/// The write half of the pipe which implements [`AsyncWrite`](https://docs.rs/tokio/0.2.16/tokio/io/trait.AsyncWrite.html).
pub struct PipeWriter {
    pub(crate) state: Arc<Mutex<State>>,
}

impl PipeWriter {
    /// Closes the pipe, any further read will return EOF and any further write will raise an error.
    pub fn close(&self) -> io::Result<()> {
        match self.state.lock() {
            Ok(mut state) => {
                state.closed = true;
                self.wake_reader_half(&*state);
                Ok(())
            }
            Err(err) => Err(io::Error::new(
                io::ErrorKind::Other,
                format!(
                    "{}: PipeWriter: Failed to lock the channel state: {}",
                    env!("CARGO_PKG_NAME"),
                    err
                ),
            )),
        }
    }

    fn wake_reader_half(&self, state: &State) {
        if let Some(ref waker) = state.reader_waker {
            waker.clone().wake();
        }
    }
}

impl Drop for PipeWriter {
    fn drop(&mut self) {
        if let Err(err) = self.close() {
            log::warn!(
                "{}: PipeWriter: Failed to close the channel on drop: {}",
                env!("CARGO_PKG_NAME"),
                err
            );
        }
    }
}

impl AsyncWrite for PipeWriter {
    fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
        let mut state;
        match self.state.lock() {
            Ok(s) => state = s,
            Err(err) => {
                return Poll::Ready(Err(io::Error::new(
                    io::ErrorKind::Other,
                    format!(
                        "{}: PipeWriter: Failed to lock the channel state: {}",
                        env!("CARGO_PKG_NAME"),
                        err
                    ),
                )))
            }
        }

        if state.closed {
            return Poll::Ready(Err(io::Error::new(
                io::ErrorKind::BrokenPipe,
                format!(
                    "{}: PipeWriter: The channel is closed",
                    env!("CARGO_PKG_NAME")
                ),
            )));
        }

        return if state.done_cycle {
            state.data = Some(Data {
                ptr: buf.as_ptr(),
                len: buf.len(),
            });
            state.done_cycle = false;
            state.writer_waker = Some(cx.waker().clone());

            self.wake_reader_half(&*state);

            Poll::Pending
        } else {
            if state.done_reading {
                let read_bytes_len = state.read;

                state.done_cycle = true;
                state.read = 0;
                state.writer_waker = None;
                state.data = None;
                state.done_reading = false;

                Poll::Ready(Ok(read_bytes_len))
            } else {
                state.writer_waker = Some(cx.waker().clone());
                Poll::Pending
            }
        };
    }

    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<io::Result<()>> {
        Poll::Ready(Ok(()))
    }

    fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<io::Result<()>> {
        match self.close() {
            Ok(_) => Poll::Ready(Ok(())),
            Err(err) => Poll::Ready(Err(io::Error::new(
                io::ErrorKind::Other,
                format!(
                    "{}: PipeWriter: Failed to shutdown the channel: {}",
                    env!("CARGO_PKG_NAME"),
                    err
                ),
            ))),
        }
    }
}