1use std::alloc::{Layout, alloc, dealloc};
8use std::pin::Pin;
9use std::process::abort;
10use std::ptr::{null, null_mut, slice_from_raw_parts};
11use std::task::{Context, Poll, Waker};
12use std::{io, mem};
13
14use futures::{AsyncRead, AsyncWrite};
15
16fn pipe_layout(bs: usize) -> Layout { Layout::from_size_align(bs, 1).expect("1-align is trivial") }
17
18pub fn pipe(size: usize) -> (Writer, Reader) {
21 assert!(0 < size, "cannot create async pipe without buffer");
22 let start = unsafe { alloc(pipe_layout(size)) };
24 extern "C" fn drop(val: *const ()) {
25 let AsyncRingbuffer {
26 start,
27 size,
28 mut read_waker,
29 mut write_waker,
30 reader_dropped,
31 writer_dropped,
32 read_idx: _,
34 write_idx: _,
35 drop: _,
37 state: _,
38 } = *unsafe { Box::from_raw(val as *mut AsyncRingbuffer) };
39 if !writer_dropped || !reader_dropped {
40 eprintln!("Pipe dropped in err before reader or writer");
41 abort()
42 }
43 read_waker.drop();
44 write_waker.drop();
45 unsafe { dealloc(start, pipe_layout(size)) }
46 }
47 let state = Box::into_raw(Box::new(AsyncRingbuffer {
48 start,
49 size,
50 state: null(),
51 read_idx: 0,
52 write_idx: 0,
53 read_waker: Trigger::empty(),
54 write_waker: Trigger::empty(),
55 reader_dropped: false,
56 writer_dropped: false,
57 drop,
58 }));
59 let state_mut = unsafe { state.as_mut().unwrap() };
60 state_mut.state = state as *const ();
61 (Writer(state_mut as *mut _), Reader(state_mut as *mut _))
62}
63
64#[repr(C)]
67struct Trigger {
68 state: *const (),
69 invoke: extern "C" fn(*const ()),
70 drop: extern "C" fn(*const ()),
71}
72impl Trigger {
73 fn new(waker: Waker) -> Self {
74 let state = Box::into_raw(Box::new(waker)) as *const ();
75 extern "C" fn drop(state: *const ()) {
76 unsafe { mem::drop(Box::from_raw(state as *mut Waker)) };
77 }
78 extern "C" fn invoke(state: *const ()) { unsafe { Box::from_raw(state as *mut Waker) }.wake(); }
79 Self { state, invoke, drop }
80 }
81 fn empty() -> Self {
82 extern "C" fn empty_fn_ptr(_: *const ()) { abort() }
83 Self { state: null(), drop: empty_fn_ptr, invoke: empty_fn_ptr }
84 }
85 fn is_empty(&self) -> bool { self.state.is_null() }
86 fn invoke(&mut self) {
87 if let Some(this) = self.take() {
88 (this.invoke)(this.state)
89 }
90 }
91 fn drop(&mut self) {
92 if let Some(this) = self.take() {
93 (this.drop)(this.state)
94 }
95 }
96 fn take(&mut self) -> Option<Self> {
97 (!self.is_empty()).then(|| std::mem::replace(self, Self::empty()))
98 }
99}
100
101#[repr(C)]
103struct AsyncRingbuffer {
104 state: *const (),
105 start: *mut u8,
106 size: usize,
107 read_idx: usize,
108 write_idx: usize,
109 read_waker: Trigger,
110 write_waker: Trigger,
111 reader_dropped: bool,
112 writer_dropped: bool,
113 drop: extern "C" fn(*const ()),
114}
115impl AsyncRingbuffer {
116 fn drop_writer(&mut self) {
117 self.writer_dropped = true;
118 if self.reader_dropped {
119 (self.drop)(self.state)
120 }
121 }
122 fn drop_reader(&mut self) {
123 self.reader_dropped = true;
124 if self.writer_dropped {
125 (self.drop)(self.state)
126 }
127 }
128 fn writer_wait<T>(&mut self, waker: &Waker) -> Poll<io::Result<T>> {
129 if self.reader_dropped {
130 return Poll::Ready(Err(broken_pipe_error()));
131 }
132 self.read_waker.invoke();
133 self.write_waker.drop();
134 self.write_waker = Trigger::new(waker.clone());
135 Poll::Pending
136 }
137 fn reader_wait(&mut self, waker: &Waker) -> Poll<io::Result<usize>> {
138 if self.writer_dropped {
139 return Poll::Ready(Err(broken_pipe_error()));
140 }
141 self.write_waker.invoke();
142 self.read_waker.drop();
143 self.read_waker = Trigger::new(waker.clone());
144 Poll::Pending
145 }
146 unsafe fn non_wrapping_write_unchecked(&mut self, buf: &[u8]) {
147 let write_ptr = unsafe { self.start.add(self.write_idx) };
148 let slc = slice_from_raw_parts(write_ptr, buf.len()).cast_mut();
149 unsafe { &mut *slc }.copy_from_slice(buf);
150 self.write_idx = (self.write_idx + buf.len()) % self.size;
151 }
152 unsafe fn non_wrapping_read_unchecked(&mut self, buf: &mut [u8]) {
153 let read_ptr = unsafe { self.start.add(self.read_idx) };
154 let slc = slice_from_raw_parts(read_ptr, buf.len()).cast_mut();
155 buf.copy_from_slice(unsafe { &*slc });
156 self.read_idx = (self.read_idx + buf.len()) % self.size;
157 }
158 fn is_full(&self) -> bool { (self.write_idx + 1) % self.size == self.read_idx }
159 fn is_empty(&self) -> bool { self.write_idx == self.read_idx }
160}
161
162fn already_closed_error() -> io::Error {
163 io::Error::new(io::ErrorKind::BrokenPipe, "Pipe already closed from this end")
164}
165fn broken_pipe_error() -> io::Error {
166 io::Error::new(io::ErrorKind::BrokenPipe, "Pipe already closed from other end")
167}
168
169#[repr(C)]
172pub struct Writer(*mut AsyncRingbuffer);
173impl Writer {
174 unsafe fn get_state(self: Pin<&mut Self>) -> io::Result<&mut AsyncRingbuffer> {
175 match unsafe { self.0.as_mut() } {
176 Some(data) => Ok(data),
177 None => Err(already_closed_error()),
178 }
179 }
180}
181impl AsyncWrite for Writer {
182 fn poll_close(mut self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
183 unsafe {
184 match self.as_mut().get_state() {
185 Err(e) => return Poll::Ready(Err(e)),
186 Ok(data) => {
187 data.drop_writer();
188 },
189 }
190 }
191 self.0 = null_mut();
192 Poll::Ready(Ok(()))
193 }
194 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
195 unsafe {
196 let data = self.as_mut().get_state()?;
197 if data.is_empty() { Poll::Ready(Ok(())) } else { data.writer_wait(cx.waker()) }
198 }
199 }
200 fn poll_write(
201 mut self: Pin<&mut Self>,
202 cx: &mut Context<'_>,
203 buf: &[u8],
204 ) -> Poll<io::Result<usize>> {
205 unsafe {
206 let data = self.as_mut().get_state()?;
207 let AsyncRingbuffer { write_idx, read_idx, size, .. } = *data;
208 if !buf.is_empty() && data.is_empty() {
209 data.read_waker.invoke();
210 }
211 if !buf.is_empty() && data.is_full() {
212 data.writer_wait(cx.waker())
214 } else if write_idx < read_idx {
215 let count = buf.len().min(read_idx - write_idx - 1);
217 data.non_wrapping_write_unchecked(&buf[0..count]);
218 Poll::Ready(Ok(count))
219 } else if data.write_idx + buf.len() < size {
220 data.non_wrapping_write_unchecked(&buf[0..buf.len()]);
222 Poll::Ready(Ok(buf.len()))
223 } else if read_idx == 0 {
224 data.non_wrapping_write_unchecked(&buf[0..size - write_idx - 1]);
226 Poll::Ready(Ok(size - write_idx - 1))
227 } else {
228 let (end, start) = buf.split_at(size - write_idx);
229 data.non_wrapping_write_unchecked(end);
231 let start_count = start.len().min(read_idx - 1);
232 data.non_wrapping_write_unchecked(&start[0..start_count]);
233 Poll::Ready(Ok(end.len() + start_count))
234 }
235 }
236 }
237}
238impl Drop for Writer {
239 fn drop(&mut self) {
240 unsafe {
241 if let Some(data) = self.0.as_mut() {
242 data.drop_writer();
243 }
244 }
245 }
246}
247
248#[repr(C)]
251pub struct Reader(*mut AsyncRingbuffer);
252impl AsyncRead for Reader {
253 fn poll_read(
254 self: Pin<&mut Self>,
255 cx: &mut Context<'_>,
256 buf: &mut [u8],
257 ) -> Poll<io::Result<usize>> {
258 unsafe {
259 let data = self.0.as_mut().expect("Cannot be null");
260 let AsyncRingbuffer { read_idx, write_idx, size, .. } = *data;
261 if !buf.is_empty() && data.is_full() {
262 data.write_waker.invoke();
263 }
264 if !buf.is_empty() && data.is_empty() {
265 data.reader_wait(cx.waker())
267 } else if read_idx < write_idx {
268 let count = buf.len().min(write_idx - read_idx);
270 data.non_wrapping_read_unchecked(&mut buf[0..count]);
271 Poll::Ready(Ok(count))
272 } else if read_idx + buf.len() < size {
273 data.non_wrapping_read_unchecked(buf);
275 Poll::Ready(Ok(buf.len()))
276 } else {
277 let (end, start) = buf.split_at_mut(size - read_idx);
279 data.non_wrapping_read_unchecked(end);
280 let start_count = start.len().min(write_idx);
281 data.non_wrapping_read_unchecked(&mut start[0..start_count]);
282 Poll::Ready(Ok(end.len() + start_count))
283 }
284 }
285 }
286}
287impl Drop for Reader {
288 fn drop(&mut self) {
289 unsafe {
290 if let Some(data) = self.0.as_mut() {
291 data.drop_reader();
292 }
293 }
294 }
295}
296
297#[cfg(test)]
298mod tests {
299 use std::pin::pin;
300
301 use futures::future::join;
302 use futures::{AsyncReadExt, AsyncWriteExt};
303 use itertools::Itertools;
304 use rand::{Rng, SeedableRng};
305 use rand_chacha::ChaCha8Rng;
306 use test_executors::spin_on;
307
308 use super::*;
309
310 #[test]
311 fn basic_io() {
312 let mut w_rng = ChaCha8Rng::seed_from_u64(2);
313 let mut r_rng = ChaCha8Rng::seed_from_u64(1);
314 spin_on(async {
315 let (w, r) = pipe(1024);
316 let test_length = 10_000_000;
317 let data = (0u32..test_length).flat_map(|num| num.to_be_bytes());
318 let write_fut = async {
319 let mut w = pin!(w);
320 let mut source = data.clone();
321 let mut tally = 0;
322 while tally < test_length * 4 {
323 let values = source.by_ref().take(w_rng.random_range(0..200)).collect::<Vec<_>>();
324 tally += values.len() as u32;
325 w.write_all(&values).await.unwrap();
326 }
327 w.flush().await.unwrap();
328 };
329 let read_fut = async {
330 let mut r = pin!(r);
331 let mut expected = data.clone();
332 let mut tally = 0;
333 while tally < test_length * 4 {
334 let expected_values =
335 expected.by_ref().take(r_rng.random_range(0..200)).collect::<Vec<_>>();
336 tally += expected_values.len() as u32;
337 let mut values = vec![0; expected_values.len()];
338 r.read_exact(&mut values[..]).await.unwrap_or_else(|e| panic!("At {tally} bytes: {e}"));
339 if values != expected_values {
340 fn print_bytes(bytes: &[u8]) -> String {
341 (bytes.iter().map(|s| format!("{s:>2x}")).chunks(32).into_iter())
342 .map(|c| c.into_iter().join(" "))
343 .join("\n")
344 }
345 panic!(
346 "Difference in generated numbers\n{}\n{}",
347 print_bytes(&values),
348 print_bytes(&expected_values),
349 )
350 }
351 }
352 };
353 join(write_fut, read_fut).await;
354 })
355 }
356}