Skip to main content

gravityfile_ops/
copy.rs

1//! Async copy operation with progress reporting.
2
3use std::collections::HashSet;
4use std::fs;
5use std::path::{Path, PathBuf};
6
7use tokio::sync::mpsc;
8use tokio_util::sync::CancellationToken;
9
10use crate::conflict::{Conflict, ConflictKind, ConflictResolution, auto_rename_path};
11use crate::progress::{OperationComplete, OperationProgress, OperationType};
12use crate::{OPERATION_CHANNEL_SIZE, OperationError};
13
14/// Result sent through the channel during copy operations.
15#[derive(Debug)]
16pub enum CopyResult {
17    /// Progress update.
18    Progress(OperationProgress),
19    /// A conflict was detected that needs resolution.
20    Conflict(Conflict),
21    /// The operation completed.
22    Complete(OperationComplete),
23}
24
25/// Options for copy operations.
26#[derive(Debug, Clone, Default)]
27pub struct CopyOptions {
28    /// How to handle conflicts (None means ask for each).
29    pub conflict_resolution: Option<ConflictResolution>,
30    /// Whether to preserve timestamps.
31    pub preserve_timestamps: bool,
32}
33
34/// Start an async copy operation.
35///
36/// Returns a receiver for progress updates and results.
37pub fn start_copy(
38    sources: Vec<PathBuf>,
39    destination: PathBuf,
40    options: CopyOptions,
41    token: CancellationToken,
42) -> mpsc::Receiver<CopyResult> {
43    let (tx, rx) = mpsc::channel(OPERATION_CHANNEL_SIZE);
44
45    if sources.is_empty() {
46        // Send immediate completion for empty sources
47        let complete = OperationComplete {
48            operation_type: OperationType::Copy,
49            succeeded: 0,
50            failed: 0,
51            bytes_processed: 0,
52            errors: vec![],
53        };
54        tokio::spawn(async move {
55            let _ = tx.send(CopyResult::Complete(complete)).await;
56        });
57        return rx;
58    }
59
60    tokio::spawn(async move {
61        copy_impl(sources, destination, options, token, tx).await;
62    });
63
64    rx
65}
66
67/// Internal implementation of copy operation.
68async fn copy_impl(
69    sources: Vec<PathBuf>,
70    destination: PathBuf,
71    options: CopyOptions,
72    token: CancellationToken,
73    tx: mpsc::Sender<CopyResult>,
74) {
75    // First, calculate total size and file count
76    let (total_files, total_bytes) = calculate_totals(&sources);
77
78    let mut progress = OperationProgress::new(OperationType::Copy, total_files, total_bytes);
79    let global_resolution: Option<ConflictResolution> = options.conflict_resolution;
80    let mut succeeded = 0;
81    let mut failed = 0;
82
83    // Ensure destination exists and is a directory
84    if !destination.exists()
85        && let Err(e) = fs::create_dir_all(&destination)
86    {
87        progress.add_error(OperationError::new(
88            destination.clone(),
89            format!("Failed to create destination: {}", e),
90        ));
91        let _ = tx
92            .send(CopyResult::Complete(OperationComplete {
93                operation_type: OperationType::Copy,
94                succeeded: 0,
95                failed: sources.len(),
96                bytes_processed: 0,
97                errors: progress.errors.clone(),
98            }))
99            .await;
100        return;
101    }
102
103    for source in sources {
104        // HIGH-3: check for cancellation before each item
105        if token.is_cancelled() {
106            break;
107        }
108
109        // MED-5: return error when file_name() is None (e.g. path is "/" or ends with "..")
110        let file_name = match source.file_name() {
111            Some(n) => n.to_owned(),
112            None => {
113                progress.add_error(OperationError::new(
114                    source.clone(),
115                    "Source path has no filename component".to_string(),
116                ));
117                failed += 1;
118                continue;
119            }
120        };
121        let dest_path = destination.join(&file_name);
122
123        // Check for conflicts using symlink_metadata so we see the link itself
124        let dest_meta = fs::symlink_metadata(&dest_path).ok();
125        if dest_meta.is_some() {
126            let conflict_kind = if dest_meta.as_ref().is_some_and(|m| m.is_dir()) {
127                ConflictKind::DirectoryExists
128            } else {
129                ConflictKind::FileExists
130            };
131
132            let resolution = if let Some(res) = global_resolution {
133                res.to_single()
134            } else {
135                // Send conflict and wait (in real impl, would need response channel)
136                // For now, default to skip
137                let _ = tx
138                    .send(CopyResult::Conflict(Conflict::new(
139                        source.clone(),
140                        dest_path.clone(),
141                        conflict_kind,
142                    )))
143                    .await;
144                ConflictResolution::Skip
145            };
146
147            match resolution {
148                ConflictResolution::Skip | ConflictResolution::SkipAll => {
149                    failed += 1;
150                    continue;
151                }
152                ConflictResolution::Abort => {
153                    let _ = tx
154                        .send(CopyResult::Complete(OperationComplete {
155                            operation_type: OperationType::Copy,
156                            succeeded,
157                            failed: failed + 1,
158                            bytes_processed: progress.bytes_processed,
159                            errors: progress.errors.clone(),
160                        }))
161                        .await;
162                    return;
163                }
164                ConflictResolution::AutoRename => {
165                    let new_dest = auto_rename_path(&dest_path);
166                    if let Err(e) = copy_item(&source, &new_dest, &mut progress, &tx).await {
167                        progress.add_error(OperationError::new(source.clone(), e));
168                        failed += 1;
169                    } else {
170                        succeeded += 1;
171                    }
172                    continue;
173                }
174                ConflictResolution::Overwrite | ConflictResolution::OverwriteAll => {
175                    // CRIT-3: handle removal failure — record error and skip this source
176                    let remove_result = if let Some(ref m) = dest_meta {
177                        if m.is_symlink() || !m.is_dir() {
178                            fs::remove_file(&dest_path)
179                        } else {
180                            fs::remove_dir_all(&dest_path)
181                        }
182                    } else {
183                        Ok(())
184                    };
185                    if let Err(e) = remove_result {
186                        progress.add_error(OperationError::new(
187                            dest_path.clone(),
188                            format!("Failed to remove existing destination: {}", e),
189                        ));
190                        failed += 1;
191                        continue;
192                    }
193                }
194            }
195        }
196
197        // Perform the copy
198        progress.set_current_file(Some(source.clone()));
199        let _ = tx.send(CopyResult::Progress(progress.clone())).await;
200
201        if let Err(e) = copy_item(&source, &dest_path, &mut progress, &tx).await {
202            progress.add_error(OperationError::new(source.clone(), e));
203            failed += 1;
204        } else {
205            succeeded += 1;
206        }
207    }
208
209    // Send completion
210    let _ = tx
211        .send(CopyResult::Complete(OperationComplete {
212            operation_type: OperationType::Copy,
213            succeeded,
214            failed,
215            bytes_processed: progress.bytes_processed,
216            errors: progress.errors,
217        }))
218        .await;
219}
220
221/// Copy a single item (file, directory, or symlink).
222async fn copy_item(
223    source: &Path,
224    dest: &Path,
225    progress: &mut OperationProgress,
226    tx: &mpsc::Sender<CopyResult>,
227) -> Result<(), String> {
228    let source = source.to_path_buf();
229    let dest = dest.to_path_buf();
230
231    let result = tokio::task::spawn_blocking(move || {
232        // Use symlink_metadata to avoid following symlinks
233        let metadata =
234            fs::symlink_metadata(&source).map_err(|e| format!("Failed to read metadata: {}", e))?;
235
236        if metadata.is_symlink() {
237            // For symlinks, read the target and recreate at destination
238            let target =
239                fs::read_link(&source).map_err(|e| format!("Failed to read symlink: {}", e))?;
240            #[cfg(unix)]
241            {
242                std::os::unix::fs::symlink(&target, &dest)
243                    .map_err(|e| format!("Failed to create symlink: {}", e))?;
244            }
245            #[cfg(windows)]
246            {
247                if target.is_dir() {
248                    std::os::windows::fs::symlink_dir(&target, &dest)
249                        .map_err(|e| format!("Failed to create symlink: {}", e))?;
250                } else {
251                    std::os::windows::fs::symlink_file(&target, &dest)
252                        .map_err(|e| format!("Failed to create symlink: {}", e))?;
253                }
254            }
255            Ok(0u64) // Symlinks have no real size
256        } else if metadata.is_dir() {
257            copy_dir_recursive(&source, &dest, &mut HashSet::new())
258        } else {
259            copy_file(&source, &dest)
260        }
261    })
262    .await
263    .map_err(|e| format!("Task failed: {}", e))?;
264
265    match result {
266        Ok(bytes) => {
267            progress.complete_file(bytes);
268            let _ = tx.send(CopyResult::Progress(progress.clone())).await;
269            Ok(())
270        }
271        Err(e) => Err(e),
272    }
273}
274
275/// Copy a single file.
276///
277/// HIGH-1: use symlink_metadata so we read the link's own size, not the target's.
278fn copy_file(source: &PathBuf, dest: &PathBuf) -> Result<u64, String> {
279    // HIGH-1: use symlink_metadata to avoid following symlinks for size
280    let metadata =
281        fs::symlink_metadata(source).map_err(|e| format!("Failed to read metadata: {}", e))?;
282    let size = metadata.len();
283
284    fs::copy(source, dest).map_err(|e| format!("Failed to copy: {}", e))?;
285
286    Ok(size)
287}
288
289/// Recursively copy a directory.
290///
291/// CRIT-2: uses `entry.file_type()` (no symlink-follow), handles symlinks as a distinct
292/// branch, and tracks inodes via a visited set to detect hard-link / symlink loops on Unix.
293fn copy_dir_recursive(
294    source: &PathBuf,
295    dest: &PathBuf,
296    visited: &mut HashSet<u64>,
297) -> Result<u64, String> {
298    // Loop detection via inode on Unix
299    #[cfg(unix)]
300    {
301        use std::os::unix::fs::MetadataExt;
302        if let Ok(meta) = fs::symlink_metadata(source) {
303            let inode = meta.ino();
304            if !visited.insert(inode) {
305                // Already visited this inode — skip to break the loop
306                return Ok(0);
307            }
308        }
309    }
310
311    fs::create_dir_all(dest).map_err(|e| format!("Failed to create directory: {}", e))?;
312
313    let mut total_bytes = 0u64;
314
315    let entries = fs::read_dir(source).map_err(|e| format!("Failed to read directory: {}", e))?;
316
317    for entry in entries {
318        let entry = entry.map_err(|e| format!("Failed to read entry: {}", e))?;
319        // CRIT-2: use entry.file_type() — does NOT follow symlinks
320        let file_type = entry
321            .file_type()
322            .map_err(|e| format!("Failed to read file type: {}", e))?;
323        let path = entry.path();
324        let dest_path = dest.join(entry.file_name());
325
326        if file_type.is_symlink() {
327            // Recreate the symlink rather than following it
328            let target =
329                fs::read_link(&path).map_err(|e| format!("Failed to read symlink: {}", e))?;
330            #[cfg(unix)]
331            {
332                std::os::unix::fs::symlink(&target, &dest_path)
333                    .map_err(|e| format!("Failed to create symlink: {}", e))?;
334            }
335            #[cfg(windows)]
336            {
337                // Choose symlink_dir vs symlink_file based on what the
338                // original link pointed at, and log failures.
339                let result = if path.is_dir() {
340                    std::os::windows::fs::symlink_dir(&target, &dest_path)
341                } else {
342                    std::os::windows::fs::symlink_file(&target, &dest_path)
343                };
344                if let Err(e) = result {
345                    tracing::warn!(
346                        "Failed to create symlink {} -> {}: {}",
347                        dest_path.display(),
348                        target.display(),
349                        e
350                    );
351                }
352            }
353        } else if file_type.is_dir() {
354            total_bytes += copy_dir_recursive(&path, &dest_path, visited)?;
355        } else {
356            total_bytes += copy_file(&path, &dest_path)?;
357        }
358    }
359
360    Ok(total_bytes)
361}
362
363/// Calculate total files and bytes for a list of sources.
364///
365/// HIGH-2: use symlink_metadata; skip symlinks in size calculations.
366fn calculate_totals(sources: &[PathBuf]) -> (usize, u64) {
367    let mut files = 0;
368    let mut bytes = 0u64;
369
370    for source in sources {
371        match fs::symlink_metadata(source) {
372            Ok(meta) if meta.is_dir() => {
373                let (f, b) = calculate_dir_totals(source);
374                files += f;
375                bytes += b;
376            }
377            Ok(meta) if !meta.is_symlink() => {
378                files += 1;
379                bytes += meta.len();
380            }
381            _ => {} // skip symlinks and inaccessible entries
382        }
383    }
384
385    (files, bytes)
386}
387
388/// Calculate totals for a directory recursively.
389///
390/// HIGH-2: use symlink_metadata; skip symlinks.
391fn calculate_dir_totals(dir: &PathBuf) -> (usize, u64) {
392    let mut files = 0;
393    let mut bytes = 0u64;
394
395    if let Ok(entries) = fs::read_dir(dir) {
396        for entry in entries.flatten() {
397            // Use entry.file_type() — does not follow symlinks
398            let Ok(ft) = entry.file_type() else { continue };
399            if ft.is_symlink() {
400                // Skip symlinks in size calculation
401                continue;
402            }
403            let path = entry.path();
404            if ft.is_dir() {
405                let (f, b) = calculate_dir_totals(&path);
406                files += f;
407                bytes += b;
408            } else if let Ok(metadata) = fs::symlink_metadata(&path) {
409                files += 1;
410                bytes += metadata.len();
411            }
412        }
413    }
414
415    (files, bytes)
416}