kz-proxy 0.3.0

MITM proxy and subprocess sandbox for blind secret injection
Documentation
//! Proxy server: accept loop, connection handler, HTTP forwarding.

use std::collections::HashMap;
use std::sync::Arc;

use bytes::Bytes;
use http_body_util::{BodyExt, Full};
use hyper::service::service_fn;
use hyper::{Method, Request, Response};
use hyper_util::rt::TokioIo;
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::mpsc;

use crate::mitm;
use crate::rewrite::{
    bad_gateway, bad_request, box_response, build_request, collect_and_rewrite_body,
    connection_allowed, http1_handshake, is_private_authority, replace_tokens_in_bytes,
    rewrite_headers, try_or_bad_gateway, validate_secret, BoxBodyType, CONNECT_TIMEOUT,
};
use crate::types::{ConnectionPolicy, SecretMapping, StringMapping};

pub(crate) async fn run_impl(
    program: &str,
    args: &[String],
    secret_mappings: Vec<SecretMapping>,
    string_mappings: Vec<StringMapping>,
    allow_private_connect: bool,
    upstream_ca: Option<std::path::PathBuf>,
    connection_policies: Vec<ConnectionPolicy>,
) -> Result<std::process::ExitStatus, Box<dyn std::error::Error + Send + Sync>> {
    if secret_mappings.is_empty() && string_mappings.is_empty() {
        return Err("sandbox requires at least one secret mapping or string mapping".into());
    }

    for m in &secret_mappings {
        validate_secret(&m.value).map_err(|e| format!("{}: {}", m.var, e))?;
    }
    for m in &string_mappings {
        validate_secret(&m.value).map_err(|e| format!("{}: {}", m.token, e))?;
    }

    let mut env_vars_with_masked: Vec<(String, String)> = Vec::with_capacity(secret_mappings.len());
    let mut proxy_map: HashMap<String, String> =
        HashMap::with_capacity(secret_mappings.len() + string_mappings.len());

    for m in &secret_mappings {
        let masked = format!(
            "{}-{}",
            m.var.to_lowercase().replace('_', "-"),
            uuid::Uuid::new_v4()
        );
        env_vars_with_masked.push((m.var.clone(), masked.clone()));
        proxy_map.insert(masked, m.value.clone());
    }
    for m in &string_mappings {
        proxy_map.insert(m.token.clone(), m.value.clone());
    }

    // Deterministic replacement order: longest token first so substrings are not partially replaced.
    let replacement_order: Vec<(String, String)> = {
        let mut v: Vec<_> = proxy_map.into_iter().collect();
        v.sort_by(|a, b| b.0.len().cmp(&a.0.len()));
        v
    };
    let token_map = Arc::new(replacement_order);

    let (mitm_config, ssl_cert_file) = {
        let (config, ca_pem) =
            mitm::MitmConfig::new(upstream_ca).map_err(|e| format!("MITM config: {}", e))?;
        let temp = {
            use std::io::Write;
            let mut f = tempfile::Builder::new()
                .prefix("kz-ca-")
                .suffix(".pem")
                .tempfile()
                .map_err(|e| format!("create CA cert temp file: {}", e))?;
            #[cfg(unix)]
            {
                use std::os::unix::fs::PermissionsExt;
                std::fs::set_permissions(f.path(), std::fs::Permissions::from_mode(0o600))
                    .map_err(|e| format!("set CA cert permissions: {}", e))?;
            }
            f.write_all(ca_pem.as_bytes())
                .map_err(|e| format!("write CA cert: {}", e))?;
            f.into_temp_path()
        };
        (Arc::new(config), temp)
    };

    let listener = TcpListener::bind("127.0.0.1:0").await?;
    let port = listener.local_addr()?.port();
    let proxy_url = format!("http://127.0.0.1:{}", port);

    let (shutdown_tx, mut shutdown_rx) = mpsc::channel::<()>(1);
    let server_handle = tokio::spawn(async move {
        loop {
            tokio::select! {
                _ = shutdown_rx.recv() => break,
                accept_result = listener.accept() => {
                    let (stream, _) = match accept_result {
                        Ok(x) => x,
                        Err(_) => continue,
                    };
                    let token_map = Arc::clone(&token_map);
                    let allow_private = allow_private_connect;
                    let mitm_config = Arc::clone(&mitm_config);
                    let connection_policies = connection_policies.clone();
                    tokio::spawn(async move {
                        let io = TokioIo::new(stream);
                        let service = service_fn(move |req| {
                            let token_map = Arc::clone(&token_map);
                            let mitm_config = Arc::clone(&mitm_config);
                            let connection_policies = connection_policies.clone();
                            async move { proxy_handler(req, token_map, allow_private, mitm_config, connection_policies).await }
                        });
                        let conn = hyper::server::conn::http1::Builder::new()
                            .serve_connection(io, service)
                            .with_upgrades();
                        if let Err(e) = conn.await {
                            tracing::error!("proxy connection error: {}", e);
                        }
                    });
                }
            }
        }
    });

    let program = program.to_string();
    let args = args.to_vec();
    let exit_status = tokio::task::spawn_blocking(move || {
        crate::enforce::run_child(
            &program,
            &args,
            &proxy_url,
            &env_vars_with_masked,
            &ssl_cert_file,
        )
    })
    .await
    .map_err(|e| format!("subprocess join: {}", e))??;

    let _ = shutdown_tx.send(()).await;
    let _ = server_handle.await;

    Ok(exit_status)
}

