1use crate::state::Data;
2use crate::state::State;
3use std::pin::Pin;
4use std::sync::{Arc, Mutex};
5use std::task::{Context, Poll};
6use tokio::io::{self, AsyncWrite};
7
8pub struct PipeWriter {
10 pub(crate) state: Arc<Mutex<State>>,
11}
12
13impl PipeWriter {
14 pub fn close(&self) -> io::Result<()> {
16 match self.state.lock() {
17 Ok(mut state) => {
18 state.closed = true;
19 self.wake_reader_half(&*state);
20 Ok(())
21 }
22 Err(err) => Err(io::Error::new(
23 io::ErrorKind::Other,
24 format!(
25 "{}: PipeWriter: Failed to lock the channel state: {}",
26 env!("CARGO_PKG_NAME"),
27 err
28 ),
29 )),
30 }
31 }
32
33 pub fn is_flushed(&self) -> io::Result<bool> {
35 let state = match self.state.lock() {
36 Ok(s) => s,
37 Err(err) => {
38 return Err(io::Error::new(
39 io::ErrorKind::Other,
40 format!(
41 "{}: PipeWriter: Failed to lock the channel state: {}",
42 env!("CARGO_PKG_NAME"),
43 err
44 ),
45 ));
46 }
47 };
48
49 Ok(state.done_cycle)
50 }
51
52 fn wake_reader_half(&self, state: &State) {
53 if let Some(ref waker) = state.reader_waker {
54 waker.clone().wake();
55 }
56 }
57}
58
59impl Drop for PipeWriter {
60 fn drop(&mut self) {
61 if let Err(err) = self.close() {
62 log::warn!(
63 "{}: PipeWriter: Failed to close the channel on drop: {}",
64 env!("CARGO_PKG_NAME"),
65 err
66 );
67 }
68 }
69}
70
71impl AsyncWrite for PipeWriter {
72 fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
73 let mut state;
74 match self.state.lock() {
75 Ok(s) => state = s,
76 Err(err) => {
77 return Poll::Ready(Err(io::Error::new(
78 io::ErrorKind::Other,
79 format!(
80 "{}: PipeWriter: Failed to lock the channel state: {}",
81 env!("CARGO_PKG_NAME"),
82 err
83 ),
84 )))
85 }
86 }
87
88 if state.closed {
89 return Poll::Ready(Err(io::Error::new(
90 io::ErrorKind::BrokenPipe,
91 format!(
92 "{}: PipeWriter: The channel is closed",
93 env!("CARGO_PKG_NAME")
94 ),
95 )));
96 }
97
98 return if state.done_cycle {
99 state.data = Some(Data {
100 ptr: buf.as_ptr(),
101 len: buf.len(),
102 });
103 state.done_cycle = false;
104 state.writer_waker = Some(cx.waker().clone());
105
106 self.wake_reader_half(&*state);
107
108 Poll::Pending
109 } else {
110 if state.done_reading {
111 let read_bytes_len = state.read;
112
113 state.done_cycle = true;
114 state.read = 0;
115 state.writer_waker = None;
116 state.data = None;
117 state.done_reading = false;
118
119 Poll::Ready(Ok(read_bytes_len))
120 } else {
121 state.writer_waker = Some(cx.waker().clone());
122 Poll::Pending
123 }
124 };
125 }
126
127 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<io::Result<()>> {
128 Poll::Ready(Ok(()))
129 }
130
131 fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<io::Result<()>> {
132 match self.close() {
133 Ok(_) => Poll::Ready(Ok(())),
134 Err(err) => Poll::Ready(Err(io::Error::new(
135 io::ErrorKind::Other,
136 format!(
137 "{}: PipeWriter: Failed to shutdown the channel: {}",
138 env!("CARGO_PKG_NAME"),
139 err
140 ),
141 ))),
142 }
143 }
144}