use super::H3Client;
use super::super::proxy::{
ProxyClient, SkipServerVerification, UpstreamBody, build_backend_uri,
convert_response, set_forwarding_headers, strip_hop_by_hop,
};
#[cfg(unix)]
use super::super::proxy::UnixConnector;
use crate::config::ProxyProtocolVersion;
use crate::error::{HttpResponse, ReqBody, response_502};
use crate::listener::{LocalAddr, LocalUnixPath};
use crate::proxy_proto;
use hyper::body::Incoming;
use hyper::header::HeaderValue;
use hyper::{Request, Response, Uri, Version};
use http_body_util::BodyExt;
use hyper_rustls::ConfigBuilderExt;
use hyper_util::client::legacy::Client;
use hyper_util::client::legacy::connect::HttpConnector;
use hyper_util::rt::{TokioExecutor, TokioIo};
use std::net::SocketAddr;
use std::sync::Arc;
const WS_GUID: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
fn compute_ws_accept(key: &HeaderValue) -> Option<HeaderValue> {
use base64::Engine as _;
use sha1::{Digest, Sha1};
let key_str = key.to_str().ok()?;
let mut hasher = Sha1::new();
hasher.update(key_str.as_bytes());
hasher.update(WS_GUID.as_bytes());
let digest = hasher.finalize();
let encoded =
base64::engine::general_purpose::STANDARD.encode(digest);
HeaderValue::from_str(&encoded).ok()
}
use tokio::io::AsyncWriteExt;
pub(crate) struct InnerProxyClient {
client: ProxyClient,
pub(super) upstream: Uri,
strip_prefix: bool,
proxy_protocol: Option<ProxyProtocolVersion>,
#[cfg(unix)]
unix_path: Option<std::path::PathBuf>,
h3_hint: Arc<tokio::sync::Mutex<Option<H3Hint>>>,
h3_lazy: Arc<tokio::sync::Mutex<Option<Arc<H3Client>>>>,
h3_params: H3LazyParams,
auto_h3_enabled: bool,
pub(super) upgrade_scheme: crate::config::ProxyUpstreamScheme,
metrics: Option<Arc<crate::metrics::Metrics>>,
}
#[derive(Clone)]
struct H3LazyParams {
upstream: Uri,
skip_verify: bool,
pool_idle: Option<std::time::Duration>,
connect_timeout: Option<std::time::Duration>,
}
struct H3Hint {
port: u16,
expires_at: std::time::Instant,
}
const MAX_ALT_SVC_MA_SECS: u64 = 24 * 3600;
pub(super) fn parse_alt_svc_h3(value: &str) -> Option<(u16, u64)> {
for entry in value.split(',') {
let mut parts = entry.split(';').map(str::trim);
let head = parts.next()?;
let (proto, rest) = head.split_once('=')?;
if proto.trim() != "h3" {
continue;
}
let port_str = rest.trim().trim_matches('"');
let port_str = port_str.strip_prefix(':').unwrap_or(port_str);
let port: u16 = port_str.parse().ok()?;
let mut ma: Option<u64> = None;
for param in parts {
if let Some(rest) = param.strip_prefix("ma=") {
ma = rest.trim().parse().ok();
}
}
let ma = ma?;
if ma == 0 {
continue;
}
return Some((port, ma));
}
None
}
impl InnerProxyClient {
#[allow(clippy::too_many_arguments)]
pub(crate) fn new(
upstream_str: &str,
strip_prefix: bool,
proxy_protocol: Option<ProxyProtocolVersion>,
scheme: crate::config::ProxyUpstreamScheme,
pool_idle_timeout_secs: Option<u64>,
pool_max_idle: Option<u32>,
skip_verify: bool,
connect_timeout_secs: Option<u64>,
) -> anyhow::Result<Self> {
let connect_timeout =
connect_timeout_secs.map(std::time::Duration::from_secs);
let pool_idle =
pool_idle_timeout_secs.map(std::time::Duration::from_secs);
let mut http_builder = Client::builder(TokioExecutor::new());
if let Some(d) = pool_idle {
http_builder.pool_idle_timeout(d);
}
if let Some(n) = pool_max_idle {
http_builder.pool_max_idle_per_host(n as usize);
}
#[cfg(unix)]
if let Some(path) = upstream_str.strip_prefix("unix:") {
let connector = UnixConnector { path: path.into() };
let client = http_builder.build(connector);
let upstream: Uri =
"http://localhost".parse().expect("static URI is valid");
return Ok(Self {
client: ProxyClient::Unix(client),
upstream,
strip_prefix,
proxy_protocol,
unix_path: Some(path.into()),
h3_hint: Arc::new(tokio::sync::Mutex::new(None)),
h3_lazy: Arc::new(tokio::sync::Mutex::new(None)),
h3_params: H3LazyParams {
upstream: "http://localhost"
.parse()
.expect("static URI is valid"),
skip_verify: false,
pool_idle,
connect_timeout,
},
auto_h3_enabled: false,
upgrade_scheme: scheme,
metrics: None,
});
}
#[cfg(not(unix))]
if upstream_str.starts_with("unix:") {
anyhow::bail!("unix: upstream not supported on this platform");
}
let upstream: Uri = upstream_str.parse().map_err(|_| {
anyhow::anyhow!("invalid upstream URL: {upstream_str}")
})?;
match upstream.scheme_str() {
Some("http") | Some("https") => {}
_ => anyhow::bail!(
"upstream '{upstream_str}' must use http or https scheme"
),
}
if upstream.authority().is_none() {
anyhow::bail!("upstream '{upstream_str}' must include a host");
}
if scheme == crate::config::ProxyUpstreamScheme::H3 {
let mut h3 = if skip_verify {
H3Client::new_skip_verify(&upstream, pool_idle)?
} else {
H3Client::new(&upstream, pool_idle)?
};
h3.connect_timeout = connect_timeout;
return Ok(Self {
client: ProxyClient::H3(h3),
upstream: upstream.clone(),
strip_prefix,
proxy_protocol,
#[cfg(unix)]
unix_path: None,
h3_hint: Arc::new(tokio::sync::Mutex::new(None)),
h3_lazy: Arc::new(tokio::sync::Mutex::new(None)),
h3_params: H3LazyParams {
upstream,
skip_verify,
pool_idle,
connect_timeout,
},
auto_h3_enabled: false,
upgrade_scheme: scheme,
metrics: None,
});
}
let mut http_conn = HttpConnector::new();
http_conn.enforce_http(false); if let Some(d) = connect_timeout {
http_conn.set_connect_timeout(Some(d));
}
let https_builder = if skip_verify {
let crypto = rustls::ClientConfig::builder()
.dangerous()
.with_custom_certificate_verifier(Arc::new(
SkipServerVerification,
))
.with_no_client_auth();
hyper_rustls::HttpsConnectorBuilder::new()
.with_tls_config(crypto)
} else {
hyper_rustls::HttpsConnectorBuilder::new()
.with_webpki_roots()
};
let connector = https_builder
.https_or_http()
.enable_http1()
.enable_http2()
.wrap_connector(http_conn);
let client = http_builder.build(connector);
let auto_h3_enabled = upstream_str.starts_with("https://");
Ok(Self {
client: ProxyClient::Http(client),
upstream: upstream.clone(),
strip_prefix,
proxy_protocol,
#[cfg(unix)]
unix_path: None,
h3_hint: Arc::new(tokio::sync::Mutex::new(None)),
h3_lazy: Arc::new(tokio::sync::Mutex::new(None)),
h3_params: H3LazyParams {
upstream,
skip_verify,
pool_idle,
connect_timeout,
},
auto_h3_enabled,
upgrade_scheme: scheme,
metrics: None,
})
}
pub(crate) fn set_metrics(&mut self, metrics: Arc<crate::metrics::Metrics>) {
if let ProxyClient::H3(h) = &mut self.client {
h.metrics = Some(metrics.clone());
}
self.metrics = Some(metrics);
}
fn count_connect_error(&self) {
if let Some(m) = &self.metrics {
m.proxy_upstream_connect_errors_total
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}
}
pub(crate) async fn serve_upgrade(
&self,
req: Request<ReqBody>,
matched_prefix: &str,
marker: super::upgrade::UpgradeRequest,
) -> HttpResponse {
use super::upgrade::{InboundProtocol, h1_upgraded, pump};
use crate::error::bytes_body;
let backend_req = match self
.prepare_backend_upgrade_request(req, matched_prefix)
{
Ok(r) => r,
Err(e) => {
tracing::error!(
"upgrade: failed to build backend URI: {e}"
);
return crate::error::response_502();
}
};
let (parts, upstream_stream) = if matches!(
self.upgrade_scheme,
crate::config::ProxyUpstreamScheme::H2c,
) {
match super::upgrade::open_h2c_upstream_tunnel(
&self.upstream,
backend_req,
&marker.protocol,
)
.await
{
Ok(t) => t,
Err(e) => {
tracing::warn!(
upstream = %self.upstream,
"h2c upgrade tunnel: open failed: {e:#}"
);
return crate::error::response_502();
}
}
} else {
match super::upgrade::open_h1_upstream_tunnel(
&self.upstream,
backend_req,
)
.await
{
Ok(t) => t,
Err(e) => {
tracing::warn!(
upstream = %self.upstream,
"upgrade tunnel: open failed: {e:#}"
);
return crate::error::response_502();
}
}
};
let outbound_is_h2c = matches!(
self.upgrade_scheme,
crate::config::ProxyUpstreamScheme::H2c,
);
let upstream_ok = if outbound_is_h2c {
parts.status == hyper::StatusCode::OK
} else {
parts.status == hyper::StatusCode::SWITCHING_PROTOCOLS
};
if !upstream_ok {
let mut resp = hyper::Response::builder()
.status(parts.status)
.body(bytes_body(bytes::Bytes::new()))
.unwrap_or_else(|_| crate::error::response_502());
for (n, v) in parts.headers.iter() {
if !matches!(
n.as_str(),
"connection"
| "keep-alive"
| "proxy-authenticate"
| "proxy-authorization"
| "te"
| "trailers"
| "transfer-encoding"
) {
resp.headers_mut().insert(n, v.clone());
}
}
return resp;
}
let is_websocket = marker
.protocol
.to_str()
.map(|s| s.eq_ignore_ascii_case("websocket"))
.unwrap_or(false);
let inbound_masks =
marker.inbound == super::upgrade::InboundProtocol::H1;
let outbound_masks = !outbound_is_h2c;
let ws_mode = if is_websocket && inbound_masks != outbound_masks
{
Some(if inbound_masks {
super::upgrade::MaskMode::Unmask
} else {
super::upgrade::MaskMode::Mask
})
} else {
None
};
let on_upgrade_arc = marker.on_upgrade.clone();
let inbound = marker.inbound;
tokio::spawn(async move {
let inbound_on = match on_upgrade_arc.lock().expect("proxy upgrade mutex").take() {
Some(f) => f,
None => return,
};
let inbound_stream = match inbound {
InboundProtocol::H1 | InboundProtocol::H2 => {
match inbound_on.await {
Ok(u) => h1_upgraded(u),
Err(e) => {
tracing::warn!(
"upgrade: inbound handoff failed: {e}"
);
return;
}
}
}
InboundProtocol::H3 => {
tracing::warn!(
"upgrade: h3 inbound dispatch not yet wired"
);
return;
}
};
let pump_res = match ws_mode {
Some(mode) => super::upgrade::pump_websocket(
inbound_stream,
upstream_stream,
mode,
)
.await,
None => pump(inbound_stream, upstream_stream)
.await
.map(|_| ()),
};
if let Err(e) = pump_res {
tracing::debug!("upgrade tunnel pump exited: {e}");
}
});
let inbound_status = match marker.inbound {
super::upgrade::InboundProtocol::H1 => {
hyper::StatusCode::SWITCHING_PROTOCOLS
}
super::upgrade::InboundProtocol::H2
| super::upgrade::InboundProtocol::H3 => {
hyper::StatusCode::OK
}
};
let mut resp = hyper::Response::builder()
.status(inbound_status)
.body(bytes_body(bytes::Bytes::new()))
.unwrap_or_else(|_| crate::error::response_502());
for (n, v) in parts.headers.iter() {
resp.headers_mut().insert(n, v.clone());
}
if marker.inbound == super::upgrade::InboundProtocol::H1
&& outbound_is_h2c
{
resp.headers_mut().insert(
hyper::header::CONNECTION,
hyper::header::HeaderValue::from_static("upgrade"),
);
resp.headers_mut().insert(
hyper::header::UPGRADE,
marker.protocol.clone(),
);
if marker
.protocol
.to_str()
.map(|s| s.eq_ignore_ascii_case("websocket"))
.unwrap_or(false)
&& let Some(accept_key) = marker
.ws_key
.as_ref()
.and_then(compute_ws_accept)
{
resp.headers_mut().insert(
hyper::header::HeaderName::from_static(
"sec-websocket-accept",
),
accept_key,
);
}
}
resp
}
pub(crate) async fn serve(
&self,
req: Request<ReqBody>,
matched_prefix: &str,
) -> HttpResponse {
if let Some(version) = self.proxy_protocol {
return self
.serve_with_proxy_protocol(req, matched_prefix, version)
.await;
}
let backend_req =
match self.prepare_backend_request(req, matched_prefix) {
Ok(r) => r,
Err(e) => {
tracing::error!("failed to build backend URI: {e}");
return response_502();
}
};
if let ProxyClient::H3(h3) = &self.client {
return match h3.request(backend_req).await {
Ok(resp) => resp,
Err(e) => {
tracing::error!("h3: backend request failed: {e:#}");
response_502()
}
};
}
if self.auto_h3_enabled
&& let Some(h3) = self.try_upgrade_to_h3().await
{
match h3.request(backend_req).await {
Ok(resp) => return resp,
Err(e) => {
tracing::debug!(
"h3 upgrade failed, falling back to h1/h2: {e:#}"
);
*self.h3_hint.lock().await = None;
return response_502();
}
}
}
let result = match &self.client {
ProxyClient::Http(c) => c.request(backend_req).await,
#[cfg(unix)]
ProxyClient::Unix(c) => c.request(backend_req).await,
ProxyClient::H3(_) => unreachable!("H3 handled above"),
};
match result {
Ok(resp) => {
if self.auto_h3_enabled {
self.absorb_alt_svc(resp.headers()).await;
}
convert_response(resp)
}
Err(e) => {
if e.is_connect()
&& let Some(m) = &self.metrics
{
m.proxy_upstream_connect_errors_total
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}
tracing::error!("backend request failed: {e}");
response_502()
}
}
}
async fn absorb_alt_svc(&self, headers: &hyper::HeaderMap) {
let original_port = self
.h3_params
.upstream
.port_u16()
.unwrap_or(443);
for v in headers.get_all(hyper::header::ALT_SVC).iter() {
let Ok(s) = v.to_str() else { continue };
if let Some((port, ma)) = parse_alt_svc_h3(s) {
if port < 1024 && port != original_port {
tracing::warn!(
port,
"ignoring Alt-Svc h3 redirect to privileged \
port that doesn't match the upstream URL"
);
continue;
}
let expires_at = std::time::Instant::now()
+ std::time::Duration::from_secs(
ma.min(MAX_ALT_SVC_MA_SECS),
);
*self.h3_hint.lock().await =
Some(H3Hint { port, expires_at });
tracing::debug!(
port,
ma,
"armed h3 auto-upgrade hint from upstream Alt-Svc"
);
return;
}
}
}
async fn try_upgrade_to_h3(&self) -> Option<Arc<H3Client>> {
let port = {
let mut g = self.h3_hint.lock().await;
let entry = g.as_ref()?;
if entry.expires_at <= std::time::Instant::now() {
*g = None;
return None;
}
entry.port
};
let mut lazy = self.h3_lazy.lock().await;
if lazy.is_none() {
let host = self.h3_params.upstream.host()?;
let alt_url = format!("https://{host}:{port}/")
.parse::<Uri>()
.ok()?;
let mut h3 = if self.h3_params.skip_verify {
H3Client::new_skip_verify(
&alt_url,
self.h3_params.pool_idle,
)
.ok()?
} else {
H3Client::new(&alt_url, self.h3_params.pool_idle).ok()?
};
h3.connect_timeout = self.h3_params.connect_timeout;
*lazy = Some(Arc::new(h3));
}
lazy.clone()
}
pub(super) fn prepare_backend_request(
&self,
req: Request<ReqBody>,
matched_prefix: &str,
) -> anyhow::Result<Request<UpstreamBody>> {
self.prepare_backend_request_inner(req, matched_prefix, false)
}
pub(super) fn prepare_backend_upgrade_request(
&self,
req: Request<ReqBody>,
matched_prefix: &str,
) -> anyhow::Result<Request<UpstreamBody>> {
self.prepare_backend_request_inner(req, matched_prefix, true)
}
fn prepare_backend_request_inner(
&self,
req: Request<ReqBody>,
matched_prefix: &str,
is_upgrade: bool,
) -> anyhow::Result<Request<UpstreamBody>> {
let peer_ip = req
.extensions()
.get::<SocketAddr>()
.map(|a| a.ip().to_string());
let backend_uri = build_backend_uri(
&self.upstream,
req.uri(),
matched_prefix,
self.strip_prefix,
)?;
let (mut parts, body) = req.into_parts();
if !is_upgrade {
strip_hop_by_hop(&mut parts.headers);
}
set_forwarding_headers(&mut parts.headers, peer_ip.as_deref());
parts.uri = backend_uri;
parts.version = Version::default();
if let Some(authority) = self.upstream.authority()
&& let Ok(v) = HeaderValue::from_str(authority.as_str())
{
parts.headers.insert(hyper::header::HOST, v);
}
Ok(Request::from_parts(parts, body.boxed_unsync()))
}
async fn serve_with_proxy_protocol(
&self,
req: Request<ReqBody>,
matched_prefix: &str,
version: ProxyProtocolVersion,
) -> HttpResponse {
let src = req.extensions().get::<SocketAddr>().copied();
let dst_tcp = req.extensions().get::<LocalAddr>().map(|a| a.0);
let dst_unix =
req.extensions().get::<LocalUnixPath>().map(|p| p.0.clone());
let backend_req =
match self.prepare_backend_request(req, matched_prefix) {
Ok(r) => r,
Err(e) => {
tracing::error!("failed to build backend URI: {e}");
return response_502();
}
};
let header = match src {
Some(src_addr) => {
let dst = dst_tcp
.unwrap_or_else(|| SocketAddr::from(([0, 0, 0, 0], 0)));
proxy_proto::build_header(version, src_addr, dst)
}
None => match version {
ProxyProtocolVersion::V1 => proxy_proto::build_v1_unknown(),
ProxyProtocolVersion::V2 => match dst_unix.as_deref() {
Some(p) => proxy_proto::build_v2_unix(None, Some(p)),
None => proxy_proto::build_v2_unspec(),
},
},
};
#[cfg(unix)]
if let Some(path) = &self.unix_path {
let mut stream = match tokio::net::UnixStream::connect(path).await {
Ok(s) => s,
Err(e) => {
self.count_connect_error();
tracing::error!("unix upstream connect failed: {e}");
return response_502();
}
};
if let Err(e) = stream.write_all(&header).await {
tracing::error!("writing PROXY header failed: {e}");
return response_502();
}
return match send_http1_request(TokioIo::new(stream), backend_req)
.await
{
Ok(r) => convert_response(r),
Err(e) => {
tracing::error!("backend request failed: {e}");
response_502()
}
};
}
let authority = self
.upstream
.authority()
.expect("upstream authority validated in new()")
.as_str();
let mut stream = match tokio::net::TcpStream::connect(authority).await {
Ok(s) => s,
Err(e) => {
self.count_connect_error();
tracing::error!("upstream connect failed: {e}");
return response_502();
}
};
if let Err(e) = stream.write_all(&header).await {
tracing::error!("writing PROXY header failed: {e}");
return response_502();
}
let resp = if self.upstream.scheme_str() == Some("https") {
let host = self.upstream.host().unwrap_or("");
let server_name = match rustls::pki_types::ServerName::try_from(
host.to_owned(),
) {
Ok(n) => n,
Err(e) => {
tracing::error!(
"invalid upstream hostname '{host}': {e}"
);
return response_502();
}
};
let tls_cfg = Arc::new(
rustls::ClientConfig::builder()
.with_webpki_roots()
.with_no_client_auth(),
);
let tls_stream = match tokio_rustls::TlsConnector::from(tls_cfg)
.connect(server_name, stream)
.await
{
Ok(s) => s,
Err(e) => {
tracing::error!("TLS handshake failed: {e}");
return response_502();
}
};
send_http1_request(TokioIo::new(tls_stream), backend_req).await
} else {
send_http1_request(TokioIo::new(stream), backend_req).await
};
match resp {
Ok(r) => convert_response(r),
Err(e) => {
tracing::error!("backend request failed: {e}");
response_502()
}
}
}
}
async fn send_http1_request<I>(
io: I,
req: Request<UpstreamBody>,
) -> anyhow::Result<Response<Incoming>>
where
I: hyper::rt::Read + hyper::rt::Write + Unpin + Send + 'static,
{
let (mut sender, conn) = hyper::client::conn::http1::handshake(io).await?;
tokio::spawn(conn);
Ok(sender.send_request(req).await?)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn ws_accept_matches_rfc6455_example() {
let key = HeaderValue::from_static("dGhlIHNhbXBsZSBub25jZQ==");
let got = compute_ws_accept(&key).expect("computed");
assert_eq!(got, "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=");
}
#[test]
fn ws_accept_rejects_non_ascii_key() {
let key = HeaderValue::from_bytes(b"key\xff").unwrap();
assert!(compute_ws_accept(&key).is_none());
}
}