ai-tournament 3.0.0

A modular Rust crate for running AI tournament
Documentation
use std::{
    fs::File,
    process::Child,
    time::{Duration, Instant},
};

use anyhow::{self, Context};
use cgroups_rs::Cgroup;

use super::create_process;

pub fn get_current_user_id() -> anyhow::Result<String> {
    let output = std::process::Command::new("id")
        .arg("-u")
        .output()
        .context("Could not launch 'id -u'")?;
    let stdout = output.stdout;
    let untrimed_id = std::str::from_utf8(&stdout).context("id is not a valid string")?;
    Ok(untrimed_id.trim().to_string())
}

pub fn get_cgroup_path(user_id: &str, group_name: &str) -> String {
    format!("user.slice/user-{user_id}.slice/user@{user_id}.service/{group_name}")
}

/// Create a cgroup at `path`.
///
/// The cgroup will have the provided limitations.
///
/// * `max_memory` - Maximum available memory in Bytes. Non-positive means no restriction.
/// * `max_pids` - Maximum number of PIDS inside the cgroup at any time. Non-positive means no restriction.
/// * `cpus` - which cpus the members can run one. Uses comma separated cpu ranges ("1-5,7", "1,3,4", ...). Empty string means no restriction.
///
/// # Errors
///
/// This function will return an error if the cgroup could not be created. This can happen if the parameters are incorrect or if cgroup is not available.
pub fn create_cgroup(
    path: &str,
    max_memory: i64,
    max_pids: i64,
    cpus: &str,
) -> anyhow::Result<cgroups_rs::Cgroup> {
    let mut builder = cgroups_rs::cgroup_builder::CgroupBuilder::new(path);
    if max_memory > 0 {
        builder = builder.memory().memory_hard_limit(max_memory).done();
    }
    if max_pids > 0 {
        builder = builder
            .pid()
            .maximum_number_of_processes(cgroups_rs::MaxValue::Value(max_pids))
            .done();
    }
    if !cpus.is_empty() {
        builder = builder.cpu().cpus(cpus.to_string()).done();
    }
    builder
        .build(cgroups_rs::hierarchies::auto())
        .context("could not create cgroup")
}

#[derive(Debug)]
pub struct TimeoutError {}

impl std::fmt::Display for TimeoutError {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        write!(f, "Timeout Error")
    }
}

impl std::error::Error for TimeoutError {}

pub fn wait_for_process_cleanup(
    cgroup: &cgroups_rs::Cgroup,
    pid: u64,
    max_duration: Duration,
) -> Result<(), TimeoutError> {
    let deadline = Instant::now() + max_duration;
    while cgroup.tasks().iter().any(|cpid| cpid.pid == pid) {
        if Instant::now() > deadline {
            return Err(TimeoutError {});
        }

        std::thread::sleep(std::cmp::min(Duration::from_millis(10), max_duration / 10));
    }
    Ok(())
}

pub fn create_process_in_cgroup(
    command: &str,
    args: &[String],
    group: &cgroups_rs::Cgroup,
    allow_stderr: bool,
    log_file: &Option<File>,
) -> anyhow::Result<std::process::Child> {
    let mut child = create_process(command, args, allow_stderr, log_file)?;

    let pid = child.id() as u64;
    let addition = group.add_task_by_tgid(cgroups_rs::CgroupPid { pid });
    if addition.is_err() {
        let kill = child.kill();

        addition.with_context(|| {
            if let Err(err) = kill {
                format!(
                    "could not add process to cgroup, and process could not be killed either ({err})"
                )
            } else {
                "could not add process to cgroup".to_string()
            }
        })?;
    }
    Ok(child)
}

#[derive(Debug)]
pub struct LimitedProcess {
    pub child: Child,
    cgroup: Option<Cgroup>,
    cleaned_up: bool,
}

