code_executor/
runner.rs

1use std::{
2    hash::{DefaultHasher, Hash, Hasher},
3    path::Path,
4    process::{self, Stdio},
5    time::Duration,
6};
7
8use cgroups_rs::{Cgroup, CgroupPid, cgroup_builder::CgroupBuilder, hierarchies};
9use tokio::{
10    io::AsyncWriteExt,
11    process::Command,
12    time::{Instant, timeout},
13};
14
15use crate::{CommandArgs, Result, metrics::Metrics};
16
17#[derive(Debug, Clone, Copy, Hash)]
18pub struct Runner<'a> {
19    pub args: CommandArgs<'a>,
20    pub max_memory: i64,
21    pub max_cpu_percentage: i64,
22}
23
24impl Runner<'_> {
25    fn get_cgroup_name(&self) -> String {
26        let mut hasher = DefaultHasher::new();
27        self.hash(&mut hasher);
28
29        format!("runner-{}", hasher.finish())
30    }
31
32    #[tracing::instrument(err)]
33    fn create_cgroup(&self) -> Result<Cgroup> {
34        let hier = hierarchies::auto();
35        let cgroup = CgroupBuilder::new(&self.get_cgroup_name())
36            .memory()
37            .memory_swap_limit(self.max_memory)
38            .memory_soft_limit(self.max_memory)
39            .memory_hard_limit(self.max_memory)
40            .done()
41            .cpu()
42            .quota(self.max_cpu_percentage * 1000)
43            .done()
44            .build(hier)?;
45        Ok(cgroup)
46    }
47
48    #[tracing::instrument(err)]
49    pub async fn run(
50        &self,
51        project_path: &Path,
52        input: &str,
53        time_limit: Duration,
54    ) -> Result<Metrics> {
55        let CommandArgs { binary, args } = self.args;
56
57        let cgroup = self.create_cgroup()?;
58
59        let mut child = Command::new(binary);
60        let child = child
61            .current_dir(project_path)
62            .args(args)
63            .stdin(Stdio::piped())
64            .stdout(Stdio::piped())
65            .stderr(Stdio::piped());
66        let child = unsafe {
67            child.pre_exec(move || {
68                cgroup
69                    .add_task_by_tgid(CgroupPid::from(process::id() as u64))
70                    .map_err(std::io::Error::other)
71            })
72        };
73        let start = Instant::now();
74        let mut child = child.spawn()?;
75
76        let child_stdin = child.stdin.as_mut().unwrap();
77        child_stdin.write_all(input.as_bytes()).await?;
78
79        let output = timeout(time_limit, child.wait_with_output()).await??;
80
81        Ok(Metrics {
82            run_time: start.elapsed(),
83            exit_status: output.status,
84            stdout: String::from_utf8(output.stdout)?.trim().to_string(),
85            stderr: String::from_utf8(output.stderr)?,
86        })
87    }
88}