code_executor/
runner.rs

1use std::{
2    io,
3    path::Path,
4    process::{self, Stdio},
5    sync::Arc,
6    time::Duration,
7};
8
9use cached::proc_macro::cached;
10use cgroups_rs::{Cgroup, CgroupPid, cgroup_builder::CgroupBuilder, hierarchies};
11use tokio::{
12    io::{AsyncReadExt, AsyncWriteExt},
13    process::Command,
14    time::{Instant, sleep},
15};
16
17use crate::{CommandArgs, Error, Result, metrics::Metrics};
18
19#[cached(result = true)]
20fn create_cgroup(memory_limit: i64) -> Result<Cgroup> {
21    let cgroup_name = format!("runner/{}", memory_limit);
22    let hier = hierarchies::auto();
23    let cgroup = CgroupBuilder::new(&cgroup_name)
24        .memory()
25        .memory_swap_limit(memory_limit)
26        .memory_soft_limit(memory_limit)
27        .memory_hard_limit(memory_limit)
28        .done()
29        .build(hier)?;
30
31    Ok(cgroup)
32}
33
34#[derive(Debug)]
35pub struct Runner<'a> {
36    pub args: CommandArgs<'a>,
37    pub project_path: &'a Path,
38    pub time_limit: Duration,
39    pub cgroup: Arc<Cgroup>,
40}
41
42impl<'a> Runner<'a> {
43    #[tracing::instrument(err)]
44    pub fn new(
45        args: CommandArgs<'a>,
46        project_path: &'a Path,
47        time_limit: Duration,
48        memory_limit: i64,
49    ) -> Result<Self> {
50        let cgroup = create_cgroup(memory_limit)?;
51
52        Ok(Self {
53            args,
54            project_path,
55            cgroup: Arc::new(cgroup),
56            time_limit,
57        })
58    }
59
60    #[tracing::instrument(err)]
61    pub async fn run(&self, input: &[u8]) -> Result<Metrics> {
62        let CommandArgs { binary, args } = self.args;
63
64        let cgroup = self.cgroup.clone();
65
66        let mut child = Command::new(binary);
67        let child = child
68            .current_dir(self.project_path)
69            .args(args)
70            .stdin(Stdio::piped())
71            .stdout(Stdio::piped())
72            .stderr(Stdio::piped());
73        let child = unsafe {
74            child.pre_exec(move || {
75                cgroup
76                    .add_task_by_tgid(CgroupPid::from(process::id() as u64))
77                    .map_err(std::io::Error::other)
78            })
79        };
80        let start = Instant::now();
81        let mut child = child.spawn()?;
82        let mut stdin = child.stdin.take().unwrap();
83        let mut stdout = child.stdout.take().unwrap();
84        let mut stderr = child.stderr.take().unwrap();
85
86        let stdout_observer = async move {
87            let mut buffer = Vec::new();
88            stdout.read_to_end(&mut buffer).await?;
89
90            Ok::<_, io::Error>(buffer)
91        };
92        let stderr_observer = async move {
93            let mut buffer = Vec::new();
94            stderr.read_to_end(&mut buffer).await?;
95            Ok::<_, io::Error>(buffer)
96        };
97
98        tokio::select! {
99            exit_status = async {
100                stdin.write_all(input).await?;
101                let exit_status = child.wait().await?;
102
103                Ok::<_, io::Error>(exit_status)
104            } => {
105                let (stdout, stderr) = tokio::try_join! {
106                    stdout_observer,
107                    stderr_observer
108                }?;
109
110                Ok(Metrics {
111                    exit_status: exit_status?,
112                    stdout,
113                    stderr,
114                    run_time: start.elapsed()
115                })
116            }
117            _ = sleep(self.time_limit) => {
118                child.kill().await?;
119                child.wait().await?;
120
121                Err(Error::Timeout)
122            }
123        }
124    }
125}