kz-proxy 0.3.0

MITM proxy and subprocess sandbox for blind secret injection
Documentation
//! Request/response transformation and policy: token replacement, header rewriting,
//! HTTP helpers, connection policy evaluation, and SSRF checks.

/// Match a `Result`, returning the `Ok` value or logging the error and returning a bad-gateway response.
macro_rules! try_or_bad_gateway {
    ($expr:expr, $msg:expr) => {
        match $expr {
            Ok(v) => v,
            Err(e) => {
                tracing::error!("{}: {}", $msg, e);
                return Ok($crate::rewrite::bad_gateway($msg));
            }
        }
    };
}
pub(crate) use try_or_bad_gateway;

use bytes::Bytes;
use http_body_util::{combinators::BoxBody, BodyExt, Full};
use hyper::{Request, Response};

use crate::types::ConnectionPolicy;

// ---------------------------------------------------------------------------
// Connection policy & SSRF
// ---------------------------------------------------------------------------

/// Returns true if an outbound connection to `host` is allowed by the given policies.
/// If policies is None or empty, returns true (allow all). Otherwise first matching policy wins; no match = deny.
pub(crate) fn connection_allowed(host: &str, policies: Option<&[ConnectionPolicy]>) -> bool {
    let Some(policies) = policies else {
        return true;
    };
    if policies.is_empty() {
        return true;
    }
    for p in policies {
        if p.pattern.matches(host) {
            return p.allow;
        }
    }
    // When rules exist but none match, deny by default (allowlist behavior)
    false
}

/// Check if the authority host is a private/local address (SSRF mitigation).
pub(crate) fn is_private_authority(authority: &str) -> bool {
    let (host, _port) = authority.split_once(':').unwrap_or((authority, "80"));
    let host = host.trim_start_matches('[').trim_end_matches(']');
    if host == "localhost" || host.is_empty() {
        return true;
    }
    if let Ok(ip) = host.parse::<std::net::IpAddr>() {
        return is_private_ip(ip);
    }
    false
}

pub(crate) fn is_private_ip(ip: std::net::IpAddr) -> bool {
    match ip {
        std::net::IpAddr::V4(a) => {
            a.is_loopback()
                || a.is_private()
                || a.is_link_local()
                || a.is_broadcast()
                || a.is_documentation()
        }
        std::net::IpAddr::V6(a) => {
            a.is_loopback()
                || a.is_unspecified()
                || {
                    let segments = a.segments();
                    // Link-local fe80::/10
                    (segments[0] & 0xffc0) == 0xfe80
                    // Unique local fc00::/7
                    || (segments[0] & 0xfe00) == 0xfc00
                }
        }
    }
}

// ---------------------------------------------------------------------------
// Token replacement
// ---------------------------------------------------------------------------

/// Characters that must not appear in real (or masked) secrets to avoid header/body injection.
pub(crate) const FORBIDDEN_IN_SECRET: &[u8] = b"\r\n\0";

pub(crate) fn validate_secret(value: &str) -> Result<(), String> {
    if value.bytes().any(|b| FORBIDDEN_IN_SECRET.contains(&b)) {
        return Err("secret must not contain CR, LF, or NUL".to_string());
    }
    Ok(())
}

/// Replace all occurrences of `from` with `to` in `buf` (byte-wise).
pub(crate) fn replace_bytes(buf: &[u8], from: &[u8], to: &[u8]) -> Vec<u8> {
    if from.is_empty() || buf.is_empty() {
        return buf.to_vec();
    }
    let mut out = Vec::with_capacity(buf.len());
    let mut i = 0;
    while i <= buf.len().saturating_sub(from.len()) {
        if buf[i..].starts_with(from) {
            out.extend_from_slice(to);
            i += from.len();
        } else {
            out.push(buf[i]);
            i += 1;
        }
    }
    out.extend_from_slice(&buf[i..]);
    out
}

