Skip to main content

chkpt_core/ops/
restore.rs

1use crate::config::{project_id_from_path, StoreLayout};
2use crate::error::{ChkpttError, Result};
3use crate::index::FileIndex;
4use crate::ops::io_order::sort_scanned_for_locality;
5use crate::ops::lock::ProjectLock;
6use crate::scanner::ScannedFile;
7use crate::store::blob::hash_path_bytes;
8use crate::store::catalog::{ManifestEntry, MetadataCatalog};
9use crate::store::pack::{PackLocation, PackSet};
10use crate::store::tree::{EntryType, TreeStore};
11use std::collections::{BTreeMap, HashMap, HashSet};
12use std::io::{BufWriter, Write};
13use std::path::Path;
14use std::sync::atomic::{AtomicU64, Ordering};
15
16use crate::ops::progress::{emit, ProgressCallback, ProgressEvent};
17
18#[derive(Default)]
19pub struct RestoreOptions {
20    pub dry_run: bool,
21    pub progress: ProgressCallback,
22}
23
24#[derive(Debug)]
25pub struct RestoreResult {
26    pub snapshot_id: String,
27    pub files_added: u64,
28    pub files_changed: u64,
29    pub files_removed: u64,
30    pub files_unchanged: u64,
31}
32
33struct CurrentFileState {
34    hash: [u8; 32],
35    is_symlink: bool,
36}
37
38struct TargetFileState {
39    hash: [u8; 32],
40    is_symlink: bool,
41}
42
43struct RestoreDiff {
44    files_to_add: Vec<String>,
45    files_to_change: Vec<String>,
46    files_to_remove: Vec<String>,
47    files_unchanged: u64,
48}
49
50#[derive(Debug, Clone, Copy)]
51enum RestoreSource {
52    Packed(PackLocation),
53}
54
55#[derive(Debug, Clone)]
56struct RestoreTask {
57    path: String,
58    is_symlink: bool,
59    source: RestoreSource,
60}
61
62/// Convert a [u8; 32] to a 64-char hex string.
63fn bytes_to_hex(bytes: &[u8; 32]) -> String {
64    blake3::Hash::from(*bytes).to_hex().to_string()
65}
66
67fn join_relative_path(prefix: &str, name: &str) -> String {
68    if prefix.is_empty() {
69        return name.to_owned();
70    }
71
72    let mut path = String::with_capacity(prefix.len() + 1 + name.len());
73    path.push_str(prefix);
74    path.push('/');
75    path.push_str(name);
76    path
77}
78
79/// Recursively walk a tree and collect all file entries as (relative_path, blob_hash_hex).
80fn collect_tree_files(
81    tree_store: &TreeStore,
82    tree_hash_hex: &str,
83    prefix: &str,
84    result: &mut BTreeMap<String, TargetFileState>,
85) -> Result<()> {
86    let entries = tree_store.read(tree_hash_hex)?;
87    for entry in &entries {
88        let path = join_relative_path(prefix, &entry.name);
89        match entry.entry_type {
90            EntryType::File => {
91                result.insert(
92                    path,
93                    TargetFileState {
94                        hash: entry.hash,
95                        is_symlink: false,
96                    },
97                );
98            }
99            EntryType::Dir => {
100                let subtree_hash_hex = bytes_to_hex(&entry.hash);
101                collect_tree_files(tree_store, &subtree_hash_hex, &path, result)?;
102            }
103            EntryType::Symlink => {
104                result.insert(
105                    path,
106                    TargetFileState {
107                        hash: entry.hash,
108                        is_symlink: true,
109                    },
110                );
111            }
112        }
113    }
114    Ok(())
115}
116
117fn target_state_from_manifest(manifest: &[ManifestEntry]) -> BTreeMap<String, TargetFileState> {
118    manifest
119        .iter()
120        .map(|entry| {
121            (
122                entry.path.clone(),
123                TargetFileState {
124                    hash: entry.blob_hash,
125                    is_symlink: mode_is_symlink(entry.mode),
126                },
127            )
128        })
129        .collect()
130}
131
132/// Scan the current workspace to get a mapping of (relative_path -> content_hash_hex).
133///
134/// This uses the scanner to discover files, then hashes each file to get the current
135/// content hash for comparison with the target snapshot state.
136fn scan_current_state(
137    workspace_root: &Path,
138    cached_entries: &HashMap<String, crate::index::FileEntry>,
139    include_deps: bool,
140) -> Result<BTreeMap<String, CurrentFileState>> {
141    let scanned = crate::scanner::scan_workspace_with_options(workspace_root, None, include_deps)?;
142    let mut state = BTreeMap::new();
143    let mut stale_files = Vec::with_capacity(scanned.len());
144
145    for file in scanned {
146        if let Some(hash) = cached_hash_bytes(&file, cached_entries) {
147            state.insert(
148                file.relative_path.clone(),
149                CurrentFileState {
150                    hash,
151                    is_symlink: file.is_symlink,
152                },
153            );
154        } else {
155            stale_files.push(file);
156        }
157    }
158
159    for (file, hash) in hash_scanned_files(stale_files)? {
160        state.insert(
161            file.relative_path.clone(),
162            CurrentFileState {
163                hash,
164                is_symlink: file.is_symlink,
165            },
166        );
167    }
168    Ok(state)
169}
170
171fn restore_files(
172    workspace_root: &Path,
173    restore_tasks: &[RestoreTask],
174    pack_set: &PackSet,
175    progress: &ProgressCallback,
176    progress_counter: &AtomicU64,
177    restore_total: u64,
178) -> Result<Vec<String>> {
179    if restore_tasks.is_empty() {
180        return Ok(Vec::new());
181    }
182
183    let worker_count = std::thread::available_parallelism()
184        .map(|count| count.get())
185        .unwrap_or(1)
186        .min(restore_tasks.len());
187    if worker_count <= 1 {
188        let mut restored = Vec::with_capacity(restore_tasks.len());
189        for task in restore_tasks {
190            restore_file(workspace_root, task, pack_set)?;
191            let completed = progress_counter.fetch_add(1, Ordering::Relaxed) + 1;
192            emit(
193                progress,
194                ProgressEvent::RestoreFile {
195                    completed,
196                    total: restore_total,
197                },
198            );
199            restored.push(task.path.clone());
200        }
201        return Ok(restored);
202    }
203
204    let chunk_size = restore_tasks.len().div_ceil(worker_count);
205    std::thread::scope(|scope| {
206        let workers: Vec<_> = restore_tasks
207            .chunks(chunk_size)
208            .map(|chunk| {
209                scope.spawn(move || -> Result<Vec<String>> {
210                    let mut restored = Vec::with_capacity(chunk.len());
211                    for task in chunk {
212                        restore_file(workspace_root, task, pack_set)?;
213                        let completed = progress_counter.fetch_add(1, Ordering::Relaxed) + 1;
214                        emit(
215                            progress,
216                            ProgressEvent::RestoreFile {
217                                completed,
218                                total: restore_total,
219                            },
220                        );
221                        restored.push(task.path.clone());
222                    }
223                    Ok(restored)
224                })
225            })
226            .collect();
227
228        let mut restored_paths = Vec::with_capacity(restore_tasks.len());
229        for worker in workers {
230            let chunk = worker
231                .join()
232                .map_err(|_| ChkpttError::Other("restore worker thread panicked".into()))??;
233            restored_paths.extend(chunk);
234        }
235        Ok(restored_paths)
236    })
237}
238
239fn restore_file(workspace_root: &Path, task: &RestoreTask, pack_set: &PackSet) -> Result<()> {
240    let file_path = workspace_root.join(&task.path);
241    if let Some(parent) = file_path.parent() {
242        std::fs::create_dir_all(parent)?;
243    }
244
245    if let Ok(metadata) = std::fs::symlink_metadata(&file_path) {
246        if metadata.file_type().is_symlink() || task.is_symlink {
247            std::fs::remove_file(&file_path)?;
248        }
249    }
250
251    match task.source {
252        RestoreSource::Packed(location) => {
253            if task.is_symlink {
254                let mut content = Vec::new();
255                pack_set.copy_to_writer(&location, &mut content)?;
256                restore_symlink(&file_path, &content)?;
257            } else {
258                let file = std::fs::File::create(&file_path)?;
259                let mut writer = BufWriter::with_capacity(256 * 1024, file);
260                pack_set.copy_to_writer(&location, &mut writer)?;
261                writer.flush()?;
262            }
263        }
264    }
265
266    Ok(())
267}
268
269#[cfg(unix)]
270fn restore_symlink(path: &Path, target_bytes: &[u8]) -> Result<()> {
271    use std::os::unix::ffi::OsStrExt;
272    let target = std::ffi::OsStr::from_bytes(target_bytes);
273    std::os::unix::fs::symlink(target, path)?;
274    Ok(())
275}
276
277#[cfg(not(unix))]
278fn restore_symlink(_path: &Path, _target_bytes: &[u8]) -> Result<()> {
279    Err(ChkpttError::RestoreFailed(
280        "symlink restore is only supported on unix platforms".into(),
281    ))
282}
283
284fn resolve_restore_sources(
285    files_to_add: &[String],
286    files_to_change: &[String],
287    target_state: &BTreeMap<String, TargetFileState>,
288    catalog: &MetadataCatalog,
289    packs_dir: &Path,
290) -> Result<(PackSet, HashMap<[u8; 32], RestoreSource>)> {
291    let candidate_count = files_to_add.len() + files_to_change.len();
292    let mut sources = HashMap::with_capacity(candidate_count);
293    let mut seen_hashes = HashSet::with_capacity(candidate_count);
294    let mut packed_hashes = Vec::with_capacity(candidate_count);
295
296    for path in files_to_add.iter().chain(files_to_change.iter()) {
297        let target = target_state
298            .get(path)
299            .expect("target hash missing for restore source");
300        if !seen_hashes.insert(target.hash) {
301            continue;
302        }
303        packed_hashes.push(target.hash);
304    }
305
306    if packed_hashes.is_empty() {
307        return Ok((PackSet::empty(), sources));
308    }
309
310    let blob_locations = catalog.blob_locations_for_hashes(&packed_hashes)?;
311    let mut selected_pack_hashes = HashSet::with_capacity(packed_hashes.len());
312    for hash in &packed_hashes {
313        let location = blob_locations
314            .get(hash)
315            .ok_or_else(|| ChkpttError::ObjectNotFound(bytes_to_hex(hash)))?;
316        let pack_hash = location.pack_hash.as_ref().ok_or_else(|| {
317            ChkpttError::StoreCorrupted(format!(
318                "blob {} is not stored in a pack",
319                bytes_to_hex(hash)
320            ))
321        })?;
322        selected_pack_hashes.insert(pack_hash.clone());
323    }
324
325    let mut pack_hashes: Vec<_> = selected_pack_hashes.into_iter().collect();
326    pack_hashes.sort_unstable();
327    let pack_set = PackSet::open_selected(packs_dir, &pack_hashes)?;
328
329    for hash in packed_hashes {
330        let pack_hash = blob_locations
331            .get(&hash)
332            .and_then(|location| location.pack_hash.as_ref())
333            .expect("pack hash missing after validation");
334        let location = pack_set
335            .locate_in_pack_bytes(pack_hash, &hash)
336            .ok_or_else(|| ChkpttError::ObjectNotFound(bytes_to_hex(&hash)))?;
337        sources.insert(hash, RestoreSource::Packed(location));
338    }
339
340    Ok((pack_set, sources))
341}
342
343fn build_restore_tasks(
344    files_to_add: &[String],
345    files_to_change: &[String],
346    target_state: &BTreeMap<String, TargetFileState>,
347    restore_sources: &HashMap<[u8; 32], RestoreSource>,
348) -> Result<Vec<RestoreTask>> {
349    let mut tasks = Vec::with_capacity(files_to_add.len() + files_to_change.len());
350
351    for path in files_to_add.iter().chain(files_to_change.iter()) {
352        let target = target_state
353            .get(path)
354            .expect("target hash missing for restore task");
355        let source = *restore_sources
356            .get(&target.hash)
357            .ok_or_else(|| ChkpttError::ObjectNotFound(bytes_to_hex(&target.hash)))?;
358
359        tasks.push(RestoreTask {
360            path: path.clone(),
361            is_symlink: target.is_symlink,
362            source,
363        });
364    }
365
366    tasks.sort_unstable_by(|left, right| match (&left.source, &right.source) {
367        (RestoreSource::Packed(left_location), RestoreSource::Packed(right_location)) => (
368            left_location.reader_index,
369            left_location.offset,
370            left.path.as_str(),
371        )
372            .cmp(&(
373                right_location.reader_index,
374                right_location.offset,
375                right.path.as_str(),
376            )),
377    });
378    Ok(tasks)
379}
380
381fn diff_restore_states(
382    target_state: &BTreeMap<String, TargetFileState>,
383    current_state: &BTreeMap<String, CurrentFileState>,
384) -> RestoreDiff {
385    let mut files_to_add = Vec::with_capacity(target_state.len());
386    let mut files_to_change = Vec::with_capacity(target_state.len().min(current_state.len()));
387    let mut files_to_remove = Vec::with_capacity(current_state.len());
388    let mut files_unchanged = 0;
389
390    let mut target_iter = target_state.iter().peekable();
391    let mut current_iter = current_state.iter().peekable();
392
393    loop {
394        match (target_iter.peek(), current_iter.peek()) {
395            (Some((target_path, target_file)), Some((current_path, current_file))) => {
396                match target_path.cmp(current_path) {
397                    std::cmp::Ordering::Less => {
398                        files_to_add.push((*target_path).clone());
399                        target_iter.next();
400                    }
401                    std::cmp::Ordering::Greater => {
402                        files_to_remove.push((*current_path).clone());
403                        current_iter.next();
404                    }
405                    std::cmp::Ordering::Equal => {
406                        if target_file.hash != current_file.hash
407                            || target_file.is_symlink != current_file.is_symlink
408                        {
409                            files_to_change.push((*target_path).clone());
410                        } else {
411                            files_unchanged += 1;
412                        }
413                        target_iter.next();
414                        current_iter.next();
415                    }
416                }
417            }
418            (Some((target_path, _)), None) => {
419                files_to_add.push((*target_path).clone());
420                target_iter.next();
421            }
422            (None, Some((current_path, _))) => {
423                files_to_remove.push((*current_path).clone());
424                current_iter.next();
425            }
426            (None, None) => break,
427        }
428    }
429
430    RestoreDiff {
431        files_to_add,
432        files_to_change,
433        files_to_remove,
434        files_unchanged,
435    }
436}
437
438/// Restore workspace to a snapshot state.
439///
440/// This is the main restore function that:
441/// 1. Resolves the snapshot ID ("latest" or prefix match)
442/// 2. Loads the snapshot and reconstructs the target file state from the tree
443/// 3. Compares target state vs current workspace state
444/// 4. Either reports what would change (dry_run) or performs the actual restore
445pub fn restore(
446    workspace_root: &Path,
447    snapshot_id: &str,
448    options: RestoreOptions,
449) -> Result<RestoreResult> {
450    // 1. Compute project_id, create StoreLayout
451    let project_id = project_id_from_path(workspace_root);
452    let layout = StoreLayout::new(&project_id);
453    layout.ensure_dirs()?;
454
455    // 2. Acquire project lock
456    let _lock = ProjectLock::acquire(&layout.locks_dir())?;
457    let catalog = MetadataCatalog::open(layout.catalog_path())?;
458
459    // 3. Resolve snapshot ID
460    let resolved_snapshot = catalog.resolve_snapshot_ref(snapshot_id)?;
461    let resolved_id = resolved_snapshot.id.clone();
462
463    // 4. Load snapshot's tree to get target state (path -> blob_hash_hex)
464    let manifest = catalog.snapshot_manifest(&resolved_id)?;
465    let target_state = if resolved_snapshot.stats.total_files == 0 {
466        BTreeMap::new()
467    } else if manifest.is_empty() {
468        let tree_store = TreeStore::new(layout.trees_dir());
469        let root_tree_hash = resolved_snapshot.root_tree_hash.ok_or_else(|| {
470            ChkpttError::StoreCorrupted(format!(
471                "snapshot '{}' is missing both manifest entries and root_tree_hash",
472                resolved_id
473            ))
474        })?;
475        let root_tree_hash_hex = bytes_to_hex(&root_tree_hash);
476        let mut state = BTreeMap::new();
477        collect_tree_files(&tree_store, &root_tree_hash_hex, "", &mut state)?;
478        state
479    } else {
480        target_state_from_manifest(&manifest)
481    };
482    let target_includes_deps = target_state
483        .keys()
484        .any(|path| path_contains_dependency_dir(path));
485
486    // 5. Scan current workspace to get current state (path -> content_hash_hex)
487    let mut index = FileIndex::open(layout.index_path())?;
488    let cached_entries = index.entries();
489    let current_state = scan_current_state(workspace_root, &cached_entries, target_includes_deps)?;
490    emit(
491        &options.progress,
492        ProgressEvent::ScanCurrentComplete {
493            file_count: current_state.len() as u64,
494        },
495    );
496
497    // 6. Compare target state vs current state
498    let diff = diff_restore_states(&target_state, &current_state);
499    let files_to_add = diff.files_to_add;
500    let files_to_change = diff.files_to_change;
501    let files_to_remove = diff.files_to_remove;
502    let files_unchanged = diff.files_unchanged;
503
504    let result = RestoreResult {
505        snapshot_id: resolved_id.clone(),
506        files_added: files_to_add.len() as u64,
507        files_changed: files_to_change.len() as u64,
508        files_removed: files_to_remove.len() as u64,
509        files_unchanged,
510    };
511
512    // 7. If dry_run, return result without modifying workspace
513    if options.dry_run {
514        return Ok(result);
515    }
516
517    // 8. Perform actual restore
518    let packs_dir = layout.packs_dir();
519    let (pack_set, restore_sources) = resolve_restore_sources(
520        &files_to_add,
521        &files_to_change,
522        &target_state,
523        &catalog,
524        &packs_dir,
525    )?;
526
527    let restore_total = (files_to_add.len() + files_to_change.len() + files_to_remove.len()) as u64;
528    emit(
529        &options.progress,
530        ProgressEvent::RestoreStart {
531            add: files_to_add.len() as u64,
532            change: files_to_change.len() as u64,
533            remove: files_to_remove.len() as u64,
534        },
535    );
536
537    // 8a. Restore files that need to be added or changed (parallel)
538    let restore_tasks = build_restore_tasks(
539        &files_to_add,
540        &files_to_change,
541        &target_state,
542        &restore_sources,
543    )?;
544    let restore_progress = AtomicU64::new(0);
545    let restored_paths = restore_files(
546        workspace_root,
547        &restore_tasks,
548        &pack_set,
549        &options.progress,
550        &restore_progress,
551        restore_total,
552    )?;
553
554    // 8b. Remove files that are not in the target snapshot
555    for path in &files_to_remove {
556        let file_path = workspace_root.join(path);
557        match std::fs::remove_file(&file_path) {
558            Ok(()) => {}
559            Err(error) if error.kind() == std::io::ErrorKind::NotFound => {}
560            Err(error) => return Err(error.into()),
561        }
562        let completed = restore_progress.fetch_add(1, Ordering::Relaxed) + 1;
563        emit(
564            &options.progress,
565            ProgressEvent::RestoreFile {
566                completed,
567                total: restore_total,
568            },
569        );
570    }
571
572    // 8c. Clean up empty directories affected by removed files only.
573    cleanup_removed_file_parents(workspace_root, &files_to_remove)?;
574
575    let file_entries = restored_index_entries(workspace_root, &restored_paths, &target_state)?;
576    index.apply_changes(&files_to_remove, &file_entries)?;
577
578    Ok(result)
579}
580
581fn path_contains_dependency_dir(relative_path: &str) -> bool {
582    relative_path.split('/').any(|component| {
583        matches!(
584            component,
585            "node_modules"
586                | ".venv"
587                | "venv"
588                | "__pypackages__"
589                | ".tox"
590                | ".nox"
591                | ".gradle"
592                | ".m2"
593        )
594    })
595}
596
597fn mode_is_symlink(mode: u32) -> bool {
598    (mode & 0o170000) == 0o120000
599}
600
601fn restored_index_entries(
602    workspace_root: &Path,
603    restored_paths: &[String],
604    target_state: &BTreeMap<String, TargetFileState>,
605) -> Result<Vec<crate::index::FileEntry>> {
606    let mut file_entries = Vec::with_capacity(restored_paths.len());
607    for path in restored_paths {
608        let absolute_path = workspace_root.join(path);
609        let metadata = std::fs::symlink_metadata(&absolute_path)?;
610        let target = target_state.get(path).ok_or_else(|| {
611            ChkpttError::RestoreFailed(format!("Missing target hash for {}", path))
612        })?;
613        let scanned = scanned_file_from_metadata(path.clone(), absolute_path, &metadata);
614
615        file_entries.push(crate::index::FileEntry {
616            path: scanned.relative_path,
617            blob_hash: target.hash,
618            size: scanned.size,
619            mtime_secs: scanned.mtime_secs,
620            mtime_nanos: scanned.mtime_nanos,
621            inode: scanned.inode,
622            mode: scanned.mode,
623        });
624    }
625    Ok(file_entries)
626}
627
628fn cached_hash_bytes(
629    file: &ScannedFile,
630    cached_entries: &HashMap<String, crate::index::FileEntry>,
631) -> Option<[u8; 32]> {
632    let cached = cached_entries.get(&file.relative_path)?;
633    if cached.mtime_secs == file.mtime_secs
634        && cached.mtime_nanos == file.mtime_nanos
635        && cached.size == file.size
636        && cached.inode == file.inode
637        && cached.mode == file.mode
638    {
639        Some(cached.blob_hash)
640    } else {
641        None
642    }
643}
644
645fn hash_scanned_files(scanned_files: Vec<ScannedFile>) -> Result<Vec<(ScannedFile, [u8; 32])>> {
646    if scanned_files.is_empty() {
647        return Ok(Vec::new());
648    }
649    let mut scanned_files = scanned_files;
650    sort_scanned_for_locality(&mut scanned_files);
651
652    let worker_count = std::thread::available_parallelism()
653        .map(|count| count.get())
654        .unwrap_or(1)
655        .min(scanned_files.len());
656    if worker_count <= 1 {
657        return scanned_files
658            .into_iter()
659            .map(|file| {
660                Ok((
661                    file.clone(),
662                    hash_path_bytes(&file.absolute_path, file.is_symlink)?,
663                ))
664            })
665            .collect();
666    }
667
668    let chunk_size = scanned_files.len().div_ceil(worker_count);
669    std::thread::scope(|scope| {
670        let mut workers = Vec::with_capacity(scanned_files.len().div_ceil(chunk_size));
671        for chunk in scanned_files.chunks(chunk_size) {
672            workers.push(
673                scope.spawn(move || -> Result<Vec<(ScannedFile, [u8; 32])>> {
674                    chunk
675                        .iter()
676                        .map(|file| {
677                            Ok((
678                                file.clone(),
679                                hash_path_bytes(&file.absolute_path, file.is_symlink)?,
680                            ))
681                        })
682                        .collect()
683                }),
684            );
685        }
686
687        let mut hashed = Vec::with_capacity(scanned_files.len());
688        for worker in workers {
689            let chunk = worker
690                .join()
691                .map_err(|_| ChkpttError::Other("restore worker thread panicked".into()))??;
692            hashed.extend(chunk);
693        }
694        Ok(hashed)
695    })
696}
697
698#[cfg(unix)]
699fn scanned_file_from_metadata(
700    relative_path: String,
701    absolute_path: std::path::PathBuf,
702    metadata: &std::fs::Metadata,
703) -> ScannedFile {
704    use std::os::unix::fs::MetadataExt;
705
706    ScannedFile {
707        relative_path,
708        absolute_path,
709        size: metadata.len(),
710        mtime_secs: metadata.mtime(),
711        mtime_nanos: metadata.mtime_nsec(),
712        device: Some(metadata.dev()),
713        inode: Some(metadata.ino()),
714        mode: metadata.mode(),
715        is_symlink: metadata.file_type().is_symlink(),
716    }
717}
718
719#[cfg(not(unix))]
720fn scanned_file_from_metadata(
721    relative_path: String,
722    absolute_path: std::path::PathBuf,
723    metadata: &std::fs::Metadata,
724) -> ScannedFile {
725    use std::time::UNIX_EPOCH;
726
727    let (mtime_secs, mtime_nanos) = metadata
728        .modified()
729        .ok()
730        .and_then(|time| time.duration_since(UNIX_EPOCH).ok())
731        .map(|duration| (duration.as_secs() as i64, duration.subsec_nanos() as i64))
732        .unwrap_or((0, 0));
733
734    let is_symlink = metadata.file_type().is_symlink();
735    ScannedFile {
736        relative_path,
737        absolute_path,
738        size: metadata.len(),
739        mtime_secs,
740        mtime_nanos,
741        device: None,
742        inode: None,
743        mode: if is_symlink { 0o120000 } else { 0o644 },
744        is_symlink,
745    }
746}
747
748fn cleanup_removed_file_parents(root: &Path, removed_paths: &[String]) -> Result<()> {
749    if removed_paths.is_empty() {
750        return Ok(());
751    }
752
753    let mut candidates = HashSet::with_capacity(removed_paths.len());
754    for removed_path in removed_paths {
755        let mut current = root.join(removed_path);
756        while let Some(parent) = current.parent() {
757            if parent == root {
758                candidates.insert(parent.to_path_buf());
759                break;
760            }
761            if !parent.starts_with(root) {
762                break;
763            }
764            candidates.insert(parent.to_path_buf());
765            current = parent.to_path_buf();
766        }
767    }
768
769    let mut candidates: Vec<_> = candidates.into_iter().filter(|dir| dir != root).collect();
770    candidates.sort_unstable_by(|left, right| {
771        right
772            .components()
773            .count()
774            .cmp(&left.components().count())
775            .then_with(|| left.cmp(right))
776    });
777
778    for dir in candidates {
779        match std::fs::remove_dir(&dir) {
780            Ok(()) => {}
781            Err(error) if error.kind() == std::io::ErrorKind::NotFound => {}
782            Err(error) if error.kind() == std::io::ErrorKind::DirectoryNotEmpty => {}
783            Err(error) => return Err(error.into()),
784        }
785    }
786
787    Ok(())
788}
789
790#[cfg(test)]
791mod tests {
792    use super::*;
793    use tempfile::TempDir;
794
795    #[test]
796    fn test_cleanup_removed_file_parents_removes_only_empty_ancestor_chain() {
797        let dir = TempDir::new().unwrap();
798        let root = dir.path();
799
800        let empty_leaf = root.join("a/b/c");
801        std::fs::create_dir_all(&empty_leaf).unwrap();
802        std::fs::write(empty_leaf.join("gone.txt"), b"gone").unwrap();
803        std::fs::remove_file(empty_leaf.join("gone.txt")).unwrap();
804
805        let non_empty_leaf = root.join("a/keep");
806        std::fs::create_dir_all(&non_empty_leaf).unwrap();
807        std::fs::write(non_empty_leaf.join("keep.txt"), b"keep").unwrap();
808
809        cleanup_removed_file_parents(root, &[String::from("a/b/c/gone.txt")]).unwrap();
810
811        assert!(!root.join("a/b/c").exists());
812        assert!(!root.join("a/b").exists());
813        assert!(root.join("a").exists());
814        assert!(root.join("a/keep/keep.txt").exists());
815    }
816
817    #[test]
818    fn test_cleanup_removed_file_parents_skips_non_empty_directories() {
819        let dir = TempDir::new().unwrap();
820        let root = dir.path();
821
822        let shared = root.join("shared");
823        std::fs::create_dir_all(&shared).unwrap();
824        std::fs::write(shared.join("still-here.txt"), b"keep").unwrap();
825
826        cleanup_removed_file_parents(root, &[String::from("shared/gone.txt")]).unwrap();
827
828        assert!(root.join("shared").exists());
829        assert!(root.join("shared/still-here.txt").exists());
830    }
831
832    #[test]
833    fn test_diff_restore_states_classifies_paths() {
834        let target_state = BTreeMap::from([
835            (
836                "a.txt".to_string(),
837                TargetFileState {
838                    hash: hash_bytes("hash-a"),
839                    is_symlink: false,
840                },
841            ),
842            (
843                "b.txt".to_string(),
844                TargetFileState {
845                    hash: hash_bytes("hash-b-target"),
846                    is_symlink: true,
847                },
848            ),
849            (
850                "c.txt".to_string(),
851                TargetFileState {
852                    hash: hash_bytes("hash-c"),
853                    is_symlink: false,
854                },
855            ),
856        ]);
857        let current_state = BTreeMap::from([
858            (
859                "b.txt".to_string(),
860                CurrentFileState {
861                    hash: hash_bytes("hash-b-current"),
862                    is_symlink: false,
863                },
864            ),
865            (
866                "c.txt".to_string(),
867                CurrentFileState {
868                    hash: hash_bytes("hash-c"),
869                    is_symlink: false,
870                },
871            ),
872            (
873                "d.txt".to_string(),
874                CurrentFileState {
875                    hash: hash_bytes("hash-d"),
876                    is_symlink: false,
877                },
878            ),
879        ]);
880
881        let diff = diff_restore_states(&target_state, &current_state);
882        assert_eq!(diff.files_to_add, vec!["a.txt".to_string()]);
883        assert_eq!(diff.files_to_change, vec!["b.txt".to_string()]);
884        assert_eq!(diff.files_to_remove, vec!["d.txt".to_string()]);
885        assert_eq!(diff.files_unchanged, 1);
886    }
887
888    #[test]
889    fn test_diff_restore_states_handles_empty_inputs() {
890        let target_state: BTreeMap<String, TargetFileState> = BTreeMap::new();
891        let current_state: BTreeMap<String, CurrentFileState> = BTreeMap::new();
892        let diff = diff_restore_states(&target_state, &current_state);
893
894        assert!(diff.files_to_add.is_empty());
895        assert!(diff.files_to_change.is_empty());
896        assert!(diff.files_to_remove.is_empty());
897        assert_eq!(diff.files_unchanged, 0);
898    }
899
900    #[test]
901    fn test_diff_restore_states_detects_type_changes() {
902        let target_state = BTreeMap::from([(
903            "link".to_string(),
904            TargetFileState {
905                hash: hash_bytes("same-hash"),
906                is_symlink: true,
907            },
908        )]);
909        let current_state = BTreeMap::from([(
910            "link".to_string(),
911            CurrentFileState {
912                hash: hash_bytes("same-hash"),
913                is_symlink: false,
914            },
915        )]);
916
917        let diff = diff_restore_states(&target_state, &current_state);
918        assert_eq!(diff.files_to_change, vec!["link".to_string()]);
919    }
920
921    fn hash_bytes(label: &str) -> [u8; 32] {
922        *blake3::hash(label.as_bytes()).as_bytes()
923    }
924}