use tokio::process::Command;
use crate::ExecError;
#[derive(Debug, Clone, Copy)]
pub struct CpuMax {
pub quota: Option<u64>,
pub period: u64,
}
impl Default for CpuMax {
fn default() -> Self {
Self {
quota: None,
period: 100_000,
}
}
}
#[derive(Debug, Clone, Default)]
pub struct CgroupLimits {
pub cpu: Option<CpuMax>,
pub memory: Option<u64>,
pub pids: Option<u64>,
pub fail_on_error: bool,
}
impl CgroupLimits {
#[inline]
pub fn is_empty(&self) -> bool {
self.cpu.is_none() && self.memory.is_none() && self.pids.is_none()
}
}
pub(crate) fn prepare_cgroup(cgroup_name: &str, limits: &CgroupLimits) -> Result<bool, ExecError> {
if limits.is_empty() {
return Ok(false);
}
#[cfg(target_os = "linux")]
{
linux_impl::prepare(cgroup_name, limits)
}
#[cfg(not(target_os = "linux"))]
{
tracing::warn!(
"cgroup v2 limits requested for '{}', but OS={} does not support them; limits will be ignored",
cgroup_name,
std::env::consts::OS
);
Ok(false)
}
}
pub(crate) fn attach_cgroup(
cmd: &mut Command,
cgroup_name: &str,
limits: &CgroupLimits,
) -> Result<(), ExecError> {
if limits.is_empty() {
return Ok(());
}
#[cfg(target_os = "linux")]
{
linux_impl::attach_join_hook(cmd, cgroup_name, limits.fail_on_error);
}
#[cfg(not(target_os = "linux"))]
{
let _ = (&cmd, cgroup_name, limits);
}
Ok(())
}
#[cfg(target_os = "linux")]
pub fn cleanup_cgroup(cgroup_name: &str) {
use std::path::Path;
let full_path = Path::new("/sys/fs/cgroup").join(cgroup_name);
match std::fs::remove_dir(&full_path) {
Ok(()) => {
tracing::debug!("removed cgroup: {}", cgroup_name);
}
Err(e) if e.kind() == std::io::ErrorKind::NotFound => {
tracing::trace!("cgroup '{}' not found (already removed)", cgroup_name);
}
Err(e) => {
tracing::debug!(
"cgroup '{}' cleanup skipped: {} (errno={:?})",
cgroup_name,
e,
e.raw_os_error(),
);
}
}
}
#[cfg(not(target_os = "linux"))]
pub fn cleanup_cgroup(_cgroup_name: &str) {}
pub fn build_cgroup_name(runner_tag: &str, slot: &str, seq: u64, timestamp: u64) -> String {
format!("{}-{}-{:x}-{:x}", runner_tag, slot, seq, timestamp)
}
#[cfg(target_os = "linux")]
mod linux_impl {
use super::{CgroupLimits, CpuMax};
use crate::utils::log::{pre_exec_log, pre_exec_log_errno};
use std::{
fs, io,
path::{Path, PathBuf},
};
use tokio::process::Command;
const CONTROLLERS_FILE: &str = "cgroup.controllers";
const CGROUP_ROOT: &str = "/sys/fs/cgroup";
const CGROUP_PROCS_SUFFIX: &str = "/cgroup.procs";
pub fn prepare(cgroup_name: &str, limits: &CgroupLimits) -> Result<bool, crate::ExecError> {
if !is_cgroup_v2(Path::new(CGROUP_ROOT)) {
tracing::warn!("cgroup v2 not detected at /sys/fs/cgroup; limits will be ignored");
return if limits.fail_on_error {
Err(crate::ExecError::InvalidRunnerConfig(
"cgroup v2 not available".into(),
))
} else {
Ok(false)
};
}
let cg_dir = Path::new(CGROUP_ROOT).join(cgroup_name);
fs::create_dir_all(&cg_dir).map_err(|e| {
crate::ExecError::Io(io::Error::other(format!(
"failed to create cgroup directory '{}': {e}",
cg_dir.display()
)))
})?;
apply_limits(&cg_dir, limits).map_err(|e| {
crate::ExecError::Io(io::Error::other(format!(
"failed to apply cgroup limits for '{}': {e}",
cg_dir.display()
)))
})?;
Ok(true)
}
const MAX_PROCS_PATH: usize = 256;
#[derive(Clone, Copy)]
struct ProcsPath {
buf: [u8; MAX_PROCS_PATH],
len: usize,
}
impl ProcsPath {
fn build(cgroup_name: &str) -> Option<Self> {
let total = CGROUP_ROOT.len() + 1 + cgroup_name.len() + CGROUP_PROCS_SUFFIX.len() + 1;
if total > MAX_PROCS_PATH {
return None;
}
let mut buf = [0u8; MAX_PROCS_PATH];
let mut pos = 0;
let parts: &[&[u8]] = &[
CGROUP_ROOT.as_bytes(),
b"/",
cgroup_name.as_bytes(),
CGROUP_PROCS_SUFFIX.as_bytes(),
b"\0",
];
for part in parts {
buf[pos..pos + part.len()].copy_from_slice(part);
pos += part.len();
}
Some(Self { buf, len: pos })
}
fn as_bytes(&self) -> &[u8] {
&self.buf[..self.len]
}
}
pub fn attach_join_hook(cmd: &mut Command, cgroup_name: &str, fail_on_error: bool) {
let procs_path = match ProcsPath::build(cgroup_name) {
Some(p) => p,
None => {
pre_exec_log(b"solti-exec: cgroup path exceeds 256 bytes, skipping join\n");
return;
}
};
unsafe {
cmd.pre_exec(move || join_cgroup_raw(procs_path.as_bytes(), fail_on_error));
}
}
fn is_cgroup_v2(root: &Path) -> bool {
root.join(CONTROLLERS_FILE).is_file()
}
fn apply_limits(dir: &Path, limits: &CgroupLimits) -> io::Result<()> {
if let Some(cpu) = limits.cpu {
write_cpu_max(dir.join("cpu.max"), cpu)?;
}
if let Some(mem) = limits.memory {
write_limit(dir.join("memory.max"), mem)?;
}
if let Some(pids) = limits.pids {
write_limit(dir.join("pids.max"), pids)?;
}
Ok(())
}
fn write_cpu_max(path: PathBuf, limit: CpuMax) -> io::Result<()> {
let content = match limit.quota {
None => format!("max {}\n", limit.period),
Some(q) => format!("{q} {}\n", limit.period),
};
fs::write(path, content)
}
fn write_limit(path: PathBuf, val: u64) -> io::Result<()> {
fs::write(path, format!("{val}\n"))
}
fn join_cgroup_raw(procs_path_cstr: &[u8], fail_on_error: bool) -> io::Result<()> {
let fd = unsafe {
libc::open(
procs_path_cstr.as_ptr() as *const libc::c_char,
libc::O_WRONLY,
)
};
if fd < 0 {
let e = io::Error::last_os_error();
pre_exec_log(b"solti-exec: failed to open cgroup.procs: ");
if let Some(code) = e.raw_os_error() {
pre_exec_log_errno(code);
}
return if fail_on_error { Err(e) } else { Ok(()) };
}
let pid = unsafe { libc::getpid() };
let mut buf = [0u8; 24];
let pid_str = super::format_pid(pid, &mut buf);
let written =
unsafe { libc::write(fd, pid_str.as_ptr() as *const libc::c_void, pid_str.len()) };
let write_err = if written < 0 {
Some(io::Error::last_os_error())
} else {
None
};
unsafe { libc::close(fd) };
if let Some(e) = write_err {
pre_exec_log(b"solti-exec: failed to write PID to cgroup.procs: ");
if let Some(code) = e.raw_os_error() {
pre_exec_log_errno(code);
}
return if fail_on_error { Err(e) } else { Ok(()) };
}
Ok(())
}
}
#[cfg_attr(not(target_os = "linux"), allow(dead_code))]
fn format_pid(pid: i32, buf: &mut [u8; 24]) -> &[u8] {
let mut n = pid as u32;
let mut idx = buf.len() - 1;
buf[idx] = b'\n';
if n == 0 {
idx -= 1;
buf[idx] = b'0';
} else {
while n > 0 {
idx -= 1;
buf[idx] = b'0' + (n % 10) as u8;
n /= 10;
}
}
&buf[idx..]
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn empty_limits_are_noop() {
let limits = CgroupLimits::default();
assert!(limits.is_empty());
let mut cmd = Command::new("sh");
let r = attach_cgroup(&mut cmd, "test-cgroup", &limits);
assert!(r.is_ok());
}
#[test]
fn build_cgroup_name_simple_case() {
let name = build_cgroup_name("runner", "slot", 42, 1000);
let parts: Vec<&str> = name.split('-').collect();
assert_eq!(name, "runner-slot-2a-3e8");
assert_eq!(parts.len(), 4);
assert_eq!(parts[0], "runner");
assert_eq!(parts[1], "slot");
assert_eq!(u64::from_str_radix(parts[2], 16).unwrap(), 42);
assert_eq!(u64::from_str_radix(parts[3], 16).unwrap(), 1000);
}
#[test]
fn build_cgroup_name_with_dashes() {
let name = build_cgroup_name("prod-runner", "demo-task", 42, 1733045913);
let timestamp_hex = format!("{:x}", 1733045913u64);
assert!(name.starts_with("prod-runner-"));
assert!(name.contains("-demo-task-"));
assert!(name.contains("-2a-"));
assert!(name.ends_with(&format!("-{}", timestamp_hex)));
}
#[test]
fn build_cgroup_name_hex_values() {
let name = build_cgroup_name("r", "s", 0, 0);
assert_eq!(name, "r-s-0-0");
let name = build_cgroup_name("r", "s", 255, 255);
assert_eq!(name, "r-s-ff-ff");
let name = build_cgroup_name("r", "s", 4096, 65536);
assert_eq!(name, "r-s-1000-10000");
}
#[cfg(target_os = "linux")]
#[test]
fn attach_with_limits_does_not_error() {
let limits = CgroupLimits {
cpu: Some(CpuMax::default()),
memory: Some(128 * 1024 * 1024),
pids: Some(32),
..Default::default()
};
let name = build_cgroup_name("test", "slot", 1, 1733045913);
let mut cmd = Command::new("true");
let r = attach_cgroup(&mut cmd, &name, &limits);
assert!(r.is_ok());
}
#[cfg(not(target_os = "linux"))]
#[test]
fn non_linux_platforms_ignore_limits() {
let limits = CgroupLimits {
cpu: Some(CpuMax::default()),
memory: Some(1),
pids: Some(1),
..Default::default()
};
let mut cmd = Command::new("true");
let r = attach_cgroup(&mut cmd, "test-cgroup", &limits);
assert!(
r.is_ok(),
"non-Linux must ignore limits but still return Ok"
);
}
#[cfg(target_os = "linux")]
#[test]
fn cleanup_nonexistent_cgroup_does_not_panic() {
let name = build_cgroup_name("test", "nonexistent", 999, 1733045913);
cleanup_cgroup(&name); }
fn fmt_pid(pid: i32) -> String {
let mut buf = [0u8; 24];
let slice = format_pid(pid, &mut buf);
String::from_utf8_lossy(slice).into_owned()
}
#[test]
fn format_pid_one() {
assert_eq!(fmt_pid(1), "1\n");
}
#[test]
fn format_pid_single_digit() {
assert_eq!(fmt_pid(9), "9\n");
}
#[test]
fn format_pid_two_digits() {
assert_eq!(fmt_pid(10), "10\n");
assert_eq!(fmt_pid(99), "99\n");
}
#[test]
fn format_pid_typical() {
assert_eq!(fmt_pid(32768), "32768\n");
}
#[test]
fn format_pid_large() {
assert_eq!(fmt_pid(4_194_304), "4194304\n");
}
#[test]
fn format_pid_zero() {
assert_eq!(fmt_pid(0), "0\n");
}
}