async fn proxy_handler(
    req: Request<hyper::body::Incoming>,
    token_map: Arc<Vec<(String, String)>>,
    allow_private_connect: bool,
    mitm_config: Arc<mitm::MitmConfig>,
    connection_policies: Vec<ConnectionPolicy>,
) -> Result<Response<BoxBodyType>, hyper::Error> {
    if req.method() == Method::CONNECT {
        let authority = match req.uri().authority() {
            Some(a) => a.clone(),
            None => return Ok(bad_request("CONNECT must include authority (host:port)")),
        };
        let host = authority.host().to_string();
        let authority_str = authority.to_string();
        if !connection_allowed(&host, Some(&connection_policies)) {
            return Ok(bad_request("CONNECT to this host is not allowed by policy"));
        }
        if !allow_private_connect && is_private_authority(&authority_str) {
            return Ok(bad_request("CONNECT to private/local address not allowed"));
        }
        return mitm::handle_connect_mitm(req, token_map, mitm_config).await;
    }
    handle_forward(req, token_map, connection_policies).await
}

async fn handle_forward(
    req: Request<hyper::body::Incoming>,
    token_map: Arc<Vec<(String, String)>>,
    connection_policies: Vec<ConnectionPolicy>,
) -> Result<Response<BoxBodyType>, hyper::Error> {
    let (parts, body) = req.into_parts();
    let uri = parts.uri.clone();
    let uri_str = uri.to_string();
    let modified_uri_bytes = replace_tokens_in_bytes(uri_str.as_bytes(), &token_map);
    let modified_uri = match String::from_utf8(modified_uri_bytes) {
        Ok(s) => s.parse().unwrap_or(uri),
        Err(_) => uri.clone(),
    };
    let host = match modified_uri.host() {
        Some(h) => h.to_string(),
        None => return Ok(bad_request("Request URI has no host")),
    };
    if !connection_allowed(host.as_str(), Some(&connection_policies)) {
        return Ok(bad_request(
            "Connection to this host is not allowed by policy",
        ));
    }
    let port = modified_uri.port_u16().unwrap_or(80);

    let body_bytes = body.collect().await?.to_bytes();
    let modified_body = match collect_and_rewrite_body(&body_bytes, &token_map) {
        Ok(b) => b,
        Err(resp) => return Ok(resp),
    };

    let mut new_headers = rewrite_headers(&parts.headers, &token_map);
    new_headers.insert(
        http::header::CONTENT_LENGTH,
        http::HeaderValue::from_str(&modified_body.len().to_string()).unwrap(),
    );

    let body = Full::new(Bytes::from(modified_body));
    let new_req = match build_request(parts.method, &modified_uri, new_headers, body) {
        Ok(r) => r,
        Err(resp) => return Ok(resp),
    };

    let stream = match tokio::time::timeout(CONNECT_TIMEOUT, TcpStream::connect((host.as_str(), port))).await {
        Ok(r) => try_or_bad_gateway!(r, "Upstream connection failed"),
        Err(_) => return Ok(bad_gateway("Upstream connect timeout")),
    };
    let io = TokioIo::new(stream);
    let mut sender = match http1_handshake(io).await {
        Ok(s) => s,
        Err(resp) => return Ok(resp),
    };

    let resp = try_or_bad_gateway!(sender.send_request(new_req).await, "Upstream request failed");
    Ok(box_response(resp))
}

#[cfg(test)]
mod tests {
    use crate::types::{Sandbox, SandboxConfig, SecretMapping, StringMapping};

