1use std::sync::RwLock;
13use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
14use tokio::sync::Notify;
15
16#[cfg(unix)]
17use nix::sys::signal::{Signal, kill};
18#[cfg(unix)]
19use nix::unistd::Pid;
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
23pub enum FreezeMode {
24 #[default]
26 None,
27 Process,
30}
31
32#[derive(Debug, Clone, PartialEq)]
34pub enum FreezeError {
35 NoPidConfigured,
37 ProcessNotFound(i32),
39 PermissionDenied(i32),
41 UnsupportedPlatform,
43 SignalFailed(String),
45 Multiple(Vec<FreezeError>),
47}
48
49impl std::fmt::Display for FreezeError {
50 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
51 match self {
52 FreezeError::NoPidConfigured => write!(f, "No PIDs configured for process freezing"),
53 FreezeError::ProcessNotFound(pid) => write!(f, "Process {} not found", pid),
54 FreezeError::PermissionDenied(pid) => {
55 write!(f, "Permission denied to send signal to process {}", pid)
56 }
57 FreezeError::UnsupportedPlatform => {
58 write!(f, "Process freezing not supported on this platform")
59 }
60 FreezeError::SignalFailed(msg) => write!(f, "Signal failed: {}", msg),
61 FreezeError::Multiple(errors) => {
62 write!(f, "Multiple freeze errors: ")?;
63 for (i, err) in errors.iter().enumerate() {
64 if i > 0 {
65 write!(f, ", ")?;
66 }
67 write!(f, "{}", err)?;
68 }
69 Ok(())
70 }
71 }
72 }
73}
74
75impl std::error::Error for FreezeError {}
76
77pub struct FreezeState {
87 mode: FreezeMode,
89
90 pids: RwLock<Vec<i32>>,
93
94 is_frozen: AtomicBool,
96
97 freeze_epoch: AtomicU64,
100
101 state_changed: Notify,
103}
104
105impl std::fmt::Debug for FreezeState {
106 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
107 f.debug_struct("FreezeState")
108 .field("mode", &self.mode)
109 .field("pids", &*self.pids.read().unwrap())
110 .field("is_frozen", &self.is_frozen.load(Ordering::SeqCst))
111 .field("freeze_epoch", &self.freeze_epoch.load(Ordering::SeqCst))
112 .finish()
113 }
114}
115
116impl FreezeState {
117 pub fn new(mode: FreezeMode, pid: Option<u32>) -> Self {
119 Self {
120 mode,
121 pids: RwLock::new(pid.map(|p| vec![p as i32]).unwrap_or_default()),
122 is_frozen: AtomicBool::new(false),
123 freeze_epoch: AtomicU64::new(0),
124 state_changed: Notify::new(),
125 }
126 }
127
128 pub fn with_pids(mode: FreezeMode, pids: Vec<u32>) -> Self {
133 Self {
134 mode,
135 pids: RwLock::new(pids.into_iter().map(|p| p as i32).collect()),
136 is_frozen: AtomicBool::new(false),
137 freeze_epoch: AtomicU64::new(0),
138 state_changed: Notify::new(),
139 }
140 }
141
142 pub fn mode(&self) -> FreezeMode {
144 self.mode
145 }
146
147 pub fn pid_count(&self) -> usize {
149 self.pids.read().unwrap().len()
150 }
151
152 pub fn pid(&self) -> Option<i32> {
154 self.pids.read().unwrap().first().copied()
155 }
156
157 pub fn register_pid(&self, pid: u32) {
162 self.pids.write().unwrap().push(pid as i32);
163 }
164
165 pub fn is_frozen(&self) -> bool {
167 self.is_frozen.load(Ordering::SeqCst)
168 }
169
170 pub fn current_epoch(&self) -> u64 {
172 self.freeze_epoch.load(Ordering::SeqCst)
173 }
174
175 pub async fn wait_for_state_change(&self) {
177 self.state_changed.notified().await;
178 }
179
180 #[cfg(unix)]
190 pub fn freeze_at_epoch(&self, epoch: u64) -> Result<bool, FreezeError> {
191 if self.mode != FreezeMode::Process {
192 return Ok(false);
193 }
194
195 let current_epoch = self.freeze_epoch.load(Ordering::SeqCst);
196 if current_epoch != epoch {
197 return Ok(false);
198 }
199
200 let pids = self.pids.read().unwrap();
201 if pids.is_empty() {
202 return Err(FreezeError::NoPidConfigured);
203 }
204
205 if self
206 .is_frozen
207 .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
208 .is_err()
209 {
210 return Ok(true);
211 }
212
213 let mut errors = Vec::new();
214 let mut any_success = false;
215
216 for &pid in pids.iter() {
217 match kill(Pid::from_raw(pid), Signal::SIGSTOP) {
218 Ok(()) => {
219 any_success = true;
220 }
221 Err(nix::errno::Errno::ESRCH) => {
222 errors.push(FreezeError::ProcessNotFound(pid));
223 }
224 Err(nix::errno::Errno::EPERM) => {
225 errors.push(FreezeError::PermissionDenied(pid));
226 }
227 Err(e) => {
228 errors.push(FreezeError::SignalFailed(format!("PID {}: {}", pid, e)));
229 }
230 }
231 }
232
233 if !errors.is_empty() && !any_success {
234 self.is_frozen.store(false, Ordering::SeqCst);
235 if errors.len() == 1 {
236 return Err(errors.remove(0));
237 }
238 return Err(FreezeError::Multiple(errors));
239 }
240
241 self.state_changed.notify_waiters();
242 Ok(true)
243 }
244
245 #[cfg(not(unix))]
246 pub fn freeze_at_epoch(&self, _epoch: u64) -> Result<bool, FreezeError> {
247 if self.mode == FreezeMode::Process {
248 Err(FreezeError::UnsupportedPlatform)
249 } else {
250 Ok(false)
251 }
252 }
253
254 #[cfg(unix)]
259 pub fn unfreeze(&self) -> Result<(), FreezeError> {
260 self.freeze_epoch.fetch_add(1, Ordering::SeqCst);
261
262 if !self.is_frozen.swap(false, Ordering::SeqCst) {
263 return Ok(());
264 }
265
266 let pids = self.pids.read().unwrap();
267 if pids.is_empty() {
268 return Ok(());
269 }
270
271 let mut errors = Vec::new();
272
273 for &pid in pids.iter() {
274 match kill(Pid::from_raw(pid), Signal::SIGCONT) {
275 Ok(()) => {}
276 Err(nix::errno::Errno::ESRCH) => {
277 errors.push(FreezeError::ProcessNotFound(pid));
278 }
279 Err(nix::errno::Errno::EPERM) => {
280 errors.push(FreezeError::PermissionDenied(pid));
281 }
282 Err(e) => {
283 errors.push(FreezeError::SignalFailed(format!("PID {}: {}", pid, e)));
284 }
285 }
286 }
287
288 self.state_changed.notify_waiters();
289
290 if errors.is_empty() {
291 Ok(())
292 } else if errors.len() == 1 {
293 Err(errors.remove(0))
294 } else {
295 Err(FreezeError::Multiple(errors))
296 }
297 }
298
299 #[cfg(not(unix))]
300 pub fn unfreeze(&self) -> Result<(), FreezeError> {
301 self.freeze_epoch.fetch_add(1, Ordering::SeqCst);
302 self.is_frozen.store(false, Ordering::SeqCst);
303 self.state_changed.notify_waiters();
304 Ok(())
305 }
306
307 #[cfg(unix)]
311 pub fn force_unfreeze(&self) {
312 self.freeze_epoch.fetch_add(1, Ordering::SeqCst);
313 self.is_frozen.store(false, Ordering::SeqCst);
314
315 let pids = self.pids.read().unwrap();
316 for &pid in pids.iter() {
317 let _ = kill(Pid::from_raw(pid), Signal::SIGCONT);
318 }
319 self.state_changed.notify_waiters();
320 }
321
322 #[cfg(not(unix))]
323 pub fn force_unfreeze(&self) {
324 self.freeze_epoch.fetch_add(1, Ordering::SeqCst);
325 self.is_frozen.store(false, Ordering::SeqCst);
326 self.state_changed.notify_waiters();
327 }
328}
329
330impl Default for FreezeState {
331 fn default() -> Self {
332 Self::new(FreezeMode::default(), None)
333 }
334}
335
336#[cfg(test)]
337mod tests {
338 use super::*;
339
340 #[test]
341 fn test_freeze_mode_default_is_none() {
342 assert_eq!(FreezeMode::default(), FreezeMode::None);
343 }
344
345 #[test]
346 fn test_freeze_state_default() {
347 let state = FreezeState::default();
348 assert_eq!(state.mode(), FreezeMode::None);
349 assert_eq!(state.pid(), None);
350 assert_eq!(state.pid_count(), 0);
351 assert!(!state.is_frozen());
352 assert_eq!(state.current_epoch(), 0);
353 }
354
355 #[test]
356 fn test_epoch_increments_on_unfreeze() {
357 let state = FreezeState::new(FreezeMode::None, None);
358 assert_eq!(state.current_epoch(), 0);
359
360 state.unfreeze().unwrap();
361 assert_eq!(state.current_epoch(), 1);
362
363 state.unfreeze().unwrap();
364 assert_eq!(state.current_epoch(), 2);
365 }
366
367 #[test]
368 fn test_freeze_without_process_mode_returns_false() {
369 let state = FreezeState::new(FreezeMode::None, Some(12345));
370 let result = state.freeze_at_epoch(0).unwrap();
371 assert!(!result);
372 assert!(!state.is_frozen());
373 }
374
375 #[cfg(unix)]
376 #[test]
377 fn test_freeze_without_pid_returns_error() {
378 let state = FreezeState::new(FreezeMode::Process, None);
379 let result = state.freeze_at_epoch(0);
380 assert!(matches!(result, Err(FreezeError::NoPidConfigured)));
381 }
382
383 #[test]
384 fn test_epoch_mismatch_prevents_freeze() {
385 let state = FreezeState::new(FreezeMode::Process, Some(12345));
386
387 state.unfreeze().unwrap();
388 assert_eq!(state.current_epoch(), 1);
389
390 let result = state.freeze_at_epoch(0);
391 assert!(matches!(result, Ok(false)));
392 assert!(!state.is_frozen());
393 }
394
395 #[test]
396 fn test_with_pids_creates_multi_pid_state() {
397 let state = FreezeState::with_pids(FreezeMode::Process, vec![111, 222, 333]);
398 assert_eq!(state.pid_count(), 3);
399 assert_eq!(state.pid(), Some(111));
400 }
401
402 #[test]
403 fn test_register_pid() {
404 let state = FreezeState::new(FreezeMode::Process, Some(111));
405 assert_eq!(state.pid_count(), 1);
406
407 state.register_pid(222);
408 assert_eq!(state.pid_count(), 2);
409
410 state.register_pid(333);
411 assert_eq!(state.pid_count(), 3);
412 }
413}