1use std::{io, path::Path, process::Stdio, time::Duration};
2
3use cgroups_rs::{CgroupPid, fs::Cgroup};
4use nix::{
5 sys::signal::{self, Signal},
6 unistd::Pid,
7};
8use tokio::{
9 io::{AsyncReadExt, AsyncWriteExt},
10 process::Command,
11 task::JoinHandle,
12 time::{Instant, sleep},
13};
14
15use crate::{CommandArgs, ExitStatus, ResourceConfig, Result, cgroup, metrics::Metrics};
16
17#[derive(Debug)]
18pub struct Runner<'a> {
19 pub args: CommandArgs<'a>,
20 pub project_path: &'a Path,
21 pub time_limit: Duration,
22 pub resource_config: ResourceConfig,
23 pub cgroup: Cgroup,
24}
25
26impl<'a> Runner<'a> {
27 #[tracing::instrument(err)]
28 pub fn new(
29 args: CommandArgs<'a>,
30 project_path: &'a Path,
31 time_limit: Duration,
32 resource_config: ResourceConfig,
33 ) -> Result<Self> {
34 let cgroup = resource_config.try_into()?;
35
36 Ok(Self {
37 args,
38 project_path,
39 time_limit,
40 resource_config,
41 cgroup,
42 })
43 }
44
45 fn new_cpu_limiter(&self, id: i32) -> JoinHandle<()> {
46 let cg = self.cgroup.clone();
47 let quota = self.resource_config.quota;
48 let period = Duration::from_millis(self.resource_config.period);
49
50 tokio::spawn(async move {
51 let Some(mut prev_usage) = cgroup::get_cpu_usage(&cg) else {
52 return;
53 };
54
55 loop {
56 sleep(period).await;
57
58 let Some(cur_usage) = cgroup::get_cpu_usage(&cg) else {
59 break;
60 };
61 if cur_usage - prev_usage <= quota {
62 prev_usage = cur_usage;
63 continue;
64 }
65
66 signal::kill(Pid::from_raw(id), Signal::SIGSTOP).unwrap();
67 sleep(period).await;
68 signal::kill(Pid::from_raw(id), Signal::SIGCONT).unwrap();
69
70 match cgroup::get_cpu_usage(&cg) {
71 Some(v) => prev_usage = v,
72 None => break,
73 }
74 }
75 })
76 }
77
78 #[tracing::instrument(err)]
79 pub async fn run(&self, input: &[u8]) -> Result<Metrics> {
80 let CommandArgs { binary, args } = self.args;
81
82 let start = Instant::now();
83
84 let mut child = Command::new(binary)
85 .current_dir(self.project_path)
86 .args(args)
87 .stdin(Stdio::piped())
88 .stdout(Stdio::piped())
89 .stderr(Stdio::piped())
90 .spawn()?;
91
92 let id = child.id().unwrap();
93 self.cgroup.add_task_by_tgid(CgroupPid::from(id as u64))?;
94 let cpu_limiter = self.new_cpu_limiter(id as i32);
95
96 let mut stdin = child.stdin.take().unwrap();
97 let mut stdout = child.stdout.take().unwrap();
98 let mut stderr = child.stderr.take().unwrap();
99
100 let stdout_observer = async move {
101 let mut buffer = Vec::new();
102 stdout.read_to_end(&mut buffer).await?;
103
104 Ok::<_, io::Error>(buffer)
105 };
106 let stderr_observer = async move {
107 let mut buffer = Vec::new();
108 stderr.read_to_end(&mut buffer).await?;
109 Ok::<_, io::Error>(buffer)
110 };
111
112 let exit_status = tokio::select! {
113 exit_status = async {
114 stdin.write_all(input).await?;
115 drop(stdin);
116 let exit_status = child.wait().await?;
117
118 Ok::<_, io::Error>(exit_status)
119 } => {
120 exit_status.map(|raw| raw.into())
121 }
122 _ = sleep(self.time_limit) => {
123 child.kill().await?;
124 child.wait().await?;
125
126 Ok(ExitStatus::TimeLimitExceeded)
127 }
128 }?;
129
130 cpu_limiter.abort();
131
132 let (stdout, stderr) = tokio::try_join!(stdout_observer, stderr_observer)?;
133
134 Ok(Metrics {
135 exit_status,
136 stdout,
137 stderr,
138 run_time: start.elapsed(),
139 })
140 }
141}
142
143impl<'a> Drop for Runner<'a> {
144 fn drop(&mut self) {
145 let _ = self.cgroup.delete();
146 }
147}
148
149#[cfg(test)]
150mod test {
151 use std::{
152 path::{Path, PathBuf},
153 time::Duration,
154 };
155
156 use bstr::ByteSlice;
157 use byte_unit::Byte;
158 use rstest::rstest;
159
160 use crate::{
161 CPP, ExitStatus, JAVA, Language, PYTHON, RUST, ResourceConfig, Runner,
162 test::{read_code, read_test_cases},
163 };
164
165 #[rstest]
166 #[tokio::test]
167 async fn should_output_correct(
168 #[values(CPP, RUST, JAVA, PYTHON)] language: Language<'static>,
169
170 #[dirs]
171 #[files("tests/data/problem/*")]
172 problem_path: PathBuf,
173 ) {
174 let test_cases = read_test_cases(&problem_path);
175
176 let code = read_code(language, &problem_path);
177 let project_path = language.compiler.compile(&code).await.unwrap();
178
179 let runner = Runner::new(
180 language.runner_args,
181 &project_path,
182 Duration::from_secs(2),
183 ResourceConfig::builder()
184 .memory_limit(Byte::GIBIBYTE)
185 .build(),
186 )
187 .unwrap();
188 for (input, output) in test_cases {
189 let metrics = runner.run(&input).await.unwrap();
190 let metrics_out = metrics.stdout.trim();
191 let test_case_out = output.trim();
192 assert_eq!(metrics_out, test_case_out);
193 }
194 }
195
196 #[rstest]
197 #[tokio::test]
198 async fn should_timeout(#[values(CPP, RUST, JAVA, PYTHON)] language: Language<'static>) {
199 let code = read_code(language, Path::new("tests/data/timeout"));
200 let project_path = language.compiler.compile(code.as_bytes()).await.unwrap();
201
202 let runner = Runner::new(
203 language.runner_args,
204 &project_path,
205 Duration::from_secs(2),
206 ResourceConfig::builder()
207 .memory_limit(Byte::GIBIBYTE)
208 .build(),
209 )
210 .unwrap();
211
212 let metrics = runner.run(b"").await.unwrap();
213
214 assert_eq!(metrics.exit_status, ExitStatus::TimeLimitExceeded)
215 }
216}