use std::future::Future;
use std::io;
use std::net::SocketAddr;
use std::pin::Pin;
use std::task::{Context, Poll};
use crate::platform::sys::{set_nonblocking, Interest};
use crate::reactor::source::{next_token, IoSource};
use super::sockaddr::{reclaim_raw_sockaddr, sockaddr_to_socketaddr, socketaddr_to_raw};
use super::{AsyncRead, AsyncWrite};
pub struct TcpStream {
source: IoSource,
}
impl TcpStream {
pub fn connect(addr: SocketAddr) -> ConnectFuture {
ConnectFuture::new(addr)
}
pub(crate) fn from_raw_fd(fd: i32) -> io::Result<Self> {
let source = IoSource::new(fd, next_token(), Interest::READABLE | Interest::WRITABLE)?;
Ok(Self { source })
}
pub fn peer_addr(&self) -> io::Result<SocketAddr> {
peer_addr(self.source.raw())
}
pub fn local_addr(&self) -> io::Result<SocketAddr> {
local_addr(self.source.raw())
}
}
impl Drop for TcpStream {
fn drop(&mut self) {
let fd = self.source.raw();
unsafe { libc::close(fd) };
}
}
#[cfg(unix)]
impl std::os::unix::io::AsRawFd for TcpStream {
fn as_raw_fd(&self) -> std::os::unix::io::RawFd {
self.source.raw()
}
}
impl AsyncRead for TcpStream {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
let fd = self.source.raw();
let n = unsafe { libc::read(fd, buf.as_mut_ptr() as *mut libc::c_void, buf.len()) };
if n > 0 {
return Poll::Ready(Ok(n as usize));
}
if n == 0 {
return Poll::Ready(Ok(0)); }
let err = io::Error::last_os_error();
if err.kind() != io::ErrorKind::WouldBlock {
return Poll::Ready(Err(err));
}
match Pin::new(&mut self.source.readable()).poll(cx) {
Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
Poll::Ready(Ok(())) | Poll::Pending => Poll::Pending,
}
}
}
impl AsyncWrite for TcpStream {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
let fd = self.source.raw();
let n = unsafe { libc::write(fd, buf.as_ptr() as *const libc::c_void, buf.len()) };
if n >= 0 {
return Poll::Ready(Ok(n as usize));
}
let err = io::Error::last_os_error();
if err.kind() != io::ErrorKind::WouldBlock {
return Poll::Ready(Err(err));
}
match Pin::new(&mut self.source.writable()).poll(cx) {
Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
Poll::Ready(Ok(())) | Poll::Pending => Poll::Pending,
}
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let fd = self.source.raw();
let rc = unsafe { libc::shutdown(fd, libc::SHUT_WR) };
if rc == -1 {
Poll::Ready(Err(io::Error::last_os_error()))
} else {
Poll::Ready(Ok(()))
}
}
}
pub struct ConnectFuture {
state: ConnectState,
}
enum ConnectState {
Init(SocketAddr),
Connecting {
fd: i32,
token: usize,
registered: bool,
},
Done,
}
impl ConnectFuture {
fn new(addr: SocketAddr) -> Self {
Self {
state: ConnectState::Init(addr),
}
}
}
impl Future for ConnectFuture {
type Output = io::Result<TcpStream>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
loop {
match &mut self.state {
ConnectState::Init(addr) => {
let addr = *addr;
match start_connect(addr) {
Err(e) => {
self.state = ConnectState::Done;
return Poll::Ready(Err(e));
}
Ok((fd, connected)) => {
if connected {
self.state = ConnectState::Done;
return Poll::Ready(TcpStream::from_raw_fd(fd));
}
let token = next_token();
if let Err(e) = crate::reactor::with_reactor(|r| {
r.register(fd, token, Interest::WRITABLE)
}) {
unsafe { libc::close(fd) };
self.state = ConnectState::Done;
return Poll::Ready(Err(e));
}
self.state = ConnectState::Connecting {
fd,
token,
registered: true,
};
}
}
}
ConnectState::Connecting { fd, token, .. } => {
let fd = *fd;
let token = *token;
crate::reactor::with_reactor_mut(|r| {
r.wakers.set_write_waker(token, cx.waker().clone());
});
match get_so_error(fd) {
Err(e) => {
let _ = crate::reactor::with_reactor_mut(|r| {
r.deregister_with_token(fd, token)
});
self.state = ConnectState::Done;
return Poll::Ready(Err(e));
}
Ok(Some(os_err)) => {
let _ = crate::reactor::with_reactor_mut(|r| {
r.deregister_with_token(fd, token)
});
unsafe { libc::close(fd) };
self.state = ConnectState::Done;
return Poll::Ready(Err(io::Error::from_raw_os_error(os_err)));
}
Ok(None) => {
if is_writable_now(fd) {
let _ = crate::reactor::with_reactor_mut(|r| {
r.deregister_with_token(fd, token)
});
self.state = ConnectState::Done;
return Poll::Ready(TcpStream::from_raw_fd(fd));
}
return Poll::Pending;
}
}
}
ConnectState::Done => {
return Poll::Ready(Err(io::Error::other(
"ConnectFuture polled after completion",
)));
}
}
}
}
}
impl Drop for ConnectFuture {
fn drop(&mut self) {
if let ConnectState::Connecting { fd, token, .. } = self.state {
let _ = crate::reactor::with_reactor_mut(|r| r.deregister_with_token(fd, token));
unsafe { libc::close(fd) };
}
}
}
fn is_writable_now(fd: i32) -> bool {
unsafe {
let mut pfd = libc::pollfd {
fd,
events: libc::POLLOUT,
revents: 0,
};
let n = libc::poll(&mut pfd, 1, 0);
n > 0 && (pfd.revents & libc::POLLOUT) != 0
}
}
fn start_connect(addr: SocketAddr) -> io::Result<(i32, bool)> {
let family = match addr {
SocketAddr::V4(_) => libc::AF_INET,
SocketAddr::V6(_) => libc::AF_INET6,
};
let fd = unsafe { libc::socket(family, libc::SOCK_STREAM, 0) };
if fd == -1 {
return Err(io::Error::last_os_error());
}
set_nonblocking(fd)?;
let (sa, sa_len) = socketaddr_to_raw(addr);
let rc = unsafe { libc::connect(fd, sa, sa_len) };
unsafe { reclaim_raw_sockaddr(sa, addr) };
if rc == 0 {
return Ok((fd, true)); }
let err = io::Error::last_os_error();
if err.raw_os_error() == Some(libc::EINPROGRESS) {
return Ok((fd, false));
}
unsafe { libc::close(fd) };
Err(err)
}
fn get_so_error(fd: i32) -> io::Result<Option<i32>> {
let mut val: libc::c_int = 0;
let mut len = std::mem::size_of_val(&val) as libc::socklen_t;
let rc = unsafe {
libc::getsockopt(
fd,
libc::SOL_SOCKET,
libc::SO_ERROR,
&mut val as *mut libc::c_int as *mut libc::c_void,
&mut len,
)
};
if rc == -1 {
return Err(io::Error::last_os_error());
}
Ok(if val == 0 { None } else { Some(val) })
}
fn peer_addr(fd: i32) -> io::Result<SocketAddr> {
let mut addr: libc::sockaddr_in6 = unsafe { std::mem::zeroed() };
let mut len = std::mem::size_of_val(&addr) as libc::socklen_t;
let rc = unsafe { libc::getpeername(fd, &mut addr as *mut _ as *mut libc::sockaddr, &mut len) };
if rc == -1 {
return Err(io::Error::last_os_error());
}
sockaddr_to_socketaddr(&addr, len)
}
fn local_addr(fd: i32) -> io::Result<SocketAddr> {
let mut addr: libc::sockaddr_in6 = unsafe { std::mem::zeroed() };
let mut len = std::mem::size_of_val(&addr) as libc::socklen_t;
let rc = unsafe { libc::getsockname(fd, &mut addr as *mut _ as *mut libc::sockaddr, &mut len) };
if rc == -1 {
return Err(io::Error::last_os_error());
}
sockaddr_to_socketaddr(&addr, len)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::executor::block_on_with_spawn;
use crate::net::TcpListener;
async fn read_exact(stream: &mut TcpStream, buf: &mut [u8]) {
use std::future::poll_fn;
let mut filled = 0;
while filled < buf.len() {
let n = poll_fn(|cx| Pin::new(&mut *stream).poll_read(cx, &mut buf[filled..]))
.await
.expect("read_exact: io error");
if n == 0 {
break;
} filled += n;
}
}
async fn write_all(stream: &mut TcpStream, buf: &[u8]) {
use std::future::poll_fn;
let mut sent = 0;
while sent < buf.len() {
let n = poll_fn(|cx| Pin::new(&mut *stream).poll_write(cx, &buf[sent..]))
.await
.expect("write_all: io error");
sent += n;
}
}
#[test]
fn tcp_connect_and_echo() {
block_on_with_spawn(async {
let listener = TcpListener::bind("127.0.0.1:0".parse().unwrap()).unwrap();
let addr = listener.local_addr().unwrap();
let server = crate::spawn(async move {
let (mut stream, _peer) = listener.accept().await.unwrap();
let mut buf = [0u8; 5];
read_exact(&mut stream, &mut buf).await;
buf
});
let mut client = TcpStream::connect(addr).await.unwrap();
write_all(&mut client, b"hello").await;
use std::future::poll_fn;
poll_fn(|cx| Pin::new(&mut client).poll_shutdown(cx))
.await
.expect("shutdown failed");
let received = server.await.unwrap();
assert_eq!(&received, b"hello");
});
}
#[test]
fn tcp_stream_connect_and_write_read() {
block_on_with_spawn(async {
let listener = TcpListener::bind("127.0.0.1:0".parse().unwrap()).unwrap();
let addr = listener.local_addr().unwrap();
let jh = crate::spawn(async move {
let mut client = TcpStream::connect(addr).await.unwrap();
write_all(&mut client, b"hello").await;
});
let (mut server, _) = listener.accept().await.unwrap();
let mut buf = [0u8; 5];
read_exact(&mut server, &mut buf).await;
assert_eq!(&buf, b"hello");
jh.await.unwrap();
});
}
#[test]
fn tcp_stream_echo_roundtrip() {
block_on_with_spawn(async {
let listener = TcpListener::bind("127.0.0.1:0".parse().unwrap()).unwrap();
let addr = listener.local_addr().unwrap();
let jh = crate::spawn(async move {
let (mut conn, _) = listener.accept().await.unwrap();
let mut buf = [0u8; 4];
read_exact(&mut conn, &mut buf).await;
write_all(&mut conn, &buf).await;
});
let mut client = TcpStream::connect(addr).await.unwrap();
write_all(&mut client, b"ping").await;
let mut buf = [0u8; 4];
read_exact(&mut client, &mut buf).await;
assert_eq!(&buf, b"ping");
jh.await.unwrap();
});
}
#[test]
fn tcp_stream_connect_refused_returns_err() {
let result = block_on_with_spawn(async {
TcpStream::connect("127.0.0.1:1".parse().unwrap()).await
});
assert!(result.is_err());
}
#[test]
fn tcp_stream_local_and_peer_addr() {
block_on_with_spawn(async {
let listener = TcpListener::bind("127.0.0.1:0".parse().unwrap()).unwrap();
let server_addr = listener.local_addr().unwrap();
let jh = crate::spawn(async move { listener.accept().await.unwrap() });
let client = TcpStream::connect(server_addr).await.unwrap();
assert_eq!(client.peer_addr().unwrap(), server_addr);
assert_eq!(client.local_addr().unwrap().ip().to_string(), "127.0.0.1");
drop(jh);
});
}
#[test]
fn tcp_stream_large_payload() {
block_on_with_spawn(async {
let listener = TcpListener::bind("127.0.0.1:0".parse().unwrap()).unwrap();
let addr = listener.local_addr().unwrap();
let payload_size = 4096usize;
let jh = crate::spawn(async move {
let mut client = TcpStream::connect(addr).await.unwrap();
let data = vec![0xABu8; payload_size];
write_all(&mut client, &data).await;
});
let (mut server, _) = listener.accept().await.unwrap();
let mut buf = vec![0u8; payload_size];
read_exact(&mut server, &mut buf).await;
assert!(buf.iter().all(|&b| b == 0xAB));
jh.await.unwrap();
});
}
#[test]
fn tcp_stream_multiple_connections_sequential() {
block_on_with_spawn(async {
let listener = TcpListener::bind("127.0.0.1:0".parse().unwrap()).unwrap();
let addr = listener.local_addr().unwrap();
for i in 0u8..3 {
let a = addr;
let jh = crate::spawn(async move {
let mut client = TcpStream::connect(a).await.unwrap();
write_all(&mut client, &[i]).await;
});
let (mut server, _) = listener.accept().await.unwrap();
let mut buf = [0u8; 1];
read_exact(&mut server, &mut buf).await;
assert_eq!(buf[0], i);
jh.await.unwrap();
}
});
}
#[test]
fn tcp_stream_shutdown_write_half() {
use std::future::poll_fn;
block_on_with_spawn(async {
let listener = TcpListener::bind("127.0.0.1:0".parse().unwrap()).unwrap();
let addr = listener.local_addr().unwrap();
let jh = crate::spawn(async move {
let mut client = TcpStream::connect(addr).await.unwrap();
poll_fn(|cx| Pin::new(&mut client).poll_shutdown(cx))
.await
.unwrap();
});
let (_server, _) = listener.accept().await.unwrap();
jh.await.unwrap();
});
}
}