1use pin_project_lite::pin_project;
8use std::{
9 fmt,
10 net::SocketAddr,
11 path::Path,
12 pin::Pin,
13 sync::Arc,
14 task::{Context, Poll},
15};
16use tokio::{
17 io::{AsyncRead, AsyncWrite, ReadBuf},
18 net::{TcpStream, UnixStream},
19};
20use tokio_rustls::{
21 client::TlsStream,
22 rustls::{
23 client::WebPkiServerVerifier,
24 pki_types::{CertificateDer, PrivateKeyDer, ServerName},
25 ClientConfig, RootCertStore,
26 },
27 TlsConnector,
28};
29use tokio_socks::{
30 tcp::{socks4::Socks4Stream, socks5::Socks5Stream},
31 IntoTargetAddr, TargetAddr,
32};
33
34pub use tokio_rustls;
35
36mod danger;
37
38#[deprecated(since = "0.2.1", note = "Stream was renamed to Connection")]
39pub type Stream = Connection;
40#[deprecated(
41 since = "0.2.1",
42 note = "StreamBuilder was renamed to ConnectionBuilder"
43)]
44pub type StreamBuilder<'a> = ConnectionBuilder<'a>;
45
46#[derive(Debug)]
48#[non_exhaustive]
49pub enum Error {
50 ClientCertNoTls,
52 Connect(std::io::Error),
54 Socks(tokio_socks::Error),
56 Rustls(tokio_rustls::rustls::Error),
58 SocksToUnsupported,
60 InvalidTarget(tokio_socks::Error),
62 NoServerName,
64}
65
66impl fmt::Display for Error {
67 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
68 match self {
69 Self::ClientCertNoTls => write!(f, "you specified a client cert without using tls"),
70 Self::Connect(e) => write!(f, "failed to connect: {e}"),
71 Self::Socks(e) => write!(f, "could not sock: {e}"),
72 Self::Rustls(e) => write!(f, "could not rustls: {e}"),
73 Self::SocksToUnsupported => write!(f, "socks cannot connect to unix sockets"),
74 Self::InvalidTarget(e) => write!(f, "invalid target address: {e}"),
75 Self::NoServerName => write!(f, "no tls servername provided and failed to guess it"),
76 }
77 }
78}
79
80impl std::error::Error for Error {}
81
82impl From<std::io::Error> for Error {
83 fn from(value: std::io::Error) -> Self {
84 Self::Connect(value)
85 }
86}
87
88impl From<tokio_socks::Error> for Error {
89 fn from(value: tokio_socks::Error) -> Self {
90 Self::Socks(value)
91 }
92}
93
94impl From<tokio_rustls::rustls::Error> for Error {
95 fn from(value: tokio_rustls::rustls::Error) -> Self {
96 Self::Rustls(value)
97 }
98}
99
100pin_project! {
101 #[derive(Debug)]
103 pub struct Connection {
104 #[pin]
105 inner: MaybeTls,
106 }
107}
108
109impl Connection {
110 pub fn new_tcp<'a>(addr: impl IntoTargetAddr<'a>) -> ConnectionBuilder<'a> {
120 ConnectionBuilder::new(BaseParams::Tcp(addr.into_target_addr()))
121 }
122 pub fn new_unix(path: &Path) -> ConnectionBuilder<'_> {
133 ConnectionBuilder::new(BaseParams::Unix(path))
134 }
135}
136
137impl AsyncRead for Connection {
138 #[inline]
139 fn poll_read(
140 self: Pin<&mut Self>,
141 cx: &mut Context<'_>,
142 buf: &mut ReadBuf<'_>,
143 ) -> Poll<std::io::Result<()>> {
144 self.project().inner.poll_read(cx, buf)
145 }
146}
147
148impl AsyncWrite for Connection {
149 #[inline]
150 fn poll_write(
151 self: Pin<&mut Self>,
152 cx: &mut Context<'_>,
153 buf: &[u8],
154 ) -> Poll<Result<usize, std::io::Error>> {
155 self.project().inner.poll_write(cx, buf)
156 }
157 #[inline]
158 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
159 self.project().inner.poll_flush(cx)
160 }
161 #[inline]
162 fn poll_shutdown(
163 self: Pin<&mut Self>,
164 cx: &mut Context<'_>,
165 ) -> Poll<Result<(), std::io::Error>> {
166 self.project().inner.poll_shutdown(cx)
167 }
168}
169
170pin_project! {
171 #[project = MaybeTlsProj]
172 #[derive(Debug)]
173 #[allow(clippy::large_enum_variant)] enum MaybeTls {
175 Plain {
176 #[pin]
177 inner: MaybeSocks,
178 },
179 Tls {
180 #[pin]
181 inner: TlsStream<MaybeSocks>,
182 },
183 }
184}
185
186macro_rules! trivial_impl {
187 ($target:ty, ($($arm:path),*)) => {
188 impl AsyncRead for $target {
189 #[inline]
190 fn poll_read(
191 self: Pin<&mut Self>,
192 cx: &mut Context<'_>,
193 buf: &mut ReadBuf<'_>,
194 ) -> Poll<std::io::Result<()>> {
195 match self.project() {
196 $($arm { inner } => inner.poll_read(cx, buf),)*
197 }
198 }
199 }
200
201 impl AsyncWrite for $target {
202 #[inline]
203 fn poll_write(
204 self: Pin<&mut Self>,
205 cx: &mut Context<'_>,
206 buf: &[u8],
207 ) -> Poll<Result<usize, std::io::Error>> {
208 match self.project() {
209 $($arm { inner } => inner.poll_write(cx, buf),)*
210 }
211 }
212 #[inline]
213 fn poll_flush(
214 self: Pin<&mut Self>,
215 cx: &mut Context<'_>,
216 ) -> Poll<Result<(), std::io::Error>> {
217 match self.project() {
218 $($arm { inner } => inner.poll_flush(cx),)*
219 }
220 }
221 #[inline]
222 fn poll_shutdown(
223 self: Pin<&mut Self>,
224 cx: &mut Context<'_>,
225 ) -> Poll<Result<(), std::io::Error>> {
226 match self.project() {
227 $($arm { inner } => inner.poll_shutdown(cx),)*
228 }
229 }
230 }
231 };
232}
233
234trivial_impl!(MaybeTls, (MaybeTlsProj::Plain, MaybeTlsProj::Tls));
235
236pin_project! {
237 #[project = MaybeSocksProj]
238 #[derive(Debug)]
239 enum MaybeSocks {
240 Clear {
241 #[pin]
242 inner: BaseStream,
243 },
244 Socks4 {
245 #[pin]
246 inner: Socks4Stream<BaseStream>,
247 },
248 Socks5 {
249 #[pin]
250 inner: Socks5Stream<BaseStream>,
251 },
252 }
253}
254
255trivial_impl!(
256 MaybeSocks,
257 (
258 MaybeSocksProj::Clear,
259 MaybeSocksProj::Socks4,
260 MaybeSocksProj::Socks5
261 )
262);
263
264pin_project! {
265 #[project = BaseStreamProj]
266 #[derive(Debug)]
267 enum BaseStream {
268 Tcp {
269 #[pin]
270 inner: TcpStream,
271 },
272 Unix {
273 #[pin]
274 inner: UnixStream,
275 },
276 }
277}
278
279trivial_impl!(BaseStream, (BaseStreamProj::Tcp, BaseStreamProj::Unix));
280
281#[derive(Debug)]
283#[must_use = "this does nothing unless you finish building"]
284pub struct ConnectionBuilder<'a> {
285 base: BaseParams<'a>,
286 socks: Option<SocksParams<'a>>,
287 tls: Option<TlsParams>,
288 client_cert: Option<ClientCert>,
289}
290
291impl<'a> ConnectionBuilder<'a> {
292 fn new(base: BaseParams<'a>) -> Self {
293 Self {
294 base,
295 socks: None,
296 tls: None,
297 client_cert: None,
298 }
299 }
300
301 fn socks(
302 mut self,
303 version: SocksVersion,
304 proxy: SocketAddr,
305 auth: Option<SocksAuth<'a>>,
306 ) -> Self {
307 self.socks = Some(SocksParams {
308 version,
309 proxy,
310 auth,
311 });
312 self
313 }
314
315 pub fn socks4(self, proxy: SocketAddr) -> Self {
326 self.socks(SocksVersion::Socks4, proxy, None)
327 }
328
329 pub fn socks4_with_userid(self, proxy: SocketAddr, userid: &'a str) -> Self {
340 self.socks(
341 SocksVersion::Socks4,
342 proxy,
343 Some(SocksAuth {
344 username: userid,
345 password: "h",
346 }),
347 )
348 }
349
350 pub fn socks5(self, proxy: SocketAddr) -> Self {
361 self.socks(SocksVersion::Socks5, proxy, None)
362 }
363
364 pub fn socks5_with_password(
376 self,
377 proxy: SocketAddr,
378 username: &'a str,
379 password: &'a str,
380 ) -> Self {
381 self.socks(
382 SocksVersion::Socks5,
383 proxy,
384 Some(SocksAuth { username, password }),
385 )
386 }
387
388 fn tls(mut self, domain: Option<ServerName<'static>>, verification: TlsVerify) -> Self {
389 self.tls = Some(TlsParams {
390 domain,
391 verification,
392 });
393 self
394 }
395
396 pub fn tls_danger_insecure(self, domain: Option<ServerName<'static>>) -> Self {
408 self.tls(domain, TlsVerify::Insecure)
409 }
410
411 pub fn tls_with_root(
433 self,
434 domain: Option<ServerName<'static>>,
435 root: impl Into<Arc<RootCertStore>>,
436 ) -> Self {
437 self.tls(domain, TlsVerify::CaStore(root.into()))
438 }
439
440 pub fn tls_with_webpki(
442 self,
443 domain: Option<ServerName<'static>>,
444 webpki: Arc<WebPkiServerVerifier>,
445 ) -> Self {
446 self.tls(domain, TlsVerify::WebPki(webpki))
447 }
448
449 pub fn client_cert(
470 mut self,
471 cert_chain: Vec<CertificateDer<'static>>,
472 key_der: PrivateKeyDer<'static>,
473 ) -> Self {
474 self.client_cert = Some(ClientCert {
475 cert_chain,
476 key_der,
477 });
478 self
479 }
480
481 pub async fn connect(self) -> Result<Connection, Error> {
496 let tls = if let Some(mut params) = self.tls {
497 params.domain = params.domain.or_else(|| match &self.base {
498 BaseParams::Tcp(Ok(TargetAddr::Ip(addr))) => Some(ServerName::from(addr.ip())),
499 BaseParams::Tcp(Ok(TargetAddr::Domain(d, _))) => {
500 ServerName::try_from(d.as_ref()).map(|s| s.to_owned()).ok()
501 }
502 _ => None,
503 });
504 Some(params)
505 } else {
506 None
507 };
508 let stream = if let Some(params) = self.socks {
509 let BaseParams::Tcp(target) = self.base else {
510 return Err(Error::SocksToUnsupported);
511 };
512 let target = target.map_err(Error::InvalidTarget)?;
513 let stream = BaseStream::Tcp {
514 inner: TcpStream::connect(params.proxy).await?,
515 };
516 match params.version {
517 SocksVersion::Socks4 => MaybeSocks::Socks4 {
518 inner: if let Some(SocksAuth { username, .. }) = params.auth {
519 Socks4Stream::connect_with_userid_and_socket(stream, target, username)
520 .await?
521 } else {
522 Socks4Stream::connect_with_socket(stream, target).await?
523 },
524 },
525 SocksVersion::Socks5 => MaybeSocks::Socks5 {
526 inner: if let Some(SocksAuth { username, password }) = params.auth {
527 Socks5Stream::connect_with_password_and_socket(
528 stream, target, username, password,
529 )
530 .await?
531 } else {
532 Socks5Stream::connect_with_socket(stream, target).await?
533 },
534 },
535 }
536 } else {
537 let stream = match self.base {
538 BaseParams::Tcp(addr) => {
539 let inner = match addr.map_err(Error::InvalidTarget)? {
542 TargetAddr::Ip(addr) => TcpStream::connect(addr).await?,
543 TargetAddr::Domain(domain, port) => {
544 TcpStream::connect((domain.as_ref(), port)).await?
545 }
546 };
547 BaseStream::Tcp { inner }
548 }
549 BaseParams::Unix(path) => BaseStream::Unix {
550 inner: UnixStream::connect(path).await?,
551 },
552 };
553 MaybeSocks::Clear { inner: stream }
554 };
555 let stream = if let Some(params) = tls {
556 let config = ClientConfig::builder();
557 let config = match params.verification {
558 TlsVerify::Insecure => {
559 let provider = config.crypto_provider().clone();
560 config
561 .dangerous()
562 .with_custom_certificate_verifier(danger::PhonyVerify::new(provider))
563 }
564 TlsVerify::CaStore(root) => config.with_root_certificates(root),
565 TlsVerify::WebPki(webpki) => config.with_webpki_verifier(webpki),
566 };
567 let config = if let Some(ClientCert {
568 cert_chain,
569 key_der,
570 }) = self.client_cert
571 {
572 config.with_client_auth_cert(cert_chain, key_der)?
573 } else {
574 config.with_no_client_auth()
575 };
576 let connector = TlsConnector::from(Arc::new(config));
577 let domain = params.domain.ok_or(Error::NoServerName)?;
578 let inner = connector.connect(domain, stream).await?;
579 MaybeTls::Tls { inner }
580 } else {
581 if self.client_cert.is_some() {
582 return Err(Error::ClientCertNoTls);
583 }
584 MaybeTls::Plain { inner: stream }
585 };
586 Ok(Connection { inner: stream })
587 }
588}
589
590#[derive(Debug)]
591enum BaseParams<'a> {
592 Tcp(tokio_socks::Result<TargetAddr<'a>>),
593 Unix(&'a Path),
594}
595
596struct SocksParams<'a> {
597 version: SocksVersion,
598 proxy: SocketAddr,
599 auth: Option<SocksAuth<'a>>,
600}
601
602impl fmt::Debug for SocksParams<'_> {
603 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
604 fmt::Debug::fmt(&self.version, f)
605 }
606}
607
608struct SocksAuth<'a> {
609 username: &'a str,
610 password: &'a str,
611}
612
613#[derive(Debug)]
614enum SocksVersion {
615 Socks4,
616 Socks5,
617}
618
619#[derive(Debug)]
620struct TlsParams {
621 domain: Option<ServerName<'static>>,
622 verification: TlsVerify,
623}
624
625#[derive(Debug)]
626enum TlsVerify {
627 Insecure,
628 CaStore(Arc<RootCertStore>),
629 WebPki(Arc<WebPkiServerVerifier>),
630}
631
632#[derive(Debug)]
633struct ClientCert {
634 cert_chain: Vec<CertificateDer<'static>>,
635 key_der: PrivateKeyDer<'static>,
636}