use crate::fs::FsError;
use crate::ipc::{Request, Response, read_message, write_message};
use crate::policy::SandboxPolicy;
use crate::proxy::{ProxyHandle, ca_bundle_for_policy, proxy_env_vars};
use anyhow::{Context, Result, bail};
use std::path::PathBuf;
use std::sync::atomic::{AtomicU64, Ordering};
use tokio::io::{AsyncBufReadExt, BufReader, ReadHalf, WriteHalf};
use tokio::net::UnixStream;
use tokio::process::{Child, Command};
use tracing::debug;
static SLOT_COUNTER: AtomicU64 = AtomicU64::new(0);
fn unique_socket_path() -> PathBuf {
let n = SLOT_COUNTER.fetch_add(1, Ordering::Relaxed);
let pid = std::process::id();
std::env::temp_dir().join(format!("koda-fs-worker-{pid}-{n}.sock"))
}
fn worker_binary() -> Result<PathBuf> {
if let Ok(p) = std::env::var("KODA_FS_WORKER_BIN") {
return Ok(PathBuf::from(p));
}
if let Ok(p) = std::env::var("CARGO_BIN_EXE_koda-fs-worker") {
return Ok(PathBuf::from(p));
}
let mut p = std::env::current_exe().context("can't locate koda executable")?;
p.set_file_name("koda-fs-worker");
if p.exists() {
return Ok(p);
}
bail!(
"koda-fs-worker not found next to {}; set KODA_FS_WORKER_BIN to override",
p.display()
)
}
pub struct WorkerClient {
child: Child,
socket_path: PathBuf,
reader: BufReader<ReadHalf<UnixStream>>,
writer: WriteHalf<UnixStream>,
}
impl WorkerClient {
pub async fn spawn() -> Result<Self> {
Self::spawn_inner(None, None).await
}
pub async fn spawn_with_policy(writable_root: PathBuf, policy: &SandboxPolicy) -> Result<Self> {
Self::spawn_inner(Some((writable_root, policy)), None).await
}
pub async fn spawn_with_policy_and_proxy(
writable_root: PathBuf,
policy: &SandboxPolicy,
proxy: Option<&ProxyHandle>,
) -> Result<Self> {
let env = proxy.map(|p| {
let ca = p.ca_bundle().or_else(|| ca_bundle_for_policy(&policy.net));
proxy_env_vars(p.port, ca)
});
Self::spawn_inner(Some((writable_root, policy)), env).await
}
async fn spawn_inner(
policy_args: Option<(PathBuf, &SandboxPolicy)>,
extra_env: Option<Vec<(String, String)>>,
) -> Result<Self> {
let socket_path = unique_socket_path();
let bin = worker_binary()?;
let mut cmd = Command::new(&bin);
cmd.arg("--socket")
.arg(&socket_path)
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::inherit());
if let Some((ref root, policy)) = policy_args {
cmd.arg("--root").arg(root);
let policy_json =
serde_json::to_string(policy).context("serialize SandboxPolicy for worker env")?;
cmd.env("KODA_FS_WORKER_POLICY", policy_json);
}
if let Some(env) = extra_env {
cmd.envs(env);
}
let mut child = cmd
.spawn()
.with_context(|| format!("spawn {}", bin.display()))?;
let stdout = child.stdout.take().expect("stdout piped");
let mut lines = BufReader::new(stdout).lines();
tokio::time::timeout(std::time::Duration::from_secs(5), async {
while let Some(line) = lines.next_line().await? {
if line.trim() == "ready" {
return Ok::<_, std::io::Error>(());
}
}
Err(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
"worker exited before signalling ready",
))
})
.await
.context("worker readiness timeout (5 s)")?
.context("reading worker stdout")?;
drop(lines);
let stream = tokio::time::timeout(
std::time::Duration::from_secs(2),
UnixStream::connect(&socket_path),
)
.await
.context("Unix socket connect timeout (2 s)")?
.context("UnixStream::connect")?;
let (r, writer) = tokio::io::split(stream);
let reader = BufReader::new(r);
debug!("worker_client: connected to {}", socket_path.display());
Ok(Self {
child,
socket_path,
reader,
writer,
})
}
pub fn socket_path(&self) -> &std::path::Path {
&self.socket_path
}
pub async fn request(&mut self, req: &Request) -> Result<Response, FsError> {
write_message(&mut self.writer, req)
.await
.map_err(|e| FsError::Transport {
message: format!("write: {e}"),
})?;
let resp: Response = read_message(&mut self.reader)
.await
.map_err(|e| FsError::Transport {
message: format!("read: {e}"),
})?
.ok_or_else(|| FsError::Transport {
message: "worker closed connection unexpectedly".into(),
})?;
Ok(resp)
}
}
impl Drop for WorkerClient {
fn drop(&mut self) {
let _ = self.child.start_kill();
let _ = std::fs::remove_file(&self.socket_path);
}
}