1use std::{
2 hash::{DefaultHasher, Hash, Hasher},
3 path::Path,
4 process::{self, Stdio},
5 time::Duration,
6};
7
8use cgroups_rs::{Cgroup, CgroupPid, cgroup_builder::CgroupBuilder, hierarchies};
9use tokio::{
10 io::AsyncWriteExt,
11 process::Command,
12 time::{Instant, timeout},
13};
14
15use crate::{CommandArgs, Result, metrics::Metrics};
16
17#[derive(Debug, Clone, Copy, Hash)]
18pub struct Runner<'a> {
19 pub args: CommandArgs<'a>,
20}
21
22impl Runner<'_> {
23 fn get_cgroup_name(&self, project_path: &Path) -> String {
24 let mut hasher = DefaultHasher::new();
25 self.hash(&mut hasher);
26 project_path.hash(&mut hasher);
27
28 format!("runner/{}", hasher.finish())
29 }
30
31 #[tracing::instrument(err)]
32 fn create_cgroup(&self, project_path: &Path, max_memory: i64) -> Result<Cgroup> {
33 let hier = hierarchies::auto();
34 let cgroup = CgroupBuilder::new(&self.get_cgroup_name(project_path))
35 .memory()
36 .memory_swap_limit(max_memory)
37 .memory_soft_limit(max_memory)
38 .memory_hard_limit(max_memory)
39 .done()
40 .build(hier)?;
41 Ok(cgroup)
42 }
43
44 #[tracing::instrument(err)]
45 pub async fn run(
46 &self,
47 project_path: &Path,
48 input: &str,
49 time_limit: Duration,
50 max_memory: i64,
51 ) -> Result<Metrics> {
52 let CommandArgs { binary, args } = self.args;
53
54 let cgroup = self.create_cgroup(project_path, max_memory)?;
55
56 let mut child = Command::new(binary);
57 let child = child
58 .current_dir(project_path)
59 .args(args)
60 .stdin(Stdio::piped())
61 .stdout(Stdio::piped())
62 .stderr(Stdio::piped());
63 let child = unsafe {
64 child.pre_exec(move || {
65 cgroup
66 .add_task_by_tgid(CgroupPid::from(process::id() as u64))
67 .map_err(std::io::Error::other)
68 })
69 };
70 let start = Instant::now();
71 let mut child = child.spawn()?;
72
73 let child_stdin = child.stdin.as_mut().unwrap();
74 child_stdin.write_all(input.as_bytes()).await?;
75
76 let output = timeout(time_limit, child.wait_with_output()).await??;
77
78 Ok(Metrics {
79 run_time: start.elapsed(),
80 exit_status: output.status,
81 stdout: String::from_utf8(output.stdout)?.trim().to_string(),
82 stderr: String::from_utf8(output.stderr)?,
83 })
84 }
85}