gitai/remote/cache/
filter.rs

1use std::{fs, path::Path};
2
3pub struct RepositoryFilter;
4
5impl RepositoryFilter {
6    /// Filter repository content based on the configuration filters.
7    /// This function copies only the specified paths from source to destination.
8    pub fn filter_repository_content(
9        &self,
10        source_path: &str,
11        destination_path: &str,
12        filters: &[String],
13    ) -> Result<(), Box<dyn std::error::Error>> {
14        // Create destination directory if it doesn't exist.
15        fs::create_dir_all(destination_path)?;
16
17        // Process each filter path.
18        for filter_path in filters {
19            self.copy_filtered_content(source_path, destination_path, filter_path)?;
20        }
21
22        Ok(())
23    }
24
25    /// Copy specific content from source to destination based on a filter path.
26    fn copy_filtered_content(
27        &self,
28        source_path: &str,
29        destination_path: &str,
30        filter_path: &str,
31    ) -> Result<(), Box<dyn std::error::Error>> {
32        let source_dir = Path::new(source_path);
33        let dest_dir = Path::new(destination_path);
34
35        // Sanitize the filter path to prevent directory traversal.
36        let normalized_filter = Self::normalize_path(filter_path);
37        let source_filtered_path = source_dir.join(&normalized_filter);
38
39        if source_filtered_path.exists() {
40            if source_filtered_path.is_file() {
41                // Copy a single file, preserving only the filename.
42                let file_name = source_filtered_path
43                    .file_name()
44                    .ok_or("Invalid file path")?;
45                let dest_file = dest_dir.join(file_name);
46                fs::copy(&source_filtered_path, &dest_file)?;
47            } else if source_filtered_path.is_dir() {
48                // Copy the entire directory recursively, preserving the path structure.
49                let dest_subdir = dest_dir.join(&normalized_filter);
50                self.copy_dir_all(&source_filtered_path, &dest_subdir)?;
51            }
52        } else {
53            // The filter path doesn't exist in the source; skip it.
54            eprintln!("Warning: Filter path '{filter_path}' does not exist in source repository");
55        }
56
57        Ok(())
58    }
59
60    /// Recursively copy a directory and its contents.
61    /// Skips symlinks and other non-regular file types.
62    #[allow(clippy::only_used_in_recursion)]
63    fn copy_dir_all(&self, src: &Path, dst: &Path) -> Result<(), Box<dyn std::error::Error>> {
64        // Ensure the destination directory exists.
65        fs::create_dir_all(dst)?;
66
67        // Read and process each entry in the source directory.
68        for entry in fs::read_dir(src)? {
69            let entry = entry?;
70            let metadata = entry.metadata()?;
71            let file_type = metadata.file_type();
72
73            let entry_path = entry.path();
74            let dest_entry_path = dst.join(entry.file_name());
75
76            if file_type.is_dir() {
77                // Recurse into subdirectory.
78                self.copy_dir_all(&entry_path, &dest_entry_path)?;
79            } else if file_type.is_file() {
80                // Copy the file.
81                fs::copy(&entry_path, &dest_entry_path)?;
82            }
83            // Intentionally skip symlinks, sockets, etc.
84        }
85
86        Ok(())
87    }
88
89    /// Normalize a path to prevent directory traversal attacks (e.g., removes "../" sequences).
90    /// This is a basic implementation; production code may require more robust validation.
91    fn normalize_path(path: &str) -> String {
92        // Split the path into components.
93        let parts: Vec<&str> = path.split('/').collect();
94
95        // Build a stack of valid path components.
96        let mut stack: Vec<&str> = Vec::new();
97        for &part in &parts {
98            match part {
99                "" | "." => {
100                    // Skip empty parts and current directory markers.
101                }
102                ".." => {
103                    // Go up one level if possible.
104                    stack.pop();
105                }
106                _ => {
107                    // Add normal directory or file name.
108                    stack.push(part);
109                }
110            }
111        }
112
113        // Join the stack into a path string.
114        let mut normalized = stack.join("/");
115
116        // Remove leading slash to avoid absolute paths.
117        if normalized.starts_with('/') {
118            normalized = normalized[1..].to_string();
119        }
120
121        normalized
122    }
123}
124
125#[cfg(test)]
126mod tests {
127    use super::*;
128    use std::fs;
129    use tempfile::TempDir;
130
131    #[test]
132    fn test_normalize_path() {
133        // Test basic path normalization.
134        assert_eq!(
135            RepositoryFilter::normalize_path("src/main.rs"),
136            "src/main.rs"
137        );
138        assert_eq!(
139            RepositoryFilter::normalize_path("./src/main.rs"),
140            "src/main.rs"
141        );
142        assert_eq!(
143            RepositoryFilter::normalize_path("src/./main.rs"),
144            "src/main.rs"
145        );
146
147        // Test path traversal removal.
148        assert_eq!(
149            RepositoryFilter::normalize_path("../src/main.rs"),
150            "src/main.rs"
151        );
152        assert_eq!(
153            RepositoryFilter::normalize_path("src/../lib/utils.rs"),
154            "lib/utils.rs"
155        );
156    }
157
158    #[test]
159    fn test_filter_repository_content() {
160        // Create a temporary source directory structure.
161        let src_dir = TempDir::new().expect("Failed to create temporary source directory");
162        let src_path = src_dir.path();
163
164        // Create test files and directories.
165        fs::create_dir_all(src_path.join("src")).expect("Failed to create src directory");
166        fs::create_dir_all(src_path.join("docs")).expect("Failed to create docs directory");
167        fs::write(src_path.join("src").join("main.rs"), "fn main() {}")
168            .expect("Failed to write main.rs");
169        fs::write(src_path.join("docs").join("README.md"), "# Docs")
170            .expect("Failed to write README.md");
171        fs::write(src_path.join("LICENSE"), "MIT License").expect("Failed to write LICENSE");
172
173        // Create a temporary destination directory.
174        let dest_dir = TempDir::new().expect("Failed to create temporary destination directory");
175        let dest_path = dest_dir.path();
176
177        let filter = RepositoryFilter;
178        let filters = vec!["src/".to_string()];
179
180        // Apply the filter.
181        filter
182            .filter_repository_content(
183                src_path.to_str().expect("Source path is not valid UTF-8"),
184                dest_path
185                    .to_str()
186                    .expect("Destination path is not valid UTF-8"),
187                &filters,
188            )
189            .expect("Failed to filter repository content");
190
191        // Verify that only the filtered content was copied.
192        assert!(dest_path.join("src").exists());
193        assert!(dest_path.join("src").join("main.rs").exists());
194        assert!(!dest_path.join("docs").exists());
195        assert!(!dest_path.join("LICENSE").exists());
196    }
197}