Skip to main content

cargo_test_filter/
filter.rs

1use crate::cli::TestFilterArgs;
2use crate::discovery::{TestFunction, TestTarget, TestType};
3use regex::Regex;
4
5pub struct TestFilter<'a> {
6    args: &'a TestFilterArgs,
7}
8
9impl<'a> TestFilter<'a> {
10    pub fn new(args: &'a TestFilterArgs) -> Self {
11        Self { args }
12    }
13
14    /// Filter individual test functions based on the provided arguments
15    pub fn filter_functions(&self, functions: Vec<TestFunction>) -> Vec<TestFunction> {
16        functions
17            .into_iter()
18            .filter(|func| self.matches_function(func))
19            .collect()
20    }
21
22    /// Check if a test function matches all filter criteria
23    fn matches_function(&self, func: &TestFunction) -> bool {
24        // Filter by test type
25        if self.args.integration && func.test_type != TestType::Integration {
26            return false;
27        }
28        if self.args.unit && func.test_type != TestType::Unit {
29            return false;
30        }
31
32        // Filter by tags (include)
33        if !self.args.tag.is_empty() {
34            let has_any_tag = self.args.tag.iter().any(|filter_tag| {
35                func.tags.iter().any(|test_tag| test_tag == filter_tag)
36            });
37            if !has_any_tag {
38                return false;
39            }
40        }
41
42        // Filter by tags (exclude)
43        if !self.args.exclude_tag.is_empty() {
44            let has_excluded_tag = self.args.exclude_tag.iter().any(|filter_tag| {
45                func.tags.iter().any(|test_tag| test_tag == filter_tag)
46            });
47            if has_excluded_tag {
48                return false;
49            }
50        }
51
52        // Filter by name pattern (matches function name)
53        if let Some(ref name_pattern) = self.args.name {
54            if let Ok(re) = Regex::new(&format!(".*{}.*", regex::escape(name_pattern))) {
55                if !re.is_match(&func.name) {
56                    return false;
57                }
58            }
59        }
60
61        // Filter by path pattern
62        if let Some(ref path_pattern) = self.args.path {
63            let path_str = func.file_path.to_string_lossy();
64            if let Ok(re) = Regex::new(&format!(".*{}.*", regex::escape(path_pattern))) {
65                if !re.is_match(&path_str) {
66                    return false;
67                }
68            }
69        }
70
71        true
72    }
73
74    /// Legacy: Filter tests based on the provided arguments (file-level)
75    pub fn filter_tests(&self, tests: Vec<TestTarget>) -> Vec<TestTarget> {
76        tests
77            .into_iter()
78            .filter(|test| self.matches_test(test))
79            .collect()
80    }
81
82    /// Legacy: Check if a test matches all filter criteria (file-level)
83    fn matches_test(&self, test: &TestTarget) -> bool {
84        // Filter by test type
85        if self.args.integration && test.test_type != TestType::Integration {
86            return false;
87        }
88        if self.args.unit && test.test_type != TestType::Unit {
89            return false;
90        }
91
92        // Filter by tags (include)
93        if !self.args.tag.is_empty() {
94            let has_any_tag = self.args.tag.iter().any(|filter_tag| {
95                test.tags.iter().any(|test_tag| test_tag == filter_tag)
96            });
97            if !has_any_tag {
98                return false;
99            }
100        }
101
102        // Filter by tags (exclude)
103        if !self.args.exclude_tag.is_empty() {
104            let has_excluded_tag = self.args.exclude_tag.iter().any(|filter_tag| {
105                test.tags.iter().any(|test_tag| test_tag == filter_tag)
106            });
107            if has_excluded_tag {
108                return false;
109            }
110        }
111
112        // Filter by name pattern
113        if let Some(ref name_pattern) = self.args.name {
114            if let Ok(re) = Regex::new(&format!(".*{}.*", regex::escape(name_pattern))) {
115                if !re.is_match(&test.name) {
116                    return false;
117                }
118            }
119        }
120
121        // Filter by path pattern
122        if let Some(ref path_pattern) = self.args.path {
123            let path_str = test.path.to_string_lossy();
124            if let Ok(re) = Regex::new(&format!(".*{}.*", regex::escape(path_pattern))) {
125                if !re.is_match(&path_str) {
126                    return false;
127                }
128            }
129        }
130
131        true
132    }
133
134    /// Get a summary of what filters are active
135    pub fn get_filter_summary(&self) -> String {
136        let mut parts = Vec::new();
137
138        if self.args.integration {
139            parts.push("integration tests only".to_string());
140        }
141        if self.args.unit {
142            parts.push("unit tests only".to_string());
143        }
144        if !self.args.tag.is_empty() {
145            parts.push(format!("tags: {}", self.args.tag.join(", ")));
146        }
147        if !self.args.exclude_tag.is_empty() {
148            parts.push(format!("excluding tags: {}", self.args.exclude_tag.join(", ")));
149        }
150        if let Some(ref name) = self.args.name {
151            parts.push(format!("name contains: {}", name));
152        }
153        if let Some(ref path) = self.args.path {
154            parts.push(format!("path contains: {}", path));
155        }
156
157        if parts.is_empty() {
158            "all tests".to_string()
159        } else {
160            parts.join(", ")
161        }
162    }
163}
164
165#[cfg(test)]
166mod tests {
167    use super::*;
168    use std::path::PathBuf;
169
170    fn create_test_args() -> TestFilterArgs {
171        TestFilterArgs {
172            integration: false,
173            unit: false,
174            tag: Vec::new(),
175            exclude_tag: Vec::new(),
176            name: None,
177            path: None,
178            timeout: None,
179            list: false,
180            verbose: false,
181            test_args: Vec::new(),
182        }
183    }
184
185    fn create_test_target(name: &str, test_type: TestType, tags: Vec<String>) -> TestTarget {
186        TestTarget {
187            name: name.to_string(),
188            path: PathBuf::from(format!("{}.rs", name)),
189            test_type,
190            tags,
191        }
192    }
193
194    fn create_test_function(name: &str, target_name: &str, test_type: TestType, tags: Vec<String>) -> TestFunction {
195        TestFunction {
196            name: name.to_string(),
197            file_path: PathBuf::from(format!("tests/{}.rs", target_name)),
198            target_name: target_name.to_string(),
199            test_type,
200            tags,
201        }
202    }
203
204    #[test]
205    fn test_filter_integration_only() {
206        let mut args = create_test_args();
207        args.integration = true;
208
209        let filter = TestFilter::new(&args);
210        let tests = vec![
211            create_test_target("unit_test", TestType::Unit, vec![]),
212            create_test_target("integration_test", TestType::Integration, vec![]),
213        ];
214
215        let filtered = filter.filter_tests(tests);
216        assert_eq!(filtered.len(), 1);
217        assert_eq!(filtered[0].name, "integration_test");
218    }
219
220    #[test]
221    fn test_filter_by_tag() {
222        let mut args = create_test_args();
223        args.tag.push("fast".to_string());
224
225        let filter = TestFilter::new(&args);
226        let tests = vec![
227            create_test_target("test1", TestType::Unit, vec!["fast".to_string()]),
228            create_test_target("test2", TestType::Unit, vec!["slow".to_string()]),
229        ];
230
231        let filtered = filter.filter_tests(tests);
232        assert_eq!(filtered.len(), 1);
233        assert_eq!(filtered[0].name, "test1");
234    }
235
236    #[test]
237    fn test_exclude_tag() {
238        let mut args = create_test_args();
239        args.exclude_tag.push("slow".to_string());
240
241        let filter = TestFilter::new(&args);
242        let tests = vec![
243            create_test_target("test1", TestType::Unit, vec!["fast".to_string()]),
244            create_test_target("test2", TestType::Unit, vec!["slow".to_string()]),
245        ];
246
247        let filtered = filter.filter_tests(tests);
248        assert_eq!(filtered.len(), 1);
249        assert_eq!(filtered[0].name, "test1");
250    }
251
252    #[test]
253    fn test_filter_functions_by_tag() {
254        let mut args = create_test_args();
255        args.tag.push("fast".to_string());
256
257        let filter = TestFilter::new(&args);
258        let functions = vec![
259            create_test_function("test_fast_api", "api_test", TestType::Integration, vec!["fast".to_string(), "api".to_string()]),
260            create_test_function("test_slow_db", "db_test", TestType::Integration, vec!["slow".to_string(), "database".to_string()]),
261            create_test_function("test_fast_db", "db_test", TestType::Integration, vec!["fast".to_string(), "database".to_string()]),
262        ];
263
264        let filtered = filter.filter_functions(functions);
265        assert_eq!(filtered.len(), 2);
266        assert!(filtered.iter().all(|f| f.tags.contains(&"fast".to_string())));
267    }
268
269    #[test]
270    fn test_filter_functions_exclude_tag() {
271        let mut args = create_test_args();
272        args.exclude_tag.push("slow".to_string());
273
274        let filter = TestFilter::new(&args);
275        let functions = vec![
276            create_test_function("test_fast_api", "api_test", TestType::Integration, vec!["fast".to_string()]),
277            create_test_function("test_slow_db", "db_test", TestType::Integration, vec!["slow".to_string()]),
278        ];
279
280        let filtered = filter.filter_functions(functions);
281        assert_eq!(filtered.len(), 1);
282        assert_eq!(filtered[0].name, "test_fast_api");
283    }
284
285    #[test]
286    fn test_filter_functions_by_name() {
287        let mut args = create_test_args();
288        args.name = Some("api".to_string());
289
290        let filter = TestFilter::new(&args);
291        let functions = vec![
292            create_test_function("test_api_endpoint", "api_test", TestType::Integration, vec![]),
293            create_test_function("test_database_query", "db_test", TestType::Integration, vec![]),
294        ];
295
296        let filtered = filter.filter_functions(functions);
297        assert_eq!(filtered.len(), 1);
298        assert_eq!(filtered[0].name, "test_api_endpoint");
299    }
300}