use std::collections::HashMap;
use std::sync::Arc;
use bytes::Bytes;
use http_body_util::{BodyExt, Full};
use hyper::service::service_fn;
use hyper::{Method, Request, Response};
use hyper_util::rt::TokioIo;
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::mpsc;
use crate::mitm;
use crate::rewrite::{
bad_gateway, bad_request, box_response, build_request, collect_and_rewrite_body,
connection_allowed, http1_handshake, is_private_authority, replace_tokens_in_bytes,
rewrite_headers, try_or_bad_gateway, validate_secret, BoxBodyType, CONNECT_TIMEOUT,
};
use crate::types::{ConnectionPolicy, SecretMapping, StringMapping};
pub(crate) async fn run_impl(
program: &str,
args: &[String],
secret_mappings: Vec<SecretMapping>,
string_mappings: Vec<StringMapping>,
allow_private_connect: bool,
upstream_ca: Option<std::path::PathBuf>,
connection_policies: Vec<ConnectionPolicy>,
) -> Result<std::process::ExitStatus, Box<dyn std::error::Error + Send + Sync>> {
if secret_mappings.is_empty() && string_mappings.is_empty() {
return Err("sandbox requires at least one secret mapping or string mapping".into());
}
for m in &secret_mappings {
validate_secret(&m.value).map_err(|e| format!("{}: {}", m.var, e))?;
}
for m in &string_mappings {
validate_secret(&m.value).map_err(|e| format!("{}: {}", m.token, e))?;
}
let mut env_vars_with_masked: Vec<(String, String)> = Vec::with_capacity(secret_mappings.len());
let mut proxy_map: HashMap<String, String> =
HashMap::with_capacity(secret_mappings.len() + string_mappings.len());
for m in &secret_mappings {
let masked = format!(
"{}-{}",
m.var.to_lowercase().replace('_', "-"),
uuid::Uuid::new_v4()
);
env_vars_with_masked.push((m.var.clone(), masked.clone()));
proxy_map.insert(masked, m.value.clone());
}
for m in &string_mappings {
proxy_map.insert(m.token.clone(), m.value.clone());
}
let replacement_order: Vec<(String, String)> = {
let mut v: Vec<_> = proxy_map.into_iter().collect();
v.sort_by(|a, b| b.0.len().cmp(&a.0.len()));
v
};
let token_map = Arc::new(replacement_order);
let (mitm_config, ssl_cert_file) = {
let (config, ca_pem) =
mitm::MitmConfig::new(upstream_ca).map_err(|e| format!("MITM config: {}", e))?;
let temp = {
use std::io::Write;
let mut f = tempfile::Builder::new()
.prefix("kz-ca-")
.suffix(".pem")
.tempfile()
.map_err(|e| format!("create CA cert temp file: {}", e))?;
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
std::fs::set_permissions(f.path(), std::fs::Permissions::from_mode(0o600))
.map_err(|e| format!("set CA cert permissions: {}", e))?;
}
f.write_all(ca_pem.as_bytes())
.map_err(|e| format!("write CA cert: {}", e))?;
f.into_temp_path()
};
(Arc::new(config), temp)
};
let listener = TcpListener::bind("127.0.0.1:0").await?;
let port = listener.local_addr()?.port();
let proxy_url = format!("http://127.0.0.1:{}", port);
let (shutdown_tx, mut shutdown_rx) = mpsc::channel::<()>(1);
let server_handle = tokio::spawn(async move {
loop {
tokio::select! {
_ = shutdown_rx.recv() => break,
accept_result = listener.accept() => {
let (stream, _) = match accept_result {
Ok(x) => x,
Err(_) => continue,
};
let token_map = Arc::clone(&token_map);
let allow_private = allow_private_connect;
let mitm_config = Arc::clone(&mitm_config);
let connection_policies = connection_policies.clone();
tokio::spawn(async move {
let io = TokioIo::new(stream);
let service = service_fn(move |req| {
let token_map = Arc::clone(&token_map);
let mitm_config = Arc::clone(&mitm_config);
let connection_policies = connection_policies.clone();
async move { proxy_handler(req, token_map, allow_private, mitm_config, connection_policies).await }
});
let conn = hyper::server::conn::http1::Builder::new()
.serve_connection(io, service)
.with_upgrades();
if let Err(e) = conn.await {
tracing::error!("proxy connection error: {}", e);
}
});
}
}
}
});
let program = program.to_string();
let args = args.to_vec();
let exit_status = tokio::task::spawn_blocking(move || {
crate::enforce::run_child(
&program,
&args,
&proxy_url,
&env_vars_with_masked,
&ssl_cert_file,
)
})
.await
.map_err(|e| format!("subprocess join: {}", e))??;
let _ = shutdown_tx.send(()).await;
let _ = server_handle.await;
Ok(exit_status)
}
async fn proxy_handler(
req: Request<hyper::body::Incoming>,
token_map: Arc<Vec<(String, String)>>,
allow_private_connect: bool,
mitm_config: Arc<mitm::MitmConfig>,
connection_policies: Vec<ConnectionPolicy>,
) -> Result<Response<BoxBodyType>, hyper::Error> {
if req.method() == Method::CONNECT {
let authority = match req.uri().authority() {
Some(a) => a.clone(),
None => return Ok(bad_request("CONNECT must include authority (host:port)")),
};
let host = authority.host().to_string();
let authority_str = authority.to_string();
if !connection_allowed(&host, Some(&connection_policies)) {
return Ok(bad_request("CONNECT to this host is not allowed by policy"));
}
if !allow_private_connect && is_private_authority(&authority_str) {
return Ok(bad_request("CONNECT to private/local address not allowed"));
}
return mitm::handle_connect_mitm(req, token_map, mitm_config).await;
}
handle_forward(req, token_map, connection_policies).await
}
async fn handle_forward(
req: Request<hyper::body::Incoming>,
token_map: Arc<Vec<(String, String)>>,
connection_policies: Vec<ConnectionPolicy>,
) -> Result<Response<BoxBodyType>, hyper::Error> {
let (parts, body) = req.into_parts();
let uri = parts.uri.clone();
let uri_str = uri.to_string();
let modified_uri_bytes = replace_tokens_in_bytes(uri_str.as_bytes(), &token_map);
let modified_uri = match String::from_utf8(modified_uri_bytes) {
Ok(s) => s.parse().unwrap_or(uri),
Err(_) => uri.clone(),
};
let host = match modified_uri.host() {
Some(h) => h.to_string(),
None => return Ok(bad_request("Request URI has no host")),
};
if !connection_allowed(host.as_str(), Some(&connection_policies)) {
return Ok(bad_request(
"Connection to this host is not allowed by policy",
));
}
let port = modified_uri.port_u16().unwrap_or(80);
let body_bytes = body.collect().await?.to_bytes();
let modified_body = match collect_and_rewrite_body(&body_bytes, &token_map) {
Ok(b) => b,
Err(resp) => return Ok(resp),
};
let mut new_headers = rewrite_headers(&parts.headers, &token_map);
new_headers.insert(
http::header::CONTENT_LENGTH,
http::HeaderValue::from_str(&modified_body.len().to_string()).unwrap(),
);
let body = Full::new(Bytes::from(modified_body));
let new_req = match build_request(parts.method, &modified_uri, new_headers, body) {
Ok(r) => r,
Err(resp) => return Ok(resp),
};
let stream = match tokio::time::timeout(CONNECT_TIMEOUT, TcpStream::connect((host.as_str(), port))).await {
Ok(r) => try_or_bad_gateway!(r, "Upstream connection failed"),
Err(_) => return Ok(bad_gateway("Upstream connect timeout")),
};
let io = TokioIo::new(stream);
let mut sender = match http1_handshake(io).await {
Ok(s) => s,
Err(resp) => return Ok(resp),
};
let resp = try_or_bad_gateway!(sender.send_request(new_req).await, "Upstream request failed");
Ok(box_response(resp))
}
#[cfg(test)]
mod tests {
use crate::types::{Sandbox, SandboxConfig, SecretMapping, StringMapping};
#[tokio::test]
async fn sandbox_clone() {
let m = SecretMapping {
var: "K".to_string(),
value: "v".to_string(),
};
let config = SandboxConfig {
secrets: vec![m],
..SandboxConfig::default()
};
let sandbox = Sandbox::new(config);
let cloned = sandbox.clone();
let status = cloned.run("true", &[]).await;
assert!(status.is_ok());
assert!(status.unwrap().success());
}
#[tokio::test]
async fn run_requires_at_least_one_mapping() {
let config = SandboxConfig {
..SandboxConfig::default()
};
let sandbox = Sandbox::new(config);
let result = sandbox.run("true", &[]).await;
assert!(result.is_err());
let err = result.unwrap_err();
assert!(
err.to_string()
.contains("at least one secret mapping or string mapping"),
"expected message about mapping, got: {}",
err
);
}
#[tokio::test]
async fn run_with_string_mapping_only() {
let config = SandboxConfig {
strings: vec![StringMapping {
token: "__TOKEN__".to_string(),
value: "replaced".to_string(),
}],
..SandboxConfig::default()
};
let sandbox = Sandbox::new(config);
let result = sandbox.run("true", &[]).await;
assert!(result.is_ok());
assert!(result.unwrap().success());
}
#[tokio::test]
async fn run_returns_exit_status_success() {
let config = SandboxConfig {
secrets: vec![SecretMapping {
var: "X".to_string(),
value: "x".to_string(),
}],
..SandboxConfig::default()
};
let sandbox = Sandbox::new(config);
let result = sandbox.run("true", &[]).await;
assert!(result.is_ok(), "{:?}", result.err());
assert!(result.unwrap().success());
}
#[tokio::test]
async fn run_returns_exit_status_failure() {
let config = SandboxConfig {
secrets: vec![SecretMapping {
var: "X".to_string(),
value: "x".to_string(),
}],
..SandboxConfig::default()
};
let sandbox = Sandbox::new(config);
let result = sandbox.run("false", &[]).await;
assert!(result.is_ok());
assert!(!result.unwrap().success());
}
#[tokio::test]
async fn run_forward_exit_code() {
let config = SandboxConfig {
secrets: vec![SecretMapping {
var: "X".to_string(),
value: "x".to_string(),
}],
..SandboxConfig::default()
};
let sandbox = Sandbox::new(config);
let result = sandbox
.run("sh", &["-c".to_string(), "exit 42".to_string()])
.await;
assert!(result.is_ok());
assert_eq!(result.unwrap().code(), Some(42));
}
#[tokio::test]
async fn run_command_sees_masked_env() {
let config = SandboxConfig {
secrets: vec![SecretMapping {
var: "API_KEY".to_string(),
value: "real-secret".to_string(),
}],
..SandboxConfig::default()
};
let sandbox = Sandbox::new(config);
let result = sandbox.run("sh", &["-c".to_string(), "v=$API_KEY; if [ \"$v\" = \"real-secret\" ]; then exit 1; fi; case \"$v\" in api-key-*) exit 0;; *) exit 2;; esac".to_string()]).await;
assert!(result.is_ok(), "run failed");
assert_eq!(
result.unwrap().code(),
Some(0),
"expected masked token pattern api-key-*"
);
}
#[tokio::test]
async fn run_sets_http_proxy_env() {
let config = SandboxConfig {
secrets: vec![SecretMapping {
var: "X".to_string(),
value: "x".to_string(),
}],
..SandboxConfig::default()
};
let sandbox = Sandbox::new(config);
let result = sandbox
.run("sh", &["-c".to_string(), "case \"$HTTP_PROXY\" in http://127.0.0.1*) exit 0;; *) exit 1;; esac".to_string()])
.await;
assert!(result.is_ok());
assert_eq!(result.unwrap().code(), Some(0));
let result = sandbox
.run("sh", &["-c".to_string(), "case \"$HTTPS_PROXY\" in http://127.0.0.1*) exit 0;; *) exit 1;; esac".to_string()])
.await;
assert!(result.is_ok());
assert_eq!(result.unwrap().code(), Some(0));
}
}