ai-tournament 2.0.0

A modular Rust crate for running AI tournament
Documentation
use std::io::{Read, Write};
use std::net::{TcpListener, TcpStream};
use std::sync::Arc;
use std::thread;
use std::time::{Duration, Instant};

use anyhow::{anyhow, bail, Context};
use tracing::{error, instrument};

use crate::agent::Agent;
use crate::cgroup_manager::LimitedProcess;
use crate::constraints::Constraints;

#[derive(Debug)]
pub struct ClientHandler {
    stream: TcpStream,
    process: LimitedProcess,
    // config: Configuration,
}

impl ClientHandler {
    const RESPONSE_TIMEOUT_DURATION: Duration = Duration::from_secs(1);

    /// launch a child process running agent with given constraints.
    ///
    /// Child process is killed on drop. Child process's cgroup is cleaned up on drop.
    #[instrument(skip_all,fields(Agent=agent.name))]
    pub fn init(
        agent: Arc<Agent>,
        resources: &Constraints,
        allow_uncontained: bool,
        debug_process_stderr: bool,
    ) -> anyhow::Result<ClientHandler> {
        assert_eq!(
            resources.total_ram, resources.agent_ram,
            "incorrect ram to launch agent"
        );
        assert_eq!(
            resources.cpus.len(),
            resources.cpus_per_agent,
            "incorrect cpus to launch agents"
        );

        static HAVE_TASKSET: std::sync::LazyLock<bool> =
            std::sync::LazyLock::new(ClientHandler::test_taskset);
        static HAVE_CGROUPS_V2: std::sync::LazyLock<bool> =
            std::sync::LazyLock::new(ClientHandler::test_cgroups);

        // return early if agent has no binary
        let path = agent
            .path_to_exe
            .clone()
            .context("no path to executable")?
            .into_os_string()
            .into_string()
            .map_err(|_| anyhow!("path is not a valid string"))?;

        let listener = TcpListener::bind("127.0.0.1:0")
            .context("server error: could not create TcpListener")?;
        let port_arg = listener.local_addr()?.port().to_string();
        let time_budget_arg = (resources.time_budget.as_micros() as u64).to_string();
        let action_timeout_arg = (resources.action_timeout.as_micros() as u64).to_string();

        let max_memory = resources.total_ram;
        let cpus = resources
            .cpus
            .iter()
            .map(u8::to_string)
            .collect::<Vec<_>>()
            .join(",");

        if !*HAVE_TASKSET && !allow_uncontained {
            bail!(
                "taskset {}unavailable. Consider setting allow_uncontained to true.",
                if *HAVE_CGROUPS_V2 {
                    ""
                } else {
                    "and cgroups v2 "
                }
            );
        }

        if !*HAVE_CGROUPS_V2 && !allow_uncontained {
            bail!("cgroups v2 unavailable. Consider setting allow_uncontained to true");
        }

        let mut full_command = if *HAVE_TASKSET {
            vec![
                "taskset".to_string(),
                "-c".to_string(),
                cpus.clone(),
                path,
                port_arg,
                time_budget_arg,
                action_timeout_arg,
            ]
        } else {
            vec![path, port_arg, time_budget_arg, action_timeout_arg]
        };

        // append agent's arguments (from config file) to the args
        if let Some(args) = &agent.args {
            full_command.extend_from_slice(args);
        }
        let mut full_command = full_command.into_iter();

        let command = full_command.next().unwrap();
        let args = full_command.collect::<Vec<_>>();

        let mut process = if *HAVE_CGROUPS_V2 {
            LimitedProcess::launch(
                &command,
                &args,
                max_memory as i64,
                &cpus,
                debug_process_stderr,
            )
            .context("server error: child + cgroup creation failed")?
        } else {
            LimitedProcess::launch_without_container(&command, &args, debug_process_stderr)?
        };

        listener
            .set_nonblocking(true)
            .context("server error: setting non-blocking to true")?;

        let response_timeout = Instant::now() + Self::RESPONSE_TIMEOUT_DURATION;
        while Instant::now() < response_timeout {
            if let Ok((stream, _addr)) = listener.accept() {
                return Ok(ClientHandler {
                    stream,
                    process,
                    // config,
                });
            }
            // at least 10 tries
            thread::sleep(Duration::from_millis(10).min(Self::RESPONSE_TIMEOUT_DURATION / 10));
        }

        //FIXME: panic
        process.try_kill(Duration::from_secs(1)).unwrap();
        Err(anyhow!("no connection made to server"))
    }

    #[instrument]
    pub fn send_and_recv(
        &mut self,
        msg: &[u8],
        buf: &mut [u8],
        max_duration: Duration,
    ) -> anyhow::Result<usize> {
        self.stream
            .set_nonblocking(true)
            .context("server error: setting non-blocking for 'write'")?;

        match self.stream.write(msg) {
            Ok(0) => {
                return Err(anyhow!("connection closed by client"));
            }
            Ok(n) => {
                if n < msg.len() {
                    error!(
                        "only {}/{} bytes of {} were sent",
                        n,
                        msg.len(),
                        std::str::from_utf8(msg).unwrap_or("NON_VALID_UTF8")
                    );
                    return Err(anyhow!(
                        "msg transmission error: only {}/{} bytes sent",
                        n,
                        msg.len()
                    ));
                }
            }
            Err(e) => {
                return Err(e).context("I/O error while sending msg");
            }
        }
        self.stream
            .set_nonblocking(false)
            .context("server error: setting blocking for 'read'")?;

        self.stream
            .set_read_timeout(Some(max_duration))
            .context("server error: setting read timeout")?;

        let n = self
            .stream
            .read(buf)
            .context("server could not read stream")?;
        Ok(n)
    }

    fn kill_child_process(&mut self) -> anyhow::Result<()> {
        self.process.try_kill(Duration::from_secs(1))
    }

    fn test_taskset() -> bool {
        std::process::Command::new("taskset")
            .arg("-V")
            .output()
            .is_ok()
    }

    #[cfg(unix)]
    fn test_cgroups() -> bool {
        match LimitedProcess::launch("pwd", &[], 1000, "0", false) {
            Ok(mut p) => {
                let _ = p.child.wait();
                let _ = p.try_kill(Duration::from_secs(1));
                true
            }
            Err(_) => false,
        }
    }

    #[cfg(not(unix))]
    fn test_cgroups() -> bool {
        false
    }
}

impl Drop for ClientHandler {
    fn drop(&mut self) {
        //FIXME: panic
        if let Err(e) = self.kill_child_process() {
            error!(
                "POTENTIAL RESOURCE LEAK: COULD NOT KILL PROCESS CHILD: {e:#?},\n {:#?}",
                self.process.child
            );
            // self.process.try_debug_cgroup();
            // .expect("could not kill child process");
        }
    }
}