use std::{io, path::Path, process::Stdio, time::Duration};
use cgroups_rs::{CgroupPid, fs::Cgroup};
use nix::{
sys::signal::{self, Signal},
unistd::Pid,
};
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
process::Command,
task::JoinHandle,
time::{Instant, sleep},
};
use crate::{CommandArgs, ExitStatus, ResourceConfig, Result, cgroup, metrics::Metrics};
#[derive(Debug)]
pub struct Runner<'a> {
pub args: CommandArgs<'a>,
pub project_path: &'a Path,
pub time_limit: Duration,
pub resource_config: ResourceConfig,
pub cgroup: Cgroup,
}
impl<'a> Runner<'a> {
#[tracing::instrument(err)]
pub fn new(
args: CommandArgs<'a>,
project_path: &'a Path,
time_limit: Duration,
resource_config: ResourceConfig,
) -> Result<Self> {
let cgroup = resource_config.try_into()?;
Ok(Self {
args,
project_path,
time_limit,
resource_config,
cgroup,
})
}
fn new_cpu_limiter(&self, id: i32) -> JoinHandle<()> {
let cg = self.cgroup.clone();
let quota = self.resource_config.quota;
let period = Duration::from_millis(self.resource_config.period);
tokio::spawn(async move {
let Some(mut prev_usage) = cgroup::get_cpu_usage(&cg) else {
return;
};
loop {
sleep(period).await;
let Some(cur_usage) = cgroup::get_cpu_usage(&cg) else {
break;
};
if cur_usage - prev_usage <= quota {
prev_usage = cur_usage;
continue;
}
signal::kill(Pid::from_raw(id), Signal::SIGSTOP).unwrap();
sleep(period).await;
signal::kill(Pid::from_raw(id), Signal::SIGCONT).unwrap();
match cgroup::get_cpu_usage(&cg) {
Some(v) => prev_usage = v,
None => break,
}
}
})
}
#[tracing::instrument(err)]
pub async fn run(&self, input: &[u8]) -> Result<Metrics> {
let CommandArgs { binary, args } = self.args;
let start = Instant::now();
let mut child = Command::new(binary)
.current_dir(self.project_path)
.args(args)
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.spawn()?;
let id = child.id().unwrap();
self.cgroup.add_task_by_tgid(CgroupPid::from(id as u64))?;
let cpu_limiter = self.new_cpu_limiter(id as i32);
let mut stdin = child.stdin.take().unwrap();
let mut stdout = child.stdout.take().unwrap();
let mut stderr = child.stderr.take().unwrap();
let stdout_observer = async move {
let mut buffer = Vec::new();
stdout.read_to_end(&mut buffer).await?;
Ok::<_, io::Error>(buffer)
};
let stderr_observer = async move {
let mut buffer = Vec::new();
stderr.read_to_end(&mut buffer).await?;
Ok::<_, io::Error>(buffer)
};
let exit_status = tokio::select! {
exit_status = async {
stdin.write_all(input).await?;
drop(stdin);
let exit_status = child.wait().await?;
Ok::<_, io::Error>(exit_status)
} => {
exit_status.map(|raw| raw.into())
}
_ = sleep(self.time_limit) => {
child.kill().await?;
child.wait().await?;
Ok(ExitStatus::TimeLimitExceeded)
}
}?;
cpu_limiter.abort();
let (stdout, stderr) = tokio::try_join!(stdout_observer, stderr_observer)?;
Ok(Metrics {
exit_status,
stdout,
stderr,
run_time: start.elapsed(),
})
}
}
impl<'a> Drop for Runner<'a> {
fn drop(&mut self) {
let _ = self.cgroup.delete();
}
}
#[cfg(test)]
mod test {
use std::{
path::{Path, PathBuf},
time::Duration,
};
use bstr::ByteSlice;
use byte_unit::Byte;
use rstest::rstest;
use crate::{
CPP, ExitStatus, JAVA, Language, PYTHON, RUST, ResourceConfig, Runner,
test::{read_code, read_test_cases},
};
#[rstest]
#[tokio::test]
async fn should_output_correct(
#[values(CPP, RUST, JAVA, PYTHON)] language: Language<'static>,
#[dirs]
#[files("tests/data/problem/*")]
problem_path: PathBuf,
) {
let test_cases = read_test_cases(&problem_path);
let code = read_code(language, &problem_path);
let project_path = language.compiler.compile(&code).await.unwrap();
let runner = Runner::new(
language.runner_args,
&project_path,
Duration::from_secs(2),
ResourceConfig::builder()
.memory_limit(Byte::GIBIBYTE)
.build(),
)
.unwrap();
for (input, output) in test_cases {
let metrics = runner.run(&input).await.unwrap();
let metrics_out = metrics.stdout.trim();
let test_case_out = output.trim();
assert_eq!(metrics_out, test_case_out);
}
}
#[rstest]
#[tokio::test]
async fn should_timeout(#[values(CPP, RUST, JAVA, PYTHON)] language: Language<'static>) {
let code = read_code(language, Path::new("tests/data/timeout"));
let project_path = language.compiler.compile(code.as_bytes()).await.unwrap();
let runner = Runner::new(
language.runner_args,
&project_path,
Duration::from_secs(2),
ResourceConfig::builder()
.memory_limit(Byte::GIBIBYTE)
.build(),
)
.unwrap();
let metrics = runner.run(b"").await.unwrap();
assert_eq!(metrics.exit_status, ExitStatus::TimeLimitExceeded)
}
}