use anyhow::{Context, Result};
use serde::{Deserialize, Serialize};
use std::fs;
use std::path::{Path, PathBuf};
use walkdir::WalkDir;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TestFunction {
pub name: String,
pub file_path: PathBuf,
pub target_name: String,
pub test_type: TestType,
pub tags: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TestTarget {
pub name: String,
pub path: PathBuf,
pub test_type: TestType,
pub tags: Vec<String>,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum TestType {
Unit,
Integration,
Doc,
}
pub struct TestDiscovery {
project_root: PathBuf,
}
impl TestDiscovery {
pub fn new(project_root: PathBuf) -> Self {
Self { project_root }
}
pub fn discover_test_functions(&self) -> Result<Vec<TestFunction>> {
let mut functions = Vec::new();
functions.extend(self.discover_integration_test_functions()?);
functions.extend(self.discover_unit_test_functions()?);
Ok(functions)
}
pub fn discover_tests(&self) -> Result<Vec<TestTarget>> {
let mut tests = Vec::new();
tests.extend(self.discover_integration_tests()?);
tests.extend(self.discover_unit_tests()?);
Ok(tests)
}
fn discover_integration_test_functions(&self) -> Result<Vec<TestFunction>> {
let tests_dir = self.project_root.join("tests");
if !tests_dir.exists() {
return Ok(Vec::new());
}
let mut functions = Vec::new();
for entry in WalkDir::new(&tests_dir)
.min_depth(1)
.max_depth(3)
.into_iter()
.filter_map(|e| e.ok())
{
let path = entry.path();
if path.is_file() && path.extension().is_some_and(|e| e == "rs") {
let target_name = path
.file_stem()
.and_then(|s| s.to_str())
.unwrap_or("unknown")
.to_string();
let file_functions = self.parse_test_functions(path, &target_name, TestType::Integration)?;
functions.extend(file_functions);
}
}
Ok(functions)
}
fn discover_unit_test_functions(&self) -> Result<Vec<TestFunction>> {
let src_dir = self.project_root.join("src");
if !src_dir.exists() {
return Ok(Vec::new());
}
let mut functions = Vec::new();
for entry in WalkDir::new(&src_dir)
.into_iter()
.filter_map(|e| e.ok())
{
let path = entry.path();
if path.is_file() && path.extension().is_some_and(|e| e == "rs") {
let content = fs::read_to_string(path)
.with_context(|| format!("Failed to read {}", path.display()))?;
if self.contains_tests(&content) {
let file_functions = self.parse_test_functions(path, "lib", TestType::Unit)?;
functions.extend(file_functions);
}
}
}
Ok(functions)
}
fn parse_test_functions(&self, path: &Path, target_name: &str, test_type: TestType) -> Result<Vec<TestFunction>> {
let content = fs::read_to_string(path)
.with_context(|| format!("Failed to read {}", path.display()))?;
let mut functions = Vec::new();
let lines: Vec<&str> = content.lines().collect();
let mut i = 0;
while i < lines.len() {
let line = lines[i].trim();
if self.is_test_attribute(line) {
let tags = self.collect_preceding_tags(&lines, i);
let func_name = self.find_function_name(&lines, i + 1);
if let Some(name) = func_name {
functions.push(TestFunction {
name,
file_path: path.to_path_buf(),
target_name: target_name.to_string(),
test_type: test_type.clone(),
tags,
});
}
}
i += 1;
}
Ok(functions)
}
fn is_test_attribute(&self, line: &str) -> bool {
let line = line.trim();
line == "#[test]"
|| line.starts_with("#[test(")
|| line.starts_with("#[tokio::test")
|| line.starts_with("#[async_std::test")
|| line.starts_with("#[rstest")
|| line.starts_with("#[test_case")
}
fn collect_preceding_tags(&self, lines: &[&str], test_line_idx: usize) -> Vec<String> {
let mut tags = Vec::new();
let mut j = test_line_idx;
while j > 0 {
j -= 1;
let line = lines[j].trim();
if line.is_empty() {
break;
}
if let Some(tag) = self.parse_tag_line(line) {
tags.push(tag);
} else if !line.starts_with("//") && !line.starts_with("#[") {
break;
}
}
tags
}
fn parse_tag_line(&self, line: &str) -> Option<String> {
let line = line.trim();
if line.starts_with("// @tag:") || line.starts_with("//@tag:") {
let parts: Vec<&str> = line.splitn(2, ':').collect();
if parts.len() >= 2 {
return Some(parts[1].trim().to_string());
}
}
if line.starts_with("#[test_tag(") && line.ends_with(")]") {
let start = line.find('"')?;
let end = line.rfind('"')?;
if start < end {
return Some(line[start + 1..end].to_string());
}
}
None
}
fn find_function_name(&self, lines: &[&str], start_idx: usize) -> Option<String> {
for line in lines.iter().skip(start_idx).take(5) {
let line = line.trim();
if line.starts_with("#[") {
continue;
}
if line.starts_with("fn ") || line.starts_with("pub fn ") || line.starts_with("async fn ") || line.starts_with("pub async fn ") {
let without_prefix = line
.trim_start_matches("pub ")
.trim_start_matches("async ")
.trim_start_matches("fn ");
let name_end = without_prefix.find(['(', '<', ' ']).unwrap_or(without_prefix.len());
let name = without_prefix[..name_end].trim().to_string();
if !name.is_empty() {
return Some(name);
}
}
}
None
}
fn discover_integration_tests(&self) -> Result<Vec<TestTarget>> {
let tests_dir = self.project_root.join("tests");
if !tests_dir.exists() {
return Ok(Vec::new());
}
let mut targets = Vec::new();
for entry in WalkDir::new(&tests_dir)
.min_depth(1)
.max_depth(3)
.into_iter()
.filter_map(|e| e.ok())
{
let path = entry.path();
if path.is_file() && path.extension().is_some_and(|e| e == "rs") {
let tags = self.extract_file_tags(path)?;
let name = path
.file_stem()
.and_then(|s| s.to_str())
.unwrap_or("unknown")
.to_string();
targets.push(TestTarget {
name,
path: path.to_path_buf(),
test_type: TestType::Integration,
tags,
});
}
}
Ok(targets)
}
fn discover_unit_tests(&self) -> Result<Vec<TestTarget>> {
let src_dir = self.project_root.join("src");
if !src_dir.exists() {
return Ok(Vec::new());
}
let mut targets = Vec::new();
for entry in WalkDir::new(&src_dir)
.into_iter()
.filter_map(|e| e.ok())
{
let path = entry.path();
if path.is_file() && path.extension().is_some_and(|e| e == "rs") {
let content = fs::read_to_string(path)
.with_context(|| format!("Failed to read {}", path.display()))?;
if self.contains_tests(&content) {
let tags = self.extract_file_tags(path)?;
let name = path
.file_stem()
.and_then(|s| s.to_str())
.unwrap_or("unknown")
.to_string();
targets.push(TestTarget {
name,
path: path.to_path_buf(),
test_type: TestType::Unit,
tags,
});
}
}
}
Ok(targets)
}
fn contains_tests(&self, content: &str) -> bool {
content.contains("#[test]") || content.contains("#[cfg(test)]")
}
fn extract_file_tags(&self, path: &Path) -> Result<Vec<String>> {
let content = fs::read_to_string(path)
.with_context(|| format!("Failed to read {}", path.display()))?;
let mut tags = Vec::new();
for line in content.lines() {
if let Some(tag) = self.parse_tag_line(line) {
if !tags.contains(&tag) {
tags.push(tag);
}
}
}
Ok(tags)
}
pub fn find_project_root() -> Result<PathBuf> {
let current_dir = std::env::current_dir()
.context("Failed to get current directory")?;
let mut dir = current_dir.as_path();
loop {
if dir.join("Cargo.toml").exists() {
return Ok(dir.to_path_buf());
}
dir = dir.parent().context("Failed to find Cargo.toml in parent directories")?;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_is_test_attribute() {
let discovery = TestDiscovery::new(PathBuf::from("."));
assert!(discovery.is_test_attribute("#[test]"));
assert!(discovery.is_test_attribute("#[tokio::test]"));
assert!(discovery.is_test_attribute("#[async_std::test]"));
assert!(!discovery.is_test_attribute("fn test_something()"));
assert!(!discovery.is_test_attribute("// #[test]"));
}
#[test]
fn test_parse_tag_line() {
let discovery = TestDiscovery::new(PathBuf::from("."));
assert_eq!(discovery.parse_tag_line("// @tag: fast"), Some("fast".to_string()));
assert_eq!(discovery.parse_tag_line("//@tag: slow"), Some("slow".to_string()));
assert_eq!(discovery.parse_tag_line("#[test_tag(\"database\")]"), Some("database".to_string()));
assert_eq!(discovery.parse_tag_line("fn test()"), None);
}
#[test]
fn test_find_function_name() {
let discovery = TestDiscovery::new(PathBuf::from("."));
let lines = vec![
"fn test_something() {",
" assert!(true);",
"}",
];
assert_eq!(discovery.find_function_name(&lines, 0), Some("test_something".to_string()));
let lines2 = vec![
"pub fn test_public() {",
];
assert_eq!(discovery.find_function_name(&lines2, 0), Some("test_public".to_string()));
let lines3 = vec![
"async fn test_async() {",
];
assert_eq!(discovery.find_function_name(&lines3, 0), Some("test_async".to_string()));
}
}