synwire_sandbox/
process_registry.rs1use std::collections::HashMap;
10use std::sync::Arc;
11
12use chrono::{DateTime, Utc};
13use serde::{Deserialize, Serialize};
14use tokio::sync::RwLock;
15
16use crate::error::SandboxError;
17use crate::output::CapturedOutput;
18
19#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
23#[non_exhaustive]
24pub enum ProcessStatus {
25 Running,
27 Exited {
29 code: i32,
31 },
32 Signaled {
34 signal: i32,
36 },
37 Unknown,
39}
40
41#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct ProcessRecord {
46 pub pid: u32,
48 pub command: String,
50 pub args: Vec<String>,
52 pub started_at: DateTime<Utc>,
54 pub cgroup_path: Option<String>,
56 pub status: ProcessStatus,
58 pub cpu_usage_ns: Option<u64>,
60 pub memory_bytes: Option<u64>,
62 #[serde(skip)]
67 pub output: Option<Arc<CapturedOutput>>,
68}
69
70impl ProcessRecord {
71 #[must_use]
73 pub fn new(pid: u32, command: impl Into<String>, args: Vec<String>) -> Self {
74 Self {
75 pid,
76 command: command.into(),
77 args,
78 started_at: Utc::now(),
79 cgroup_path: None,
80 status: ProcessStatus::Running,
81 cpu_usage_ns: None,
82 memory_bytes: None,
83 output: None,
84 }
85 }
86}
87
88#[derive(Debug, Default)]
95pub struct ProcessRegistry {
96 entries: HashMap<u32, ProcessRecord>,
98 max_tracked: Option<usize>,
100}
101
102impl ProcessRegistry {
103 #[must_use]
105 pub fn new(max_tracked: Option<usize>) -> Self {
106 Self {
107 entries: HashMap::new(),
108 max_tracked,
109 }
110 }
111
112 pub fn insert(&mut self, record: ProcessRecord) -> Result<(), SandboxError> {
116 if let Some(max) = self.max_tracked {
117 let running = self
118 .entries
119 .values()
120 .filter(|r| r.status == ProcessStatus::Running)
121 .count();
122 if running >= max {
123 return Err(SandboxError::RegistryFull { max_tracked: max });
124 }
125 }
126 let _ = self.entries.insert(record.pid, record);
127 Ok(())
128 }
129
130 #[must_use]
132 pub fn get(&self, pid: u32) -> Option<&ProcessRecord> {
133 self.entries.get(&pid)
134 }
135
136 pub fn get_mut(&mut self, pid: u32) -> Option<&mut ProcessRecord> {
138 self.entries.get_mut(&pid)
139 }
140
141 pub fn running(&self) -> impl Iterator<Item = &ProcessRecord> {
143 self.entries
144 .values()
145 .filter(|r| r.status == ProcessStatus::Running)
146 }
147
148 pub fn all(&self) -> impl Iterator<Item = &ProcessRecord> {
150 self.entries.values()
151 }
152
153 pub fn mark_exited(&mut self, pid: u32, code: i32) {
155 if let Some(r) = self.entries.get_mut(&pid) {
156 r.status = ProcessStatus::Exited { code };
157 }
158 }
159
160 pub fn mark_signaled(&mut self, pid: u32, signal: i32) {
162 if let Some(r) = self.entries.get_mut(&pid) {
163 r.status = ProcessStatus::Signaled { signal };
164 }
165 }
166
167 pub fn gc(&mut self) {
170 self.entries
171 .retain(|_, r| r.status == ProcessStatus::Running);
172 }
173}
174
175pub fn monitor_child(
184 mut child: tokio::process::Child,
185 pid: u32,
186 registry: Arc<RwLock<ProcessRegistry>>,
187) {
188 let _handle = tokio::spawn(async move {
189 match child.wait().await {
190 Ok(status) => {
191 let code = status
192 .code()
193 .unwrap_or_else(|| if status.success() { 0 } else { -1 });
194 let mut reg = registry.write().await;
195 if status.code().is_some() {
196 reg.mark_exited(pid, code);
197 } else {
198 reg.mark_signaled(pid, -1);
200 }
201 }
202 Err(_) => {
203 registry.write().await.mark_signaled(pid, -1);
204 }
205 }
206 });
207}
208
209#[cfg(test)]
212#[allow(clippy::unwrap_used)]
213mod tests {
214 use super::*;
215
216 fn make_record(pid: u32) -> ProcessRecord {
217 ProcessRecord::new(pid, "sleep", vec!["1".into()])
218 }
219
220 #[test]
221 fn insert_and_retrieve() {
222 let mut reg = ProcessRegistry::new(None);
223 reg.insert(make_record(100)).unwrap();
224 let r = reg.get(100).unwrap();
225 assert_eq!(r.pid, 100);
226 assert_eq!(r.status, ProcessStatus::Running);
227 }
228
229 #[test]
230 fn max_tracked_enforced() {
231 let mut reg = ProcessRegistry::new(Some(2));
232 reg.insert(make_record(1)).unwrap();
233 reg.insert(make_record(2)).unwrap();
234 let err = reg.insert(make_record(3)).unwrap_err();
235 assert!(matches!(err, SandboxError::RegistryFull { max_tracked: 2 }));
236 }
237
238 #[test]
239 fn max_tracked_counts_only_running() {
240 let mut reg = ProcessRegistry::new(Some(2));
241 reg.insert(make_record(1)).unwrap();
242 reg.insert(make_record(2)).unwrap();
243 reg.mark_exited(1, 0);
244 reg.insert(make_record(3)).unwrap();
246 }
247
248 #[test]
249 fn mark_exited_and_gc() {
250 let mut reg = ProcessRegistry::new(None);
251 reg.insert(make_record(10)).unwrap();
252 reg.insert(make_record(11)).unwrap();
253 reg.mark_exited(10, 0);
254 reg.gc();
255 assert!(reg.get(10).is_none());
256 assert!(reg.get(11).is_some());
257 }
258
259 #[test]
260 fn running_iterator_excludes_exited() {
261 let mut reg = ProcessRegistry::new(None);
262 reg.insert(make_record(20)).unwrap();
263 reg.insert(make_record(21)).unwrap();
264 reg.mark_signaled(20, 9);
265 let running: Vec<_> = reg.running().collect();
266 assert_eq!(running.len(), 1);
267 assert_eq!(running[0].pid, 21);
268 }
269}