async_foundation/net/
tcp_stream.rs

1use crate::common::ready_future::ReadyFuture;
2use crate::common::ready_future_state::ReadyFutureResult;
3use crate::net::event_listener;
4use futures::{AsyncRead, AsyncWrite, FutureExt};
5use mio::Token;
6use mio::net::TcpStream as MioTcpStream;
7use std::io::{self, ErrorKind};
8use std::net::{Shutdown, SocketAddr, ToSocketAddrs};
9use std::pin::Pin;
10use std::task::{Context, Poll};
11use std::time::{Duration, Instant};
12
13/// Async wrapper around a non-blocking `TcpStream` for reading.
14///
15/// `TcpReadStream` integrates with the shared event listener to wait for
16/// readability with a configurable timeout, and implements [`AsyncRead`].
17pub struct TcpReadStream {
18    tcp_stream: MioTcpStream,
19    read_token: Token,
20    read_future: Option<ReadyFuture<()>>,
21    pub read_timeout: Duration,
22}
23
24impl TcpReadStream {
25    pub fn new(tcp_stream: MioTcpStream) -> Self {
26        TcpReadStream {
27            tcp_stream,
28            read_token: event_listener().next_token(),
29            read_future: None,
30            read_timeout: Duration::from_secs(20),
31        }
32    }
33
34    pub fn set_read_timeout(&mut self, duration: Duration) {
35        self.read_timeout = duration;
36    }
37
38    fn wait_read_data(&mut self) -> io::Result<()> {
39        let future = event_listener().listen_read(
40            &mut self.tcp_stream,
41            Instant::now() + self.read_timeout,
42            self.read_token,
43        )?;
44        self.read_future = Some(future);
45        Ok(())
46    }
47
48    fn poll_read_attempt(
49        &mut self,
50        cx: &mut Context<'_>,
51        buf: &mut [u8],
52    ) -> Poll<io::Result<usize>> {
53        let mut future = match self.read_future.take() {
54            None => {
55                match io::Read::read(&mut self.tcp_stream, buf) {
56                    Ok(size) => return Poll::Ready(Ok(size)),
57                    Err(err) if err.kind() == ErrorKind::WouldBlock => (),
58                    Err(err) => return Poll::Ready(Err(err)),
59                }
60                if let Err(err) = self.wait_read_data() {
61                    return Poll::Ready(Err(err));
62                }
63                self.read_future.take().unwrap()
64            }
65            Some(future) => future,
66        };
67        match future.poll_unpin(cx) {
68            Poll::Pending => {
69                self.read_future = Some(future);
70                Poll::Pending
71            }
72            Poll::Ready(ReadyFutureResult::Timeout) => {
73                Poll::Ready(Err(io::ErrorKind::TimedOut.into()))
74            }
75            Poll::Ready(_) => match io::Read::read(&mut self.tcp_stream, buf) {
76                Ok(size) => Poll::Ready(Ok(size)),
77                Err(err) => Poll::Ready(Err(err)),
78            },
79        }
80    }
81}
82
83impl Drop for TcpReadStream {
84    fn drop(&mut self) {
85        event_listener()
86            .stop_listening(&mut self.tcp_stream, self.read_token)
87            .ok();
88    }
89}
90
91impl AsyncRead for TcpReadStream {
92    fn poll_read(
93        self: Pin<&mut Self>,
94        cx: &mut Context<'_>,
95        buf: &mut [u8],
96    ) -> Poll<io::Result<usize>> {
97        let me = self.get_mut();
98        me.poll_read_attempt(cx, buf)
99    }
100}
101
102/// Async wrapper around a non-blocking `TcpStream` for writing.
103///
104/// `TcpWriteStream` also uses the shared event listener to wait for writable
105/// readiness and implements [`AsyncWrite`].
106pub struct TcpWriteStream {
107    tcp_stream: MioTcpStream,
108    write_token: Token,
109    write_future: Option<ReadyFuture<()>>,
110    pub write_timeout: Duration,
111}
112
113impl TcpWriteStream {
114    pub fn new(tcp_stream: MioTcpStream) -> Self {
115        TcpWriteStream {
116            tcp_stream,
117            write_token: event_listener().next_token(),
118            write_future: None,
119            write_timeout: Duration::from_secs(2),
120        }
121    }
122
123    pub fn set_write_timeout(&mut self, duration: Duration) {
124        self.write_timeout = duration;
125    }
126
127    fn wait_write_channel(&mut self) -> io::Result<()> {
128        let future = event_listener().listen_write(
129            &mut self.tcp_stream,
130            Instant::now() + self.write_timeout,
131            self.write_token,
132        )?;
133        self.write_future = Some(future);
134        Ok(())
135    }
136
137    fn poll_write_attempt(&mut self, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
138        let mut future = match self.write_future.take() {
139            None => {
140                match io::Write::write(&mut self.tcp_stream, buf) {
141                    Ok(size) => return Poll::Ready(Ok(size)),
142                    Err(err) if err.kind() == ErrorKind::WouldBlock => (),
143                    Err(err) => return Poll::Ready(Err(err)),
144                }
145
146                if let Err(err) = self.wait_write_channel() {
147                    return Poll::Ready(Err(err));
148                }
149                self.write_future.take().unwrap()
150            }
151            Some(future) => future,
152        };
153        match future.poll_unpin(cx) {
154            Poll::Pending => {
155                self.write_future = Some(future);
156                Poll::Pending
157            }
158            Poll::Ready(ReadyFutureResult::Timeout) => {
159                Poll::Ready(Err(io::ErrorKind::TimedOut.into()))
160            }
161            Poll::Ready(_) => match io::Write::write(&mut self.tcp_stream, buf) {
162                Ok(size) => Poll::Ready(Ok(size)),
163                Err(err) => Poll::Ready(Err(err)),
164            },
165        }
166    }
167
168    pub fn shutdown(&self, how: Shutdown) -> io::Result<()> {
169        self.tcp_stream.shutdown(how)
170    }
171}
172
173impl Drop for TcpWriteStream {
174    fn drop(&mut self) {
175        event_listener()
176            .stop_listening(&mut self.tcp_stream, self.write_token)
177            .ok();
178    }
179}
180
181impl AsyncWrite for TcpWriteStream {
182    fn poll_write(
183        self: Pin<&mut Self>,
184        cx: &mut Context<'_>,
185        buf: &[u8],
186    ) -> Poll<io::Result<usize>> {
187        let me = self.get_mut();
188        me.poll_write_attempt(cx, buf)
189    }
190
191    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
192        Poll::Ready(Ok(()))
193    }
194
195    fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
196        let me = self.get_mut();
197        me.shutdown(Shutdown::Write)?;
198        Poll::Ready(Ok(()))
199    }
200}
201
202/// Combined async `TcpStream` that implements both [`AsyncRead`] and [`AsyncWrite`].
203///
204/// This type owns both a [`TcpReadStream`] and a [`TcpWriteStream`] and can be
205/// used directly with async/await or split into its read/write halves.
206pub struct TcpStream {
207    read_stream: TcpReadStream,
208    write_stream: TcpWriteStream,
209}
210
211impl TcpStream {
212    pub fn from(tcp_stream: std::net::TcpStream) -> io::Result<TcpStream> {
213        tcp_stream.set_nonblocking(true)?;
214        Ok(TcpStream {
215            read_stream: TcpReadStream::new(MioTcpStream::from_std(tcp_stream.try_clone()?)),
216            write_stream: TcpWriteStream::new(MioTcpStream::from_std(tcp_stream)),
217        })
218    }
219
220    pub fn connect<A: ToSocketAddrs>(addr: A) -> io::Result<TcpStream> {
221        Self::from(std::net::TcpStream::connect(addr)?)
222    }
223
224    pub fn read_stream(&self) -> &TcpReadStream {
225        &self.read_stream
226    }
227
228    pub fn read_stream_mut(&mut self) -> &mut TcpReadStream {
229        &mut self.read_stream
230    }
231
232    pub fn write_stream(&self) -> &TcpWriteStream {
233        &self.write_stream
234    }
235
236    pub fn write_stream_mut(&mut self) -> &mut TcpWriteStream {
237        &mut self.write_stream
238    }
239
240    pub fn split(self) -> (TcpReadStream, TcpWriteStream) {
241        (self.read_stream, self.write_stream)
242    }
243
244    pub fn set_read_timeout(&mut self, duration: Duration) {
245        self.read_stream.set_read_timeout(duration);
246    }
247
248    pub fn set_write_timeout(&mut self, duration: Duration) {
249        self.write_stream.set_write_timeout(duration);
250    }
251
252    pub fn local_addr(&self) -> io::Result<SocketAddr> {
253        self.read_stream.tcp_stream.local_addr()
254    }
255
256    pub fn peer_addr(&self) -> io::Result<SocketAddr> {
257        self.read_stream.tcp_stream.peer_addr()
258    }
259
260    pub fn shutdown(&self, how: Shutdown) -> io::Result<()> {
261        self.read_stream.tcp_stream.shutdown(how)
262    }
263}
264
265// Only AsyncRead/AsyncWrite implementations remain
266impl AsyncRead for TcpStream {
267    fn poll_read(
268        self: Pin<&mut Self>,
269        cx: &mut Context<'_>,
270        buf: &mut [u8],
271    ) -> Poll<io::Result<usize>> {
272        let me = self.get_mut();
273        Pin::new(&mut me.read_stream).poll_read(cx, buf)
274    }
275}
276
277impl AsyncWrite for TcpStream {
278    fn poll_write(
279        self: Pin<&mut Self>,
280        cx: &mut Context<'_>,
281        buf: &[u8],
282    ) -> Poll<io::Result<usize>> {
283        let me = self.get_mut();
284        Pin::new(&mut me.write_stream).poll_write(cx, buf)
285    }
286
287    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
288        let me = self.get_mut();
289        Pin::new(&mut me.write_stream).poll_flush(cx)
290    }
291
292    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
293        let me = self.get_mut();
294        Pin::new(&mut me.write_stream).poll_close(cx)
295    }
296}
297
298#[cfg(test)]
299mod tests {
300    use super::*;
301    use futures::executor::block_on;
302    use futures::io::{AsyncReadExt, AsyncWriteExt};
303    use std::io::{Read, Write};
304    use std::net::TcpListener;
305    use std::thread;
306    use std::time::Duration;
307
308    fn setup_test_server() -> (TcpListener, std::net::SocketAddr) {
309        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
310        let addr = listener.local_addr().unwrap();
311        (listener, addr)
312    }
313
314    #[test]
315    fn test_tcp_stream_wrapper_creation() {
316        let (listener, addr) = setup_test_server();
317
318        thread::spawn(move || {
319            if let Ok((stream, _)) = listener.accept() {
320                drop(stream);
321            }
322        });
323        let wrapper = TcpStream::connect(addr);
324        assert!(wrapper.is_ok());
325
326        let wrapper = wrapper.unwrap();
327        assert_eq!(wrapper.peer_addr().unwrap(), addr);
328    }
329
330    #[test]
331    fn test_stream_accessors() {
332        let (listener, addr) = setup_test_server();
333
334        thread::spawn(move || {
335            if let Ok((stream, _)) = listener.accept() {
336                drop(stream);
337            }
338        });
339
340        let mut wrapper = TcpStream::connect(addr).unwrap();
341
342        let read_stream = wrapper.read_stream();
343        assert_eq!(read_stream.read_timeout, Duration::from_secs(20));
344
345        let read_stream_mut = wrapper.read_stream_mut();
346        read_stream_mut.set_read_timeout(Duration::from_secs(15));
347        assert_eq!(read_stream_mut.read_timeout, Duration::from_secs(15));
348
349        let write_stream = wrapper.write_stream();
350        assert_eq!(write_stream.write_timeout, Duration::from_secs(2));
351
352        let write_stream_mut = wrapper.write_stream_mut();
353        write_stream_mut.set_write_timeout(Duration::from_secs(10));
354        assert_eq!(write_stream_mut.write_timeout, Duration::from_secs(10));
355    }
356
357    #[test]
358    fn test_stream_split() {
359        let (listener, addr) = setup_test_server();
360
361        thread::spawn(move || {
362            if let Ok((stream, _)) = listener.accept() {
363                drop(stream);
364            }
365        });
366
367        let wrapper = TcpStream::connect(addr).unwrap();
368        let (read_stream, write_stream) = wrapper.split();
369
370        assert_eq!(read_stream.read_timeout, Duration::from_secs(20));
371        assert_eq!(write_stream.write_timeout, Duration::from_secs(2));
372    }
373
374    #[test]
375    fn test_async_read_write() {
376        let (listener, addr) = setup_test_server();
377
378        thread::spawn(move || match listener.accept() {
379            Ok((mut stream, _)) => {
380                let mut buf = [0u8; 1024];
381                loop {
382                    let n = stream.read(&mut buf).unwrap();
383                    if n == 0 {
384                        break;
385                    }
386                    let _ = stream.write_all(&buf[..n]);
387                }
388            }
389            Err(err) => {
390                eprintln!("server error {:?}", &err);
391            }
392        });
393
394        thread::sleep(Duration::from_millis(10));
395
396        let test_future = async {
397            let mut wrapper = TcpStream::connect(addr).unwrap();
398
399            let test_data = &[1, 2, 3, 4, 5, 6];
400            let written = wrapper.write_all(test_data).await;
401            assert!(written.is_ok());
402
403            let mut buf = [0u8; 1024];
404            let read = wrapper.read_exact(&mut buf[..2]).await;
405            assert!(read.is_ok());
406            let read = wrapper.read_exact(&mut buf[2..test_data.len()]).await;
407            assert!(read.is_ok());
408            assert_eq!(&buf[..test_data.len()], test_data);
409
410            let test_data = &[7, 8, 9, 10];
411            let written = wrapper.write_all(test_data).await;
412            assert!(written.is_ok());
413            let read = wrapper.read(&mut buf).await;
414            assert!(read.is_ok());
415            assert_eq!(&buf[..test_data.len()], test_data);
416        };
417
418        block_on(test_future);
419    }
420
421    #[test]
422    fn test_async_read_write_with_delay() {
423        let (listener, addr) = setup_test_server();
424
425        thread::spawn(move || {
426            if let Ok((mut stream, _)) = listener.accept() {
427                let mut buf = [0u8; 1024];
428                let n = stream.read(&mut buf).unwrap();
429                let half = n / 2;
430                stream.write_all(&buf[..half]).unwrap();
431                thread::sleep(Duration::from_millis(50));
432                stream.write_all(&buf[half..n]).unwrap();
433            }
434        });
435
436        let test_future = async {
437            let mut wrapper = TcpStream::connect(addr).unwrap();
438            let test_data = b"Delayed Hello!";
439            let written = wrapper.write_all(test_data).await;
440            assert!(written.is_ok());
441
442            let mut buf = [0u8; 1024];
443            let read = wrapper.read_exact(&mut buf[..test_data.len()]).await;
444            assert!(read.is_ok());
445            assert_eq!(&buf[..test_data.len()], test_data);
446        };
447
448        block_on(test_future);
449    }
450
451    #[test]
452    fn test_concurrent_operations() {
453        let (listener, addr) = setup_test_server();
454
455        thread::spawn(move || {
456            for _ in 0..3 {
457                if let Ok((mut stream, _)) = listener.accept() {
458                    thread::spawn(move || {
459                        let mut buf = [0u8; 1024];
460                        let n = stream.read(&mut buf).unwrap();
461                        let _ = stream.write_all(&buf[..n]);
462                    });
463                }
464            }
465        });
466
467        thread::sleep(Duration::from_millis(10));
468
469        let test_future = async {
470            let mut futures = Vec::new();
471
472            for i in 0..3 {
473                let test_data = format!("Message {}", i);
474                let future = async move {
475                    let mut client = TcpStream::connect(addr).unwrap();
476                    client.write_all(test_data.as_bytes()).await.unwrap();
477                    let mut buf = [0u8; 1024];
478                    let read_bytes = client.read(&mut buf).await.unwrap();
479                    assert_eq!(&buf[..read_bytes], test_data.as_bytes());
480                };
481                futures.push(future);
482            }
483
484            futures::future::join_all(futures).await;
485        };
486
487        block_on(test_future);
488    }
489
490    #[test]
491    fn test_timeout_behavior() {
492        let (listener, addr) = setup_test_server();
493
494        thread::spawn(move || {
495            if let Ok((mut stream, _)) = listener.accept() {
496                let mut buf = [0u8; 1024];
497                let _ = stream.read(&mut buf);
498                thread::sleep(Duration::from_millis(200));
499                let _ = stream.write_all(b"slow response");
500            }
501        });
502
503        thread::sleep(Duration::from_millis(10));
504
505        let test_future = async {
506            let mut wrapper = TcpStream::connect(addr).unwrap();
507
508            wrapper.set_read_timeout(Duration::from_millis(50));
509            wrapper.write_all(b"test").await.unwrap();
510
511            let mut buf = [0u8; 1024];
512            let read_result = wrapper.read(&mut buf).await;
513
514            assert!(read_result.is_err());
515            let err = read_result.unwrap_err();
516            assert_eq!(err.kind(), io::ErrorKind::TimedOut);
517        };
518
519        block_on(test_future);
520    }
521
522    #[test]
523    fn test_shutdown() {
524        let (listener, addr) = setup_test_server();
525
526        thread::spawn(move || {
527            if let Ok((mut stream, _)) = listener.accept() {
528                let mut buf = [0u8; 1024];
529                if let Ok(n) = stream.read(&mut buf) {
530                    let _ = stream.write_all(&buf[..n]);
531                }
532            }
533        });
534
535        let wrapper = TcpStream::connect(addr).unwrap();
536        let result = wrapper.shutdown(Shutdown::Both);
537        assert!(result.is_ok());
538    }
539
540    #[test]
541    fn test_split_streams_independently() {
542        let (listener, addr) = setup_test_server();
543
544        thread::spawn(move || {
545            if let Ok((mut stream, _)) = listener.accept() {
546                let mut buf = [0u8; 1024];
547                if let Ok(n) = stream.read(&mut buf) {
548                    let _ = stream.write_all(&buf[..n]);
549                }
550            }
551        });
552
553        thread::sleep(Duration::from_millis(10));
554
555        let test_future = async {
556            let wrapper = TcpStream::connect(addr).unwrap();
557            let (mut read_stream, mut write_stream) = wrapper.split();
558
559            // Test that split streams work independently
560            let test_data = b"Split stream test";
561            write_stream.write_all(test_data).await.unwrap();
562
563            let mut buf = [0u8; 1024];
564            let read_bytes = read_stream.read(&mut buf).await.unwrap();
565            assert_eq!(&buf[..read_bytes], test_data);
566        };
567
568        block_on(test_future);
569    }
570}