#![doc = include_str!("../README.md")]
pub mod http;
pub mod middleware;
#[cfg(feature = "config")]
pub mod config;
#[cfg(feature = "redis")]
mod redis;
pub use rustls_mitm::CertificateAuthority;
use std::convert::Infallible;
use std::future::Future;
use std::net::SocketAddr;
use std::sync::{Arc, Mutex};
use std::time::Duration;
use ::http::{Request, Response};
use base64::Engine;
use hyper::body::Incoming;
use hyper_util::rt::TokioIo;
use rustls::pki_types::{CertificateDer, PrivateKeyDer, ServerName};
use rustls::{ClientConfig, RootCertStore, ServerConfig};
use tokio::net::{TcpListener, TcpStream, ToSocketAddrs};
use tokio::sync::Semaphore;
use tokio::task::JoinSet;
use tokio_rustls::{TlsAcceptor, TlsConnector};
use tower::Service;
use tracing::Instrument;
use http::{
Body, BoxError, ForwardService, HttpService, UpstreamClient, UpstreamConnector, UpstreamScheme,
empty_body, incoming_to_body,
};
type LayerFn = Box<dyn Fn(HttpService) -> HttpService + Send + Sync>;
type RoutePredicate = Arc<dyn Fn(&Request<Body>) -> bool + Send + Sync>;
type BufferedHttpService =
tower::buffer::Buffer<Request<Body>, <HttpService as Service<Request<Body>>>::Future>;
#[derive(Debug)]
struct NoCertVerifier;
impl rustls::client::danger::ServerCertVerifier for NoCertVerifier {
fn verify_server_cert(
&self,
_end_entity: &CertificateDer<'_>,
_intermediates: &[CertificateDer<'_>],
_server_name: &ServerName<'_>,
_ocsp_response: &[u8],
_now: rustls::pki_types::UnixTime,
) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
Ok(rustls::client::danger::ServerCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
_message: &[u8],
_cert: &CertificateDer<'_>,
_dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}
fn verify_tls13_signature(
&self,
_message: &[u8],
_cert: &CertificateDer<'_>,
_dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
rustls::crypto::ring::default_provider()
.signature_verification_algorithms
.supported_schemes()
}
}
#[derive(Clone)]
enum ProxyMode {
Forward {
mitm_acceptor: Arc<ServerConfig>,
credentials: Option<Arc<Vec<(String, String)>>>,
tls_acceptor: Option<Arc<TlsAcceptor>>,
},
Reverse {
upstream_authority: ::http::uri::Authority,
upstream_scheme: UpstreamScheme,
tls_acceptor: Option<Arc<TlsAcceptor>>,
},
}
pub struct ProxyBuilder {
ca: Option<CertificateAuthority>,
layers: Vec<LayerFn>,
accept_invalid_upstream_certs: bool,
handshake_timeout: Option<Duration>,
idle_timeout: Option<Duration>,
max_connections: Option<usize>,
drain_timeout: Option<Duration>,
credentials: Vec<(String, String)>,
pool_max_idle_per_host: usize,
pool_idle_timeout: Duration,
upstream_target: Option<(::http::uri::Authority, UpstreamScheme)>,
tls_identity: Option<(Vec<CertificateDer<'static>>, PrivateKeyDer<'static>)>,
routes: Vec<(RoutePredicate, middleware::Upstream)>,
}
impl ProxyBuilder {
pub fn ca_pem(mut self, cert_pem: &str, key_pem: &str) -> anyhow::Result<Self> {
self.ca = Some(CertificateAuthority::from_pem(cert_pem, key_pem)?);
Ok(self)
}
pub fn ca_pem_files(
mut self,
cert_path: impl AsRef<std::path::Path>,
key_path: impl AsRef<std::path::Path>,
) -> anyhow::Result<Self> {
self.ca = Some(CertificateAuthority::from_pem_files(cert_path, key_path)?);
Ok(self)
}
pub fn ca(mut self, ca: CertificateAuthority) -> Self {
self.ca = Some(ca);
self
}
pub fn layer<L>(mut self, layer: L) -> Self
where
L: tower::Layer<HttpService> + Send + Sync + 'static,
L::Service:
Service<Request<Body>, Response = Response<Body>, Error = BoxError> + Send + 'static,
<L::Service as Service<Request<Body>>>::Future: Send,
{
self.layers.push(Box::new(move |svc| {
tower::util::BoxService::new(layer.layer(svc))
}));
self
}
pub fn middleware<F, Fut>(self, f: F) -> Self
where
F: Fn(Request<Body>, middleware::Next) -> Fut + Send + Sync + 'static,
Fut: std::future::Future<Output = Result<Response<Body>, BoxError>> + Send + 'static,
{
self.layer(middleware::from_fn(f))
}
pub fn traffic_logger(self) -> Self {
self.layer(middleware::TrafficLogger::new())
}
pub fn latency(self, delay: std::time::Duration) -> Self {
self.layer(middleware::LatencyInjector::fixed(delay))
}
pub fn bandwidth(self, bytes_per_second: u64) -> Self {
self.layer(middleware::BandwidthThrottle::new(bytes_per_second))
}
pub fn rate_limit(self, count: u64, window: Duration) -> Self {
self.layer(middleware::RateLimiter::global(count, window))
}
pub fn sliding_window(self, count: u64, window: Duration) -> Self {
self.layer(middleware::SlidingWindow::global(count, window))
}
pub fn retry(self, max_retries: u32) -> Self {
self.layer(middleware::Retry::default().max_retries(max_retries))
}
pub fn circuit_breaker(self, threshold: u32, recovery: Duration) -> Self {
self.layer(middleware::CircuitBreaker::global(threshold, recovery))
}
pub fn set_request_header(self, name: impl AsRef<str>, value: impl AsRef<str>) -> Self {
self.layer(middleware::ModifyHeaders::new().set_request(name, value))
}
pub fn remove_request_header(self, name: impl AsRef<str>) -> Self {
self.layer(middleware::ModifyHeaders::new().remove_request(name))
}
pub fn set_response_header(self, name: impl AsRef<str>, value: impl AsRef<str>) -> Self {
self.layer(middleware::ModifyHeaders::new().set_response(name, value))
}
pub fn remove_response_header(self, name: impl AsRef<str>) -> Self {
self.layer(middleware::ModifyHeaders::new().remove_response(name))
}
pub fn block_hosts(
self,
patterns: impl IntoIterator<Item = impl AsRef<str>>,
) -> Result<Self, globset::Error> {
Ok(self.layer(middleware::BlockList::hosts(patterns)?))
}
pub fn rewrite_path(self, pattern: &str, replacement: &str) -> Self {
self.layer(
middleware::UrlRewrite::path(pattern, replacement).expect("invalid rewrite pattern"),
)
}
pub fn block_paths(
self,
patterns: impl IntoIterator<Item = impl AsRef<str>>,
) -> Result<Self, globset::Error> {
Ok(self.layer(middleware::BlockList::paths(patterns)?))
}
pub fn handshake_timeout(mut self, timeout: Duration) -> Self {
self.handshake_timeout = Some(timeout);
self
}
pub fn idle_timeout(mut self, timeout: Duration) -> Self {
self.idle_timeout = Some(timeout);
self
}
pub fn max_connections(mut self, max: usize) -> Self {
self.max_connections = Some(max);
self
}
pub fn drain_timeout(mut self, timeout: Duration) -> Self {
self.drain_timeout = Some(timeout);
self
}
pub fn credential(mut self, username: impl Into<String>, password: impl Into<String>) -> Self {
self.credentials.push((username.into(), password.into()));
self
}
pub fn pool_max_idle_per_host(mut self, max: usize) -> Self {
self.pool_max_idle_per_host = max;
self
}
pub fn pool_idle_timeout(mut self, timeout: Duration) -> Self {
self.pool_idle_timeout = timeout;
self
}
pub fn danger_accept_invalid_upstream_certs(mut self) -> Self {
self.accept_invalid_upstream_certs = true;
self
}
pub fn reverse_proxy(mut self, url: &str) -> anyhow::Result<Self> {
let uri: ::http::Uri = url
.parse()
.map_err(|e| anyhow::anyhow!("invalid upstream URL: {e}"))?;
let scheme = uri.scheme().cloned().unwrap_or(::http::uri::Scheme::HTTP);
let authority = uri
.authority()
.cloned()
.ok_or_else(|| anyhow::anyhow!("upstream URL must contain a host"))?;
self.upstream_target = Some((authority, scheme));
Ok(self)
}
pub fn tls_identity(
mut self,
cert_path: impl AsRef<std::path::Path>,
key_path: impl AsRef<std::path::Path>,
) -> anyhow::Result<Self> {
let cert_pem = std::fs::read(cert_path)?;
let key_pem = std::fs::read(key_path)?;
let certs: Vec<CertificateDer<'static>> =
rustls_pemfile::certs(&mut &*cert_pem).collect::<Result<_, _>>()?;
let key = rustls_pemfile::private_key(&mut &*key_pem)?
.ok_or_else(|| anyhow::anyhow!("no private key found in PEM file"))?;
self.tls_identity = Some((certs, key));
Ok(self)
}
pub fn route(
mut self,
predicate: impl Fn(&Request<Body>) -> bool + Send + Sync + 'static,
upstream: middleware::Upstream,
) -> Self {
self.routes.push((Arc::new(predicate), upstream));
self
}
pub fn route_prefix(self, prefix: impl Into<String>, url: &str) -> anyhow::Result<Self> {
let upstream = middleware::Upstream::new([url])?;
let prefix = prefix.into();
Ok(self.route(
move |req: &Request<Body>| req.uri().path().starts_with(&prefix),
upstream,
))
}
pub fn route_prefix_balanced(
self,
prefix: impl Into<String>,
urls: impl IntoIterator<Item = impl AsRef<str>>,
) -> anyhow::Result<Self> {
let upstream = middleware::Upstream::balanced(urls)?;
let prefix = prefix.into();
Ok(self.route(
move |req: &Request<Body>| req.uri().path().starts_with(&prefix),
upstream,
))
}
pub fn build(self) -> anyhow::Result<Proxy> {
let mut client_config = if self.accept_invalid_upstream_certs {
ClientConfig::builder()
.dangerous()
.with_custom_certificate_verifier(Arc::new(NoCertVerifier))
.with_no_client_auth()
} else {
let mut root_store = RootCertStore::empty();
root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
ClientConfig::builder()
.with_root_certificates(root_store)
.with_no_client_auth()
};
client_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
let connector = UpstreamConnector {
tls: TlsConnector::from(Arc::new(client_config)),
};
let upstream_client =
hyper_util::client::legacy::Client::builder(hyper_util::rt::TokioExecutor::new())
.pool_idle_timeout(self.pool_idle_timeout)
.pool_max_idle_per_host(self.pool_max_idle_per_host)
.pool_timer(hyper_util::rt::TokioTimer::new())
.build(connector);
let mode = if let Some((authority, scheme)) = self.upstream_target {
let tls_acceptor = if let Some((certs, key)) = self.tls_identity {
let mut config = ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(certs, key)?;
config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
Some(Arc::new(TlsAcceptor::from(Arc::new(config))))
} else {
None
};
ProxyMode::Reverse {
upstream_authority: authority,
upstream_scheme: scheme,
tls_acceptor,
}
} else {
let ca = self
.ca
.ok_or_else(|| anyhow::anyhow!("CertificateAuthority must be set"))?;
let tls_acceptor = if let Some((certs, key)) = self.tls_identity {
let mut config = ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(certs, key)?;
config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
Some(Arc::new(TlsAcceptor::from(Arc::new(config))))
} else {
None
};
ProxyMode::Forward {
mitm_acceptor: {
let mut config = rustls_mitm::MitmCertResolver::new(ca).into_server_config();
config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
Arc::new(config)
},
credentials: if self.credentials.is_empty() {
None
} else {
Some(Arc::new(self.credentials))
},
tls_acceptor,
}
};
let mut layers = self.layers;
if !self.routes.is_empty() {
let mut router = middleware::Router::new();
for (predicate, upstream) in self.routes {
let pred = predicate;
router = router.route(move |req: &Request<Body>| pred(req), upstream);
}
let router_layer: LayerFn = Box::new(move |svc| {
tower::util::BoxService::new(tower::Layer::layer(&router, svc))
});
layers.push(router_layer);
}
Ok(Proxy {
mode,
layers: Arc::new(layers),
upstream_client,
handshake_timeout: self.handshake_timeout,
idle_timeout: self.idle_timeout,
max_connections: self.max_connections.map(|n| Arc::new(Semaphore::new(n))),
drain_timeout: self.drain_timeout,
})
}
}
struct HyperServiceAdapter {
inner: BufferedHttpService,
}
impl hyper::service::Service<Request<Incoming>> for HyperServiceAdapter {
type Response = Response<Body>;
type Error = BoxError;
type Future = std::pin::Pin<
Box<dyn std::future::Future<Output = Result<Response<Body>, BoxError>> + Send>,
>;
fn call(&self, mut req: Request<Incoming>) -> Self::Future {
let mut inner = self.inner.clone();
let span = tracing::debug_span!("request", method = %req.method(), uri = %req.uri());
let is_upgrade = req
.headers()
.get(::http::header::CONNECTION)
.and_then(|v| v.to_str().ok())
.map(|v| {
v.split(',')
.any(|t| t.trim().eq_ignore_ascii_case("upgrade"))
})
.unwrap_or(false);
let client_upgrade = if is_upgrade {
Some(hyper::upgrade::on(&mut req))
} else {
None
};
Box::pin(
async move {
let req = req.map(incoming_to_body);
std::future::poll_fn(|cx| inner.poll_ready(cx)).await?;
let mut resp = inner.call(req).await?;
if resp.status() == ::http::StatusCode::SWITCHING_PROTOCOLS
&& let Some(client_upgrade) = client_upgrade
{
let upstream_upgrade = hyper::upgrade::on(&mut resp);
tokio::spawn(async move {
let (client_io, upstream_io) =
match tokio::try_join!(client_upgrade, upstream_upgrade) {
Ok(io) => io,
Err(e) => {
tracing::warn!(error = %e, "upgrade failed");
return;
}
};
let mut client_io = TokioIo::new(client_io);
let mut upstream_io = TokioIo::new(upstream_io);
if let Err(e) =
tokio::io::copy_bidirectional(&mut client_io, &mut upstream_io).await
{
tracing::debug!(error = %e, "upgrade stream ended");
}
});
}
Ok(resp)
}
.instrument(span),
)
}
}
#[derive(Clone)]
pub struct Proxy {
mode: ProxyMode,
layers: Arc<Vec<LayerFn>>,
upstream_client: UpstreamClient,
handshake_timeout: Option<Duration>,
idle_timeout: Option<Duration>,
max_connections: Option<Arc<Semaphore>>,
drain_timeout: Option<Duration>,
}
impl Proxy {
pub fn builder() -> ProxyBuilder {
ProxyBuilder {
ca: None,
layers: Vec::new(),
accept_invalid_upstream_certs: false,
handshake_timeout: None,
idle_timeout: None,
max_connections: None,
drain_timeout: None,
credentials: Vec::new(),
pool_max_idle_per_host: 8,
pool_idle_timeout: Duration::from_secs(90),
upstream_target: None,
tls_identity: None,
routes: Vec::new(),
}
}
pub async fn listen(&self, addr: impl ToSocketAddrs) -> anyhow::Result<()> {
self.listen_with_shutdown(addr, std::future::pending())
.await
}
pub async fn listen_with_shutdown(
&self,
addr: impl ToSocketAddrs,
shutdown: impl Future<Output = ()>,
) -> anyhow::Result<()> {
let listener = TcpListener::bind(addr).await?;
self.listen_on_with_shutdown(listener, shutdown).await
}
pub async fn listen_on(&self, listener: TcpListener) -> anyhow::Result<()> {
self.listen_on_with_shutdown(listener, std::future::pending())
.await
}
pub async fn listen_on_with_shutdown(
&self,
listener: TcpListener,
shutdown: impl Future<Output = ()>,
) -> anyhow::Result<()> {
let local_addr = listener.local_addr()?;
match &self.mode {
ProxyMode::Forward { tls_acceptor, .. } => {
if tls_acceptor.is_some() {
tracing::info!(%local_addr, "listening (forward proxy, TLS)");
} else {
tracing::info!(%local_addr, "listening (forward proxy)");
}
}
ProxyMode::Reverse {
upstream_authority,
upstream_scheme,
..
} => {
tracing::info!(
%local_addr,
upstream = %format_args!("{upstream_scheme}://{upstream_authority}"),
"listening (reverse proxy)",
);
}
}
let (shutdown_tx, shutdown_rx) = tokio::sync::watch::channel(false);
let mut tasks = JoinSet::new();
tokio::pin!(shutdown);
loop {
tokio::select! {
result = listener.accept() => {
let (stream, addr) = result?;
let permit = if let Some(ref sem) = self.max_connections {
Some(Arc::clone(sem).acquire_owned().await?)
} else {
None
};
let proxy = self.clone();
let rx = shutdown_rx.clone();
let span = tracing::info_span!(
"connection",
client = %addr,
target = tracing::field::Empty,
);
tasks.spawn(
async move {
if let Err(e) = proxy
.handle_connection_inner(stream, rx)
.await
{
tracing::warn!(error = %e, "connection error");
}
drop(permit);
}
.instrument(span),
);
}
() = &mut shutdown => {
break;
}
}
}
tracing::info!("shutdown signal received, draining connections");
let _ = shutdown_tx.send(true);
drop(shutdown_rx);
if let Some(timeout) = self.drain_timeout {
if tokio::time::timeout(timeout, async {
while tasks.join_next().await.is_some() {}
})
.await
.is_err()
{
tracing::warn!("drain timeout reached, aborting remaining connections");
tasks.abort_all();
}
} else {
while tasks.join_next().await.is_some() {}
}
tracing::info!("all connections closed");
Ok(())
}
pub async fn handle_connection(
&self,
stream: TcpStream,
client_addr: SocketAddr,
) -> anyhow::Result<()> {
let (_tx, rx) = tokio::sync::watch::channel(false);
self.handle_connection_inner(stream, rx)
.instrument(tracing::info_span!(
"connection",
client = %client_addr,
target = tracing::field::Empty,
))
.await
}
async fn handle_connection_inner(
&self,
stream: TcpStream,
shutdown_rx: tokio::sync::watch::Receiver<bool>,
) -> anyhow::Result<()> {
match &self.mode {
ProxyMode::Forward {
mitm_acceptor,
credentials,
tls_acceptor,
} => {
self.handle_forward_connection(
stream,
shutdown_rx,
mitm_acceptor.clone(),
credentials.clone(),
tls_acceptor.clone(),
)
.await
}
ProxyMode::Reverse {
upstream_authority,
upstream_scheme,
tls_acceptor,
} => {
self.handle_reverse_connection(
stream,
shutdown_rx,
upstream_authority.clone(),
upstream_scheme.clone(),
tls_acceptor.clone(),
)
.await
}
}
}
fn build_service_chain(
&self,
authority: ::http::uri::Authority,
scheme: UpstreamScheme,
) -> HttpService {
let forward = ForwardService::new(self.upstream_client.clone(), authority, scheme);
let mut service: HttpService = tower::util::BoxService::new(forward);
for layer_fn in self.layers.iter().rev() {
service = layer_fn(service);
}
service
}
async fn serve_client<
I: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static,
>(
&self,
client_io: I,
hyper_service: HyperServiceAdapter,
mut shutdown_rx: tokio::sync::watch::Receiver<bool>,
) -> anyhow::Result<()> {
let client_io = TokioIo::new(client_io);
let mut builder =
hyper_util::server::conn::auto::Builder::new(hyper_util::rt::TokioExecutor::new());
if let Some(idle) = self.idle_timeout {
builder
.http1()
.timer(hyper_util::rt::TokioTimer::new())
.header_read_timeout(idle);
builder
.http2()
.timer(hyper_util::rt::TokioTimer::new())
.keep_alive_interval(Some(idle / 2))
.keep_alive_timeout(idle);
}
let conn = builder.serve_connection_with_upgrades(client_io, hyper_service);
tokio::pin!(conn);
tokio::select! {
result = conn.as_mut() => {
result.map_err(|e| anyhow::anyhow!(e))?;
}
_ = shutdown_rx.changed() => {
conn.as_mut().graceful_shutdown();
if let Err(e) = conn.await {
tracing::debug!(error = %e, "connection closed during shutdown");
}
}
}
tracing::debug!("connection closed");
Ok(())
}
async fn handle_forward_connection(
&self,
stream: TcpStream,
shutdown_rx: tokio::sync::watch::Receiver<bool>,
mitm_acceptor: Arc<ServerConfig>,
credentials: Option<Arc<Vec<(String, String)>>>,
tls_acceptor: Option<Arc<TlsAcceptor>>,
) -> anyhow::Result<()> {
if let Some(acceptor) = tls_acceptor {
let tls_stream = acceptor.accept(stream).await?;
let (hyper_service, client_tls) = {
let handshake = self.handshake(tls_stream, mitm_acceptor, credentials);
if let Some(timeout) = self.handshake_timeout {
tokio::time::timeout(timeout, handshake)
.await
.map_err(|_| anyhow::anyhow!("handshake timed out"))??
} else {
handshake.await?
}
};
self.serve_client(client_tls, hyper_service, shutdown_rx)
.await
} else {
let (hyper_service, client_tls) = {
let handshake = self.handshake(stream, mitm_acceptor, credentials);
if let Some(timeout) = self.handshake_timeout {
tokio::time::timeout(timeout, handshake)
.await
.map_err(|_| anyhow::anyhow!("handshake timed out"))??
} else {
handshake.await?
}
};
self.serve_client(client_tls, hyper_service, shutdown_rx)
.await
}
}
async fn handle_reverse_connection(
&self,
stream: TcpStream,
shutdown_rx: tokio::sync::watch::Receiver<bool>,
upstream_authority: ::http::uri::Authority,
upstream_scheme: UpstreamScheme,
tls_acceptor: Option<Arc<TlsAcceptor>>,
) -> anyhow::Result<()> {
let service = self.build_service_chain(upstream_authority, upstream_scheme);
let hyper_service = HyperServiceAdapter {
inner: tower::buffer::Buffer::new(service, 1024),
};
if let Some(acceptor) = tls_acceptor {
let tls_stream = acceptor.accept(stream).await?;
self.serve_client(tls_stream, hyper_service, shutdown_rx)
.await
} else {
self.serve_client(stream, hyper_service, shutdown_rx).await
}
}
async fn handshake<S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static>(
&self,
stream: S,
mitm_acceptor: Arc<ServerConfig>,
credentials: Option<Arc<Vec<(String, String)>>>,
) -> anyhow::Result<(
HyperServiceAdapter,
tokio_rustls::server::TlsStream<TokioIo<hyper::upgrade::Upgraded>>,
)> {
type ConnectInfo = anyhow::Result<(String, hyper::upgrade::Upgraded)>;
let (connect_tx, connect_rx) = tokio::sync::oneshot::channel::<ConnectInfo>();
let connect_tx = Arc::new(Mutex::new(Some(connect_tx)));
let service = hyper::service::service_fn(move |req: Request<Incoming>| {
let connect_tx = connect_tx.clone();
let credentials = credentials.clone();
async move {
let send_err = |msg: &str| {
if let Some(tx) = connect_tx.lock().unwrap().take() {
let _ = tx.send(Err(anyhow::anyhow!("{msg}")));
}
};
if req.method() != ::http::Method::CONNECT {
send_err(&format!("expected CONNECT, got {}", req.method()));
return Ok::<_, Infallible>(
Response::builder()
.status(::http::StatusCode::BAD_REQUEST)
.body(empty_body())
.unwrap(),
);
}
let authority = match req.uri().authority() {
Some(a) => a.to_string(),
None => {
send_err("missing authority in CONNECT URI");
return Ok(Response::builder()
.status(::http::StatusCode::BAD_REQUEST)
.body(empty_body())
.unwrap());
}
};
if let Some(ref creds) = credentials {
let authenticated = req
.headers()
.get(::http::header::PROXY_AUTHORIZATION)
.and_then(|v| v.to_str().ok())
.and_then(|auth| auth.strip_prefix("Basic "))
.and_then(|b64| base64::engine::general_purpose::STANDARD.decode(b64).ok())
.and_then(|bytes| String::from_utf8(bytes).ok())
.and_then(|decoded| {
let (u, p) = decoded.split_once(':')?;
Some(creds.iter().any(|(eu, ep)| eu == u && ep == p))
})
.unwrap_or(false);
if !authenticated {
send_err("proxy authentication failed");
return Ok(Response::builder()
.status(::http::StatusCode::PROXY_AUTHENTICATION_REQUIRED)
.header(::http::header::PROXY_AUTHENTICATE, "Basic realm=\"noxy\"")
.body(empty_body())
.unwrap());
}
}
tokio::spawn(async move {
let result = hyper::upgrade::on(req)
.await
.map(|upgraded| (authority, upgraded))
.map_err(|e| anyhow::anyhow!(e));
if let Some(tx) = connect_tx.lock().unwrap().take() {
let _ = tx.send(result);
}
});
Ok(Response::new(empty_body()))
}
});
hyper::server::conn::http1::Builder::new()
.serve_connection(TokioIo::new(stream), service)
.with_upgrades()
.await
.map_err(|e| anyhow::anyhow!(e))?;
let (authority, upgraded) = connect_rx
.await
.map_err(|_| anyhow::anyhow!("CONNECT handler did not complete"))??;
let (host, port) = if let Some(colon) = authority.rfind(':') {
(&authority[..colon], authority[colon + 1..].parse::<u16>()?)
} else {
(authority.as_str(), 443u16)
};
tracing::Span::current().record(
"target",
tracing::field::display(format_args!("{host}:{port}")),
);
tracing::debug!("CONNECT");
let acceptor = TlsAcceptor::from(mitm_acceptor);
let client_tls = acceptor.accept(TokioIo::new(upgraded)).await?;
let authority: ::http::uri::Authority = format!("{host}:{port}").parse()?;
let service = self.build_service_chain(authority, ::http::uri::Scheme::HTTPS);
let hyper_service = HyperServiceAdapter {
inner: tower::buffer::Buffer::new(service, 1024),
};
Ok((hyper_service, client_tls))
}
}