Skip to main content

git_filter_tree/
lib.rs

1#![doc = include_str!("../README.md")]
2pub mod cli;
3pub mod exe;
4pub use git2::{Error, Repository};
5use globset::{GlobSet, GlobSetBuilder};
6
7pub trait FilterTree {
8    /// Filters tree entries by gitattributes-style patterns and returns a new tree with contents
9    /// filtered through the provided patterns. Recursively walks the tree and matches patterns
10    /// against full paths from the tree root.
11    ///
12    /// The `patterns` type is an array of string slices and not a glob type because Git has
13    /// specific glob syntax that differs from standard shell syntax.
14    fn filter_by_patterns<'a>(
15        &'a self,
16        tree: &'a git2::Tree<'a>,
17        patterns: &[&str], // TODO create a `git-glob` crate to handle patterns more gracefully
18    ) -> Result<git2::Tree<'a>, Error>;
19}
20
21impl FilterTree for git2::Repository {
22    fn filter_by_patterns<'a>(
23        &'a self,
24        tree: &'a git2::Tree<'a>,
25        patterns: &[&str],
26    ) -> Result<git2::Tree<'a>, Error> {
27        if patterns.is_empty() {
28            return Err(Error::from_str("At least one pattern is required"));
29        }
30
31        // Build GlobSet matcher
32        let mut glob_builder = GlobSetBuilder::new();
33        for pattern in patterns {
34            // A trailing `/` means "this directory" in gitattributes/gitignore
35            // semantics.  Normalize to `dir/**` so globset matches all files
36            // under the directory recursively.
37            let normalized: String;
38            let pat = if pattern.ends_with('/') {
39                normalized = format!("{}**", pattern);
40                normalized.as_str()
41            } else {
42                pattern
43            };
44            let glob = globset::Glob::new(pat)
45                .map_err(|e| Error::from_str(&format!("Invalid pattern '{}': {}", pattern, e)))?;
46            glob_builder.add(glob);
47        }
48
49        let matcher = glob_builder
50            .build()
51            .map_err(|e| Error::from_str(&e.to_string()))?;
52
53        // Recursively filter the tree
54        filter_tree_recursive(self, tree, None, &matcher)
55    }
56}
57
58/// Recursively filters a tree, matching patterns against full paths.
59/// Returns a new tree containing only entries that match or have matching descendants.
60fn filter_tree_recursive<'a>(
61    repo: &'a Repository,
62    tree: &'a git2::Tree<'a>,
63    prefix: Option<&str>,
64    matcher: &GlobSet,
65) -> Result<git2::Tree<'a>, Error> {
66    let mut builder = repo.treebuilder(None)?;
67
68    for entry in tree.iter() {
69        let Some(name) = entry.name() else {
70            return Err(Error::from_str("name has invalid UTF-8"));
71        };
72
73        let full_path = match prefix {
74            Some(subdir) => format!("{}/{}", subdir, name),
75            None => name.to_string(),
76        };
77
78        match entry.kind() {
79            Some(git2::ObjectType::Blob) => {
80                if matcher.is_match(&full_path) {
81                    builder.insert(name, entry.id(), entry.filemode())?;
82                }
83            }
84            Some(git2::ObjectType::Tree) => {
85                let subtree = entry.to_object(repo)?.peel_to_tree()?;
86                let filtered_subtree =
87                    filter_tree_recursive(repo, &subtree, Some(&full_path), matcher)?;
88                if !filtered_subtree.is_empty() {
89                    builder.insert(name, filtered_subtree.id(), entry.filemode())?;
90                }
91            }
92            // Skip submodule commit pointers, tags, and any other unexpected
93            // object types that can appear as tree entries.
94            _ => continue,
95        }
96    }
97
98    let tree_oid = builder.write()?;
99    repo.find_tree(tree_oid)
100}
101
102#[cfg(test)]
103mod tests {
104    use super::*;
105    use std::fs;
106    use std::path::PathBuf;
107
108    fn setup_test_repo() -> (Repository, PathBuf) {
109        let thread_id = std::thread::current().id();
110        let temp_path = std::env::temp_dir().join(format!("git-filter-tree-test-{:?}", thread_id));
111        let _ = fs::remove_dir_all(&temp_path);
112        fs::create_dir_all(&temp_path).unwrap();
113        let repo = Repository::init_bare(&temp_path).unwrap();
114        (repo, temp_path)
115    }
116
117    fn cleanup_test_repo(path: PathBuf) {
118        let _ = fs::remove_dir_all(path);
119    }
120
121    fn create_test_tree<'a>(repo: &'a Repository) -> Result<git2::Tree<'a>, Error> {
122        let mut tree_builder = repo.treebuilder(None)?;
123
124        // Create some blob entries
125        let blob1 = repo.blob(b"content1")?;
126        let blob2 = repo.blob(b"content2")?;
127        let blob3 = repo.blob(b"content3")?;
128
129        tree_builder.insert("file1.txt", blob1, 0o100644)?;
130        tree_builder.insert("file2.rs", blob2, 0o100644)?;
131        tree_builder.insert("test.md", blob3, 0o100644)?;
132
133        let tree_oid = tree_builder.write()?;
134        repo.find_tree(tree_oid)
135    }
136
137    #[test]
138    fn test_filter_single_pattern() -> Result<(), Error> {
139        let (repo, temp_path) = setup_test_repo();
140
141        let tree = create_test_tree(&repo)?;
142        assert_eq!(tree.len(), 3);
143
144        // Filter for .txt files only
145        let filtered = repo.filter_by_patterns(&tree, &["*.txt"])?;
146        assert_eq!(filtered.len(), 1);
147        assert!(filtered.get_name("file1.txt").is_some());
148        assert!(filtered.get_name("file2.rs").is_none());
149        assert!(filtered.get_name("test.md").is_none());
150
151        cleanup_test_repo(temp_path);
152        Ok(())
153    }
154
155    #[test]
156    fn test_filter_multiple_patterns() -> Result<(), Error> {
157        let (repo, temp_path) = setup_test_repo();
158
159        let tree = create_test_tree(&repo)?;
160
161        // Filter for .txt and .rs files
162        let filtered = repo.filter_by_patterns(&tree, &["*.txt", "*.rs"])?;
163        assert_eq!(filtered.len(), 2);
164        assert!(filtered.get_name("file1.txt").is_some());
165        assert!(filtered.get_name("file2.rs").is_some());
166        assert!(filtered.get_name("test.md").is_none());
167
168        cleanup_test_repo(temp_path);
169        Ok(())
170    }
171
172    #[test]
173    fn test_filter_exact_match() -> Result<(), Error> {
174        let (repo, temp_path) = setup_test_repo();
175
176        let tree = create_test_tree(&repo)?;
177
178        // Filter for exact filename
179        let filtered = repo.filter_by_patterns(&tree, &["file1.txt"])?;
180        assert_eq!(filtered.len(), 1);
181        assert!(filtered.get_name("file1.txt").is_some());
182
183        cleanup_test_repo(temp_path);
184        Ok(())
185    }
186
187    #[test]
188    fn test_filter_wildcard_patterns() -> Result<(), Error> {
189        let (repo, temp_path) = setup_test_repo();
190
191        let tree = create_test_tree(&repo)?;
192
193        // Filter with wildcard pattern
194        let filtered = repo.filter_by_patterns(&tree, &["file*"])?;
195        assert_eq!(filtered.len(), 2);
196        assert!(filtered.get_name("file1.txt").is_some());
197        assert!(filtered.get_name("file2.rs").is_some());
198        assert!(filtered.get_name("test.md").is_none());
199
200        cleanup_test_repo(temp_path);
201        Ok(())
202    }
203
204    #[test]
205    fn test_filter_no_matches() -> Result<(), Error> {
206        let (repo, temp_path) = setup_test_repo();
207
208        let tree = create_test_tree(&repo)?;
209
210        // Filter with pattern that matches nothing
211        let filtered = repo.filter_by_patterns(&tree, &["*.nonexistent"])?;
212        assert_eq!(filtered.len(), 0);
213
214        cleanup_test_repo(temp_path);
215        Ok(())
216    }
217
218    #[test]
219    fn test_filter_all_matches() -> Result<(), Error> {
220        let (repo, temp_path) = setup_test_repo();
221
222        let tree = create_test_tree(&repo)?;
223
224        // Filter with pattern that matches everything
225        let filtered = repo.filter_by_patterns(&tree, &["*"])?;
226        assert_eq!(filtered.len(), 3);
227
228        cleanup_test_repo(temp_path);
229        Ok(())
230    }
231
232    #[test]
233    fn test_filter_empty_patterns_error() {
234        let (repo, temp_path) = setup_test_repo();
235
236        let tree = create_test_tree(&repo).unwrap();
237
238        // Empty patterns should return an error
239        let result = repo.filter_by_patterns(&tree, &[]);
240        assert!(result.is_err());
241        assert_eq!(
242            result.unwrap_err().message(),
243            "At least one pattern is required"
244        );
245
246        cleanup_test_repo(temp_path);
247    }
248
249    #[test]
250    fn test_filter_invalid_pattern_error() {
251        let (repo, temp_path) = setup_test_repo();
252
253        let tree = create_test_tree(&repo).unwrap();
254
255        // Invalid glob pattern should return an error
256        let result = repo.filter_by_patterns(&tree, &["[invalid"]);
257        assert!(result.is_err());
258
259        cleanup_test_repo(temp_path);
260    }
261
262    #[test]
263    fn test_filter_with_nested_tree() -> Result<(), Error> {
264        let (repo, temp_path) = setup_test_repo();
265
266        let mut tree_builder = repo.treebuilder(None)?;
267
268        // Create a nested tree
269        let mut subtree_builder = repo.treebuilder(None)?;
270        let blob = repo.blob(b"nested content")?;
271        subtree_builder.insert("nested.txt", blob, 0o100644)?;
272        let subtree_oid = subtree_builder.write()?;
273
274        // Add files and subtree to main tree
275        let blob1 = repo.blob(b"content1")?;
276        tree_builder.insert("file1.txt", blob1, 0o100644)?;
277        tree_builder.insert("subdir", subtree_oid, 0o040000)?;
278
279        let tree_oid = tree_builder.write()?;
280        let tree = repo.find_tree(tree_oid)?;
281
282        // Filter - should keep both file and directory
283        let filtered = repo.filter_by_patterns(&tree, &["*"])?;
284        assert_eq!(filtered.len(), 2);
285
286        cleanup_test_repo(temp_path);
287        Ok(())
288    }
289
290    #[test]
291    fn test_filter_preserves_empty_tree() -> Result<(), Error> {
292        let (repo, temp_path) = setup_test_repo();
293
294        // Create an empty tree
295        let tree_builder = repo.treebuilder(None)?;
296        let tree_oid = tree_builder.write()?;
297        let tree = repo.find_tree(tree_oid)?;
298
299        assert_eq!(tree.len(), 0);
300
301        // Filter empty tree
302        let filtered = repo.filter_by_patterns(&tree, &["*"])?;
303        assert_eq!(filtered.len(), 0);
304
305        cleanup_test_repo(temp_path);
306        Ok(())
307    }
308
309    #[test]
310    fn test_filter_case_sensitive() -> Result<(), Error> {
311        let (repo, temp_path) = setup_test_repo();
312
313        let mut tree_builder = repo.treebuilder(None)?;
314        let blob1 = repo.blob(b"content1")?;
315        let blob2 = repo.blob(b"content2")?;
316
317        tree_builder.insert("File.txt", blob1, 0o100644)?;
318        tree_builder.insert("file.txt", blob2, 0o100644)?;
319
320        let tree_oid = tree_builder.write()?;
321        let tree = repo.find_tree(tree_oid)?;
322
323        // Filter with exact case match
324        let filtered = repo.filter_by_patterns(&tree, &["file.txt"])?;
325        assert_eq!(filtered.len(), 1);
326        assert!(filtered.get_name("file.txt").is_some());
327
328        cleanup_test_repo(temp_path);
329        Ok(())
330    }
331
332    #[test]
333    fn test_filter_complex_patterns() -> Result<(), Error> {
334        let (repo, temp_path) = setup_test_repo();
335
336        let mut tree_builder = repo.treebuilder(None)?;
337        let blob = repo.blob(b"content")?;
338
339        tree_builder.insert("test1.txt", blob, 0o100644)?;
340        tree_builder.insert("test2.rs", blob, 0o100644)?;
341        tree_builder.insert("data.json", blob, 0o100644)?;
342        tree_builder.insert("README.md", blob, 0o100644)?;
343
344        let tree_oid = tree_builder.write()?;
345        let tree = repo.find_tree(tree_oid)?;
346
347        // Multiple patterns with different wildcards
348        let filtered = repo.filter_by_patterns(&tree, &["test*", "*.md"])?;
349        assert_eq!(filtered.len(), 3);
350        assert!(filtered.get_name("test1.txt").is_some());
351        assert!(filtered.get_name("test2.rs").is_some());
352        assert!(filtered.get_name("README.md").is_some());
353        assert!(filtered.get_name("data.json").is_none());
354
355        cleanup_test_repo(temp_path);
356        Ok(())
357    }
358
359    #[test]
360    fn test_filter_trailing_slash_matches_directory_contents() -> Result<(), Error> {
361        let (repo, temp_path) = setup_test_repo();
362
363        // Build a tree with a subdirectory: pyo3/Cargo.toml, pyo3/src/lib.rs,
364        // and a top-level file that should NOT match.
365        let blob = repo.blob(b"content")?;
366
367        let mut src_builder = repo.treebuilder(None)?;
368        src_builder.insert("lib.rs", blob, 0o100644)?;
369        let src_oid = src_builder.write()?;
370
371        let mut pyo3_builder = repo.treebuilder(None)?;
372        pyo3_builder.insert("Cargo.toml", blob, 0o100644)?;
373        pyo3_builder.insert("src", src_oid, 0o040000)?;
374        let pyo3_oid = pyo3_builder.write()?;
375
376        let mut root_builder = repo.treebuilder(None)?;
377        root_builder.insert("pyo3", pyo3_oid, 0o040000)?;
378        root_builder.insert("README.md", blob, 0o100644)?;
379        let root_oid = root_builder.write()?;
380        let tree = repo.find_tree(root_oid)?;
381
382        // "pyo3/" (trailing slash) must match all files under pyo3/.
383        let filtered = repo.filter_by_patterns(&tree, &["pyo3/"])?;
384        assert_eq!(filtered.len(), 1, "only the pyo3 dir should remain");
385        assert!(filtered.get_name("pyo3").is_some());
386        assert!(filtered.get_name("README.md").is_none());
387
388        // The pyo3 subtree itself must retain both entries.
389        let pyo3_entry = filtered.get_name("pyo3").unwrap();
390        let pyo3_tree = repo.find_tree(pyo3_entry.id())?;
391        assert!(pyo3_tree.get_name("Cargo.toml").is_some());
392        assert!(pyo3_tree.get_name("src").is_some());
393
394        cleanup_test_repo(temp_path);
395        Ok(())
396    }
397}