use std::collections::HashSet;
use std::path::PathBuf;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Mutex};
#[derive(Clone)]
pub struct BackupContext {
interrupt_flag: Arc<AtomicBool>,
active_operations: Arc<Mutex<HashSet<PathBuf>>>,
}
impl BackupContext {
pub fn new() -> Self {
Self {
interrupt_flag: Arc::new(AtomicBool::new(false)),
active_operations: Arc::new(Mutex::new(HashSet::new())),
}
}
pub fn interrupt_flag(&self) -> Arc<AtomicBool> {
self.interrupt_flag.clone()
}
pub fn is_interrupted(&self) -> bool {
self.interrupt_flag.load(Ordering::SeqCst)
}
pub fn set_interrupted(&self, interrupted: bool) {
self.interrupt_flag.store(interrupted, Ordering::SeqCst);
}
pub fn register_operation(&self, backup_path: PathBuf) -> BackupOperationGuard {
if let Ok(mut operations) = self.active_operations.lock() {
operations.insert(backup_path.clone());
}
BackupOperationGuard::new(backup_path, self.clone())
}
pub fn get_active_operations(&self) -> Vec<PathBuf> {
self.active_operations
.lock()
.map(|operations| operations.iter().cloned().collect())
.unwrap_or_default()
}
pub fn cleanup_active_operations(&self) {
self.cleanup_active_operations_with_mode(false);
}
pub fn cleanup_active_operations_with_mode(&self, silent: bool) {
let active_ops = self.get_active_operations();
for backup_path in active_ops {
if backup_path.exists() {
let cleanup_result = if backup_path.is_dir() {
std::fs::remove_dir_all(&backup_path)
} else {
std::fs::remove_file(&backup_path)
};
if !silent {
if cleanup_result.is_ok() {
eprintln!("Cleaned up incomplete backup: {}", backup_path.display());
} else {
eprintln!(
"Warning: Could not clean up incomplete backup: {}",
backup_path.display()
);
}
}
}
}
if let Ok(mut operations) = self.active_operations.lock() {
operations.clear();
}
}
fn remove_operation(&self, backup_path: &PathBuf) {
if let Ok(mut operations) = self.active_operations.lock() {
operations.remove(backup_path);
}
}
}
impl Default for BackupContext {
fn default() -> Self {
Self::new()
}
}
static GLOBAL_CONTEXT: Mutex<Option<BackupContext>> = Mutex::new(None);
pub fn set_global_context(context: BackupContext) {
if let Ok(mut global) = GLOBAL_CONTEXT.lock() {
*global = Some(context);
}
}
fn get_global_context() -> Option<BackupContext> {
GLOBAL_CONTEXT.lock().ok().and_then(|global| global.clone())
}
pub fn set_interrupt_flag(flag: Arc<AtomicBool>) {
let context = BackupContext {
interrupt_flag: flag,
active_operations: Arc::new(Mutex::new(HashSet::new())),
};
set_global_context(context);
}
pub fn is_interrupted() -> bool {
get_global_context()
.map(|ctx| ctx.is_interrupted())
.unwrap_or(false)
}
#[cfg(test)]
pub fn reset_for_testing() {
if let Ok(mut global) = GLOBAL_CONTEXT.lock() {
*global = Some(BackupContext::new());
}
}
pub struct BackupOperationGuard {
backup_path: PathBuf,
context: BackupContext,
registered: bool,
completed: bool,
}
impl BackupOperationGuard {
pub fn new(backup_path: PathBuf, context: BackupContext) -> Self {
Self {
backup_path,
context,
registered: true,
completed: false,
}
}
pub fn complete(mut self) {
self.context.remove_operation(&self.backup_path);
self.registered = false;
self.completed = true;
}
}
impl Drop for BackupOperationGuard {
fn drop(&mut self) {
if self.registered && !self.completed {
if !self.context.is_interrupted() {
self.context.remove_operation(&self.backup_path);
}
}
}
}
pub fn get_active_operations() -> Vec<PathBuf> {
get_global_context()
.map(|ctx| ctx.get_active_operations())
.unwrap_or_default()
}
pub fn cleanup_active_operations() {
cleanup_active_operations_with_mode(false);
}
pub fn cleanup_active_operations_with_mode(silent: bool) {
if let Some(context) = get_global_context() {
context.cleanup_active_operations_with_mode(silent);
}
}
pub fn create_backup_guard(backup_path: PathBuf) -> BackupOperationGuard {
if let Some(context) = get_global_context() {
context.register_operation(backup_path)
} else {
let context = BackupContext::new();
set_global_context(context.clone());
context.register_operation(backup_path)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::fs;
use tempfile::tempdir;
#[test]
fn test_backup_operation_guard_normal_completion() {
let dir = tempdir().unwrap();
let backup_path = dir.path().join("test-backup");
fs::create_dir(&backup_path).unwrap();
let context = BackupContext::new();
let guard = context.register_operation(backup_path.clone());
let active_ops = context.get_active_operations();
assert!(active_ops.contains(&backup_path));
guard.complete();
let active_ops = context.get_active_operations();
assert!(!active_ops.contains(&backup_path));
assert!(backup_path.exists());
}
#[test]
fn test_backup_operation_guard_interrupted() {
let dir = tempdir().unwrap();
let backup_path = dir.path().join("test-backup");
fs::create_dir(&backup_path).unwrap();
let context = BackupContext::new();
{
let _guard = context.register_operation(backup_path.clone());
let active_ops = context.get_active_operations();
assert!(active_ops.contains(&backup_path));
}
let active_ops = context.get_active_operations();
assert!(!active_ops.contains(&backup_path));
}
#[test]
fn test_cleanup_active_operations() {
let dir = tempdir().unwrap();
let backup_path1 = dir.path().join("backup1");
let backup_path2 = dir.path().join("backup2");
fs::create_dir(&backup_path1).unwrap();
fs::write(&backup_path2, "content").unwrap();
let context = BackupContext::new();
let _guard1 = context.register_operation(backup_path1.clone());
let _guard2 = context.register_operation(backup_path2.clone());
let active_ops = context.get_active_operations();
assert!(active_ops.contains(&backup_path1));
assert!(active_ops.contains(&backup_path2));
assert!(backup_path1.exists());
assert!(backup_path2.exists());
context.cleanup_active_operations();
let active_ops = context.get_active_operations();
assert!(!active_ops.contains(&backup_path1));
assert!(!active_ops.contains(&backup_path2));
assert!(!backup_path1.exists());
assert!(!backup_path2.exists());
}
#[test]
fn test_cleanup_nonexistent_operations() {
let dir = tempdir().unwrap();
let backup_path = dir.path().join("nonexistent-backup");
let context = BackupContext::new();
let _guard = context.register_operation(backup_path.clone());
context.cleanup_active_operations();
let active_ops = context.get_active_operations();
assert!(!active_ops.contains(&backup_path));
}
#[test]
fn test_interrupt_race_condition_fix() {
let dir = tempdir().unwrap();
let backup_path = dir.path().join("test-backup-interrupted");
fs::create_dir_all(&backup_path).unwrap();
fs::write(backup_path.join("partial.txt"), "partial content").unwrap();
let context = BackupContext::new();
{
let _guard = context.register_operation(backup_path.clone());
assert!(backup_path.exists());
let active_ops = context.get_active_operations();
assert!(active_ops.contains(&backup_path));
context.set_interrupted(true);
assert!(context.is_interrupted(), "Interrupt flag should be set");
}
let active_ops = context.get_active_operations();
assert!(
active_ops.contains(&backup_path),
"Operation should still be tracked after interrupted guard drops"
);
assert!(backup_path.exists());
context.cleanup_active_operations_with_mode(true);
assert!(
!backup_path.exists(),
"Backup should be cleaned up by signal handler"
);
let remaining_ops = context.get_active_operations();
assert!(
remaining_ops.is_empty(),
"No operations should remain tracked"
);
}
}