use std::ffi::{CStr, CString};
use std::io;
use std::os::unix::ffi::OsStringExt;
use std::os::unix::process::CommandExt;
use std::path::{Path, PathBuf};
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::Duration;
use tokio::process::{Child, Command};
use tokio::time::{Instant, sleep};
use crate::Mechanism;
#[cfg(feature = "process-control")]
use crate::Signal;
#[cfg(feature = "limits")]
use crate::limits::ResourceLimits;
#[cfg(feature = "stats")]
use crate::stats::ProcessGroupStats;
#[cfg(feature = "stats")]
use crate::sys::ProcMetrics;
use crate::sys::pgroup::ProcessGroup;
const POLL_INTERVAL: Duration = Duration::from_millis(20);
static NEXT_ID: AtomicU64 = AtomicU64::new(0);
pub(crate) struct Job {
backend: Backend,
}
enum Backend {
Cgroup(Cgroup),
ProcessGroup(ProcessGroup),
}
impl Job {
pub(crate) fn new(#[cfg(feature = "limits")] limits: &ResourceLimits) -> io::Result<Self> {
let backend = match Cgroup::create(
#[cfg(feature = "limits")]
limits,
) {
Ok(cg) => Backend::Cgroup(cg),
Err(_e) => {
#[cfg(feature = "limits")]
if limits.any() {
return Err(_e);
}
Backend::ProcessGroup(ProcessGroup::new())
}
};
Ok(Job { backend })
}
pub(crate) fn spawn(
&self,
cmd: &mut Command,
opts: &crate::sys::SpawnOptions,
) -> io::Result<Child> {
let arm = |cmd: &mut Command| {
if opts.kill_on_parent_death {
let spawner_pid = std::process::id();
unsafe {
cmd.as_std_mut()
.pre_exec(move || arm_pdeathsig(spawner_pid));
}
}
};
match &self.backend {
Backend::Cgroup(cg) => {
let procs = CString::new(cg.path.join("cgroup.procs").into_os_string().into_vec())
.map_err(|_| {
io::Error::new(io::ErrorKind::InvalidInput, "cgroup path contains NUL")
})?;
unsafe {
cmd.as_std_mut()
.pre_exec(move || write_self_pid(procs.as_c_str()));
}
arm(cmd);
cmd.spawn()
}
Backend::ProcessGroup(pg) => {
arm(cmd);
pg.spawn(cmd, opts)
}
}
}
#[cfg(feature = "process-control")]
pub(crate) fn adopt(&self, child: &Child) -> io::Result<()> {
let pid = child
.id()
.ok_or_else(|| io::Error::other("child has no pid (already exited?)"))?
as i32;
match &self.backend {
Backend::Cgroup(cg) => {
std::fs::write(cg.path.join("cgroup.procs"), pid.to_string().as_bytes())
}
Backend::ProcessGroup(pg) => pg.adopt(child),
}
}
pub(crate) fn kill_all(&self) -> io::Result<()> {
match &self.backend {
Backend::Cgroup(cg) => cg.kill(),
Backend::ProcessGroup(pg) => pg.kill_all(),
}
}
#[cfg(feature = "process-control")]
pub(crate) fn signal(&self, sig: Signal) -> io::Result<()> {
match &self.backend {
Backend::Cgroup(cg) if sig.raw() == libc::SIGKILL => cg.kill(),
Backend::Cgroup(cg) => cg.signal(sig.raw()),
Backend::ProcessGroup(pg) => pg.signal(sig.raw()),
}
}
#[cfg(feature = "process-control")]
pub(crate) fn suspend(&self) -> io::Result<()> {
match &self.backend {
Backend::Cgroup(cg) => cg.freeze(true),
Backend::ProcessGroup(pg) => pg.suspend(),
}
}
#[cfg(feature = "process-control")]
pub(crate) fn resume(&self) -> io::Result<()> {
match &self.backend {
Backend::Cgroup(cg) => cg.freeze(false),
Backend::ProcessGroup(pg) => pg.resume(),
}
}
#[cfg(feature = "process-control")]
pub(crate) fn members(&self) -> io::Result<Vec<u32>> {
let pids = match &self.backend {
Backend::Cgroup(cg) => cg.members(),
Backend::ProcessGroup(pg) => pg.members(),
};
Ok(pids.into_iter().map(|pid| pid as u32).collect())
}
pub(crate) async fn graceful_shutdown(
&self,
timeout: Duration,
escalate: bool,
) -> io::Result<()> {
match &self.backend {
Backend::Cgroup(cg) => {
let _ = cg.signal(libc::SIGTERM);
let deadline = Instant::now() + timeout;
while !cg.is_empty() {
if Instant::now() >= deadline {
break;
}
sleep(POLL_INTERVAL).await;
}
if escalate && !cg.is_empty() {
cg.kill()?;
}
Ok(())
}
Backend::ProcessGroup(pg) => pg.graceful_shutdown(timeout, escalate).await,
}
}
#[cfg(feature = "stats")]
pub(crate) fn stats(&self) -> io::Result<ProcessGroupStats> {
match &self.backend {
Backend::Cgroup(cg) => {
let pids = cg.members();
let active = pids.len();
let mut cpu = Duration::ZERO;
let mut have_cpu = false;
let mut mem = 0u64;
let mut have_mem = false;
for pid in pids {
let m = process_metrics(pid as u32);
if let Some(c) = m.cpu_time {
cpu += c;
have_cpu = true;
}
if let Some(p) = m.peak_memory_bytes {
mem += p;
have_mem = true;
}
}
Ok(ProcessGroupStats {
active_process_count: active,
total_cpu_time: have_cpu.then_some(cpu),
peak_memory_bytes: have_mem.then_some(mem),
})
}
Backend::ProcessGroup(pg) => pg.stats(),
}
}
pub(crate) fn mechanism(&self) -> Mechanism {
match &self.backend {
Backend::Cgroup(_) => Mechanism::CgroupV2,
Backend::ProcessGroup(_) => Mechanism::ProcessGroup,
}
}
}
#[cfg(feature = "stats")]
pub(crate) fn process_metrics(pid: u32) -> ProcMetrics {
let mut metrics = ProcMetrics::default();
if let Ok(stat) = std::fs::read_to_string(format!("/proc/{pid}/stat"))
&& let Some(idx) = stat.rfind(')')
{
let fields: Vec<&str> = stat[idx + 1..].split_whitespace().collect();
if fields.len() > 12
&& let (Ok(utime), Ok(stime)) = (fields[11].parse::<u64>(), fields[12].parse::<u64>())
{
let hz = unsafe { libc::sysconf(libc::_SC_CLK_TCK) };
if hz > 0 {
let nanos = (utime + stime) as u128 * 1_000_000_000u128 / hz as u128;
metrics.cpu_time = Some(Duration::from_nanos(nanos as u64));
}
}
}
if let Ok(status) = std::fs::read_to_string(format!("/proc/{pid}/status")) {
for line in status.lines() {
if let Some(rest) = line.strip_prefix("VmHWM:") {
if let Some(kb) = rest
.split_whitespace()
.next()
.and_then(|s| s.parse::<u64>().ok())
{
metrics.peak_memory_bytes = Some(kb * 1024);
}
break;
}
}
}
metrics
}
impl Drop for Job {
fn drop(&mut self) {
match &self.backend {
Backend::Cgroup(cg) => {
let _ = cg.kill();
for _ in 0..50 {
if cg.is_empty() {
break;
}
std::thread::sleep(Duration::from_millis(2));
}
let _ = std::fs::remove_dir(&cg.path);
}
Backend::ProcessGroup(_) => {}
}
}
}
struct Cgroup {
path: PathBuf,
}
impl Cgroup {
fn create(#[cfg(feature = "limits")] limits: &ResourceLimits) -> io::Result<Self> {
let root = Path::new("/sys/fs/cgroup");
if !root.join("cgroup.controllers").exists() {
return Err(io::Error::new(
io::ErrorKind::Unsupported,
"cgroup v2 not mounted",
));
}
let self_cgroup = std::fs::read_to_string("/proc/self/cgroup")?;
let rel = self_cgroup
.lines()
.find_map(|line| line.strip_prefix("0::"))
.unwrap_or("/")
.trim();
let parent = root.join(rel.trim_start_matches('/'));
let name = format!(
"processkit-{}-{}",
std::process::id(),
NEXT_ID.fetch_add(1, Ordering::Relaxed)
);
let path = parent.join(name);
std::fs::create_dir(&path)?;
let cg = Cgroup { path };
#[cfg(feature = "limits")]
if limits.any()
&& let Err(e) = cg.apply_limits(&parent, limits)
{
let _ = std::fs::remove_dir(&cg.path);
return Err(e);
}
Ok(cg)
}
#[cfg(feature = "limits")]
fn apply_limits(&self, parent: &Path, limits: &ResourceLimits) -> io::Result<()> {
let mut spec = String::new();
if limits.memory_max.is_some() {
spec.push_str("+memory ");
}
if limits.max_processes.is_some() {
spec.push_str("+pids ");
}
if limits.cpu_quota.is_some() {
spec.push_str("+cpu ");
}
let spec = spec.trim_end();
if !spec.is_empty() {
let file = parent.join("cgroup.subtree_control");
std::fs::write(&file, spec).map_err(|e| {
io::Error::new(
e.kind(),
format!(
"enabling cgroup controllers ({spec}) via {} failed: {e} — \
resource limits require a delegated cgroup (run as root, in a \
container, or under a systemd unit with Delegate=yes)",
file.display()
),
)
})?;
}
if let Some(bytes) = limits.memory_max {
std::fs::write(self.path.join("memory.max"), bytes.to_string())?;
}
if let Some(n) = limits.max_processes {
std::fs::write(self.path.join("pids.max"), n.to_string())?;
}
if let Some(cores) = limits.cpu_quota {
std::fs::write(self.path.join("cpu.max"), cpu_max_value(cores))?;
}
Ok(())
}
fn members(&self) -> Vec<i32> {
match std::fs::read_to_string(self.path.join("cgroup.procs")) {
Ok(procs) => procs
.lines()
.filter_map(|l| l.trim().parse::<i32>().ok())
.collect(),
Err(_) => Vec::new(),
}
}
fn is_empty(&self) -> bool {
self.members().is_empty()
}
fn signal(&self, sig: i32) -> io::Result<()> {
for pid in self.members() {
unsafe {
libc::kill(pid, sig);
}
}
Ok(())
}
#[cfg(feature = "process-control")]
fn freeze(&self, frozen: bool) -> io::Result<()> {
let val: &[u8] = if frozen { b"1" } else { b"0" };
if std::fs::write(self.path.join("cgroup.freeze"), val).is_ok() {
return Ok(());
}
let sig = if frozen { libc::SIGSTOP } else { libc::SIGCONT };
self.signal(sig)
}
fn kill(&self) -> io::Result<()> {
if std::fs::write(self.path.join("cgroup.kill"), b"1").is_ok() {
return Ok(());
}
for _ in 0..50 {
let members = self.members();
if members.is_empty() {
break;
}
for pid in members {
unsafe {
libc::kill(pid, libc::SIGKILL);
}
}
std::thread::sleep(Duration::from_millis(2));
}
Ok(())
}
}
#[cfg(feature = "limits")]
fn cpu_max_value(cores: f64) -> String {
const PERIOD: u64 = 100_000;
let quota = (cores * PERIOD as f64).round().max(1.0) as u64;
format!("{quota} {PERIOD}")
}
fn arm_pdeathsig(spawner_pid: u32) -> io::Result<()> {
unsafe {
if libc::prctl(libc::PR_SET_PDEATHSIG, libc::SIGKILL, 0, 0, 0) != 0 {
return Err(io::Error::last_os_error());
}
if libc::getppid() as u32 != spawner_pid {
libc::_exit(0);
}
}
Ok(())
}
fn write_self_pid(path: &CStr) -> io::Result<()> {
unsafe {
let fd = libc::open(path.as_ptr(), libc::O_WRONLY | libc::O_CLOEXEC);
if fd < 0 {
return Err(io::Error::last_os_error());
}
let mut buf = [0u8; 12];
let mut i = buf.len();
let mut v = libc::getpid() as u32;
loop {
i -= 1;
buf[i] = b'0' + (v % 10) as u8;
v /= 10;
if v == 0 {
break;
}
}
let bytes = &buf[i..];
let written = libc::write(fd, bytes.as_ptr().cast(), bytes.len());
let werr = io::Error::last_os_error();
libc::close(fd);
if written < 0 {
return Err(werr);
}
Ok(())
}
}
#[cfg(all(test, feature = "limits"))]
mod tests {
use super::cpu_max_value;
#[test]
fn cpu_max_formats_quota_and_period() {
assert_eq!(cpu_max_value(0.5), "50000 100000");
assert_eq!(cpu_max_value(2.0), "200000 100000");
assert_eq!(cpu_max_value(0.000_001), "1 100000");
}
}