1use crate::common::ready_future::ReadyFuture;
2use crate::common::ready_future_state::ReadyFutureResult;
3use crate::net::event_listener;
4use futures::{AsyncRead, AsyncWrite, FutureExt};
5use mio::Token;
6use mio::net::TcpStream as MioTcpStream;
7use std::io::{self, ErrorKind};
8use std::net::{Shutdown, SocketAddr, ToSocketAddrs};
9use std::pin::Pin;
10use std::task::{Context, Poll};
11use std::time::{Duration, Instant};
12
13pub struct TcpReadStream {
14 tcp_stream: MioTcpStream,
15 read_token: Token,
16 read_future: Option<ReadyFuture<()>>,
17 pub read_timeout: Duration,
18}
19
20impl TcpReadStream {
21 pub fn new(tcp_stream: MioTcpStream) -> Self {
22 TcpReadStream {
23 tcp_stream,
24 read_token: event_listener().next_token(),
25 read_future: None,
26 read_timeout: Duration::from_secs(20),
27 }
28 }
29
30 pub fn set_read_timeout(&mut self, duration: Duration) {
31 self.read_timeout = duration;
32 }
33
34 fn wait_read_data(&mut self) -> io::Result<()> {
35 let future = event_listener().listen_read(
36 &mut self.tcp_stream,
37 Instant::now() + self.read_timeout,
38 self.read_token,
39 )?;
40 self.read_future = Some(future);
41 Ok(())
42 }
43
44 fn poll_read_attempt(
45 &mut self,
46 cx: &mut Context<'_>,
47 buf: &mut [u8],
48 ) -> Poll<io::Result<usize>> {
49 let mut future = match self.read_future.take() {
50 None => {
51 match io::Read::read(&mut self.tcp_stream, buf) {
52 Ok(size) => return Poll::Ready(Ok(size)),
53 Err(err) if err.kind() == ErrorKind::WouldBlock => (),
54 Err(err) => return Poll::Ready(Err(err)),
55 }
56 if let Err(err) = self.wait_read_data() {
57 return Poll::Ready(Err(err));
58 }
59 self.read_future.take().unwrap()
60 }
61 Some(future) => future,
62 };
63 match future.poll_unpin(cx) {
64 Poll::Pending => {
65 self.read_future = Some(future);
66 Poll::Pending
67 }
68 Poll::Ready(ReadyFutureResult::Timeout) => {
69 Poll::Ready(Err(io::ErrorKind::TimedOut.into()))
70 }
71 Poll::Ready(_) => match io::Read::read(&mut self.tcp_stream, buf) {
72 Ok(size) => Poll::Ready(Ok(size)),
73 Err(err) => Poll::Ready(Err(err)),
74 },
75 }
76 }
77}
78
79impl Drop for TcpReadStream {
80 fn drop(&mut self) {
81 event_listener()
82 .stop_listening(&mut self.tcp_stream, self.read_token)
83 .ok();
84 }
85}
86
87impl AsyncRead for TcpReadStream {
88 fn poll_read(
89 self: Pin<&mut Self>,
90 cx: &mut Context<'_>,
91 buf: &mut [u8],
92 ) -> Poll<io::Result<usize>> {
93 let me = self.get_mut();
94 me.poll_read_attempt(cx, buf)
95 }
96}
97
98pub struct TcpWriteStream {
99 tcp_stream: MioTcpStream,
100 write_token: Token,
101 write_future: Option<ReadyFuture<()>>,
102 pub write_timeout: Duration,
103}
104
105impl TcpWriteStream {
106 pub fn new(tcp_stream: MioTcpStream) -> Self {
107 TcpWriteStream {
108 tcp_stream,
109 write_token: event_listener().next_token(),
110 write_future: None,
111 write_timeout: Duration::from_secs(2),
112 }
113 }
114
115 pub fn set_write_timeout(&mut self, duration: Duration) {
116 self.write_timeout = duration;
117 }
118
119 fn wait_write_channel(&mut self) -> io::Result<()> {
120 let future = event_listener().listen_write(
121 &mut self.tcp_stream,
122 Instant::now() + self.write_timeout,
123 self.write_token,
124 )?;
125 self.write_future = Some(future);
126 Ok(())
127 }
128
129 fn poll_write_attempt(&mut self, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
130 let mut future = match self.write_future.take() {
131 None => {
132 match io::Write::write(&mut self.tcp_stream, buf) {
133 Ok(size) => return Poll::Ready(Ok(size)),
134 Err(err) if err.kind() == ErrorKind::WouldBlock => (),
135 Err(err) => return Poll::Ready(Err(err)),
136 }
137
138 if let Err(err) = self.wait_write_channel() {
139 return Poll::Ready(Err(err));
140 }
141 self.write_future.take().unwrap()
142 }
143 Some(future) => future,
144 };
145 match future.poll_unpin(cx) {
146 Poll::Pending => {
147 self.write_future = Some(future);
148 Poll::Pending
149 }
150 Poll::Ready(ReadyFutureResult::Timeout) => {
151 Poll::Ready(Err(io::ErrorKind::TimedOut.into()))
152 }
153 Poll::Ready(_) => match io::Write::write(&mut self.tcp_stream, buf) {
154 Ok(size) => Poll::Ready(Ok(size)),
155 Err(err) => Poll::Ready(Err(err)),
156 },
157 }
158 }
159
160 pub fn shutdown(&self, how: Shutdown) -> io::Result<()> {
161 self.tcp_stream.shutdown(how)
162 }
163}
164
165impl Drop for TcpWriteStream {
166 fn drop(&mut self) {
167 event_listener()
168 .stop_listening(&mut self.tcp_stream, self.write_token)
169 .ok();
170 }
171}
172
173impl AsyncWrite for TcpWriteStream {
174 fn poll_write(
175 self: Pin<&mut Self>,
176 cx: &mut Context<'_>,
177 buf: &[u8],
178 ) -> Poll<io::Result<usize>> {
179 let me = self.get_mut();
180 me.poll_write_attempt(cx, buf)
181 }
182
183 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
184 Poll::Ready(Ok(()))
185 }
186
187 fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
188 let me = self.get_mut();
189 me.shutdown(Shutdown::Write)?;
190 Poll::Ready(Ok(()))
191 }
192}
193
194pub struct TcpStream {
195 read_stream: TcpReadStream,
196 write_stream: TcpWriteStream,
197}
198
199impl TcpStream {
200 pub fn from(tcp_stream: std::net::TcpStream) -> io::Result<TcpStream> {
201 tcp_stream.set_nonblocking(true)?;
202 Ok(TcpStream {
203 read_stream: TcpReadStream::new(MioTcpStream::from_std(tcp_stream.try_clone()?)),
204 write_stream: TcpWriteStream::new(MioTcpStream::from_std(tcp_stream)),
205 })
206 }
207
208 pub fn connect<A: ToSocketAddrs>(addr: A) -> io::Result<TcpStream> {
209 Self::from(std::net::TcpStream::connect(addr)?)
210 }
211
212 pub fn read_stream(&self) -> &TcpReadStream {
213 &self.read_stream
214 }
215
216 pub fn read_stream_mut(&mut self) -> &mut TcpReadStream {
217 &mut self.read_stream
218 }
219
220 pub fn write_stream(&self) -> &TcpWriteStream {
221 &self.write_stream
222 }
223
224 pub fn write_stream_mut(&mut self) -> &mut TcpWriteStream {
225 &mut self.write_stream
226 }
227
228 pub fn split(self) -> (TcpReadStream, TcpWriteStream) {
229 (self.read_stream, self.write_stream)
230 }
231
232 pub fn set_read_timeout(&mut self, duration: Duration) {
233 self.read_stream.set_read_timeout(duration);
234 }
235
236 pub fn set_write_timeout(&mut self, duration: Duration) {
237 self.write_stream.set_write_timeout(duration);
238 }
239
240 pub fn local_addr(&self) -> io::Result<SocketAddr> {
241 self.read_stream.tcp_stream.local_addr()
242 }
243
244 pub fn peer_addr(&self) -> io::Result<SocketAddr> {
245 self.read_stream.tcp_stream.peer_addr()
246 }
247
248 pub fn shutdown(&self, how: Shutdown) -> io::Result<()> {
249 self.read_stream.tcp_stream.shutdown(how)
250 }
251}
252
253impl AsyncRead for TcpStream {
255 fn poll_read(
256 self: Pin<&mut Self>,
257 cx: &mut Context<'_>,
258 buf: &mut [u8],
259 ) -> Poll<io::Result<usize>> {
260 let me = self.get_mut();
261 Pin::new(&mut me.read_stream).poll_read(cx, buf)
262 }
263}
264
265impl AsyncWrite for TcpStream {
266 fn poll_write(
267 self: Pin<&mut Self>,
268 cx: &mut Context<'_>,
269 buf: &[u8],
270 ) -> Poll<io::Result<usize>> {
271 let me = self.get_mut();
272 Pin::new(&mut me.write_stream).poll_write(cx, buf)
273 }
274
275 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
276 let me = self.get_mut();
277 Pin::new(&mut me.write_stream).poll_flush(cx)
278 }
279
280 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
281 let me = self.get_mut();
282 Pin::new(&mut me.write_stream).poll_close(cx)
283 }
284}
285
286#[cfg(test)]
287mod tests {
288 use super::*;
289 use futures::executor::block_on;
290 use futures::io::{AsyncReadExt, AsyncWriteExt};
291 use std::io::{Read, Write};
292 use std::net::TcpListener;
293 use std::thread;
294 use std::time::Duration;
295
296 fn setup_test_server() -> (TcpListener, std::net::SocketAddr) {
297 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
298 let addr = listener.local_addr().unwrap();
299 (listener, addr)
300 }
301
302 #[test]
303 fn test_tcp_stream_wrapper_creation() {
304 let (listener, addr) = setup_test_server();
305
306 thread::spawn(move || {
307 if let Ok((stream, _)) = listener.accept() {
308 drop(stream);
309 }
310 });
311 let wrapper = TcpStream::connect(addr);
312 assert!(wrapper.is_ok());
313
314 let wrapper = wrapper.unwrap();
315 assert_eq!(wrapper.peer_addr().unwrap(), addr);
316 }
317
318 #[test]
319 fn test_stream_accessors() {
320 let (listener, addr) = setup_test_server();
321
322 thread::spawn(move || {
323 if let Ok((stream, _)) = listener.accept() {
324 drop(stream);
325 }
326 });
327
328 let mut wrapper = TcpStream::connect(addr).unwrap();
329
330 let read_stream = wrapper.read_stream();
331 assert_eq!(read_stream.read_timeout, Duration::from_secs(20));
332
333 let read_stream_mut = wrapper.read_stream_mut();
334 read_stream_mut.set_read_timeout(Duration::from_secs(15));
335 assert_eq!(read_stream_mut.read_timeout, Duration::from_secs(15));
336
337 let write_stream = wrapper.write_stream();
338 assert_eq!(write_stream.write_timeout, Duration::from_secs(2));
339
340 let write_stream_mut = wrapper.write_stream_mut();
341 write_stream_mut.set_write_timeout(Duration::from_secs(10));
342 assert_eq!(write_stream_mut.write_timeout, Duration::from_secs(10));
343 }
344
345 #[test]
346 fn test_stream_split() {
347 let (listener, addr) = setup_test_server();
348
349 thread::spawn(move || {
350 if let Ok((stream, _)) = listener.accept() {
351 drop(stream);
352 }
353 });
354
355 let wrapper = TcpStream::connect(addr).unwrap();
356 let (read_stream, write_stream) = wrapper.split();
357
358 assert_eq!(read_stream.read_timeout, Duration::from_secs(20));
359 assert_eq!(write_stream.write_timeout, Duration::from_secs(2));
360 }
361
362 #[test]
363 fn test_async_read_write() {
364 let (listener, addr) = setup_test_server();
365
366 thread::spawn(move || match listener.accept() {
367 Ok((mut stream, _)) => {
368 let mut buf = [0u8; 1024];
369 loop {
370 let n = stream.read(&mut buf).unwrap();
371 if n == 0 {
372 break;
373 }
374 let _ = stream.write_all(&buf[..n]);
375 }
376 }
377 Err(err) => {
378 eprintln!("server error {:?}", &err);
379 }
380 });
381
382 thread::sleep(Duration::from_millis(10));
383
384 let test_future = async {
385 let mut wrapper = TcpStream::connect(addr).unwrap();
386
387 let test_data = &[1, 2, 3, 4, 5, 6];
388 let written = wrapper.write_all(test_data).await;
389 assert!(written.is_ok());
390
391 let mut buf = [0u8; 1024];
392 let read = wrapper.read_exact(&mut buf[..2]).await;
393 assert!(read.is_ok());
394 let read = wrapper.read_exact(&mut buf[2..test_data.len()]).await;
395 assert!(read.is_ok());
396 assert_eq!(&buf[..test_data.len()], test_data);
397
398 let test_data = &[7, 8, 9, 10];
399 let written = wrapper.write_all(test_data).await;
400 assert!(written.is_ok());
401 let read = wrapper.read(&mut buf).await;
402 assert!(read.is_ok());
403 assert_eq!(&buf[..test_data.len()], test_data);
404 };
405
406 block_on(test_future);
407 }
408
409 #[test]
410 fn test_async_read_write_with_delay() {
411 let (listener, addr) = setup_test_server();
412
413 thread::spawn(move || {
414 if let Ok((mut stream, _)) = listener.accept() {
415 let mut buf = [0u8; 1024];
416 let n = stream.read(&mut buf).unwrap();
417 let half = n / 2;
418 stream.write_all(&buf[..half]).unwrap();
419 thread::sleep(Duration::from_millis(50));
420 stream.write_all(&buf[half..n]).unwrap();
421 }
422 });
423
424 let test_future = async {
425 let mut wrapper = TcpStream::connect(addr).unwrap();
426 let test_data = b"Delayed Hello!";
427 let written = wrapper.write_all(test_data).await;
428 assert!(written.is_ok());
429
430 let mut buf = [0u8; 1024];
431 let read = wrapper.read_exact(&mut buf[..test_data.len()]).await;
432 assert!(read.is_ok());
433 assert_eq!(&buf[..test_data.len()], test_data);
434 };
435
436 block_on(test_future);
437 }
438
439 #[test]
440 fn test_concurrent_operations() {
441 let (listener, addr) = setup_test_server();
442
443 thread::spawn(move || {
444 for _ in 0..3 {
445 if let Ok((mut stream, _)) = listener.accept() {
446 thread::spawn(move || {
447 let mut buf = [0u8; 1024];
448 let n = stream.read(&mut buf).unwrap();
449 let _ = stream.write_all(&buf[..n]);
450 });
451 }
452 }
453 });
454
455 thread::sleep(Duration::from_millis(10));
456
457 let test_future = async {
458 let mut futures = Vec::new();
459
460 for i in 0..3 {
461 let test_data = format!("Message {}", i);
462 let future = async move {
463 let mut client = TcpStream::connect(addr).unwrap();
464 client.write_all(test_data.as_bytes()).await.unwrap();
465 let mut buf = [0u8; 1024];
466 let read_bytes = client.read(&mut buf).await.unwrap();
467 assert_eq!(&buf[..read_bytes], test_data.as_bytes());
468 };
469 futures.push(future);
470 }
471
472 futures::future::join_all(futures).await;
473 };
474
475 block_on(test_future);
476 }
477
478 #[test]
479 fn test_timeout_behavior() {
480 let (listener, addr) = setup_test_server();
481
482 thread::spawn(move || {
483 if let Ok((mut stream, _)) = listener.accept() {
484 let mut buf = [0u8; 1024];
485 let _ = stream.read(&mut buf);
486 thread::sleep(Duration::from_millis(200));
487 let _ = stream.write_all(b"slow response");
488 }
489 });
490
491 thread::sleep(Duration::from_millis(10));
492
493 let test_future = async {
494 let mut wrapper = TcpStream::connect(addr).unwrap();
495
496 wrapper.set_read_timeout(Duration::from_millis(50));
497 wrapper.write_all(b"test").await.unwrap();
498
499 let mut buf = [0u8; 1024];
500 let read_result = wrapper.read(&mut buf).await;
501
502 assert!(read_result.is_err());
503 let err = read_result.unwrap_err();
504 assert_eq!(err.kind(), io::ErrorKind::TimedOut);
505 };
506
507 block_on(test_future);
508 }
509
510 #[test]
511 fn test_shutdown() {
512 let (listener, addr) = setup_test_server();
513
514 thread::spawn(move || {
515 if let Ok((mut stream, _)) = listener.accept() {
516 let mut buf = [0u8; 1024];
517 if let Ok(n) = stream.read(&mut buf) {
518 let _ = stream.write_all(&buf[..n]);
519 }
520 }
521 });
522
523 let wrapper = TcpStream::connect(addr).unwrap();
524 let result = wrapper.shutdown(Shutdown::Both);
525 assert!(result.is_ok());
526 }
527
528 #[test]
529 fn test_split_streams_independently() {
530 let (listener, addr) = setup_test_server();
531
532 thread::spawn(move || {
533 if let Ok((mut stream, _)) = listener.accept() {
534 let mut buf = [0u8; 1024];
535 if let Ok(n) = stream.read(&mut buf) {
536 let _ = stream.write_all(&buf[..n]);
537 }
538 }
539 });
540
541 thread::sleep(Duration::from_millis(10));
542
543 let test_future = async {
544 let wrapper = TcpStream::connect(addr).unwrap();
545 let (mut read_stream, mut write_stream) = wrapper.split();
546
547 let test_data = b"Split stream test";
549 write_stream.write_all(test_data).await.unwrap();
550
551 let mut buf = [0u8; 1024];
552 let read_bytes = read_stream.read(&mut buf).await.unwrap();
553 assert_eq!(&buf[..read_bytes], test_data);
554 };
555
556 block_on(test_future);
557 }
558}