use std::{
fs::File,
process::Child,
time::{Duration, Instant},
};
use anyhow::{self, Context};
use cgroups_rs::Cgroup;
use super::create_process;
pub fn get_current_user_id() -> anyhow::Result<String> {
let output = std::process::Command::new("id")
.arg("-u")
.output()
.context("Could not launch 'id -u'")?;
let stdout = output.stdout;
let untrimed_id = std::str::from_utf8(&stdout).context("id is not a valid string")?;
Ok(untrimed_id.trim().to_string())
}
pub fn get_cgroup_path(user_id: &str, group_name: &str) -> String {
format!("user.slice/user-{user_id}.slice/user@{user_id}.service/{group_name}")
}
pub fn create_cgroup(
path: &str,
max_memory: i64,
max_pids: i64,
cpus: &str,
) -> anyhow::Result<cgroups_rs::Cgroup> {
let mut builder = cgroups_rs::cgroup_builder::CgroupBuilder::new(path);
if max_memory > 0 {
builder = builder.memory().memory_hard_limit(max_memory).done();
}
if max_pids > 0 {
builder = builder
.pid()
.maximum_number_of_processes(cgroups_rs::MaxValue::Value(max_pids))
.done();
}
if !cpus.is_empty() {
builder = builder.cpu().cpus(cpus.to_string()).done();
}
builder
.build(cgroups_rs::hierarchies::auto())
.context("could not create cgroup")
}
#[derive(Debug)]
pub struct TimeoutError {}
impl std::fmt::Display for TimeoutError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Timeout Error")
}
}
impl std::error::Error for TimeoutError {}
pub fn wait_for_process_cleanup(
cgroup: &cgroups_rs::Cgroup,
pid: u64,
max_duration: Duration,
) -> Result<(), TimeoutError> {
let deadline = Instant::now() + max_duration;
while cgroup.tasks().iter().any(|cpid| cpid.pid == pid) {
if Instant::now() > deadline {
return Err(TimeoutError {});
}
std::thread::sleep(std::cmp::min(Duration::from_millis(10), max_duration / 10));
}
Ok(())
}
pub fn create_process_in_cgroup(
command: &str,
args: &[String],
group: &cgroups_rs::Cgroup,
allow_stderr: bool,
log_file: &Option<File>,
) -> anyhow::Result<std::process::Child> {
let mut child = create_process(command, args, allow_stderr, log_file)?;
let pid = child.id() as u64;
let addition = group.add_task_by_tgid(cgroups_rs::CgroupPid { pid });
if addition.is_err() {
let kill = child.kill();
addition.with_context(|| {
if let Err(err) = kill {
format!(
"could not add process to cgroup, and process could not be killed either ({err})"
)
} else {
"could not add process to cgroup".to_string()
}
})?;
}
Ok(child)
}
#[derive(Debug)]
pub struct LimitedProcess {
pub child: Child,
cgroup: Option<Cgroup>,
cleaned_up: bool,
}
impl LimitedProcess {
pub fn launch(
command: &str,
args: &[String],
max_memory: i64,
cpus: &str,
allow_stderr: bool,
log_file: &Option<File>,
) -> anyhow::Result<LimitedProcess> {
static COUNTER: std::sync::atomic::AtomicU32 = std::sync::atomic::AtomicU32::new(1); let user_id = get_current_user_id().context("could not get user id")?;
let group_name = "CGROUP_MANAGER_".to_owned()
+ &COUNTER
.fetch_add(1, std::sync::atomic::Ordering::Relaxed)
.to_string();
let path = get_cgroup_path(&user_id, &group_name);
let group =
create_cgroup(&path, max_memory, 100, cpus).context("could not create cgroup")?;
let child = create_process_in_cgroup(command, args, &group, allow_stderr, log_file)
.with_context(|| {
let _ = group.delete();
"could not create process in cgroup"
})?;
Ok(LimitedProcess {
child,
cgroup: Some(group),
cleaned_up: false,
})
}
pub fn try_kill(&mut self, max_duration: Duration) -> anyhow::Result<()> {
match &mut self.cgroup {
Some(cgroup) => {
self.child.kill().context("could not kill child process")?; cgroup.kill().context("could not kill process")?;
wait_for_process_cleanup(cgroup, self.child.id() as u64, max_duration)
.context("process cleanup timed out")?;
self.cleaned_up = true;
if let Err(e) = cgroup.delete() {
tracing::warn!("Failed to remove cgroup. If this happens a lot, it may slow down the computer. {e}");
}
Ok(())
}
None => {
self.child.kill().context("could not kill process")?;
self.cleaned_up = true;
Ok(())
}
}
}
pub fn launch_without_container(
command: &str,
args: &[String],
allow_stderr: bool,
log_file: &Option<File>,
) -> anyhow::Result<LimitedProcess> {
let child = create_process(command, args, allow_stderr, log_file)
.context("could not create process")?;
Ok(LimitedProcess {
child,
cgroup: None,
cleaned_up: false,
})
}
#[allow(dead_code)]
pub(crate) fn try_debug_cgroup(&mut self) {
let pid = self.child.id();
let mut p = String::new();
p += "/sys/fs/cgroup/";
p += self.cgroup.as_ref().unwrap().path();
println!("Path: {p:?}");
Self::exec(&format!("lsof +D {p}"));
Self::exec(&format!("cat {p}/cgroup.procs"));
Self::exec(&format!("cat {p}/cgroup.stat"));
Self::exec(&format!("cat {p}/pids.current"));
Self::exec(&format!("ps -Flww -p {pid}"));
Self::exec(&format!("cat /proc/{pid}/status"));
if let Err(e) = self.try_kill(Duration::from_millis(100)) {
println!("failed to kill again: {e:#}");
} else {
println!("successfully killed this time ??");
}
Self::exec(&format!("rmdir {p}"));
}
#[allow(dead_code)]
fn exec(cmd: &str) {
let mut iter = cmd.split(" ");
let program = iter.next().unwrap();
let args = iter.collect::<Vec<_>>();
let output = std::process::Command::new(program)
.args(&args)
.output()
.unwrap();
println!(
"$ {cmd}\n\x1b[31m{}\x1b[39m{}",
std::str::from_utf8(&output.stderr).unwrap(),
std::str::from_utf8(&output.stdout).unwrap()
);
}
}
impl Drop for LimitedProcess {
fn drop(&mut self) {
static CLEANUP_DURATION: Duration = Duration::from_secs(1);
if !self.cleaned_up {
match self.try_kill(CLEANUP_DURATION) {
Ok(_) => { }
Err(e) => {
if std::env::var("DEBUG_CGROUP").is_ok() {
self.try_debug_cgroup();
}
panic!("could not kill process/cgroup on LimitedProcess::drop: {e}");
}
}
}
}
}
#[cfg(test)]
mod cgroup_manager_tests {
use std::{io::Read, process::Stdio, time::Duration};
use super::*;
#[test]
fn launch_something() {
use std::process;
let proc = process::Command::new("echo")
.args(vec!["Hello", "World"])
.stdout(Stdio::piped())
.spawn()
.expect("Could not spawn child");
let mut res = proc.stdout.expect("No result ?");
let mut buffer = String::new();
let _length = res
.read_to_string(&mut buffer)
.expect("Could not make a string ?");
println!("{buffer}");
}
#[test]
fn test_create_cgroup() {
assert_eq!(
std::env::consts::OS,
"linux",
"Cgroups are only implemented on linux."
);
let my_hierarchy = cgroups_rs::hierarchies::auto();
if my_hierarchy.v2() {
println!("V2 Hierarchy");
} else {
println!("V1 Hierarchy /!\\ THIS CASE IS UNTESTED");
}
let my_id = get_current_user_id().expect("Could not get user ID");
println!("User id: {my_id}");
let group_name = "my_cgroup";
let new_group_path = get_cgroup_path(&my_id, &group_name);
println!("Future new group path: {new_group_path}");
let my_group = create_cgroup(&new_group_path, 1024 * 1024, 3, "1-3,5")
.expect("Could not create cgroup...");
println!("path: {}", my_group.path());
my_group.delete().expect("Could not delete cgroup")
}
#[test]
fn test_create_process_in_cgroup() {
let id = get_current_user_id().unwrap();
let path = get_cgroup_path(&id, "rust_group");
let group = create_cgroup(&path, 1024 * 1024, 0, "").unwrap();
println!("Cgroup created");
let process = std::process::Command::new("sleep").arg("10").spawn();
if let Ok(mut child) = process {
let pid = child.id() as u64;
println!("Process {pid} created");
match group.add_task_by_tgid(cgroups_rs::CgroupPid { pid }) {
Err(e) => {
println!("Could not add task to cgroup: {e}");
}
_ => {
println!("Task added to cgroup");
println!("Waiting for response...");
println!("Finished waiting");
let result = child.stdout.take();
let is_late_or_incorrect = match result {
Some(_answer) => {
println!(
"The process responded on time and the response is acceptable"
);
false
} None => {
println!("Process is late !");
true
}
};
if is_late_or_incorrect {
println!("Attempting to kill process");
group.kill().unwrap_or_else(|e| {
println!("Could not kill process. Must wait 10s to let it \"die by itself\", to avoid error in cgroup.delete(). Error: {e}");
std::thread::sleep(Duration::from_secs(10));
});
wait_for_process_cleanup(&group, pid, Duration::from_millis(100))
.unwrap_or_else(|e| println!("Process cleanup did not end well: {e}"));
} else {
}
}
}
} else {
let error = process.unwrap_err();
println!("Process creation failed: {}", error);
}
println!("Deleting cgroup.");
group.delete().unwrap_or_else(|e| {
println!("Could not delete cgroup ! Is there any descendant left ? ({e})");
let procs = group.tasks();
println!("PIDS: {:?}", procs);
});
}
}