use pin_project_lite::pin_project;
use std::{
fmt,
net::SocketAddr,
path::Path,
pin::Pin,
sync::Arc,
task::{Context, Poll},
};
use tokio::{
io::{AsyncRead, AsyncWrite, ReadBuf},
net::{TcpStream, UnixStream},
};
use tokio_rustls::{
client::TlsStream,
rustls::{
client::WebPkiServerVerifier,
pki_types::{CertificateDer, PrivateKeyDer, ServerName},
ClientConfig, RootCertStore,
},
TlsConnector,
};
use tokio_socks::{
tcp::{socks4::Socks4Stream, socks5::Socks5Stream},
IntoTargetAddr, TargetAddr,
};
pub use tokio_rustls;
mod danger;
#[derive(Debug, foxerror::FoxError)]
#[non_exhaustive]
pub enum Error {
ClientCertNoTls,
#[err(from)]
Connect(std::io::Error),
#[err(from)]
Socks(tokio_socks::Error),
#[err(from)]
Rustls(tokio_rustls::rustls::Error),
SocksToUnsupported,
InvalidTarget(tokio_socks::Error),
NoServerName,
}
pin_project! {
#[derive(Debug)]
pub struct Stream {
#[pin]
inner: MaybeTls,
}
}
impl Stream {
pub fn new_tcp<'a>(addr: impl IntoTargetAddr<'a>) -> StreamBuilder<'a> {
StreamBuilder::new(BaseParams::Tcp(addr.into_target_addr()))
}
pub fn new_unix(path: &Path) -> StreamBuilder<'_> {
StreamBuilder::new(BaseParams::Unix(path))
}
}
impl AsyncRead for Stream {
#[inline]
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
self.project().inner.poll_read(cx, buf)
}
}
impl AsyncWrite for Stream {
#[inline]
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, std::io::Error>> {
self.project().inner.poll_write(cx, buf)
}
#[inline]
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
self.project().inner.poll_flush(cx)
}
#[inline]
fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
self.project().inner.poll_shutdown(cx)
}
}
pin_project! {
#[project = MaybeTlsProj]
#[derive(Debug)]
#[allow(clippy::large_enum_variant)] enum MaybeTls {
Plain {
#[pin]
inner: MaybeSocks,
},
Tls {
#[pin]
inner: TlsStream<MaybeSocks>,
},
}
}
macro_rules! trivial_impl {
($target:ty, ($($arm:path),*)) => {
impl AsyncRead for $target {
#[inline]
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
match self.project() {
$($arm { inner } => inner.poll_read(cx, buf),)*
}
}
}
impl AsyncWrite for $target {
#[inline]
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, std::io::Error>> {
match self.project() {
$($arm { inner } => inner.poll_write(cx, buf),)*
}
}
#[inline]
fn poll_flush(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
match self.project() {
$($arm { inner } => inner.poll_flush(cx),)*
}
}
#[inline]
fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
match self.project() {
$($arm { inner } => inner.poll_shutdown(cx),)*
}
}
}
};
}
trivial_impl!(MaybeTls, (MaybeTlsProj::Plain, MaybeTlsProj::Tls));
pin_project! {
#[project = MaybeSocksProj]
#[derive(Debug)]
enum MaybeSocks {
Clear {
#[pin]
inner: BaseStream,
},
Socks4 {
#[pin]
inner: Socks4Stream<BaseStream>,
},
Socks5 {
#[pin]
inner: Socks5Stream<BaseStream>,
},
}
}
trivial_impl!(
MaybeSocks,
(
MaybeSocksProj::Clear,
MaybeSocksProj::Socks4,
MaybeSocksProj::Socks5
)
);
pin_project! {
#[project = BaseStreamProj]
#[derive(Debug)]
enum BaseStream {
Tcp {
#[pin]
inner: TcpStream,
},
Unix {
#[pin]
inner: UnixStream,
},
}
}
trivial_impl!(BaseStream, (BaseStreamProj::Tcp, BaseStreamProj::Unix));
#[derive(Debug)]
#[must_use = "this does nothing unless you finish building"]
pub struct StreamBuilder<'a> {
base: BaseParams<'a>,
socks: Option<SocksParams<'a>>,
tls: Option<TlsParams>,
client_cert: Option<ClientCert>,
}
impl<'a> StreamBuilder<'a> {
fn new(base: BaseParams<'a>) -> Self {
Self {
base,
socks: None,
tls: None,
client_cert: None,
}
}
fn socks(
mut self,
version: SocksVersion,
proxy: SocketAddr,
auth: Option<SocksAuth<'a>>,
) -> Self {
self.socks = Some(SocksParams {
version,
proxy,
auth,
});
self
}
pub fn socks4(self, proxy: SocketAddr) -> Self {
self.socks(SocksVersion::Socks4, proxy, None)
}
pub fn socks4_with_userid(self, proxy: SocketAddr, userid: &'a str) -> Self {
self.socks(
SocksVersion::Socks4,
proxy,
Some(SocksAuth {
username: userid,
password: "h",
}),
)
}
pub fn socks5(self, proxy: SocketAddr) -> Self {
self.socks(SocksVersion::Socks5, proxy, None)
}
pub fn socks5_with_password(
self,
proxy: SocketAddr,
username: &'a str,
password: &'a str,
) -> Self {
self.socks(
SocksVersion::Socks5,
proxy,
Some(SocksAuth { username, password }),
)
}
fn tls(mut self, domain: Option<ServerName<'static>>, verification: TlsVerify) -> Self {
self.tls = Some(TlsParams {
domain,
verification,
});
self
}
pub fn tls_danger_insecure(self, domain: Option<ServerName<'static>>) -> Self {
self.tls(domain, TlsVerify::Insecure)
}
pub fn tls_with_root(
self,
domain: Option<ServerName<'static>>,
root: impl Into<Arc<RootCertStore>>,
) -> Self {
self.tls(domain, TlsVerify::CaStore(root.into()))
}
pub fn tls_with_webpki(
self,
domain: Option<ServerName<'static>>,
webpki: Arc<WebPkiServerVerifier>,
) -> Self {
self.tls(domain, TlsVerify::WebPki(webpki))
}
pub fn client_cert(
mut self,
cert_chain: Vec<CertificateDer<'static>>,
key_der: PrivateKeyDer<'static>,
) -> Self {
self.client_cert = Some(ClientCert {
cert_chain,
key_der,
});
self
}
pub async fn connect(self) -> Result<Stream, Error> {
let tls = if let Some(mut params) = self.tls {
params.domain = params.domain.or_else(|| match &self.base {
BaseParams::Tcp(Ok(TargetAddr::Ip(addr))) => Some(ServerName::from(addr.ip())),
BaseParams::Tcp(Ok(TargetAddr::Domain(d, _))) => {
ServerName::try_from(d.as_ref()).map(|s| s.to_owned()).ok()
}
_ => None,
});
Some(params)
} else {
None
};
let stream = if let Some(params) = self.socks {
let BaseParams::Tcp(target) = self.base else {
return Err(Error::SocksToUnsupported);
};
let target = target.map_err(Error::InvalidTarget)?;
let stream = BaseStream::Tcp {
inner: TcpStream::connect(params.proxy).await?,
};
match params.version {
SocksVersion::Socks4 => MaybeSocks::Socks4 {
inner: if let Some(SocksAuth { username, .. }) = params.auth {
Socks4Stream::connect_with_userid_and_socket(stream, target, username)
.await?
} else {
Socks4Stream::connect_with_socket(stream, target).await?
},
},
SocksVersion::Socks5 => MaybeSocks::Socks5 {
inner: if let Some(SocksAuth { username, password }) = params.auth {
Socks5Stream::connect_with_password_and_socket(
stream, target, username, password,
)
.await?
} else {
Socks5Stream::connect_with_socket(stream, target).await?
},
},
}
} else {
let stream = match self.base {
BaseParams::Tcp(addr) => {
let inner = match addr.map_err(Error::InvalidTarget)? {
TargetAddr::Ip(addr) => TcpStream::connect(addr).await?,
TargetAddr::Domain(domain, port) => {
TcpStream::connect((domain.as_ref(), port)).await?
}
};
BaseStream::Tcp { inner }
}
BaseParams::Unix(path) => BaseStream::Unix {
inner: UnixStream::connect(path).await?,
},
};
MaybeSocks::Clear { inner: stream }
};
let stream = if let Some(params) = tls {
let config = ClientConfig::builder();
let config = match params.verification {
TlsVerify::Insecure => {
let provider = config.crypto_provider().clone();
config
.dangerous()
.with_custom_certificate_verifier(danger::PhonyVerify::new(provider))
}
TlsVerify::CaStore(root) => config.with_root_certificates(root),
TlsVerify::WebPki(webpki) => config.with_webpki_verifier(webpki),
};
let config = if let Some(ClientCert {
cert_chain,
key_der,
}) = self.client_cert
{
config.with_client_auth_cert(cert_chain, key_der)?
} else {
config.with_no_client_auth()
};
let connector = TlsConnector::from(Arc::new(config));
let domain = params.domain.ok_or(Error::NoServerName)?;
let inner = connector.connect(domain, stream).await?;
MaybeTls::Tls { inner }
} else {
if self.client_cert.is_some() {
return Err(Error::ClientCertNoTls);
}
MaybeTls::Plain { inner: stream }
};
Ok(Stream { inner: stream })
}
}
#[derive(Debug)]
enum BaseParams<'a> {
Tcp(tokio_socks::Result<TargetAddr<'a>>),
Unix(&'a Path),
}
struct SocksParams<'a> {
version: SocksVersion,
proxy: SocketAddr,
auth: Option<SocksAuth<'a>>,
}
impl fmt::Debug for SocksParams<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Debug::fmt(&self.version, f)
}
}
struct SocksAuth<'a> {
username: &'a str,
password: &'a str,
}
#[derive(Debug)]
enum SocksVersion {
Socks4,
Socks5,
}
#[derive(Debug)]
struct TlsParams {
domain: Option<ServerName<'static>>,
verification: TlsVerify,
}
#[derive(Debug)]
enum TlsVerify {
Insecure,
CaStore(Arc<RootCertStore>),
WebPki(Arc<WebPkiServerVerifier>),
}
#[derive(Debug)]
struct ClientCert {
cert_chain: Vec<CertificateDer<'static>>,
key_der: PrivateKeyDer<'static>,
}