use crate::config::{TcpConfig, TransportConfig};
use crate::constants::MESSAGE_TIMEOUT_SECS;
use super::{AddrMaybeCached, ProtobufStream, SocketOpts, Transport};
pub use crate::unix_tcp::{Listener, NamedSocketAddr, SocketAddr, Stream};
use crate::utils::host_port_pair;
use anyhow::{Context as _, Result};
use async_http_proxy::{http_connect_tokio, http_connect_tokio_with_basic_auth};
use async_trait::async_trait;
use socket2::{SockRef, TcpKeepalive};
#[cfg(unix)]
use std::os::fd::RawFd;
use std::str::FromStr;
use std::time::Duration;
type RawTcpStream = Stream;
use crate::protocol::message::Message as ProtocolMessage;
use crate::protocol::{read_message, write_message};
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf};
use tokio::time::timeout;
use tracing::trace;
use url::Url;
#[derive(Debug)]
pub struct TcpStream {
inner: RawTcpStream,
}
impl TcpStream {
pub fn new(stream: RawTcpStream) -> Self {
Self { inner: stream }
}
pub fn into_inner(self) -> RawTcpStream {
self.inner
}
pub fn get_ref(&self) -> &RawTcpStream {
&self.inner
}
pub fn get_mut(&mut self) -> &mut RawTcpStream {
&mut self.inner
}
pub fn into_stream(self) -> Stream {
self.inner
}
}
impl AsyncRead for TcpStream {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
Pin::new(&mut self.inner).poll_read(cx, buf)
}
}
impl AsyncWrite for TcpStream {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, std::io::Error>> {
Pin::new(&mut self.inner).poll_write(cx, buf)
}
fn poll_flush(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
Pin::new(&mut self.inner).poll_flush(cx)
}
fn poll_shutdown(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
Pin::new(&mut self.inner).poll_shutdown(cx)
}
}
#[async_trait]
impl ProtobufStream for TcpStream {
async fn recv_message(&mut self) -> anyhow::Result<Option<ProtocolMessage>> {
let timeout_duration = Duration::from_secs(MESSAGE_TIMEOUT_SECS);
match timeout(timeout_duration, read_message(&mut self.inner)).await {
Ok(Ok(msg)) => Ok(Some(msg)),
Ok(Err(e)) => {
if let Some(io_err) = e.downcast_ref::<std::io::Error>() {
if io_err.kind() == std::io::ErrorKind::UnexpectedEof {
return Ok(None);
}
}
Err(e)
}
Err(_) => Err(anyhow::anyhow!(
"Timeout reading message after {} seconds",
MESSAGE_TIMEOUT_SECS
)),
}
}
async fn send_message(&mut self, msg: &ProtocolMessage) -> anyhow::Result<()> {
let timeout_duration = Duration::from_secs(MESSAGE_TIMEOUT_SECS);
timeout(timeout_duration, write_message(&mut self.inner, msg))
.await
.map_err(|_| {
anyhow::anyhow!(
"Timeout writing message after {} seconds",
MESSAGE_TIMEOUT_SECS
)
})?
}
async fn close(&mut self) -> anyhow::Result<()> {
self.inner
.shutdown()
.await
.context("Failed to shutdown stream")
}
}
#[derive(Debug)]
pub struct TcpTransport {
pub socket_opts: SocketOpts,
pub cfg: TcpConfig,
}
#[async_trait]
impl Transport for TcpTransport {
type Acceptor = Listener;
type Stream = TcpStream;
type RawStream = Stream;
fn new(config: &TransportConfig) -> Result<Self> {
Ok(TcpTransport {
socket_opts: SocketOpts::for_control_channel(),
cfg: config.tcp.clone(),
})
}
#[cfg(unix)]
fn as_raw_fd(conn: &Self::Stream) -> RawFd {
use std::os::fd::AsRawFd;
match conn.get_ref() {
Stream::Tcp(tcp_stream) => tcp_stream.as_raw_fd(),
#[cfg(unix)]
Stream::Unix(unix_stream) => unix_stream.as_raw_fd(),
}
}
fn hint(conn: &Self::Stream, opt: SocketOpts) {
opt.apply(conn.get_ref());
}
async fn bind(&self, addr: NamedSocketAddr) -> Result<Self::Acceptor> {
#[cfg(unix)]
if let NamedSocketAddr::Unix(path) = &addr {
if path.exists() {
tokio::fs::remove_file(path).await?;
}
}
Ok(Listener::bind(&addr).await?)
}
async fn accept(&self, a: &Self::Acceptor) -> Result<(Self::RawStream, SocketAddr)> {
let (s, addr) = a.accept().await?;
self.socket_opts.apply(&s);
Ok((s, addr))
}
async fn handshake(&self, conn: Self::RawStream) -> Result<Self::Stream> {
Ok(TcpStream::new(conn))
}
async fn connect(&self, addr: &AddrMaybeCached) -> Result<Self::Stream> {
let s = tcp_connect_with_proxy(addr, self.cfg.proxy.as_ref()).await?;
self.socket_opts.apply(&s);
Ok(TcpStream::new(s))
}
}
pub fn try_set_tcp_keepalive(
conn: &RawTcpStream,
keepalive_duration: Duration,
keepalive_interval: Duration,
) -> Result<()> {
match conn {
Stream::Tcp(tcp_stream) => {
let s = SockRef::from(tcp_stream);
let keepalive = TcpKeepalive::new()
.with_time(keepalive_duration)
.with_interval(keepalive_interval);
trace!(
"Set TCP keepalive {:?} {:?}",
keepalive_duration,
keepalive_interval
);
Ok(s.set_tcp_keepalive(&keepalive)?)
}
#[cfg(unix)]
Stream::Unix(_) => {
Ok(())
}
}
}
pub async fn tcp_connect_with_proxy(addr: &AddrMaybeCached, proxy: Option<&Url>) -> Result<Stream> {
if let Some(url) = proxy {
let addr = &addr.addr;
let proxy_addr = format!(
"{}:{}",
url.host_str().expect("proxy url should have host field"),
url.port().expect("proxy url should have port field")
);
let mut s = Stream::connect(&NamedSocketAddr::from_str(&proxy_addr)?).await?;
let auth = if !url.username().is_empty() || url.password().is_some() {
Some(async_socks5::Auth {
username: url.username().into(),
password: url.password().unwrap_or("").into(),
})
} else {
None
};
match url.scheme() {
"socks5" => {
async_socks5::connect(&mut s, host_port_pair(addr)?, auth).await?;
}
"http" => {
let (host, port) = host_port_pair(addr)?;
match auth {
Some(auth) => {
http_connect_tokio_with_basic_auth(
&mut s,
host,
port,
&auth.username,
&auth.password,
)
.await?
}
None => http_connect_tokio(&mut s, host, port).await?,
}
}
_ => panic!("unknown proxy scheme"),
}
Ok(s)
} else {
Ok(match addr.socket_addr.as_ref() {
Some(s) => Stream::connect(s).await?,
None => Stream::connect(&NamedSocketAddr::from_str(&addr.addr)?).await?,
})
}
}