1use std::{
2 hash::{DefaultHasher, Hash, Hasher},
3 path::Path,
4 process::{self, Stdio},
5 time::Duration,
6};
7
8use cgroups_rs::{Cgroup, CgroupPid, cgroup_builder::CgroupBuilder, hierarchies};
9use tokio::{
10 io::AsyncWriteExt,
11 process::Command,
12 time::{Instant, timeout},
13};
14
15use crate::{CommandArgs, Result, metrics::Metrics};
16
17#[derive(Debug, Clone, Copy, Hash)]
18pub struct Runner<'a> {
19 pub args: CommandArgs<'a>,
20 pub max_memory: i64,
21 pub max_cpu_percentage: i64,
22}
23
24impl Runner<'_> {
25 fn get_cgroup_name(&self) -> String {
26 let mut hasher = DefaultHasher::new();
27 self.hash(&mut hasher);
28
29 format!("runner-{}", hasher.finish())
30 }
31
32 #[tracing::instrument(err)]
33 fn create_cgroup(&self) -> Result<Cgroup> {
34 let hier = hierarchies::auto();
35 let cgroup = CgroupBuilder::new(&self.get_cgroup_name())
36 .memory()
37 .memory_swap_limit(self.max_memory)
38 .memory_soft_limit(self.max_memory)
39 .memory_hard_limit(self.max_memory)
40 .done()
41 .cpu()
42 .quota(self.max_cpu_percentage * 1000)
43 .done()
44 .build(hier)?;
45 Ok(cgroup)
46 }
47
48 #[tracing::instrument(err)]
49 pub async fn run(
50 &self,
51 project_path: &Path,
52 input: &str,
53 time_limit: Duration,
54 ) -> Result<Metrics> {
55 let CommandArgs { binary, args } = self.args;
56
57 let cgroup = self.create_cgroup()?;
58
59 let mut child = Command::new(binary);
60 let child = child
61 .current_dir(project_path)
62 .args(args)
63 .stdin(Stdio::piped())
64 .stdout(Stdio::piped())
65 .stderr(Stdio::piped());
66 let child = unsafe {
67 child.pre_exec(move || {
68 cgroup
69 .add_task_by_tgid(CgroupPid::from(process::id() as u64))
70 .map_err(std::io::Error::other)
71 })
72 };
73 let start = Instant::now();
74 let mut child = child.spawn()?;
75
76 let child_stdin = child.stdin.as_mut().unwrap();
77 child_stdin.write_all(input.as_bytes()).await?;
78
79 let output = timeout(time_limit, child.wait_with_output()).await??;
80
81 Ok(Metrics {
82 run_time: start.elapsed(),
83 exit_status: output.status,
84 stdout: String::from_utf8(output.stdout)?.trim().to_string(),
85 stderr: String::from_utf8(output.stderr)?,
86 })
87 }
88}