1use super::async_acceptable::AsyncAcceptable;
9use futures_util::task::AtomicWaker;
10use std::io;
11use std::pin::Pin;
12use std::sync::Arc;
13use std::sync::atomic::{AtomicBool, Ordering};
14use std::task::{Context, Poll};
15use tokio::io::{self as tio, AsyncRead, AsyncWrite, ReadBuf};
16
17#[derive(derive_more::Debug)]
20pub struct ReusableListener<R, W> {
21 in_use: Arc<AtomicBool>,
22 end_waker: Arc<AtomicWaker>,
23 #[debug(skip)]
24 factory: Box<dyn (Fn() -> (R, W)) + Send + Sync>,
25}
26
27impl ReusableListener<tio::Stdin, tio::Stdout> {
28 #[must_use]
31 #[inline]
32 pub fn new_stdio() -> Self {
33 Self {
34 in_use: Arc::new(AtomicBool::new(false)),
35 end_waker: Arc::new(AtomicWaker::new()),
36 factory: Box::new(|| (tio::stdin(), tio::stdout())),
37 }
38 }
39}
40
41impl<R, W> AsyncAcceptable for ReusableListener<R, W>
42where
43 R: AsyncRead + Unpin + Send + 'static,
44 W: AsyncWrite + Unpin + Send + 'static,
45{
46 type Stream = ReusableListenerStream<R, W>;
47
48 fn poll_accept(&self, cx: &mut Context<'_>) -> Poll<io::Result<Self::Stream>> {
49 if self.in_use.swap(true, Ordering::Acquire) {
50 self.end_waker.register(cx.waker());
51 Poll::Pending
52 } else {
53 let (reader, writer) = (self.factory)();
54 Poll::Ready(Ok(ReusableListenerStream {
55 reader,
56 writer,
57 in_use: self.in_use.clone(),
58 end_waker: self.end_waker.clone(),
59 }))
60 }
61 }
62}
63
64#[derive(Debug)]
66pub struct ReusableListenerStream<R, W> {
67 reader: R,
68 writer: W,
69 in_use: Arc<AtomicBool>,
70 end_waker: Arc<AtomicWaker>,
71}
72
73macro_rules! impl_fn_by_pin_delegate {
74 ($fn:ident, $ret:ty, $field:ident$(,)? $($arg_name:ident: $arg_ty:ty),*) => {
75 #[inline]
76 fn $fn(
77 mut self: Pin<&mut Self>,
78 cx: &mut Context<'_>,
79 $($arg_name: $arg_ty),*
80 ) -> std::task::Poll<$ret> {
81 Pin::new(&mut self.$field).$fn(cx, $($arg_name),*)
82 }
83 };
84}
85
86impl<R: AsyncRead + Unpin, W: Unpin> AsyncRead for ReusableListenerStream<R, W> {
87 impl_fn_by_pin_delegate! { poll_read, io::Result<()>, reader, buf: &mut ReadBuf<'_> }
88}
89
90impl<R: Unpin, W: AsyncWrite + Unpin> AsyncWrite for ReusableListenerStream<R, W> {
91 impl_fn_by_pin_delegate! { poll_write, io::Result<usize>, writer, buf: &[u8] }
92 impl_fn_by_pin_delegate! { poll_flush, io::Result<()>, writer }
93 impl_fn_by_pin_delegate! { poll_shutdown, io::Result<()>, writer }
94 impl_fn_by_pin_delegate! { poll_write_vectored, io::Result<usize>, writer, bufs: &[std::io::IoSlice<'_>] }
95 fn is_write_vectored(&self) -> bool {
96 self.writer.is_write_vectored()
97 }
98}
99
100impl<R, W> Drop for ReusableListenerStream<R, W> {
101 fn drop(&mut self) {
102 self.in_use.store(false, Ordering::Release);
103 self.end_waker.wake();
104 }
105}
106
107#[cfg(test)]
108mod tests {
109 use super::*;
110 use crate::AsyncAcceptableExt;
111 use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
112
113 #[tokio::test]
114 async fn test_reusable_listener() {
115 let listener = ReusableListener {
116 in_use: Arc::new(AtomicBool::new(false)),
117 end_waker: Arc::new(AtomicWaker::new()),
118 factory: Box::new(|| duplex(64)),
119 };
120 let mut accepted_stream = listener.accept().await.expect("Failed to accept stream");
121 let mut test_cx = Context::from_waker(futures_util::task::noop_waker_ref());
122 let res2 = listener.poll_accept(&mut test_cx);
123 assert!(res2.is_pending(), "Listener should be busy");
124 accepted_stream
125 .write_all(b"Hello")
126 .await
127 .expect("Failed to write to stream");
128 let mut buf = [0u8; 5];
129 accepted_stream
130 .read_exact(&mut buf)
131 .await
132 .expect("Failed to read from stream");
133 assert_eq!(&buf, b"Hello", "Data read does not match data written");
134 drop(accepted_stream);
135 let mut accepted_stream2 = listener
136 .accept()
137 .await
138 .expect("Failed to accept stream after previous stream dropped");
139 accepted_stream2
140 .write_all(b"World")
141 .await
142 .expect("Failed to write to stream");
143 let mut buf2 = [0u8; 5];
144 accepted_stream2
145 .read_exact(&mut buf2)
146 .await
147 .expect("Failed to read from stream");
148 assert_eq!(&buf2, b"World", "Data read does not match data written");
149 }
150}