#![cfg(target_os = "linux")]
use crate::bwrap_proxy::{STAGE2_REWRITE_KEYS_ENV_KEY, STAGE2_UDS_ENV_KEY, rewrite_proxy_url_port};
use anyhow::{Context, Result, bail};
use std::ffi::{CString, OsString};
use std::io::{Read, Write};
use std::net::{Ipv4Addr, TcpListener, TcpStream};
use std::os::unix::ffi::OsStringExt;
use std::os::unix::net::UnixStream;
use std::path::PathBuf;
pub const STAGE2_SETUP_FAILED_EXIT: i32 = 88;
pub fn run(argv: Vec<OsString>) -> Result<()> {
if argv.len() < 2 {
bail!("stage2: missing user command (usage: koda-sandbox-stage2 <cmd> [args...])");
}
let uds_path = std::env::var(STAGE2_UDS_ENV_KEY)
.with_context(|| format!("stage2: {STAGE2_UDS_ENV_KEY} env var missing"))?;
let uds_path = PathBuf::from(uds_path);
if !uds_path.exists() {
bail!(
"stage2: UDS bridge socket {} not found inside sandbox (was it bind-mounted with --bind?)",
uds_path.display()
);
}
bring_up_loopback().context("stage2: bring up lo")?;
let listener = TcpListener::bind((Ipv4Addr::LOCALHOST, 0))
.context("stage2: bind in-netns TCP listener")?;
let local_port = listener
.local_addr()
.context("stage2: read in-netns listener port")?
.port();
let pid = unsafe { libc::fork() };
if pid < 0 {
bail!("stage2: fork failed: {}", std::io::Error::last_os_error());
}
if pid == 0 {
run_bridge_child(listener, uds_path);
unsafe { libc::_exit(0) };
}
let rewrite_keys = std::env::var(STAGE2_REWRITE_KEYS_ENV_KEY).unwrap_or_default();
for key in rewrite_keys.split(',').filter(|k| !k.is_empty()) {
let Ok(old) = std::env::var(key) else {
continue;
};
let Some(new) = rewrite_proxy_url_port(&old, local_port) else {
continue;
};
unsafe { std::env::set_var(key, new) };
}
unsafe {
std::env::remove_var(STAGE2_UDS_ENV_KEY);
std::env::remove_var(STAGE2_REWRITE_KEYS_ENV_KEY);
}
let prog = argv[1].clone();
let prog_c = osstring_to_cstring(&prog).context("stage2: convert program name")?;
let arg_cs: Vec<CString> = argv[1..]
.iter()
.map(osstring_to_cstring)
.collect::<Result<Vec<_>>>()
.context("stage2: convert command args")?;
let arg_ptrs: Vec<*const libc::c_char> = arg_cs
.iter()
.map(|c| c.as_ptr())
.chain(std::iter::once(std::ptr::null()))
.collect();
unsafe { libc::execvp(prog_c.as_ptr(), arg_ptrs.as_ptr()) };
let err = std::io::Error::last_os_error();
bail!("stage2: execvp({:?}) failed: {err}", prog);
}
fn osstring_to_cstring(s: &OsString) -> Result<CString> {
CString::new(s.clone().into_vec())
.with_context(|| format!("stage2: argv contains NUL byte: {s:?}"))
}
fn bring_up_loopback() -> Result<()> {
let fd = unsafe { libc::socket(libc::AF_INET, libc::SOCK_DGRAM | libc::SOCK_CLOEXEC, 0) };
if fd < 0 {
bail!("socket(AF_INET): {}", std::io::Error::last_os_error());
}
let _close = CloseOnDrop(fd);
let mut req: libc::ifreq = unsafe { std::mem::zeroed() };
let name: &[u8] = b"lo\0";
for (i, b) in name.iter().enumerate() {
req.ifr_name[i] = *b as libc::c_char;
}
let r = unsafe { libc::ioctl(fd, libc::SIOCGIFFLAGS as libc::Ioctl, &mut req) };
if r < 0 {
bail!(
"ioctl(SIOCGIFFLAGS, lo): {}",
std::io::Error::last_os_error()
);
}
let up = (libc::IFF_UP | libc::IFF_RUNNING) as libc::c_short;
let cur = unsafe { req.ifr_ifru.ifru_flags };
if (cur & up) == up {
return Ok(()); }
req.ifr_ifru.ifru_flags = cur | up;
let r = unsafe { libc::ioctl(fd, libc::SIOCSIFFLAGS as libc::Ioctl, &req) };
if r < 0 {
bail!(
"ioctl(SIOCSIFFLAGS, lo|UP): {}",
std::io::Error::last_os_error()
);
}
Ok(())
}
struct CloseOnDrop(libc::c_int);
impl Drop for CloseOnDrop {
fn drop(&mut self) {
unsafe { libc::close(self.0) };
}
}
fn run_bridge_child(listener: TcpListener, uds_path: PathBuf) {
unsafe { libc::prctl(libc::PR_SET_PDEATHSIG, libc::SIGTERM, 0, 0, 0) };
if unsafe { libc::getppid() } == 1 {
return;
}
loop {
let (tcp, _) = match listener.accept() {
Ok(p) => p,
Err(_) => return, };
let path = uds_path.clone();
std::thread::spawn(move || {
let uds = match UnixStream::connect(&path) {
Ok(s) => s,
Err(_) => return,
};
let _ = bridge_two_streams(tcp, uds);
});
}
}
fn bridge_two_streams(tcp: TcpStream, uds: UnixStream) -> std::io::Result<()> {
let tcp_r = tcp.try_clone()?;
let mut tcp_w = tcp;
let uds_r = uds.try_clone()?;
let mut uds_w = uds;
let h = std::thread::spawn(move || copy_until_eof(tcp_r, &mut uds_w));
let _ = copy_until_eof(uds_r, &mut tcp_w);
let _ = h.join();
Ok(())
}
fn copy_until_eof<R: Read, W: Write>(mut r: R, w: &mut W) -> std::io::Result<u64> {
let mut buf = [0u8; 8 * 1024];
let mut total = 0u64;
loop {
let n = match r.read(&mut buf) {
Ok(0) => return Ok(total),
Ok(n) => n,
Err(e) if e.kind() == std::io::ErrorKind::Interrupted => continue,
Err(e) => return Err(e),
};
w.write_all(&buf[..n])?;
total += n as u64;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn osstring_to_cstring_rejects_nul() {
let s = OsString::from_vec(b"foo\0bar".to_vec());
assert!(osstring_to_cstring(&s).is_err());
}
#[test]
fn osstring_to_cstring_accepts_normal_strings() {
let s = OsString::from("sh");
let c = osstring_to_cstring(&s).unwrap();
assert_eq!(c.to_bytes(), b"sh");
}
#[test]
fn copy_until_eof_handles_simple_payload() {
let mut input = std::io::Cursor::new(b"hello world".to_vec());
let mut output: Vec<u8> = Vec::new();
let n = copy_until_eof(&mut input, &mut output).unwrap();
assert_eq!(n, 11);
assert_eq!(&output, b"hello world");
}
#[test]
fn stage2_setup_failed_exit_is_88() {
assert_eq!(STAGE2_SETUP_FAILED_EXIT, 88);
}
}