impl LimitedProcess {
    pub fn launch(
        command: &str,
        args: &[String],
        max_memory: i64,
        cpus: &str,
        allow_stderr: bool,
        log_file: &Option<File>,
    ) -> anyhow::Result<LimitedProcess> {
        static COUNTER: std::sync::atomic::AtomicU32 = std::sync::atomic::AtomicU32::new(1); // lazy cell ? (if multiple evaluations at the same time !)
        let user_id = get_current_user_id().context("could not get user id")?;
        // generate a new cgroup name for each Limited Process
        let group_name = "CGROUP_MANAGER_".to_owned()
            + &COUNTER
                .fetch_add(1, std::sync::atomic::Ordering::Relaxed)
                .to_string();
        let path = get_cgroup_path(&user_id, &group_name);
        let group =
            create_cgroup(&path, max_memory, 100, cpus).context("could not create cgroup")?;
        let child = create_process_in_cgroup(command, args, &group, allow_stderr, log_file)
            .with_context(|| {
                let _ = group.delete();
                "could not create process in cgroup"
            })?;

        Ok(LimitedProcess {
            child,
            cgroup: Some(group),
            cleaned_up: false,
        })
    }

    pub fn try_kill(&mut self, max_duration: Duration) -> anyhow::Result<()> {
        match &mut self.cgroup {
            Some(cgroup) => {
                self.child.kill().context("could not kill child process")?; // start with (blocking) process kill
                cgroup.kill().context("could not kill process")?;
                wait_for_process_cleanup(cgroup, self.child.id() as u64, max_duration)
                    .context("process cleanup timed out")?;
                // at this point, the process is killed. Even so the cgroup cleanup fail, it is
                // 'safe' (probably) to continue
                self.cleaned_up = true;
                if let Err(e) = cgroup.delete() {
                    // Oh well... Whatever...
                    tracing::warn!("Failed to remove cgroup. If this happens a lot, it may slow down the computer. {e}");
                }
                Ok(())
            }
            None => {
                self.child.kill().context("could not kill process")?;
                self.cleaned_up = true;
                Ok(())
            }
        }
    }

    pub fn launch_without_container(
        command: &str,
        args: &[String],
        allow_stderr: bool,
        log_file: &Option<File>,
    ) -> anyhow::Result<LimitedProcess> {
        let child = create_process(command, args, allow_stderr, log_file)
            .context("could not create process")?;

        Ok(LimitedProcess {
            child,
            cgroup: None,
            cleaned_up: false,
        })
    }

    /// Will print out as much info as possible
    #[allow(dead_code)]
    pub(crate) fn try_debug_cgroup(&mut self) {
        let pid = self.child.id();
        let mut p = String::new();
        p += "/sys/fs/cgroup/";
        p += self.cgroup.as_ref().unwrap().path();
        println!("Path: {p:?}");
        Self::exec(&format!("lsof +D {p}"));
        Self::exec(&format!("cat {p}/cgroup.procs"));
        Self::exec(&format!("cat {p}/cgroup.stat"));
        Self::exec(&format!("cat {p}/pids.current"));
        Self::exec(&format!("ps -Flww -p {pid}"));
        Self::exec(&format!("cat /proc/{pid}/status"));
        if let Err(e) = self.try_kill(Duration::from_millis(100)) {
            println!("failed to kill again: {e:#}");
        } else {
            println!("successfully killed this time ??");
        }
        Self::exec(&format!("rmdir {p}"));
    }

    #[allow(dead_code)]
    fn exec(cmd: &str) {
        let mut iter = cmd.split(" ");
        let program = iter.next().unwrap();
        let args = iter.collect::<Vec<_>>();
        let output = std::process::Command::new(program)
            .args(&args)
            .output()
            .unwrap();
        println!(
            "$ {cmd}\n\x1b[31m{}\x1b[39m{}",
            std::str::from_utf8(&output.stderr).unwrap(),
            std::str::from_utf8(&output.stdout).unwrap()
        );
    }
}

