#![allow(missing_docs)]
mod rt;
mod stream;
mod tunnel;
use std::{fmt, io, sync::Arc};
use std::{
future::Future,
pin::Pin,
task::{Context, Poll},
};
use futures_util::future::TryFutureExt;
use headers::{authorization::Credentials, Authorization, HeaderMapExt, ProxyAuthorization};
use http::header::{HeaderMap, HeaderName, HeaderValue};
use hyper::rt::{Read, Write};
use hyper::Uri;
use tower_service::Service;
pub use stream::ProxyStream;
#[cfg(all(not(feature = "__rustls"), feature = "native-tls"))]
use native_tls::TlsConnector as NativeTlsConnector;
#[cfg(all(not(feature = "__rustls"), feature = "native-tls"))]
use tokio_native_tls::TlsConnector;
#[cfg(feature = "__rustls")]
use hyper_rustls::ConfigBuilderExt;
#[cfg(feature = "__rustls")]
use tokio_rustls::TlsConnector;
#[cfg(feature = "__rustls")]
use tokio_rustls::rustls::pki_types::ServerName;
type BoxError = Box<dyn std::error::Error + Send + Sync>;
#[derive(Debug, Clone)]
pub enum Intercept {
All,
Http,
Https,
None,
Custom(Custom),
}
pub trait Dst {
fn scheme(&self) -> Option<&str>;
fn host(&self) -> Option<&str>;
fn port(&self) -> Option<u16>;
}
impl Dst for Uri {
fn scheme(&self) -> Option<&str> {
self.scheme_str()
}
fn host(&self) -> Option<&str> {
self.host()
}
fn port(&self) -> Option<u16> {
self.port_u16()
}
}
#[inline]
pub(crate) fn io_err<E: Into<Box<dyn std::error::Error + Send + Sync>>>(e: E) -> io::Error {
io::Error::new(io::ErrorKind::Other, e)
}
pub type CustomProxyCallback =
dyn Fn(Option<&str>, Option<&str>, Option<u16>) -> bool + Send + Sync;
#[derive(Clone)]
pub struct Custom(Arc<CustomProxyCallback>);
impl fmt::Debug for Custom {
fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
write!(f, "_")
}
}
impl<F: Fn(Option<&str>, Option<&str>, Option<u16>) -> bool + Send + Sync + 'static> From<F>
for Custom
{
fn from(f: F) -> Custom {
Custom(Arc::new(f))
}
}
impl Intercept {
pub fn matches<D: Dst>(&self, uri: &D) -> bool {
match (self, uri.scheme()) {
(&Intercept::All, _)
| (&Intercept::Http, Some("http"))
| (&Intercept::Https, Some("https")) => true,
(&Intercept::Custom(Custom(ref f)), _) => f(uri.scheme(), uri.host(), uri.port()),
_ => false,
}
}
}
impl<F: Fn(Option<&str>, Option<&str>, Option<u16>) -> bool + Send + Sync + 'static> From<F>
for Intercept
{
fn from(f: F) -> Intercept {
Intercept::Custom(f.into())
}
}
#[derive(Clone, Debug)]
pub struct Proxy {
intercept: Intercept,
force_connect: bool,
headers: HeaderMap,
uri: Uri,
}
impl Proxy {
pub fn new<I: Into<Intercept>>(intercept: I, uri: Uri) -> Proxy {
let mut proxy = Proxy {
intercept: intercept.into(),
uri: uri.clone(),
headers: HeaderMap::new(),
force_connect: false,
};
if let Some((user, pass)) = extract_user_pass(&uri) {
proxy.set_authorization(Authorization::basic(user, pass));
}
proxy
}
pub fn set_authorization<C: Credentials + Clone>(&mut self, credentials: Authorization<C>) {
match self.intercept {
Intercept::Http => {
self.headers.typed_insert(Authorization(credentials.0));
}
Intercept::Https => {
self.headers.typed_insert(ProxyAuthorization(credentials.0));
}
_ => {
self.headers
.typed_insert(Authorization(credentials.0.clone()));
self.headers.typed_insert(ProxyAuthorization(credentials.0));
}
}
}
pub fn force_connect(&mut self) {
self.force_connect = true;
}
pub fn set_header(&mut self, name: HeaderName, value: HeaderValue) {
self.headers.insert(name, value);
}
pub fn intercept(&self) -> &Intercept {
&self.intercept
}
pub fn headers(&self) -> &HeaderMap {
&self.headers
}
pub fn uri(&self) -> &Uri {
&self.uri
}
}
#[derive(Clone)]
pub struct ProxyConnector<C> {
proxies: Vec<Proxy>,
connector: C,
#[cfg(all(not(feature = "__rustls"), feature = "native-tls"))]
tls: Option<NativeTlsConnector>,
#[cfg(feature = "__rustls")]
tls: Option<TlsConnector>,
#[cfg(not(feature = "__tls"))]
tls: Option<()>,
}
impl<C: fmt::Debug> fmt::Debug for ProxyConnector<C> {
fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
write!(
f,
"ProxyConnector {}{{ proxies: {:?}, connector: {:?} }}",
if self.tls.is_some() {
""
} else {
"(unsecured)"
},
self.proxies,
self.connector
)
}
}
impl<C> ProxyConnector<C> {
#[cfg(all(not(feature = "__rustls"), feature = "native-tls"))]
pub fn new(connector: C) -> Result<Self, io::Error> {
let tls = NativeTlsConnector::builder()
.build()
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
Ok(ProxyConnector {
proxies: Vec::new(),
connector: connector,
tls: Some(tls),
})
}
#[cfg(feature = "__rustls")]
pub fn new(connector: C) -> Result<Self, io::Error> {
let config = tokio_rustls::rustls::ClientConfig::builder();
#[cfg(all(
feature = "rustls-tls-native-roots",
not(feature = "rustls-tls-webpki-roots")
))]
let config = config.with_native_roots()?;
#[cfg(feature = "rustls-tls-webpki-roots")]
let config = config.with_webpki_roots();
let cfg = Arc::new(config.with_no_client_auth());
let tls = TlsConnector::from(cfg);
Ok(ProxyConnector {
proxies: Vec::new(),
connector,
tls: Some(tls),
})
}
pub fn unsecured(connector: C) -> Self {
ProxyConnector {
proxies: Vec::new(),
connector,
tls: None,
}
}
#[cfg(feature = "__tls")]
pub fn from_proxy(connector: C, proxy: Proxy) -> Result<Self, io::Error> {
let mut c = ProxyConnector::new(connector)?;
c.proxies.push(proxy);
Ok(c)
}
pub fn from_proxy_unsecured(connector: C, proxy: Proxy) -> Self {
let mut c = ProxyConnector::unsecured(connector);
c.proxies.push(proxy);
c
}
pub fn with_connector<CC>(self, connector: CC) -> ProxyConnector<CC> {
ProxyConnector {
connector,
proxies: self.proxies,
tls: self.tls,
}
}
#[cfg(all(not(feature = "__rustls"), feature = "native-tls"))]
pub fn set_tls(&mut self, tls: Option<NativeTlsConnector>) {
self.tls = tls;
}
#[cfg(feature = "__rustls")]
pub fn set_tls(&mut self, tls: Option<TlsConnector>) {
self.tls = tls;
}
pub fn proxies(&self) -> &[Proxy] {
&self.proxies
}
pub fn add_proxy(&mut self, proxy: Proxy) {
self.proxies.push(proxy);
}
pub fn extend_proxies<I: IntoIterator<Item = Proxy>>(&mut self, proxies: I) {
self.proxies.extend(proxies)
}
pub fn http_headers(&self, uri: &Uri) -> Option<&HeaderMap> {
if uri.scheme_str() != Some("http") {
return None;
}
self.match_proxy(uri).map(|p| &p.headers)
}
fn match_proxy<D: Dst>(&self, uri: &D) -> Option<&Proxy> {
self.proxies.iter().find(|p| p.intercept.matches(uri))
}
}
macro_rules! mtry {
($e:expr) => {
match $e {
Ok(v) => v,
Err(e) => break Err(e.into()),
}
};
}
impl<C> Service<Uri> for ProxyConnector<C>
where
C: Service<Uri>,
C::Response: Read + Write + Send + Unpin + 'static,
C::Future: Send + 'static,
C::Error: Into<BoxError>,
{
type Response = ProxyStream<C::Response>;
type Error = io::Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
match self.connector.poll_ready(cx) {
Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
Poll::Ready(Err(e)) => Poll::Ready(Err(io_err(e.into()))),
Poll::Pending => Poll::Pending,
}
}
fn call(&mut self, uri: Uri) -> Self::Future {
if let (Some(p), Some(host)) = (self.match_proxy(&uri), uri.host()) {
if uri.scheme() == Some(&http::uri::Scheme::HTTPS) || p.force_connect {
let host = host.to_owned();
let port =
uri.port_u16()
.unwrap_or(if uri.scheme() == Some(&http::uri::Scheme::HTTP) {
80
} else {
443
});
let tunnel = tunnel::new(&host, port, &p.headers);
let connection =
proxy_dst(&uri, &p.uri).map(|proxy_url| self.connector.call(proxy_url));
let tls = if uri.scheme() == Some(&http::uri::Scheme::HTTPS) {
self.tls.clone()
} else {
None
};
Box::pin(async move {
#[allow(clippy::never_loop)]
loop {
let proxy_stream = mtry!(mtry!(connection).await.map_err(io_err));
let tunnel_stream = mtry!(tunnel.with_stream(proxy_stream).await);
break match tls {
#[cfg(all(not(feature = "__rustls"), feature = "native-tls"))]
Some(tls) => {
use hyper_util::rt::TokioIo;
let tls = TlsConnector::from(tls);
let secure_stream = mtry!(tls
.connect(&host, TokioIo::new(tunnel_stream))
.await
.map_err(io_err));
Ok(ProxyStream::Secured(Box::new(TokioIo::new(secure_stream))))
}
#[cfg(feature = "__rustls")]
Some(tls) => {
use hyper_util::rt::TokioIo;
let server_name =
mtry!(ServerName::try_from(host.to_string()).map_err(io_err));
let secure_stream = mtry!(tls
.connect(server_name, TokioIo::new(tunnel_stream))
.await
.map_err(io_err));
Ok(ProxyStream::Secured(Box::new(TokioIo::new(secure_stream))))
}
#[cfg(not(feature = "__tls",))]
Some(_) => panic!("hyper-proxy was not built with TLS support"),
None => Ok(ProxyStream::Regular(tunnel_stream)),
};
}
})
} else {
match proxy_dst(&uri, &p.uri) {
Ok(proxy_uri) => Box::pin(
self.connector
.call(proxy_uri)
.map_ok(ProxyStream::Regular)
.map_err(|err| io_err(err.into())),
),
Err(err) => Box::pin(futures_util::future::err(io_err(err))),
}
}
} else {
Box::pin(
self.connector
.call(uri)
.map_ok(ProxyStream::NoProxy)
.map_err(|err| io_err(err.into())),
)
}
}
}
fn proxy_dst(dst: &Uri, proxy: &Uri) -> io::Result<Uri> {
Uri::builder()
.scheme(
proxy
.scheme_str()
.ok_or_else(|| io_err(format!("proxy uri missing scheme: {}", proxy)))?,
)
.authority(
proxy
.authority()
.ok_or_else(|| io_err(format!("proxy uri missing host: {}", proxy)))?
.clone(),
)
.path_and_query(
dst.path_and_query()
.ok_or_else(|| io_err(format!("dst uri missing path: {}", proxy)))?
.clone(),
)
.build()
.map_err(|err| io_err(format!("other error: {}", err)))
}
fn extract_user_pass(uri: &Uri) -> Option<(&str, &str)> {
let authority = uri.authority()?.as_str();
let (userinfo, _) = authority.rsplit_once('@')?;
let (username, password) = userinfo.split_once(':')?;
Some((username, password))
}
#[cfg(test)]
mod tests {
use http::Uri;
use crate::{Intercept, Proxy};
#[test]
fn test_new_proxy_with_authorization() {
let proxy = Proxy::new(
Intercept::All,
Uri::from_static("https://bob:secret@my-proxy:8080"),
);
assert_eq!(
proxy
.headers()
.get("authorization")
.unwrap()
.to_str()
.unwrap(),
"Basic Ym9iOnNlY3JldA=="
);
}
#[test]
fn test_new_proxy_without_authorization() {
let proxy = Proxy::new(Intercept::All, Uri::from_static("https://my-proxy:8080"));
assert_eq!(proxy.headers().get("authorization"), None);
}
}