use crate::common::ready_future::ReadyFuture;
use crate::common::ready_future_state::ReadyFutureResult;
use crate::net::event_listener;
use futures::{AsyncRead, AsyncWrite, FutureExt};
use mio::Token;
use mio::net::TcpStream as MioTcpStream;
use std::io::{self, ErrorKind};
use std::net::{Shutdown, SocketAddr, ToSocketAddrs};
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::{Duration, Instant};
pub struct TcpReadStream {
tcp_stream: MioTcpStream,
read_token: Token,
read_future: Option<ReadyFuture<()>>,
pub read_timeout: Duration,
}
impl TcpReadStream {
pub fn new(tcp_stream: MioTcpStream) -> Self {
TcpReadStream {
tcp_stream,
read_token: event_listener().next_token(),
read_future: None,
read_timeout: Duration::from_secs(20),
}
}
pub fn set_read_timeout(&mut self, duration: Duration) {
self.read_timeout = duration;
}
fn wait_read_data(&mut self) -> io::Result<()> {
let future = event_listener().listen_read(
&mut self.tcp_stream,
Instant::now() + self.read_timeout,
self.read_token,
)?;
self.read_future = Some(future);
Ok(())
}
fn poll_read_attempt(
&mut self,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
let mut future = match self.read_future.take() {
None => {
match io::Read::read(&mut self.tcp_stream, buf) {
Ok(size) => return Poll::Ready(Ok(size)),
Err(err) if err.kind() == ErrorKind::WouldBlock => (),
Err(err) => return Poll::Ready(Err(err)),
}
if let Err(err) = self.wait_read_data() {
return Poll::Ready(Err(err));
}
self.read_future.take().unwrap()
}
Some(future) => future,
};
match future.poll_unpin(cx) {
Poll::Pending => {
self.read_future = Some(future);
Poll::Pending
}
Poll::Ready(ReadyFutureResult::Timeout) => {
Poll::Ready(Err(io::ErrorKind::TimedOut.into()))
}
Poll::Ready(_) => match io::Read::read(&mut self.tcp_stream, buf) {
Ok(size) => Poll::Ready(Ok(size)),
Err(err) => Poll::Ready(Err(err)),
},
}
}
}
impl Drop for TcpReadStream {
fn drop(&mut self) {
event_listener()
.stop_listening(&mut self.tcp_stream, self.read_token)
.ok();
}
}
impl AsyncRead for TcpReadStream {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
let me = self.get_mut();
me.poll_read_attempt(cx, buf)
}
}
pub struct TcpWriteStream {
tcp_stream: MioTcpStream,
write_token: Token,
write_future: Option<ReadyFuture<()>>,
pub write_timeout: Duration,
}
impl TcpWriteStream {
pub fn new(tcp_stream: MioTcpStream) -> Self {
TcpWriteStream {
tcp_stream,
write_token: event_listener().next_token(),
write_future: None,
write_timeout: Duration::from_secs(2),
}
}
pub fn set_write_timeout(&mut self, duration: Duration) {
self.write_timeout = duration;
}
fn wait_write_channel(&mut self) -> io::Result<()> {
let future = event_listener().listen_write(
&mut self.tcp_stream,
Instant::now() + self.write_timeout,
self.write_token,
)?;
self.write_future = Some(future);
Ok(())
}
fn poll_write_attempt(&mut self, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
let mut future = match self.write_future.take() {
None => {
match io::Write::write(&mut self.tcp_stream, buf) {
Ok(size) => return Poll::Ready(Ok(size)),
Err(err) if err.kind() == ErrorKind::WouldBlock => (),
Err(err) => return Poll::Ready(Err(err)),
}
if let Err(err) = self.wait_write_channel() {
return Poll::Ready(Err(err));
}
self.write_future.take().unwrap()
}
Some(future) => future,
};
match future.poll_unpin(cx) {
Poll::Pending => {
self.write_future = Some(future);
Poll::Pending
}
Poll::Ready(ReadyFutureResult::Timeout) => {
Poll::Ready(Err(io::ErrorKind::TimedOut.into()))
}
Poll::Ready(_) => match io::Write::write(&mut self.tcp_stream, buf) {
Ok(size) => Poll::Ready(Ok(size)),
Err(err) => Poll::Ready(Err(err)),
},
}
}
pub fn shutdown(&self, how: Shutdown) -> io::Result<()> {
self.tcp_stream.shutdown(how)
}
}
impl Drop for TcpWriteStream {
fn drop(&mut self) {
event_listener()
.stop_listening(&mut self.tcp_stream, self.write_token)
.ok();
}
}
impl AsyncWrite for TcpWriteStream {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
let me = self.get_mut();
me.poll_write_attempt(cx, buf)
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let me = self.get_mut();
me.shutdown(Shutdown::Write)?;
Poll::Ready(Ok(()))
}
}
pub struct TcpStream {
read_stream: TcpReadStream,
write_stream: TcpWriteStream,
}
impl TcpStream {
pub fn from(tcp_stream: std::net::TcpStream) -> io::Result<TcpStream> {
tcp_stream.set_nonblocking(true)?;
Ok(TcpStream {
read_stream: TcpReadStream::new(MioTcpStream::from_std(tcp_stream.try_clone()?)),
write_stream: TcpWriteStream::new(MioTcpStream::from_std(tcp_stream)),
})
}
pub fn connect<A: ToSocketAddrs>(addr: A) -> io::Result<TcpStream> {
Self::from(std::net::TcpStream::connect(addr)?)
}
pub fn read_stream(&self) -> &TcpReadStream {
&self.read_stream
}
pub fn read_stream_mut(&mut self) -> &mut TcpReadStream {
&mut self.read_stream
}
pub fn write_stream(&self) -> &TcpWriteStream {
&self.write_stream
}
pub fn write_stream_mut(&mut self) -> &mut TcpWriteStream {
&mut self.write_stream
}
pub fn split(self) -> (TcpReadStream, TcpWriteStream) {
(self.read_stream, self.write_stream)
}
pub fn set_read_timeout(&mut self, duration: Duration) {
self.read_stream.set_read_timeout(duration);
}
pub fn set_write_timeout(&mut self, duration: Duration) {
self.write_stream.set_write_timeout(duration);
}
pub fn local_addr(&self) -> io::Result<SocketAddr> {
self.read_stream.tcp_stream.local_addr()
}
pub fn peer_addr(&self) -> io::Result<SocketAddr> {
self.read_stream.tcp_stream.peer_addr()
}
pub fn shutdown(&self, how: Shutdown) -> io::Result<()> {
self.read_stream.tcp_stream.shutdown(how)
}
}
impl AsyncRead for TcpStream {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
let me = self.get_mut();
Pin::new(&mut me.read_stream).poll_read(cx, buf)
}
}
impl AsyncWrite for TcpStream {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
let me = self.get_mut();
Pin::new(&mut me.write_stream).poll_write(cx, buf)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let me = self.get_mut();
Pin::new(&mut me.write_stream).poll_flush(cx)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let me = self.get_mut();
Pin::new(&mut me.write_stream).poll_close(cx)
}
}
#[cfg(test)]
mod tests {
use super::*;
use futures::executor::block_on;
use futures::io::{AsyncReadExt, AsyncWriteExt};
use std::io::{Read, Write};
use std::net::TcpListener;
use std::thread;
use std::time::Duration;
fn setup_test_server() -> (TcpListener, std::net::SocketAddr) {
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let addr = listener.local_addr().unwrap();
(listener, addr)
}
#[test]
fn test_tcp_stream_wrapper_creation() {
let (listener, addr) = setup_test_server();
thread::spawn(move || {
if let Ok((stream, _)) = listener.accept() {
drop(stream);
}
});
let wrapper = TcpStream::connect(addr);
assert!(wrapper.is_ok());
let wrapper = wrapper.unwrap();
assert_eq!(wrapper.peer_addr().unwrap(), addr);
}
#[test]
fn test_stream_accessors() {
let (listener, addr) = setup_test_server();
thread::spawn(move || {
if let Ok((stream, _)) = listener.accept() {
drop(stream);
}
});
let mut wrapper = TcpStream::connect(addr).unwrap();
let read_stream = wrapper.read_stream();
assert_eq!(read_stream.read_timeout, Duration::from_secs(20));
let read_stream_mut = wrapper.read_stream_mut();
read_stream_mut.set_read_timeout(Duration::from_secs(15));
assert_eq!(read_stream_mut.read_timeout, Duration::from_secs(15));
let write_stream = wrapper.write_stream();
assert_eq!(write_stream.write_timeout, Duration::from_secs(2));
let write_stream_mut = wrapper.write_stream_mut();
write_stream_mut.set_write_timeout(Duration::from_secs(10));
assert_eq!(write_stream_mut.write_timeout, Duration::from_secs(10));
}
#[test]
fn test_stream_split() {
let (listener, addr) = setup_test_server();
thread::spawn(move || {
if let Ok((stream, _)) = listener.accept() {
drop(stream);
}
});
let wrapper = TcpStream::connect(addr).unwrap();
let (read_stream, write_stream) = wrapper.split();
assert_eq!(read_stream.read_timeout, Duration::from_secs(20));
assert_eq!(write_stream.write_timeout, Duration::from_secs(2));
}
#[test]
fn test_async_read_write() {
let (listener, addr) = setup_test_server();
thread::spawn(move || match listener.accept() {
Ok((mut stream, _)) => {
let mut buf = [0u8; 1024];
loop {
let n = stream.read(&mut buf).unwrap();
if n == 0 {
break;
}
let _ = stream.write_all(&buf[..n]);
}
}
Err(err) => {
eprintln!("server error {:?}", &err);
}
});
thread::sleep(Duration::from_millis(10));
let test_future = async {
let mut wrapper = TcpStream::connect(addr).unwrap();
let test_data = &[1, 2, 3, 4, 5, 6];
let written = wrapper.write_all(test_data).await;
assert!(written.is_ok());
let mut buf = [0u8; 1024];
let read = wrapper.read_exact(&mut buf[..2]).await;
assert!(read.is_ok());
let read = wrapper.read_exact(&mut buf[2..test_data.len()]).await;
assert!(read.is_ok());
assert_eq!(&buf[..test_data.len()], test_data);
let test_data = &[7, 8, 9, 10];
let written = wrapper.write_all(test_data).await;
assert!(written.is_ok());
let read = wrapper.read(&mut buf).await;
assert!(read.is_ok());
assert_eq!(&buf[..test_data.len()], test_data);
};
block_on(test_future);
}
#[test]
fn test_async_read_write_with_delay() {
let (listener, addr) = setup_test_server();
thread::spawn(move || {
if let Ok((mut stream, _)) = listener.accept() {
let mut buf = [0u8; 1024];
let n = stream.read(&mut buf).unwrap();
let half = n / 2;
stream.write_all(&buf[..half]).unwrap();
thread::sleep(Duration::from_millis(50));
stream.write_all(&buf[half..n]).unwrap();
}
});
let test_future = async {
let mut wrapper = TcpStream::connect(addr).unwrap();
let test_data = b"Delayed Hello!";
let written = wrapper.write_all(test_data).await;
assert!(written.is_ok());
let mut buf = [0u8; 1024];
let read = wrapper.read_exact(&mut buf[..test_data.len()]).await;
assert!(read.is_ok());
assert_eq!(&buf[..test_data.len()], test_data);
};
block_on(test_future);
}
#[test]
fn test_concurrent_operations() {
let (listener, addr) = setup_test_server();
thread::spawn(move || {
for _ in 0..3 {
if let Ok((mut stream, _)) = listener.accept() {
thread::spawn(move || {
let mut buf = [0u8; 1024];
let n = stream.read(&mut buf).unwrap();
let _ = stream.write_all(&buf[..n]);
});
}
}
});
thread::sleep(Duration::from_millis(10));
let test_future = async {
let mut futures = Vec::new();
for i in 0..3 {
let test_data = format!("Message {}", i);
let future = async move {
let mut client = TcpStream::connect(addr).unwrap();
client.write_all(test_data.as_bytes()).await.unwrap();
let mut buf = [0u8; 1024];
let read_bytes = client.read(&mut buf).await.unwrap();
assert_eq!(&buf[..read_bytes], test_data.as_bytes());
};
futures.push(future);
}
futures::future::join_all(futures).await;
};
block_on(test_future);
}
#[test]
fn test_timeout_behavior() {
let (listener, addr) = setup_test_server();
thread::spawn(move || {
if let Ok((mut stream, _)) = listener.accept() {
let mut buf = [0u8; 1024];
let _ = stream.read(&mut buf);
thread::sleep(Duration::from_millis(200));
let _ = stream.write_all(b"slow response");
}
});
thread::sleep(Duration::from_millis(10));
let test_future = async {
let mut wrapper = TcpStream::connect(addr).unwrap();
wrapper.set_read_timeout(Duration::from_millis(50));
wrapper.write_all(b"test").await.unwrap();
let mut buf = [0u8; 1024];
let read_result = wrapper.read(&mut buf).await;
assert!(read_result.is_err());
let err = read_result.unwrap_err();
assert_eq!(err.kind(), io::ErrorKind::TimedOut);
};
block_on(test_future);
}
#[test]
fn test_shutdown() {
let (listener, addr) = setup_test_server();
thread::spawn(move || {
if let Ok((mut stream, _)) = listener.accept() {
let mut buf = [0u8; 1024];
if let Ok(n) = stream.read(&mut buf) {
let _ = stream.write_all(&buf[..n]);
}
}
});
let wrapper = TcpStream::connect(addr).unwrap();
let result = wrapper.shutdown(Shutdown::Both);
assert!(result.is_ok());
}
#[test]
fn test_split_streams_independently() {
let (listener, addr) = setup_test_server();
thread::spawn(move || {
if let Ok((mut stream, _)) = listener.accept() {
let mut buf = [0u8; 1024];
if let Ok(n) = stream.read(&mut buf) {
let _ = stream.write_all(&buf[..n]);
}
}
});
thread::sleep(Duration::from_millis(10));
let test_future = async {
let wrapper = TcpStream::connect(addr).unwrap();
let (mut read_stream, mut write_stream) = wrapper.split();
let test_data = b"Split stream test";
write_stream.write_all(test_data).await.unwrap();
let mut buf = [0u8; 1024];
let read_bytes = read_stream.read(&mut buf).await.unwrap();
assert_eq!(&buf[..read_bytes], test_data);
};
block_on(test_future);
}
}