code-executor 2.7.0

A library designed for the backend of competitive programming platforms
Documentation
use std::{io, path::Path, process::Stdio, time::Duration};

use cgroups_rs::{CgroupPid, fs::Cgroup};
use nix::{
    sys::signal::{self, Signal},
    unistd::Pid,
};
use tokio::{
    io::{AsyncReadExt, AsyncWriteExt},
    process::Command,
    task::JoinHandle,
    time::{Instant, sleep},
};

use crate::{CommandArgs, ExitStatus, ResourceConfig, Result, cgroup, metrics::Metrics};

#[derive(Debug)]
pub struct Runner<'a> {
    pub args: CommandArgs<'a>,
    pub project_path: &'a Path,
    pub time_limit: Duration,
    pub resource_config: ResourceConfig,
    pub cgroup: Cgroup,
}

impl<'a> Runner<'a> {
    #[tracing::instrument(err)]
    pub fn new(
        args: CommandArgs<'a>,
        project_path: &'a Path,
        time_limit: Duration,
        resource_config: ResourceConfig,
    ) -> Result<Self> {
        let cgroup = resource_config.try_into()?;

        Ok(Self {
            args,
            project_path,
            time_limit,
            resource_config,
            cgroup,
        })
    }

    fn new_cpu_limiter(&self, id: i32) -> JoinHandle<()> {
        let cg = self.cgroup.clone();
        let quota = self.resource_config.quota;
        let period = Duration::from_millis(self.resource_config.period);

        tokio::spawn(async move {
            let Some(mut prev_usage) = cgroup::get_cpu_usage(&cg) else {
                return;
            };

            loop {
                sleep(period).await;

                let Some(cur_usage) = cgroup::get_cpu_usage(&cg) else {
                    break;
                };
                if cur_usage - prev_usage <= quota {
                    prev_usage = cur_usage;
                    continue;
                }

                signal::kill(Pid::from_raw(id), Signal::SIGSTOP).unwrap();
                sleep(period).await;
                signal::kill(Pid::from_raw(id), Signal::SIGCONT).unwrap();

                match cgroup::get_cpu_usage(&cg) {
                    Some(v) => prev_usage = v,
                    None => break,
                }
            }
        })
    }

    #[tracing::instrument(err)]
    pub async fn run(&self, input: &[u8]) -> Result<Metrics> {
        let CommandArgs { binary, args } = self.args;

        let start = Instant::now();

        let mut child = Command::new(binary)
            .current_dir(self.project_path)
            .args(args)
            .stdin(Stdio::piped())
            .stdout(Stdio::piped())
            .stderr(Stdio::piped())
            .spawn()?;

        let id = child.id().unwrap();
        self.cgroup.add_task_by_tgid(CgroupPid::from(id as u64))?;
        let cpu_limiter = self.new_cpu_limiter(id as i32);

        let mut stdin = child.stdin.take().unwrap();
        let mut stdout = child.stdout.take().unwrap();
        let mut stderr = child.stderr.take().unwrap();

        let stdout_observer = async move {
            let mut buffer = Vec::new();
            stdout.read_to_end(&mut buffer).await?;

            Ok::<_, io::Error>(buffer)
        };
        let stderr_observer = async move {
            let mut buffer = Vec::new();
            stderr.read_to_end(&mut buffer).await?;
            Ok::<_, io::Error>(buffer)
        };

        let exit_status = tokio::select! {
            exit_status = async {
                stdin.write_all(input).await?;
                drop(stdin);
                let exit_status = child.wait().await?;

                Ok::<_, io::Error>(exit_status)
            } => {
                exit_status.map(|raw| raw.into())
            }
            _ = sleep(self.time_limit) => {
                child.kill().await?;
                child.wait().await?;

                Ok(ExitStatus::TimeLimitExceeded)
            }
        }?;

        cpu_limiter.abort();

        let (stdout, stderr) = tokio::try_join!(stdout_observer, stderr_observer)?;

        Ok(Metrics {
            exit_status,
            stdout,
            stderr,
            run_time: start.elapsed(),
        })
    }
}

impl<'a> Drop for Runner<'a> {
    fn drop(&mut self) {
        let _ = self.cgroup.delete();
    }
}

#[cfg(test)]
mod test {
    use std::{
        path::{Path, PathBuf},
        time::Duration,
    };

    use bstr::ByteSlice;
    use byte_unit::Byte;
    use rstest::rstest;

    use crate::{
        CPP, ExitStatus, JAVA, Language, PYTHON, RUST, ResourceConfig, Runner,
        test::{read_code, read_test_cases},
    };

    #[rstest]
    #[tokio::test]
    async fn should_output_correct(
        #[values(CPP, RUST, JAVA, PYTHON)] language: Language<'static>,

        #[dirs]
        #[files("tests/data/problem/*")]
        problem_path: PathBuf,
    ) {
        let test_cases = read_test_cases(&problem_path);

        let code = read_code(language, &problem_path);
        let project_path = language.compiler.compile(&code).await.unwrap();

        let runner = Runner::new(
            language.runner_args,
            &project_path,
            Duration::from_secs(2),
            ResourceConfig::builder()
                .memory_limit(Byte::GIBIBYTE)
                .build(),
        )
        .unwrap();
        for (input, output) in test_cases {
            let metrics = runner.run(&input).await.unwrap();
            let metrics_out = metrics.stdout.trim();
            let test_case_out = output.trim();
            assert_eq!(metrics_out, test_case_out);
        }
    }

    #[rstest]
    #[tokio::test]
    async fn should_timeout(#[values(CPP, RUST, JAVA, PYTHON)] language: Language<'static>) {
        let code = read_code(language, Path::new("tests/data/timeout"));
        let project_path = language.compiler.compile(code.as_bytes()).await.unwrap();

        let runner = Runner::new(
            language.runner_args,
            &project_path,
            Duration::from_secs(2),
            ResourceConfig::builder()
                .memory_limit(Byte::GIBIBYTE)
                .build(),
        )
        .unwrap();

        let metrics = runner.run(b"").await.unwrap();

        assert_eq!(metrics.exit_status, ExitStatus::TimeLimitExceeded)
    }
}