1use pin_project_lite::pin_project;
4use std::{
5 fmt,
6 net::SocketAddr,
7 path::Path,
8 pin::Pin,
9 sync::Arc,
10 task::{Context, Poll},
11};
12use tokio::{
13 io::{AsyncRead, AsyncWrite, ReadBuf},
14 net::{TcpStream, UnixStream},
15};
16use tokio_rustls::{
17 client::TlsStream,
18 rustls::{
19 client::WebPkiServerVerifier,
20 pki_types::{CertificateDer, PrivateKeyDer, ServerName},
21 ClientConfig, RootCertStore,
22 },
23 TlsConnector,
24};
25use tokio_socks::{
26 tcp::{socks4::Socks4Stream, socks5::Socks5Stream},
27 IntoTargetAddr, TargetAddr,
28};
29
30pub use tokio_rustls;
31
32mod danger;
33
34#[derive(Debug, foxerror::FoxError)]
36#[non_exhaustive]
37pub enum Error {
38 ClientCertNoTls,
40 #[err(from)]
42 Connect(std::io::Error),
43 #[err(from)]
45 Socks(tokio_socks::Error),
46 #[err(from)]
48 Rustls(tokio_rustls::rustls::Error),
49}
50
51pin_project! {
52 #[derive(Debug)]
54 pub struct Stream {
55 #[pin]
56 inner: MaybeTls,
57 }
58}
59
60impl Stream {
61 pub fn new_tcp(addr: &SocketAddr) -> StreamBuilder<'_> {
71 StreamBuilder::new(BaseParams::Tcp(addr))
72 }
73 pub fn new_unix(path: &Path) -> StreamBuilder<'_> {
84 StreamBuilder::new(BaseParams::Unix(path))
85 }
86}
87
88impl AsyncRead for Stream {
89 #[inline]
90 fn poll_read(
91 self: Pin<&mut Self>,
92 cx: &mut Context<'_>,
93 buf: &mut ReadBuf<'_>,
94 ) -> Poll<std::io::Result<()>> {
95 self.project().inner.poll_read(cx, buf)
96 }
97}
98
99impl AsyncWrite for Stream {
100 #[inline]
101 fn poll_write(
102 self: Pin<&mut Self>,
103 cx: &mut Context<'_>,
104 buf: &[u8],
105 ) -> Poll<Result<usize, std::io::Error>> {
106 self.project().inner.poll_write(cx, buf)
107 }
108 #[inline]
109 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
110 self.project().inner.poll_flush(cx)
111 }
112 #[inline]
113 fn poll_shutdown(
114 self: Pin<&mut Self>,
115 cx: &mut Context<'_>,
116 ) -> Poll<Result<(), std::io::Error>> {
117 self.project().inner.poll_shutdown(cx)
118 }
119}
120
121pin_project! {
122 #[project = MaybeTlsProj]
123 #[derive(Debug)]
124 #[allow(clippy::large_enum_variant)] enum MaybeTls {
126 Plain {
127 #[pin]
128 inner: MaybeSocks,
129 },
130 Tls {
131 #[pin]
132 inner: TlsStream<MaybeSocks>,
133 },
134 }
135}
136
137macro_rules! trivial_impl {
138 ($target:ty, ($($arm:path),*)) => {
139 impl AsyncRead for $target {
140 #[inline]
141 fn poll_read(
142 self: Pin<&mut Self>,
143 cx: &mut Context<'_>,
144 buf: &mut ReadBuf<'_>,
145 ) -> Poll<std::io::Result<()>> {
146 match self.project() {
147 $($arm { inner } => inner.poll_read(cx, buf),)*
148 }
149 }
150 }
151
152 impl AsyncWrite for $target {
153 #[inline]
154 fn poll_write(
155 self: Pin<&mut Self>,
156 cx: &mut Context<'_>,
157 buf: &[u8],
158 ) -> Poll<Result<usize, std::io::Error>> {
159 match self.project() {
160 $($arm { inner } => inner.poll_write(cx, buf),)*
161 }
162 }
163 #[inline]
164 fn poll_flush(
165 self: Pin<&mut Self>,
166 cx: &mut Context<'_>,
167 ) -> Poll<Result<(), std::io::Error>> {
168 match self.project() {
169 $($arm { inner } => inner.poll_flush(cx),)*
170 }
171 }
172 #[inline]
173 fn poll_shutdown(
174 self: Pin<&mut Self>,
175 cx: &mut Context<'_>,
176 ) -> Poll<Result<(), std::io::Error>> {
177 match self.project() {
178 $($arm { inner } => inner.poll_shutdown(cx),)*
179 }
180 }
181 }
182 };
183}
184
185trivial_impl!(MaybeTls, (MaybeTlsProj::Plain, MaybeTlsProj::Tls));
186
187pin_project! {
188 #[project = MaybeSocksProj]
189 #[derive(Debug)]
190 enum MaybeSocks {
191 Clear {
192 #[pin]
193 inner: BaseStream,
194 },
195 Socks4 {
196 #[pin]
197 inner: Socks4Stream<BaseStream>,
198 },
199 Socks5 {
200 #[pin]
201 inner: Socks5Stream<BaseStream>,
202 },
203 }
204}
205
206trivial_impl!(
207 MaybeSocks,
208 (
209 MaybeSocksProj::Clear,
210 MaybeSocksProj::Socks4,
211 MaybeSocksProj::Socks5
212 )
213);
214
215pin_project! {
216 #[project = BaseStreamProj]
217 #[derive(Debug)]
218 enum BaseStream {
219 Tcp {
220 #[pin]
221 inner: TcpStream,
222 },
223 Unix {
224 #[pin]
225 inner: UnixStream,
226 },
227 }
228}
229
230trivial_impl!(BaseStream, (BaseStreamProj::Tcp, BaseStreamProj::Unix));
231
232#[derive(Debug)]
234pub struct StreamBuilder<'a> {
235 base: BaseParams<'a>,
236 socks: Option<SocksParams<'a>>,
237 tls: Option<TlsParams>,
238 client_cert: Option<ClientCert>,
239}
240
241impl<'a> StreamBuilder<'a> {
242 fn new(base: BaseParams<'a>) -> Self {
243 Self {
244 base,
245 socks: None,
246 tls: None,
247 client_cert: None,
248 }
249 }
250
251 fn socks(
252 mut self,
253 version: SocksVersion,
254 target: impl IntoTargetAddr<'a>,
255 auth: Option<SocksAuth<'a>>,
256 ) -> Self {
257 self.socks = Some(SocksParams {
258 version,
259 target: target.into_target_addr(),
260 auth,
261 });
262 self
263 }
264
265 #[deprecated(note = "the current behavior is unintentional and will be replaced for v0.2.0")]
277 pub fn socks4(self, target: impl IntoTargetAddr<'a>) -> Self {
278 self.socks(SocksVersion::Socks4, target, None)
279 }
280
281 #[deprecated(note = "the current behavior is unintentional and will be replaced for v0.2.0")]
293 pub fn socks4_with_userid(self, target: impl IntoTargetAddr<'a>, userid: &'a str) -> Self {
294 self.socks(
295 SocksVersion::Socks4,
296 target,
297 Some(SocksAuth {
298 username: userid,
299 password: "h",
300 }),
301 )
302 }
303
304 #[deprecated(note = "the current behavior is unintentional and will be replaced for v0.2.0")]
316 pub fn socks5(self, target: impl IntoTargetAddr<'a>) -> Self {
317 self.socks(SocksVersion::Socks5, target, None)
318 }
319
320 #[deprecated(note = "the current behavior is unintentional and will be replaced for v0.2.0")]
332 pub fn socks5_with_password(
333 self,
334 target: impl IntoTargetAddr<'a>,
335 username: &'a str,
336 password: &'a str,
337 ) -> Self {
338 self.socks(
339 SocksVersion::Socks5,
340 target,
341 Some(SocksAuth { username, password }),
342 )
343 }
344
345 fn tls(mut self, domain: ServerName<'static>, verification: TlsVerify) -> Self {
346 self.tls = Some(TlsParams {
347 domain,
348 verification,
349 });
350 self
351 }
352
353 pub fn tls_danger_insecure(self, domain: ServerName<'static>) -> Self {
366 self.tls(domain, TlsVerify::Insecure)
367 }
368
369 pub fn tls_with_root(
392 self,
393 domain: ServerName<'static>,
394 root: impl Into<Arc<RootCertStore>>,
395 ) -> Self {
396 self.tls(domain, TlsVerify::CaStore(root.into()))
397 }
398
399 pub fn tls_with_webpki(
401 self,
402 domain: ServerName<'static>,
403 webpki: Arc<WebPkiServerVerifier>,
404 ) -> Self {
405 self.tls(domain, TlsVerify::WebPki(webpki))
406 }
407
408 pub fn client_cert(
429 mut self,
430 cert_chain: Vec<CertificateDer<'static>>,
431 key_der: PrivateKeyDer<'static>,
432 ) -> Self {
433 self.client_cert = Some(ClientCert {
434 cert_chain,
435 key_der,
436 });
437 self
438 }
439
440 pub async fn connect(self) -> Result<Stream, Error> {
452 let stream = match self.base {
453 BaseParams::Tcp(addr) => BaseStream::Tcp {
454 inner: TcpStream::connect(addr).await?,
455 },
456 BaseParams::Unix(path) => BaseStream::Unix {
457 inner: UnixStream::connect(path).await?,
458 },
459 };
460 let stream = if let Some(params) = self.socks {
461 let target = params.target?;
462 match params.version {
463 SocksVersion::Socks4 => MaybeSocks::Socks4 {
464 inner: if let Some(SocksAuth { username, .. }) = params.auth {
465 Socks4Stream::connect_with_userid_and_socket(stream, target, username)
466 .await?
467 } else {
468 Socks4Stream::connect_with_socket(stream, target).await?
469 },
470 },
471 SocksVersion::Socks5 => MaybeSocks::Socks5 {
472 inner: if let Some(SocksAuth { username, password }) = params.auth {
473 Socks5Stream::connect_with_password_and_socket(
474 stream, target, username, password,
475 )
476 .await?
477 } else {
478 Socks5Stream::connect_with_socket(stream, target).await?
479 },
480 },
481 }
482 } else {
483 MaybeSocks::Clear { inner: stream }
484 };
485 let stream = if let Some(params) = self.tls {
486 let config = ClientConfig::builder();
487 let config = match params.verification {
488 TlsVerify::Insecure => {
489 let provider = config.crypto_provider().clone();
490 config
491 .dangerous()
492 .with_custom_certificate_verifier(danger::PhonyVerify::new(provider))
493 }
494 TlsVerify::CaStore(root) => config.with_root_certificates(root),
495 TlsVerify::WebPki(webpki) => config.with_webpki_verifier(webpki),
496 };
497 let config = if let Some(ClientCert {
498 cert_chain,
499 key_der,
500 }) = self.client_cert
501 {
502 config.with_client_auth_cert(cert_chain, key_der)?
503 } else {
504 config.with_no_client_auth()
505 };
506 let connector = TlsConnector::from(Arc::new(config));
507 let inner = connector.connect(params.domain, stream).await?;
508 MaybeTls::Tls { inner }
509 } else {
510 if self.client_cert.is_some() {
511 return Err(Error::ClientCertNoTls);
512 }
513 MaybeTls::Plain { inner: stream }
514 };
515 Ok(Stream { inner: stream })
516 }
517}
518
519#[derive(Debug)]
520enum BaseParams<'a> {
521 Tcp(&'a SocketAddr),
523 Unix(&'a Path),
524}
525
526struct SocksParams<'a> {
527 version: SocksVersion,
528 target: tokio_socks::Result<TargetAddr<'a>>,
529 auth: Option<SocksAuth<'a>>,
530}
531
532impl fmt::Debug for SocksParams<'_> {
533 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
534 fmt::Debug::fmt(&self.version, f)
535 }
536}
537
538struct SocksAuth<'a> {
539 username: &'a str,
540 password: &'a str,
541}
542
543#[derive(Debug)]
544enum SocksVersion {
545 Socks4,
546 Socks5,
547}
548
549#[derive(Debug)]
550struct TlsParams {
551 domain: ServerName<'static>,
552 verification: TlsVerify,
553}
554
555#[derive(Debug)]
556enum TlsVerify {
557 Insecure,
558 CaStore(Arc<RootCertStore>),
559 WebPki(Arc<WebPkiServerVerifier>),
560}
561
562#[derive(Debug)]
563struct ClientCert {
564 cert_chain: Vec<CertificateDer<'static>>,
565 key_der: PrivateKeyDer<'static>,
566}