Skip to main content

cargo_test_filter/
discovery.rs

1use anyhow::{Context, Result};
2use serde::{Deserialize, Serialize};
3use std::fs;
4use std::path::{Path, PathBuf};
5use walkdir::WalkDir;
6
7/// Represents an individual test function with its metadata
8#[derive(Debug, Clone, Serialize, Deserialize)]
9pub struct TestFunction {
10    /// The name of the test function (e.g., "test_database_connection")
11    pub name: String,
12    /// The file containing this test
13    pub file_path: PathBuf,
14    /// The name of the test target (file stem for integration tests, "lib" for unit tests)
15    pub target_name: String,
16    /// The type of test (unit, integration, etc.)
17    pub test_type: TestType,
18    /// Tags associated with this specific test function
19    pub tags: Vec<String>,
20}
21
22/// Legacy struct for file-level test targets (kept for compatibility)
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct TestTarget {
25    pub name: String,
26    pub path: PathBuf,
27    pub test_type: TestType,
28    pub tags: Vec<String>,
29}
30
31#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
32pub enum TestType {
33    Unit,
34    Integration,
35    Doc,
36}
37
38pub struct TestDiscovery {
39    project_root: PathBuf,
40}
41
42impl TestDiscovery {
43    pub fn new(project_root: PathBuf) -> Self {
44        Self { project_root }
45    }
46
47    /// Discover all individual test functions in the project
48    pub fn discover_test_functions(&self) -> Result<Vec<TestFunction>> {
49        let mut functions = Vec::new();
50
51        // Discover integration test functions
52        functions.extend(self.discover_integration_test_functions()?);
53
54        // Discover unit test functions in src/
55        functions.extend(self.discover_unit_test_functions()?);
56
57        Ok(functions)
58    }
59
60    /// Legacy method: Discover test targets (file-level)
61    pub fn discover_tests(&self) -> Result<Vec<TestTarget>> {
62        let mut tests = Vec::new();
63
64        // Discover integration tests
65        tests.extend(self.discover_integration_tests()?);
66
67        // Discover unit tests in src/
68        tests.extend(self.discover_unit_tests()?);
69
70        Ok(tests)
71    }
72
73    /// Discover integration test functions in tests/ directory
74    fn discover_integration_test_functions(&self) -> Result<Vec<TestFunction>> {
75        let tests_dir = self.project_root.join("tests");
76        if !tests_dir.exists() {
77            return Ok(Vec::new());
78        }
79
80        let mut functions = Vec::new();
81
82        for entry in WalkDir::new(&tests_dir)
83            .min_depth(1)
84            .max_depth(3)
85            .into_iter()
86            .filter_map(|e| e.ok())
87        {
88            let path = entry.path();
89            if path.is_file() && path.extension().is_some_and(|e| e == "rs") {
90                let target_name = path
91                    .file_stem()
92                    .and_then(|s| s.to_str())
93                    .unwrap_or("unknown")
94                    .to_string();
95
96                let file_functions = self.parse_test_functions(path, &target_name, TestType::Integration)?;
97                functions.extend(file_functions);
98            }
99        }
100
101        Ok(functions)
102    }
103
104    /// Discover unit test functions in src/ directory
105    fn discover_unit_test_functions(&self) -> Result<Vec<TestFunction>> {
106        let src_dir = self.project_root.join("src");
107        if !src_dir.exists() {
108            return Ok(Vec::new());
109        }
110
111        let mut functions = Vec::new();
112
113        for entry in WalkDir::new(&src_dir)
114            .into_iter()
115            .filter_map(|e| e.ok())
116        {
117            let path = entry.path();
118            if path.is_file() && path.extension().is_some_and(|e| e == "rs") {
119                let content = fs::read_to_string(path)
120                    .with_context(|| format!("Failed to read {}", path.display()))?;
121
122                if self.contains_tests(&content) {
123                    let file_functions = self.parse_test_functions(path, "lib", TestType::Unit)?;
124                    functions.extend(file_functions);
125                }
126            }
127        }
128
129        Ok(functions)
130    }
131
132    /// Parse a file to extract individual test functions with their tags
133    fn parse_test_functions(&self, path: &Path, target_name: &str, test_type: TestType) -> Result<Vec<TestFunction>> {
134        let content = fs::read_to_string(path)
135            .with_context(|| format!("Failed to read {}", path.display()))?;
136
137        let mut functions = Vec::new();
138        let lines: Vec<&str> = content.lines().collect();
139        let mut i = 0;
140
141        while i < lines.len() {
142            let line = lines[i].trim();
143
144            // Look for #[test] or #[tokio::test] or similar test attributes
145            if self.is_test_attribute(line) {
146                // Collect tags from preceding lines
147                let tags = self.collect_preceding_tags(&lines, i);
148
149                // Find the function name from the next lines
150                let func_name = self.find_function_name(&lines, i + 1);
151
152                if let Some(name) = func_name {
153                    functions.push(TestFunction {
154                        name,
155                        file_path: path.to_path_buf(),
156                        target_name: target_name.to_string(),
157                        test_type: test_type.clone(),
158                        tags,
159                    });
160                }
161            }
162
163            i += 1;
164        }
165
166        Ok(functions)
167    }
168
169    /// Check if a line is a test attribute
170    fn is_test_attribute(&self, line: &str) -> bool {
171        let line = line.trim();
172        // Match various test attributes
173        line == "#[test]"
174            || line.starts_with("#[test(")
175            || line.starts_with("#[tokio::test")
176            || line.starts_with("#[async_std::test")
177            || line.starts_with("#[rstest")
178            || line.starts_with("#[test_case")
179    }
180
181    /// Collect tags from lines preceding a test attribute
182    fn collect_preceding_tags(&self, lines: &[&str], test_line_idx: usize) -> Vec<String> {
183        let mut tags = Vec::new();
184
185        // Look backwards from the test attribute for tag comments
186        let mut j = test_line_idx;
187        while j > 0 {
188            j -= 1;
189            let line = lines[j].trim();
190
191            // Stop if we hit an empty line or something that's not a comment/attribute
192            if line.is_empty() {
193                break;
194            }
195
196            // Parse tag from comment or attribute
197            if let Some(tag) = self.parse_tag_line(line) {
198                tags.push(tag);
199            } else if !line.starts_with("//") && !line.starts_with("#[") {
200                // Hit non-comment, non-attribute line - stop looking
201                break;
202            }
203        }
204
205        tags
206    }
207
208    /// Parse a tag from a line (supports multiple formats)
209    fn parse_tag_line(&self, line: &str) -> Option<String> {
210        let line = line.trim();
211
212        // Support comment-based tags: // @tag: tagname or //@tag: tagname
213        if line.starts_with("// @tag:") || line.starts_with("//@tag:") {
214            let parts: Vec<&str> = line.splitn(2, ':').collect();
215            if parts.len() >= 2 {
216                return Some(parts[1].trim().to_string());
217            }
218        }
219
220        // Support attribute-based tags: #[test_tag("tagname")]
221        if line.starts_with("#[test_tag(") && line.ends_with(")]") {
222            let start = line.find('"')?;
223            let end = line.rfind('"')?;
224            if start < end {
225                return Some(line[start + 1..end].to_string());
226            }
227        }
228
229        None
230    }
231
232    /// Find the function name after a test attribute
233    fn find_function_name(&self, lines: &[&str], start_idx: usize) -> Option<String> {
234        for line in lines.iter().skip(start_idx).take(5) {
235            let line = line.trim();
236
237            // Skip additional attributes
238            if line.starts_with("#[") {
239                continue;
240            }
241
242            // Look for function definition
243            if line.starts_with("fn ") || line.starts_with("pub fn ") || line.starts_with("async fn ") || line.starts_with("pub async fn ") {
244                // Extract function name
245                let without_prefix = line
246                    .trim_start_matches("pub ")
247                    .trim_start_matches("async ")
248                    .trim_start_matches("fn ");
249
250                // Get the function name (everything before '(' or '<')
251                let name_end = without_prefix.find(['(', '<', ' ']).unwrap_or(without_prefix.len());
252                let name = without_prefix[..name_end].trim().to_string();
253
254                if !name.is_empty() {
255                    return Some(name);
256                }
257            }
258        }
259
260        None
261    }
262
263    /// Legacy: Discover integration tests in tests/ directory (file-level)
264    fn discover_integration_tests(&self) -> Result<Vec<TestTarget>> {
265        let tests_dir = self.project_root.join("tests");
266        if !tests_dir.exists() {
267            return Ok(Vec::new());
268        }
269
270        let mut targets = Vec::new();
271
272        for entry in WalkDir::new(&tests_dir)
273            .min_depth(1)
274            .max_depth(3)
275            .into_iter()
276            .filter_map(|e| e.ok())
277        {
278            let path = entry.path();
279            if path.is_file() && path.extension().is_some_and(|e| e == "rs") {
280                let tags = self.extract_file_tags(path)?;
281                let name = path
282                    .file_stem()
283                    .and_then(|s| s.to_str())
284                    .unwrap_or("unknown")
285                    .to_string();
286
287                targets.push(TestTarget {
288                    name,
289                    path: path.to_path_buf(),
290                    test_type: TestType::Integration,
291                    tags,
292                });
293            }
294        }
295
296        Ok(targets)
297    }
298
299    /// Legacy: Discover unit tests in src/ directory (file-level)
300    fn discover_unit_tests(&self) -> Result<Vec<TestTarget>> {
301        let src_dir = self.project_root.join("src");
302        if !src_dir.exists() {
303            return Ok(Vec::new());
304        }
305
306        let mut targets = Vec::new();
307
308        for entry in WalkDir::new(&src_dir)
309            .into_iter()
310            .filter_map(|e| e.ok())
311        {
312            let path = entry.path();
313            if path.is_file() && path.extension().is_some_and(|e| e == "rs") {
314                let content = fs::read_to_string(path)
315                    .with_context(|| format!("Failed to read {}", path.display()))?;
316
317                if self.contains_tests(&content) {
318                    let tags = self.extract_file_tags(path)?;
319                    let name = path
320                        .file_stem()
321                        .and_then(|s| s.to_str())
322                        .unwrap_or("unknown")
323                        .to_string();
324
325                    targets.push(TestTarget {
326                        name,
327                        path: path.to_path_buf(),
328                        test_type: TestType::Unit,
329                        tags,
330                    });
331                }
332            }
333        }
334
335        Ok(targets)
336    }
337
338    /// Check if a file contains test functions
339    fn contains_tests(&self, content: &str) -> bool {
340        content.contains("#[test]") || content.contains("#[cfg(test)]")
341    }
342
343    /// Extract all tags from a file (for legacy file-level support)
344    fn extract_file_tags(&self, path: &Path) -> Result<Vec<String>> {
345        let content = fs::read_to_string(path)
346            .with_context(|| format!("Failed to read {}", path.display()))?;
347
348        let mut tags = Vec::new();
349
350        for line in content.lines() {
351            if let Some(tag) = self.parse_tag_line(line) {
352                if !tags.contains(&tag) {
353                    tags.push(tag);
354                }
355            }
356        }
357
358        Ok(tags)
359    }
360
361    /// Get the project root by looking for Cargo.toml
362    pub fn find_project_root() -> Result<PathBuf> {
363        let current_dir = std::env::current_dir()
364            .context("Failed to get current directory")?;
365
366        let mut dir = current_dir.as_path();
367        loop {
368            if dir.join("Cargo.toml").exists() {
369                return Ok(dir.to_path_buf());
370            }
371            dir = dir.parent().context("Failed to find Cargo.toml in parent directories")?;
372        }
373    }
374}
375
376#[cfg(test)]
377mod tests {
378    use super::*;
379
380    #[test]
381    fn test_is_test_attribute() {
382        let discovery = TestDiscovery::new(PathBuf::from("."));
383
384        assert!(discovery.is_test_attribute("#[test]"));
385        assert!(discovery.is_test_attribute("#[tokio::test]"));
386        assert!(discovery.is_test_attribute("#[async_std::test]"));
387        assert!(!discovery.is_test_attribute("fn test_something()"));
388        assert!(!discovery.is_test_attribute("// #[test]"));
389    }
390
391    #[test]
392    fn test_parse_tag_line() {
393        let discovery = TestDiscovery::new(PathBuf::from("."));
394
395        assert_eq!(discovery.parse_tag_line("// @tag: fast"), Some("fast".to_string()));
396        assert_eq!(discovery.parse_tag_line("//@tag: slow"), Some("slow".to_string()));
397        assert_eq!(discovery.parse_tag_line("#[test_tag(\"database\")]"), Some("database".to_string()));
398        assert_eq!(discovery.parse_tag_line("fn test()"), None);
399    }
400
401    #[test]
402    fn test_find_function_name() {
403        let discovery = TestDiscovery::new(PathBuf::from("."));
404
405        let lines = vec![
406            "fn test_something() {",
407            "    assert!(true);",
408            "}",
409        ];
410        assert_eq!(discovery.find_function_name(&lines, 0), Some("test_something".to_string()));
411
412        let lines2 = vec![
413            "pub fn test_public() {",
414        ];
415        assert_eq!(discovery.find_function_name(&lines2, 0), Some("test_public".to_string()));
416
417        let lines3 = vec![
418            "async fn test_async() {",
419        ];
420        assert_eq!(discovery.find_function_name(&lines3, 0), Some("test_async".to_string()));
421    }
422}