    #[tokio::test]
    async fn sandbox_clone() {
        let m = SecretMapping {
            var: "K".to_string(),
            value: "v".to_string(),
        };
        let config = SandboxConfig {
            secrets: vec![m],
            ..SandboxConfig::default()
        };
        let sandbox = Sandbox::new(config);
        let cloned = sandbox.clone();
        let status = cloned.run("true", &[]).await;
        assert!(status.is_ok());
        assert!(status.unwrap().success());
    }

    #[tokio::test]
    async fn run_requires_at_least_one_mapping() {
        let config = SandboxConfig {
            ..SandboxConfig::default()
        };
        let sandbox = Sandbox::new(config);
        let result = sandbox.run("true", &[]).await;
        assert!(result.is_err());
        let err = result.unwrap_err();
        assert!(
            err.to_string()
                .contains("at least one secret mapping or string mapping"),
            "expected message about mapping, got: {}",
            err
        );
    }

    #[tokio::test]
    async fn run_with_string_mapping_only() {
        let config = SandboxConfig {
            strings: vec![StringMapping {
                token: "__TOKEN__".to_string(),
                value: "replaced".to_string(),
            }],
            ..SandboxConfig::default()
        };
        let sandbox = Sandbox::new(config);
        let result = sandbox.run("true", &[]).await;
        assert!(result.is_ok());
        assert!(result.unwrap().success());
    }

    #[tokio::test]
    async fn run_returns_exit_status_success() {
        let config = SandboxConfig {
            secrets: vec![SecretMapping {
                var: "X".to_string(),
                value: "x".to_string(),
            }],
            ..SandboxConfig::default()
        };
        let sandbox = Sandbox::new(config);
        let result = sandbox.run("true", &[]).await;
        assert!(result.is_ok(), "{:?}", result.err());
        assert!(result.unwrap().success());
    }

    #[tokio::test]
    async fn run_returns_exit_status_failure() {
        let config = SandboxConfig {
            secrets: vec![SecretMapping {
                var: "X".to_string(),
                value: "x".to_string(),
            }],
            ..SandboxConfig::default()
        };
        let sandbox = Sandbox::new(config);
        let result = sandbox.run("false", &[]).await;
        assert!(result.is_ok());
        assert!(!result.unwrap().success());
    }

    #[tokio::test]
    async fn run_forward_exit_code() {
        let config = SandboxConfig {
            secrets: vec![SecretMapping {
                var: "X".to_string(),
                value: "x".to_string(),
            }],
            ..SandboxConfig::default()
        };
        let sandbox = Sandbox::new(config);
        let result = sandbox
            .run("sh", &["-c".to_string(), "exit 42".to_string()])
            .await;
        assert!(result.is_ok());
        assert_eq!(result.unwrap().code(), Some(42));
    }

    #[tokio::test]
    async fn run_command_sees_masked_env() {
        let config = SandboxConfig {
            secrets: vec![SecretMapping {
                var: "API_KEY".to_string(),
                value: "real-secret".to_string(),
            }],
            ..SandboxConfig::default()
        };
        let sandbox = Sandbox::new(config);
        let result = sandbox.run("sh", &["-c".to_string(), "v=$API_KEY; if [ \"$v\" = \"real-secret\" ]; then exit 1; fi; case \"$v\" in api-key-*) exit 0;; *) exit 2;; esac".to_string()]).await;
        assert!(result.is_ok(), "run failed");
        assert_eq!(
            result.unwrap().code(),
            Some(0),
            "expected masked token pattern api-key-*"
        );
    }

    #[tokio::test]
    async fn run_sets_http_proxy_env() {
        let config = SandboxConfig {
            secrets: vec![SecretMapping {
                var: "X".to_string(),
                value: "x".to_string(),
            }],
            ..SandboxConfig::default()
        };
        let sandbox = Sandbox::new(config);
        let result = sandbox
            .run("sh", &["-c".to_string(), "case \"$HTTP_PROXY\" in http://127.0.0.1*) exit 0;; *) exit 1;; esac".to_string()])
            .await;
        assert!(result.is_ok());
        assert_eq!(result.unwrap().code(), Some(0));
        let result = sandbox
            .run("sh", &["-c".to_string(), "case \"$HTTPS_PROXY\" in http://127.0.0.1*) exit 0;; *) exit 1;; esac".to_string()])
            .await;
        assert!(result.is_ok());
        assert_eq!(result.unwrap().code(), Some(0));
    }
}