use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, OwnedFd, RawFd};
use std::time::Duration;
use crate::error::{SandboxError, SandlockError};
use crate::policy::Policy;
use crate::result::{ExitStatus, RunResult};
use crate::sandbox::Sandbox;
pub struct Stage {
pub policy: Policy,
pub args: Vec<String>,
}
impl Stage {
pub fn new(policy: &Policy, args: &[&str]) -> Self {
Self {
policy: policy.clone(),
args: args.iter().map(|s| s.to_string()).collect(),
}
}
pub async fn run(self, timeout: Option<Duration>) -> Result<RunResult, SandlockError> {
let cmd_refs: Vec<&str> = self.args.iter().map(|s| s.as_str()).collect();
if let Some(dur) = timeout {
match tokio::time::timeout(dur, Sandbox::run_interactive(&self.policy, &cmd_refs)).await
{
Ok(result) => result,
Err(_) => Ok(RunResult {
exit_status: ExitStatus::Timeout,
stdout: None,
stderr: None,
}),
}
} else {
Sandbox::run_interactive(&self.policy, &cmd_refs).await
}
}
}
impl std::ops::BitOr<Stage> for Stage {
type Output = Pipeline;
fn bitor(self, rhs: Stage) -> Pipeline {
Pipeline {
stages: vec![self, rhs],
}
}
}
pub struct Pipeline {
pub stages: Vec<Stage>,
}
impl Pipeline {
pub fn new(stages: Vec<Stage>) -> Result<Self, SandlockError> {
if stages.len() < 2 {
return Err(SandlockError::Sandbox(SandboxError::Child(
"Pipeline requires at least 2 stages".into(),
)));
}
Ok(Self { stages })
}
pub async fn run(self, timeout: Option<Duration>) -> Result<RunResult, SandlockError> {
if let Some(dur) = timeout {
match tokio::time::timeout(dur, run_pipeline(self.stages)).await {
Ok(result) => result,
Err(_) => Ok(RunResult {
exit_status: ExitStatus::Timeout,
stdout: None,
stderr: None,
}),
}
} else {
run_pipeline(self.stages).await
}
}
}
impl std::ops::BitOr<Stage> for Pipeline {
type Output = Pipeline;
fn bitor(mut self, rhs: Stage) -> Pipeline {
self.stages.push(rhs);
self
}
}
fn make_pipe() -> std::io::Result<(OwnedFd, OwnedFd)> {
let mut fds = [0i32; 2];
if unsafe { libc::pipe2(fds.as_mut_ptr(), libc::O_CLOEXEC) } < 0 {
return Err(std::io::Error::last_os_error());
}
Ok(unsafe {
(
OwnedFd::from_raw_fd(fds[0]),
OwnedFd::from_raw_fd(fds[1]),
)
})
}
async fn run_pipeline(stages: Vec<Stage>) -> Result<RunResult, SandlockError> {
let n = stages.len();
let mut inter_pipes: Vec<(OwnedFd, OwnedFd)> = Vec::with_capacity(n - 1);
for _ in 0..n - 1 {
inter_pipes.push(make_pipe().map_err(SandboxError::Io)?);
}
let (cap_stdout_r, cap_stdout_w) = make_pipe().map_err(SandboxError::Io)?;
let (cap_stderr_r, cap_stderr_w) = make_pipe().map_err(SandboxError::Io)?;
let mut sandboxes: Vec<Sandbox> = Vec::with_capacity(n);
for (i, stage) in stages.into_iter().enumerate() {
let mut sb = Sandbox::new(&stage.policy)?;
let stdin_fd: Option<RawFd> = if i == 0 {
None } else {
Some(inter_pipes[i - 1].0.as_raw_fd()) };
let stdout_fd: Option<RawFd> = if i == n - 1 {
Some(cap_stdout_w.as_raw_fd()) } else {
Some(inter_pipes[i].1.as_raw_fd()) };
let stderr_fd: Option<RawFd> = if i == n - 1 {
Some(cap_stderr_w.as_raw_fd()) } else {
None };
let cmd_refs: Vec<&str> = stage.args.iter().map(|s| s.as_str()).collect();
sb.spawn_with_io(&cmd_refs, stdin_fd, stdout_fd, stderr_fd)
.await?;
sandboxes.push(sb);
}
drop(inter_pipes);
drop(cap_stdout_w);
drop(cap_stderr_w);
let mut last_result = RunResult {
exit_status: ExitStatus::Killed,
stdout: None,
stderr: None,
};
for (i, mut sb) in sandboxes.into_iter().enumerate() {
let result = sb.wait().await?;
if i == n - 1 {
last_result.exit_status = result.exit_status;
}
}
last_result.stdout = Some(read_fd_to_end(cap_stdout_r));
last_result.stderr = Some(read_fd_to_end(cap_stderr_r));
Ok(last_result)
}
fn read_fd_to_end(fd: OwnedFd) -> Vec<u8> {
use std::io::Read;
let mut file = unsafe { std::fs::File::from_raw_fd(fd.into_raw_fd()) };
let mut buf = Vec::new();
let _ = file.read_to_end(&mut buf);
buf
}
pub struct NamedStage {
pub name: String,
pub stage: Stage,
}
pub struct Gather {
sources: Vec<NamedStage>,
consumer: Option<Stage>,
}
impl Gather {
pub fn new() -> Self {
Self {
sources: Vec::new(),
consumer: None,
}
}
pub fn source(mut self, name: &str, stage: Stage) -> Self {
self.sources.push(NamedStage {
name: name.to_string(),
stage,
});
self
}
pub fn consumer(mut self, stage: Stage) -> Self {
self.consumer = Some(stage);
self
}
pub async fn run(self, timeout: Option<Duration>) -> Result<RunResult, SandlockError> {
let consumer = self.consumer.ok_or_else(|| {
SandlockError::Sandbox(SandboxError::Child("Gather requires a consumer".into()))
})?;
if self.sources.is_empty() {
return Err(SandlockError::Sandbox(SandboxError::Child(
"Gather requires at least one source".into(),
)));
}
if let Some(dur) = timeout {
match tokio::time::timeout(dur, run_gather(self.sources, consumer)).await {
Ok(result) => result,
Err(_) => Ok(RunResult {
exit_status: ExitStatus::Timeout,
stdout: None,
stderr: None,
}),
}
} else {
run_gather(self.sources, consumer).await
}
}
}
async fn run_gather(
sources: Vec<NamedStage>,
consumer: Stage,
) -> Result<RunResult, SandlockError> {
let n = sources.len();
let mut source_pipes: Vec<(OwnedFd, OwnedFd)> = Vec::with_capacity(n);
for _ in 0..n {
source_pipes.push(make_pipe().map_err(SandboxError::Io)?);
}
let mut fd_assignments: Vec<(String, i32)> = Vec::with_capacity(n);
let mut next_fd = 3i32;
for (i, ns) in sources.iter().enumerate() {
let target_fd = if i == n - 1 { 0 } else { let fd = next_fd; next_fd += 1; fd };
fd_assignments.push((ns.name.clone(), target_fd));
}
let gather_env: String = fd_assignments
.iter()
.map(|(name, fd)| format!("{}:{}", name, fd))
.collect::<Vec<_>>()
.join(",");
let (cap_stdout_r, cap_stdout_w) = make_pipe().map_err(SandboxError::Io)?;
let (cap_stderr_r, cap_stderr_w) = make_pipe().map_err(SandboxError::Io)?;
let mut sandboxes: Vec<Sandbox> = Vec::with_capacity(n + 1);
for (i, ns) in sources.into_iter().enumerate() {
let mut sb = Sandbox::new(&ns.stage.policy)?;
let stdout_fd = source_pipes[i].1.as_raw_fd();
let cmd_refs: Vec<&str> = ns.stage.args.iter().map(|s| s.as_str()).collect();
sb.spawn_with_io(&cmd_refs, None, Some(stdout_fd), None).await?;
sandboxes.push(sb);
}
let mut consumer_policy = consumer.policy.clone();
consumer_policy.env.insert("_SANDLOCK_GATHER".to_string(), gather_env);
let mut consumer_sb = Sandbox::new(&consumer_policy)?;
let stdin_fd = source_pipes[n - 1].0.as_raw_fd();
let mut extra_fds = Vec::new();
for (i, (_, target_fd)) in fd_assignments.iter().enumerate() {
if i < n - 1 {
let read_fd = source_pipes[i].0.as_raw_fd();
unsafe {
let flags = libc::fcntl(read_fd, libc::F_GETFD);
libc::fcntl(read_fd, libc::F_SETFD, flags & !libc::FD_CLOEXEC);
}
extra_fds.push((*target_fd, read_fd));
}
}
let cmd_refs: Vec<&str> = consumer.args.iter().map(|s| s.as_str()).collect();
consumer_sb.spawn_with_gather_io(
&cmd_refs,
Some(stdin_fd),
Some(cap_stdout_w.as_raw_fd()),
Some(cap_stderr_w.as_raw_fd()),
extra_fds,
).await?;
sandboxes.push(consumer_sb);
drop(source_pipes);
drop(cap_stdout_w);
drop(cap_stderr_w);
let total = sandboxes.len();
let mut last_result = RunResult {
exit_status: ExitStatus::Killed,
stdout: None,
stderr: None,
};
for (i, mut sb) in sandboxes.into_iter().enumerate() {
let result = sb.wait().await?;
if i == total - 1 {
last_result.exit_status = result.exit_status;
}
}
last_result.stdout = Some(read_fd_to_end(cap_stdout_r));
last_result.stderr = Some(read_fd_to_end(cap_stderr_r));
Ok(last_result)
}