1use pin_project_lite::pin_project;
4use std::{
5 fmt,
6 net::SocketAddr,
7 path::Path,
8 pin::Pin,
9 sync::Arc,
10 task::{Context, Poll},
11};
12use tokio::{
13 io::{AsyncRead, AsyncWrite, ReadBuf},
14 net::{TcpStream, UnixStream},
15};
16use tokio_rustls::{
17 client::TlsStream,
18 rustls::{
19 client::WebPkiServerVerifier,
20 pki_types::{CertificateDer, PrivateKeyDer, ServerName},
21 ClientConfig, RootCertStore,
22 },
23 TlsConnector,
24};
25use tokio_socks::{
26 tcp::{socks4::Socks4Stream, socks5::Socks5Stream},
27 IntoTargetAddr, TargetAddr,
28};
29
30pub use tokio_rustls;
31
32mod danger;
33
34#[derive(Debug, foxerror::FoxError)]
36#[non_exhaustive]
37pub enum Error {
38 ClientCertNoTls,
40 #[err(from)]
42 Connect(std::io::Error),
43 #[err(from)]
45 Socks(tokio_socks::Error),
46 #[err(from)]
48 Rustls(tokio_rustls::rustls::Error),
49 SocksToUnsupported,
51 InvalidTarget(tokio_socks::Error),
53 NoServerName,
55}
56
57pin_project! {
58 #[derive(Debug)]
60 pub struct Stream {
61 #[pin]
62 inner: MaybeTls,
63 }
64}
65
66impl Stream {
67 pub fn new_tcp<'a>(addr: impl IntoTargetAddr<'a>) -> StreamBuilder<'a> {
77 StreamBuilder::new(BaseParams::Tcp(addr.into_target_addr()))
78 }
79 pub fn new_unix(path: &Path) -> StreamBuilder<'_> {
90 StreamBuilder::new(BaseParams::Unix(path))
91 }
92}
93
94impl AsyncRead for Stream {
95 #[inline]
96 fn poll_read(
97 self: Pin<&mut Self>,
98 cx: &mut Context<'_>,
99 buf: &mut ReadBuf<'_>,
100 ) -> Poll<std::io::Result<()>> {
101 self.project().inner.poll_read(cx, buf)
102 }
103}
104
105impl AsyncWrite for Stream {
106 #[inline]
107 fn poll_write(
108 self: Pin<&mut Self>,
109 cx: &mut Context<'_>,
110 buf: &[u8],
111 ) -> Poll<Result<usize, std::io::Error>> {
112 self.project().inner.poll_write(cx, buf)
113 }
114 #[inline]
115 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
116 self.project().inner.poll_flush(cx)
117 }
118 #[inline]
119 fn poll_shutdown(
120 self: Pin<&mut Self>,
121 cx: &mut Context<'_>,
122 ) -> Poll<Result<(), std::io::Error>> {
123 self.project().inner.poll_shutdown(cx)
124 }
125}
126
127pin_project! {
128 #[project = MaybeTlsProj]
129 #[derive(Debug)]
130 #[allow(clippy::large_enum_variant)] enum MaybeTls {
132 Plain {
133 #[pin]
134 inner: MaybeSocks,
135 },
136 Tls {
137 #[pin]
138 inner: TlsStream<MaybeSocks>,
139 },
140 }
141}
142
143macro_rules! trivial_impl {
144 ($target:ty, ($($arm:path),*)) => {
145 impl AsyncRead for $target {
146 #[inline]
147 fn poll_read(
148 self: Pin<&mut Self>,
149 cx: &mut Context<'_>,
150 buf: &mut ReadBuf<'_>,
151 ) -> Poll<std::io::Result<()>> {
152 match self.project() {
153 $($arm { inner } => inner.poll_read(cx, buf),)*
154 }
155 }
156 }
157
158 impl AsyncWrite for $target {
159 #[inline]
160 fn poll_write(
161 self: Pin<&mut Self>,
162 cx: &mut Context<'_>,
163 buf: &[u8],
164 ) -> Poll<Result<usize, std::io::Error>> {
165 match self.project() {
166 $($arm { inner } => inner.poll_write(cx, buf),)*
167 }
168 }
169 #[inline]
170 fn poll_flush(
171 self: Pin<&mut Self>,
172 cx: &mut Context<'_>,
173 ) -> Poll<Result<(), std::io::Error>> {
174 match self.project() {
175 $($arm { inner } => inner.poll_flush(cx),)*
176 }
177 }
178 #[inline]
179 fn poll_shutdown(
180 self: Pin<&mut Self>,
181 cx: &mut Context<'_>,
182 ) -> Poll<Result<(), std::io::Error>> {
183 match self.project() {
184 $($arm { inner } => inner.poll_shutdown(cx),)*
185 }
186 }
187 }
188 };
189}
190
191trivial_impl!(MaybeTls, (MaybeTlsProj::Plain, MaybeTlsProj::Tls));
192
193pin_project! {
194 #[project = MaybeSocksProj]
195 #[derive(Debug)]
196 enum MaybeSocks {
197 Clear {
198 #[pin]
199 inner: BaseStream,
200 },
201 Socks4 {
202 #[pin]
203 inner: Socks4Stream<BaseStream>,
204 },
205 Socks5 {
206 #[pin]
207 inner: Socks5Stream<BaseStream>,
208 },
209 }
210}
211
212trivial_impl!(
213 MaybeSocks,
214 (
215 MaybeSocksProj::Clear,
216 MaybeSocksProj::Socks4,
217 MaybeSocksProj::Socks5
218 )
219);
220
221pin_project! {
222 #[project = BaseStreamProj]
223 #[derive(Debug)]
224 enum BaseStream {
225 Tcp {
226 #[pin]
227 inner: TcpStream,
228 },
229 Unix {
230 #[pin]
231 inner: UnixStream,
232 },
233 }
234}
235
236trivial_impl!(BaseStream, (BaseStreamProj::Tcp, BaseStreamProj::Unix));
237
238#[derive(Debug)]
240#[must_use = "this does nothing unless you finish building"]
241pub struct StreamBuilder<'a> {
242 base: BaseParams<'a>,
243 socks: Option<SocksParams<'a>>,
244 tls: Option<TlsParams>,
245 client_cert: Option<ClientCert>,
246}
247
248impl<'a> StreamBuilder<'a> {
249 fn new(base: BaseParams<'a>) -> Self {
250 Self {
251 base,
252 socks: None,
253 tls: None,
254 client_cert: None,
255 }
256 }
257
258 fn socks(
259 mut self,
260 version: SocksVersion,
261 proxy: SocketAddr,
262 auth: Option<SocksAuth<'a>>,
263 ) -> Self {
264 self.socks = Some(SocksParams {
265 version,
266 proxy,
267 auth,
268 });
269 self
270 }
271
272 pub fn socks4(self, proxy: SocketAddr) -> Self {
283 self.socks(SocksVersion::Socks4, proxy, None)
284 }
285
286 pub fn socks4_with_userid(self, proxy: SocketAddr, userid: &'a str) -> Self {
297 self.socks(
298 SocksVersion::Socks4,
299 proxy,
300 Some(SocksAuth {
301 username: userid,
302 password: "h",
303 }),
304 )
305 }
306
307 pub fn socks5(self, proxy: SocketAddr) -> Self {
318 self.socks(SocksVersion::Socks5, proxy, None)
319 }
320
321 pub fn socks5_with_password(
333 self,
334 proxy: SocketAddr,
335 username: &'a str,
336 password: &'a str,
337 ) -> Self {
338 self.socks(
339 SocksVersion::Socks5,
340 proxy,
341 Some(SocksAuth { username, password }),
342 )
343 }
344
345 fn tls(mut self, domain: Option<ServerName<'static>>, verification: TlsVerify) -> Self {
346 self.tls = Some(TlsParams {
347 domain,
348 verification,
349 });
350 self
351 }
352
353 pub fn tls_danger_insecure(self, domain: Option<ServerName<'static>>) -> Self {
365 self.tls(domain, TlsVerify::Insecure)
366 }
367
368 pub fn tls_with_root(
390 self,
391 domain: Option<ServerName<'static>>,
392 root: impl Into<Arc<RootCertStore>>,
393 ) -> Self {
394 self.tls(domain, TlsVerify::CaStore(root.into()))
395 }
396
397 pub fn tls_with_webpki(
399 self,
400 domain: Option<ServerName<'static>>,
401 webpki: Arc<WebPkiServerVerifier>,
402 ) -> Self {
403 self.tls(domain, TlsVerify::WebPki(webpki))
404 }
405
406 pub fn client_cert(
427 mut self,
428 cert_chain: Vec<CertificateDer<'static>>,
429 key_der: PrivateKeyDer<'static>,
430 ) -> Self {
431 self.client_cert = Some(ClientCert {
432 cert_chain,
433 key_der,
434 });
435 self
436 }
437
438 pub async fn connect(self) -> Result<Stream, Error> {
453 let tls = if let Some(mut params) = self.tls {
454 params.domain = params.domain.or_else(|| match &self.base {
455 BaseParams::Tcp(Ok(TargetAddr::Ip(addr))) => Some(ServerName::from(addr.ip())),
456 BaseParams::Tcp(Ok(TargetAddr::Domain(d, _))) => {
457 ServerName::try_from(d.as_ref()).map(|s| s.to_owned()).ok()
458 }
459 _ => None,
460 });
461 Some(params)
462 } else {
463 None
464 };
465 let stream = if let Some(params) = self.socks {
466 let BaseParams::Tcp(target) = self.base else {
467 return Err(Error::SocksToUnsupported);
468 };
469 let target = target.map_err(Error::InvalidTarget)?;
470 let stream = BaseStream::Tcp {
471 inner: TcpStream::connect(params.proxy).await?,
472 };
473 match params.version {
474 SocksVersion::Socks4 => MaybeSocks::Socks4 {
475 inner: if let Some(SocksAuth { username, .. }) = params.auth {
476 Socks4Stream::connect_with_userid_and_socket(stream, target, username)
477 .await?
478 } else {
479 Socks4Stream::connect_with_socket(stream, target).await?
480 },
481 },
482 SocksVersion::Socks5 => MaybeSocks::Socks5 {
483 inner: if let Some(SocksAuth { username, password }) = params.auth {
484 Socks5Stream::connect_with_password_and_socket(
485 stream, target, username, password,
486 )
487 .await?
488 } else {
489 Socks5Stream::connect_with_socket(stream, target).await?
490 },
491 },
492 }
493 } else {
494 let stream = match self.base {
495 BaseParams::Tcp(addr) => {
496 let inner = match addr.map_err(Error::InvalidTarget)? {
499 TargetAddr::Ip(addr) => TcpStream::connect(addr).await?,
500 TargetAddr::Domain(domain, port) => {
501 TcpStream::connect((domain.as_ref(), port)).await?
502 }
503 };
504 BaseStream::Tcp { inner }
505 }
506 BaseParams::Unix(path) => BaseStream::Unix {
507 inner: UnixStream::connect(path).await?,
508 },
509 };
510 MaybeSocks::Clear { inner: stream }
511 };
512 let stream = if let Some(params) = tls {
513 let config = ClientConfig::builder();
514 let config = match params.verification {
515 TlsVerify::Insecure => {
516 let provider = config.crypto_provider().clone();
517 config
518 .dangerous()
519 .with_custom_certificate_verifier(danger::PhonyVerify::new(provider))
520 }
521 TlsVerify::CaStore(root) => config.with_root_certificates(root),
522 TlsVerify::WebPki(webpki) => config.with_webpki_verifier(webpki),
523 };
524 let config = if let Some(ClientCert {
525 cert_chain,
526 key_der,
527 }) = self.client_cert
528 {
529 config.with_client_auth_cert(cert_chain, key_der)?
530 } else {
531 config.with_no_client_auth()
532 };
533 let connector = TlsConnector::from(Arc::new(config));
534 let domain = params.domain.ok_or(Error::NoServerName)?;
535 let inner = connector.connect(domain, stream).await?;
536 MaybeTls::Tls { inner }
537 } else {
538 if self.client_cert.is_some() {
539 return Err(Error::ClientCertNoTls);
540 }
541 MaybeTls::Plain { inner: stream }
542 };
543 Ok(Stream { inner: stream })
544 }
545}
546
547#[derive(Debug)]
548enum BaseParams<'a> {
549 Tcp(tokio_socks::Result<TargetAddr<'a>>),
550 Unix(&'a Path),
551}
552
553struct SocksParams<'a> {
554 version: SocksVersion,
555 proxy: SocketAddr,
556 auth: Option<SocksAuth<'a>>,
557}
558
559impl fmt::Debug for SocksParams<'_> {
560 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
561 fmt::Debug::fmt(&self.version, f)
562 }
563}
564
565struct SocksAuth<'a> {
566 username: &'a str,
567 password: &'a str,
568}
569
570#[derive(Debug)]
571enum SocksVersion {
572 Socks4,
573 Socks5,
574}
575
576#[derive(Debug)]
577struct TlsParams {
578 domain: Option<ServerName<'static>>,
579 verification: TlsVerify,
580}
581
582#[derive(Debug)]
583enum TlsVerify {
584 Insecure,
585 CaStore(Arc<RootCertStore>),
586 WebPki(Arc<WebPkiServerVerifier>),
587}
588
589#[derive(Debug)]
590struct ClientCert {
591 cert_chain: Vec<CertificateDer<'static>>,
592 key_der: PrivateKeyDer<'static>,
593}