#[cfg(unix)]
use std::os::unix::io::{AsRawFd, RawFd};
#[cfg(windows)]
use std::os::windows::io::{AsRawSocket, RawSocket};
use std::{
io,
net::SocketAddr,
ops::{Deref, DerefMut},
pin::Pin,
task::{self, Poll},
};
use futures::{future, ready};
use pin_project::pin_project;
use tokio::{
io::{AsyncRead, AsyncWrite, ReadBuf},
net::{TcpListener as TokioTcpListener, TcpStream as TokioTcpStream},
};
use crate::{context::Context, relay::socks5::Address, ServerAddr};
use super::{
is_dual_stack_addr,
sys::{
create_inbound_tcp_socket,
set_common_sockopt_after_accept,
set_tcp_fastopen,
socket_bind_dual_stack,
TcpStream as SysTcpStream,
},
AcceptOpts,
ConnectOpts,
};
#[pin_project]
pub struct TcpStream(#[pin] SysTcpStream);
impl TcpStream {
pub async fn connect_with_opts(addr: &SocketAddr, opts: &ConnectOpts) -> io::Result<TcpStream> {
SysTcpStream::connect(*addr, opts).await.map(TcpStream)
}
pub async fn connect_server_with_opts(
context: &Context,
addr: &ServerAddr,
opts: &ConnectOpts,
) -> io::Result<TcpStream> {
let stream = match *addr {
ServerAddr::SocketAddr(ref addr) => SysTcpStream::connect(*addr, opts).await?,
ServerAddr::DomainName(ref domain, port) => {
lookup_then_connect!(context, domain, port, |addr| {
SysTcpStream::connect(addr, opts).await
})?
.1
}
};
Ok(TcpStream(stream))
}
pub async fn connect_remote_with_opts(
context: &Context,
addr: &Address,
opts: &ConnectOpts,
) -> io::Result<TcpStream> {
let stream = match *addr {
Address::SocketAddress(ref addr) => SysTcpStream::connect(*addr, opts).await?,
Address::DomainNameAddress(ref domain, port) => {
lookup_then_connect!(context, domain, port, |addr| {
SysTcpStream::connect(addr, opts).await
})?
.1
}
};
Ok(TcpStream(stream))
}
pub fn local_addr(&self) -> io::Result<SocketAddr> {
self.0.local_addr()
}
pub fn peer_addr(&self) -> io::Result<SocketAddr> {
self.0.peer_addr()
}
pub fn nodelay(&self) -> io::Result<bool> {
self.0.nodelay()
}
pub fn set_nodelay(&self, nodelay: bool) -> io::Result<()> {
self.0.set_nodelay(nodelay)
}
}
impl AsyncRead for TcpStream {
fn poll_read(self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<io::Result<()>> {
self.project().0.poll_read(cx, buf)
}
}
impl AsyncWrite for TcpStream {
fn poll_write(self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
self.project().0.poll_write(cx, buf)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
self.project().0.poll_flush(cx)
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
self.project().0.poll_shutdown(cx)
}
}
pub struct TcpListener {
inner: TokioTcpListener,
accept_opts: AcceptOpts,
}
impl TcpListener {
pub async fn bind_with_opts(addr: &SocketAddr, accept_opts: AcceptOpts) -> io::Result<TcpListener> {
let socket = create_inbound_tcp_socket(addr, &accept_opts).await?;
if let Some(size) = accept_opts.tcp.send_buffer_size {
socket.set_send_buffer_size(size)?;
}
if let Some(size) = accept_opts.tcp.recv_buffer_size {
socket.set_recv_buffer_size(size)?;
}
#[cfg(not(windows))]
socket.set_reuseaddr(true)?;
let set_dual_stack = is_dual_stack_addr(addr);
if set_dual_stack {
socket_bind_dual_stack(&socket, addr, accept_opts.ipv6_only)?;
} else {
socket.bind(*addr)?;
}
let inner = socket.listen(1024)?;
if accept_opts.tcp.fastopen {
set_tcp_fastopen(&inner)?;
}
Ok(TcpListener { inner, accept_opts })
}
pub fn from_listener(listener: TokioTcpListener, accept_opts: AcceptOpts) -> io::Result<TcpListener> {
if accept_opts.tcp.fastopen {
set_tcp_fastopen(&listener)?;
}
Ok(TcpListener {
inner: listener,
accept_opts,
})
}
pub fn poll_accept(&self, cx: &mut task::Context<'_>) -> Poll<io::Result<(TokioTcpStream, SocketAddr)>> {
let (stream, peer_addr) = ready!(self.inner.poll_accept(cx))?;
set_common_sockopt_after_accept(&stream, &self.accept_opts)?;
Poll::Ready(Ok((stream, peer_addr)))
}
pub async fn accept(&self) -> io::Result<(TokioTcpStream, SocketAddr)> {
future::poll_fn(|cx| self.poll_accept(cx)).await
}
pub fn into_inner(self) -> TokioTcpListener {
self.inner
}
}
impl Deref for TcpListener {
type Target = TokioTcpListener;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
impl DerefMut for TcpListener {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.inner
}
}
impl From<TcpListener> for TokioTcpListener {
fn from(listener: TcpListener) -> TokioTcpListener {
listener.inner
}
}
#[cfg(unix)]
impl AsRawFd for TcpStream {
fn as_raw_fd(&self) -> RawFd {
self.0.as_raw_fd()
}
}
#[cfg(windows)]
impl AsRawSocket for TcpStream {
fn as_raw_socket(&self) -> RawSocket {
self.0.as_raw_socket()
}
}