macro_rules! try_or_bad_gateway {
($expr:expr, $msg:expr) => {
match $expr {
Ok(v) => v,
Err(e) => {
tracing::error!("{}: {}", $msg, e);
return Ok($crate::rewrite::bad_gateway($msg));
}
}
};
}
pub(crate) use try_or_bad_gateway;
use bytes::Bytes;
use http_body_util::{combinators::BoxBody, BodyExt, Full};
use hyper::{Request, Response};
use crate::types::ConnectionPolicy;
pub(crate) fn connection_allowed(host: &str, policies: Option<&[ConnectionPolicy]>) -> bool {
let Some(policies) = policies else {
return true;
};
if policies.is_empty() {
return true;
}
for p in policies {
if p.pattern.matches(host) {
return p.allow;
}
}
false
}
pub(crate) fn is_private_authority(authority: &str) -> bool {
let (host, _port) = authority.split_once(':').unwrap_or((authority, "80"));
let host = host.trim_start_matches('[').trim_end_matches(']');
if host == "localhost" || host.is_empty() {
return true;
}
if let Ok(ip) = host.parse::<std::net::IpAddr>() {
return is_private_ip(ip);
}
false
}
pub(crate) fn is_private_ip(ip: std::net::IpAddr) -> bool {
match ip {
std::net::IpAddr::V4(a) => {
a.is_loopback()
|| a.is_private()
|| a.is_link_local()
|| a.is_broadcast()
|| a.is_documentation()
}
std::net::IpAddr::V6(a) => {
a.is_loopback()
|| a.is_unspecified()
|| {
let segments = a.segments();
(segments[0] & 0xffc0) == 0xfe80
|| (segments[0] & 0xfe00) == 0xfc00
}
}
}
}
pub(crate) const FORBIDDEN_IN_SECRET: &[u8] = b"\r\n\0";
pub(crate) fn validate_secret(value: &str) -> Result<(), String> {
if value.bytes().any(|b| FORBIDDEN_IN_SECRET.contains(&b)) {
return Err("secret must not contain CR, LF, or NUL".to_string());
}
Ok(())
}
pub(crate) fn replace_bytes(buf: &[u8], from: &[u8], to: &[u8]) -> Vec<u8> {
if from.is_empty() || buf.is_empty() {
return buf.to_vec();
}
let mut out = Vec::with_capacity(buf.len());
let mut i = 0;
while i <= buf.len().saturating_sub(from.len()) {
if buf[i..].starts_with(from) {
out.extend_from_slice(to);
i += from.len();
} else {
out.push(buf[i]);
i += 1;
}
}
out.extend_from_slice(&buf[i..]);
out
}
pub(crate) fn replace_tokens_in_bytes(
buf: &[u8],
replacement_order: &[(String, String)],
) -> Vec<u8> {
let mut current = buf.to_vec();
for (masked, real) in replacement_order {
current = replace_bytes(¤t, masked.as_bytes(), real.as_bytes());
}
current
}
pub(crate) fn replace_tokens_in_header_value(
value: &str,
replacement_order: &[(String, String)],
) -> Option<String> {
let replaced = replace_tokens_in_bytes(value.as_bytes(), replacement_order);
let s = String::from_utf8(replaced).ok()?;
if s.bytes().any(|b| FORBIDDEN_IN_SECRET.contains(&b)) {
return None;
}
Some(s)
}
pub(crate) const MAX_BODY_SIZE: usize = 100 * 1024 * 1024;
pub(crate) const CONNECT_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10);
pub(crate) type BoxBodyType = BoxBody<Bytes, hyper::Error>;
pub(crate) fn full_body(chunk: Bytes) -> BoxBodyType {
Full::new(chunk).map_err(|never| match never {}).boxed()
}
pub(crate) fn bad_request(msg: &str) -> Response<BoxBodyType> {
Response::builder()
.status(http::StatusCode::BAD_REQUEST)
.body(full_body(Bytes::copy_from_slice(msg.as_bytes())))
.unwrap()
}
pub(crate) fn bad_gateway(msg: &str) -> Response<BoxBodyType> {
Response::builder()
.status(http::StatusCode::BAD_GATEWAY)
.body(full_body(Bytes::copy_from_slice(msg.as_bytes())))
.unwrap()
}
pub(crate) fn rewrite_headers(
original_headers: &hyper::header::HeaderMap,
token_map: &[(String, String)],
) -> hyper::header::HeaderMap {
let mut new_headers = http::HeaderMap::new();
for (name, value) in original_headers.iter() {
if name == http::header::CONNECTION
|| name.as_str().eq_ignore_ascii_case("proxy-connection")
|| name == http::header::TRANSFER_ENCODING
{
continue;
}
let value_str = match value.to_str() {
Ok(s) => s,
Err(_) => continue,
};
let new_value = replace_tokens_in_header_value(value_str, token_map);
if let Some(v) = new_value {
if let Ok(hv) = v.parse() {
new_headers.insert(name.clone(), hv);
}
}
}
new_headers
}
pub(crate) fn build_request<B>(
method: hyper::Method,
uri: &http::Uri,
headers: http::HeaderMap,
body: B,
) -> Result<Request<B>, Response<BoxBodyType>> {
let req = Request::builder()
.method(method)
.uri(uri)
.body(body)
.map_err(|_| bad_gateway("Invalid request"))?;
let (mut parts, body) = req.into_parts();
parts.headers = headers;
Ok(Request::from_parts(parts, body))
}
pub(crate) async fn http1_handshake<I>(
io: I,
) -> Result<hyper::client::conn::http1::SendRequest<Full<Bytes>>, Response<BoxBodyType>>
where
I: hyper::rt::Read + hyper::rt::Write + Unpin + Send + 'static,
{
let (sender, conn) = hyper::client::conn::http1::Builder::new()
.handshake(io)
.await
.map_err(|e| {
tracing::error!("upstream handshake error: {}", e);
bad_gateway("Upstream handshake failed")
})?;
tokio::spawn(async move {
let _ = conn.await;
});
Ok(sender)
}
pub(crate) fn collect_and_rewrite_body(
body_bytes: &[u8],
token_map: &[(String, String)],
) -> Result<Vec<u8>, Response<BoxBodyType>> {
if body_bytes.len() > MAX_BODY_SIZE {
return Err(bad_request("Request body too large"));
}
Ok(replace_tokens_in_bytes(body_bytes, token_map))
}
pub(crate) fn box_response(
resp: Response<hyper::body::Incoming>,
) -> Response<BoxBodyType> {
let (parts, body) = resp.into_parts();
let body = body.boxed();
Response::from_parts(parts, body)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::HostPattern;
#[test]
fn is_private_authority_blocks_localhost() {
assert!(is_private_authority("localhost:443"));
assert!(is_private_authority("127.0.0.1:8080"));
assert!(is_private_authority("10.0.0.1:80"));
assert!(!is_private_authority("example.com:443"));
}
#[test]
fn connection_allowed_no_policies() {
assert!(connection_allowed("example.com", None));
assert!(connection_allowed("evil.com", Some(&[])));
}
#[test]
fn connection_allowed_first_match_wins() {
let policies = vec![
ConnectionPolicy::deny(HostPattern::exact("blocked.com")),
ConnectionPolicy::allow(HostPattern::exact("blocked.com")),
];
assert!(!connection_allowed("blocked.com", Some(&policies)));
}
#[test]
fn connection_allowed_regex() {
let policies = vec![
ConnectionPolicy::deny(HostPattern::regex(r"^internal\.").unwrap()),
ConnectionPolicy::allow(HostPattern::exact("api.example.com")),
];
assert!(!connection_allowed("internal.service", Some(&policies)));
assert!(connection_allowed("api.example.com", Some(&policies)));
assert!(!connection_allowed("other.com", Some(&policies))); }
#[test]
fn replace_bytes_basic() {
let buf = b"hello world";
let out = replace_bytes(buf, b"o", b"X");
assert_eq!(out.as_slice(), b"hellX wXrld");
}
#[test]
fn replace_tokens_longest_first() {
let order = vec![
("api-key-long".to_string(), "real-long".to_string()),
("api-key".to_string(), "real-short".to_string()),
];
let buf = b"prefix api-key-long suffix";
let out = replace_tokens_in_bytes(buf, &order);
assert_eq!(out.as_slice(), b"prefix real-long suffix");
}
#[test]
fn replace_tokens_in_bytes_string_mapping() {
let order = vec![("__TOKEN__".to_string(), "real-value".to_string())];
let buf = b"Bearer __TOKEN__";
let out = replace_tokens_in_bytes(buf, &order);
assert_eq!(out.as_slice(), b"Bearer real-value");
}
#[test]
fn validate_secret_rejects_crlf() {
assert!(validate_secret("ok").is_ok());
assert!(validate_secret("no\rcr").is_err());
assert!(validate_secret("no\nlf").is_err());
assert!(validate_secret("no\0nul").is_err());
}
}