use std::{
io::{self, ErrorKind},
pin::Pin,
task::{self, Poll},
};
use bytes::{BufMut, BytesMut};
use cfg_if::cfg_if;
use futures::ready;
use log::trace;
use once_cell::sync::Lazy;
use pin_project::pin_project;
use tokio::{
io::{AsyncRead, AsyncWrite, ReadBuf},
time,
};
#[cfg(feature = "aead-cipher-2022")]
use crate::relay::get_aead_2022_padding_size;
use crate::{
config::ServerConfig,
context::SharedContext,
crypto::CipherKind,
net::{ConnectOpts, TcpStream as OutboundTcpStream},
relay::{
socks5::Address,
tcprelay::crypto_io::{CryptoRead, CryptoStream, CryptoWrite, StreamType},
},
};
enum ProxyClientStreamWriteState {
Connect(Address),
Connecting(BytesMut),
Connected,
}
enum ProxyClientStreamReadState {
#[cfg(feature = "aead-cipher-2022")]
CheckRequestNonce,
Established,
}
#[pin_project]
pub struct ProxyClientStream<S> {
#[pin]
stream: CryptoStream<S>,
writer_state: ProxyClientStreamWriteState,
reader_state: ProxyClientStreamReadState,
context: SharedContext,
}
static DEFAULT_CONNECT_OPTS: Lazy<ConnectOpts> = Lazy::new(Default::default);
impl ProxyClientStream<OutboundTcpStream> {
pub async fn connect<A>(
context: SharedContext,
svr_cfg: &ServerConfig,
addr: A,
) -> io::Result<ProxyClientStream<OutboundTcpStream>>
where
A: Into<Address>,
{
ProxyClientStream::connect_with_opts(context, svr_cfg, addr, &DEFAULT_CONNECT_OPTS).await
}
pub async fn connect_with_opts<A>(
context: SharedContext,
svr_cfg: &ServerConfig,
addr: A,
opts: &ConnectOpts,
) -> io::Result<ProxyClientStream<OutboundTcpStream>>
where
A: Into<Address>,
{
ProxyClientStream::connect_with_opts_map(context, svr_cfg, addr, opts, |s| s).await
}
}
impl<S> ProxyClientStream<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
pub async fn connect_map<A, F>(
context: SharedContext,
svr_cfg: &ServerConfig,
addr: A,
map_fn: F,
) -> io::Result<ProxyClientStream<S>>
where
A: Into<Address>,
F: FnOnce(OutboundTcpStream) -> S,
{
ProxyClientStream::connect_with_opts_map(context, svr_cfg, addr, &DEFAULT_CONNECT_OPTS, map_fn).await
}
pub async fn connect_with_opts_map<A, F>(
context: SharedContext,
svr_cfg: &ServerConfig,
addr: A,
opts: &ConnectOpts,
map_fn: F,
) -> io::Result<ProxyClientStream<S>>
where
A: Into<Address>,
F: FnOnce(OutboundTcpStream) -> S,
{
let stream = match svr_cfg.timeout() {
Some(d) => {
match time::timeout(
d,
OutboundTcpStream::connect_server_with_opts(&context, svr_cfg.external_addr(), opts),
)
.await
{
Ok(Ok(s)) => s,
Ok(Err(e)) => return Err(e),
Err(..) => {
return Err(io::Error::new(
ErrorKind::TimedOut,
format!("connect {} timeout", svr_cfg.addr()),
))
}
}
}
None => OutboundTcpStream::connect_server_with_opts(&context, svr_cfg.external_addr(), opts).await?,
};
trace!(
"connected tcp remote {} (outbound: {}) with {:?}",
svr_cfg.addr(),
svr_cfg.external_addr(),
opts
);
Ok(ProxyClientStream::from_stream(context, map_fn(stream), svr_cfg, addr))
}
pub fn from_stream<A>(context: SharedContext, stream: S, svr_cfg: &ServerConfig, addr: A) -> ProxyClientStream<S>
where
A: Into<Address>,
{
let addr = addr.into();
let stream = CryptoStream::from_stream_with_identity(
&context,
stream,
StreamType::Client,
svr_cfg.method(),
svr_cfg.key(),
svr_cfg.identity_keys(),
None,
);
#[cfg(not(feature = "aead-cipher-2022"))]
let reader_state = ProxyClientStreamReadState::Established;
#[cfg(feature = "aead-cipher-2022")]
let reader_state = if svr_cfg.method().is_aead_2022() {
ProxyClientStreamReadState::CheckRequestNonce
} else {
ProxyClientStreamReadState::Established
};
ProxyClientStream {
stream,
writer_state: ProxyClientStreamWriteState::Connect(addr),
reader_state,
context,
}
}
pub fn get_ref(&self) -> &S {
self.stream.get_ref()
}
pub fn get_mut(&mut self) -> &mut S {
self.stream.get_mut()
}
pub fn into_inner(self) -> S {
self.stream.into_inner()
}
}
impl<S> AsyncRead for ProxyClientStream<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
#[inline]
fn poll_read(self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<io::Result<()>> {
#[allow(unused_mut)]
let mut this = self.project();
#[allow(clippy::never_loop)]
loop {
match this.reader_state {
ProxyClientStreamReadState::Established => {
return this
.stream
.poll_read_decrypted(cx, this.context, buf)
.map_err(Into::into);
}
#[cfg(feature = "aead-cipher-2022")]
ProxyClientStreamReadState::CheckRequestNonce => {
ready!(this.stream.as_mut().poll_read_decrypted(cx, this.context, buf))?;
let (data_chunk_count, _) = this.stream.current_data_chunk_remaining();
if data_chunk_count > 0 {
let sent_nonce = this.stream.sent_nonce();
let sent_nonce = if sent_nonce.is_empty() { None } else { Some(sent_nonce) };
if sent_nonce != this.stream.received_request_nonce() {
return Err(io::Error::new(
ErrorKind::Other,
"received TCP response header with unmatched salt",
))
.into();
}
*(this.reader_state) = ProxyClientStreamReadState::Established;
}
return Ok(()).into();
}
}
}
}
}
#[inline]
fn make_first_packet_buffer(method: CipherKind, addr: &Address, buf: &[u8]) -> BytesMut {
let addr_length = addr.serialized_len();
let mut buffer = BytesMut::new();
cfg_if! {
if #[cfg(feature = "aead-cipher-2022")] {
let padding_size = get_aead_2022_padding_size(buf);
let header_length = if method.is_aead_2022() {
addr_length + 2 + padding_size + buf.len()
} else {
addr_length + buf.len()
};
} else {
let _ = method;
let header_length = addr_length + buf.len();
}
}
buffer.reserve(header_length);
addr.write_to_buf(&mut buffer);
#[cfg(feature = "aead-cipher-2022")]
if method.is_aead_2022() {
buffer.put_u16(padding_size as u16);
if padding_size > 0 {
unsafe {
buffer.advance_mut(padding_size);
}
}
}
buffer.put_slice(buf);
buffer
}
impl<S> AsyncWrite for ProxyClientStream<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
fn poll_write(self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &[u8]) -> Poll<Result<usize, io::Error>> {
let this = self.project();
loop {
match this.writer_state {
ProxyClientStreamWriteState::Connect(ref addr) => {
let buffer = make_first_packet_buffer(this.stream.method(), addr, buf);
*(this.writer_state) = ProxyClientStreamWriteState::Connecting(buffer);
}
ProxyClientStreamWriteState::Connecting(ref buffer) => {
let n = ready!(this.stream.poll_write_encrypted(cx, buffer))?;
debug_assert!(n == buffer.len());
*(this.writer_state) = ProxyClientStreamWriteState::Connected;
return Ok(buf.len()).into();
}
ProxyClientStreamWriteState::Connected => {
return this.stream.poll_write_encrypted(cx, buf).map_err(Into::into);
}
}
}
}
#[inline]
fn poll_flush(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Result<(), io::Error>> {
self.project().stream.poll_flush(cx).map_err(Into::into)
}
#[inline]
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Result<(), io::Error>> {
self.project().stream.poll_shutdown(cx).map_err(Into::into)
}
}