1use pin_project_lite::pin_project;
4use std::{
5 fmt,
6 net::SocketAddr,
7 path::Path,
8 pin::Pin,
9 sync::Arc,
10 task::{Context, Poll},
11};
12use tokio::{
13 io::{AsyncRead, AsyncWrite, ReadBuf},
14 net::{TcpStream, UnixStream},
15};
16use tokio_rustls::{
17 client::TlsStream,
18 rustls::{
19 client::WebPkiServerVerifier,
20 pki_types::{CertificateDer, PrivateKeyDer, ServerName},
21 ClientConfig, RootCertStore,
22 },
23 TlsConnector,
24};
25use tokio_socks::{
26 tcp::{socks4::Socks4Stream, socks5::Socks5Stream},
27 IntoTargetAddr, TargetAddr,
28};
29
30pub use tokio_rustls;
31
32mod danger;
33
34#[derive(Debug, foxerror::FoxError)]
36#[non_exhaustive]
37pub enum Error {
38 ClientCertNoTls,
40 #[err(from)]
42 Connect(std::io::Error),
43 #[err(from)]
45 Socks(tokio_socks::Error),
46 #[err(from)]
48 Rustls(tokio_rustls::rustls::Error),
49}
50
51pin_project! {
52 #[derive(Debug)]
54 pub struct Stream {
55 #[pin]
56 inner: MaybeTls,
57 }
58}
59
60impl Stream {
61 pub fn new_tcp(addr: &SocketAddr) -> StreamBuilder<'_> {
71 StreamBuilder::new(BaseParams::Tcp(addr))
72 }
73 pub fn new_unix(path: &Path) -> StreamBuilder<'_> {
84 StreamBuilder::new(BaseParams::Unix(path))
85 }
86}
87
88impl AsyncRead for Stream {
89 #[inline]
90 fn poll_read(
91 self: Pin<&mut Self>,
92 cx: &mut Context<'_>,
93 buf: &mut ReadBuf<'_>,
94 ) -> Poll<std::io::Result<()>> {
95 self.project().inner.poll_read(cx, buf)
96 }
97}
98
99impl AsyncWrite for Stream {
100 #[inline]
101 fn poll_write(
102 self: Pin<&mut Self>,
103 cx: &mut Context<'_>,
104 buf: &[u8],
105 ) -> Poll<Result<usize, std::io::Error>> {
106 self.project().inner.poll_write(cx, buf)
107 }
108 #[inline]
109 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
110 self.project().inner.poll_flush(cx)
111 }
112 #[inline]
113 fn poll_shutdown(
114 self: Pin<&mut Self>,
115 cx: &mut Context<'_>,
116 ) -> Poll<Result<(), std::io::Error>> {
117 self.project().inner.poll_shutdown(cx)
118 }
119}
120
121pin_project! {
122 #[project = MaybeTlsProj]
123 #[derive(Debug)]
124 #[allow(clippy::large_enum_variant)] enum MaybeTls {
126 Plain {
127 #[pin]
128 inner: MaybeSocks,
129 },
130 Tls {
131 #[pin]
132 inner: TlsStream<MaybeSocks>,
133 },
134 }
135}
136
137impl AsyncRead for MaybeTls {
138 #[inline]
139 fn poll_read(
140 self: Pin<&mut Self>,
141 cx: &mut Context,
142 buf: &mut ReadBuf<'_>,
143 ) -> Poll<std::io::Result<()>> {
144 match self.project() {
145 MaybeTlsProj::Plain { inner } => inner.poll_read(cx, buf),
146 MaybeTlsProj::Tls { inner } => inner.poll_read(cx, buf),
147 }
148 }
149}
150
151impl AsyncWrite for MaybeTls {
152 #[inline]
153 fn poll_write(
154 self: Pin<&mut Self>,
155 cx: &mut Context<'_>,
156 buf: &[u8],
157 ) -> Poll<Result<usize, std::io::Error>> {
158 match self.project() {
159 MaybeTlsProj::Plain { inner } => inner.poll_write(cx, buf),
160 MaybeTlsProj::Tls { inner } => inner.poll_write(cx, buf),
161 }
162 }
163 #[inline]
164 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
165 match self.project() {
166 MaybeTlsProj::Plain { inner } => inner.poll_flush(cx),
167 MaybeTlsProj::Tls { inner } => inner.poll_flush(cx),
168 }
169 }
170 #[inline]
171 fn poll_shutdown(
172 self: Pin<&mut Self>,
173 cx: &mut Context<'_>,
174 ) -> Poll<Result<(), std::io::Error>> {
175 match self.project() {
176 MaybeTlsProj::Plain { inner } => inner.poll_shutdown(cx),
177 MaybeTlsProj::Tls { inner } => inner.poll_shutdown(cx),
178 }
179 }
180}
181
182pin_project! {
183 #[project = MaybeSocksProj]
184 #[derive(Debug)]
185 enum MaybeSocks {
186 Clear {
187 #[pin]
188 inner: BaseStream,
189 },
190 Socks4 {
191 #[pin]
192 inner: Socks4Stream<BaseStream>,
193 },
194 Socks5 {
195 #[pin]
196 inner: Socks5Stream<BaseStream>,
197 },
198 }
199}
200
201impl AsyncRead for MaybeSocks {
202 #[inline]
203 fn poll_read(
204 self: Pin<&mut Self>,
205 cx: &mut Context<'_>,
206 buf: &mut ReadBuf<'_>,
207 ) -> Poll<std::io::Result<()>> {
208 match self.project() {
209 MaybeSocksProj::Clear { inner } => inner.poll_read(cx, buf),
210 MaybeSocksProj::Socks4 { inner } => inner.poll_read(cx, buf),
211 MaybeSocksProj::Socks5 { inner } => inner.poll_read(cx, buf),
212 }
213 }
214}
215
216impl AsyncWrite for MaybeSocks {
217 #[inline]
218 fn poll_write(
219 self: Pin<&mut Self>,
220 cx: &mut Context<'_>,
221 buf: &[u8],
222 ) -> Poll<Result<usize, std::io::Error>> {
223 match self.project() {
224 MaybeSocksProj::Clear { inner } => inner.poll_write(cx, buf),
225 MaybeSocksProj::Socks4 { inner } => inner.poll_write(cx, buf),
226 MaybeSocksProj::Socks5 { inner } => inner.poll_write(cx, buf),
227 }
228 }
229 #[inline]
230 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
231 match self.project() {
232 MaybeSocksProj::Clear { inner } => inner.poll_flush(cx),
233 MaybeSocksProj::Socks4 { inner } => inner.poll_flush(cx),
234 MaybeSocksProj::Socks5 { inner } => inner.poll_flush(cx),
235 }
236 }
237 #[inline]
238 fn poll_shutdown(
239 self: Pin<&mut Self>,
240 cx: &mut Context<'_>,
241 ) -> Poll<Result<(), std::io::Error>> {
242 match self.project() {
243 MaybeSocksProj::Clear { inner } => inner.poll_shutdown(cx),
244 MaybeSocksProj::Socks4 { inner } => inner.poll_shutdown(cx),
245 MaybeSocksProj::Socks5 { inner } => inner.poll_shutdown(cx),
246 }
247 }
248}
249
250pin_project! {
251 #[project = BaseStreamProj]
252 #[derive(Debug)]
253 enum BaseStream {
254 Tcp {
255 #[pin]
256 inner: TcpStream,
257 },
258 Unix {
259 #[pin]
260 inner: UnixStream,
261 },
262 }
263}
264
265impl AsyncRead for BaseStream {
266 #[inline]
267 fn poll_read(
268 self: Pin<&mut Self>,
269 cx: &mut Context,
270 buf: &mut ReadBuf<'_>,
271 ) -> Poll<std::io::Result<()>> {
272 match self.project() {
273 BaseStreamProj::Tcp { inner } => inner.poll_read(cx, buf),
274 BaseStreamProj::Unix { inner } => inner.poll_read(cx, buf),
275 }
276 }
277}
278
279impl AsyncWrite for BaseStream {
280 #[inline]
281 fn poll_write(
282 self: Pin<&mut Self>,
283 cx: &mut Context<'_>,
284 buf: &[u8],
285 ) -> Poll<Result<usize, std::io::Error>> {
286 match self.project() {
287 BaseStreamProj::Tcp { inner } => inner.poll_write(cx, buf),
288 BaseStreamProj::Unix { inner } => inner.poll_write(cx, buf),
289 }
290 }
291 #[inline]
292 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
293 match self.project() {
294 BaseStreamProj::Tcp { inner } => inner.poll_flush(cx),
295 BaseStreamProj::Unix { inner } => inner.poll_flush(cx),
296 }
297 }
298 #[inline]
299 fn poll_shutdown(
300 self: Pin<&mut Self>,
301 cx: &mut Context<'_>,
302 ) -> Poll<Result<(), std::io::Error>> {
303 match self.project() {
304 BaseStreamProj::Tcp { inner } => inner.poll_shutdown(cx),
305 BaseStreamProj::Unix { inner } => inner.poll_shutdown(cx),
306 }
307 }
308}
309
310#[derive(Debug)]
312pub struct StreamBuilder<'a> {
313 base: BaseParams<'a>,
314 socks: Option<SocksParams<'a>>,
315 tls: Option<TlsParams>,
316 client_cert: Option<ClientCert>,
317}
318
319impl<'a> StreamBuilder<'a> {
320 fn new(base: BaseParams<'a>) -> Self {
321 Self {
322 base,
323 socks: None,
324 tls: None,
325 client_cert: None,
326 }
327 }
328
329 fn socks(
330 mut self,
331 version: SocksVersion,
332 target: impl IntoTargetAddr<'a>,
333 auth: Option<SocksAuth<'a>>,
334 ) -> Self {
335 self.socks = Some(SocksParams {
336 version,
337 target: target.into_target_addr(),
338 auth,
339 });
340 self
341 }
342
343 pub fn socks4(self, target: impl IntoTargetAddr<'a>) -> Self {
355 self.socks(SocksVersion::Socks4, target, None)
356 }
357
358 pub fn socks4_with_userid(self, target: impl IntoTargetAddr<'a>, userid: &'a str) -> Self {
370 self.socks(
371 SocksVersion::Socks4,
372 target,
373 Some(SocksAuth {
374 username: userid,
375 password: "h",
376 }),
377 )
378 }
379
380 pub fn socks5(self, target: impl IntoTargetAddr<'a>) -> Self {
392 self.socks(SocksVersion::Socks5, target, None)
393 }
394
395 pub fn socks5_with_password(
407 self,
408 target: impl IntoTargetAddr<'a>,
409 username: &'a str,
410 password: &'a str,
411 ) -> Self {
412 self.socks(
413 SocksVersion::Socks5,
414 target,
415 Some(SocksAuth { username, password }),
416 )
417 }
418
419 fn tls(mut self, domain: ServerName<'static>, verification: TlsVerify) -> Self {
420 self.tls = Some(TlsParams {
421 domain,
422 verification,
423 });
424 self
425 }
426
427 pub fn tls_danger_insecure(self, domain: ServerName<'static>) -> Self {
440 self.tls(domain, TlsVerify::Insecure)
441 }
442
443 pub fn tls_with_root(
466 self,
467 domain: ServerName<'static>,
468 root: impl Into<Arc<RootCertStore>>,
469 ) -> Self {
470 self.tls(domain, TlsVerify::CaStore(root.into()))
471 }
472
473 pub fn tls_with_webpki(
475 self,
476 domain: ServerName<'static>,
477 webpki: Arc<WebPkiServerVerifier>,
478 ) -> Self {
479 self.tls(domain, TlsVerify::WebPki(webpki))
480 }
481
482 pub fn client_cert(
503 mut self,
504 cert_chain: Vec<CertificateDer<'static>>,
505 key_der: PrivateKeyDer<'static>,
506 ) -> Self {
507 self.client_cert = Some(ClientCert {
508 cert_chain,
509 key_der,
510 });
511 self
512 }
513
514 pub async fn connect(self) -> Result<Stream, Error> {
526 let stream = match self.base {
527 BaseParams::Tcp(addr) => BaseStream::Tcp {
528 inner: TcpStream::connect(addr).await?,
529 },
530 BaseParams::Unix(path) => BaseStream::Unix {
531 inner: UnixStream::connect(path).await?,
532 },
533 };
534 let stream = if let Some(params) = self.socks {
535 let target = params.target?;
536 match params.version {
537 SocksVersion::Socks4 => MaybeSocks::Socks4 {
538 inner: if let Some(SocksAuth { username, .. }) = params.auth {
539 Socks4Stream::connect_with_userid_and_socket(stream, target, username)
540 .await?
541 } else {
542 Socks4Stream::connect_with_socket(stream, target).await?
543 },
544 },
545 SocksVersion::Socks5 => MaybeSocks::Socks5 {
546 inner: if let Some(SocksAuth { username, password }) = params.auth {
547 Socks5Stream::connect_with_password_and_socket(
548 stream, target, username, password,
549 )
550 .await?
551 } else {
552 Socks5Stream::connect_with_socket(stream, target).await?
553 },
554 },
555 }
556 } else {
557 MaybeSocks::Clear { inner: stream }
558 };
559 let stream = if let Some(params) = self.tls {
560 let config = ClientConfig::builder();
561 let config = match params.verification {
562 TlsVerify::Insecure => {
563 let provider = config.crypto_provider().clone();
564 config
565 .dangerous()
566 .with_custom_certificate_verifier(danger::PhonyVerify::new(provider))
567 }
568 TlsVerify::CaStore(root) => config.with_root_certificates(root),
569 TlsVerify::WebPki(webpki) => config.with_webpki_verifier(webpki),
570 };
571 let config = if let Some(ClientCert {
572 cert_chain,
573 key_der,
574 }) = self.client_cert
575 {
576 config.with_client_auth_cert(cert_chain, key_der)?
577 } else {
578 config.with_no_client_auth()
579 };
580 let connector = TlsConnector::from(Arc::new(config));
581 let inner = connector.connect(params.domain, stream).await?;
582 MaybeTls::Tls { inner }
583 } else {
584 if self.client_cert.is_some() {
585 return Err(Error::ClientCertNoTls);
586 }
587 MaybeTls::Plain { inner: stream }
588 };
589 Ok(Stream { inner: stream })
590 }
591}
592
593#[derive(Debug)]
594enum BaseParams<'a> {
595 Tcp(&'a SocketAddr),
597 Unix(&'a Path),
598}
599
600struct SocksParams<'a> {
601 version: SocksVersion,
602 target: tokio_socks::Result<TargetAddr<'a>>,
603 auth: Option<SocksAuth<'a>>,
604}
605
606impl fmt::Debug for SocksParams<'_> {
607 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
608 fmt::Debug::fmt(&self.version, f)
609 }
610}
611
612struct SocksAuth<'a> {
613 username: &'a str,
614 password: &'a str,
615}
616
617#[derive(Debug)]
618enum SocksVersion {
619 Socks4,
620 Socks5,
621}
622
623#[derive(Debug)]
624struct TlsParams {
625 domain: ServerName<'static>,
626 verification: TlsVerify,
627}
628
629#[derive(Debug)]
630enum TlsVerify {
631 Insecure,
632 CaStore(Arc<RootCertStore>),
633 WebPki(Arc<WebPkiServerVerifier>),
634}
635
636#[derive(Debug)]
637struct ClientCert {
638 cert_chain: Vec<CertificateDer<'static>>,
639 key_der: PrivateKeyDer<'static>,
640}