intra_pipe/
lib.rs

1extern crate futures;
2extern crate tokio_current_thread;
3extern crate tokio_io;
4
5use futures::{sync::mpsc, Async, Poll, Sink, Stream};
6use std::io::{Error as IOError, ErrorKind, Read, Write};
7use tokio_io::{AsyncRead, AsyncWrite};
8
9pub trait WritePipe {
10    fn new(sender: mpsc::Sender<Vec<u8>>) -> Self;
11}
12
13pub trait ReadPipe {
14    fn new(receiver: mpsc::Receiver<Vec<u8>>) -> Self;
15}
16
17pub trait IsAsync {}
18
19pub trait IsSync {}
20
21#[derive(Debug)]
22pub struct AsyncWritePipe {
23    sender: mpsc::Sender<Vec<u8>>,
24}
25
26impl Write for AsyncWritePipe {
27    fn write(&mut self, buf: &[u8]) -> Result<usize, IOError> {
28        if self.sender.is_closed() {
29            return Ok(0);
30        }
31        let len = buf.len();
32        if len == 0 {
33            return Ok(0);
34        }
35        self.sender
36            .start_send(buf.to_vec())
37            .map_err(|err| IOError::new(ErrorKind::BrokenPipe, err))
38            .and_then(|ret| {
39                if ret.is_not_ready() {
40                    Err(IOError::new(ErrorKind::WouldBlock, ""))
41                } else {
42                    Ok(())
43                }
44            })
45            .and_then(|_| {
46                self.sender
47                    .poll_complete()
48                    .map_err(|err| IOError::new(ErrorKind::BrokenPipe, err))
49                    .map(|_| len)
50            })
51    }
52
53    fn flush(&mut self) -> Result<(), IOError> {
54        // Fake flush since data is always flushed after write.
55        Ok(())
56    }
57}
58
59impl AsyncWrite for AsyncWritePipe {
60    fn shutdown(&mut self) -> Poll<(), IOError> {
61        self.sender
62            .close()
63            .map_err(|err| IOError::new(ErrorKind::BrokenPipe, err))
64    }
65}
66
67impl WritePipe for AsyncWritePipe {
68    fn new(sender: mpsc::Sender<Vec<u8>>) -> Self {
69        AsyncWritePipe { sender }
70    }
71}
72
73impl IsAsync for AsyncWritePipe {}
74
75#[derive(Debug)]
76pub struct SyncWritePipe {
77    writer: AsyncWritePipe,
78}
79
80impl Write for SyncWritePipe {
81    fn write(&mut self, buf: &[u8]) -> Result<usize, IOError> {
82        let fut = tokio_io::io::write_all(&mut self.writer, buf);
83        tokio_current_thread::block_on_all(fut).map(|_| buf.len())
84    }
85
86    fn flush(&mut self) -> Result<(), IOError> {
87        let fut = tokio_io::io::flush(&mut self.writer);
88        tokio_current_thread::block_on_all(fut).map(|_| ())
89    }
90}
91
92impl WritePipe for SyncWritePipe {
93    fn new(sender: mpsc::Sender<Vec<u8>>) -> Self {
94        SyncWritePipe {
95            writer: AsyncWritePipe::new(sender),
96        }
97    }
98}
99
100impl IsSync for SyncWritePipe {}
101
102#[derive(Debug)]
103pub struct AsyncReadPipe {
104    receiver: mpsc::Receiver<Vec<u8>>,
105    buf: Vec<u8>,
106    pos: usize,
107}
108
109impl Read for AsyncReadPipe {
110    fn read(&mut self, buf: &mut [u8]) -> Result<usize, IOError> {
111        if self.pos == self.buf.len() {
112            self.buf = match self.receiver.poll() {
113                Ok(Async::Ready(Some(data))) => data,
114                Ok(Async::Ready(None)) => return Ok(0),
115                Ok(Async::NotReady) => {
116                    return if buf.len() == 0 {
117                        Ok(0)
118                    } else {
119                        Err(IOError::new(ErrorKind::WouldBlock, ""))
120                    };
121                }
122                Err(_) => return Err(IOError::new(ErrorKind::BrokenPipe, "")),
123            };
124            self.pos = 0;
125        }
126        let ret_len = (self.buf.len() - self.pos).min(buf.len());
127        buf[..ret_len].clone_from_slice(&self.buf[self.pos..(self.pos + ret_len)]);
128        self.pos += ret_len;
129        return Ok(ret_len);
130    }
131}
132
133impl AsyncRead for AsyncReadPipe {}
134
135impl ReadPipe for AsyncReadPipe {
136    fn new(receiver: mpsc::Receiver<Vec<u8>>) -> Self {
137        AsyncReadPipe {
138            receiver,
139            buf: vec![],
140            pos: 0,
141        }
142    }
143}
144
145impl IsAsync for AsyncReadPipe {}
146
147#[derive(Debug)]
148pub struct SyncReadPipe {
149    reader: AsyncReadPipe,
150}
151
152impl Read for SyncReadPipe {
153    fn read(&mut self, buf: &mut [u8]) -> Result<usize, IOError> {
154        let fut = tokio_io::io::read(&mut self.reader, buf);
155        tokio_current_thread::block_on_all(fut).map(|(_, _, len)| len)
156    }
157}
158
159impl ReadPipe for SyncReadPipe {
160    fn new(receiver: mpsc::Receiver<Vec<u8>>) -> Self {
161        SyncReadPipe {
162            reader: AsyncReadPipe::new(receiver),
163        }
164    }
165}
166
167impl IsSync for SyncReadPipe {}
168
169pub fn pipe<W: WritePipe, R: ReadPipe>() -> (W, R) {
170    let (sender, receiver) = mpsc::channel(1);
171    (W::new(sender), R::new(receiver))
172}
173
174#[derive(Debug)]
175pub struct Channel<R: ReadPipe, W: WritePipe> {
176    reader: R,
177    writer: W,
178}
179
180impl<R: ReadPipe + Read, W: WritePipe> Read for Channel<R, W> {
181    fn read(&mut self, buf: &mut [u8]) -> Result<usize, IOError> {
182        self.reader.read(buf)
183    }
184}
185
186impl<R: ReadPipe, W: Write + WritePipe> Write for Channel<R, W> {
187    fn write(&mut self, buf: &[u8]) -> Result<usize, IOError> {
188        self.writer.write(buf)
189    }
190
191    fn flush(&mut self) -> Result<(), IOError> {
192        self.writer.flush()
193    }
194}
195
196impl<R: ReadPipe + AsyncRead, W: WritePipe> AsyncRead for Channel<R, W> {}
197
198impl<R: ReadPipe, W: WritePipe + AsyncWrite> AsyncWrite for Channel<R, W> {
199    fn shutdown(&mut self) -> Poll<(), IOError> {
200        self.writer.shutdown()
201    }
202}
203
204pub fn channel<FR, FW, SR, SW>() -> (Channel<FR, FW>, Channel<SR, SW>)
205where
206    FR: ReadPipe,
207    FW: WritePipe,
208    SR: ReadPipe,
209    SW: WritePipe,
210{
211    let (fst_tx, fst_rx) = mpsc::channel(1);
212    let (snd_tx, snd_rx) = mpsc::channel(1);
213    (
214        Channel {
215            reader: FR::new(fst_rx),
216            writer: FW::new(snd_tx),
217        },
218        Channel {
219            reader: SR::new(snd_rx),
220            writer: SW::new(fst_tx),
221        },
222    )
223}
224
225pub type AsyncChannel = Channel<AsyncReadPipe, AsyncWritePipe>;
226pub type SyncChannel = Channel<SyncReadPipe, SyncWritePipe>;
227
228#[cfg(test)]
229mod tests {
230    use super::*;
231    use futures::{future, Future};
232    use std::thread;
233    use tokio_io::io::{read_exact, read_to_end, write_all};
234
235    const TEST_WRITE_DATA_A: &[u8] = b"Hello ";
236    const TEST_WRITE_DATA_B: &[u8] = b"World";
237    const TEST_EXPECT_DATA: &[u8] = b"Hello World";
238
239    fn sync_sender(mut tx: SyncWritePipe) -> thread::JoinHandle<()> {
240        thread::spawn(move || {
241            assert_eq!(
242                tx.write(TEST_WRITE_DATA_A).unwrap(),
243                TEST_WRITE_DATA_A.len()
244            );
245            assert_eq!(
246                tx.write(TEST_WRITE_DATA_B).unwrap(),
247                TEST_WRITE_DATA_B.len()
248            );
249        })
250    }
251
252    fn sync_receiver(mut rx: SyncReadPipe) -> thread::JoinHandle<()> {
253        thread::spawn(move || {
254            let mut buf = Vec::new();
255            rx.read_to_end(&mut buf).unwrap();
256            assert_eq!(buf, TEST_EXPECT_DATA);
257        })
258    }
259
260    fn async_sender(tx: AsyncWritePipe) -> impl Future<Item = (), Error = ()> {
261        write_all(tx, TEST_WRITE_DATA_A)
262            .and_then(|(tx, _)| write_all(tx, TEST_WRITE_DATA_B))
263            .then(|result| {
264                assert!(result.is_ok());
265                Ok(())
266            })
267    }
268
269    fn async_receiver(rx: AsyncReadPipe) -> impl Future<Item = (), Error = ()> {
270        read_to_end(rx, Vec::new()).then(|result| {
271            let (_, buf) = result.unwrap();
272            assert_eq!(buf, TEST_EXPECT_DATA);
273            Ok(())
274        })
275    }
276
277    fn run_and_wait(
278        thds: Vec<thread::JoinHandle<()>>,
279        futs: Vec<Box<dyn Future<Item = (), Error = ()>>>,
280    ) {
281        tokio_current_thread::block_on_all(future::lazy(|| {
282            for fut in futs {
283                tokio_current_thread::spawn(fut);
284            }
285            future::ok::<(), ()>(())
286        }))
287        .unwrap();
288        for thd in thds {
289            thd.join().unwrap();
290        }
291    }
292
293    #[test]
294    fn normal_pipe() {
295        let (tx, rx) = pipe();
296        run_and_wait(vec![sync_sender(tx), sync_receiver(rx)], vec![]);
297        let (tx, rx) = pipe();
298        run_and_wait(vec![sync_sender(tx)], vec![Box::new(async_receiver(rx))]);
299        let (tx, rx) = pipe();
300        run_and_wait(vec![sync_receiver(rx)], vec![Box::new(async_sender(tx))]);
301        let (tx, rx) = pipe();
302        run_and_wait(
303            vec![],
304            vec![Box::new(async_sender(tx)), Box::new(async_receiver(rx))],
305        );
306    }
307
308    #[test]
309    fn zero_read_write_pipe() {
310        let (mut tx, mut rx): (SyncWritePipe, SyncReadPipe) = pipe();
311        assert_eq!(tx.write(&[]).unwrap(), 0);
312        let mut buf = [0u8; 0];
313        assert_eq!(rx.read(&mut buf).unwrap(), 0);
314
315        let (mut tx, mut _rx): (AsyncWritePipe, AsyncReadPipe) = pipe();
316        assert_eq!(tx.write(&[]).unwrap(), 0);
317    }
318
319    #[test]
320    fn broken_pipe() {
321        let (mut tx, rx): (SyncWritePipe, SyncReadPipe) = pipe();
322        drop(rx);
323        assert_eq!(tx.write(&[]).unwrap(), 0);
324        assert_eq!(
325            tx.write(&TEST_EXPECT_DATA).err().unwrap().kind(),
326            ErrorKind::WriteZero
327        );
328        let (tx, mut rx): (SyncWritePipe, SyncReadPipe) = pipe();
329        drop(tx);
330        let mut buf = [0u8; 1];
331        assert_eq!(rx.read(&mut buf).unwrap(), 0);
332    }
333
334    #[test]
335    fn flush_pipe() {
336        let (mut tx, mut _rx): (SyncWritePipe, SyncReadPipe) = pipe();
337        assert_eq!(tx.flush().unwrap(), ());
338    }
339
340    fn sync_send_receive(ch: SyncChannel, reverse: bool) -> thread::JoinHandle<()> {
341        thread::spawn(move || {
342            let send = |mut ch: SyncChannel| {
343                assert_eq!(
344                    ch.write(TEST_WRITE_DATA_A).unwrap(),
345                    TEST_WRITE_DATA_A.len()
346                );
347                assert_eq!(
348                    ch.write(TEST_WRITE_DATA_B).unwrap(),
349                    TEST_WRITE_DATA_B.len()
350                );
351                ch
352            };
353            let receive = |mut ch: SyncChannel| {
354                let mut buf = vec![0u8; TEST_EXPECT_DATA.len()];
355                ch.read_exact(&mut buf).unwrap();
356                assert_eq!(buf, TEST_EXPECT_DATA);
357                ch
358            };
359            if reverse {
360                let ch = receive(ch);
361                send(ch);
362            } else {
363                let ch = send(ch);
364                receive(ch);
365            }
366        })
367    }
368
369    fn async_send_receive(ch: AsyncChannel, reverse: bool) -> Box<Future<Item = (), Error = ()>> {
370        let send = |tx| {
371            write_all(tx, TEST_WRITE_DATA_A)
372                .and_then(|(tx, _)| write_all(tx, TEST_WRITE_DATA_B))
373                .then(|result| {
374                    let (tx, _) = result.unwrap();
375                    Ok(tx)
376                })
377        };
378        let receive = |rx| {
379            let buf = vec![0u8; TEST_EXPECT_DATA.len()];
380            read_exact(rx, buf).then(|result| {
381                let (rx, buf) = result.unwrap();
382                assert_eq!(buf, TEST_EXPECT_DATA);
383                Ok(rx)
384            })
385        };
386        if reverse {
387            Box::new(receive(ch).and_then(move |ch| send(ch)).map(|_| ()))
388        } else {
389            Box::new(send(ch).and_then(move |ch| receive(ch)).map(|_| ()))
390        }
391    }
392
393    #[test]
394    fn normal_channel() {
395        let (fst, snd) = channel();
396        run_and_wait(
397            vec![sync_send_receive(fst, false), sync_send_receive(snd, true)],
398            vec![],
399        );
400        let (fst, snd) = channel();
401        run_and_wait(
402            vec![sync_send_receive(fst, false)],
403            vec![async_send_receive(snd, true)],
404        );
405        let (fst, snd) = channel();
406        run_and_wait(
407            vec![sync_send_receive(snd, false)],
408            vec![async_send_receive(fst, true)],
409        );
410        let (fst, snd) = channel();
411        run_and_wait(
412            vec![],
413            vec![
414                async_send_receive(fst, false),
415                async_send_receive(snd, true),
416            ],
417        );
418    }
419
420    #[test]
421    fn shutdown() {
422        let (mut tx, mut _rx): (AsyncWritePipe, AsyncReadPipe) = pipe();
423        assert_eq!(
424            tokio_current_thread::block_on_all(future::poll_fn(|| tx.shutdown())).unwrap(),
425            ()
426        );
427
428        let (mut fst, mut _snd): (AsyncChannel, AsyncChannel) = channel();
429        assert_eq!(
430            tokio_current_thread::block_on_all(future::poll_fn(|| fst.shutdown())).unwrap(),
431            ()
432        );
433    }
434}