use std::collections::HashMap;
use std::sync::Arc;
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use tokio::sync::RwLock;
use crate::error::SandboxError;
use crate::output::CapturedOutput;
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[non_exhaustive]
pub enum ProcessStatus {
Running,
Exited {
code: i32,
},
Signaled {
signal: i32,
},
Unknown,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProcessRecord {
pub pid: u32,
pub command: String,
pub args: Vec<String>,
pub started_at: DateTime<Utc>,
pub cgroup_path: Option<String>,
pub status: ProcessStatus,
pub cpu_usage_ns: Option<u64>,
pub memory_bytes: Option<u64>,
#[serde(skip)]
pub output: Option<Arc<CapturedOutput>>,
}
impl ProcessRecord {
#[must_use]
pub fn new(pid: u32, command: impl Into<String>, args: Vec<String>) -> Self {
Self {
pid,
command: command.into(),
args,
started_at: Utc::now(),
cgroup_path: None,
status: ProcessStatus::Running,
cpu_usage_ns: None,
memory_bytes: None,
output: None,
}
}
}
#[derive(Debug, Default)]
pub struct ProcessRegistry {
entries: HashMap<u32, ProcessRecord>,
max_tracked: Option<usize>,
}
impl ProcessRegistry {
#[must_use]
pub fn new(max_tracked: Option<usize>) -> Self {
Self {
entries: HashMap::new(),
max_tracked,
}
}
pub fn insert(&mut self, record: ProcessRecord) -> Result<(), SandboxError> {
if let Some(max) = self.max_tracked {
let running = self
.entries
.values()
.filter(|r| r.status == ProcessStatus::Running)
.count();
if running >= max {
return Err(SandboxError::RegistryFull { max_tracked: max });
}
}
let _ = self.entries.insert(record.pid, record);
Ok(())
}
#[must_use]
pub fn get(&self, pid: u32) -> Option<&ProcessRecord> {
self.entries.get(&pid)
}
pub fn get_mut(&mut self, pid: u32) -> Option<&mut ProcessRecord> {
self.entries.get_mut(&pid)
}
pub fn running(&self) -> impl Iterator<Item = &ProcessRecord> {
self.entries
.values()
.filter(|r| r.status == ProcessStatus::Running)
}
pub fn all(&self) -> impl Iterator<Item = &ProcessRecord> {
self.entries.values()
}
pub fn mark_exited(&mut self, pid: u32, code: i32) {
if let Some(r) = self.entries.get_mut(&pid) {
r.status = ProcessStatus::Exited { code };
}
}
pub fn mark_signaled(&mut self, pid: u32, signal: i32) {
if let Some(r) = self.entries.get_mut(&pid) {
r.status = ProcessStatus::Signaled { signal };
}
}
pub fn gc(&mut self) {
self.entries
.retain(|_, r| r.status == ProcessStatus::Running);
}
}
pub fn monitor_child(
mut child: tokio::process::Child,
pid: u32,
registry: Arc<RwLock<ProcessRegistry>>,
) {
let _handle = tokio::spawn(async move {
match child.wait().await {
Ok(status) => {
let code = status
.code()
.unwrap_or_else(|| if status.success() { 0 } else { -1 });
let mut reg = registry.write().await;
if status.code().is_some() {
reg.mark_exited(pid, code);
} else {
reg.mark_signaled(pid, -1);
}
}
Err(_) => {
registry.write().await.mark_signaled(pid, -1);
}
}
});
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
fn make_record(pid: u32) -> ProcessRecord {
ProcessRecord::new(pid, "sleep", vec!["1".into()])
}
#[test]
fn insert_and_retrieve() {
let mut reg = ProcessRegistry::new(None);
reg.insert(make_record(100)).unwrap();
let r = reg.get(100).unwrap();
assert_eq!(r.pid, 100);
assert_eq!(r.status, ProcessStatus::Running);
}
#[test]
fn max_tracked_enforced() {
let mut reg = ProcessRegistry::new(Some(2));
reg.insert(make_record(1)).unwrap();
reg.insert(make_record(2)).unwrap();
let err = reg.insert(make_record(3)).unwrap_err();
assert!(matches!(err, SandboxError::RegistryFull { max_tracked: 2 }));
}
#[test]
fn max_tracked_counts_only_running() {
let mut reg = ProcessRegistry::new(Some(2));
reg.insert(make_record(1)).unwrap();
reg.insert(make_record(2)).unwrap();
reg.mark_exited(1, 0);
reg.insert(make_record(3)).unwrap();
}
#[test]
fn mark_exited_and_gc() {
let mut reg = ProcessRegistry::new(None);
reg.insert(make_record(10)).unwrap();
reg.insert(make_record(11)).unwrap();
reg.mark_exited(10, 0);
reg.gc();
assert!(reg.get(10).is_none());
assert!(reg.get(11).is_some());
}
#[test]
fn running_iterator_excludes_exited() {
let mut reg = ProcessRegistry::new(None);
reg.insert(make_record(20)).unwrap();
reg.insert(make_record(21)).unwrap();
reg.mark_signaled(20, 9);
let running: Vec<_> = reg.running().collect();
assert_eq!(running.len(), 1);
assert_eq!(running[0].pid, 21);
}
}