use std::path::Path;
use tldr_core::ast::ParserPool;
use tldr_core::Language;
use tree_sitter::{Node, Tree};
#[derive(Debug, Clone, Default)]
pub struct TestFileInfo {
pub is_test_file: bool,
pub test_function_count: u32,
}
pub fn recognize(path: &Path, source: &str, language: Language) -> TestFileInfo {
if !is_candidate_test_file(path, language) {
return TestFileInfo::default();
}
if matches!(language, Language::Rust) && !source.contains("#[test]") {
return TestFileInfo::default();
}
if source.trim().is_empty() {
return TestFileInfo {
is_test_file: true,
test_function_count: 0,
};
}
let pool = ParserPool::new();
let tree = match pool.parse(source, language).ok() {
Some(t) => t,
None => {
return TestFileInfo {
is_test_file: true,
test_function_count: 0,
};
}
};
let count = count_test_functions(&tree, source.as_bytes(), language);
TestFileInfo {
is_test_file: true,
test_function_count: count,
}
}
fn is_candidate_test_file(path: &Path, language: Language) -> bool {
let file_name = match path.file_name().and_then(|n| n.to_str()) {
Some(n) => n,
None => return false,
};
let stem = path
.file_stem()
.and_then(|s| s.to_str())
.unwrap_or(file_name);
let lower = file_name.to_ascii_lowercase();
match language {
Language::Python => {
file_name.starts_with("test_") && file_name.ends_with(".py")
|| file_name.ends_with("_test.py")
}
Language::JavaScript | Language::TypeScript => {
let in_tests_dir = path
.components()
.any(|c| c.as_os_str() == "__tests__" || c.as_os_str() == "test"
|| c.as_os_str() == "tests" || c.as_os_str() == "spec");
let has_test_marker = stem.ends_with(".test")
|| stem.ends_with(".spec")
|| stem.ends_with("_test")
|| stem.ends_with("_spec")
|| stem.ends_with("Test")
|| stem.ends_with("Spec");
(has_test_marker || in_tests_dir)
&& (lower.ends_with(".js")
|| lower.ends_with(".jsx")
|| lower.ends_with(".mjs")
|| lower.ends_with(".cjs")
|| lower.ends_with(".ts")
|| lower.ends_with(".tsx"))
}
Language::Java => {
if !lower.ends_with(".java") {
return false;
}
stem.ends_with("Test")
|| stem.ends_with("Tests")
|| stem.ends_with("IT")
|| stem.ends_with("ITCase")
|| path.components().any(|c| c.as_os_str() == "test")
}
Language::Kotlin => {
(lower.ends_with(".kt") || lower.ends_with(".kts"))
&& (stem.ends_with("Test")
|| stem.ends_with("Tests")
|| path.components().any(|c| c.as_os_str() == "test"))
}
Language::Php => lower.ends_with(".php") && (stem.ends_with("Test") || stem.ends_with("Tests")),
Language::Swift => {
lower.ends_with(".swift")
&& (stem.ends_with("Tests") || stem.ends_with("Test") || stem.ends_with("Spec"))
}
Language::Ruby => {
lower.ends_with(".rb")
&& (file_name.starts_with("test_")
|| stem.ends_with("_test")
|| stem.ends_with("_spec"))
}
Language::Go => lower.ends_with("_test.go"),
Language::Scala => {
lower.ends_with(".scala")
&& (stem.ends_with("Test")
|| stem.ends_with("Tests")
|| stem.ends_with("Spec")
|| stem.ends_with("Suite"))
}
Language::Elixir => {
(lower.ends_with(".exs") || lower.ends_with(".ex"))
&& (stem.ends_with("_test") || file_name.starts_with("test_"))
}
Language::Lua | Language::Luau => {
(lower.ends_with(".lua") || lower.ends_with(".luau"))
&& (stem.ends_with("_spec")
|| stem.ends_with("_test")
|| file_name.starts_with("test_"))
}
Language::Rust => {
lower.ends_with(".rs")
}
Language::CSharp => {
lower.ends_with(".cs")
&& (stem.ends_with("Test")
|| stem.ends_with("Tests")
|| path.components().any(|c| {
let s = c.as_os_str().to_string_lossy().to_ascii_lowercase();
s == "test" || s == "tests"
}))
}
Language::C | Language::Cpp | Language::Ocaml => {
lower.contains("test") || lower.contains("spec")
}
}
}
fn count_test_functions(tree: &Tree, source: &[u8], language: Language) -> u32 {
let root = tree.root_node();
let mut count = 0u32;
walk_count(&root, source, language, &mut count);
count
}
fn walk_count(node: &Node, source: &[u8], language: Language, count: &mut u32) {
if matches_test_function(node, source, language) {
*count += 1;
return;
}
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
walk_count(&child, source, language, count);
}
}
pub fn is_test_function_node(node: &Node, source: &[u8], language: Language) -> bool {
matches_test_function(node, source, language)
}
fn matches_test_function(node: &Node, source: &[u8], language: Language) -> bool {
match language {
Language::Python => python_is_test_function(node, source),
Language::JavaScript | Language::TypeScript => js_is_test_call(node, source),
Language::Java | Language::Kotlin => jvm_has_test_annotation(node, source),
Language::Php => php_is_test_method(node, source),
Language::Swift => swift_is_test_method(node, source),
Language::Ruby => ruby_is_test_def_or_block(node, source),
Language::Go => go_is_top_level_test_function(node, source),
Language::Scala => scala_is_test_call(node, source),
Language::Elixir => elixir_is_test_macro(node, source),
Language::Lua | Language::Luau => lua_is_test_call(node, source),
Language::Rust => rust_is_test_function(node, source),
Language::CSharp => csharp_has_test_attribute(node, source),
Language::C | Language::Cpp | Language::Ocaml => false,
}
}
fn rust_is_test_function(node: &Node, source: &[u8]) -> bool {
if node.kind() != "function_item" {
return false;
}
let mut prev = node.prev_sibling();
while let Some(p) = prev {
match p.kind() {
"attribute_item" => {
if rust_attribute_is_test(&p, source) {
return true;
}
prev = p.prev_sibling();
}
"line_comment" | "block_comment" => {
prev = p.prev_sibling();
}
_ => break,
}
}
false
}
fn rust_attribute_is_test(attr_item: &Node, source: &[u8]) -> bool {
let mut cursor = attr_item.walk();
for child in attr_item.children(&mut cursor) {
if child.kind() == "attribute" {
let text = node_text(child, source);
let head = text
.split(|c: char| c == '(' || c.is_whitespace())
.next()
.unwrap_or("");
let tail = head.rsplit("::").next().unwrap_or("");
if matches!(
tail,
"test" | "tokio_test" | "async_test" | "rstest" | "test_case"
) {
return true;
}
}
}
false
}
fn csharp_has_test_attribute(node: &Node, source: &[u8]) -> bool {
if node.kind() != "method_declaration" {
return false;
}
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if child.kind() == "attribute_list" {
let mut inner = child.walk();
for attr in child.children(&mut inner) {
if attr.kind() == "attribute" && csharp_attribute_is_test(&attr, source) {
return true;
}
}
}
}
false
}
fn csharp_attribute_is_test(attribute: &Node, source: &[u8]) -> bool {
let text = node_text(*attribute, source);
let head = text
.split(|c: char| c == '(' || c.is_whitespace())
.next()
.unwrap_or("");
let tail = head.rsplit('.').next().unwrap_or("");
matches!(
tail,
"Test"
| "TestAttribute"
| "Fact"
| "FactAttribute"
| "Theory"
| "TheoryAttribute"
| "TestMethod"
| "TestMethodAttribute"
| "TestCase"
| "TestCaseAttribute"
| "DataTestMethod"
| "DataTestMethodAttribute"
)
}
fn python_is_test_function(node: &Node, source: &[u8]) -> bool {
if node.kind() != "function_definition" {
return false;
}
let name = node
.child_by_field_name("name")
.map(|n| node_text(n, source))
.unwrap_or_default();
name.starts_with("test_")
}
fn js_is_test_call(node: &Node, source: &[u8]) -> bool {
if node.kind() != "call_expression" {
return false;
}
let func_node = match node.child_by_field_name("function") {
Some(n) => n,
None => return false,
};
if func_node.kind() != "identifier" {
return false;
}
let name = node_text(func_node, source);
matches!(name.as_str(), "it" | "test" | "fit" | "xit" | "xtest")
}
fn jvm_has_test_annotation(node: &Node, source: &[u8]) -> bool {
let kind = node.kind();
if kind != "method_declaration" && kind != "function_declaration" {
return false;
}
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if child.kind() == "modifiers" {
if subtree_contains_annotation_named(&child, source, "Test") {
return true;
}
} else if child.kind() == "annotation" || child.kind() == "marker_annotation" {
if annotation_has_name(&child, source, "Test") {
return true;
}
}
}
false
}
fn subtree_contains_annotation_named(node: &Node, source: &[u8], target: &str) -> bool {
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
let kind = child.kind();
if (kind == "annotation" || kind == "marker_annotation")
&& annotation_has_name(&child, source, target)
{
return true;
}
if subtree_contains_annotation_named(&child, source, target) {
return true;
}
}
false
}
fn annotation_has_name(annotation_node: &Node, source: &[u8], target: &str) -> bool {
let text = node_text(*annotation_node, source);
let trimmed = text.trim_start_matches('@');
let head = trimmed.split(|c: char| c == '(' || c.is_whitespace()).next().unwrap_or("");
let last = head.rsplit('.').next().unwrap_or("");
last == target
}
fn php_is_test_method(node: &Node, source: &[u8]) -> bool {
if node.kind() != "method_declaration" {
return false;
}
let name = node
.child_by_field_name("name")
.map(|n| node_text(n, source))
.unwrap_or_default();
name.starts_with("test")
}
fn swift_is_test_method(node: &Node, source: &[u8]) -> bool {
let kind = node.kind();
if !(kind == "function_declaration" || kind == "protocol_function_declaration") {
return false;
}
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if child.kind() == "simple_identifier" {
return node_text(child, source).starts_with("test");
}
}
false
}
fn ruby_is_test_def_or_block(node: &Node, source: &[u8]) -> bool {
match node.kind() {
"method" => {
let name = node
.child_by_field_name("name")
.map(|n| node_text(n, source))
.unwrap_or_default();
name.starts_with("test_")
}
"call" => {
let method = node
.child_by_field_name("method")
.map(|n| node_text(n, source))
.unwrap_or_default();
matches!(method.as_str(), "it" | "specify")
}
_ => false,
}
}
fn go_is_top_level_test_function(node: &Node, source: &[u8]) -> bool {
if node.kind() != "function_declaration" {
return false;
}
let name = node
.child_by_field_name("name")
.map(|n| node_text(n, source))
.unwrap_or_default();
if !name.starts_with("Test") {
return false;
}
let after = name.strip_prefix("Test").unwrap_or("");
let starts_upper = after.chars().next().map(|c| c.is_ascii_uppercase()).unwrap_or(false);
starts_upper
}
fn scala_is_test_call(node: &Node, source: &[u8]) -> bool {
if node.kind() != "call_expression" {
return false;
}
let func_node = match node.child_by_field_name("function") {
Some(n) => n,
None => return false,
};
let name = node_text(func_node, source);
matches!(name.as_str(), "test")
}
fn elixir_is_test_macro(node: &Node, source: &[u8]) -> bool {
if node.kind() != "call" {
return false;
}
let target = match node.child_by_field_name("target") {
Some(n) => n,
None => return false,
};
node_text(target, source) == "test"
}
fn lua_is_test_call(node: &Node, source: &[u8]) -> bool {
if node.kind() != "function_call" && node.kind() != "function_call_statement" {
return false;
}
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if child.kind() == "identifier" {
return matches!(node_text(child, source).as_str(), "it" | "test");
}
}
false
}
fn node_text(node: Node, source: &[u8]) -> String {
let start = node.start_byte();
let end = node.end_byte();
if end <= source.len() {
std::str::from_utf8(&source[start..end])
.unwrap_or("")
.to_string()
} else {
String::new()
}
}
pub fn detect_language(path: &Path) -> Option<Language> {
Language::from_path(path)
}
#[cfg(test)]
mod tests {
use super::*;
use std::fs;
use tempfile::tempdir;
fn write(dir: &Path, name: &str, body: &str) -> std::path::PathBuf {
let p = dir.join(name);
if let Some(parent) = p.parent() {
fs::create_dir_all(parent).ok();
}
fs::write(&p, body).unwrap();
p
}
#[test]
fn python_test_function_counted() {
let tmp = tempdir().unwrap();
let p = write(
tmp.path(),
"test_x.py",
"def test_one():\n pass\n\ndef test_two():\n pass\n",
);
let src = fs::read_to_string(&p).unwrap();
let info = recognize(&p, &src, Language::Python);
assert!(info.is_test_file);
assert_eq!(info.test_function_count, 2);
}
#[test]
fn javascript_describe_it_counted() {
let tmp = tempdir().unwrap();
let p = write(
tmp.path(),
"foo.test.js",
"describe('s', () => { it('a', () => {}); it('b', () => {}); });",
);
let src = fs::read_to_string(&p).unwrap();
let info = recognize(&p, &src, Language::JavaScript);
assert!(info.is_test_file);
assert_eq!(info.test_function_count, 2);
}
#[test]
fn java_test_annotation_counted() {
let tmp = tempdir().unwrap();
let p = write(
tmp.path(),
"FooTest.java",
"import org.junit.Test;\nclass FooTest {\n @Test public void shouldFoo() {}\n @Test public void shouldBar() {}\n}\n",
);
let src = fs::read_to_string(&p).unwrap();
let info = recognize(&p, &src, Language::Java);
assert!(info.is_test_file);
assert_eq!(info.test_function_count, 2);
}
#[test]
fn php_phpunit_counted() {
let tmp = tempdir().unwrap();
let p = write(
tmp.path(),
"FooTest.php",
"<?php\nclass FooTest {\n public function testBar() {}\n public function testBaz() {}\n}\n",
);
let src = fs::read_to_string(&p).unwrap();
let info = recognize(&p, &src, Language::Php);
assert!(info.is_test_file);
assert_eq!(info.test_function_count, 2);
}
#[test]
fn swift_xctest_counted() {
let tmp = tempdir().unwrap();
let p = write(
tmp.path(),
"FooTests.swift",
"import XCTest\nclass FooTests: XCTestCase {\n func testBar() {}\n func testBaz() {}\n}\n",
);
let src = fs::read_to_string(&p).unwrap();
let info = recognize(&p, &src, Language::Swift);
assert!(info.is_test_file);
assert!(info.test_function_count >= 2);
}
#[test]
fn go_testing_counted() {
let tmp = tempdir().unwrap();
let p = write(
tmp.path(),
"foo_test.go",
"package foo\nimport \"testing\"\nfunc TestFoo(t *testing.T) {}\nfunc TestBar(t *testing.T) {}\nfunc helper() {}\n",
);
let src = fs::read_to_string(&p).unwrap();
let info = recognize(&p, &src, Language::Go);
assert!(info.is_test_file);
assert_eq!(info.test_function_count, 2);
}
#[test]
fn ruby_minitest_counted() {
let tmp = tempdir().unwrap();
let p = write(
tmp.path(),
"foo_test.rb",
"class FooTest\n def test_one; end\n def test_two; end\nend\n",
);
let src = fs::read_to_string(&p).unwrap();
let info = recognize(&p, &src, Language::Ruby);
assert!(info.is_test_file);
assert_eq!(info.test_function_count, 2);
}
}