use std::path::PathBuf;
use clap::Args;
use microsandbox::sandbox::SandboxBuilder;
use crate::ui;
#[derive(Debug, Args)]
pub struct SandboxOpts {
#[arg(short, long)]
pub name: Option<String>,
#[arg(short = 'c', long)]
pub cpus: Option<u8>,
#[arg(short, long)]
pub memory: Option<String>,
#[arg(short, long)]
pub volume: Vec<String>,
#[arg(short, long)]
pub workdir: Option<String>,
#[arg(long)]
pub shell: Option<String>,
#[arg(short, long)]
pub env: Vec<String>,
#[arg(long)]
pub replace: bool,
#[arg(short, long)]
pub quiet: bool,
#[arg(long)]
pub tmpfs: Vec<String>,
#[arg(long)]
pub script: Vec<String>,
#[arg(long)]
pub entrypoint: Option<String>,
#[arg(short = 'H', long)]
pub hostname: Option<String>,
#[arg(short = 'u', long)]
pub user: Option<String>,
#[arg(long)]
pub pull: Option<String>,
#[arg(long)]
pub log_level: Option<String>,
#[arg(long)]
pub max_duration: Option<String>,
#[arg(long)]
pub idle_timeout: Option<String>,
#[cfg(feature = "net")]
#[arg(short, long)]
pub port: Vec<String>,
#[cfg(feature = "net")]
#[arg(long)]
pub no_network: bool,
#[cfg(feature = "net")]
#[arg(long)]
pub dns_block_domain: Vec<String>,
#[cfg(feature = "net")]
#[arg(long)]
pub dns_block_suffix: Vec<String>,
#[cfg(feature = "net")]
#[arg(long)]
pub no_dns_rebind_protection: bool,
#[cfg(feature = "net")]
#[arg(long)]
pub max_connections: Option<usize>,
#[cfg(feature = "net")]
#[arg(long)]
pub tls_intercept: bool,
#[cfg(feature = "net")]
#[arg(long)]
pub tls_intercept_port: Vec<u16>,
#[cfg(feature = "net")]
#[arg(long)]
pub tls_bypass: Vec<String>,
#[cfg(feature = "net")]
#[arg(long)]
pub no_block_quic: bool,
#[cfg(feature = "net")]
#[arg(long)]
pub tls_ca_cert: Option<PathBuf>,
#[cfg(feature = "net")]
#[arg(long)]
pub tls_ca_key: Option<PathBuf>,
#[cfg(feature = "net")]
#[arg(long)]
pub secret: Vec<String>,
#[cfg(feature = "net")]
#[arg(long)]
pub on_secret_violation: Option<String>,
}
impl SandboxOpts {
pub fn has_creation_flags(&self) -> bool {
let base = self.cpus.is_some()
|| self.memory.is_some()
|| !self.volume.is_empty()
|| self.workdir.is_some()
|| self.shell.is_some()
|| !self.env.is_empty()
|| !self.tmpfs.is_empty()
|| !self.script.is_empty()
|| self.entrypoint.is_some()
|| self.hostname.is_some()
|| self.user.is_some()
|| self.pull.is_some()
|| self.log_level.is_some()
|| self.max_duration.is_some()
|| self.idle_timeout.is_some();
#[cfg(feature = "net")]
let net = !self.port.is_empty()
|| self.no_network
|| !self.dns_block_domain.is_empty()
|| !self.dns_block_suffix.is_empty()
|| self.no_dns_rebind_protection
|| self.max_connections.is_some()
|| self.tls_intercept
|| !self.tls_intercept_port.is_empty()
|| !self.tls_bypass.is_empty()
|| self.no_block_quic
|| self.tls_ca_cert.is_some()
|| self.tls_ca_key.is_some()
|| !self.secret.is_empty()
|| self.on_secret_violation.is_some();
#[cfg(not(feature = "net"))]
let net = false;
base || net
}
}
pub fn apply_sandbox_opts(
mut builder: SandboxBuilder,
opts: &SandboxOpts,
) -> anyhow::Result<SandboxBuilder> {
if let Some(cpus) = opts.cpus {
builder = builder.cpus(cpus);
}
if let Some(ref mem) = opts.memory {
builder = builder.memory(ui::parse_size_mib(mem).map_err(anyhow::Error::msg)?);
}
if let Some(ref workdir) = opts.workdir {
builder = builder.workdir(workdir);
}
if let Some(ref shell) = opts.shell {
builder = builder.shell(shell);
}
if opts.replace {
builder = builder.replace();
}
for env_str in &opts.env {
let (k, v) = ui::parse_env(env_str).map_err(anyhow::Error::msg)?;
builder = builder.env(k, v);
}
for vol_str in &opts.volume {
builder = apply_volume(builder, vol_str)?;
}
for tmpfs_str in &opts.tmpfs {
let (path, size) = parse_tmpfs(tmpfs_str)?;
builder = if let Some(size_mib) = size {
builder.volume(&path, |m| m.tmpfs().size(size_mib))
} else {
builder.volume(&path, |m| m.tmpfs())
};
}
for script_str in &opts.script {
let (name, content) = parse_script(script_str)?;
builder = builder.script(name, content);
}
if let Some(ref ep) = opts.entrypoint {
builder = builder.entrypoint(vec![ep.clone()]);
}
if let Some(ref hostname) = opts.hostname {
builder = builder.hostname(hostname);
}
if let Some(ref user) = opts.user {
builder = builder.user(user);
}
if let Some(ref pull) = opts.pull {
builder = builder.pull_policy(parse_pull_policy(pull)?);
}
if let Some(ref level) = opts.log_level {
builder = builder.log_level(parse_log_level(level)?);
}
if let Some(ref dur) = opts.max_duration {
builder = builder.max_duration(parse_duration_secs(dur)?);
}
if let Some(ref dur) = opts.idle_timeout {
builder = builder.idle_timeout(parse_duration_secs(dur)?);
}
#[cfg(feature = "net")]
{
builder = apply_network_opts(builder, opts)?;
}
Ok(builder)
}
pub fn apply_volume(builder: SandboxBuilder, spec: &str) -> anyhow::Result<SandboxBuilder> {
let (source, guest) = spec
.split_once(':')
.ok_or_else(|| anyhow::anyhow!("volume must be in format source:guest"))?;
if source.starts_with('/') || source.starts_with("./") || source.starts_with("../") {
Ok(builder.volume(guest, |m| m.bind(source)))
} else {
Ok(builder.volume(guest, |m| m.named(source)))
}
}
#[cfg(feature = "net")]
fn apply_network_opts(
mut builder: SandboxBuilder,
opts: &SandboxOpts,
) -> anyhow::Result<SandboxBuilder> {
for port_str in &opts.port {
let (host, guest, udp) = parse_port(port_str)?;
builder = if udp {
builder.port_udp(host, guest)
} else {
builder.port(host, guest)
};
}
if opts.no_network {
builder = builder.disable_network();
}
for secret_str in &opts.secret {
let (env_var, value, host) = parse_secret(secret_str)?;
builder = builder.secret_env(env_var, value, host);
}
let has_network_config = !opts.dns_block_domain.is_empty()
|| !opts.dns_block_suffix.is_empty()
|| opts.no_dns_rebind_protection
|| opts.max_connections.is_some()
|| opts.tls_intercept
|| !opts.tls_intercept_port.is_empty()
|| !opts.tls_bypass.is_empty()
|| opts.no_block_quic
|| opts.tls_ca_cert.is_some()
|| opts.tls_ca_key.is_some()
|| opts.on_secret_violation.is_some();
if has_network_config {
let dns_block_domain = opts.dns_block_domain.clone();
let dns_block_suffix = opts.dns_block_suffix.clone();
let no_dns_rebind = opts.no_dns_rebind_protection;
let max_conn = opts.max_connections;
let tls_intercept = opts.tls_intercept;
let tls_ports = opts.tls_intercept_port.clone();
let tls_bypass = opts.tls_bypass.clone();
let no_block_quic = opts.no_block_quic;
let ca_cert = opts.tls_ca_cert.clone();
let ca_key = opts.tls_ca_key.clone();
let violation_action = parse_violation_action(&opts.on_secret_violation)?;
builder = builder.network(move |mut n| {
for domain in &dns_block_domain {
n = n.block_domain(domain);
}
for suffix in &dns_block_suffix {
n = n.block_domain_suffix(suffix);
}
if no_dns_rebind {
n = n.dns_rebind_protection(false);
}
if let Some(max) = max_conn {
n = n.max_connections(max);
}
if let Some(action) = violation_action {
n = n.on_secret_violation(action);
}
let has_tls = tls_intercept
|| !tls_ports.is_empty()
|| !tls_bypass.is_empty()
|| no_block_quic
|| ca_cert.is_some()
|| ca_key.is_some();
if has_tls {
let tls_ports = tls_ports.clone();
let tls_bypass = tls_bypass.clone();
let ca_cert = ca_cert.clone();
let ca_key = ca_key.clone();
n = n.tls(move |mut t| {
if !tls_ports.is_empty() {
t = t.intercepted_ports(tls_ports);
}
for domain in &tls_bypass {
t = t.bypass(domain);
}
if no_block_quic {
t = t.block_quic(false);
}
if let Some(ref cert) = ca_cert {
t = t.ca_cert(cert);
}
if let Some(ref key) = ca_key {
t = t.ca_key(key);
}
t
});
}
n
});
}
Ok(builder)
}
pub fn parse_duration_secs(s: &str) -> anyhow::Result<u64> {
let s = s.trim();
if let Some(n) = s.strip_suffix('s') {
Ok(n.trim().parse::<u64>()?)
} else if let Some(n) = s.strip_suffix('m') {
Ok(n.trim().parse::<u64>()? * 60)
} else if let Some(n) = s.strip_suffix('h') {
Ok(n.trim().parse::<u64>()? * 3600)
} else {
Ok(s.parse::<u64>()?)
}
}
#[cfg(feature = "net")]
fn parse_port(spec: &str) -> anyhow::Result<(u16, u16, bool)> {
let (port_part, udp) = if let Some(p) = spec.strip_suffix("/udp") {
(p, true)
} else if let Some(p) = spec.strip_suffix("/tcp") {
(p, false)
} else {
(spec, false)
};
let (host_str, guest_str) = port_part
.split_once(':')
.ok_or_else(|| anyhow::anyhow!("port must be in format HOST:GUEST[/udp]"))?;
let host: u16 = host_str
.trim()
.parse()
.map_err(|_| anyhow::anyhow!("invalid host port: {host_str}"))?;
let guest: u16 = guest_str
.trim()
.parse()
.map_err(|_| anyhow::anyhow!("invalid guest port: {guest_str}"))?;
Ok((host, guest, udp))
}
#[cfg(feature = "net")]
fn parse_secret(spec: &str) -> anyhow::Result<(String, String, String)> {
let eq_pos = spec
.find('=')
.ok_or_else(|| anyhow::anyhow!("secret must be in format ENV=VALUE@HOST"))?;
let env_var = spec[..eq_pos].to_string();
let rest = &spec[eq_pos + 1..];
let at_pos = rest
.rfind('@')
.ok_or_else(|| anyhow::anyhow!("secret must be in format ENV=VALUE@HOST"))?;
let value = rest[..at_pos].to_string();
let host = rest[at_pos + 1..].to_string();
if env_var.is_empty() || value.is_empty() || host.is_empty() {
anyhow::bail!("secret must be in format ENV=VALUE@HOST (all parts required)");
}
Ok((env_var, value, host))
}
#[cfg(feature = "net")]
fn parse_violation_action(
s: &Option<String>,
) -> anyhow::Result<Option<microsandbox_network::secrets::config::ViolationAction>> {
use microsandbox_network::secrets::config::ViolationAction;
match s.as_deref() {
None => Ok(None),
Some("block") => Ok(Some(ViolationAction::Block)),
Some("block-and-log") => Ok(Some(ViolationAction::BlockAndLog)),
Some("block-and-terminate") => Ok(Some(ViolationAction::BlockAndTerminate)),
Some(other) => anyhow::bail!(
"invalid violation action: {other} (expected: block, block-and-log, block-and-terminate)"
),
}
}
fn parse_tmpfs(spec: &str) -> anyhow::Result<(String, Option<u32>)> {
if let Some((path, size_str)) = spec.split_once(':') {
let size_mib = ui::parse_size_mib(size_str).map_err(anyhow::Error::msg)?;
Ok((path.to_string(), Some(size_mib)))
} else {
Ok((spec.to_string(), None))
}
}
fn parse_script(spec: &str) -> anyhow::Result<(String, String)> {
let (name, path) = spec
.split_once(':')
.ok_or_else(|| anyhow::anyhow!("script must be in format NAME:PATH"))?;
let content = std::fs::read_to_string(path)
.map_err(|e| anyhow::anyhow!("failed to read script file '{path}': {e}"))?;
Ok((name.to_string(), content))
}
fn parse_pull_policy(s: &str) -> anyhow::Result<microsandbox::sandbox::PullPolicy> {
use microsandbox::sandbox::PullPolicy;
match s {
"always" => Ok(PullPolicy::Always),
"if-missing" => Ok(PullPolicy::IfMissing),
"never" => Ok(PullPolicy::Never),
_ => anyhow::bail!("invalid pull policy: {s} (expected: always, if-missing, never)"),
}
}
fn parse_log_level(s: &str) -> anyhow::Result<microsandbox::LogLevel> {
use microsandbox::LogLevel;
match s {
"error" => Ok(LogLevel::Error),
"warn" => Ok(LogLevel::Warn),
"info" => Ok(LogLevel::Info),
"debug" => Ok(LogLevel::Debug),
"trace" => Ok(LogLevel::Trace),
_ => anyhow::bail!("invalid log level: {s} (expected: error, warn, info, debug, trace)"),
}
}
pub fn parse_rlimit(
spec: &str,
) -> anyhow::Result<(microsandbox::sandbox::RlimitResource, u64, u64)> {
use microsandbox::sandbox::RlimitResource;
let (res_str, limit_str) = spec
.split_once('=')
.ok_or_else(|| anyhow::anyhow!("rlimit must be in format RESOURCE=LIMIT"))?;
let resource = RlimitResource::try_from(res_str).map_err(|e| anyhow::anyhow!("{e}"))?;
let (soft, hard) = if let Some((s, h)) = limit_str.split_once(':') {
let soft = s
.parse::<u64>()
.map_err(|e| anyhow::anyhow!("invalid soft limit: {e}"))?;
let hard = h
.parse::<u64>()
.map_err(|e| anyhow::anyhow!("invalid hard limit: {e}"))?;
(soft, hard)
} else {
let limit = limit_str
.parse::<u64>()
.map_err(|e| anyhow::anyhow!("invalid limit: {e}"))?;
(limit, limit)
};
Ok((resource, soft, hard))
}