Skip to main content

toggle/
io.rs

1// File I/O operations for the Toggle CLI
2
3use crate::journal::{self, Journal, JournalEntry, JOURNAL_FILENAME, LOCK_FILENAME};
4use crate::platform;
5use similar::TextDiff;
6use std::fs::File;
7use std::io::{self, Read, Write};
8use std::path::{Path, PathBuf};
9use std::sync::atomic::{AtomicBool, Ordering};
10use std::sync::Arc;
11use tempfile::NamedTempFile;
12
13/// Read file content as UTF-8.
14pub fn read_file(path: &Path) -> io::Result<String> {
15    let mut file = File::open(path)?;
16    let mut content = String::new();
17    file.read_to_string(&mut content)?;
18    Ok(content)
19}
20
21/// Read file content with a specified encoding.
22/// Supports any encoding label recognized by the Encoding Standard
23/// (e.g., "utf-8", "latin-1", "iso-8859-1", "windows-1252", "ascii").
24pub fn read_file_encoded(path: &Path, encoding: &str) -> io::Result<String> {
25    if encoding.eq_ignore_ascii_case("utf-8") {
26        return read_file(path);
27    }
28    let bytes = std::fs::read(path)?;
29    let enc = resolve_encoding(encoding)?;
30    let (decoded, _, had_errors) = enc.decode(&bytes);
31    if had_errors {
32        return Err(io::Error::new(
33            io::ErrorKind::InvalidData,
34            format!("Failed to decode file as {}", encoding),
35        ));
36    }
37    Ok(decoded.into_owned())
38}
39
40/// Resolve an encoding label to an encoding_rs::Encoding.
41/// Handles common aliases like "latin-1" that encoding_rs doesn't directly recognize.
42fn resolve_encoding(label: &str) -> io::Result<&'static encoding_rs::Encoding> {
43    // Try direct lookup first
44    if let Some(enc) = encoding_rs::Encoding::for_label(label.as_bytes()) {
45        return Ok(enc);
46    }
47    // Handle common aliases not in the Encoding Standard
48    let alias = match label.to_ascii_lowercase().as_str() {
49        "latin-1" | "latin1" => Some("iso-8859-1"),
50        "ascii" | "us-ascii" => Some("windows-1252"),
51        _ => None,
52    };
53    if let Some(alias_label) = alias {
54        if let Some(enc) = encoding_rs::Encoding::for_label(alias_label.as_bytes()) {
55            return Ok(enc);
56        }
57    }
58    Err(io::Error::new(
59        io::ErrorKind::InvalidInput,
60        format!("Unsupported encoding: {}", label),
61    ))
62}
63
64/// Check if an encoding label is valid/supported.
65pub fn is_valid_encoding(label: &str) -> bool {
66    if label.eq_ignore_ascii_case("utf-8") {
67        return true;
68    }
69    resolve_encoding(label).is_ok()
70}
71
72/// Encode a string into bytes using the specified encoding.
73fn encode_string(content: &str, encoding: &str) -> io::Result<Vec<u8>> {
74    if encoding.eq_ignore_ascii_case("utf-8") {
75        return Ok(content.as_bytes().to_vec());
76    }
77    let enc = resolve_encoding(encoding)?;
78    let (encoded, _, had_errors) = enc.encode(content);
79    if had_errors {
80        return Err(io::Error::new(
81            io::ErrorKind::InvalidData,
82            format!("Failed to encode content as {}", encoding),
83        ));
84    }
85    Ok(encoded.into_owned())
86}
87
88/// Check if a path is a symbolic link.
89pub fn is_symlink(path: &Path) -> bool {
90    path.symlink_metadata()
91        .map(|m| m.file_type().is_symlink())
92        .unwrap_or(false)
93}
94
95/// Resolve symlink target to an absolute path.
96/// If the symlink target is relative, resolves it against the symlink's parent directory.
97fn resolve_symlink(path: &Path) -> io::Result<PathBuf> {
98    let target = std::fs::read_link(path)?;
99    if target.is_absolute() {
100        Ok(target)
101    } else {
102        let parent = path.parent().unwrap_or(Path::new("."));
103        Ok(parent.join(target))
104    }
105}
106
107/// Write file content atomically using a temp file + rename.
108/// If `temp_suffix` is provided, uses `path.<suffix>` as the temp file name.
109/// Otherwise uses a NamedTempFile in the same directory.
110/// If `no_dereference` is true and path is a symlink, writes to the symlink's
111/// target instead of replacing the symlink.
112pub fn write_file(path: &Path, content: &str, temp_suffix: Option<&str>) -> io::Result<()> {
113    write_bytes_impl(path, content.as_bytes(), temp_suffix, false)
114}
115
116/// Write file with optional symlink-aware behavior.
117pub fn write_file_no_deref(
118    path: &Path,
119    content: &str,
120    temp_suffix: Option<&str>,
121    no_dereference: bool,
122) -> io::Result<()> {
123    let bytes = content.as_bytes();
124    write_bytes_impl(path, bytes, temp_suffix, no_dereference)
125}
126
127/// Write file with encoding and symlink support.
128pub fn write_file_encoded(
129    path: &Path,
130    content: &str,
131    temp_suffix: Option<&str>,
132    no_dereference: bool,
133    encoding: &str,
134) -> io::Result<()> {
135    let bytes = encode_string(content, encoding)?;
136    write_bytes_impl(path, &bytes, temp_suffix, no_dereference)
137}
138
139fn write_bytes_impl(
140    path: &Path,
141    bytes: &[u8],
142    temp_suffix: Option<&str>,
143    no_dereference: bool,
144) -> io::Result<()> {
145    let write_path = if no_dereference && is_symlink(path) {
146        resolve_symlink(path)?
147    } else {
148        path.to_path_buf()
149    };
150    let dir = write_path.parent().unwrap_or(Path::new("."));
151
152    if let Some(suffix) = temp_suffix {
153        // Use explicit temp file name: file.py.tmp (append suffix, not replace extension)
154        let mut temp_name = write_path.as_os_str().to_os_string();
155        temp_name.push(".");
156        temp_name.push(suffix);
157        let temp_path = std::path::PathBuf::from(temp_name);
158        let mut file = File::create(&temp_path)?;
159        file.write_all(bytes)?;
160        file.sync_all()?;
161        std::fs::rename(&temp_path, &write_path)?;
162    } else {
163        // Use tempfile crate for safe atomic write
164        let mut tmp = NamedTempFile::new_in(dir)?;
165        tmp.write_all(bytes)?;
166        tmp.as_file().sync_all()?;
167        tmp.persist(&write_path).map_err(|e| e.error)?;
168    }
169
170    Ok(())
171}
172
173/// Print a unified diff between original and modified content.
174/// No-ops if content is identical.
175pub fn print_diff(path: &Path, original: &str, modified: &str) {
176    if original == modified {
177        return;
178    }
179    let diff = TextDiff::from_lines(original, modified);
180    let path_str = path.display().to_string();
181    print!(
182        "{}",
183        diff.unified_diff()
184            .header(&format!("a/{}", path_str), &format!("b/{}", path_str))
185    );
186}
187
188/// Create a backup copy of a file by appending the given extension.
189/// e.g., create_backup("file.py", ".bak") creates "file.py.bak"
190pub fn create_backup(path: &Path, extension: &str) -> io::Result<()> {
191    let mut backup_path = path.as_os_str().to_os_string();
192    backup_path.push(extension);
193    std::fs::copy(path, PathBuf::from(backup_path))?;
194    Ok(())
195}
196
197/// Normalize line endings in content.
198/// - "preserve": return unchanged
199/// - "lf": convert all line endings to \n
200/// - "crlf": convert all line endings to \r\n
201pub fn normalize_eol(content: &str, eol: &str) -> String {
202    match eol {
203        "lf" => content.replace("\r\n", "\n").replace('\r', "\n"),
204        "crlf" => {
205            // First normalize to LF, then convert to CRLF
206            let lf = content.replace("\r\n", "\n").replace('\r', "\n");
207            lf.replace('\n', "\r\n")
208        }
209        _ => content.to_string(), // "preserve" or any other value
210    }
211}
212
213/// Function to detect if a file has UTF-8 BOM
214pub fn has_utf8_bom(content: &[u8]) -> bool {
215    content.starts_with(&[0xEF, 0xBB, 0xBF])
216}
217
218/// Detect lines that should never be toggled: shebang and encoding pragma.
219/// Only checks the first two non-blank lines (shebangs are only valid on line 1,
220/// PEP 263 encoding pragmas on lines 1-2).
221/// Returns 0-based line indices of protected lines.
222pub fn detect_protected_lines(content: &str) -> Vec<usize> {
223    let mut protected = Vec::new();
224    let mut non_blank_seen = 0;
225
226    for (i, line) in content.lines().enumerate() {
227        let trimmed = line.trim();
228        if trimmed.is_empty() {
229            continue;
230        }
231
232        non_blank_seen += 1;
233        if non_blank_seen > 2 {
234            break;
235        }
236
237        // Shebang: must be first non-blank line
238        if non_blank_seen == 1 && trimmed.starts_with("#!") {
239            protected.push(i);
240        }
241
242        // PEP 263 encoding pragma: first or second non-blank line
243        if trimmed.starts_with('#')
244            && (trimmed.contains("coding:") || trimmed.contains("coding="))
245            && !protected.contains(&i)
246        {
247            protected.push(i);
248        }
249    }
250
251    protected
252}
253
254/// Encode a string for atomic mode staging. Public wrapper around encode_string.
255pub fn encode_for_atomic(content: &str, encoding: &str) -> io::Result<Vec<u8>> {
256    encode_string(content, encoding)
257}
258
259// ── Atomic multi-file batch operations ──
260
261/// Threshold for emitting a batch size warning.
262const BATCH_SIZE_WARNING_THRESHOLD: usize = 500;
263
264/// Default backup extension for atomic mode.
265const ATOMIC_BACKUP_EXT: &str = ".toggle-atomic-backup";
266
267/// A single staged write: temp file is written and fsynced, ready for rename.
268pub struct StagedWrite {
269    /// Path to the staged temp file (fd already released via into_temp_path).
270    pub temp_path: PathBuf,
271    /// Final target path.
272    pub target_path: PathBuf,
273    /// SHA-256 hex digest of the written content.
274    pub content_sha256: String,
275    /// Original file permissions to copy to temp before rename.
276    pub original_permissions: Option<std::fs::Permissions>,
277}
278
279/// Manages a two-phase atomic commit of multiple file writes.
280pub struct AtomicBatch {
281    staged: Vec<StagedWrite>,
282    journal_path: PathBuf,
283    lock_path: PathBuf,
284    _lock: Option<fd_lock::RwLock<File>>,
285    backup_enabled: bool,
286    interrupted: Arc<AtomicBool>,
287}
288
289impl AtomicBatch {
290    /// Create a new atomic batch. Acquires the lock file immediately.
291    /// `targets` is used to determine the journal directory.
292    /// `backup_enabled` controls whether hard-link backups are created.
293    /// `interrupted` is an AtomicBool set by signal handlers.
294    pub fn new(
295        targets: &[PathBuf],
296        backup_enabled: bool,
297        interrupted: Arc<AtomicBool>,
298    ) -> io::Result<Self> {
299        let dir = journal::journal_dir(targets)?;
300        let lock_path = dir.join(LOCK_FILENAME);
301        let journal_path = dir.join(JOURNAL_FILENAME);
302
303        // Acquire exclusive lock.
304        // We keep the RwLock (and its write guard implicitly via try_write)
305        // alive for the lifetime of the batch by storing the RwLock itself.
306        let lock_file = File::create(&lock_path)?;
307        let mut lock = fd_lock::RwLock::new(lock_file);
308        // Test that we can acquire the lock; this will fail if another
309        // atomic operation is running. The write guard is dropped immediately,
310        // but the underlying file descriptor (held by the RwLock) keeps the
311        // advisory lock on some platforms. We re-acquire below.
312        {
313            let _guard = lock.try_write().map_err(|_| {
314                io::Error::new(
315                    io::ErrorKind::WouldBlock,
316                    "Another atomic operation is already in progress in this directory. \
317                     Wait for it to complete or remove .toggle-atomic.lock if the previous \
318                     process crashed.",
319                )
320            })?;
321            // Guard dropped here but we keep the RwLock (and its fd) alive
322        }
323
324        Ok(Self {
325            staged: Vec::new(),
326            journal_path,
327            lock_path,
328            _lock: Some(lock),
329            backup_enabled,
330            interrupted,
331        })
332    }
333
334    /// Stage a single file write: write content to a temp file in the same
335    /// directory as the target, fsync it, then release the fd.
336    pub fn stage(&mut self, target_path: &Path, content: &[u8], _encoding: &str) -> io::Result<()> {
337        let target_dir = target_path.parent().unwrap_or(Path::new("."));
338        let mut tmp = NamedTempFile::new_in(target_dir)?;
339        let encoded = content.to_vec();
340        tmp.write_all(&encoded)?;
341        platform::durable_sync(tmp.as_file())?;
342
343        // Copy permissions from original file if it exists
344        let original_permissions = if target_path.exists() {
345            let meta = std::fs::metadata(target_path)?;
346            let perms = meta.permissions();
347            tmp.as_file().set_permissions(perms.clone()).ok();
348            Some(perms)
349        } else {
350            None
351        };
352
353        let content_sha256 = journal::sha256_hex(&encoded);
354
355        // Release the fd but keep the path for later rename
356        let temp_path_obj = tmp.into_temp_path();
357        let temp_path = temp_path_obj.to_path_buf();
358        // Prevent TempPath from deleting the file on drop — we manage it ourselves
359        temp_path_obj
360            .keep()
361            .map_err(|e| io::Error::other(format!("Failed to keep temp path: {}", e)))?;
362
363        self.staged.push(StagedWrite {
364            temp_path,
365            target_path: target_path.to_path_buf(),
366            content_sha256,
367            original_permissions,
368        });
369
370        Ok(())
371    }
372
373    /// Emit a warning if the batch size exceeds the threshold.
374    pub fn warn_if_large_batch(&self) {
375        if self.staged.len() > BATCH_SIZE_WARNING_THRESHOLD {
376            eprintln!(
377                "Warning: Staging {} files in atomic mode. Large batches may be \
378                 slow due to fsync overhead. Consider splitting into smaller \
379                 batches if performance is critical.",
380                self.staged.len()
381            );
382        }
383    }
384
385    /// Execute the two-phase commit: create backups, write journal, rename all.
386    /// Returns Ok(()) if all renames succeed. On failure, attempts rollback
387    /// if backups are enabled.
388    pub fn commit(self) -> io::Result<()> {
389        if self.staged.is_empty() {
390            self.cleanup_lock();
391            return Ok(());
392        }
393
394        self.warn_if_large_batch();
395
396        // Build journal entries
397        let mut journal_entries: Vec<JournalEntry> = Vec::with_capacity(self.staged.len());
398        for sw in &self.staged {
399            let backup_path = if self.backup_enabled {
400                let mut bp = sw.target_path.as_os_str().to_os_string();
401                bp.push(ATOMIC_BACKUP_EXT);
402                Some(PathBuf::from(bp))
403            } else {
404                None
405            };
406            journal_entries.push(JournalEntry {
407                target_path: sw.target_path.clone(),
408                temp_path: sw.temp_path.clone(),
409                backup_path,
410                content_sha256: sw.content_sha256.clone(),
411                rename_completed: false,
412            });
413        }
414
415        let mut j = Journal::new(journal_entries, self.backup_enabled);
416
417        // Persist journal in Staged state
418        journal::persist_journal(&j, &self.journal_path)?;
419
420        // Create hard-link backups if enabled
421        if self.backup_enabled {
422            for entry in &j.entries {
423                if let Some(ref backup_path) = entry.backup_path {
424                    if entry.target_path.exists() {
425                        if let Err(e) = std::fs::hard_link(&entry.target_path, backup_path) {
426                            eprintln!(
427                                "Error: failed to create backup for '{}': {}",
428                                entry.target_path.display(),
429                                e
430                            );
431                            self.rollback_staged(&j);
432                            return Err(e);
433                        }
434                    }
435                }
436            }
437        }
438
439        // Transition to Committing
440        j.transition_to_committing();
441        journal::persist_journal(&j, &self.journal_path)?;
442
443        if !self.backup_enabled {
444            eprintln!(
445                "Warning: Running without backups. If the rename phase fails, \
446                 rollback is not possible."
447            );
448        }
449
450        // Phase 2: Rename all temp files to targets
451        let entry_count = j.entries.len();
452        for idx in 0..entry_count {
453            // Check for signal interrupt between renames
454            if self.interrupted.load(Ordering::Relaxed) {
455                eprintln!("Interrupted. Journal preserved for recovery.");
456                journal::persist_journal(&j, &self.journal_path)?;
457                return Err(io::Error::new(
458                    io::ErrorKind::Interrupted,
459                    "Atomic commit interrupted by signal. \
460                     Run with --recover to clean up.",
461                ));
462            }
463
464            let temp_path = j.entries[idx].temp_path.clone();
465            let target_path = j.entries[idx].target_path.clone();
466
467            // Copy permissions before rename
468            if let Some(ref perms) = self.staged[idx].original_permissions {
469                let _ = std::fs::set_permissions(&temp_path, perms.clone());
470            }
471
472            match platform::rename_with_retry(&temp_path, &target_path) {
473                Ok(()) => {
474                    j.mark_entry_completed(idx);
475                    journal::persist_journal_best_effort(&j, &self.journal_path);
476                }
477                Err(e) => {
478                    eprintln!(
479                        "Error: rename failed for '{}': {}",
480                        target_path.display(),
481                        e
482                    );
483                    if self.backup_enabled {
484                        eprintln!("Attempting rollback...");
485                        if let Err(rb_err) = journal::recover_rollback(&j, &self.journal_path) {
486                            eprintln!("Rollback also failed: {}", rb_err);
487                        }
488                    } else {
489                        let _ = journal::persist_journal(&j, &self.journal_path);
490                        eprintln!(
491                            "No backups available. Journal preserved at '{}' for manual recovery.",
492                            self.journal_path.display()
493                        );
494                    }
495                    return Err(e);
496                }
497            }
498        }
499
500        // Finalization: fsync parent directories
501        let mut synced_dirs = std::collections::HashSet::new();
502        for entry in &j.entries {
503            if let Some(parent) = entry.target_path.parent() {
504                if synced_dirs.insert(parent.to_path_buf()) {
505                    let _ = platform::sync_dir(parent);
506                }
507            }
508        }
509
510        // Delete journal
511        journal::delete_journal(&self.journal_path)?;
512
513        // Clean up atomic backup files
514        if self.backup_enabled {
515            for entry in &j.entries {
516                if let Some(ref backup_path) = entry.backup_path {
517                    let _ = std::fs::remove_file(backup_path);
518                }
519            }
520        }
521
522        self.cleanup_lock();
523        Ok(())
524    }
525
526    /// Rollback from Staged state: delete all temp files and backups, delete journal.
527    fn rollback_staged(&self, journal: &Journal) {
528        for entry in &journal.entries {
529            if entry.temp_path.exists() {
530                let _ = std::fs::remove_file(&entry.temp_path);
531            }
532            if let Some(ref backup_path) = entry.backup_path {
533                if backup_path.exists() {
534                    let _ = std::fs::remove_file(backup_path);
535                }
536            }
537        }
538        let _ = journal::delete_journal(&self.journal_path);
539        self.cleanup_lock();
540    }
541
542    /// Clean up the lock file.
543    fn cleanup_lock(&self) {
544        let _ = std::fs::remove_file(&self.lock_path);
545    }
546}
547
548impl Drop for AtomicBatch {
549    fn drop(&mut self) {
550        // If we're being dropped without commit() having cleaned up,
551        // the lock file should still be removed.
552        // Note: staged temp files are NOT cleaned up on drop since we called
553        // keep() on them. The journal (if written) provides recovery info.
554    }
555}
556
557/// Trait abstracting filesystem operations for testability.
558/// Production code uses `RealFileOps`; tests can inject failures.
559pub trait FileOps {
560    fn rename(&self, from: &Path, to: &Path) -> io::Result<()>;
561    fn hard_link(&self, src: &Path, dst: &Path) -> io::Result<()>;
562    fn remove_file(&self, path: &Path) -> io::Result<()>;
563    fn sync_dir(&self, path: &Path) -> io::Result<()>;
564}
565
566/// Production filesystem operations using std::fs.
567pub struct RealFileOps;
568
569impl FileOps for RealFileOps {
570    fn rename(&self, from: &Path, to: &Path) -> io::Result<()> {
571        platform::rename_with_retry(from, to)
572    }
573
574    fn hard_link(&self, src: &Path, dst: &Path) -> io::Result<()> {
575        std::fs::hard_link(src, dst)
576    }
577
578    fn remove_file(&self, path: &Path) -> io::Result<()> {
579        std::fs::remove_file(path)
580    }
581
582    fn sync_dir(&self, path: &Path) -> io::Result<()> {
583        platform::sync_dir(path)
584    }
585}