1use crate::error::{NucleusError, Result, StateTransition};
2use crate::resources::{CgroupState, ResourceLimits};
3use nix::sys::signal::{kill, Signal};
4use nix::unistd::Pid;
5use std::fs;
6use std::path::{Path, PathBuf};
7use std::thread;
8use std::time::Duration;
9use tracing::{debug, info, warn};
10
11const CGROUP_V2_ROOT: &str = "/sys/fs/cgroup";
12const NUCLEUS_CGROUP_ROOT_ENV: &str = "NUCLEUS_CGROUP_ROOT";
13const CGROUP_CLEANUP_RETRIES: usize = 50;
14const CGROUP_CLEANUP_SLEEP: Duration = Duration::from_millis(20);
15
16pub struct Cgroup {
21 path: PathBuf,
22 state: CgroupState,
23}
24
25impl Cgroup {
26 pub fn create(name: &str) -> Result<Self> {
30 let state = CgroupState::Nonexistent.transition(CgroupState::Created)?;
31 let path = Self::root_path()?.join(name);
32
33 info!("Creating cgroup at {:?}", path);
34
35 fs::create_dir_all(&path).map_err(|e| {
37 NucleusError::CgroupError(format!("Failed to create cgroup directory: {}", e))
38 })?;
39
40 Ok(Self { path, state })
41 }
42
43 fn root_path() -> Result<PathBuf> {
44 Self::root_path_from_override(std::env::var_os(NUCLEUS_CGROUP_ROOT_ENV))
45 }
46
47 fn root_path_from_override(raw: Option<std::ffi::OsString>) -> Result<PathBuf> {
48 match raw {
49 Some(raw) if !raw.as_os_str().is_empty() => {
50 let path = PathBuf::from(raw);
51 if !path.is_absolute() {
52 return Err(NucleusError::CgroupError(format!(
53 "{} must be an absolute path",
54 NUCLEUS_CGROUP_ROOT_ENV
55 )));
56 }
57 Ok(path)
58 }
59 _ => Ok(PathBuf::from(CGROUP_V2_ROOT)),
60 }
61 }
62
63 pub fn set_limits(&mut self, limits: &ResourceLimits) -> Result<()> {
67 self.state = self.state.transition(CgroupState::Configured)?;
68
69 info!("Configuring cgroup limits: {:?}", limits);
70
71 if let Some(memory_bytes) = limits.memory_bytes {
73 self.write_value("memory.max", &memory_bytes.to_string())?;
74 debug!("Set memory.max = {}", memory_bytes);
75 }
76
77 if let Some(memory_high) = limits.memory_high {
79 self.write_value("memory.high", &memory_high.to_string())?;
80 debug!("Set memory.high = {}", memory_high);
81 }
82
83 if let Some(swap_max) = limits.memory_swap_max {
85 self.write_value("memory.swap.max", &swap_max.to_string())?;
86 debug!("Set memory.swap.max = {}", swap_max);
87 }
88 if limits.memory_bytes.is_some()
89 || limits.memory_high.is_some()
90 || limits.memory_swap_max.is_some()
91 {
92 self.write_value("memory.oom.group", "1")?;
93 debug!("Set memory.oom.group = 1");
94 }
95
96 if let Some(cpu_quota_us) = limits.cpu_quota_us {
98 let cpu_max = format!("{} {}", cpu_quota_us, limits.cpu_period_us);
99 self.write_value("cpu.max", &cpu_max)?;
100 debug!("Set cpu.max = {}", cpu_max);
101 }
102
103 if let Some(cpu_weight) = limits.cpu_weight {
105 self.write_value("cpu.weight", &cpu_weight.to_string())?;
106 debug!("Set cpu.weight = {}", cpu_weight);
107 }
108
109 if let Some(pids_max) = limits.pids_max {
111 self.write_value("pids.max", &pids_max.to_string())?;
112 debug!("Set pids.max = {}", pids_max);
113 }
114
115 for io_limit in &limits.io_limits {
117 let line = io_limit.to_io_max_line();
118 self.write_value("io.max", &line)?;
119 debug!("Set io.max: {}", line);
120 }
121
122 info!("Successfully configured cgroup limits");
123
124 Ok(())
125 }
126
127 pub fn attach_process(&mut self, pid: u32) -> Result<()> {
131 self.state = self.state.transition(CgroupState::Attached)?;
132
133 info!("Attaching process {} to cgroup", pid);
134
135 self.write_value("cgroup.procs", &pid.to_string())?;
136
137 info!("Successfully attached process to cgroup");
138
139 Ok(())
140 }
141
142 fn write_value(&self, file: &str, value: &str) -> Result<()> {
144 let file_path = self.path.join(file);
145 fs::write(&file_path, value).map_err(|e| {
146 NucleusError::CgroupError(format!(
147 "Failed to write {} to {:?}: {}",
148 value, file_path, e
149 ))
150 })?;
151 Ok(())
152 }
153
154 fn read_value(&self, file: &str) -> Result<String> {
156 let file_path = self.path.join(file);
157 fs::read_to_string(&file_path).map_err(|e| {
158 NucleusError::CgroupError(format!("Failed to read {:?}: {}", file_path, e))
159 })
160 }
161
162 fn set_frozen(&self, frozen: bool) -> Result<bool> {
163 let freeze_path = self.path.join("cgroup.freeze");
164 if !freeze_path.exists() {
165 return Ok(false);
166 }
167 self.write_value("cgroup.freeze", if frozen { "1" } else { "0" })?;
168 debug!("Set cgroup.freeze = {}", if frozen { 1 } else { 0 });
169 Ok(true)
170 }
171
172 fn parse_cgroup_events_populated(events: &str) -> Result<bool> {
173 for line in events.lines() {
174 if let Some(value) = line.strip_prefix("populated ") {
175 return match value.trim() {
176 "0" => Ok(false),
177 "1" => Ok(true),
178 other => Err(NucleusError::CgroupError(format!(
179 "Unexpected populated value in cgroup.events: {}",
180 other
181 ))),
182 };
183 }
184 }
185 Err(NucleusError::CgroupError(
186 "Missing populated entry in cgroup.events".to_string(),
187 ))
188 }
189
190 fn read_pids(&self) -> Result<Vec<Pid>> {
191 let file_path = self.path.join("cgroup.procs");
192 if !file_path.exists() {
193 return Ok(Vec::new());
194 }
195 let content = fs::read_to_string(&file_path).map_err(|e| {
196 NucleusError::CgroupError(format!("Failed to read {:?}: {}", file_path, e))
197 })?;
198 content
199 .lines()
200 .filter(|line| !line.trim().is_empty())
201 .map(|line| {
202 line.trim().parse::<i32>().map(Pid::from_raw).map_err(|e| {
203 NucleusError::CgroupError(format!(
204 "Failed to parse pid '{}' from {:?}: {}",
205 line.trim(),
206 file_path,
207 e
208 ))
209 })
210 })
211 .collect()
212 }
213
214 fn is_populated(&self) -> Result<bool> {
215 let events_path = self.path.join("cgroup.events");
216 if events_path.exists() {
217 let events = fs::read_to_string(&events_path).map_err(|e| {
218 NucleusError::CgroupError(format!("Failed to read {:?}: {}", events_path, e))
219 })?;
220 return Self::parse_cgroup_events_populated(&events);
221 }
222 Ok(!self.read_pids()?.is_empty())
223 }
224
225 fn kill_visible_processes(&self) -> Result<()> {
226 for pid in self.read_pids()? {
227 match kill(pid, Signal::SIGKILL) {
228 Ok(()) => {}
229 Err(nix::errno::Errno::ESRCH) => {}
230 Err(e) => {
231 return Err(NucleusError::CgroupError(format!(
232 "Failed to SIGKILL pid {} in {:?}: {}",
233 pid, self.path, e
234 )))
235 }
236 }
237 }
238 Ok(())
239 }
240
241 fn kill_all_processes(&self) -> Result<()> {
242 let kill_path = self.path.join("cgroup.kill");
243 if kill_path.exists() {
244 self.write_value("cgroup.kill", "1")?;
245 debug!("Triggered cgroup.kill for {:?}", self.path);
246 }
247 self.kill_visible_processes()
248 }
249
250 fn wait_until_empty(&self) -> Result<()> {
251 for attempt in 0..CGROUP_CLEANUP_RETRIES {
252 if !self.is_populated()? {
253 return Ok(());
254 }
255 if attempt + 1 < CGROUP_CLEANUP_RETRIES {
256 self.kill_visible_processes()?;
257 thread::sleep(CGROUP_CLEANUP_SLEEP);
258 }
259 }
260
261 let remaining = self
262 .read_pids()?
263 .into_iter()
264 .map(|pid| pid.to_string())
265 .collect::<Vec<_>>();
266 Err(NucleusError::CgroupError(format!(
267 "Timed out waiting for cgroup {:?} to drain (remaining pids: {})",
268 self.path,
269 if remaining.is_empty() {
270 "<unknown>".to_string()
271 } else {
272 remaining.join(", ")
273 }
274 )))
275 }
276
277 pub fn memory_current(&self) -> Result<u64> {
279 let value = self.read_value("memory.current")?;
280 value.trim().parse().map_err(|e| {
281 NucleusError::CgroupError(format!("Failed to parse memory.current: {}", e))
282 })
283 }
284
285 pub fn path(&self) -> &Path {
287 &self.path
288 }
289
290 pub fn state(&self) -> CgroupState {
292 self.state
293 }
294
295 pub fn cleanup(mut self) -> Result<()> {
299 info!("Cleaning up cgroup {:?}", self.path);
300
301 if self.path.exists() {
302 let froze = self.set_frozen(true)?;
303 let cleanup_result: Result<()> = (|| {
304 self.kill_all_processes()?;
305 self.wait_until_empty()?;
306 fs::remove_dir(&self.path).map_err(|e| {
307 NucleusError::CgroupError(format!("Failed to remove cgroup: {}", e))
310 })?;
311 Ok(())
312 })();
313 if cleanup_result.is_err() && froze {
314 if let Err(e) = self.set_frozen(false) {
315 warn!(
316 "Failed to unfreeze cgroup {:?} after cleanup error: {}",
317 self.path, e
318 );
319 }
320 }
321 cleanup_result?;
322 }
323
324 self.state = CgroupState::Removed;
326 info!("Successfully cleaned up cgroup");
327
328 Ok(())
329 }
330}
331
332impl Drop for Cgroup {
333 fn drop(&mut self) {
334 if !self.state.is_terminal() && self.path.exists() {
335 let froze = self.set_frozen(true).unwrap_or(false);
336 let _ = self.kill_all_processes();
337 let _ = self.wait_until_empty();
338 let _ = fs::remove_dir(&self.path);
339 if self.path.exists() && froze {
340 let _ = self.set_frozen(false);
341 }
342 }
343 }
344}
345
346#[cfg(test)]
347mod tests {
348 use super::*;
349 use std::ffi::OsString;
350
351 #[test]
352 fn test_resource_limits_unlimited() {
353 let limits = ResourceLimits::unlimited();
354 assert!(limits.memory_bytes.is_none());
355 assert!(limits.memory_high.is_none());
356 assert!(limits.memory_swap_max.is_none());
357 assert!(limits.cpu_quota_us.is_none());
358 assert!(limits.cpu_weight.is_none());
359 assert!(limits.pids_max.is_none());
360 assert!(limits.io_limits.is_empty());
361 }
362
363 #[test]
364 fn test_cgroup_root_override_requires_absolute_path() {
365 assert_eq!(
366 Cgroup::root_path_from_override(None).unwrap(),
367 PathBuf::from(CGROUP_V2_ROOT)
368 );
369 assert_eq!(
370 Cgroup::root_path_from_override(Some(OsString::from(""))).unwrap(),
371 PathBuf::from(CGROUP_V2_ROOT)
372 );
373 assert_eq!(
374 Cgroup::root_path_from_override(Some(OsString::from("/sys/fs/cgroup/example.service")))
375 .unwrap(),
376 PathBuf::from("/sys/fs/cgroup/example.service")
377 );
378 assert!(Cgroup::root_path_from_override(Some(OsString::from("relative"))).is_err());
379 }
380
381 #[test]
385 fn test_cleanup_sets_removed_only_after_success() {
386 let source = include_str!("cgroup.rs");
390 let fn_start = source.find("pub fn cleanup").unwrap();
391 let after = &source[fn_start..];
392 let open = after.find('{').unwrap();
393 let mut depth = 0u32;
394 let mut fn_end = open;
395 for (i, ch) in after[open..].char_indices() {
396 match ch {
397 '{' => depth += 1,
398 '}' => {
399 depth -= 1;
400 if depth == 0 {
401 fn_end = open + i + 1;
402 break;
403 }
404 }
405 _ => {}
406 }
407 }
408 let cleanup_body = &after[..fn_end];
409 let removed_pos = cleanup_body
410 .find("Removed")
411 .expect("must reference Removed state");
412 let remove_dir_pos = cleanup_body
413 .find("remove_dir")
414 .expect("must call remove_dir");
415 assert!(
416 removed_pos > remove_dir_pos,
417 "CgroupState::Removed must be set AFTER remove_dir succeeds, not before"
418 );
419 }
420
421 #[test]
422 fn test_parse_cgroup_events_populated() {
423 assert!(Cgroup::parse_cgroup_events_populated("populated 1\nfrozen 0\n").unwrap());
424 assert!(!Cgroup::parse_cgroup_events_populated("frozen 0\npopulated 0\n").unwrap());
425 }
426
427 #[test]
428 fn test_set_limits_source_enables_memory_oom_group() {
429 let source = include_str!("cgroup.rs");
430 let fn_start = source.find("pub fn set_limits").unwrap();
431 let after = &source[fn_start..];
432 let open = after.find('{').unwrap();
433 let mut depth = 0u32;
434 let mut fn_end = open;
435 for (i, ch) in after[open..].char_indices() {
436 match ch {
437 '{' => depth += 1,
438 '}' => {
439 depth -= 1;
440 if depth == 0 {
441 fn_end = open + i + 1;
442 break;
443 }
444 }
445 _ => {}
446 }
447 }
448 let body = &after[..fn_end];
449 assert!(
450 body.contains("memory.oom.group"),
451 "set_limits must enable memory.oom.group when memory controls are configured"
452 );
453 }
454
455 #[test]
456 fn test_cleanup_source_kills_processes_before_remove_dir() {
457 let source = include_str!("cgroup.rs");
458 let fn_start = source.find("pub fn cleanup").unwrap();
459 let after = &source[fn_start..];
460 let open = after.find('{').unwrap();
461 let mut depth = 0u32;
462 let mut fn_end = open;
463 for (i, ch) in after[open..].char_indices() {
464 match ch {
465 '{' => depth += 1,
466 '}' => {
467 depth -= 1;
468 if depth == 0 {
469 fn_end = open + i + 1;
470 break;
471 }
472 }
473 _ => {}
474 }
475 }
476 let body = &after[..fn_end];
477 let freeze_pos = body.find("set_frozen(true)").unwrap();
478 let kill_pos = body.find("kill_all_processes").unwrap();
479 let remove_dir_pos = body.find("remove_dir").unwrap();
480 assert!(
481 freeze_pos < kill_pos && kill_pos < remove_dir_pos,
482 "cleanup must freeze and kill the cgroup before attempting remove_dir"
483 );
484 }
485}