1use std::{
2 io,
3 path::Path,
4 process::{self, Stdio},
5 sync::Arc,
6 time::Duration,
7};
8
9use cached::proc_macro::cached;
10use cgroups_rs::{Cgroup, CgroupPid, cgroup_builder::CgroupBuilder, hierarchies};
11use tokio::{
12 io::{AsyncReadExt, AsyncWriteExt},
13 process::Command,
14 time::{Instant, sleep},
15};
16
17use crate::{CommandArgs, Error, Result, metrics::Metrics};
18
19#[cached(result = true)]
20fn create_cgroup(memory_limit: i64) -> Result<Cgroup> {
21 let cgroup_name = format!("runner/{}", memory_limit);
22 let hier = hierarchies::auto();
23 let cgroup = CgroupBuilder::new(&cgroup_name)
24 .memory()
25 .memory_swap_limit(memory_limit)
26 .memory_soft_limit(memory_limit)
27 .memory_hard_limit(memory_limit)
28 .done()
29 .build(hier)?;
30
31 Ok(cgroup)
32}
33
34#[derive(Debug)]
35pub struct Runner<'a> {
36 pub args: CommandArgs<'a>,
37 pub project_path: &'a Path,
38 pub time_limit: Duration,
39 pub cgroup: Arc<Cgroup>,
40}
41
42impl<'a> Runner<'a> {
43 #[tracing::instrument(err)]
44 pub fn new(
45 args: CommandArgs<'a>,
46 project_path: &'a Path,
47 time_limit: Duration,
48 memory_limit: i64,
49 ) -> Result<Self> {
50 let cgroup = create_cgroup(memory_limit)?;
51
52 Ok(Self {
53 args,
54 project_path,
55 cgroup: Arc::new(cgroup),
56 time_limit,
57 })
58 }
59
60 #[tracing::instrument(err)]
61 pub async fn run(&self, input: &[u8]) -> Result<Metrics> {
62 let CommandArgs { binary, args } = self.args;
63
64 let cgroup = self.cgroup.clone();
65
66 let mut child = Command::new(binary);
67 let child = child
68 .current_dir(self.project_path)
69 .args(args)
70 .stdin(Stdio::piped())
71 .stdout(Stdio::piped())
72 .stderr(Stdio::piped());
73 let child = unsafe {
74 child.pre_exec(move || {
75 cgroup
76 .add_task_by_tgid(CgroupPid::from(process::id() as u64))
77 .map_err(std::io::Error::other)
78 })
79 };
80 let start = Instant::now();
81 let mut child = child.spawn()?;
82 let mut stdin = child.stdin.take().unwrap();
83 let mut stdout = child.stdout.take().unwrap();
84 let mut stderr = child.stderr.take().unwrap();
85
86 let stdout_observer = async move {
87 let mut buffer = Vec::new();
88 stdout.read_to_end(&mut buffer).await?;
89
90 Ok::<_, io::Error>(buffer)
91 };
92 let stderr_observer = async move {
93 let mut buffer = Vec::new();
94 stderr.read_to_end(&mut buffer).await?;
95 Ok::<_, io::Error>(buffer)
96 };
97
98 tokio::select! {
99 exit_status = async {
100 stdin.write_all(input).await?;
101 let exit_status = child.wait().await?;
102
103 Ok::<_, io::Error>(exit_status)
104 } => {
105 let (stdout, stderr) = tokio::try_join! {
106 stdout_observer,
107 stderr_observer
108 }?;
109
110 Ok(Metrics {
111 exit_status: exit_status?,
112 stdout,
113 stderr,
114 run_time: start.elapsed()
115 })
116 }
117 _ = sleep(self.time_limit) => {
118 child.kill().await?;
119 child.wait().await?;
120
121 Err(Error::Timeout)
122 }
123 }
124 }
125}