1use std::{
2 hash::{DefaultHasher, Hash, Hasher},
3 path::Path,
4 process::{self, Stdio},
5 time::Duration,
6};
7
8use cgroups_rs::{Cgroup, CgroupPid, cgroup_builder::CgroupBuilder, hierarchies};
9use tokio::{
10 io::AsyncWriteExt,
11 process::Command,
12 time::{Instant, timeout},
13};
14
15use crate::{CommandArgs, Result, metrics::Metrics};
16
17#[derive(Debug)]
18pub struct Runner<'a> {
19 args: CommandArgs<'a>,
20 project_path: &'a Path,
21 time_limit: Duration,
22 cgroup: Cgroup,
23}
24
25impl<'a> Runner<'a> {
26 #[tracing::instrument(err)]
27 pub fn new(
28 args: CommandArgs<'a>,
29 project_path: &'a Path,
30 time_limit: Duration,
31 memory_limit: i64,
32 ) -> Result<Self> {
33 let mut hasher = DefaultHasher::new();
34 (args, memory_limit, time_limit).hash(&mut hasher);
35 let cgroup_name = format!("runner/{}", hasher.finish());
36 let hier = hierarchies::auto();
37 let cgroup = CgroupBuilder::new(&cgroup_name)
38 .memory()
39 .memory_swap_limit(memory_limit)
40 .memory_soft_limit(memory_limit)
41 .memory_hard_limit(memory_limit)
42 .done()
43 .build(hier)?;
44
45 Ok(Self {
46 args,
47 project_path,
48 cgroup,
49 time_limit,
50 })
51 }
52
53 #[tracing::instrument(err)]
54 pub async fn run(&self, input: &[u8]) -> Result<Metrics> {
55 let CommandArgs { binary, args } = self.args;
56
57 let cgroup = self.cgroup.clone();
58
59 let mut child = Command::new(binary);
60 let child = child
61 .current_dir(self.project_path)
62 .args(args)
63 .stdin(Stdio::piped())
64 .stdout(Stdio::piped())
65 .stderr(Stdio::piped());
66 let child = unsafe {
67 child.pre_exec(move || {
68 cgroup
69 .add_task_by_tgid(CgroupPid::from(process::id() as u64))
70 .map_err(std::io::Error::other)
71 })
72 };
73 let start = Instant::now();
74 let mut child = child.spawn()?;
75
76 let child_stdin = child.stdin.as_mut().unwrap();
77 child_stdin.write_all(input).await?;
78
79 let output = timeout(self.time_limit, child.wait_with_output()).await??;
80
81 Ok(Metrics {
82 run_time: start.elapsed(),
83 exit_status: output.status,
84 stdout: output.stdout,
85 stderr: output.stderr,
86 })
87 }
88}