1use std::collections::VecDeque;
21use std::io;
22use std::pin::Pin;
23use std::sync::atomic::{AtomicBool, Ordering};
24use std::sync::{Arc, Mutex};
25use std::task::{Context, Poll, Waker};
26
27use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
28
29pub const PIPE_BUFFER_SIZE: usize = 64 * 1024;
31
32struct PipeBuffer {
37 buffer: VecDeque<u8>,
38 capacity: usize,
39 reader_waker: Option<Waker>,
41 writer_waker: Option<Waker>,
43}
44
45struct PipeShared {
47 buf: Mutex<PipeBuffer>,
48 writer_closed: AtomicBool,
50 reader_closed: AtomicBool,
52}
53
54pub struct PipeWriter {
56 shared: Arc<PipeShared>,
57}
58
59pub struct PipeReader {
61 shared: Arc<PipeShared>,
62}
63
64pub fn pipe_stream(capacity: usize) -> (PipeWriter, PipeReader) {
70 let shared = Arc::new(PipeShared {
71 buf: Mutex::new(PipeBuffer {
72 buffer: VecDeque::with_capacity(capacity.min(8192)),
73 capacity,
74 reader_waker: None,
75 writer_waker: None,
76 }),
77 writer_closed: AtomicBool::new(false),
78 reader_closed: AtomicBool::new(false),
79 });
80
81 (
82 PipeWriter { shared: shared.clone() },
83 PipeReader { shared },
84 )
85}
86
87pub fn pipe_stream_default() -> (PipeWriter, PipeReader) {
89 pipe_stream(PIPE_BUFFER_SIZE)
90}
91
92impl PipeWriter {
93 pub async fn write_bytes(&self, data: &[u8]) -> io::Result<usize> {
97 use std::future::poll_fn;
98
99 if data.is_empty() {
100 return Ok(0);
101 }
102
103 poll_fn(|cx| Pin::new(&mut &*self).poll_write_impl(cx, data)).await
104 }
105}
106
107impl PipeWriter {
108 fn poll_write_impl(
110 &self,
111 cx: &mut Context<'_>,
112 buf: &[u8],
113 ) -> Poll<io::Result<usize>> {
114 if buf.is_empty() {
115 return Poll::Ready(Ok(0));
116 }
117
118 if self.shared.reader_closed.load(Ordering::Acquire) {
119 return Poll::Ready(Err(io::Error::new(
120 io::ErrorKind::BrokenPipe,
121 "pipe reader closed",
122 )));
123 }
124
125 let mut inner = self.shared.buf.lock().unwrap_or_else(|e| e.into_inner());
126
127 if self.shared.reader_closed.load(Ordering::Acquire) {
129 return Poll::Ready(Err(io::Error::new(
130 io::ErrorKind::BrokenPipe,
131 "pipe reader closed",
132 )));
133 }
134
135 let available = inner.capacity.saturating_sub(inner.buffer.len());
136 if available > 0 {
137 let to_write = buf.len().min(available);
138 inner.buffer.extend(&buf[..to_write]);
139 if let Some(waker) = inner.reader_waker.take() {
141 waker.wake();
142 }
143 Poll::Ready(Ok(to_write))
144 } else {
145 inner.writer_waker = Some(cx.waker().clone());
147 Poll::Pending
148 }
149 }
150}
151
152impl AsyncWrite for PipeWriter {
153 fn poll_write(
154 self: Pin<&mut Self>,
155 cx: &mut Context<'_>,
156 buf: &[u8],
157 ) -> Poll<io::Result<usize>> {
158 self.poll_write_impl(cx, buf)
159 }
160
161 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
162 Poll::Ready(Ok(()))
163 }
164
165 fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
166 self.shared.writer_closed.store(true, Ordering::Release);
167 let mut inner = self.shared.buf.lock().unwrap_or_else(|e| e.into_inner());
168 if let Some(waker) = inner.reader_waker.take() {
169 waker.wake();
170 }
171 Poll::Ready(Ok(()))
172 }
173}
174
175impl Drop for PipeWriter {
176 fn drop(&mut self) {
177 self.shared.writer_closed.store(true, Ordering::Release);
178 if let Ok(mut inner) = self.shared.buf.lock() {
181 if let Some(waker) = inner.reader_waker.take() {
182 waker.wake();
183 }
184 }
185 }
187}
188
189impl AsyncRead for PipeReader {
190 fn poll_read(
191 self: Pin<&mut Self>,
192 cx: &mut Context<'_>,
193 buf: &mut ReadBuf<'_>,
194 ) -> Poll<io::Result<()>> {
195 let mut inner = self.shared.buf.lock().unwrap_or_else(|e| e.into_inner());
196
197 if !inner.buffer.is_empty() {
198 let to_read = buf.remaining().min(inner.buffer.len());
199 let (front, back) = inner.buffer.as_slices();
200
201 if to_read <= front.len() {
202 buf.put_slice(&front[..to_read]);
203 } else {
204 buf.put_slice(front);
205 let remaining = to_read - front.len();
206 buf.put_slice(&back[..remaining]);
207 }
208
209 inner.buffer.drain(..to_read);
210 if let Some(waker) = inner.writer_waker.take() {
212 waker.wake();
213 }
214 Poll::Ready(Ok(()))
215 } else if self.shared.writer_closed.load(Ordering::Acquire) {
216 Poll::Ready(Ok(()))
218 } else {
219 inner.reader_waker = Some(cx.waker().clone());
221 Poll::Pending
222 }
223 }
224}
225
226impl Drop for PipeReader {
227 fn drop(&mut self) {
228 self.shared.reader_closed.store(true, Ordering::Release);
229 if let Ok(mut inner) = self.shared.buf.lock() {
230 if let Some(waker) = inner.writer_waker.take() {
231 waker.wake();
232 }
233 }
234 }
235}
236
237impl std::fmt::Debug for PipeWriter {
238 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
239 f.debug_struct("PipeWriter").finish()
240 }
241}
242
243impl std::fmt::Debug for PipeReader {
244 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
245 f.debug_struct("PipeReader").finish()
246 }
247}
248
249#[cfg(test)]
250mod tests {
251 use super::*;
252 use tokio::io::{AsyncReadExt, AsyncWriteExt};
253
254 #[tokio::test]
255 async fn test_basic_write_read() {
256 let (writer, mut reader) = pipe_stream(1024);
257
258 writer.write_bytes(b"hello").await.unwrap();
259 drop(writer); let mut buf = Vec::new();
262 reader.read_to_end(&mut buf).await.unwrap();
263 assert_eq!(buf, b"hello");
264 }
265
266 #[tokio::test]
267 async fn test_concurrent_write_read() {
268 let (writer, mut reader) = pipe_stream(64);
269
270 let write_task = tokio::spawn(async move {
271 for i in 0..100 {
272 let msg = format!("line {}\n", i);
273 writer.write_bytes(msg.as_bytes()).await.unwrap();
274 }
275 });
277
278 let mut output = Vec::new();
279 reader.read_to_end(&mut output).await.unwrap();
280
281 write_task.await.unwrap();
282
283 let text = String::from_utf8(output).unwrap();
284 assert!(text.contains("line 0"));
285 assert!(text.contains("line 99"));
286 }
287
288 #[tokio::test]
289 async fn test_backpressure() {
290 let (writer, mut reader) = pipe_stream(16);
291
292 let write_task = tokio::spawn(async move {
294 let data = b"0123456789ABCDEF_EXTRA_DATA";
295 let mut written = 0;
296 while written < data.len() {
297 let n = writer.write_bytes(&data[written..]).await.unwrap();
298 written += n;
299 }
300 });
301
302 let mut output = Vec::new();
303 reader.read_to_end(&mut output).await.unwrap();
304
305 write_task.await.unwrap();
306 assert_eq!(output, b"0123456789ABCDEF_EXTRA_DATA");
307 }
308
309 #[tokio::test]
310 async fn test_eof_on_writer_drop() {
311 let (writer, mut reader) = pipe_stream(1024);
312
313 writer.write_bytes(b"data").await.unwrap();
314 drop(writer);
315
316 let mut buf = [0u8; 1024];
317 let n = reader.read(&mut buf).await.unwrap();
318 assert_eq!(&buf[..n], b"data");
319
320 let n = reader.read(&mut buf).await.unwrap();
322 assert_eq!(n, 0);
323 }
324
325 #[tokio::test]
326 async fn test_broken_pipe_on_reader_drop() {
327 let (writer, reader) = pipe_stream(1024);
328 drop(reader);
329
330 let result = writer.write_bytes(b"data").await;
331 assert!(result.is_err());
332 assert_eq!(result.unwrap_err().kind(), io::ErrorKind::BrokenPipe);
333 }
334
335 #[tokio::test]
336 async fn test_empty_pipe() {
337 let (writer, mut reader) = pipe_stream(1024);
338 drop(writer);
339
340 let mut buf = Vec::new();
341 reader.read_to_end(&mut buf).await.unwrap();
342 assert!(buf.is_empty());
343 }
344
345 #[tokio::test]
346 async fn test_async_write_trait() {
347 let (mut writer, mut reader) = pipe_stream(1024);
348
349 writer.write_all(b"async write").await.unwrap();
350 writer.shutdown().await.unwrap();
351
352 let mut buf = Vec::new();
353 reader.read_to_end(&mut buf).await.unwrap();
354 assert_eq!(buf, b"async write");
355 }
356
357 #[tokio::test]
358 async fn test_large_data_through_small_buffer() {
359 let (writer, mut reader) = pipe_stream(32);
360
361 let data: Vec<u8> = (0..10_000).map(|i| (i % 256) as u8).collect();
362 let expected = data.clone();
363
364 let write_task = tokio::spawn(async move {
365 let mut pos = 0;
366 while pos < data.len() {
367 let n = writer.write_bytes(&data[pos..]).await.unwrap();
368 pos += n;
369 }
370 });
371
372 let mut output = Vec::new();
373 reader.read_to_end(&mut output).await.unwrap();
374
375 write_task.await.unwrap();
376 assert_eq!(output, expected);
377 }
378
379 #[tokio::test]
383 async fn test_no_lost_wakeups_under_contention() {
384 let (writer, mut reader) = pipe_stream(16);
386
387 let write_task = tokio::spawn(async move {
388 for i in 0u32..1000 {
389 let bytes = i.to_le_bytes();
390 let mut pos = 0;
391 while pos < bytes.len() {
392 let n = writer.write_bytes(&bytes[pos..]).await.unwrap();
393 pos += n;
394 }
395 }
396 });
398
399 let mut output = Vec::new();
400 reader.read_to_end(&mut output).await.unwrap();
401 write_task.await.unwrap();
402
403 assert_eq!(output.len(), 4000);
405 }
406
407 #[tokio::test]
409 async fn test_concurrent_stress_no_hang() {
410 let result = tokio::time::timeout(std::time::Duration::from_secs(5), async {
411 let (writer, mut reader) = pipe_stream(64);
412
413 let write_task = tokio::spawn(async move {
414 let chunk = vec![0xABu8; 37]; for _ in 0..5000 {
416 let mut pos = 0;
417 while pos < chunk.len() {
418 match writer.write_bytes(&chunk[pos..]).await {
419 Ok(n) => pos += n,
420 Err(_) => return, }
422 }
423 }
424 });
425
426 let mut total = 0usize;
427 let mut buf = [0u8; 128];
428 loop {
429 match reader.read(&mut buf).await {
430 Ok(0) => break,
431 Ok(n) => total += n,
432 Err(_) => break,
433 }
434 }
435
436 write_task.await.unwrap();
437 assert_eq!(total, 37 * 5000);
438 }).await;
439
440 assert!(result.is_ok(), "pipe stress test timed out — likely deadlock");
441 }
442
443 #[tokio::test]
445 async fn test_writer_drop_during_active_read() {
446 let (writer, mut reader) = pipe_stream(1024);
447
448 let read_task = tokio::spawn(async move {
450 let mut buf = Vec::new();
451 reader.read_to_end(&mut buf).await.unwrap();
452 buf
453 });
454
455 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
457 drop(writer);
458
459 let result = tokio::time::timeout(
460 std::time::Duration::from_secs(2),
461 read_task,
462 ).await;
463 assert!(result.is_ok(), "reader hung after writer dropped");
464 assert!(result.unwrap().unwrap().is_empty());
465 }
466
467 #[tokio::test]
469 async fn test_reader_drop_while_writer_blocked() {
470 let (writer, reader) = pipe_stream(8);
471
472 let write_task = tokio::spawn(async move {
473 let data = vec![0u8; 1024]; let mut pos = 0;
475 while pos < data.len() {
476 match writer.write_bytes(&data[pos..]).await {
477 Ok(n) => pos += n,
478 Err(e) => {
479 assert_eq!(e.kind(), io::ErrorKind::BrokenPipe);
480 return;
481 }
482 }
483 }
484 panic!("writer should have gotten broken pipe");
485 });
486
487 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
489 drop(reader);
490
491 let result = tokio::time::timeout(
492 std::time::Duration::from_secs(2),
493 write_task,
494 ).await;
495 assert!(result.is_ok(), "writer hung after reader dropped");
496 }
497}