kz-proxy 0.3.0

MITM proxy and subprocess sandbox for blind secret injection
Documentation
//! HTTPS MITM: on CONNECT, accept TLS with a generated cert, rewrite tokens in the request, forward to upstream over TLS.
//!
//! The subprocess must trust our CA (e.g. SSL_CERT_FILE).

use std::sync::Arc;

use bytes::Bytes;
use http_body_util::{BodyExt, Empty, Full};
use hyper::server::conn::http1;
use hyper::service::service_fn;
use hyper::{Request, Response};
use hyper_util::rt::TokioIo;
use rcgen::{CertificateParams, KeyPair};
use rustls::pki_types::{CertificateDer, PrivateKeyDer};
use tokio::net::TcpStream;

use crate::rewrite::{
    bad_gateway, bad_request, box_response, build_request, collect_and_rewrite_body,
    http1_handshake, replace_tokens_in_bytes, rewrite_headers, try_or_bad_gateway, BoxBodyType,
    CONNECT_TIMEOUT,
};

/// Root CA and cert generation for MITM. Built once per sandbox when MITM is enabled.
pub struct MitmConfig {
    /// Root issuer used to sign per-host certs.
    root_issuer: rcgen::Issuer<'static, rcgen::KeyPair>,
    /// Optional path to a PEM file with extra CA cert(s) to trust for upstream (e.g. test server).
    pub(crate) optional_extra_ca: Option<std::path::PathBuf>,
}

impl MitmConfig {
    /// Generate a new root CA and return config plus PEM of the CA cert for SSL_CERT_FILE.
    /// When `optional_extra_ca` is Some, those certs are trusted when connecting to upstream (e.g. for tests).
    pub fn new(
        optional_extra_ca: Option<std::path::PathBuf>,
    ) -> Result<(Self, String), Box<dyn std::error::Error + Send + Sync>> {
        // Install the default crypto provider once at startup; ignore AlreadyInstalled.
        let _ = rustls::crypto::ring::default_provider().install_default();

        let mut params = CertificateParams::default();
        params.distinguished_name = rcgen::DistinguishedName::new();
        params.distinguished_name.push(
            rcgen::DnType::CommonName,
            rcgen::DnValue::Utf8String("keyzero-mitm-ca".to_string()),
        );
        params.key_usages = vec![
            rcgen::KeyUsagePurpose::KeyCertSign,
            rcgen::KeyUsagePurpose::CrlSign,
        ];
        params.is_ca = rcgen::IsCa::Ca(rcgen::BasicConstraints::Unconstrained);
        let signing_key = KeyPair::generate()?;
        let cert = params.self_signed(&signing_key)?;
        let ca_pem = cert.pem();
        let root_issuer = rcgen::Issuer::new(params, signing_key);
        Ok((
            Self {
                root_issuer,
                optional_extra_ca,
            },
            ca_pem,
        ))
    }

    /// Build a rustls ServerConfig for the given host (generates a cert signed by our root).
    fn server_config(
        &self,
        host: &str,
    ) -> Result<rustls::ServerConfig, Box<dyn std::error::Error + Send + Sync>> {
        let mut params = CertificateParams::new(vec![host.to_string()])?;
        params.distinguished_name = rcgen::DistinguishedName::new();
        params
            .distinguished_name
            .push(rcgen::DnType::CommonName, host);
        let key_pair = KeyPair::generate()?;
        let cert = params.signed_by(&key_pair, &self.root_issuer)?;
        let cert_der = CertificateDer::from(cert.der().to_vec());
        let key_der = PrivateKeyDer::Pkcs8(rustls::pki_types::PrivatePkcs8KeyDer::from(
            key_pair.serialize_der(),
        ));
        let config = rustls::ServerConfig::builder()
            .with_no_client_auth()
            .with_single_cert(vec![cert_der], key_der)?;
        Ok(config)
    }
}

