1#![allow(missing_docs)]
60
61mod rt;
62mod stream;
63mod tunnel;
64
65use std::{fmt, io, sync::Arc};
66use std::{
67 future::Future,
68 pin::Pin,
69 task::{Context, Poll},
70};
71
72use futures_util::future::TryFutureExt;
73use headers::{authorization::Credentials, Authorization, HeaderMapExt, ProxyAuthorization};
74use http::header::{HeaderMap, HeaderName, HeaderValue};
75use hyper::rt::{Read, Write};
76use hyper::Uri;
77use tower_service::Service;
78
79pub use stream::ProxyStream;
80
81#[cfg(all(not(feature = "__rustls"), feature = "native-tls"))]
82use native_tls::TlsConnector as NativeTlsConnector;
83
84#[cfg(all(not(feature = "__rustls"), feature = "native-tls"))]
85use tokio_native_tls::TlsConnector;
86
87#[cfg(feature = "__rustls")]
88use hyper_rustls::ConfigBuilderExt;
89
90#[cfg(feature = "__rustls")]
91use tokio_rustls::TlsConnector;
92
93#[cfg(feature = "__rustls")]
94use tokio_rustls::rustls::pki_types::ServerName;
95
96type BoxError = Box<dyn std::error::Error + Send + Sync>;
97
98#[derive(Debug, Clone)]
100pub enum Intercept {
101 All,
103 Http,
105 Https,
107 None,
109 Custom(Custom),
111}
112
113pub trait Dst {
115 fn scheme(&self) -> Option<&str>;
117 fn host(&self) -> Option<&str>;
119 fn port(&self) -> Option<u16>;
121}
122
123impl Dst for Uri {
124 fn scheme(&self) -> Option<&str> {
125 self.scheme_str()
126 }
127
128 fn host(&self) -> Option<&str> {
129 self.host()
130 }
131
132 fn port(&self) -> Option<u16> {
133 self.port_u16()
134 }
135}
136
137#[inline]
138pub(crate) fn io_err<E: Into<Box<dyn std::error::Error + Send + Sync>>>(e: E) -> io::Error {
139 io::Error::new(io::ErrorKind::Other, e)
140}
141
142pub type CustomProxyCallback =
143 dyn Fn(Option<&str>, Option<&str>, Option<u16>) -> bool + Send + Sync;
144
145#[derive(Clone)]
147pub struct Custom(Arc<CustomProxyCallback>);
148
149impl fmt::Debug for Custom {
150 fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
151 write!(f, "_")
152 }
153}
154
155impl<F: Fn(Option<&str>, Option<&str>, Option<u16>) -> bool + Send + Sync + 'static> From<F>
156 for Custom
157{
158 fn from(f: F) -> Custom {
159 Custom(Arc::new(f))
160 }
161}
162
163impl Intercept {
164 pub fn matches<D: Dst>(&self, uri: &D) -> bool {
166 match (self, uri.scheme()) {
167 (&Intercept::All, _)
168 | (&Intercept::Http, Some("http"))
169 | (&Intercept::Https, Some("https")) => true,
170 (&Intercept::Custom(Custom(ref f)), _) => f(uri.scheme(), uri.host(), uri.port()),
171 _ => false,
172 }
173 }
174}
175
176impl<F: Fn(Option<&str>, Option<&str>, Option<u16>) -> bool + Send + Sync + 'static> From<F>
177 for Intercept
178{
179 fn from(f: F) -> Intercept {
180 Intercept::Custom(f.into())
181 }
182}
183
184#[derive(Clone, Debug)]
186pub struct Proxy {
187 intercept: Intercept,
188 force_connect: bool,
189 headers: HeaderMap,
190 uri: Uri,
191}
192
193impl Proxy {
194 pub fn new<I: Into<Intercept>>(intercept: I, uri: Uri) -> Proxy {
196 let mut proxy = Proxy {
197 intercept: intercept.into(),
198 uri: uri.clone(),
199 headers: HeaderMap::new(),
200 force_connect: false,
201 };
202
203 if let Some((user, pass)) = extract_user_pass(&uri) {
204 proxy.set_authorization(Authorization::basic(user, pass));
205 }
206
207 proxy
208 }
209
210 pub fn set_authorization<C: Credentials + Clone>(&mut self, credentials: Authorization<C>) {
212 match self.intercept {
213 Intercept::Http => {
214 self.headers.typed_insert(Authorization(credentials.0));
215 }
216 Intercept::Https => {
217 self.headers.typed_insert(ProxyAuthorization(credentials.0));
218 }
219 _ => {
220 self.headers
221 .typed_insert(Authorization(credentials.0.clone()));
222 self.headers.typed_insert(ProxyAuthorization(credentials.0));
223 }
224 }
225 }
226
227 pub fn force_connect(&mut self) {
229 self.force_connect = true;
230 }
231
232 pub fn set_header(&mut self, name: HeaderName, value: HeaderValue) {
234 self.headers.insert(name, value);
235 }
236
237 pub fn intercept(&self) -> &Intercept {
239 &self.intercept
240 }
241
242 pub fn headers(&self) -> &HeaderMap {
244 &self.headers
245 }
246
247 pub fn uri(&self) -> &Uri {
249 &self.uri
250 }
251}
252
253#[derive(Clone)]
255pub struct ProxyConnector<C> {
256 proxies: Vec<Proxy>,
257 connector: C,
258
259 #[cfg(all(not(feature = "__rustls"), feature = "native-tls"))]
260 tls: Option<NativeTlsConnector>,
261
262 #[cfg(feature = "__rustls")]
263 tls: Option<TlsConnector>,
264
265 #[cfg(not(feature = "__tls"))]
266 tls: Option<()>,
267}
268
269impl<C: fmt::Debug> fmt::Debug for ProxyConnector<C> {
270 fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
271 write!(
272 f,
273 "ProxyConnector {}{{ proxies: {:?}, connector: {:?} }}",
274 if self.tls.is_some() {
275 ""
276 } else {
277 "(unsecured)"
278 },
279 self.proxies,
280 self.connector
281 )
282 }
283}
284
285impl<C> ProxyConnector<C> {
286 #[cfg(all(not(feature = "__rustls"), feature = "native-tls"))]
288 pub fn new(connector: C) -> Result<Self, io::Error> {
289 let tls = NativeTlsConnector::builder()
290 .build()
291 .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
292
293 Ok(ProxyConnector {
294 proxies: Vec::new(),
295 connector: connector,
296 tls: Some(tls),
297 })
298 }
299
300 #[cfg(feature = "__rustls")]
302 pub fn new(connector: C) -> Result<Self, io::Error> {
303 let config = tokio_rustls::rustls::ClientConfig::builder();
304
305 #[cfg(all(
306 feature = "rustls-tls-native-roots",
307 not(feature = "rustls-tls-webpki-roots")
308 ))]
309 let config = config.with_native_roots()?;
310
311 #[cfg(feature = "rustls-tls-webpki-roots")]
312 let config = config.with_webpki_roots();
313
314 let cfg = Arc::new(config.with_no_client_auth());
315 let tls = TlsConnector::from(cfg);
316
317 Ok(ProxyConnector {
318 proxies: Vec::new(),
319 connector,
320 tls: Some(tls),
321 })
322 }
323
324 pub fn unsecured(connector: C) -> Self {
326 ProxyConnector {
327 proxies: Vec::new(),
328 connector,
329 tls: None,
330 }
331 }
332
333 #[cfg(feature = "__tls")]
335 pub fn from_proxy(connector: C, proxy: Proxy) -> Result<Self, io::Error> {
336 let mut c = ProxyConnector::new(connector)?;
337 c.proxies.push(proxy);
338 Ok(c)
339 }
340
341 pub fn from_proxy_unsecured(connector: C, proxy: Proxy) -> Self {
343 let mut c = ProxyConnector::unsecured(connector);
344 c.proxies.push(proxy);
345 c
346 }
347
348 pub fn with_connector<CC>(self, connector: CC) -> ProxyConnector<CC> {
350 ProxyConnector {
351 connector,
352 proxies: self.proxies,
353 tls: self.tls,
354 }
355 }
356
357 #[cfg(all(not(feature = "__rustls"), feature = "native-tls"))]
359 pub fn set_tls(&mut self, tls: Option<NativeTlsConnector>) {
360 self.tls = tls;
361 }
362
363 #[cfg(feature = "__rustls")]
365 pub fn set_tls(&mut self, tls: Option<TlsConnector>) {
366 self.tls = tls;
367 }
368
369 pub fn proxies(&self) -> &[Proxy] {
371 &self.proxies
372 }
373
374 pub fn add_proxy(&mut self, proxy: Proxy) {
376 self.proxies.push(proxy);
377 }
378
379 pub fn extend_proxies<I: IntoIterator<Item = Proxy>>(&mut self, proxies: I) {
381 self.proxies.extend(proxies)
382 }
383
384 pub fn http_headers(&self, uri: &Uri) -> Option<&HeaderMap> {
389 if uri.scheme_str() != Some("http") {
390 return None;
391 }
392
393 self.match_proxy(uri).map(|p| &p.headers)
394 }
395
396 fn match_proxy<D: Dst>(&self, uri: &D) -> Option<&Proxy> {
397 self.proxies.iter().find(|p| p.intercept.matches(uri))
398 }
399}
400
401macro_rules! mtry {
402 ($e:expr) => {
403 match $e {
404 Ok(v) => v,
405 Err(e) => break Err(e.into()),
406 }
407 };
408}
409
410impl<C> Service<Uri> for ProxyConnector<C>
411where
412 C: Service<Uri>,
413 C::Response: Read + Write + Send + Unpin + 'static,
414 C::Future: Send + 'static,
415 C::Error: Into<BoxError>,
416{
417 type Response = ProxyStream<C::Response>;
418 type Error = io::Error;
419 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
420
421 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
422 match self.connector.poll_ready(cx) {
423 Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
424 Poll::Ready(Err(e)) => Poll::Ready(Err(io_err(e.into()))),
425 Poll::Pending => Poll::Pending,
426 }
427 }
428
429 fn call(&mut self, uri: Uri) -> Self::Future {
430 if let (Some(p), Some(host)) = (self.match_proxy(&uri), uri.host()) {
431 if uri.scheme() == Some(&http::uri::Scheme::HTTPS) || p.force_connect {
432 let host = host.to_owned();
433 let port =
434 uri.port_u16()
435 .unwrap_or(if uri.scheme() == Some(&http::uri::Scheme::HTTP) {
436 80
437 } else {
438 443
439 });
440
441 let tunnel = tunnel::new(&host, port, &p.headers);
442 let connection =
443 proxy_dst(&uri, &p.uri).map(|proxy_url| self.connector.call(proxy_url));
444 let tls = if uri.scheme() == Some(&http::uri::Scheme::HTTPS) {
445 self.tls.clone()
446 } else {
447 None
448 };
449
450 Box::pin(async move {
451 #[allow(clippy::never_loop)]
453 loop {
454 let proxy_stream = mtry!(mtry!(connection).await.map_err(io_err));
455 let tunnel_stream = mtry!(tunnel.with_stream(proxy_stream).await);
456
457 break match tls {
458 #[cfg(all(not(feature = "__rustls"), feature = "native-tls"))]
459 Some(tls) => {
460 use hyper_util::rt::TokioIo;
461 let tls = TlsConnector::from(tls);
462 let secure_stream = mtry!(tls
463 .connect(&host, TokioIo::new(tunnel_stream))
464 .await
465 .map_err(io_err));
466
467 Ok(ProxyStream::Secured(Box::new(TokioIo::new(secure_stream))))
468 }
469
470 #[cfg(feature = "__rustls")]
471 Some(tls) => {
472 use hyper_util::rt::TokioIo;
473 let server_name =
474 mtry!(ServerName::try_from(host.to_string()).map_err(io_err));
475 let secure_stream = mtry!(tls
476 .connect(server_name, TokioIo::new(tunnel_stream))
477 .await
478 .map_err(io_err));
479
480 Ok(ProxyStream::Secured(Box::new(TokioIo::new(secure_stream))))
481 }
482
483 #[cfg(not(feature = "__tls",))]
484 Some(_) => panic!("hyper-proxy was not built with TLS support"),
485
486 None => Ok(ProxyStream::Regular(tunnel_stream)),
487 };
488 }
489 })
490 } else {
491 match proxy_dst(&uri, &p.uri) {
492 Ok(proxy_uri) => Box::pin(
493 self.connector
494 .call(proxy_uri)
495 .map_ok(ProxyStream::Regular)
496 .map_err(|err| io_err(err.into())),
497 ),
498 Err(err) => Box::pin(futures_util::future::err(io_err(err))),
499 }
500 }
501 } else {
502 Box::pin(
503 self.connector
504 .call(uri)
505 .map_ok(ProxyStream::NoProxy)
506 .map_err(|err| io_err(err.into())),
507 )
508 }
509 }
510}
511
512fn proxy_dst(dst: &Uri, proxy: &Uri) -> io::Result<Uri> {
513 Uri::builder()
514 .scheme(
515 proxy
516 .scheme_str()
517 .ok_or_else(|| io_err(format!("proxy uri missing scheme: {}", proxy)))?,
518 )
519 .authority(
520 proxy
521 .authority()
522 .ok_or_else(|| io_err(format!("proxy uri missing host: {}", proxy)))?
523 .clone(),
524 )
525 .path_and_query(
526 dst.path_and_query()
527 .ok_or_else(|| io_err(format!("dst uri missing path: {}", proxy)))?
528 .clone(),
529 )
530 .build()
531 .map_err(|err| io_err(format!("other error: {}", err)))
532}
533
534fn extract_user_pass(uri: &Uri) -> Option<(&str, &str)> {
536 let authority = uri.authority()?.as_str();
537 let (userinfo, _) = authority.rsplit_once('@')?;
538 let (username, password) = userinfo.split_once(':')?;
539
540 Some((username, password))
541}
542
543#[cfg(test)]
544mod tests {
545 use http::Uri;
546
547 use crate::{Intercept, Proxy};
548
549 #[test]
550 fn test_new_proxy_with_authorization() {
551 let proxy = Proxy::new(
552 Intercept::All,
553 Uri::from_static("https://bob:secret@my-proxy:8080"),
554 );
555
556 assert_eq!(
557 proxy
558 .headers()
559 .get("authorization")
560 .unwrap()
561 .to_str()
562 .unwrap(),
563 "Basic Ym9iOnNlY3JldA=="
564 );
565 }
566
567 #[test]
568 fn test_new_proxy_without_authorization() {
569 let proxy = Proxy::new(Intercept::All, Uri::from_static("https://my-proxy:8080"));
570
571 assert_eq!(proxy.headers().get("authorization"), None);
572 }
573}