bipe/
lib.rs

1mod buffer;
2use futures_lite::prelude::*;
3
4use std::{
5    io::Read,
6    io::Write,
7    pin::Pin,
8    sync::{
9        atomic::{AtomicBool, Ordering},
10        Arc,
11    },
12    task::Context,
13    task::Poll,
14};
15
16/// Create a "bipe". Use async_dup's methods if you want something cloneable/shareable
17pub fn bipe(capacity: usize) -> (BipeWriter, BipeReader) {
18    let (send_buf, recv_buf) = buffer::new(capacity);
19    let write_ready = Arc::new(event_listener::Event::new());
20    let read_ready = Arc::new(event_listener::Event::new());
21    let closed = Arc::new(AtomicBool::new(false));
22    (
23        BipeWriter {
24            queue: send_buf,
25            signal: write_ready.clone(),
26            signal_reader: read_ready.clone(),
27            listener: write_ready.listen(),
28            closed: closed.clone(),
29        },
30        BipeReader {
31            queue: recv_buf,
32            signal: read_ready.clone(),
33            signal_writer: write_ready.clone(),
34            listener: read_ready.listen(),
35            closed,
36        },
37    )
38}
39
40/// Writing end of a byte pipe.
41pub struct BipeWriter {
42    queue: buffer::Producer,
43    signal: Arc<event_listener::Event>,
44    signal_reader: Arc<event_listener::Event>,
45    listener: event_listener::EventListener,
46    closed: Arc<AtomicBool>,
47}
48
49impl Drop for BipeWriter {
50    fn drop(&mut self) {
51        self.closed.store(true, Ordering::SeqCst);
52        self.signal_reader.notify(1);
53    }
54}
55
56fn broken_pipe() -> std::io::Error {
57    std::io::Error::new(std::io::ErrorKind::ConnectionReset, "broken pipe")
58}
59
60impl AsyncWrite for BipeWriter {
61    fn poll_write(
62        mut self: Pin<&mut Self>,
63        cx: &mut Context<'_>,
64        buf: &[u8],
65    ) -> Poll<std::io::Result<usize>> {
66        loop {
67            if self.closed.load(Ordering::SeqCst) {
68                return Poll::Ready(Err(broken_pipe()));
69            }
70            // if there's room in the buffer then it's fine
71            {
72                if let Ok(n) = self.queue.write(buf) {
73                    // if n > 0 {
74                    self.signal_reader.notify(1);
75                    return Poll::Ready(Ok(n));
76                    // }
77                }
78            }
79            let listen_capacity = &mut self.listener;
80            futures_lite::pin!(listen_capacity);
81            // there's no room, so we try again later
82            futures_lite::ready!(listen_capacity.poll(cx));
83            self.listener = self.signal.listen()
84        }
85    }
86
87    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
88        Poll::Ready(Ok(()))
89    }
90
91    fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
92        self.closed.store(true, Ordering::SeqCst);
93        self.signal_reader.notify(1);
94        Poll::Ready(Ok(()))
95    }
96}
97
98/// Read end of a byte pipe.
99pub struct BipeReader {
100    queue: buffer::Consumer,
101    signal: Arc<event_listener::Event>,
102    signal_writer: Arc<event_listener::Event>,
103    listener: event_listener::EventListener,
104    closed: Arc<AtomicBool>,
105}
106
107impl AsyncRead for BipeReader {
108    fn poll_read(
109        mut self: Pin<&mut Self>,
110        cx: &mut Context<'_>,
111        buf: &mut [u8],
112    ) -> Poll<std::io::Result<usize>> {
113        loop {
114            if let Ok(n) = self.queue.read(buf) {
115                if n > 0 {
116                    self.signal_writer.notify(1);
117                    return Poll::Ready(Ok(n));
118                }
119            }
120            if self.closed.load(Ordering::Relaxed) {
121                return Poll::Ready(Ok(0));
122            }
123            let listen_new_data = &mut self.listener;
124            futures_lite::pin!(listen_new_data);
125            futures_lite::ready!(listen_new_data.poll(cx));
126            self.listener = self.signal.listen();
127        }
128    }
129}
130
131#[cfg(test)]
132mod tests {
133    use super::*;
134
135    #[test]
136    fn test_no_corruption() {
137        const ITERATIONS: u64 = 1000;
138        let (mut send, mut recv) = bipe(9);
139        async_global_executor::block_on(async move {
140            async_global_executor::spawn(async move {
141                for iteration in 0u64..ITERATIONS {
142                    // dbg!(iteration);
143                    send.write_all(&iteration.to_be_bytes()).await.unwrap();
144                }
145            })
146            .detach();
147            let mut buff = vec![];
148            recv.read_to_end(&mut buff).await.unwrap();
149
150            assert_eq!(buff.len() as u64, ITERATIONS * 8);
151        })
152    }
153}