/// Handle CONNECT with MITM: accept TLS from client, run HTTP with token replacement, forward to upstream over TLS.
pub async fn handle_connect_mitm(
    req: hyper::Request<hyper::body::Incoming>,
    token_map: Arc<Vec<(String, String)>>,
    mitm_config: Arc<MitmConfig>,
) -> Result<Response<BoxBodyType>, hyper::Error> {
    let authority = match req.uri().authority() {
        Some(a) => a.clone(),
        None => {
            return Ok(bad_request(
                "CONNECT must include authority (host:port)",
            ));
        }
    };
    let authority_str = authority.to_string();
    let host = authority.host().to_string();

    tokio::spawn(async move {
        let upgraded = match hyper::upgrade::on(req).await {
            Ok(u) => u,
            Err(e) => {
                tracing::error!("CONNECT MITM upgrade error: {}", e);
                return;
            }
        };

        let server_config = match mitm_config.server_config(&host) {
            Ok(c) => c,
            Err(e) => {
                tracing::error!("CONNECT MITM cert gen for {}: {}", host, e);
                return;
            }
        };
        let acceptor = tokio_rustls::TlsAcceptor::from(Arc::new(server_config));
        let tls_stream = match acceptor.accept(TokioIo::new(upgraded)).await {
            Ok(s) => s,
            Err(e) => {
                tracing::error!("CONNECT MITM TLS accept for {}: {}", host, e);
                return;
            }
        };

        let token_map = Arc::clone(&token_map);
        let authority_str = authority_str.clone();
        let mitm_config = Arc::clone(&mitm_config);
        let service = service_fn(move |req: Request<hyper::body::Incoming>| {
            let token_map = Arc::clone(&token_map);
            let authority = authority_str.clone();
            let mitm_config = Arc::clone(&mitm_config);
            async move {
                mitm_forward(
                    req,
                    token_map,
                    &authority,
                    &mitm_config,
                )
                .await
            }
        });

        if let Err(e) = http1::Builder::new()
            .serve_connection(TokioIo::new(tls_stream), service)
            .await
        {
            tracing::error!("CONNECT MITM serve for {}: {}", host, e);
        }
    });

    Ok(Response::new(
        Empty::<Bytes>::new()
            .map_err(|never: std::convert::Infallible| match never {})
            .boxed(),
    ))
}

async fn mitm_forward(
    req: Request<hyper::body::Incoming>,
    token_map: Arc<Vec<(String, String)>>,
    authority: &str,
    config: &MitmConfig,
) -> Result<Response<BoxBodyType>, hyper::Error> {
    let (parts, body) = req.into_parts();
    let body_bytes = body.collect().await?.to_bytes();
    let modified_body = match collect_and_rewrite_body(&body_bytes, &token_map) {
        Ok(b) => b,
        Err(resp) => return Ok(resp),
    };

    let mut new_headers = rewrite_headers(&parts.headers, &token_map);
    new_headers.insert(
        http::header::CONTENT_LENGTH,
        http::HeaderValue::from_str(&modified_body.len().to_string()).unwrap(),
    );

    // Build URI for upstream and apply token replacement (path + query, same as HTTP path).
    let path = parts.uri.path();
    let query = parts
        .uri
        .query()
        .map(|q| format!("?{}", q))
        .unwrap_or_default();
    let uri_str = format!("https://{}{}{}", authority, path, query);
    let modified_uri_bytes = replace_tokens_in_bytes(uri_str.as_bytes(), &token_map);
    let modified_uri_str = match String::from_utf8(modified_uri_bytes) {
        Ok(s) => s,
        Err(_) => uri_str.clone(),
    };
    let uri = match modified_uri_str
        .parse::<http::Uri>()
        .or_else(|_| uri_str.parse::<http::Uri>())
    {
        Ok(u) => u,
        Err(_) => return Ok(bad_gateway("bad upstream URI")),
    };

    let new_req = match build_request(
        parts.method,
        &uri,
        new_headers,
        Full::new(Bytes::from(modified_body)),
    ) {
        Ok(r) => r,
        Err(resp) => return Ok(resp),
    };

    let (host, port) = authority
        .split_once(':')
        .map(|(h, p)| (h, p.parse().unwrap_or(443)))
        .unwrap_or((authority, 443));

    let tcp = match tokio::time::timeout(CONNECT_TIMEOUT, TcpStream::connect((host, port))).await {
        Ok(r) => try_or_bad_gateway!(r, "upstream connection failed"),
        Err(_) => return Ok(bad_gateway("upstream connect timeout")),
    };

    let mut root_store =
        rustls::RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
    if let Some(path) = &config.optional_extra_ca {
        let pem = std::fs::read(path).map_err(|e| {
            tracing::error!("MITM extra CA read error: {}", e);
        });
        if let Ok(pem) = pem {
            for cert in rustls_pemfile::certs(&mut std::io::Cursor::new(pem)).flatten() {
                let _ = root_store.add(cert);
            }
        }
    }
    let client_config = rustls::ClientConfig::builder()
        .with_root_certificates(root_store)
        .with_no_client_auth();
    let connector = tokio_rustls::TlsConnector::from(Arc::new(client_config));
    let domain = match host.to_string().try_into() {
        Ok(d) => d,
        Err(_) => return Ok(bad_gateway("invalid upstream host")),
    };
    let tls = try_or_bad_gateway!(connector.connect(domain, tcp).await, "upstream TLS failed");

    let io = TokioIo::new(tls);
    let mut sender = match http1_handshake(io).await {
        Ok(s) => s,
        Err(resp) => return Ok(resp),
    };

    let resp = try_or_bad_gateway!(sender.send_request(new_req).await, "upstream request failed");
    Ok(box_response(resp))
}