async_foundation/net/
tcp_stream.rs

1use crate::common::ready_future::ReadyFuture;
2use crate::common::ready_future_state::ReadyFutureResult;
3use crate::net::event_listener;
4use futures::{AsyncRead, AsyncWrite, FutureExt};
5use mio::Token;
6use mio::net::TcpStream as MioTcpStream;
7use std::io::{self, ErrorKind};
8use std::net::{Shutdown, SocketAddr, ToSocketAddrs};
9use std::pin::Pin;
10use std::task::{Context, Poll};
11use std::time::{Duration, Instant};
12
13pub struct TcpReadStream {
14    tcp_stream: MioTcpStream,
15    read_token: Token,
16    read_future: Option<ReadyFuture<()>>,
17    pub read_timeout: Duration,
18}
19
20impl TcpReadStream {
21    pub fn new(tcp_stream: MioTcpStream) -> Self {
22        TcpReadStream {
23            tcp_stream,
24            read_token: event_listener().next_token(),
25            read_future: None,
26            read_timeout: Duration::from_secs(20),
27        }
28    }
29
30    pub fn set_read_timeout(&mut self, duration: Duration) {
31        self.read_timeout = duration;
32    }
33
34    fn wait_read_data(&mut self) -> io::Result<()> {
35        let future = event_listener().listen_read(
36            &mut self.tcp_stream,
37            Instant::now() + self.read_timeout,
38            self.read_token,
39        )?;
40        self.read_future = Some(future);
41        Ok(())
42    }
43
44    fn poll_read_attempt(
45        &mut self,
46        cx: &mut Context<'_>,
47        buf: &mut [u8],
48    ) -> Poll<io::Result<usize>> {
49        let mut future = match self.read_future.take() {
50            None => {
51                match io::Read::read(&mut self.tcp_stream, buf) {
52                    Ok(size) => return Poll::Ready(Ok(size)),
53                    Err(err) if err.kind() == ErrorKind::WouldBlock => (),
54                    Err(err) => return Poll::Ready(Err(err)),
55                }
56                if let Err(err) = self.wait_read_data() {
57                    return Poll::Ready(Err(err));
58                }
59                self.read_future.take().unwrap()
60            }
61            Some(future) => future,
62        };
63        match future.poll_unpin(cx) {
64            Poll::Pending => {
65                self.read_future = Some(future);
66                Poll::Pending
67            }
68            Poll::Ready(ReadyFutureResult::Timeout) => {
69                Poll::Ready(Err(io::ErrorKind::TimedOut.into()))
70            }
71            Poll::Ready(_) => match io::Read::read(&mut self.tcp_stream, buf) {
72                Ok(size) => Poll::Ready(Ok(size)),
73                Err(err) => Poll::Ready(Err(err)),
74            },
75        }
76    }
77}
78
79impl Drop for TcpReadStream {
80    fn drop(&mut self) {
81        event_listener()
82            .stop_listening(&mut self.tcp_stream, self.read_token)
83            .ok();
84    }
85}
86
87impl AsyncRead for TcpReadStream {
88    fn poll_read(
89        self: Pin<&mut Self>,
90        cx: &mut Context<'_>,
91        buf: &mut [u8],
92    ) -> Poll<io::Result<usize>> {
93        let me = self.get_mut();
94        me.poll_read_attempt(cx, buf)
95    }
96}
97
98pub struct TcpWriteStream {
99    tcp_stream: MioTcpStream,
100    write_token: Token,
101    write_future: Option<ReadyFuture<()>>,
102    pub write_timeout: Duration,
103}
104
105impl TcpWriteStream {
106    pub fn new(tcp_stream: MioTcpStream) -> Self {
107        TcpWriteStream {
108            tcp_stream,
109            write_token: event_listener().next_token(),
110            write_future: None,
111            write_timeout: Duration::from_secs(2),
112        }
113    }
114
115    pub fn set_write_timeout(&mut self, duration: Duration) {
116        self.write_timeout = duration;
117    }
118
119    fn wait_write_channel(&mut self) -> io::Result<()> {
120        let future = event_listener().listen_write(
121            &mut self.tcp_stream,
122            Instant::now() + self.write_timeout,
123            self.write_token,
124        )?;
125        self.write_future = Some(future);
126        Ok(())
127    }
128
129    fn poll_write_attempt(&mut self, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
130        let mut future = match self.write_future.take() {
131            None => {
132                match io::Write::write(&mut self.tcp_stream, buf) {
133                    Ok(size) => return Poll::Ready(Ok(size)),
134                    Err(err) if err.kind() == ErrorKind::WouldBlock => (),
135                    Err(err) => return Poll::Ready(Err(err)),
136                }
137
138                if let Err(err) = self.wait_write_channel() {
139                    return Poll::Ready(Err(err));
140                }
141                self.write_future.take().unwrap()
142            }
143            Some(future) => future,
144        };
145        match future.poll_unpin(cx) {
146            Poll::Pending => {
147                self.write_future = Some(future);
148                Poll::Pending
149            }
150            Poll::Ready(ReadyFutureResult::Timeout) => {
151                Poll::Ready(Err(io::ErrorKind::TimedOut.into()))
152            }
153            Poll::Ready(_) => match io::Write::write(&mut self.tcp_stream, buf) {
154                Ok(size) => Poll::Ready(Ok(size)),
155                Err(err) => Poll::Ready(Err(err)),
156            },
157        }
158    }
159
160    pub fn shutdown(&self, how: Shutdown) -> io::Result<()> {
161        self.tcp_stream.shutdown(how)
162    }
163}
164
165impl Drop for TcpWriteStream {
166    fn drop(&mut self) {
167        event_listener()
168            .stop_listening(&mut self.tcp_stream, self.write_token)
169            .ok();
170    }
171}
172
173impl AsyncWrite for TcpWriteStream {
174    fn poll_write(
175        self: Pin<&mut Self>,
176        cx: &mut Context<'_>,
177        buf: &[u8],
178    ) -> Poll<io::Result<usize>> {
179        let me = self.get_mut();
180        me.poll_write_attempt(cx, buf)
181    }
182
183    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
184        Poll::Ready(Ok(()))
185    }
186
187    fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
188        let me = self.get_mut();
189        me.shutdown(Shutdown::Write)?;
190        Poll::Ready(Ok(()))
191    }
192}
193
194pub struct TcpStream {
195    read_stream: TcpReadStream,
196    write_stream: TcpWriteStream,
197}
198
199impl TcpStream {
200    pub fn from(tcp_stream: std::net::TcpStream) -> io::Result<TcpStream> {
201        tcp_stream.set_nonblocking(true)?;
202        Ok(TcpStream {
203            read_stream: TcpReadStream::new(MioTcpStream::from_std(tcp_stream.try_clone()?)),
204            write_stream: TcpWriteStream::new(MioTcpStream::from_std(tcp_stream)),
205        })
206    }
207
208    pub fn connect<A: ToSocketAddrs>(addr: A) -> io::Result<TcpStream> {
209        Self::from(std::net::TcpStream::connect(addr)?)
210    }
211
212    pub fn read_stream(&self) -> &TcpReadStream {
213        &self.read_stream
214    }
215
216    pub fn read_stream_mut(&mut self) -> &mut TcpReadStream {
217        &mut self.read_stream
218    }
219
220    pub fn write_stream(&self) -> &TcpWriteStream {
221        &self.write_stream
222    }
223
224    pub fn write_stream_mut(&mut self) -> &mut TcpWriteStream {
225        &mut self.write_stream
226    }
227
228    pub fn split(self) -> (TcpReadStream, TcpWriteStream) {
229        (self.read_stream, self.write_stream)
230    }
231
232    pub fn set_read_timeout(&mut self, duration: Duration) {
233        self.read_stream.set_read_timeout(duration);
234    }
235
236    pub fn set_write_timeout(&mut self, duration: Duration) {
237        self.write_stream.set_write_timeout(duration);
238    }
239
240    pub fn local_addr(&self) -> io::Result<SocketAddr> {
241        self.read_stream.tcp_stream.local_addr()
242    }
243
244    pub fn peer_addr(&self) -> io::Result<SocketAddr> {
245        self.read_stream.tcp_stream.peer_addr()
246    }
247
248    pub fn shutdown(&self, how: Shutdown) -> io::Result<()> {
249        self.read_stream.tcp_stream.shutdown(how)
250    }
251}
252
253// Only AsyncRead/AsyncWrite implementations remain
254impl AsyncRead for TcpStream {
255    fn poll_read(
256        self: Pin<&mut Self>,
257        cx: &mut Context<'_>,
258        buf: &mut [u8],
259    ) -> Poll<io::Result<usize>> {
260        let me = self.get_mut();
261        Pin::new(&mut me.read_stream).poll_read(cx, buf)
262    }
263}
264
265impl AsyncWrite for TcpStream {
266    fn poll_write(
267        self: Pin<&mut Self>,
268        cx: &mut Context<'_>,
269        buf: &[u8],
270    ) -> Poll<io::Result<usize>> {
271        let me = self.get_mut();
272        Pin::new(&mut me.write_stream).poll_write(cx, buf)
273    }
274
275    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
276        let me = self.get_mut();
277        Pin::new(&mut me.write_stream).poll_flush(cx)
278    }
279
280    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
281        let me = self.get_mut();
282        Pin::new(&mut me.write_stream).poll_close(cx)
283    }
284}
285
286#[cfg(test)]
287mod tests {
288    use super::*;
289    use futures::executor::block_on;
290    use futures::io::{AsyncReadExt, AsyncWriteExt};
291    use std::io::{Read, Write};
292    use std::net::TcpListener;
293    use std::thread;
294    use std::time::Duration;
295
296    fn setup_test_server() -> (TcpListener, std::net::SocketAddr) {
297        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
298        let addr = listener.local_addr().unwrap();
299        (listener, addr)
300    }
301
302    #[test]
303    fn test_tcp_stream_wrapper_creation() {
304        let (listener, addr) = setup_test_server();
305
306        thread::spawn(move || {
307            if let Ok((stream, _)) = listener.accept() {
308                drop(stream);
309            }
310        });
311        let wrapper = TcpStream::connect(addr);
312        assert!(wrapper.is_ok());
313
314        let wrapper = wrapper.unwrap();
315        assert_eq!(wrapper.peer_addr().unwrap(), addr);
316    }
317
318    #[test]
319    fn test_stream_accessors() {
320        let (listener, addr) = setup_test_server();
321
322        thread::spawn(move || {
323            if let Ok((stream, _)) = listener.accept() {
324                drop(stream);
325            }
326        });
327
328        let mut wrapper = TcpStream::connect(addr).unwrap();
329
330        let read_stream = wrapper.read_stream();
331        assert_eq!(read_stream.read_timeout, Duration::from_secs(20));
332
333        let read_stream_mut = wrapper.read_stream_mut();
334        read_stream_mut.set_read_timeout(Duration::from_secs(15));
335        assert_eq!(read_stream_mut.read_timeout, Duration::from_secs(15));
336
337        let write_stream = wrapper.write_stream();
338        assert_eq!(write_stream.write_timeout, Duration::from_secs(2));
339
340        let write_stream_mut = wrapper.write_stream_mut();
341        write_stream_mut.set_write_timeout(Duration::from_secs(10));
342        assert_eq!(write_stream_mut.write_timeout, Duration::from_secs(10));
343    }
344
345    #[test]
346    fn test_stream_split() {
347        let (listener, addr) = setup_test_server();
348
349        thread::spawn(move || {
350            if let Ok((stream, _)) = listener.accept() {
351                drop(stream);
352            }
353        });
354
355        let wrapper = TcpStream::connect(addr).unwrap();
356        let (read_stream, write_stream) = wrapper.split();
357
358        assert_eq!(read_stream.read_timeout, Duration::from_secs(20));
359        assert_eq!(write_stream.write_timeout, Duration::from_secs(2));
360    }
361
362    #[test]
363    fn test_async_read_write() {
364        let (listener, addr) = setup_test_server();
365
366        thread::spawn(move || match listener.accept() {
367            Ok((mut stream, _)) => {
368                let mut buf = [0u8; 1024];
369                loop {
370                    let n = stream.read(&mut buf).unwrap();
371                    if n == 0 {
372                        break;
373                    }
374                    let _ = stream.write_all(&buf[..n]);
375                }
376            }
377            Err(err) => {
378                eprintln!("server error {:?}", &err);
379            }
380        });
381
382        thread::sleep(Duration::from_millis(10));
383
384        let test_future = async {
385            let mut wrapper = TcpStream::connect(addr).unwrap();
386
387            let test_data = &[1, 2, 3, 4, 5, 6];
388            let written = wrapper.write_all(test_data).await;
389            assert!(written.is_ok());
390
391            let mut buf = [0u8; 1024];
392            let read = wrapper.read_exact(&mut buf[..2]).await;
393            assert!(read.is_ok());
394            let read = wrapper.read_exact(&mut buf[2..test_data.len()]).await;
395            assert!(read.is_ok());
396            assert_eq!(&buf[..test_data.len()], test_data);
397
398            let test_data = &[7, 8, 9, 10];
399            let written = wrapper.write_all(test_data).await;
400            assert!(written.is_ok());
401            let read = wrapper.read(&mut buf).await;
402            assert!(read.is_ok());
403            assert_eq!(&buf[..test_data.len()], test_data);
404        };
405
406        block_on(test_future);
407    }
408
409    #[test]
410    fn test_async_read_write_with_delay() {
411        let (listener, addr) = setup_test_server();
412
413        thread::spawn(move || {
414            if let Ok((mut stream, _)) = listener.accept() {
415                let mut buf = [0u8; 1024];
416                let n = stream.read(&mut buf).unwrap();
417                let half = n / 2;
418                stream.write_all(&buf[..half]).unwrap();
419                thread::sleep(Duration::from_millis(50));
420                stream.write_all(&buf[half..n]).unwrap();
421            }
422        });
423
424        let test_future = async {
425            let mut wrapper = TcpStream::connect(addr).unwrap();
426            let test_data = b"Delayed Hello!";
427            let written = wrapper.write_all(test_data).await;
428            assert!(written.is_ok());
429
430            let mut buf = [0u8; 1024];
431            let read = wrapper.read_exact(&mut buf[..test_data.len()]).await;
432            assert!(read.is_ok());
433            assert_eq!(&buf[..test_data.len()], test_data);
434        };
435
436        block_on(test_future);
437    }
438
439    #[test]
440    fn test_concurrent_operations() {
441        let (listener, addr) = setup_test_server();
442
443        thread::spawn(move || {
444            for _ in 0..3 {
445                if let Ok((mut stream, _)) = listener.accept() {
446                    thread::spawn(move || {
447                        let mut buf = [0u8; 1024];
448                        let n = stream.read(&mut buf).unwrap();
449                        let _ = stream.write_all(&buf[..n]);
450                    });
451                }
452            }
453        });
454
455        thread::sleep(Duration::from_millis(10));
456
457        let test_future = async {
458            let mut futures = Vec::new();
459
460            for i in 0..3 {
461                let test_data = format!("Message {}", i);
462                let future = async move {
463                    let mut client = TcpStream::connect(addr).unwrap();
464                    client.write_all(test_data.as_bytes()).await.unwrap();
465                    let mut buf = [0u8; 1024];
466                    let read_bytes = client.read(&mut buf).await.unwrap();
467                    assert_eq!(&buf[..read_bytes], test_data.as_bytes());
468                };
469                futures.push(future);
470            }
471
472            futures::future::join_all(futures).await;
473        };
474
475        block_on(test_future);
476    }
477
478    #[test]
479    fn test_timeout_behavior() {
480        let (listener, addr) = setup_test_server();
481
482        thread::spawn(move || {
483            if let Ok((mut stream, _)) = listener.accept() {
484                let mut buf = [0u8; 1024];
485                let _ = stream.read(&mut buf);
486                thread::sleep(Duration::from_millis(200));
487                let _ = stream.write_all(b"slow response");
488            }
489        });
490
491        thread::sleep(Duration::from_millis(10));
492
493        let test_future = async {
494            let mut wrapper = TcpStream::connect(addr).unwrap();
495
496            wrapper.set_read_timeout(Duration::from_millis(50));
497            wrapper.write_all(b"test").await.unwrap();
498
499            let mut buf = [0u8; 1024];
500            let read_result = wrapper.read(&mut buf).await;
501
502            assert!(read_result.is_err());
503            let err = read_result.unwrap_err();
504            assert_eq!(err.kind(), io::ErrorKind::TimedOut);
505        };
506
507        block_on(test_future);
508    }
509
510    #[test]
511    fn test_shutdown() {
512        let (listener, addr) = setup_test_server();
513
514        thread::spawn(move || {
515            if let Ok((mut stream, _)) = listener.accept() {
516                let mut buf = [0u8; 1024];
517                if let Ok(n) = stream.read(&mut buf) {
518                    let _ = stream.write_all(&buf[..n]);
519                }
520            }
521        });
522
523        let wrapper = TcpStream::connect(addr).unwrap();
524        let result = wrapper.shutdown(Shutdown::Both);
525        assert!(result.is_ok());
526    }
527
528    #[test]
529    fn test_split_streams_independently() {
530        let (listener, addr) = setup_test_server();
531
532        thread::spawn(move || {
533            if let Ok((mut stream, _)) = listener.accept() {
534                let mut buf = [0u8; 1024];
535                if let Ok(n) = stream.read(&mut buf) {
536                    let _ = stream.write_all(&buf[..n]);
537                }
538            }
539        });
540
541        thread::sleep(Duration::from_millis(10));
542
543        let test_future = async {
544            let wrapper = TcpStream::connect(addr).unwrap();
545            let (mut read_stream, mut write_stream) = wrapper.split();
546
547            // Test that split streams work independently
548            let test_data = b"Split stream test";
549            write_stream.write_all(test_data).await.unwrap();
550
551            let mut buf = [0u8; 1024];
552            let read_bytes = read_stream.read(&mut buf).await.unwrap();
553            assert_eq!(&buf[..read_bytes], test_data);
554        };
555
556        block_on(test_future);
557    }
558}