use std::io::{IoSlice, IoSliceMut};
use std::mem;
use std::net::{self, SocketAddr, ToSocketAddrs};
use std::pin::Pin;
use cfg_if::cfg_if;
use futures::future;
use futures::io::{AsyncRead, AsyncWrite};
use crate::io;
use crate::net::driver::IoHandle;
use crate::task::{Context, Poll};
#[derive(Debug)]
pub struct TcpStream {
pub(super) io_handle: IoHandle<mio::net::TcpStream>,
#[cfg(unix)]
pub(super) raw_fd: std::os::unix::io::RawFd,
}
impl TcpStream {
pub async fn connect<A: ToSocketAddrs>(addrs: A) -> io::Result<TcpStream> {
enum State {
Waiting(TcpStream),
Error(io::Error),
Done,
}
let mut last_err = None;
for addr in addrs.to_socket_addrs()? {
let mut state = {
match mio::net::TcpStream::connect(&addr) {
Ok(mio_stream) => {
#[cfg(unix)]
let stream = TcpStream {
raw_fd: mio_stream.as_raw_fd(),
io_handle: IoHandle::new(mio_stream),
};
#[cfg(windows)]
let stream = TcpStream {
io_handle: IoHandle::new(mio_stream),
};
State::Waiting(stream)
}
Err(err) => State::Error(err),
}
};
let res = future::poll_fn(|cx| {
match mem::replace(&mut state, State::Done) {
State::Waiting(stream) => {
if let Poll::Pending = stream.io_handle.poll_writable(cx)? {
state = State::Waiting(stream);
return Poll::Pending;
}
if let Some(err) = stream.io_handle.get_ref().take_error()? {
return Poll::Ready(Err(err));
}
Poll::Ready(Ok(stream))
}
State::Error(err) => Poll::Ready(Err(err)),
State::Done => panic!("`TcpStream::connect()` future polled after completion"),
}
})
.await;
match res {
Ok(stream) => return Ok(stream),
Err(err) => last_err = Some(err),
}
}
Err(last_err.unwrap_or_else(|| {
io::Error::new(
io::ErrorKind::InvalidInput,
"could not resolve to any addresses",
)
}))
}
pub fn local_addr(&self) -> io::Result<SocketAddr> {
self.io_handle.get_ref().local_addr()
}
pub fn peer_addr(&self) -> io::Result<SocketAddr> {
self.io_handle.get_ref().peer_addr()
}
pub fn ttl(&self) -> io::Result<u32> {
self.io_handle.get_ref().ttl()
}
pub fn set_ttl(&self, ttl: u32) -> io::Result<()> {
self.io_handle.get_ref().set_ttl(ttl)
}
pub async fn peek(&self, buf: &mut [u8]) -> io::Result<usize> {
let res = future::poll_fn(|cx| {
futures::ready!(self.io_handle.poll_readable(cx)?);
match self.io_handle.get_ref().peek(buf) {
Ok(len) => Poll::Ready(Ok(len)),
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => {
self.io_handle.clear_readable(cx)?;
Poll::Pending
}
Err(e) => Poll::Ready(Err(e)),
}
})
.await?;
Ok(res)
}
pub fn nodelay(&self) -> io::Result<bool> {
self.io_handle.get_ref().nodelay()
}
pub fn set_nodelay(&self, nodelay: bool) -> io::Result<()> {
self.io_handle.get_ref().set_nodelay(nodelay)
}
pub fn shutdown(&self, how: std::net::Shutdown) -> std::io::Result<()> {
self.io_handle.get_ref().shutdown(how)
}
}
impl AsyncRead for TcpStream {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut &*self).poll_read(cx, buf)
}
fn poll_read_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &mut [IoSliceMut<'_>],
) -> Poll<io::Result<usize>> {
Pin::new(&mut &*self).poll_read_vectored(cx, bufs)
}
}
impl AsyncRead for &TcpStream {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut &self.io_handle).poll_read(cx, buf)
}
fn poll_read_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &mut [IoSliceMut<'_>],
) -> Poll<io::Result<usize>> {
Pin::new(&mut &self.io_handle).poll_read_vectored(cx, bufs)
}
}
impl AsyncWrite for TcpStream {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut &*self).poll_write(cx, buf)
}
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[IoSlice<'_>],
) -> Poll<io::Result<usize>> {
Pin::new(&mut &*self).poll_write_vectored(cx, bufs)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut &*self).poll_flush(cx)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut &*self).poll_close(cx)
}
}
impl AsyncWrite for &TcpStream {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut &self.io_handle).poll_write(cx, buf)
}
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[IoSlice<'_>],
) -> Poll<io::Result<usize>> {
Pin::new(&mut &self.io_handle).poll_write_vectored(cx, bufs)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut &self.io_handle).poll_flush(cx)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut &self.io_handle).poll_close(cx)
}
}
impl From<net::TcpStream> for TcpStream {
fn from(stream: net::TcpStream) -> TcpStream {
let mio_stream = mio::net::TcpStream::from_stream(stream).unwrap();
#[cfg(unix)]
let stream = TcpStream {
raw_fd: mio_stream.as_raw_fd(),
io_handle: IoHandle::new(mio_stream),
};
#[cfg(windows)]
let stream = TcpStream {
io_handle: IoHandle::new(mio_stream),
};
stream
}
}
cfg_if! {
if #[cfg(feature = "docs")] {
use crate::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd};
} else if #[cfg(unix)] {
use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd};
} else if #[cfg(windows)] {
}
}
#[cfg_attr(feature = "docs", doc(cfg(unix)))]
cfg_if! {
if #[cfg(any(unix, feature = "docs"))] {
impl AsRawFd for TcpStream {
fn as_raw_fd(&self) -> RawFd {
self.raw_fd
}
}
impl FromRawFd for TcpStream {
unsafe fn from_raw_fd(fd: RawFd) -> TcpStream {
net::TcpStream::from_raw_fd(fd).into()
}
}
impl IntoRawFd for TcpStream {
fn into_raw_fd(self) -> RawFd {
self.raw_fd
}
}
}
}
#[cfg_attr(feature = "docs", doc(cfg(windows)))]
cfg_if! {
if #[cfg(any(windows, feature = "docs"))] {
}
}