/// Apply token replacements (longest first) to a byte buffer.
pub(crate) fn replace_tokens_in_bytes(
    buf: &[u8],
    replacement_order: &[(String, String)],
) -> Vec<u8> {
    let mut current = buf.to_vec();
    for (masked, real) in replacement_order {
        current = replace_bytes(&current, masked.as_bytes(), real.as_bytes());
    }
    current
}

/// Apply token replacements to a header value string; returns None if result would be invalid.
pub(crate) fn replace_tokens_in_header_value(
    value: &str,
    replacement_order: &[(String, String)],
) -> Option<String> {
    let replaced = replace_tokens_in_bytes(value.as_bytes(), replacement_order);
    let s = String::from_utf8(replaced).ok()?;
    if s.bytes().any(|b| FORBIDDEN_IN_SECRET.contains(&b)) {
        return None;
    }
    Some(s)
}

// ---------------------------------------------------------------------------
// HTTP helpers
// ---------------------------------------------------------------------------

/// Maximum request body size to prevent memory exhaustion (100 MB).
pub(crate) const MAX_BODY_SIZE: usize = 100 * 1024 * 1024;

/// Timeout for establishing upstream TCP connections.
pub(crate) const CONNECT_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10);

pub(crate) type BoxBodyType = BoxBody<Bytes, hyper::Error>;

pub(crate) fn full_body(chunk: Bytes) -> BoxBodyType {
    Full::new(chunk).map_err(|never| match never {}).boxed()
}

pub(crate) fn bad_request(msg: &str) -> Response<BoxBodyType> {
    Response::builder()
        .status(http::StatusCode::BAD_REQUEST)
        .body(full_body(Bytes::copy_from_slice(msg.as_bytes())))
        .unwrap()
}

pub(crate) fn bad_gateway(msg: &str) -> Response<BoxBodyType> {
    Response::builder()
        .status(http::StatusCode::BAD_GATEWAY)
        .body(full_body(Bytes::copy_from_slice(msg.as_bytes())))
        .unwrap()
}

/// Filter hop-by-hop headers and apply token replacements to header values.
pub(crate) fn rewrite_headers(
    original_headers: &hyper::header::HeaderMap,
    token_map: &[(String, String)],
) -> hyper::header::HeaderMap {
    let mut new_headers = http::HeaderMap::new();
    for (name, value) in original_headers.iter() {
        if name == http::header::CONNECTION
            || name.as_str().eq_ignore_ascii_case("proxy-connection")
            || name == http::header::TRANSFER_ENCODING
        {
            continue;
        }
        let value_str = match value.to_str() {
            Ok(s) => s,
            Err(_) => continue,
        };
        let new_value = replace_tokens_in_header_value(value_str, token_map);
        if let Some(v) = new_value {
            if let Ok(hv) = v.parse() {
                new_headers.insert(name.clone(), hv);
            }
        }
    }
    new_headers
}

/// Build a request with the given method, URI, headers, and body.
/// Uses builder -> into_parts -> set headers -> from_parts to ensure headers are exactly as given.
pub(crate) fn build_request<B>(
    method: hyper::Method,
    uri: &http::Uri,
    headers: http::HeaderMap,
    body: B,
) -> Result<Request<B>, Response<BoxBodyType>> {
    let req = Request::builder()
        .method(method)
        .uri(uri)
        .body(body)
        .map_err(|_| bad_gateway("Invalid request"))?;
    let (mut parts, body) = req.into_parts();
    parts.headers = headers;
    Ok(Request::from_parts(parts, body))
}

/// Perform HTTP/1.1 handshake and spawn the connection driver.
pub(crate) async fn http1_handshake<I>(
    io: I,
) -> Result<hyper::client::conn::http1::SendRequest<Full<Bytes>>, Response<BoxBodyType>>
where
    I: hyper::rt::Read + hyper::rt::Write + Unpin + Send + 'static,
{
    let (sender, conn) = hyper::client::conn::http1::Builder::new()
        .handshake(io)
        .await
        .map_err(|e| {
            tracing::error!("upstream handshake error: {}", e);
            bad_gateway("Upstream handshake failed")
        })?;
    tokio::spawn(async move {
        let _ = conn.await;
    });
    Ok(sender)
}

