1use crate::common::ready_future::ReadyFuture;
2use crate::common::ready_future_state::ReadyFutureResult;
3use crate::net::event_listener;
4use futures::{AsyncRead, AsyncWrite, FutureExt};
5use mio::Token;
6use mio::net::TcpStream as MioTcpStream;
7use std::io::{self, ErrorKind};
8use std::net::{Shutdown, SocketAddr, ToSocketAddrs};
9use std::pin::Pin;
10use std::task::{Context, Poll};
11use std::time::{Duration, Instant};
12
13pub struct TcpReadStream {
18 tcp_stream: MioTcpStream,
19 read_token: Token,
20 read_future: Option<ReadyFuture<()>>,
21 pub read_timeout: Duration,
22}
23
24impl TcpReadStream {
25 pub fn new(tcp_stream: MioTcpStream) -> Self {
26 TcpReadStream {
27 tcp_stream,
28 read_token: event_listener().next_token(),
29 read_future: None,
30 read_timeout: Duration::from_secs(20),
31 }
32 }
33
34 pub fn set_read_timeout(&mut self, duration: Duration) {
35 self.read_timeout = duration;
36 }
37
38 fn wait_read_data(&mut self) -> io::Result<()> {
39 let future = event_listener().listen_read(
40 &mut self.tcp_stream,
41 Instant::now() + self.read_timeout,
42 self.read_token,
43 )?;
44 self.read_future = Some(future);
45 Ok(())
46 }
47
48 fn poll_read_attempt(
49 &mut self,
50 cx: &mut Context<'_>,
51 buf: &mut [u8],
52 ) -> Poll<io::Result<usize>> {
53 let mut future = match self.read_future.take() {
54 None => {
55 match io::Read::read(&mut self.tcp_stream, buf) {
56 Ok(size) => return Poll::Ready(Ok(size)),
57 Err(err) if err.kind() == ErrorKind::WouldBlock => (),
58 Err(err) => return Poll::Ready(Err(err)),
59 }
60 if let Err(err) = self.wait_read_data() {
61 return Poll::Ready(Err(err));
62 }
63 self.read_future.take().unwrap()
64 }
65 Some(future) => future,
66 };
67 match future.poll_unpin(cx) {
68 Poll::Pending => {
69 self.read_future = Some(future);
70 Poll::Pending
71 }
72 Poll::Ready(ReadyFutureResult::Timeout) => {
73 Poll::Ready(Err(io::ErrorKind::TimedOut.into()))
74 }
75 Poll::Ready(_) => match io::Read::read(&mut self.tcp_stream, buf) {
76 Ok(size) => Poll::Ready(Ok(size)),
77 Err(err) => Poll::Ready(Err(err)),
78 },
79 }
80 }
81}
82
83impl Drop for TcpReadStream {
84 fn drop(&mut self) {
85 event_listener()
86 .stop_listening(&mut self.tcp_stream, self.read_token)
87 .ok();
88 }
89}
90
91impl AsyncRead for TcpReadStream {
92 fn poll_read(
93 self: Pin<&mut Self>,
94 cx: &mut Context<'_>,
95 buf: &mut [u8],
96 ) -> Poll<io::Result<usize>> {
97 let me = self.get_mut();
98 me.poll_read_attempt(cx, buf)
99 }
100}
101
102pub struct TcpWriteStream {
107 tcp_stream: MioTcpStream,
108 write_token: Token,
109 write_future: Option<ReadyFuture<()>>,
110 pub write_timeout: Duration,
111}
112
113impl TcpWriteStream {
114 pub fn new(tcp_stream: MioTcpStream) -> Self {
115 TcpWriteStream {
116 tcp_stream,
117 write_token: event_listener().next_token(),
118 write_future: None,
119 write_timeout: Duration::from_secs(2),
120 }
121 }
122
123 pub fn set_write_timeout(&mut self, duration: Duration) {
124 self.write_timeout = duration;
125 }
126
127 fn wait_write_channel(&mut self) -> io::Result<()> {
128 let future = event_listener().listen_write(
129 &mut self.tcp_stream,
130 Instant::now() + self.write_timeout,
131 self.write_token,
132 )?;
133 self.write_future = Some(future);
134 Ok(())
135 }
136
137 fn poll_write_attempt(&mut self, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
138 let mut future = match self.write_future.take() {
139 None => {
140 match io::Write::write(&mut self.tcp_stream, buf) {
141 Ok(size) => return Poll::Ready(Ok(size)),
142 Err(err) if err.kind() == ErrorKind::WouldBlock => (),
143 Err(err) => return Poll::Ready(Err(err)),
144 }
145
146 if let Err(err) = self.wait_write_channel() {
147 return Poll::Ready(Err(err));
148 }
149 self.write_future.take().unwrap()
150 }
151 Some(future) => future,
152 };
153 match future.poll_unpin(cx) {
154 Poll::Pending => {
155 self.write_future = Some(future);
156 Poll::Pending
157 }
158 Poll::Ready(ReadyFutureResult::Timeout) => {
159 Poll::Ready(Err(io::ErrorKind::TimedOut.into()))
160 }
161 Poll::Ready(_) => match io::Write::write(&mut self.tcp_stream, buf) {
162 Ok(size) => Poll::Ready(Ok(size)),
163 Err(err) => Poll::Ready(Err(err)),
164 },
165 }
166 }
167
168 pub fn shutdown(&self, how: Shutdown) -> io::Result<()> {
169 self.tcp_stream.shutdown(how)
170 }
171}
172
173impl Drop for TcpWriteStream {
174 fn drop(&mut self) {
175 event_listener()
176 .stop_listening(&mut self.tcp_stream, self.write_token)
177 .ok();
178 }
179}
180
181impl AsyncWrite for TcpWriteStream {
182 fn poll_write(
183 self: Pin<&mut Self>,
184 cx: &mut Context<'_>,
185 buf: &[u8],
186 ) -> Poll<io::Result<usize>> {
187 let me = self.get_mut();
188 me.poll_write_attempt(cx, buf)
189 }
190
191 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
192 Poll::Ready(Ok(()))
193 }
194
195 fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
196 let me = self.get_mut();
197 me.shutdown(Shutdown::Write)?;
198 Poll::Ready(Ok(()))
199 }
200}
201
202pub struct TcpStream {
207 read_stream: TcpReadStream,
208 write_stream: TcpWriteStream,
209}
210
211impl TcpStream {
212 pub fn from(tcp_stream: std::net::TcpStream) -> io::Result<TcpStream> {
213 tcp_stream.set_nonblocking(true)?;
214 Ok(TcpStream {
215 read_stream: TcpReadStream::new(MioTcpStream::from_std(tcp_stream.try_clone()?)),
216 write_stream: TcpWriteStream::new(MioTcpStream::from_std(tcp_stream)),
217 })
218 }
219
220 pub fn connect<A: ToSocketAddrs>(addr: A) -> io::Result<TcpStream> {
221 Self::from(std::net::TcpStream::connect(addr)?)
222 }
223
224 pub fn read_stream(&self) -> &TcpReadStream {
225 &self.read_stream
226 }
227
228 pub fn read_stream_mut(&mut self) -> &mut TcpReadStream {
229 &mut self.read_stream
230 }
231
232 pub fn write_stream(&self) -> &TcpWriteStream {
233 &self.write_stream
234 }
235
236 pub fn write_stream_mut(&mut self) -> &mut TcpWriteStream {
237 &mut self.write_stream
238 }
239
240 pub fn split(self) -> (TcpReadStream, TcpWriteStream) {
241 (self.read_stream, self.write_stream)
242 }
243
244 pub fn set_read_timeout(&mut self, duration: Duration) {
245 self.read_stream.set_read_timeout(duration);
246 }
247
248 pub fn set_write_timeout(&mut self, duration: Duration) {
249 self.write_stream.set_write_timeout(duration);
250 }
251
252 pub fn local_addr(&self) -> io::Result<SocketAddr> {
253 self.read_stream.tcp_stream.local_addr()
254 }
255
256 pub fn peer_addr(&self) -> io::Result<SocketAddr> {
257 self.read_stream.tcp_stream.peer_addr()
258 }
259
260 pub fn shutdown(&self, how: Shutdown) -> io::Result<()> {
261 self.read_stream.tcp_stream.shutdown(how)
262 }
263}
264
265impl AsyncRead for TcpStream {
267 fn poll_read(
268 self: Pin<&mut Self>,
269 cx: &mut Context<'_>,
270 buf: &mut [u8],
271 ) -> Poll<io::Result<usize>> {
272 let me = self.get_mut();
273 Pin::new(&mut me.read_stream).poll_read(cx, buf)
274 }
275}
276
277impl AsyncWrite for TcpStream {
278 fn poll_write(
279 self: Pin<&mut Self>,
280 cx: &mut Context<'_>,
281 buf: &[u8],
282 ) -> Poll<io::Result<usize>> {
283 let me = self.get_mut();
284 Pin::new(&mut me.write_stream).poll_write(cx, buf)
285 }
286
287 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
288 let me = self.get_mut();
289 Pin::new(&mut me.write_stream).poll_flush(cx)
290 }
291
292 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
293 let me = self.get_mut();
294 Pin::new(&mut me.write_stream).poll_close(cx)
295 }
296}
297
298#[cfg(test)]
299mod tests {
300 use super::*;
301 use futures::executor::block_on;
302 use futures::io::{AsyncReadExt, AsyncWriteExt};
303 use std::io::{Read, Write};
304 use std::net::TcpListener;
305 use std::thread;
306 use std::time::Duration;
307
308 fn setup_test_server() -> (TcpListener, std::net::SocketAddr) {
309 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
310 let addr = listener.local_addr().unwrap();
311 (listener, addr)
312 }
313
314 #[test]
315 fn test_tcp_stream_wrapper_creation() {
316 let (listener, addr) = setup_test_server();
317
318 thread::spawn(move || {
319 if let Ok((stream, _)) = listener.accept() {
320 drop(stream);
321 }
322 });
323 let wrapper = TcpStream::connect(addr);
324 assert!(wrapper.is_ok());
325
326 let wrapper = wrapper.unwrap();
327 assert_eq!(wrapper.peer_addr().unwrap(), addr);
328 }
329
330 #[test]
331 fn test_stream_accessors() {
332 let (listener, addr) = setup_test_server();
333
334 thread::spawn(move || {
335 if let Ok((stream, _)) = listener.accept() {
336 drop(stream);
337 }
338 });
339
340 let mut wrapper = TcpStream::connect(addr).unwrap();
341
342 let read_stream = wrapper.read_stream();
343 assert_eq!(read_stream.read_timeout, Duration::from_secs(20));
344
345 let read_stream_mut = wrapper.read_stream_mut();
346 read_stream_mut.set_read_timeout(Duration::from_secs(15));
347 assert_eq!(read_stream_mut.read_timeout, Duration::from_secs(15));
348
349 let write_stream = wrapper.write_stream();
350 assert_eq!(write_stream.write_timeout, Duration::from_secs(2));
351
352 let write_stream_mut = wrapper.write_stream_mut();
353 write_stream_mut.set_write_timeout(Duration::from_secs(10));
354 assert_eq!(write_stream_mut.write_timeout, Duration::from_secs(10));
355 }
356
357 #[test]
358 fn test_stream_split() {
359 let (listener, addr) = setup_test_server();
360
361 thread::spawn(move || {
362 if let Ok((stream, _)) = listener.accept() {
363 drop(stream);
364 }
365 });
366
367 let wrapper = TcpStream::connect(addr).unwrap();
368 let (read_stream, write_stream) = wrapper.split();
369
370 assert_eq!(read_stream.read_timeout, Duration::from_secs(20));
371 assert_eq!(write_stream.write_timeout, Duration::from_secs(2));
372 }
373
374 #[test]
375 fn test_async_read_write() {
376 let (listener, addr) = setup_test_server();
377
378 thread::spawn(move || match listener.accept() {
379 Ok((mut stream, _)) => {
380 let mut buf = [0u8; 1024];
381 loop {
382 let n = stream.read(&mut buf).unwrap();
383 if n == 0 {
384 break;
385 }
386 let _ = stream.write_all(&buf[..n]);
387 }
388 }
389 Err(err) => {
390 eprintln!("server error {:?}", &err);
391 }
392 });
393
394 thread::sleep(Duration::from_millis(10));
395
396 let test_future = async {
397 let mut wrapper = TcpStream::connect(addr).unwrap();
398
399 let test_data = &[1, 2, 3, 4, 5, 6];
400 let written = wrapper.write_all(test_data).await;
401 assert!(written.is_ok());
402
403 let mut buf = [0u8; 1024];
404 let read = wrapper.read_exact(&mut buf[..2]).await;
405 assert!(read.is_ok());
406 let read = wrapper.read_exact(&mut buf[2..test_data.len()]).await;
407 assert!(read.is_ok());
408 assert_eq!(&buf[..test_data.len()], test_data);
409
410 let test_data = &[7, 8, 9, 10];
411 let written = wrapper.write_all(test_data).await;
412 assert!(written.is_ok());
413 let read = wrapper.read(&mut buf).await;
414 assert!(read.is_ok());
415 assert_eq!(&buf[..test_data.len()], test_data);
416 };
417
418 block_on(test_future);
419 }
420
421 #[test]
422 fn test_async_read_write_with_delay() {
423 let (listener, addr) = setup_test_server();
424
425 thread::spawn(move || {
426 if let Ok((mut stream, _)) = listener.accept() {
427 let mut buf = [0u8; 1024];
428 let n = stream.read(&mut buf).unwrap();
429 let half = n / 2;
430 stream.write_all(&buf[..half]).unwrap();
431 thread::sleep(Duration::from_millis(50));
432 stream.write_all(&buf[half..n]).unwrap();
433 }
434 });
435
436 let test_future = async {
437 let mut wrapper = TcpStream::connect(addr).unwrap();
438 let test_data = b"Delayed Hello!";
439 let written = wrapper.write_all(test_data).await;
440 assert!(written.is_ok());
441
442 let mut buf = [0u8; 1024];
443 let read = wrapper.read_exact(&mut buf[..test_data.len()]).await;
444 assert!(read.is_ok());
445 assert_eq!(&buf[..test_data.len()], test_data);
446 };
447
448 block_on(test_future);
449 }
450
451 #[test]
452 fn test_concurrent_operations() {
453 let (listener, addr) = setup_test_server();
454
455 thread::spawn(move || {
456 for _ in 0..3 {
457 if let Ok((mut stream, _)) = listener.accept() {
458 thread::spawn(move || {
459 let mut buf = [0u8; 1024];
460 let n = stream.read(&mut buf).unwrap();
461 let _ = stream.write_all(&buf[..n]);
462 });
463 }
464 }
465 });
466
467 thread::sleep(Duration::from_millis(10));
468
469 let test_future = async {
470 let mut futures = Vec::new();
471
472 for i in 0..3 {
473 let test_data = format!("Message {}", i);
474 let future = async move {
475 let mut client = TcpStream::connect(addr).unwrap();
476 client.write_all(test_data.as_bytes()).await.unwrap();
477 let mut buf = [0u8; 1024];
478 let read_bytes = client.read(&mut buf).await.unwrap();
479 assert_eq!(&buf[..read_bytes], test_data.as_bytes());
480 };
481 futures.push(future);
482 }
483
484 futures::future::join_all(futures).await;
485 };
486
487 block_on(test_future);
488 }
489
490 #[test]
491 fn test_timeout_behavior() {
492 let (listener, addr) = setup_test_server();
493
494 thread::spawn(move || {
495 if let Ok((mut stream, _)) = listener.accept() {
496 let mut buf = [0u8; 1024];
497 let _ = stream.read(&mut buf);
498 thread::sleep(Duration::from_millis(200));
499 let _ = stream.write_all(b"slow response");
500 }
501 });
502
503 thread::sleep(Duration::from_millis(10));
504
505 let test_future = async {
506 let mut wrapper = TcpStream::connect(addr).unwrap();
507
508 wrapper.set_read_timeout(Duration::from_millis(50));
509 wrapper.write_all(b"test").await.unwrap();
510
511 let mut buf = [0u8; 1024];
512 let read_result = wrapper.read(&mut buf).await;
513
514 assert!(read_result.is_err());
515 let err = read_result.unwrap_err();
516 assert_eq!(err.kind(), io::ErrorKind::TimedOut);
517 };
518
519 block_on(test_future);
520 }
521
522 #[test]
523 fn test_shutdown() {
524 let (listener, addr) = setup_test_server();
525
526 thread::spawn(move || {
527 if let Ok((mut stream, _)) = listener.accept() {
528 let mut buf = [0u8; 1024];
529 if let Ok(n) = stream.read(&mut buf) {
530 let _ = stream.write_all(&buf[..n]);
531 }
532 }
533 });
534
535 let wrapper = TcpStream::connect(addr).unwrap();
536 let result = wrapper.shutdown(Shutdown::Both);
537 assert!(result.is_ok());
538 }
539
540 #[test]
541 fn test_split_streams_independently() {
542 let (listener, addr) = setup_test_server();
543
544 thread::spawn(move || {
545 if let Ok((mut stream, _)) = listener.accept() {
546 let mut buf = [0u8; 1024];
547 if let Ok(n) = stream.read(&mut buf) {
548 let _ = stream.write_all(&buf[..n]);
549 }
550 }
551 });
552
553 thread::sleep(Duration::from_millis(10));
554
555 let test_future = async {
556 let wrapper = TcpStream::connect(addr).unwrap();
557 let (mut read_stream, mut write_stream) = wrapper.split();
558
559 let test_data = b"Split stream test";
561 write_stream.write_all(test_data).await.unwrap();
562
563 let mut buf = [0u8; 1024];
564 let read_bytes = read_stream.read(&mut buf).await.unwrap();
565 assert_eq!(&buf[..read_bytes], test_data);
566 };
567
568 block_on(test_future);
569 }
570}