1use std::collections::VecDeque;
2use std::future::Future;
3use std::io::{BufRead, Error, ErrorKind, IoSlice, Write};
4use std::pin::Pin;
5use std::task::{Context, Poll, Waker};
6
7use futures_io::{AsyncBufRead, AsyncRead, AsyncWrite};
8use loole::{unbounded, Receiver, RecvFuture, Sender, TrySendError};
9
10use crate::state::SharedState;
11
12pub fn async_pipe() -> (AsyncWriter, AsyncReader) {
34 let (sender, receiver) = unbounded();
35 let state = SharedState::default();
36
37 (
38 AsyncWriter {
39 sender,
40 state: state.clone(),
41 wakers: VecDeque::new(),
42 },
43 AsyncReader {
44 receiver,
45 state,
46 buf: VecDeque::new(),
47 reading: None,
48 },
49 )
50}
51
52#[cfg(feature = "async")]
74#[cfg(feature = "sync")]
75pub fn async_reader_pipe() -> (crate::Writer, AsyncReader) {
76 let (sender, receiver) = unbounded();
77 let state = SharedState::default();
78
79 (
80 crate::Writer {
81 sender,
82 state: state.clone(),
83 },
84 AsyncReader {
85 receiver,
86 state,
87 buf: VecDeque::new(),
88 reading: None,
89 },
90 )
91}
92
93#[cfg(feature = "async")]
115#[cfg(feature = "sync")]
116pub fn async_writer_pipe() -> (AsyncWriter, crate::Reader) {
117 let (sender, receiver) = unbounded();
118 let state = SharedState::default();
119
120 (
121 AsyncWriter {
122 sender,
123 state: state.clone(),
124 wakers: VecDeque::new(),
125 },
126 crate::Reader {
127 receiver,
128 state,
129 buf: VecDeque::new(),
130 },
131 )
132}
133
134#[derive(Debug)]
138pub struct AsyncWriter {
139 sender: Sender<()>,
140 wakers: VecDeque<Waker>,
141 state: SharedState,
142}
143
144impl Clone for AsyncWriter {
145 fn clone(&self) -> Self {
146 Self {
147 sender: self.sender.clone(),
148 wakers: VecDeque::new(),
149 state: self.state.clone(),
150 }
151 }
152}
153
154impl AsyncWriter {
155 fn poll_send(&mut self, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
156 match self.sender.try_send(()) {
157 Ok(_) => {
158 if let Some(waker) = self.wakers.pop_front() {
159 waker.wake()
160 }
161 Poll::Ready(Ok(()))
162 }
163 Err(TrySendError::Full(_)) => {
164 self.wakers.push_back(cx.waker().clone());
165 Poll::Pending
166 }
167 Err(e @ TrySendError::Disconnected(_)) => {
168 if let Some(waker) = self.wakers.pop_front() {
169 waker.wake()
170 }
171 Poll::Ready(Err(Error::new(ErrorKind::WriteZero, e)))
172 }
173 }
174 }
175}
176
177impl AsyncWrite for AsyncWriter {
178 fn poll_write(
179 mut self: Pin<&mut Self>,
180 cx: &mut Context<'_>,
181 buf: &[u8],
182 ) -> Poll<std::io::Result<usize>> {
183 let n = self.state.write(buf)?;
184 match self.poll_send(cx)? {
185 Poll::Ready(_) => Poll::Ready(Ok(n)),
186 Poll::Pending => Poll::Pending,
187 }
188 }
189
190 fn poll_write_vectored(
191 mut self: Pin<&mut Self>,
192 cx: &mut Context<'_>,
193 bufs: &[IoSlice<'_>],
194 ) -> Poll<std::io::Result<usize>> {
195 let n = self.state.write_vectored(bufs)?;
196 match self.poll_send(cx)? {
197 Poll::Ready(_) => Poll::Ready(Ok(n)),
198 Poll::Pending => Poll::Pending,
199 }
200 }
201
202 fn poll_flush(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
203 Poll::Ready(self.state.flush())
204 }
205
206 fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
207 self.sender.close();
208 Poll::Ready(Ok(()))
209 }
210}
211
212#[derive(Debug)]
216pub struct AsyncReader {
217 receiver: Receiver<()>,
218 buf: VecDeque<u8>,
219 reading: Option<RecvFuture<()>>,
220 state: SharedState,
221}
222
223impl AsyncBufRead for AsyncReader {
224 fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<&[u8]>> {
225 let this = self.get_mut();
226 while this.buf.is_empty() {
227 let n = this.state.copy_to(&mut this.buf)?;
228 if n == 0 {
229 if this.reading.is_none() {
230 this.reading = Some(this.receiver.recv_async())
231 }
232
233 match Pin::new(this.reading.as_mut().unwrap()).poll(cx) {
234 Poll::Ready(Ok(_)) => {
235 this.reading = None;
236 }
237 Poll::Ready(Err(_)) => {
238 this.reading = None;
239 break;
240 }
241 Poll::Pending => return Poll::Pending,
242 }
243 }
244 }
245
246 if this.buf.is_empty() {
247 _ = this.state.copy_to(&mut this.buf)?;
248 }
249
250 Poll::Ready(this.buf.fill_buf())
251 }
252
253 fn consume(mut self: Pin<&mut Self>, amt: usize) {
254 self.buf.consume(amt)
255 }
256}
257
258impl AsyncRead for AsyncReader {
259 fn poll_read(
260 mut self: Pin<&mut Self>,
261 cx: &mut Context<'_>,
262 mut buf: &mut [u8],
263 ) -> Poll<std::io::Result<usize>> {
264 let data = match self.as_mut().poll_fill_buf(cx)? {
265 Poll::Ready(buf) => buf,
266 Poll::Pending => return Poll::Pending,
267 };
268 let n = buf.write(data)?;
269 self.consume(n);
270 Poll::Ready(Ok(n))
271 }
272}
273
274#[cfg(test)]
275mod tests {
276 use std::io::IoSlice;
277 use std::thread::spawn;
278
279 use futures::{
280 executor::block_on, AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, StreamExt, TryStreamExt,
281 };
282
283 #[test]
284 fn base_write_case() {
285 block_on(async {
286 let (mut writer, reader) = crate::async_pipe();
288 for _ in 0..1000 {
289 writer.write_all("hello".as_bytes()).await.unwrap();
290 }
291 drop(reader)
292 })
293 }
294
295 #[test]
296 fn base_read_case() {
297 block_on(async {
298 let (mut writer, mut reader) = crate::async_pipe();
299
300 writer.write_all("hello ".as_bytes()).await.unwrap();
301 writer.write_all("world".as_bytes()).await.unwrap();
302 drop(writer);
303
304 let mut str = String::new();
305 reader.read_to_string(&mut str).await.unwrap();
306
307 assert_eq!("hello world".to_string(), str);
308 });
309 }
310
311 #[test]
312 fn base_vectored_case() {
313 block_on(async {
314 let (mut writer, mut reader) = crate::async_pipe();
315 _ = writer
316 .write_vectored(&[
317 IoSlice::new("hello ".as_bytes()),
318 IoSlice::new("world".as_bytes()),
319 ])
320 .await
321 .unwrap();
322 drop(writer);
323
324 let mut str = String::new();
325 reader.read_to_string(&mut str).await.unwrap();
326
327 assert_eq!("hello world".to_string(), str);
328 });
329 }
330
331 #[test]
332 fn thread_case() {
333 block_on(async {
334 let (writer, mut reader) = crate::async_pipe();
335 let writers = (0..1000).map(|_| writer.clone()).collect::<Vec<_>>();
336 let writers_len = writers.len();
337 drop(writer);
338 let write_fut = futures::stream::iter(writers)
339 .map(|mut writer| async move { writer.write_all("hello".as_bytes()).await })
340 .buffer_unordered(writers_len)
341 .try_collect::<Vec<()>>();
342
343 let mut str = String::new();
344 let read_fut = reader.read_to_string(&mut str);
345 futures::join!(
346 async {
347 write_fut.await.unwrap();
348 },
349 async { read_fut.await.unwrap() }
350 );
351
352 assert_eq!("hello".repeat(writers_len), str);
353 });
354 }
355
356 #[test]
357 fn writer_err_case() {
358 block_on(async {
359 let (mut writer, reader) = crate::async_pipe();
360 drop(reader);
361
362 assert!(writer.write("hello".as_bytes()).await.is_err());
363 });
364 }
365
366 #[test]
367 fn bufread_case() {
368 block_on(async {
369 let (mut writer, mut reader) = crate::async_pipe();
370 writer.write_all("hello\n".as_bytes()).await.unwrap();
371 writer.write_all("world".as_bytes()).await.unwrap();
372 drop(writer);
373
374 let mut str = String::new();
375 assert_ne!(0, reader.read_line(&mut str).await.unwrap());
376 assert_eq!("hello\n".to_string(), str);
377
378 let mut str = String::new();
379 assert_ne!(0, reader.read_line(&mut str).await.unwrap());
380 assert_eq!("world".to_string(), str);
381
382 let mut str = String::new();
383 assert_eq!(0, reader.read_line(&mut str).await.unwrap());
384 });
385 }
386
387 #[test]
388 fn bufread_lines_case() {
389 block_on(async {
390 let (mut writer, reader) = crate::async_pipe();
391 writer.write_all("hello\n".as_bytes()).await.unwrap();
392 writer.write_all("world".as_bytes()).await.unwrap();
393 drop(writer);
394
395 assert_eq!(2, reader.lines().map(|l| assert!(l.is_ok())).count().await)
396 });
397 }
398
399 #[test]
400 fn thread_writer_case() {
401 use std::io::Write;
402
403 let (writer, mut reader) = crate::async_reader_pipe();
404 for _ in 0..1000 {
405 let mut writer = writer.clone();
406 spawn(move || {
407 writer.write_all("hello".as_bytes()).unwrap();
408 });
409 }
410 drop(writer);
411
412 block_on(async {
413 let mut str = String::new();
414 reader.read_to_string(&mut str).await.unwrap();
415
416 assert_eq!("hello".repeat(1000), str);
417 })
418 }
419
420 #[test]
421 fn thread_reader_case() {
422 use std::io::Read;
423
424 let (writer, mut reader) = crate::async_writer_pipe();
425 for _ in 0..1000 {
426 let mut writer = writer.clone();
427 spawn(move || {
428 block_on(async {
429 writer.write_all("hello".as_bytes()).await.unwrap();
430 })
431 });
432 }
433 drop(writer);
434
435 let mut str = String::new();
436 reader.read_to_string(&mut str).unwrap();
437
438 assert_eq!("hello".repeat(1000), str);
439 }
440
441 #[test]
442 fn threads_write_and_read_case() {
443 let (writer, mut reader) = crate::async_pipe();
444
445 for _ in 0..1000 {
446 let mut writer = writer.clone();
447
448 spawn(move || {
449 block_on(async {
450 writer.write_all(&[0; 4]).await.unwrap();
451 })
452 });
453
454 block_on(async {
455 let mut buf = [0; 4];
456 assert_eq!(buf.len(), reader.read(&mut buf).await.unwrap());
457 })
458 }
459 drop(writer);
460 }
461}