code_executor/
runner.rs

1use std::{io, path::Path, process::Stdio, time::Duration};
2
3use cgroups_rs::{CgroupPid, fs::Cgroup};
4use nix::{
5    sys::signal::{self, Signal},
6    unistd::Pid,
7};
8use tokio::{
9    io::{AsyncReadExt, AsyncWriteExt},
10    process::Command,
11    task::JoinHandle,
12    time::{Instant, sleep},
13};
14
15use crate::{CommandArgs, ExitStatus, ResourceConfig, Result, cgroup, metrics::Metrics};
16
17#[derive(Debug)]
18pub struct Runner<'a> {
19    pub args: CommandArgs<'a>,
20    pub project_path: &'a Path,
21    pub time_limit: Duration,
22    pub resource_config: ResourceConfig,
23    pub cgroup: Cgroup,
24}
25
26impl<'a> Runner<'a> {
27    #[tracing::instrument(err)]
28    pub fn new(
29        args: CommandArgs<'a>,
30        project_path: &'a Path,
31        time_limit: Duration,
32        resource_config: ResourceConfig,
33    ) -> Result<Self> {
34        let cgroup = resource_config.try_into()?;
35
36        Ok(Self {
37            args,
38            project_path,
39            time_limit,
40            resource_config,
41            cgroup,
42        })
43    }
44
45    fn new_cpu_limiter(&self, id: i32) -> JoinHandle<()> {
46        let cg = self.cgroup.clone();
47        let quota = self.resource_config.quota;
48        let period = Duration::from_millis(self.resource_config.period);
49
50        tokio::spawn(async move {
51            let Some(mut prev_usage) = cgroup::get_cpu_usage(&cg) else {
52                return;
53            };
54
55            loop {
56                sleep(period).await;
57
58                let Some(cur_usage) = cgroup::get_cpu_usage(&cg) else {
59                    break;
60                };
61                if cur_usage - prev_usage <= quota {
62                    prev_usage = cur_usage;
63                    continue;
64                }
65
66                signal::kill(Pid::from_raw(id), Signal::SIGSTOP).unwrap();
67                sleep(period).await;
68                signal::kill(Pid::from_raw(id), Signal::SIGCONT).unwrap();
69
70                match cgroup::get_cpu_usage(&cg) {
71                    Some(v) => prev_usage = v,
72                    None => break,
73                }
74            }
75        })
76    }
77
78    #[tracing::instrument(err)]
79    pub async fn run(&self, input: &[u8]) -> Result<Metrics> {
80        let CommandArgs { binary, args } = self.args;
81
82        let start = Instant::now();
83
84        let mut child = Command::new(binary)
85            .current_dir(self.project_path)
86            .args(args)
87            .stdin(Stdio::piped())
88            .stdout(Stdio::piped())
89            .stderr(Stdio::piped())
90            .spawn()?;
91
92        let id = child.id().unwrap();
93        self.cgroup.add_task_by_tgid(CgroupPid::from(id as u64))?;
94        let cpu_limiter = self.new_cpu_limiter(id as i32);
95
96        let mut stdin = child.stdin.take().unwrap();
97        let mut stdout = child.stdout.take().unwrap();
98        let mut stderr = child.stderr.take().unwrap();
99
100        let stdout_observer = async move {
101            let mut buffer = Vec::new();
102            stdout.read_to_end(&mut buffer).await?;
103
104            Ok::<_, io::Error>(buffer)
105        };
106        let stderr_observer = async move {
107            let mut buffer = Vec::new();
108            stderr.read_to_end(&mut buffer).await?;
109            Ok::<_, io::Error>(buffer)
110        };
111
112        let exit_status = tokio::select! {
113            exit_status = async {
114                stdin.write_all(input).await?;
115                drop(stdin);
116                let exit_status = child.wait().await?;
117
118                Ok::<_, io::Error>(exit_status)
119            } => {
120                exit_status.map(|raw| raw.into())
121            }
122            _ = sleep(self.time_limit) => {
123                child.kill().await?;
124                child.wait().await?;
125
126                Ok(ExitStatus::TimeLimitExceeded)
127            }
128        }?;
129
130        cpu_limiter.abort();
131
132        let (stdout, stderr) = tokio::try_join!(stdout_observer, stderr_observer)?;
133
134        Ok(Metrics {
135            exit_status,
136            stdout,
137            stderr,
138            run_time: start.elapsed(),
139        })
140    }
141}
142
143impl<'a> Drop for Runner<'a> {
144    fn drop(&mut self) {
145        let _ = self.cgroup.delete();
146    }
147}
148
149#[cfg(test)]
150mod test {
151    use std::{
152        path::{Path, PathBuf},
153        time::Duration,
154    };
155
156    use bstr::ByteSlice;
157    use byte_unit::Byte;
158    use rstest::rstest;
159
160    use crate::{
161        CPP, ExitStatus, JAVA, Language, PYTHON, RUST, ResourceConfig, Runner,
162        test::{read_code, read_test_cases},
163    };
164
165    #[rstest]
166    #[tokio::test]
167    async fn should_output_correct(
168        #[values(CPP, RUST, JAVA, PYTHON)] language: Language<'static>,
169
170        #[dirs]
171        #[files("tests/data/problem/*")]
172        problem_path: PathBuf,
173    ) {
174        let test_cases = read_test_cases(&problem_path);
175
176        let code = read_code(language, &problem_path);
177        let project_path = language.compiler.compile(&code).await.unwrap();
178
179        let runner = Runner::new(
180            language.runner_args,
181            &project_path,
182            Duration::from_secs(2),
183            ResourceConfig::builder()
184                .memory_limit(Byte::GIBIBYTE)
185                .build(),
186        )
187        .unwrap();
188        for (input, output) in test_cases {
189            let metrics = runner.run(&input).await.unwrap();
190            let metrics_out = metrics.stdout.trim();
191            let test_case_out = output.trim();
192            assert_eq!(metrics_out, test_case_out);
193        }
194    }
195
196    #[rstest]
197    #[tokio::test]
198    async fn should_timeout(#[values(CPP, RUST, JAVA, PYTHON)] language: Language<'static>) {
199        let code = read_code(language, Path::new("tests/data/timeout"));
200        let project_path = language.compiler.compile(code.as_bytes()).await.unwrap();
201
202        let runner = Runner::new(
203            language.runner_args,
204            &project_path,
205            Duration::from_secs(2),
206            ResourceConfig::builder()
207                .memory_limit(Byte::GIBIBYTE)
208                .build(),
209        )
210        .unwrap();
211
212        let metrics = runner.run(b"").await.unwrap();
213
214        assert_eq!(metrics.exit_status, ExitStatus::TimeLimitExceeded)
215    }
216}