use std::collections::BTreeMap;
use std::fs::File;
use std::io::{self, Read, Write};
use std::os::fd::{AsRawFd, FromRawFd, IntoRawFd, OwnedFd, RawFd};
use std::os::unix::net::UnixStream;
use std::path::Path;
use std::path::PathBuf;
use std::process::ExitStatus;
use std::sync::mpsc::{channel, Receiver, Sender};
use std::sync::{Arc, Mutex};
use std::thread::{self, JoinHandle};
use std::time::{Duration, Instant};
use serde::Serialize;
const FRAME_REQUEST: u8 = 0xff;
const FRAME_CONTROL: u8 = 0xfe;
const FRAME_STDIN: u8 = 0;
const FRAME_STDOUT: u8 = 1;
const FRAME_STDERR: u8 = 2;
const FRAME_RESIZE: u8 = 3;
const FRAME_SIGNAL: u8 = 4;
const FRAME_EXIT: u8 = 5;
const FRAME_ERROR: u8 = 6;
pub struct ExecBuilder {
exec_path: PathBuf,
argv: Vec<String>,
env: BTreeMap<String, String>,
cwd: Option<String>,
tty: bool,
cols: Option<u16>,
rows: Option<u16>,
timeout: Option<Duration>,
stage_files: Vec<StageFile>,
chain: Vec<Vec<String>>,
}
#[derive(Clone)]
struct StageFile {
path: String,
data: Vec<u8>,
mode: Option<u32>,
}
impl ExecBuilder {
pub fn new(exec_path: PathBuf) -> Self {
Self {
exec_path,
argv: Vec::new(),
env: BTreeMap::new(),
cwd: None,
tty: false,
cols: None,
rows: None,
timeout: None,
stage_files: Vec::new(),
chain: Vec::new(),
}
}
pub fn argv<I, S>(mut self, argv: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
self.argv = argv.into_iter().map(Into::into).collect();
self
}
pub fn env(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.env.insert(key.into(), value.into());
self
}
pub fn cwd(mut self, path: impl Into<String>) -> Self {
self.cwd = Some(path.into());
self
}
pub fn tty(mut self, on: bool) -> Self {
self.tty = on;
self
}
pub fn winsize(mut self, cols: u16, rows: u16) -> Self {
self.cols = Some(cols);
self.rows = Some(rows);
self
}
pub fn timeout(mut self, d: Duration) -> Self {
self.timeout = Some(d);
self
}
pub fn stage_file(mut self, path: impl Into<String>, bytes: impl Into<Vec<u8>>) -> Self {
self.stage_files.push(StageFile {
path: path.into(),
data: bytes.into(),
mode: None,
});
self
}
pub fn stage_file_mode(
mut self,
path: impl Into<String>,
bytes: impl Into<Vec<u8>>,
mode: u32,
) -> Self {
self.stage_files.push(StageFile {
path: path.into(),
data: bytes.into(),
mode: Some(mode),
});
self
}
pub fn chain<I, S>(mut self, argv: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
self.chain.push(argv.into_iter().map(Into::into).collect());
self
}
pub fn spawn(self) -> io::Result<ExecChild> {
if self.argv.is_empty() {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"exec: argv is empty",
));
}
spawn(self)
}
pub fn output(self) -> io::Result<ExecOutcome> {
let timeout = self.timeout;
let t0 = Instant::now();
let mut child = self.spawn()?;
drop(child.stdin());
let mut stdout = child.stdout();
let mut stderr = child.stderr();
let stdout_handle = stdout.take().map(|mut s| {
thread::spawn(move || {
let mut buf = Vec::new();
let _ = s.read_to_end(&mut buf);
buf
})
});
let stderr_handle = stderr.take().map(|mut s| {
thread::spawn(move || {
let mut buf = Vec::new();
let _ = s.read_to_end(&mut buf);
buf
})
});
let (cancel_tx, cancel_rx) = channel::<()>();
let timed_out = Arc::new(std::sync::atomic::AtomicBool::new(false));
let watchdog = if let Some(d) = timeout {
let timed_out = Arc::clone(&timed_out);
let sock_w = Arc::clone(&child.sock_w);
Some(thread::spawn(move || {
if cancel_rx.recv_timeout(d).is_err() {
timed_out.store(true, std::sync::atomic::Ordering::SeqCst);
let r = send_frame(&sock_w, FRAME_SIGNAL, &[9u8]);
if std::env::var_os("SUPERMACHINE_TIMEOUT_TRACE").is_some() {
eprintln!("[exec.timeout] fired at +{:?}, send_frame={:?}", d, r);
}
}
}))
} else {
None
};
let exit = child.wait_with_rss();
let _ = cancel_tx.send(());
if let Some(h) = watchdog {
let _ = h.join();
}
let stdout_bytes = stdout_handle.map(|h| h.join().unwrap_or_default()).unwrap_or_default();
let stderr_bytes = stderr_handle.map(|h| h.join().unwrap_or_default()).unwrap_or_default();
let (status, peak_rss_kib) = exit?;
Ok(ExecOutcome {
status,
stdout: stdout_bytes,
stderr: stderr_bytes,
duration: t0.elapsed(),
peak_rss_kib,
timed_out: timed_out.load(std::sync::atomic::Ordering::SeqCst),
})
}
}
#[derive(Debug)]
pub struct ExecOutcome {
pub status: ExitStatus,
pub stdout: Vec<u8>,
pub stderr: Vec<u8>,
pub duration: Duration,
pub peak_rss_kib: Option<u64>,
pub timed_out: bool,
}
impl ExecOutcome {
pub fn success(&self) -> bool {
self.status.success() && !self.timed_out
}
}
#[derive(Serialize)]
struct RequestPayload<'a> {
argv: &'a [String],
env: &'a BTreeMap<String, String>,
#[serde(skip_serializing_if = "Option::is_none")]
cwd: Option<&'a str>,
tty: bool,
#[serde(skip_serializing_if = "Option::is_none")]
cols: Option<u16>,
#[serde(skip_serializing_if = "Option::is_none")]
rows: Option<u16>,
#[serde(skip_serializing_if = "Vec::is_empty")]
stage_files: Vec<StageFilePayload<'a>>,
#[serde(skip_serializing_if = "<[_]>::is_empty")]
chain: &'a [Vec<String>],
}
#[derive(Serialize)]
struct StageFilePayload<'a> {
path: &'a str,
data_b64: String,
#[serde(skip_serializing_if = "Option::is_none")]
mode: Option<u32>,
}
pub struct ExecChild {
sock_w: Arc<Mutex<UnixStream>>,
stdout_r: Option<OwnedFd>,
stderr_r: Option<OwnedFd>,
stdin: Option<ExecStdin>,
demux: Option<JoinHandle<()>>,
exit_rx: Receiver<DemuxExit>,
}
enum DemuxExit {
Status { code: u32, peak_rss_kib: Option<u64> },
EofBeforeExit,
Error(String),
Io(io::Error),
}
impl ExecChild {
pub fn stdin(&mut self) -> Option<ExecStdin> {
self.stdin.take()
}
pub fn stdout(&mut self) -> Option<ExecStdout> {
self.stdout_r.take().map(ExecStdout::from_fd)
}
pub fn stderr(&mut self) -> Option<ExecStderr> {
self.stderr_r.take().map(ExecStderr::from_fd)
}
pub fn signal(&self, signum: i32) -> io::Result<()> {
let payload = [signum as u8];
send_frame(&self.sock_w, FRAME_SIGNAL, &payload)
}
pub fn resize(&self, cols: u16, rows: u16) -> io::Result<()> {
let mut payload = [0u8; 4];
payload[0..2].copy_from_slice(&cols.to_be_bytes());
payload[2..4].copy_from_slice(&rows.to_be_bytes());
send_frame(&self.sock_w, FRAME_RESIZE, &payload)
}
pub fn wait(self) -> io::Result<ExitStatus> {
Ok(self.wait_with_rss()?.0)
}
pub fn wait_with_rss(mut self) -> io::Result<(ExitStatus, Option<u64>)> {
self.stdin.take();
let exit = self
.exit_rx
.recv()
.map_err(|_| io::Error::new(io::ErrorKind::Other, "exec: demux thread died"))?;
if let Some(h) = self.demux.take() {
let _ = h.join();
}
match exit {
DemuxExit::Status { code, peak_rss_kib } => {
Ok((synthesize_exit(code), peak_rss_kib))
}
DemuxExit::EofBeforeExit => Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"exec: agent closed connection before sending EXIT",
)),
DemuxExit::Error(msg) => Err(io::Error::new(io::ErrorKind::Other, msg)),
DemuxExit::Io(e) => Err(e),
}
}
}
impl Drop for ExecChild {
fn drop(&mut self) {
self.stdin.take();
}
}
pub struct ExecStdin {
sock_w: Arc<Mutex<UnixStream>>,
closed: bool,
}
impl ExecStdin {
fn new(sock_w: Arc<Mutex<UnixStream>>) -> Self {
Self { sock_w, closed: false }
}
pub fn close(mut self) -> io::Result<()> {
if self.closed {
return Ok(());
}
self.closed = true;
send_frame(&self.sock_w, FRAME_STDIN, &[])
}
}
impl Write for ExecStdin {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
if buf.is_empty() {
return Ok(0);
}
send_frame(&self.sock_w, FRAME_STDIN, buf)?;
Ok(buf.len())
}
fn flush(&mut self) -> io::Result<()> {
Ok(())
}
}
impl Drop for ExecStdin {
fn drop(&mut self) {
if !self.closed {
let _ = send_frame(&self.sock_w, FRAME_STDIN, &[]);
}
}
}
pub struct ExecStdout {
file: File,
}
impl ExecStdout {
fn from_fd(fd: OwnedFd) -> Self {
let file = unsafe { File::from_raw_fd(fd.into_raw_fd()) };
Self { file }
}
}
impl Read for ExecStdout {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.file.read(buf)
}
}
pub struct ExecStderr {
file: File,
}
impl ExecStderr {
fn from_fd(fd: OwnedFd) -> Self {
let file = unsafe { File::from_raw_fd(fd.into_raw_fd()) };
Self { file }
}
}
impl Read for ExecStderr {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.file.read(buf)
}
}
pub fn send_control(exec_path: &Path, action_json: &serde_json::Value) -> io::Result<()> {
let _ = send_control_with_ack(exec_path, action_json, None)?;
Ok(())
}
pub(crate) fn send_control_with_ack(
exec_path: &Path,
action_json: &serde_json::Value,
read_timeout: Option<std::time::Duration>,
) -> io::Result<serde_json::Value> {
use std::io::Read;
let mut sock = UnixStream::connect(exec_path)?;
sock.set_read_timeout(read_timeout.or(Some(std::time::Duration::from_secs(5))))?;
sock.set_write_timeout(Some(std::time::Duration::from_secs(5)))?;
let body = serde_json::to_vec(action_json)
.map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, format!("encode CONTROL: {e}")))?;
let mut header = [0u8; 5];
header[0] = FRAME_CONTROL;
header[1..5].copy_from_slice(&(body.len() as u32).to_be_bytes());
sock.write_all(&header)?;
if !body.is_empty() {
sock.write_all(&body)?;
}
let mut ack_hdr = [0u8; 5];
sock.read_exact(&mut ack_hdr)?;
let kind = ack_hdr[0];
let len = u32::from_be_bytes([ack_hdr[1], ack_hdr[2], ack_hdr[3], ack_hdr[4]]) as usize;
if len > 16 * 1024 * 1024 {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("CONTROL ack frame too large: {len}"),
));
}
let mut ack_body = vec![0u8; len];
if !ack_body.is_empty() {
sock.read_exact(&mut ack_body)?;
}
if kind != FRAME_CONTROL {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("expected CONTROL ack frame, got {kind:#x}"),
));
}
let ack: serde_json::Value = serde_json::from_slice(&ack_body)
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, format!("CONTROL ack JSON: {e}")))?;
if ack.get("ok").and_then(|v| v.as_bool()) == Some(true) {
return Ok(ack);
}
let msg = ack
.get("error")
.and_then(|v| v.as_str())
.unwrap_or("CONTROL: agent reported failure")
.to_owned();
Err(io::Error::new(io::ErrorKind::Other, msg))
}
fn spawn(builder: ExecBuilder) -> io::Result<ExecChild> {
let sock = UnixStream::connect(&builder.exec_path)?;
let sock_r = sock.try_clone()?;
let sock_w = Arc::new(Mutex::new(sock));
let stage_payload: Vec<StageFilePayload> = builder
.stage_files
.iter()
.map(|s| StageFilePayload {
path: &s.path,
data_b64: crate::api::b64_encode(&s.data),
mode: s.mode,
})
.collect();
let payload = RequestPayload {
argv: &builder.argv,
env: &builder.env,
cwd: builder.cwd.as_deref(),
tty: builder.tty,
cols: builder.cols,
rows: builder.rows,
stage_files: stage_payload,
chain: &builder.chain,
};
let json = serde_json::to_vec(&payload)
.map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, format!("exec: encode REQUEST: {e}")))?;
send_frame(&sock_w, FRAME_REQUEST, &json)?;
let (stdout_r, stdout_w) = pipe()?;
let (stderr_r, stderr_w) = pipe()?;
let (exit_tx, exit_rx) = channel::<DemuxExit>();
let demux = thread::Builder::new()
.name("supermachine-exec-demux".into())
.spawn(move || {
demux_loop(sock_r, stdout_w, stderr_w, exit_tx);
})
.map_err(|e| io::Error::new(io::ErrorKind::Other, format!("exec: demux thread: {e}")))?;
let stdin = ExecStdin::new(sock_w.clone());
Ok(ExecChild {
sock_w,
stdout_r: Some(stdout_r),
stderr_r: Some(stderr_r),
stdin: Some(stdin),
demux: Some(demux),
exit_rx,
})
}
fn demux_loop(
mut sock: UnixStream,
stdout_w: OwnedFd,
stderr_w: OwnedFd,
exit_tx: Sender<DemuxExit>,
) {
let stdout_fd = stdout_w.as_raw_fd();
let stderr_fd = stderr_w.as_raw_fd();
loop {
let (kind, payload) = match read_frame(&mut sock) {
Ok(f) => f,
Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => {
let _ = exit_tx.send(DemuxExit::EofBeforeExit);
return;
}
Err(e) => {
let _ = exit_tx.send(DemuxExit::Io(e));
return;
}
};
match kind {
FRAME_STDOUT => {
let _ = pipe_write(stdout_fd, &payload);
}
FRAME_STDERR => {
let _ = pipe_write(stderr_fd, &payload);
}
FRAME_EXIT => {
if payload.len() >= 4 {
let code = u32::from_be_bytes([payload[0], payload[1], payload[2], payload[3]]);
let peak_rss_kib = if payload.len() >= 12 {
Some(u64::from_be_bytes([
payload[4], payload[5], payload[6], payload[7],
payload[8], payload[9], payload[10], payload[11],
]))
} else {
None
};
let _ = exit_tx.send(DemuxExit::Status { code, peak_rss_kib });
} else {
let _ = exit_tx.send(DemuxExit::Status { code: 0, peak_rss_kib: None });
}
return;
}
FRAME_ERROR => {
let msg = String::from_utf8_lossy(&payload).into_owned();
let _ = exit_tx.send(DemuxExit::Error(msg));
return;
}
_ => {
}
}
}
}
fn send_frame(sock: &Arc<Mutex<UnixStream>>, kind: u8, payload: &[u8]) -> io::Result<()> {
let mut header = [0u8; 5];
header[0] = kind;
header[1..5].copy_from_slice(&(payload.len() as u32).to_be_bytes());
let mut g = sock.lock().map_err(|_| {
io::Error::new(io::ErrorKind::Other, "exec: socket mutex poisoned")
})?;
g.write_all(&header)?;
if !payload.is_empty() {
g.write_all(payload)?;
}
Ok(())
}
fn read_frame(sock: &mut UnixStream) -> io::Result<(u8, Vec<u8>)> {
let mut header = [0u8; 5];
sock.read_exact(&mut header)?;
let kind = header[0];
let len = u32::from_be_bytes([header[1], header[2], header[3], header[4]]) as usize;
if len > 16 * 1024 * 1024 {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("exec: frame len {len} > 16 MiB"),
));
}
let mut payload = vec![0u8; len];
if !payload.is_empty() {
sock.read_exact(&mut payload)?;
}
Ok((kind, payload))
}
fn pipe() -> io::Result<(OwnedFd, OwnedFd)> {
let mut fds = [0 as RawFd; 2];
let r = unsafe { libc::pipe(fds.as_mut_ptr()) };
if r < 0 {
return Err(io::Error::last_os_error());
}
Ok((unsafe { OwnedFd::from_raw_fd(fds[0]) }, unsafe {
OwnedFd::from_raw_fd(fds[1])
}))
}
fn pipe_write(fd: RawFd, mut buf: &[u8]) -> io::Result<()> {
while !buf.is_empty() {
let n = unsafe {
libc::write(fd, buf.as_ptr() as *const libc::c_void, buf.len())
};
if n < 0 {
let e = io::Error::last_os_error();
if e.raw_os_error() == Some(libc::EINTR) {
continue;
}
return Err(e);
}
buf = &buf[n as usize..];
}
Ok(())
}
#[cfg(unix)]
fn synthesize_exit(code: u32) -> ExitStatus {
use std::os::unix::process::ExitStatusExt;
ExitStatus::from_raw(((code as i32) & 0xff) << 8)
}