code_executor/
sandbox.rs

1use std::{
2    env,
3    ffi::{CString, c_int},
4    fs, iter,
5    path::Path,
6    time::{Duration, Instant},
7};
8
9use libseccomp::{ScmpAction, ScmpFilterContext, ScmpSyscall};
10use nix::{
11    libc::{self, WEXITSTATUS, WSTOPPED, WTERMSIG, wait4},
12    sys::{
13        resource::{Resource, setrlimit},
14        signal::{SaFlags, SigAction, SigHandler, SigSet, Signal, sigaction},
15    },
16    unistd::{ForkResult, alarm, dup2_stderr, dup2_stdin, dup2_stdout, execvp, fork},
17};
18use state_shift::{impl_state, type_state};
19
20use crate::{
21    CommandArgs, Error, Result,
22    metrics::{Metrics, Rusage, get_default_rusage},
23};
24
25extern "C" fn signal_handler(_: nix::libc::c_int) {}
26#[derive(Debug, Clone, Copy)]
27pub struct RlimitConfig {
28    pub resource: Resource,
29    pub soft_limit: u64,
30    pub hard_limit: u64,
31}
32
33#[derive(Debug, Clone, Copy)]
34pub struct SandboxConfig<'a> {
35    pub scmp_black_list: &'a [&'a str],
36    pub rlimit_configs: &'a [RlimitConfig],
37}
38
39impl SandboxConfig<'_> {
40    fn apply(&self) -> Result<()> {
41        for rlimit in self.rlimit_configs {
42            setrlimit(rlimit.resource, rlimit.soft_limit, rlimit.hard_limit)?;
43        }
44
45        let mut scmp_filter = ScmpFilterContext::new(ScmpAction::Allow)?;
46        for s in self.scmp_black_list {
47            let syscall = ScmpSyscall::from_name(s)?;
48            scmp_filter.add_rule_exact(ScmpAction::KillProcess, syscall)?;
49        }
50
51        scmp_filter.load()?;
52
53        Ok(())
54    }
55}
56
57#[type_state(
58    states = (Initial, Running),
59    slots = (Initial)
60)]
61#[derive(Debug)]
62pub struct Sandbox<'a> {
63    config: SandboxConfig<'a>,
64    project_path: &'a Path,
65    args: CommandArgs<'a>,
66    stdin: &'a Path,
67    stdout: &'a Path,
68    stderr: &'a Path,
69    time_limit: Duration,
70
71    child_pid: i32,
72    start: Instant,
73}
74
75#[impl_state]
76impl<'a> Sandbox<'a> {
77    #[require(Initial)]
78    pub fn new(
79        config: SandboxConfig<'a>,
80        project_path: &'a Path,
81        args: CommandArgs<'a>,
82        stdin: &'a Path,
83        stdout: &'a Path,
84        stderr: &'a Path,
85        time_limit: Duration,
86    ) -> Self {
87        Self {
88            config,
89            project_path,
90            args,
91            stdin,
92            stdout,
93            stderr,
94            time_limit,
95            child_pid: -1,
96            start: Instant::now(),
97            _state: (::core::marker::PhantomData),
98        }
99    }
100
101    #[require(Initial)]
102    fn load_io(&self) -> Result<()> {
103        let stdin = fs::OpenOptions::new().read(true).open(self.stdin)?;
104        dup2_stdin(stdin)?;
105
106        let stdout = fs::OpenOptions::new()
107            .create(true)
108            .truncate(true)
109            .write(true)
110            .open(self.stdout)?;
111        dup2_stdout(stdout)?;
112
113        let stderr = fs::OpenOptions::new()
114            .create(true)
115            .truncate(true)
116            .write(true)
117            .open(self.stderr)?;
118        dup2_stderr(stderr)?;
119
120        Ok(())
121    }
122
123    #[require(Initial)]
124    #[switch_to(Running)]
125    pub fn spawn(self) -> Result<Sandbox<'a, Running>> {
126        unsafe {
127            sigaction(
128                Signal::SIGALRM,
129                &SigAction::new(
130                    SigHandler::Handler(signal_handler),
131                    SaFlags::empty(),
132                    SigSet::empty(),
133                ),
134            )
135            .unwrap();
136        }
137
138        let start = Instant::now();
139        match unsafe { fork() } {
140            Ok(ForkResult::Parent { child, .. }) => Ok(Sandbox {
141                config: self.config,
142                project_path: self.project_path,
143                args: self.args,
144                stdin: self.stdin,
145                stdout: self.stdout,
146                stderr: self.stderr,
147                time_limit: self.time_limit,
148                child_pid: child.as_raw(),
149                start,
150                _state: (::core::marker::PhantomData),
151            }),
152            // child process should not return to do things outside `spawn()`
153            Ok(ForkResult::Child) => {
154                if env::set_current_dir(self.project_path).is_err() {
155                    eprintln!("Failed to load change to project directory");
156                    unsafe { libc::_exit(100) };
157                }
158
159                if self.load_io().is_err() {
160                    eprintln!("Failed to load I/O");
161                    unsafe { libc::_exit(1) };
162                }
163
164                if self.config.apply().is_err() {
165                    eprintln!("Failed to load config");
166                    unsafe { libc::_exit(1) };
167                }
168
169                alarm::set(self.time_limit.as_secs() as u32);
170
171                let CommandArgs { binary, args } = self.args;
172                let args: Vec<_> = iter::once(binary)
173                    .chain(args.iter().copied())
174                    .map(|arg| CString::new(arg.as_bytes()).unwrap())
175                    .collect();
176                let binary = CString::new(binary.as_bytes()).unwrap();
177
178                let error = execvp(&binary, &args).unwrap_err();
179                eprintln!("{}", error);
180
181                unsafe { libc::_exit(0) };
182            }
183            Err(e) => Err(e.into()),
184        }
185    }
186
187    #[require(Running)]
188    pub fn wait(self) -> Result<Metrics> {
189        let mut status: c_int = 0;
190        let mut usage = get_default_rusage();
191        unsafe {
192            wait4(self.child_pid, &mut status, WSTOPPED, &mut usage);
193        }
194
195        let error = fs::read_to_string(self.stderr)?;
196        if !error.is_empty() {
197            return Err(Error::Runtime { message: error });
198        }
199
200        let output = fs::read_to_string(self.stdout)?.trim().to_string();
201        Ok(Metrics {
202            exit_status: status,
203            exit_signal: WTERMSIG(status),
204            exit_code: WEXITSTATUS(status),
205            real_time_cost: self.start.elapsed(),
206            resource_usage: Rusage::from(usage),
207            output,
208        })
209    }
210}