fluke_buffet/io/
pipe.rs

1use tokio::sync::mpsc;
2
3use crate::{Piece, ReadOwned, WriteOwned};
4
5/// Create a new pipe.
6pub fn pipe() -> (PipeWrite, PipeRead) {
7    let (tx, rx) = mpsc::channel(1);
8    (
9        PipeWrite { tx },
10        PipeRead {
11            rx,
12            state: Default::default(),
13            remain: None,
14        },
15    )
16}
17
18enum PipeEvent {
19    Piece(Piece),
20    Reset,
21    // close is just dropping the channel
22}
23
24#[derive(Clone, Copy, Default)]
25enum ReadState {
26    #[default]
27    Live,
28    Reset,
29    Eof,
30}
31
32pub struct PipeRead {
33    rx: mpsc::Receiver<PipeEvent>,
34    remain: Option<Piece>,
35    state: ReadState,
36}
37
38impl ReadOwned for PipeRead {
39    async fn read_owned<B: crate::IoBufMut>(&mut self, mut buf: B) -> crate::BufResult<usize, B> {
40        loop {
41            match self.state {
42                ReadState::Live => {
43                    // good, continue
44                }
45                ReadState::Reset => {
46                    let err = std::io::Error::new(
47                        std::io::ErrorKind::ConnectionReset,
48                        "simulated connection reset",
49                    );
50                    return (Err(err), buf);
51                }
52                ReadState::Eof => return (Ok(0), buf),
53            }
54
55            if self.remain.is_none() {
56                match self.rx.recv().await {
57                    Some(PipeEvent::Piece(piece)) => {
58                        assert!(!piece.is_empty());
59                        self.remain = Some(piece);
60                    }
61                    Some(PipeEvent::Reset) => {
62                        self.state = ReadState::Reset;
63                        continue;
64                    }
65                    None => {
66                        self.state = ReadState::Eof;
67                        continue;
68                    }
69                }
70            }
71
72            let remain = self.remain.take().unwrap();
73            let avail = buf.io_buf_mut_capacity();
74            let read_size = avail.min(remain.len());
75            let (copied, remain) = remain.split_at(read_size);
76            assert_eq!(copied.len(), read_size);
77            unsafe {
78                buf.slice_mut()[..read_size].copy_from_slice(&copied[..]);
79            }
80
81            if !remain.is_empty() {
82                self.remain = Some(remain);
83            }
84            return (Ok(read_size), buf);
85        }
86    }
87}
88
89pub struct PipeWrite {
90    tx: mpsc::Sender<PipeEvent>,
91}
92
93impl PipeWrite {
94    /// Simulate a connection reset
95    pub async fn reset(self) {
96        self.tx.send(PipeEvent::Reset).await.unwrap()
97    }
98}
99
100impl WriteOwned for PipeWrite {
101    async fn write_owned(&mut self, buf: impl Into<Piece>) -> crate::BufResult<usize, Piece> {
102        let buf = buf.into();
103        if buf.is_empty() {
104            // ignore 0-length writes
105        }
106
107        if self.tx.send(PipeEvent::Piece(buf.clone())).await.is_err() {
108            let err = std::io::Error::new(std::io::ErrorKind::BrokenPipe, "simulated broken pipe");
109            return (Err(err), buf);
110        }
111
112        (Ok(buf.len()), buf)
113    }
114
115    async fn shutdown(&mut self) -> std::io::Result<()> {
116        Ok(())
117    }
118}
119
120#[cfg(all(test, not(feature = "miri")))]
121mod tests {
122    use crate::{ReadOwned, WriteOwned};
123
124    use super::pipe;
125    use std::{cell::RefCell, rc::Rc};
126
127    #[test]
128    fn test_pipe() {
129        crate::start(async move {
130            let (mut w, mut r) = pipe();
131            let wrote_three = Rc::new(RefCell::new(false));
132
133            crate::spawn({
134                let wrote_three = wrote_three.clone();
135                async move {
136                    w.write_all_owned("one").await.unwrap();
137                    w.write_all_owned("two").await.unwrap();
138                    w.write_all_owned("three").await.unwrap();
139                    *wrote_three.borrow_mut() = true;
140                    w.write_all_owned("splitread").await.unwrap();
141                }
142            });
143
144            {
145                let buf = vec![0u8; 256];
146                let (res, buf) = r.read_owned(buf).await;
147                let n = res.unwrap();
148                assert_eq!(&buf[..n], b"one");
149            }
150
151            assert!(!*wrote_three.borrow());
152
153            {
154                let buf = vec![0u8; 256];
155                let (res, buf) = r.read_owned(buf).await;
156                let n = res.unwrap();
157                assert_eq!(&buf[..n], b"two");
158            }
159
160            tokio::task::yield_now().await;
161            assert!(*wrote_three.borrow());
162
163            {
164                let buf = vec![0u8; 256];
165                let (res, buf) = r.read_owned(buf).await;
166                let n = res.unwrap();
167                assert_eq!(&buf[..n], b"three");
168            }
169
170            {
171                let buf = vec![0u8; 5];
172                let (res, buf) = r.read_owned(buf).await;
173                let n = res.unwrap();
174                assert_eq!(&buf[..n], b"split");
175
176                let buf = vec![0u8; 256];
177                let (res, buf) = r.read_owned(buf).await;
178                let n = res.unwrap();
179                assert_eq!(&buf[..n], b"read");
180            }
181
182            {
183                let buf = vec![0u8; 0];
184                let (res, _) = r.read_owned(buf).await;
185                let n = res.unwrap();
186                assert_eq!(n, 0, "reached EOF");
187            }
188        })
189    }
190
191    #[test]
192    fn test_pipe_fragmented_read() {
193        crate::start(async move {
194            let (mut w, mut r) = pipe();
195
196            crate::spawn({
197                async move {
198                    w.write_all_owned("two-part").await.unwrap();
199                    w.reset().await;
200                }
201            });
202
203            for _ in 0..5 {
204                tokio::task::yield_now().await;
205            }
206
207            {
208                let buf = vec![0u8; 4];
209                let (res, buf) = r.read_owned(buf).await;
210                let n = res.unwrap();
211                assert_eq!(&buf[..n], b"two-");
212            }
213
214            {
215                let buf = vec![0u8; 4];
216                let (res, buf) = r.read_owned(buf).await;
217                let n = res.unwrap();
218                assert_eq!(&buf[..n], b"part");
219            }
220
221            {
222                let buf = vec![0u8; 0];
223                let (res, _) = r.read_owned(buf).await;
224                let err = res.unwrap_err();
225                assert_eq!(
226                    err.kind(),
227                    std::io::ErrorKind::ConnectionReset,
228                    "reached EOF"
229                );
230            }
231        })
232    }
233}