#[cfg(unix)]
use std::os::unix::prelude::{AsRawFd, FromRawFd, IntoRawFd, RawFd};
#[cfg(windows)]
use std::os::windows::prelude::{AsRawHandle, IntoRawHandle, RawHandle};
use std::{
cell::UnsafeCell,
future::Future,
io,
net::{SocketAddr, ToSocketAddrs},
time::Duration,
};
use crate::{
buf::{IoBuf, IoBufMut, IoVecBuf, IoVecBufMut},
driver::{op::Op, shared_fd::SharedFd},
io::{
as_fd::{AsReadFd, AsWriteFd, SharedFdWrapper},
AsyncReadRent, AsyncWriteRent, Split,
},
};
const EMPTY_SLICE: [u8; 0] = [];
pub struct TcpStream {
fd: SharedFd,
meta: StreamMeta,
}
unsafe impl Split for TcpStream {}
impl TcpStream {
pub(crate) fn from_shared_fd(fd: SharedFd) -> Self {
#[cfg(unix)]
let meta = StreamMeta::new(fd.raw_fd());
#[cfg(windows)]
let meta = StreamMeta::new(fd.raw_handle());
#[cfg(feature = "zero-copy")]
meta.set_zero_copy();
Self { fd, meta }
}
pub async fn connect<A: ToSocketAddrs>(addr: A) -> io::Result<Self> {
let addr = addr
.to_socket_addrs()?
.next()
.ok_or_else(|| io::Error::new(io::ErrorKind::Other, "empty address"))?;
Self::connect_addr(addr).await
}
#[cfg(unix)]
pub async fn connect_addr(addr: SocketAddr) -> io::Result<Self> {
let op = Op::connect(libc::SOCK_STREAM, addr)?;
let completion = op.await;
completion.meta.result?;
let mut stream = TcpStream::from_shared_fd(completion.data.fd);
let _ = stream.write(&EMPTY_SLICE).await;
let sys_socket = unsafe { std::net::TcpStream::from_raw_fd(stream.fd.raw_fd()) };
let err = sys_socket.take_error();
let _ = sys_socket.into_raw_fd();
if let Some(e) = err? {
return Err(e);
}
Ok(stream)
}
#[cfg(windows)]
pub async fn connect_addr(addr: SocketAddr) -> io::Result<Self> {
unimplemented!()
}
pub fn local_addr(&self) -> io::Result<SocketAddr> {
self.meta.local_addr()
}
pub fn peer_addr(&self) -> io::Result<SocketAddr> {
self.meta.peer_addr()
}
pub fn nodelay(&self) -> io::Result<bool> {
self.meta.no_delay()
}
pub fn set_nodelay(&self, nodelay: bool) -> io::Result<()> {
self.meta.set_no_delay(nodelay)
}
pub fn set_tcp_keepalive(
&self,
time: Option<Duration>,
interval: Option<Duration>,
retries: Option<u32>,
) -> io::Result<()> {
self.meta.set_tcp_keepalive(time, interval, retries)
}
pub fn from_std(stream: std::net::TcpStream) -> io::Result<Self> {
let fd = stream.into_raw_fd();
Ok(Self::from_shared_fd(SharedFd::new(fd)?))
}
}
impl AsReadFd for TcpStream {
fn as_reader_fd(&mut self) -> &SharedFdWrapper {
SharedFdWrapper::new(&self.fd)
}
}
impl AsWriteFd for TcpStream {
fn as_writer_fd(&mut self) -> &SharedFdWrapper {
SharedFdWrapper::new(&self.fd)
}
}
#[cfg(unix)]
impl IntoRawFd for TcpStream {
fn into_raw_fd(self) -> RawFd {
self.fd
.try_unwrap()
.expect("unexpected multiple reference to rawfd")
}
}
#[cfg(unix)]
impl AsRawFd for TcpStream {
fn as_raw_fd(&self) -> RawFd {
self.fd.raw_fd()
}
}
#[cfg(windows)]
impl IntoRawHandle for TcpStream {
fn into_raw_handle(self) -> RawHandle {
self.fd
.try_unwrap()
.expect("unexpected multiple reference to rawfd")
}
}
#[cfg(windows)]
impl AsRawHandle for TcpStream {
fn as_raw_handle(&self) -> RawHandle {
self.fd.raw_handle()
}
}
impl std::fmt::Debug for TcpStream {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TcpStream").field("fd", &self.fd).finish()
}
}
impl AsyncWriteRent for TcpStream {
type WriteFuture<'a, B> = impl Future<Output = crate::BufResult<usize, B>> where
B: IoBuf + 'a;
type WritevFuture<'a, B> = impl Future<Output = crate::BufResult<usize, B>> where
B: IoVecBuf + 'a;
type FlushFuture<'a> = impl Future<Output = io::Result<()>>;
type ShutdownFuture<'a> = impl Future<Output = io::Result<()>>;
fn write<T: IoBuf>(&mut self, buf: T) -> Self::WriteFuture<'_, T> {
let op = Op::send(&self.fd, buf).unwrap();
op.write()
}
fn writev<T: IoVecBuf>(&mut self, buf_vec: T) -> Self::WritevFuture<'_, T> {
let op = Op::writev(&self.fd, buf_vec).unwrap();
op.write()
}
fn flush(&mut self) -> Self::FlushFuture<'_> {
async move { Ok(()) }
}
#[cfg(unix)]
fn shutdown(&mut self) -> Self::ShutdownFuture<'_> {
let fd = self.as_raw_fd();
let res = match unsafe { libc::shutdown(fd, libc::SHUT_WR) } {
-1 => Err(io::Error::last_os_error()),
_ => Ok(()),
};
async move { res }
}
#[cfg(windows)]
fn shutdown(&mut self) -> Self::ShutdownFuture<'_> {
async { unimplemented!() }
}
}
impl AsyncReadRent for TcpStream {
type ReadFuture<'a, B> = impl std::future::Future<Output = crate::BufResult<usize, B>> where
B: IoBufMut + 'a;
type ReadvFuture<'a, B> = impl std::future::Future<Output = crate::BufResult<usize, B>> where
B: IoVecBufMut + 'a;
fn read<T: IoBufMut>(&mut self, buf: T) -> Self::ReadFuture<'_, T> {
let op = Op::recv(&self.fd, buf).unwrap();
op.read()
}
fn readv<T: IoVecBufMut>(&mut self, buf: T) -> Self::ReadvFuture<'_, T> {
let op = Op::readv(&self.fd, buf).unwrap();
op.read()
}
}
struct StreamMeta {
socket: Option<socket2::Socket>,
meta: UnsafeCell<Meta>,
}
#[derive(Debug, Default, Clone)]
struct Meta {
local_addr: Option<SocketAddr>,
peer_addr: Option<SocketAddr>,
}
impl StreamMeta {
#[cfg(unix)]
fn new(fd: RawFd) -> Self {
Self {
socket: unsafe { Some(socket2::Socket::from_raw_fd(fd)) },
meta: Default::default(),
}
}
#[cfg(windows)]
fn new(fd: RawHandle) -> Self {
unimplemented!()
}
fn local_addr(&self) -> io::Result<SocketAddr> {
let meta = unsafe { &mut *self.meta.get() };
if let Some(addr) = meta.local_addr {
return Ok(addr);
}
let ret = self
.socket
.as_ref()
.unwrap()
.local_addr()
.map(|addr| addr.as_socket().expect("tcp socket is expected"));
if let Ok(addr) = ret {
meta.local_addr = Some(addr);
}
ret
}
fn peer_addr(&self) -> io::Result<SocketAddr> {
let meta = unsafe { &mut *self.meta.get() };
if let Some(addr) = meta.peer_addr {
return Ok(addr);
}
let ret = self
.socket
.as_ref()
.unwrap()
.peer_addr()
.map(|addr| addr.as_socket().expect("tcp socket is expected"));
if let Ok(addr) = ret {
meta.peer_addr = Some(addr);
}
ret
}
fn no_delay(&self) -> io::Result<bool> {
self.socket.as_ref().unwrap().nodelay()
}
fn set_no_delay(&self, no_delay: bool) -> io::Result<()> {
self.socket.as_ref().unwrap().set_nodelay(no_delay)
}
fn set_tcp_keepalive(
&self,
time: Option<Duration>,
interval: Option<Duration>,
retries: Option<u32>,
) -> io::Result<()> {
let mut t = socket2::TcpKeepalive::new();
if let Some(time) = time {
t = t.with_time(time)
}
if let Some(interval) = interval {
t = t.with_interval(interval)
}
#[cfg(unix)]
if let Some(retries) = retries {
t = t.with_retries(retries)
}
self.socket.as_ref().unwrap().set_tcp_keepalive(&t)
}
#[cfg(feature = "zero-copy")]
fn set_zero_copy(&self) {
#[cfg(target_os = "linux")]
unsafe {
let fd = self.socket.as_ref().unwrap().as_raw_fd();
let v: libc::c_int = 1;
libc::setsockopt(
fd,
libc::SOL_SOCKET,
libc::SO_ZEROCOPY,
&v as *const _ as *const _,
std::mem::size_of::<libc::c_int>() as _,
);
}
}
}
impl Drop for StreamMeta {
fn drop(&mut self) {
#[cfg(unix)]
self.socket.take().unwrap().into_raw_fd();
#[cfg(windows)]
unimplemented!()
}
}