/// Collect body bytes, check size limit, and apply token replacements.
pub(crate) fn collect_and_rewrite_body(
    body_bytes: &[u8],
    token_map: &[(String, String)],
) -> Result<Vec<u8>, Response<BoxBodyType>> {
    if body_bytes.len() > MAX_BODY_SIZE {
        return Err(bad_request("Request body too large"));
    }
    Ok(replace_tokens_in_bytes(body_bytes, token_map))
}

/// Wrap a response body as a BoxBodyType for return.
pub(crate) fn box_response(
    resp: Response<hyper::body::Incoming>,
) -> Response<BoxBodyType> {
    let (parts, body) = resp.into_parts();
    let body = body.boxed();
    Response::from_parts(parts, body)
}

// ---------------------------------------------------------------------------
// Tests
// ---------------------------------------------------------------------------

#[cfg(test)]
mod tests {
    use super::*;
    use crate::types::HostPattern;

    // -- policy tests --

    #[test]
    fn is_private_authority_blocks_localhost() {
        assert!(is_private_authority("localhost:443"));
        assert!(is_private_authority("127.0.0.1:8080"));
        assert!(is_private_authority("10.0.0.1:80"));
        assert!(!is_private_authority("example.com:443"));
    }

    #[test]
    fn connection_allowed_no_policies() {
        assert!(connection_allowed("example.com", None));
        assert!(connection_allowed("evil.com", Some(&[])));
    }

    #[test]
    fn connection_allowed_first_match_wins() {
        let policies = vec![
            ConnectionPolicy::deny(HostPattern::exact("blocked.com")),
            ConnectionPolicy::allow(HostPattern::exact("blocked.com")),
        ];
        assert!(!connection_allowed("blocked.com", Some(&policies)));
    }

    #[test]
    fn connection_allowed_regex() {
        let policies = vec![
            ConnectionPolicy::deny(HostPattern::regex(r"^internal\.").unwrap()),
            ConnectionPolicy::allow(HostPattern::exact("api.example.com")),
        ];
        assert!(!connection_allowed("internal.service", Some(&policies)));
        assert!(connection_allowed("api.example.com", Some(&policies)));
        assert!(!connection_allowed("other.com", Some(&policies))); // no match = deny (allowlist behavior)
    }

    // -- token tests --

    #[test]
    fn replace_bytes_basic() {
        let buf = b"hello world";
        let out = replace_bytes(buf, b"o", b"X");
        assert_eq!(out.as_slice(), b"hellX wXrld");
    }

    #[test]
    fn replace_tokens_longest_first() {
        let order = vec![
            ("api-key-long".to_string(), "real-long".to_string()),
            ("api-key".to_string(), "real-short".to_string()),
        ];
        let buf = b"prefix api-key-long suffix";
        let out = replace_tokens_in_bytes(buf, &order);
        assert_eq!(out.as_slice(), b"prefix real-long suffix");
    }

    #[test]
    fn replace_tokens_in_bytes_string_mapping() {
        let order = vec![("__TOKEN__".to_string(), "real-value".to_string())];
        let buf = b"Bearer __TOKEN__";
        let out = replace_tokens_in_bytes(buf, &order);
        assert_eq!(out.as_slice(), b"Bearer real-value");
    }

    #[test]
    fn validate_secret_rejects_crlf() {
        assert!(validate_secret("ok").is_ok());
        assert!(validate_secret("no\rcr").is_err());
        assert!(validate_secret("no\nlf").is_err());
        assert!(validate_secret("no\0nul").is_err());
    }
}