1use crate::state::{Data, State};
2use std::pin::Pin;
3use std::ptr;
4use std::sync::{Arc, Mutex};
5use std::task::{Context, Poll};
6use tokio::io::{self, AsyncRead};
7
8pub struct PipeReader {
10 pub(crate) state: Arc<Mutex<State>>,
11}
12
13impl PipeReader {
14 pub fn close(&self) -> io::Result<()> {
16 match self.state.lock() {
17 Ok(mut state) => {
18 state.closed = true;
19 self.wake_writer_half(&*state);
20 Ok(())
21 }
22 Err(err) => Err(io::Error::new(
23 io::ErrorKind::Other,
24 format!(
25 "{}: PipeReader: Failed to lock the channel state: {}",
26 env!("CARGO_PKG_NAME"),
27 err
28 ),
29 )),
30 }
31 }
32
33 pub fn is_flushed(&self) -> io::Result<bool> {
35 let state = match self.state.lock() {
36 Ok(s) => s,
37 Err(err) => {
38 return Err(io::Error::new(
39 io::ErrorKind::Other,
40 format!(
41 "{}: PipeReader: Failed to lock the channel state: {}",
42 env!("CARGO_PKG_NAME"),
43 err
44 ),
45 ));
46 }
47 };
48
49 Ok(state.done_cycle)
50 }
51
52 fn wake_writer_half(&self, state: &State) {
53 if let Some(ref waker) = state.writer_waker {
54 waker.clone().wake();
55 }
56 }
57
58 fn copy_data_into_buffer(&self, data: &Data, buf: &mut [u8]) -> usize {
59 let len = data.len.min(buf.len());
60 unsafe {
61 ptr::copy_nonoverlapping(data.ptr, buf.as_mut_ptr(), len);
62 }
63 len
64 }
65}
66
67impl Drop for PipeReader {
68 fn drop(&mut self) {
69 if let Err(err) = self.close() {
70 log::warn!(
71 "{}: PipeReader: Failed to close the channel on drop: {}",
72 env!("CARGO_PKG_NAME"),
73 err
74 );
75 }
76 }
77}
78
79impl AsyncRead for PipeReader {
80 fn poll_read(
81 self: Pin<&mut Self>,
82 cx: &mut Context,
83 buf: &mut [u8],
84 ) -> Poll<io::Result<usize>> {
85 let mut state;
86 match self.state.lock() {
87 Ok(s) => state = s,
88 Err(err) => {
89 return Poll::Ready(Err(io::Error::new(
90 io::ErrorKind::Other,
91 format!(
92 "{}: PipeReader: Failed to lock the channel state: {}",
93 env!("CARGO_PKG_NAME"),
94 err
95 ),
96 )))
97 }
98 }
99
100 if state.closed {
101 return Poll::Ready(Ok(0));
102 }
103
104 return if state.done_cycle {
105 state.reader_waker = Some(cx.waker().clone());
106 Poll::Pending
107 } else {
108 if let Some(ref data) = state.data {
109 let copied_bytes_len = self.copy_data_into_buffer(data, buf);
110
111 state.data = None;
112 state.read = copied_bytes_len;
113 state.done_reading = true;
114 state.reader_waker = None;
115
116 self.wake_writer_half(&*state);
117
118 Poll::Ready(Ok(copied_bytes_len))
119 } else {
120 state.reader_waker = Some(cx.waker().clone());
121 Poll::Pending
122 }
123 };
124 }
125}