impl Drop for LimitedProcess {
    fn drop(&mut self) {
        static CLEANUP_DURATION: Duration = Duration::from_secs(1);
        if !self.cleaned_up {
            // warn!(
            //     "Process {} was not cleaned up before dropping. Trying to clean up for up to {:?}...",
            //     self.child.id(),
            //     CLEANUP_DURATION
            // );
            match self.try_kill(CLEANUP_DURATION) {
                Ok(_) => { /* happy dance */ }
                Err(e) => {
                    if std::env::var("DEBUG_CGROUP").is_ok() {
                        self.try_debug_cgroup();
                    }
                    panic!("could not kill process/cgroup on LimitedProcess::drop: {e}");
                }
            }
        }
    }
}

#[cfg(test)]
mod cgroup_manager_tests {

    use std::{io::Read, process::Stdio, time::Duration};

    use super::*;

    #[test]
    fn launch_something() {
        use std::process;

        let proc = process::Command::new("echo")
            .args(vec!["Hello", "World"])
            .stdout(Stdio::piped())
            .spawn()
            .expect("Could not spawn child");
        let mut res = proc.stdout.expect("No result ?");

        let mut buffer = String::new();
        let _length = res
            .read_to_string(&mut buffer)
            .expect("Could not make a string ?");

        println!("{buffer}");
    }

    #[test]
    fn test_create_cgroup() {
        //NOTE: future work: implement the Windows equivalent: "Job Object"
        assert_eq!(
            std::env::consts::OS,
            "linux",
            "Cgroups are only implemented on linux."
        );

        let my_hierarchy = cgroups_rs::hierarchies::auto();
        if my_hierarchy.v2() {
            println!("V2 Hierarchy");
        } else {
            println!("V1 Hierarchy /!\\ THIS CASE IS UNTESTED");
        }

        let my_id = get_current_user_id().expect("Could not get user ID");

        println!("User id: {my_id}");

        let group_name = "my_cgroup";

        let new_group_path = get_cgroup_path(&my_id, &group_name);

        println!("Future new group path: {new_group_path}");

        let my_group = create_cgroup(&new_group_path, 1024 * 1024, 3, "1-3,5")
            .expect("Could not create cgroup...");
        println!("path: {}", my_group.path());

        my_group.delete().expect("Could not delete cgroup")
    }

    #[test]
    fn test_create_process_in_cgroup() {
        let id = get_current_user_id().unwrap();
        let path = get_cgroup_path(&id, "rust_group");
        let group = create_cgroup(&path, 1024 * 1024, 0, "").unwrap();
        println!("Cgroup created");
        let process = std::process::Command::new("sleep").arg("10").spawn();
        if let Ok(mut child) = process {
            let pid = child.id() as u64;
            println!("Process {pid} created");
            match group.add_task_by_tgid(cgroups_rs::CgroupPid { pid }) {
                Err(e) => {
                    println!("Could not add task to cgroup: {e}");
                }
                _ => {
                    println!("Task added to cgroup");
                    println!("Waiting for response...");
                    // sleep for ...ms and then try get result ?
                    // BUT loss time if it finishes "early"
                    println!("Finished waiting");
                    let result = child.stdout.take();
                    let is_late_or_incorrect = match result {
                        Some(_answer) => {
                            println!(
                                "The process responded on time and the response is acceptable"
                            );
                            false
                        } // !is_answer_ok(answer)
                        None => {
                            println!("Process is late !");
                            true
                        }
                    };
                    if is_late_or_incorrect {
                        println!("Attempting to kill process");
                        // kill
                        group.kill().unwrap_or_else(|e| {
                        println!("Could not kill process. Must wait 10s to let it \"die by itself\", to avoid error in cgroup.delete(). Error: {e}");
                        std::thread::sleep(Duration::from_secs(10));
                    });
                        wait_for_process_cleanup(&group, pid, Duration::from_millis(100))
                            .unwrap_or_else(|e| println!("Process cleanup did not end well: {e}"));
                    } else {
                        // release (auto ?)
                    }
                }
            }
        } else {
            let error = process.unwrap_err();
            println!("Process creation failed: {}", error);
        }
        println!("Deleting cgroup.");
        group.delete().unwrap_or_else(|e| {
            println!("Could not delete cgroup ! Is there any descendant left ? ({e})");
            let procs = group.tasks();
            println!("PIDS: {:?}", procs);
        });
    }
}