use std::io::{Read, Write};
use std::net::{TcpListener, TcpStream};
use std::sync::Arc;
use std::thread;
use std::time::{Duration, Instant};
use anyhow::{anyhow, bail, Context};
use tracing::{error, instrument};
use crate::agent::Agent;
use crate::cgroup_manager::LimitedProcess;
use crate::constraints::Constraints;
#[derive(Debug)]
pub struct ClientHandler {
stream: TcpStream,
process: LimitedProcess,
}
impl ClientHandler {
const RESPONSE_TIMEOUT_DURATION: Duration = Duration::from_secs(1);
#[instrument(skip_all,fields(Agent=agent.name))]
pub fn init(
agent: Arc<Agent>,
resources: &Constraints,
allow_uncontained: bool,
debug_process_stderr: bool,
) -> anyhow::Result<ClientHandler> {
assert_eq!(
resources.total_ram, resources.agent_ram,
"incorrect ram to launch agent"
);
assert_eq!(
resources.cpus.len(),
resources.cpus_per_agent,
"incorrect cpus to launch agents"
);
static HAVE_TASKSET: std::sync::LazyLock<bool> =
std::sync::LazyLock::new(ClientHandler::test_taskset);
static HAVE_CGROUPS_V2: std::sync::LazyLock<bool> =
std::sync::LazyLock::new(ClientHandler::test_cgroups);
let path = agent
.path_to_exe
.clone()
.context("no path to executable")?
.into_os_string()
.into_string()
.map_err(|_| anyhow!("path is not a valid string"))?;
let listener = TcpListener::bind("127.0.0.1:0")
.context("server error: could not create TcpListener")?;
let port_arg = listener.local_addr()?.port().to_string();
let time_budget_arg = (resources.time_budget.as_micros() as u64).to_string();
let action_timeout_arg = (resources.action_timeout.as_micros() as u64).to_string();
let max_memory = resources.total_ram;
let cpus = resources
.cpus
.iter()
.map(u8::to_string)
.collect::<Vec<_>>()
.join(",");
if !*HAVE_TASKSET && !allow_uncontained {
bail!(
"taskset {}unavailable. Consider setting allow_uncontained to true.",
if *HAVE_CGROUPS_V2 {
""
} else {
"and cgroups v2 "
}
);
}
if !*HAVE_CGROUPS_V2 && !allow_uncontained {
bail!("cgroups v2 unavailable. Consider setting allow_uncontained to true");
}
let mut full_command = if *HAVE_TASKSET {
vec![
"taskset".to_string(),
"-c".to_string(),
cpus.clone(),
path,
port_arg,
time_budget_arg,
action_timeout_arg,
]
} else {
vec![path, port_arg, time_budget_arg, action_timeout_arg]
};
if let Some(args) = &agent.args {
full_command.extend_from_slice(args);
}
let mut full_command = full_command.into_iter();
let command = full_command.next().unwrap();
let args = full_command.collect::<Vec<_>>();
let mut process = if *HAVE_CGROUPS_V2 {
LimitedProcess::launch(
&command,
&args,
max_memory as i64,
&cpus,
debug_process_stderr,
)
.context("server error: child + cgroup creation failed")?
} else {
LimitedProcess::launch_without_container(&command, &args, debug_process_stderr)?
};
listener
.set_nonblocking(true)
.context("server error: setting non-blocking to true")?;
let response_timeout = Instant::now() + Self::RESPONSE_TIMEOUT_DURATION;
while Instant::now() < response_timeout {
if let Ok((stream, _addr)) = listener.accept() {
return Ok(ClientHandler {
stream,
process,
});
}
thread::sleep(Duration::from_millis(10).min(Self::RESPONSE_TIMEOUT_DURATION / 10));
}
process.try_kill(Duration::from_secs(1)).unwrap();
Err(anyhow!("no connection made to server"))
}
#[instrument]
pub fn send_and_recv(
&mut self,
msg: &[u8],
buf: &mut [u8],
max_duration: Duration,
) -> anyhow::Result<usize> {
self.stream
.set_nonblocking(true)
.context("server error: setting non-blocking for 'write'")?;
match self.stream.write(msg) {
Ok(0) => {
return Err(anyhow!("connection closed by client"));
}
Ok(n) => {
if n < msg.len() {
error!(
"only {}/{} bytes of {} were sent",
n,
msg.len(),
std::str::from_utf8(msg).unwrap_or("NON_VALID_UTF8")
);
return Err(anyhow!(
"msg transmission error: only {}/{} bytes sent",
n,
msg.len()
));
}
}
Err(e) => {
return Err(e).context("I/O error while sending msg");
}
}
self.stream
.set_nonblocking(false)
.context("server error: setting blocking for 'read'")?;
self.stream
.set_read_timeout(Some(max_duration))
.context("server error: setting read timeout")?;
let n = self
.stream
.read(buf)
.context("server could not read stream")?;
Ok(n)
}
fn kill_child_process(&mut self) -> anyhow::Result<()> {
self.process.try_kill(Duration::from_secs(1))
}
fn test_taskset() -> bool {
std::process::Command::new("taskset")
.arg("-V")
.output()
.is_ok()
}
#[cfg(unix)]
fn test_cgroups() -> bool {
match LimitedProcess::launch("pwd", &[], 1000, "0", false) {
Ok(mut p) => {
let _ = p.child.wait();
let _ = p.try_kill(Duration::from_secs(1));
true
}
Err(_) => false,
}
}
#[cfg(not(unix))]
fn test_cgroups() -> bool {
false
}
}
impl Drop for ClientHandler {
fn drop(&mut self) {
if let Err(e) = self.kill_child_process() {
error!(
"POTENTIAL RESOURCE LEAK: COULD NOT KILL PROCESS CHILD: {e:#?},\n {:#?}",
self.process.child
);
}
}
}