extern crate rustc_ast;
use rustc_ast::Item;
use rustc_lint::{EarlyContext, EarlyLintPass, LintContext};
use std::cell::RefCell;
use std::collections::HashSet;
thread_local! {
static SCANNED_FILES: RefCell<HashSet<String>> = RefCell::new(HashSet::new());
}
const MODULE_PREFIXES: &[&str] = &[
"libs/",
"modules/system/",
"modules/",
"examples/",
"apps/",
"plugins/",
];
const DEFAULT_MAX_INLINE_TEST_LINES: usize = 100;
#[derive(Default, serde::Deserialize)]
struct Config {
#[serde(default)]
excluded_paths: Vec<String>,
#[serde(default)]
max_inline_test_lines: Option<usize>,
}
pub(crate) struct De1101TestsInSeparateFiles {
excluded_set: HashSet<String>,
max_inline_test_lines: usize,
}
impl De1101TestsInSeparateFiles {
pub fn new() -> Self {
let config: Config = dylint_linting::config_or_default(crate::LIBRARY_NAME);
Self {
excluded_set: config.excluded_paths.into_iter().collect(),
max_inline_test_lines: config
.max_inline_test_lines
.unwrap_or(DEFAULT_MAX_INLINE_TEST_LINES),
}
}
fn is_in_scope(&self, normalized_path: &str) -> bool {
for prefix in MODULE_PREFIXES {
if let Some(pos) = normalized_path.find(prefix) {
if pos > 0 && normalized_path.as_bytes()[pos - 1] != b'/' {
continue;
}
let rest = &normalized_path[pos + prefix.len()..];
let seg_end = rest.find('/').unwrap_or(rest.len());
let key = &normalized_path[pos..pos + prefix.len() + seg_end];
return !self.excluded_set.contains(key);
}
}
true
}
}
dylint_linting::impl_pre_expansion_lint! {
#[doc = include_str!("../../docs/de11_testing/de1101_tests_in_separate_files/README.md")]
pub DE1101_TESTS_IN_SEPARATE_FILES,
Deny,
"tests must live in separate files, not inline in production files (DE1101)",
De1101TestsInSeparateFiles::new()
}
enum TestViolation {
InlineTestCode,
InlineTestCodeWithCompanion,
WrongPathAttr { expected: String, actual: String },
}
impl EarlyLintPass for De1101TestsInSeparateFiles {
fn check_crate_post(&mut self, _cx: &EarlyContext<'_>, _krate: &rustc_ast::Crate) {
SCANNED_FILES.with(|files| files.borrow_mut().clear());
}
fn check_item(&mut self, cx: &EarlyContext<'_>, item: &Item) {
let Some(path) = crate::lint_utils::filename_str(cx.sess().source_map(), item.span) else {
return;
};
let normalized = path.replace('\\', "/");
if !self.is_in_scope(&normalized) {
return;
}
let should_scan = SCANNED_FILES.with(|files| files.borrow_mut().insert(normalized.clone()));
if !should_scan {
return;
}
let Ok(source) = std::fs::read_to_string(&path) else {
return;
};
if is_allowed_test_file(&normalized) {
return;
}
let source_stem = file_stem(&normalized);
let has_companion = has_companion_test_file(&path, source_stem.as_deref());
let violations = find_test_violations(
&source,
source_stem.as_deref(),
has_companion,
self.max_inline_test_lines,
);
for violation in violations {
match violation {
TestViolation::InlineTestCode => {
cx.span_lint(DE1101_TESTS_IN_SEPARATE_FILES, item.span, |diag| {
diag.primary_message(
"test code must be moved to a separate test file (DE1101)",
);
diag.help(format!(
"move the test into `tests/*.rs` or an out-of-line `*_tests.rs` module (inline test block exceeds {} lines)",
self.max_inline_test_lines,
));
});
}
TestViolation::InlineTestCodeWithCompanion => {
cx.span_lint(DE1101_TESTS_IN_SEPARATE_FILES, item.span, |diag| {
diag.primary_message(
"test code must not be added back to a file that already has a companion test file (DE1101)",
);
diag.help(
"a `*_tests.rs` companion file already exists; add tests there instead",
);
});
}
TestViolation::WrongPathAttr { expected, actual } => {
cx.span_lint(DE1101_TESTS_IN_SEPARATE_FILES, item.span, |diag| {
diag.primary_message(format!(
"test module path `{actual}.rs` must reference `{expected}.rs` to match the source file (DE1101)",
));
diag.help(format!(
"use `#[path = \"{expected}.rs\"]` or remove `#[path]`"
));
});
}
}
}
}
}
fn is_allowed_test_file(path: &str) -> bool {
let file_name = path.rsplit('/').next().unwrap_or(path);
path.contains("/tests/") || file_name.ends_with("_tests.rs")
}
fn file_stem(path: &str) -> Option<String> {
let file_name = path.rsplit('/').next().unwrap_or(path);
let stem = file_name.strip_suffix(".rs")?;
match stem {
"lib" | "main" | "mod" | "tests" | "test" => None,
_ => Some(stem.to_string()),
}
}
fn has_companion_test_file(source_path: &str, source_stem: Option<&str>) -> bool {
let Some(stem) = source_stem else {
return false;
};
let parent = match source_path.rfind('/').or_else(|| source_path.rfind('\\')) {
Some(pos) => &source_path[..=pos],
None => "",
};
let companion = format!("{parent}{stem}_tests.rs");
std::path::Path::new(&companion).exists()
}
fn count_inline_test_block_lines(lines: &[&str], open_brace_line: usize) -> usize {
let mut depth = 0usize;
let mut count = 0usize;
for line in &lines[open_brace_line..] {
count += 1;
for ch in line.chars() {
if ch == '{' {
depth += 1;
} else if ch == '}' {
depth = depth.saturating_sub(1);
if depth == 0 {
return count;
}
}
}
}
count
}
fn find_test_violations(
source: &str,
source_stem: Option<&str>,
has_companion: bool,
max_inline_lines: usize,
) -> Vec<TestViolation> {
let lines: Vec<&str> = source.lines().collect();
let mut violations = Vec::new();
let mut reported_inline = false;
let mut reported_naming = false;
for (index, line) in lines.iter().enumerate() {
if is_comment_or_blank_line(line) {
continue;
}
let compact_line = compact(line);
if !reported_inline && is_direct_test_attr(&compact_line) {
if has_companion {
violations.push(TestViolation::InlineTestCodeWithCompanion);
} else {
violations.push(TestViolation::InlineTestCode);
}
reported_inline = true;
continue;
}
if !is_cfg_test_attr(&compact_line) {
continue;
}
let mut next = index + 1;
let mut path_attr_value: Option<String> = None;
while let Some(candidate) = lines.get(next) {
let trimmed = candidate.trim();
let candidate_compact = compact(candidate);
if is_comment_or_blank_line(candidate) {
next += 1;
continue;
}
if candidate_compact.starts_with("#[") {
if is_path_attr(&candidate_compact) {
path_attr_value = extract_path_attr_value(trimmed);
}
next += 1;
continue;
}
if is_out_of_line_mod_decl(trimmed) {
if let (Some(stem), Some(pv)) = (source_stem, &path_attr_value)
&& !reported_naming
{
let expected = format!("{stem}_tests");
let filename = pv.rsplit('/').next().unwrap_or(pv);
let actual = filename.strip_suffix(".rs").unwrap_or(filename);
if actual != expected {
violations.push(TestViolation::WrongPathAttr {
expected,
actual: actual.to_string(),
});
reported_naming = true;
}
}
break;
}
if is_extern_crate_alias(trimmed) {
break;
}
if !reported_inline {
if has_companion {
violations.push(TestViolation::InlineTestCodeWithCompanion);
reported_inline = true;
} else {
let block_lines = count_inline_test_block_lines(&lines, next);
let total_test_lines = (next - index) + block_lines;
if total_test_lines > max_inline_lines {
violations.push(TestViolation::InlineTestCode);
}
reported_inline = true;
}
}
break;
}
}
violations
}
fn is_comment_or_blank_line(line: &str) -> bool {
let trimmed = line.trim();
trimmed.is_empty() || trimmed.starts_with("//")
}
fn compact(line: &str) -> String {
let without_comment = match line.find("//") {
Some(pos) => &line[..pos],
None => line,
};
without_comment
.chars()
.filter(|ch| !ch.is_whitespace())
.collect()
}
fn is_direct_test_attr(line: &str) -> bool {
let trimmed = line.trim_start();
let is_attr = trimmed.starts_with("#[");
trimmed.starts_with("#[test")
|| trimmed.starts_with("#[tokio::test")
|| (is_attr && trimmed.contains("::test]"))
|| (is_attr && trimmed.contains("::test("))
}
fn is_cfg_test_attr(line: &str) -> bool {
let Some(inner) = line
.strip_prefix("#[cfg(")
.and_then(|rest| rest.strip_suffix(")]"))
else {
return false;
};
contains_test_cfg_operand(inner)
}
fn contains_test_cfg_operand(input: &str) -> bool {
split_top_level_args(input).into_iter().any(|arg| {
if arg == "test" {
return true;
}
if let Some(inner) = arg
.strip_prefix("all(")
.and_then(|rest| rest.strip_suffix(')'))
{
return contains_test_cfg_operand(inner);
}
if let Some(inner) = arg
.strip_prefix("any(")
.and_then(|rest| rest.strip_suffix(')'))
{
return contains_test_cfg_operand(inner);
}
false
})
}
fn split_top_level_args(input: &str) -> Vec<&str> {
let mut args = Vec::new();
let mut depth = 0usize;
let mut start = 0usize;
for (index, ch) in input.char_indices() {
match ch {
'(' => depth += 1,
')' => depth = depth.saturating_sub(1),
',' if depth == 0 => {
args.push(input[start..index].trim());
start = index + ch.len_utf8();
}
_ => {}
}
}
args.push(input[start..].trim());
args
}
fn is_path_attr(compact_line: &str) -> bool {
compact_line.starts_with("#[path=")
}
fn extract_path_attr_value(line: &str) -> Option<String> {
let trimmed = line.trim();
let rest = trimmed.strip_prefix("#[path")?;
let rest = rest.trim_start();
let rest = rest.strip_prefix('=')?;
let rest = rest.trim_start();
let rest = rest.strip_prefix('"')?;
let end = rest.find('"')?;
Some(rest[..end].to_string())
}
fn strip_visibility(line: &str) -> &str {
if let Some(rest) = line.strip_prefix("pub(in ")
&& let Some(close) = rest.find(')')
{
let after = &rest[close + 1..];
return after.strip_prefix(' ').unwrap_or(after);
}
line.strip_prefix("pub(crate) ")
.or_else(|| line.strip_prefix("pub(super) "))
.or_else(|| line.strip_prefix("pub(self) "))
.or_else(|| line.strip_prefix("pub "))
.unwrap_or(line)
}
#[cfg(test)]
fn extract_mod_name(line: &str) -> Option<&str> {
if !line.ends_with(';') {
return None;
}
let name_with_semi = strip_visibility(line).strip_prefix("mod ")?;
Some(name_with_semi.trim_end_matches(';').trim())
}
fn is_out_of_line_mod_decl(line: &str) -> bool {
line.ends_with(';') && strip_visibility(line).starts_with("mod ")
}
fn is_extern_crate_alias(line: &str) -> bool {
line.starts_with("extern crate ") && line.ends_with(';')
}
#[cfg(test)]
mod tests {
use super::{
count_inline_test_block_lines, extract_mod_name, extract_path_attr_value,
find_test_violations, is_cfg_test_attr, is_path_attr,
};
#[test]
fn test_is_cfg_test_attr_matches_supported_forms() {
assert!(is_cfg_test_attr("#[cfg(test)]"));
assert!(is_cfg_test_attr("#[cfg(test,feature=\"foo\")]"));
assert!(is_cfg_test_attr("#[cfg(all(test,feature=\"foo\"))]"));
assert!(is_cfg_test_attr("#[cfg(any(feature=\"foo\",test))]"));
}
#[test]
fn test_is_cfg_test_attr_rejects_unsupported_forms() {
assert!(!is_cfg_test_attr("#[cfg(not(test))]"));
assert!(!is_cfg_test_attr("#[cfg(feature=\"test\")]"));
assert!(!is_cfg_test_attr("#[cfg(any(feature=\"test\",unix))]"));
}
#[test]
fn test_is_path_attr() {
assert!(is_path_attr("#[path=\"foo.rs\"]"));
assert!(is_path_attr("#[path=\"some/path.rs\"]"));
assert!(!is_path_attr("#[cfg(test)]"));
assert!(!is_path_attr("#[derive(Debug)]"));
}
#[test]
fn test_extract_path_attr_value() {
assert_eq!(
extract_path_attr_value(r#"#[path = "foo_tests.rs"]"#),
Some("foo_tests.rs".to_string())
);
assert_eq!(
extract_path_attr_value(r#"#[path="bar.rs"]"#),
Some("bar.rs".to_string())
);
assert_eq!(extract_path_attr_value(r#"#[cfg(test)]"#), None);
}
#[test]
fn test_extract_mod_name() {
assert_eq!(extract_mod_name("mod foo_tests;"), Some("foo_tests"));
assert_eq!(extract_mod_name("pub mod foo_tests;"), Some("foo_tests"));
assert_eq!(
extract_mod_name("pub(crate) mod foo_tests;"),
Some("foo_tests")
);
assert_eq!(extract_mod_name("mod tests;"), Some("tests"));
assert_eq!(extract_mod_name("mod tests { }"), None);
assert_eq!(extract_mod_name("fn main() {}"), None);
}
#[test]
fn test_find_violations_correct_name_no_issues() {
let source = r#"
#[cfg(test)]
mod handler_tests;
fn main() {}
"#;
let violations = find_test_violations(source, Some("handler"), false, 100);
assert!(violations.is_empty(), "expected no violations");
}
#[test]
fn test_find_violations_any_mod_name_without_path_ok() {
let source = r#"
#[cfg(test)]
mod tests;
fn main() {}
"#;
let violations = find_test_violations(source, Some("handler"), false, 100);
assert!(
violations.is_empty(),
"any mod name should be accepted without #[path]"
);
}
#[test]
fn test_find_violations_path_attr_wrong_value() {
let source = r#"
#[cfg(test)]
#[path = "dto_tests.rs"]
mod tests;
fn main() {}
"#;
let violations = find_test_violations(source, Some("handler"), false, 100);
assert_eq!(violations.len(), 1);
assert!(matches!(
&violations[0],
super::TestViolation::WrongPathAttr { expected, actual }
if expected == "handler_tests" && actual == "dto_tests"
));
}
#[test]
fn test_find_violations_path_attr_correct_value() {
let source = r#"
#[cfg(test)]
#[path = "handler_tests.rs"]
mod tests;
fn main() {}
"#;
let violations = find_test_violations(source, Some("handler"), false, 100);
assert!(violations.is_empty(), "expected no violations");
}
#[test]
fn test_find_violations_inline_code_over_threshold() {
let source = r#"
#[cfg(test)]
mod tests {
#[test]
fn foo() {}
}
fn main() {}
"#;
let violations = find_test_violations(source, Some("handler"), false, 3);
assert!(
violations
.iter()
.any(|v| matches!(v, super::TestViolation::InlineTestCode))
);
}
#[test]
fn test_find_violations_inline_code_under_threshold() {
let source = r#"
#[cfg(test)]
mod tests {
#[test]
fn foo() {}
}
fn main() {}
"#;
let violations = find_test_violations(source, Some("handler"), false, 100);
assert!(
violations.is_empty(),
"expected no violations under threshold"
);
}
#[test]
fn test_find_violations_inline_code_with_companion() {
let source = r#"
#[cfg(test)]
mod tests {
#[test]
fn foo() {}
}
fn main() {}
"#;
let violations = find_test_violations(source, Some("handler"), true, 100);
assert!(
violations
.iter()
.any(|v| matches!(v, super::TestViolation::InlineTestCodeWithCompanion))
);
}
#[test]
fn test_count_inline_test_block_lines() {
let source = "mod tests {\n #[test]\n fn foo() {}\n}\n";
let lines: Vec<&str> = source.lines().collect();
assert_eq!(count_inline_test_block_lines(&lines, 0), 4);
}
#[test]
fn test_count_inline_test_block_lines_nested() {
let source = "mod tests {\n fn foo() {\n if true {\n }\n }\n}\n";
let lines: Vec<&str> = source.lines().collect();
assert_eq!(count_inline_test_block_lines(&lines, 0), 6);
}
}