use std::future::Future;
use std::io::{self, IoSlice};
use std::net::{Shutdown, SocketAddr};
use std::num::NonZeroUsize;
use std::pin::Pin;
use std::task::{self, Poll};
#[cfg(target_os = "linux")]
use log::warn;
use mio::{net, Interest};
use heph::actor;
use socket2::SockRef;
use crate::bytes::{Bytes, BytesVectored, MaybeUninitSlice};
use crate::{self as rt, Bound};
#[derive(Debug)]
pub struct TcpStream {
pub(in crate::net) socket: net::TcpStream,
}
impl TcpStream {
pub fn connect<M, RT>(
ctx: &mut actor::Context<M, RT>,
address: SocketAddr,
) -> io::Result<Connect>
where
RT: rt::Access,
{
let mut socket = net::TcpStream::connect(address)?;
ctx.runtime()
.register(&mut socket, Interest::READABLE | Interest::WRITABLE)?;
Ok(Connect {
socket: Some(socket),
#[cfg(target_os = "linux")]
cpu_affinity: ctx.runtime_ref().cpu(),
})
}
pub fn peer_addr(&mut self) -> io::Result<SocketAddr> {
self.socket.peer_addr()
}
pub fn local_addr(&mut self) -> io::Result<SocketAddr> {
self.socket.local_addr()
}
#[cfg(target_os = "linux")]
pub(crate) fn set_cpu_affinity(&mut self, cpu: usize) -> io::Result<()> {
SockRef::from(&self.socket).set_cpu_affinity(cpu)
}
pub fn set_ttl(&mut self, ttl: u32) -> io::Result<()> {
self.socket.set_ttl(ttl)
}
pub fn ttl(&mut self) -> io::Result<u32> {
self.socket.ttl()
}
pub fn set_nodelay(&mut self, nodelay: bool) -> io::Result<()> {
self.socket.set_nodelay(nodelay)
}
pub fn nodelay(&mut self) -> io::Result<bool> {
self.socket.nodelay()
}
pub fn keepalive(&self) -> io::Result<bool> {
let socket = SockRef::from(&self.socket);
socket.keepalive()
}
pub fn set_keepalive(&self, enable: bool) -> io::Result<()> {
let socket = SockRef::from(&self.socket);
socket.set_keepalive(enable)
}
pub fn try_send(&mut self, buf: &[u8]) -> io::Result<usize> {
SockRef::from(&self.socket).send(buf)
}
pub fn send<'a, 'b>(&'a mut self, buf: &'b [u8]) -> Send<'a, 'b> {
Send { stream: self, buf }
}
pub fn send_all<'a, 'b>(&'a mut self, buf: &'b [u8]) -> SendAll<'a, 'b> {
SendAll { stream: self, buf }
}
pub fn try_send_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result<usize> {
SockRef::from(&self.socket).send_vectored(bufs)
}
pub fn send_vectored<'a, 'b>(
&'a mut self,
bufs: &'b mut [IoSlice<'b>],
) -> SendVectored<'a, 'b> {
SendVectored { stream: self, bufs }
}
pub fn send_vectored_all<'a, 'b>(
&'a mut self,
bufs: &'b mut [IoSlice<'b>],
) -> SendVectoredAll<'a, 'b> {
SendVectoredAll { stream: self, bufs }
}
pub fn try_recv<B>(&mut self, mut buf: B) -> io::Result<usize>
where
B: Bytes,
{
debug_assert!(
buf.has_spare_capacity(),
"called `TcpStream::try_recv with an empty buffer"
);
SockRef::from(&self.socket)
.recv(buf.as_bytes())
.map(|read| {
unsafe { buf.update_length(read) }
read
})
}
pub fn recv<'a, B>(&'a mut self, buf: B) -> Recv<'a, B>
where
B: Bytes,
{
Recv { stream: self, buf }
}
pub fn recv_n<'a, B>(&'a mut self, buf: B, n: usize) -> RecvN<'a, B>
where
B: Bytes,
{
debug_assert!(
buf.spare_capacity() >= n,
"called `TcpStream::recv_n` with a buffer smaller then `n`"
);
RecvN {
stream: self,
buf,
left: n,
}
}
pub fn try_recv_vectored<B>(&mut self, mut bufs: B) -> io::Result<usize>
where
B: BytesVectored,
{
debug_assert!(
bufs.has_spare_capacity(),
"called `TcpStream::try_recv_vectored` with empty buffers"
);
let res = SockRef::from(&self.socket)
.recv_vectored(MaybeUninitSlice::as_socket2(bufs.as_bufs().as_mut()));
match res {
Ok((read, _)) => {
unsafe { bufs.update_lengths(read) }
Ok(read)
}
Err(err) => Err(err),
}
}
pub fn recv_vectored<B>(&mut self, bufs: B) -> RecvVectored<'_, B>
where
B: BytesVectored,
{
debug_assert!(
bufs.has_spare_capacity(),
"called `TcpStream::recv_vectored` with empty buffers"
);
RecvVectored { stream: self, bufs }
}
pub fn recv_n_vectored<B>(&mut self, bufs: B, n: usize) -> RecvNVectored<'_, B>
where
B: BytesVectored,
{
debug_assert!(
bufs.spare_capacity() >= n,
"called `TcpStream::recv_n_vectored` with a buffer smaller then `n`"
);
RecvNVectored {
stream: self,
bufs,
left: n,
}
}
pub fn try_peek<B>(&mut self, mut buf: B) -> io::Result<usize>
where
B: Bytes,
{
debug_assert!(
buf.has_spare_capacity(),
"called `TcpStream::try_peek with an empty buffer"
);
SockRef::from(&self.socket)
.peek(buf.as_bytes())
.map(|read| {
unsafe { buf.update_length(read) }
read
})
}
pub fn peek<'a, B>(&'a mut self, buf: B) -> Peek<'a, B>
where
B: Bytes,
{
Peek { stream: self, buf }
}
pub fn try_peek_vectored<B>(&mut self, mut bufs: B) -> io::Result<usize>
where
B: BytesVectored,
{
debug_assert!(
bufs.has_spare_capacity(),
"called `TcpStream::try_peek_vectored` with empty buffers"
);
let res = SockRef::from(&self.socket).recv_vectored_with_flags(
MaybeUninitSlice::as_socket2(bufs.as_bufs().as_mut()),
libc::MSG_PEEK,
);
match res {
Ok((read, _)) => {
unsafe { bufs.update_lengths(read) }
Ok(read)
}
Err(err) => Err(err),
}
}
pub fn peek_vectored<B>(&mut self, bufs: B) -> PeekVectored<'_, B>
where
B: BytesVectored,
{
PeekVectored { stream: self, bufs }
}
pub fn try_send_file<F>(
&mut self,
file: &F,
offset: usize,
length: Option<NonZeroUsize>,
) -> io::Result<usize>
where
F: FileSend,
{
SockRef::from(&self.socket).sendfile(file, offset, length)
}
pub fn send_file<'a, 'f, F>(
&'a mut self,
file: &'f F,
offset: usize,
length: Option<NonZeroUsize>,
) -> SendFile<'a, 'f, F>
where
F: FileSend,
{
SendFile {
stream: self,
file,
offset,
length,
}
}
pub fn send_file_all<'a, 'f, F>(
&'a mut self,
file: &'f F,
offset: usize,
length: Option<NonZeroUsize>,
) -> SendFileAll<'a, 'f, F>
where
F: FileSend,
{
SendFileAll {
stream: self,
file,
start: offset,
end: length.and_then(|length| NonZeroUsize::new(offset + length.get())),
}
}
pub fn send_entire_file<'a, 'f, F>(&'a mut self, file: &'f F) -> SendFileAll<'a, 'f, F>
where
F: FileSend,
{
self.send_file_all(file, 0, None)
}
pub fn shutdown(&mut self, how: Shutdown) -> io::Result<()> {
self.socket.shutdown(how)
}
pub fn take_error(&mut self) -> io::Result<Option<io::Error>> {
self.socket.take_error()
}
}
#[derive(Debug)]
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct Connect {
socket: Option<net::TcpStream>,
#[cfg(target_os = "linux")]
cpu_affinity: Option<usize>,
}
impl Future for Connect {
type Output = io::Result<TcpStream>;
#[track_caller]
fn poll(mut self: Pin<&mut Self>, _: &mut task::Context<'_>) -> Poll<Self::Output> {
match self.socket.take() {
Some(socket) => {
if let Ok(Some(err)) | Err(err) = socket.take_error() {
return Poll::Ready(Err(err));
}
match socket.peer_addr() {
Ok(..) => {
#[allow(unused_mut)]
let mut stream = TcpStream { socket };
#[cfg(target_os = "linux")]
if let Some(cpu) = self.cpu_affinity {
if let Err(err) = stream.set_cpu_affinity(cpu) {
warn!("failed to set CPU affinity on TcpStream: {}", err);
}
}
Poll::Ready(Ok(stream))
}
Err(err)
if err.kind() == io::ErrorKind::NotConnected
|| err.raw_os_error() == Some(libc::EINPROGRESS) =>
{
self.socket = Some(socket);
Poll::Pending
}
Err(err) => Poll::Ready(Err(err)),
}
}
None => panic!("polled `tcp::stream::Connect` after completion"),
}
}
}
#[derive(Debug)]
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct Send<'a, 'b> {
stream: &'a mut TcpStream,
buf: &'b [u8],
}
impl<'a, 'b> Future for Send<'a, 'b> {
type Output = io::Result<usize>;
fn poll(self: Pin<&mut Self>, _: &mut task::Context<'_>) -> Poll<Self::Output> {
let Send { stream, buf } = Pin::into_inner(self);
try_io!(stream.try_send(*buf))
}
}
#[derive(Debug)]
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct SendAll<'a, 'b> {
stream: &'a mut TcpStream,
buf: &'b [u8],
}
impl<'a, 'b> Future for SendAll<'a, 'b> {
type Output = io::Result<()>;
fn poll(self: Pin<&mut Self>, _: &mut task::Context<'_>) -> Poll<Self::Output> {
let SendAll { stream, buf } = Pin::into_inner(self);
loop {
match stream.try_send(*buf) {
Ok(0) => return Poll::Ready(Err(io::ErrorKind::WriteZero.into())),
Ok(n) if buf.len() <= n => return Poll::Ready(Ok(())),
Ok(n) => {
*buf = &buf[n..];
continue;
}
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => break Poll::Pending,
Err(ref err) if err.kind() == io::ErrorKind::Interrupted => continue,
Err(err) => break Poll::Ready(Err(err)),
}
}
}
}
#[derive(Debug)]
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct SendVectored<'a, 'b> {
stream: &'a mut TcpStream,
bufs: &'b mut [IoSlice<'b>],
}
impl<'a, 'b> Future for SendVectored<'a, 'b> {
type Output = io::Result<usize>;
fn poll(self: Pin<&mut Self>, _: &mut task::Context<'_>) -> Poll<Self::Output> {
let SendVectored { stream, bufs } = Pin::into_inner(self);
try_io!(stream.try_send_vectored(*bufs))
}
}
#[derive(Debug)]
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct SendVectoredAll<'a, 'b> {
stream: &'a mut TcpStream,
bufs: &'b mut [IoSlice<'b>],
}
impl<'a, 'b> Future for SendVectoredAll<'a, 'b> {
type Output = io::Result<()>;
fn poll(self: Pin<&mut Self>, _: &mut task::Context<'_>) -> Poll<Self::Output> {
let SendVectoredAll { stream, bufs } = Pin::into_inner(self);
while !bufs.is_empty() {
match stream.try_send_vectored(*bufs) {
Ok(0) => return Poll::Ready(Err(io::ErrorKind::WriteZero.into())),
Ok(n) => IoSlice::advance_slices(bufs, n),
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => return Poll::Pending,
Err(ref err) if err.kind() == io::ErrorKind::Interrupted => continue,
Err(err) => return Poll::Ready(Err(err)),
}
}
Poll::Ready(Ok(()))
}
}
#[derive(Debug)]
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct Recv<'b, B> {
stream: &'b mut TcpStream,
buf: B,
}
impl<'b, B> Future for Recv<'b, B>
where
B: Bytes + Unpin,
{
type Output = io::Result<usize>;
fn poll(self: Pin<&mut Self>, _: &mut task::Context<'_>) -> Poll<Self::Output> {
let Recv { stream, buf } = Pin::into_inner(self);
try_io!(stream.try_recv(&mut *buf))
}
}
#[derive(Debug)]
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct Peek<'b, B> {
stream: &'b mut TcpStream,
buf: B,
}
impl<'b, B> Future for Peek<'b, B>
where
B: Bytes + Unpin,
{
type Output = io::Result<usize>;
fn poll(self: Pin<&mut Self>, _: &mut task::Context<'_>) -> Poll<Self::Output> {
let Peek { stream, buf } = Pin::into_inner(self);
try_io!(stream.try_peek(&mut *buf))
}
}
#[derive(Debug)]
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct RecvN<'b, B> {
stream: &'b mut TcpStream,
buf: B,
left: usize,
}
impl<'b, B> Future for RecvN<'b, B>
where
B: Bytes + Unpin,
{
type Output = io::Result<()>;
fn poll(self: Pin<&mut Self>, _: &mut task::Context<'_>) -> Poll<Self::Output> {
let RecvN { stream, buf, left } = Pin::into_inner(self);
loop {
match stream.try_recv(&mut *buf) {
Ok(0) => return Poll::Ready(Err(io::ErrorKind::UnexpectedEof.into())),
Ok(n) if n >= *left => return Poll::Ready(Ok(())),
Ok(n) => {
*left -= n;
continue;
}
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => break Poll::Pending,
Err(ref err) if err.kind() == io::ErrorKind::Interrupted => continue,
Err(err) => break Poll::Ready(Err(err)),
}
}
}
}
#[derive(Debug)]
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct RecvVectored<'b, B> {
stream: &'b mut TcpStream,
bufs: B,
}
impl<'b, B> Future for RecvVectored<'b, B>
where
B: BytesVectored + Unpin,
{
type Output = io::Result<usize>;
fn poll(self: Pin<&mut Self>, _: &mut task::Context<'_>) -> Poll<Self::Output> {
let RecvVectored { stream, bufs } = Pin::into_inner(self);
try_io!(stream.try_recv_vectored(&mut *bufs))
}
}
#[derive(Debug)]
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct RecvNVectored<'b, B> {
stream: &'b mut TcpStream,
bufs: B,
left: usize,
}
impl<'b, B> Future for RecvNVectored<'b, B>
where
B: BytesVectored + Unpin,
{
type Output = io::Result<()>;
fn poll(self: Pin<&mut Self>, _: &mut task::Context<'_>) -> Poll<Self::Output> {
let RecvNVectored { stream, bufs, left } = Pin::into_inner(self);
loop {
match stream.try_recv_vectored(&mut *bufs) {
Ok(0) => return Poll::Ready(Err(io::ErrorKind::UnexpectedEof.into())),
Ok(n) if n >= *left => return Poll::Ready(Ok(())),
Ok(n) => {
*left -= n;
continue;
}
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => break Poll::Pending,
Err(ref err) if err.kind() == io::ErrorKind::Interrupted => continue,
Err(err) => break Poll::Ready(Err(err)),
}
}
}
}
#[derive(Debug)]
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct PeekVectored<'b, B> {
stream: &'b mut TcpStream,
bufs: B,
}
impl<'b, B> Future for PeekVectored<'b, B>
where
B: BytesVectored + Unpin,
{
type Output = io::Result<usize>;
fn poll(self: Pin<&mut Self>, _: &mut task::Context<'_>) -> Poll<Self::Output> {
let PeekVectored { stream, bufs } = Pin::into_inner(self);
try_io!(stream.try_peek_vectored(&mut *bufs))
}
}
#[derive(Debug)]
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct SendFile<'a, 'f, F> {
stream: &'a mut TcpStream,
file: &'f F,
offset: usize,
length: Option<NonZeroUsize>,
}
impl<'a, 'f, F> Future for SendFile<'a, 'f, F>
where
F: FileSend,
{
type Output = io::Result<usize>;
fn poll(self: Pin<&mut Self>, _: &mut task::Context<'_>) -> Poll<Self::Output> {
#[rustfmt::skip]
let SendFile { stream, file, offset, length } = Pin::into_inner(self);
try_io!(stream.try_send_file(*file, *offset, *length))
}
}
#[derive(Debug)]
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct SendFileAll<'a, 'f, F> {
stream: &'a mut TcpStream,
file: &'f F,
start: usize,
end: Option<NonZeroUsize>,
}
impl<'a, 'f, F> Future for SendFileAll<'a, 'f, F>
where
F: FileSend,
{
type Output = io::Result<()>;
fn poll(self: Pin<&mut Self>, _: &mut task::Context<'_>) -> Poll<Self::Output> {
#[rustfmt::skip]
let SendFileAll { stream, file, start, end } = Pin::into_inner(self);
loop {
let length = end.and_then(|end| NonZeroUsize::new(end.get() - *start));
match stream.try_send_file(*file, *start, length) {
Ok(0) => break Poll::Ready(Ok(())),
Ok(n) => {
*start += n;
match end {
Some(end) if *start >= end.get() => break Poll::Ready(Ok(())),
Some(_) | None => {
continue;
}
}
}
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => break Poll::Pending,
Err(ref err) if err.kind() == io::ErrorKind::Interrupted => continue, Err(err) => break Poll::Ready(Err(err)),
}
}
}
}
pub trait FileSend: PrivateFileSend {}
use private::PrivateFileSend;
mod private {
use std::fs::File;
use std::os::unix::io::AsRawFd;
pub trait PrivateFileSend: AsRawFd {}
impl super::FileSend for File {}
impl PrivateFileSend for File {}
}
impl<RT: rt::Access> Bound<RT> for TcpStream {
type Error = io::Error;
fn bind_to<M>(&mut self, ctx: &mut actor::Context<M, RT>) -> io::Result<()> {
ctx.runtime()
.reregister(&mut self.socket, Interest::READABLE | Interest::WRITABLE)
}
}