use crate::error::{NucleusError, Result, StateTransition};
use crate::resources::{CgroupState, ResourceLimits};
use std::fs;
use std::path::{Path, PathBuf};
use tracing::{debug, info};
const CGROUP_V2_ROOT: &str = "/sys/fs/cgroup";
pub struct Cgroup {
path: PathBuf,
state: CgroupState,
}
impl Cgroup {
pub fn create(name: &str) -> Result<Self> {
let state = CgroupState::Nonexistent.transition(CgroupState::Created)?;
let path = PathBuf::from(CGROUP_V2_ROOT).join(name);
info!("Creating cgroup at {:?}", path);
fs::create_dir_all(&path).map_err(|e| {
NucleusError::CgroupError(format!("Failed to create cgroup directory: {}", e))
})?;
Ok(Self { path, state })
}
pub fn set_limits(&mut self, limits: &ResourceLimits) -> Result<()> {
self.state = self.state.transition(CgroupState::Configured)?;
info!("Configuring cgroup limits: {:?}", limits);
if let Some(memory_bytes) = limits.memory_bytes {
self.write_value("memory.max", &memory_bytes.to_string())?;
debug!("Set memory.max = {}", memory_bytes);
}
if let Some(memory_high) = limits.memory_high {
self.write_value("memory.high", &memory_high.to_string())?;
debug!("Set memory.high = {}", memory_high);
}
if let Some(swap_max) = limits.memory_swap_max {
self.write_value("memory.swap.max", &swap_max.to_string())?;
debug!("Set memory.swap.max = {}", swap_max);
}
if let Some(cpu_quota_us) = limits.cpu_quota_us {
let cpu_max = format!("{} {}", cpu_quota_us, limits.cpu_period_us);
self.write_value("cpu.max", &cpu_max)?;
debug!("Set cpu.max = {}", cpu_max);
}
if let Some(cpu_weight) = limits.cpu_weight {
self.write_value("cpu.weight", &cpu_weight.to_string())?;
debug!("Set cpu.weight = {}", cpu_weight);
}
if let Some(pids_max) = limits.pids_max {
self.write_value("pids.max", &pids_max.to_string())?;
debug!("Set pids.max = {}", pids_max);
}
for io_limit in &limits.io_limits {
let line = io_limit.to_io_max_line();
self.write_value("io.max", &line)?;
debug!("Set io.max: {}", line);
}
info!("Successfully configured cgroup limits");
Ok(())
}
pub fn attach_process(&mut self, pid: u32) -> Result<()> {
self.state = self.state.transition(CgroupState::Attached)?;
info!("Attaching process {} to cgroup", pid);
self.write_value("cgroup.procs", &pid.to_string())?;
info!("Successfully attached process to cgroup");
Ok(())
}
fn write_value(&self, file: &str, value: &str) -> Result<()> {
let file_path = self.path.join(file);
fs::write(&file_path, value).map_err(|e| {
NucleusError::CgroupError(format!(
"Failed to write {} to {:?}: {}",
value, file_path, e
))
})?;
Ok(())
}
fn read_value(&self, file: &str) -> Result<String> {
let file_path = self.path.join(file);
fs::read_to_string(&file_path).map_err(|e| {
NucleusError::CgroupError(format!("Failed to read {:?}: {}", file_path, e))
})
}
pub fn memory_current(&self) -> Result<u64> {
let value = self.read_value("memory.current")?;
value.trim().parse().map_err(|e| {
NucleusError::CgroupError(format!("Failed to parse memory.current: {}", e))
})
}
pub fn path(&self) -> &Path {
&self.path
}
pub fn state(&self) -> CgroupState {
self.state
}
pub fn cleanup(mut self) -> Result<()> {
info!("Cleaning up cgroup {:?}", self.path);
if self.path.exists() {
fs::remove_dir(&self.path).map_err(|e| {
NucleusError::CgroupError(format!("Failed to remove cgroup: {}", e))
})?;
}
self.state = CgroupState::Removed;
info!("Successfully cleaned up cgroup");
Ok(())
}
}
impl Drop for Cgroup {
fn drop(&mut self) {
if !self.state.is_terminal() && self.path.exists() {
let _ = fs::remove_dir(&self.path);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_resource_limits_unlimited() {
let limits = ResourceLimits::unlimited();
assert!(limits.memory_bytes.is_none());
assert!(limits.memory_high.is_none());
assert!(limits.memory_swap_max.is_none());
assert!(limits.cpu_quota_us.is_none());
assert!(limits.cpu_weight.is_none());
assert!(limits.pids_max.is_none());
assert!(limits.io_limits.is_empty());
}
#[test]
fn test_cleanup_sets_removed_only_after_success() {
let source = include_str!("cgroup.rs");
let fn_start = source.find("pub fn cleanup").unwrap();
let after = &source[fn_start..];
let open = after.find('{').unwrap();
let mut depth = 0u32;
let mut fn_end = open;
for (i, ch) in after[open..].char_indices() {
match ch {
'{' => depth += 1,
'}' => {
depth -= 1;
if depth == 0 {
fn_end = open + i + 1;
break;
}
}
_ => {}
}
}
let cleanup_body = &after[..fn_end];
let removed_pos = cleanup_body
.find("Removed")
.expect("must reference Removed state");
let remove_dir_pos = cleanup_body
.find("remove_dir")
.expect("must call remove_dir");
assert!(
removed_pos > remove_dir_pos,
"CgroupState::Removed must be set AFTER remove_dir succeeds, not before"
);
}
}