Skip to main content

async_acceptor/
stdio.rs

1//! [`AsyncAcceptable`] implementation for a stream that can be repeatedly created
2//! using a factory function and only one instance can be in use at a time.
3//!
4//! A particular use case is to wrap [`tokio::io::stdin`] and [`tokio::io::stdout`].
5//
6// SPDX-License-Identifier: Apache-2.0 OR GPL-3.0-or-later
7
8use super::async_acceptable::AsyncAcceptable;
9use futures_util::task::AtomicWaker;
10use std::io;
11use std::pin::Pin;
12use std::sync::Arc;
13use std::sync::atomic::{AtomicBool, Ordering};
14use std::task::{Context, Poll};
15use tokio::io::{self as tio, AsyncRead, AsyncWrite, ReadBuf};
16
17/// A psudo-listener that repeatedly produces an `AsyncRead + AsyncWrite`
18/// using a factory function.
19#[derive(derive_more::Debug)]
20pub struct ReusableListener<R, W> {
21    in_use: Arc<AtomicBool>,
22    end_waker: Arc<AtomicWaker>,
23    #[debug(skip)]
24    factory: Box<dyn (Fn() -> (R, W)) + Send + Sync>,
25}
26
27impl ReusableListener<tio::Stdin, tio::Stdout> {
28    /// Produce a `ReusableListener` with [`tokio::io::stdin`] and [`tokio::io::stdout`]
29    /// as the underlying streams.
30    #[must_use]
31    #[inline]
32    pub fn new_stdio() -> Self {
33        Self {
34            in_use: Arc::new(AtomicBool::new(false)),
35            end_waker: Arc::new(AtomicWaker::new()),
36            factory: Box::new(|| (tio::stdin(), tio::stdout())),
37        }
38    }
39}
40
41impl<R, W> AsyncAcceptable for ReusableListener<R, W>
42where
43    R: AsyncRead + Unpin + Send + 'static,
44    W: AsyncWrite + Unpin + Send + 'static,
45{
46    type Stream = ReusableListenerStream<R, W>;
47
48    fn poll_accept(&self, cx: &mut Context<'_>) -> Poll<io::Result<Self::Stream>> {
49        if self.in_use.swap(true, Ordering::Acquire) {
50            self.end_waker.register(cx.waker());
51            Poll::Pending
52        } else {
53            let (reader, writer) = (self.factory)();
54            Poll::Ready(Ok(ReusableListenerStream {
55                reader,
56                writer,
57                in_use: self.in_use.clone(),
58                end_waker: self.end_waker.clone(),
59            }))
60        }
61    }
62}
63
64/// A stream produced by calling [`accept`](`crate::AsyncAcceptableExt::accept`) on a [`ReusableListener`].
65#[derive(Debug)]
66pub struct ReusableListenerStream<R, W> {
67    reader: R,
68    writer: W,
69    in_use: Arc<AtomicBool>,
70    end_waker: Arc<AtomicWaker>,
71}
72
73macro_rules! impl_fn_by_pin_delegate {
74    ($fn:ident, $ret:ty, $field:ident$(,)? $($arg_name:ident: $arg_ty:ty),*) => {
75        #[inline]
76        fn $fn(
77            mut self: Pin<&mut Self>,
78            cx: &mut Context<'_>,
79            $($arg_name: $arg_ty),*
80        ) -> std::task::Poll<$ret> {
81            Pin::new(&mut self.$field).$fn(cx, $($arg_name),*)
82        }
83    };
84}
85
86impl<R: AsyncRead + Unpin, W: Unpin> AsyncRead for ReusableListenerStream<R, W> {
87    impl_fn_by_pin_delegate! { poll_read, io::Result<()>, reader, buf: &mut ReadBuf<'_> }
88}
89
90impl<R: Unpin, W: AsyncWrite + Unpin> AsyncWrite for ReusableListenerStream<R, W> {
91    impl_fn_by_pin_delegate! { poll_write, io::Result<usize>, writer, buf: &[u8] }
92    impl_fn_by_pin_delegate! { poll_flush, io::Result<()>, writer }
93    impl_fn_by_pin_delegate! { poll_shutdown, io::Result<()>, writer }
94    impl_fn_by_pin_delegate! { poll_write_vectored, io::Result<usize>, writer, bufs: &[std::io::IoSlice<'_>] }
95    fn is_write_vectored(&self) -> bool {
96        self.writer.is_write_vectored()
97    }
98}
99
100impl<R, W> Drop for ReusableListenerStream<R, W> {
101    fn drop(&mut self) {
102        self.in_use.store(false, Ordering::Release);
103        self.end_waker.wake();
104    }
105}
106
107#[cfg(test)]
108mod tests {
109    use super::*;
110    use crate::AsyncAcceptableExt;
111    use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
112
113    #[tokio::test]
114    async fn test_reusable_listener() {
115        let listener = ReusableListener {
116            in_use: Arc::new(AtomicBool::new(false)),
117            end_waker: Arc::new(AtomicWaker::new()),
118            factory: Box::new(|| duplex(64)),
119        };
120        let mut accepted_stream = listener.accept().await.expect("Failed to accept stream");
121        let mut test_cx = Context::from_waker(futures_util::task::noop_waker_ref());
122        let res2 = listener.poll_accept(&mut test_cx);
123        assert!(res2.is_pending(), "Listener should be busy");
124        accepted_stream
125            .write_all(b"Hello")
126            .await
127            .expect("Failed to write to stream");
128        let mut buf = [0u8; 5];
129        accepted_stream
130            .read_exact(&mut buf)
131            .await
132            .expect("Failed to read from stream");
133        assert_eq!(&buf, b"Hello", "Data read does not match data written");
134        drop(accepted_stream);
135        let mut accepted_stream2 = listener
136            .accept()
137            .await
138            .expect("Failed to accept stream after previous stream dropped");
139        accepted_stream2
140            .write_all(b"World")
141            .await
142            .expect("Failed to write to stream");
143        let mut buf2 = [0u8; 5];
144        accepted_stream2
145            .read_exact(&mut buf2)
146            .await
147            .expect("Failed to read from stream");
148        assert_eq!(&buf2, b"World", "Data read does not match data written");
149    }
150}