use std::path::{Path, PathBuf};
use globset::GlobMatcher;
#[derive(Debug)]
pub(crate) struct RollbackReport {
pub restored_count: usize,
pub deleted_count: usize,
}
#[derive(Debug)]
enum EntryKind {
Existing { backup_path: PathBuf },
New,
}
#[derive(Debug)]
struct SnapshotEntry {
original: PathBuf,
kind: EntryKind,
}
pub(crate) struct TransactionSnapshot {
#[allow(dead_code)]
backup_dir: tempfile::TempDir,
entries: Vec<SnapshotEntry>,
}
impl std::fmt::Debug for TransactionSnapshot {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TransactionSnapshot")
.field("entry_count", &self.entries.len())
.finish_non_exhaustive()
}
}
impl TransactionSnapshot {
pub(crate) fn capture(paths: &[PathBuf], max_bytes: u64) -> Result<Self, std::io::Error> {
let backup_dir = tempfile::TempDir::new()?;
let mut entries = Vec::with_capacity(paths.len());
let mut cumulative_bytes: u64 = 0;
for (i, original) in paths.iter().enumerate() {
match original.symlink_metadata() {
Err(_) => {
entries.push(SnapshotEntry {
original: original.clone(),
kind: EntryKind::New,
});
continue;
}
Ok(meta) if meta.file_type().is_symlink() => {
tracing::debug!(
path = %original.display(),
"transaction snapshot: skipping symlink"
);
continue;
}
Ok(_) => {}
}
let backup_path = backup_dir
.path()
.join(format!("{i}_{}", file_name(original)));
std::fs::copy(original, &backup_path)?;
let meta = std::fs::metadata(original)?;
std::fs::set_permissions(&backup_path, meta.permissions())?;
cumulative_bytes += meta.len();
if max_bytes > 0 && cumulative_bytes > max_bytes {
return Err(std::io::Error::other(format!(
"snapshot size {cumulative_bytes} exceeds limit {max_bytes}"
)));
}
entries.push(SnapshotEntry {
original: original.clone(),
kind: EntryKind::Existing { backup_path },
});
}
Ok(Self {
backup_dir,
entries,
})
}
pub(crate) fn file_count(&self) -> usize {
self.entries.len()
}
pub(crate) fn total_bytes(&self) -> u64 {
self.entries
.iter()
.filter_map(|e| {
if let EntryKind::Existing { backup_path } = &e.kind {
std::fs::metadata(backup_path).map(|m| m.len()).ok()
} else {
None
}
})
.sum()
}
pub(crate) fn rollback(self) -> Result<RollbackReport, std::io::Error> {
let mut restored_count = 0usize;
let mut deleted_count = 0usize;
let mut first_error: Option<std::io::Error> = None;
for entry in &self.entries {
let result = match &entry.kind {
EntryKind::Existing { backup_path } => {
let dir_result = entry
.original
.parent()
.map_or(Ok(()), std::fs::create_dir_all);
dir_result
.and_then(|()| std::fs::copy(backup_path, &entry.original).map(|_| ()))
}
EntryKind::New => {
if entry.original.exists() {
std::fs::remove_file(&entry.original)
} else {
Ok(())
}
}
};
match result {
Ok(()) => match &entry.kind {
EntryKind::Existing { .. } => restored_count += 1,
EntryKind::New => {
if !entry.original.exists() {
deleted_count += 1;
}
}
},
Err(e) => {
tracing::warn!(
path = %entry.original.display(),
err = %e,
"rollback: failed to restore entry, continuing"
);
if first_error.is_none() {
first_error = Some(e);
}
}
}
}
if let Some(e) = first_error {
Err(e)
} else {
Ok(RollbackReport {
restored_count,
deleted_count,
})
}
}
}
fn file_name(path: &Path) -> String {
path.file_name()
.map_or_else(|| "file".to_owned(), |n| n.to_string_lossy().into_owned())
}
const WRITE_INDICATORS: &[&str] = &[
">",
">>",
"tee ",
"mv ",
"cp ",
"rm ",
"mkdir ",
"touch ",
"sed -i",
"chmod ",
"chown ",
"git checkout",
"cargo fmt",
"patch ",
];
pub(crate) fn is_write_command(command: &str) -> bool {
let lower = command.to_lowercase();
WRITE_INDICATORS.iter().any(|ind| lower.contains(ind))
}
pub(crate) fn extract_redirection_targets(command: &str) -> Vec<String> {
let mut targets = Vec::new();
let tokens: Vec<&str> = command.split_whitespace().collect();
let mut i = 0;
while i < tokens.len() {
let tok = tokens[i];
let is_redir = matches!(tok, ">" | ">>" | "2>" | "2>>" | "&>" | "&>>");
let is_glued_redir = tok.starts_with(">>")
|| tok.starts_with("2>>")
|| tok.starts_with("&>>")
|| tok.starts_with("2>")
|| tok.starts_with("&>")
|| (tok.starts_with('>') && tok.len() > 1 && !tok.starts_with(">>"));
if is_redir {
if let Some(next) = tokens.get(i + 1) {
if !next.starts_with('-') {
targets.push((*next).to_owned());
}
i += 2;
continue;
}
} else if is_glued_redir {
let path_part = tok
.trim_start_matches("&>>")
.trim_start_matches("2>>")
.trim_start_matches("&>")
.trim_start_matches("2>")
.trim_start_matches(">>")
.trim_start_matches('>');
if !path_part.is_empty() && !path_part.starts_with('-') {
targets.push(path_part.to_owned());
}
}
i += 1;
}
targets
}
pub(crate) fn affected_paths(command: &str, scope: &[GlobMatcher]) -> Vec<PathBuf> {
let mut raw: Vec<String> = super::extract_paths(command);
raw.extend(extract_redirection_targets(command));
raw.sort_unstable();
raw.dedup();
raw.into_iter()
.map(PathBuf::from)
.filter(|p| scope.is_empty() || scope.iter().any(|m| m.is_match(p)))
.collect()
}
pub(crate) fn build_scope_matchers(patterns: &[String]) -> Vec<GlobMatcher> {
patterns
.iter()
.filter_map(|pat| {
globset::Glob::new(pat)
.map(|g| g.compile_matcher())
.map_err(
|e| tracing::warn!(pattern = %pat, err = %e, "invalid transaction_scope glob"),
)
.ok()
})
.collect()
}