use crate::finding::SeverityClass;
use std::path::Path;
use syn::spanned::Spanned;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ReferenceContext {
TestFn,
ImplBlock,
Caller,
}
impl ReferenceContext {
pub fn refined_severity(self) -> SeverityClass {
match self {
Self::TestFn => SeverityClass::Low,
Self::ImplBlock => SeverityClass::High,
Self::Caller => SeverityClass::Medium,
}
}
}
pub fn classify(file_abs: &Path, line_1based: u32) -> ReferenceContext {
let Ok(src) = std::fs::read_to_string(file_abs) else {
return ReferenceContext::Caller;
};
let Ok(ast) = syn::parse_file(&src) else {
return ReferenceContext::Caller;
};
classify_in_file(&ast, line_1based)
}
fn classify_in_file(ast: &syn::File, line_1based: u32) -> ReferenceContext {
let mut walker = Walker {
line: line_1based,
in_test_mod: false,
best: ReferenceContext::Caller,
};
walker.visit_items(&ast.items);
walker.best
}
struct Walker {
line: u32,
in_test_mod: bool,
best: ReferenceContext,
}
impl Walker {
fn visit_items(&mut self, items: &[syn::Item]) {
for item in items {
self.visit_item(item);
}
}
fn visit_item(&mut self, item: &syn::Item) {
let span = item.span();
let start = span.start().line as u32;
let end = span.end().line as u32;
if self.line < start || self.line > end {
return;
}
match item {
syn::Item::Fn(f) if has_test_attr(&f.attrs) || self.in_test_mod => {
self.best = ReferenceContext::TestFn;
}
syn::Item::Fn(_) => {}
syn::Item::Impl(i) => {
if self.in_test_mod || matches!(self.best, ReferenceContext::TestFn) {
self.best = ReferenceContext::TestFn;
} else {
self.best = ReferenceContext::ImplBlock;
}
for ii in &i.items {
if let syn::ImplItem::Fn(f) = ii {
let span = f.span();
let start = span.start().line as u32;
let end = span.end().line as u32;
if self.line >= start && self.line <= end && has_test_attr(&f.attrs) {
self.best = ReferenceContext::TestFn;
}
}
}
}
syn::Item::Mod(m) => {
let was_test = self.in_test_mod;
if is_test_module(m) {
self.in_test_mod = true;
}
if let Some((_, inner)) = &m.content {
self.visit_items(inner);
}
self.in_test_mod = was_test;
}
_ => {}
}
}
}
fn has_test_attr(attrs: &[syn::Attribute]) -> bool {
attrs.iter().any(|a| {
let Some(last) = a.path().segments.last() else {
return false;
};
matches!(
last.ident.to_string().as_str(),
"test" | "tokio_test" | "rstest" | "proptest" | "bench"
)
})
}
fn is_test_module(m: &syn::ItemMod) -> bool {
if m.ident == "tests" {
return true;
}
m.attrs.iter().any(|a| {
if !a.path().is_ident("cfg") {
return false;
}
a.parse_args::<syn::Ident>()
.map(|id| id == "test")
.unwrap_or(false)
})
}
#[cfg(test)]
mod tests {
use super::*;
fn parse(src: &str) -> syn::File {
syn::parse_str(src).expect("parse")
}
#[test]
fn plain_top_level_fn_is_caller() {
let ast = parse("fn hello() {\n let x = 1;\n foo(x);\n}\n");
assert_eq!(classify_in_file(&ast, 3), ReferenceContext::Caller);
}
#[test]
fn test_attribute_on_fn_flags_test_fn() {
let ast = parse("#[test]\nfn t_smoke() {\n foo();\n}\n");
assert_eq!(classify_in_file(&ast, 3), ReferenceContext::TestFn);
}
#[test]
fn tokio_test_attribute_flags_test_fn() {
let ast = parse("#[tokio::test]\nasync fn t_async() {\n foo();\n}\n");
assert_eq!(classify_in_file(&ast, 3), ReferenceContext::TestFn);
}
#[test]
fn rstest_attribute_flags_test_fn() {
let ast = parse("#[rstest]\nfn t_param() {\n foo();\n}\n");
assert_eq!(classify_in_file(&ast, 3), ReferenceContext::TestFn);
}
#[test]
fn fn_inside_tests_module_is_test_fn() {
let ast = parse("mod tests {\n fn helper() {\n foo();\n }\n}\n");
assert_eq!(classify_in_file(&ast, 3), ReferenceContext::TestFn);
}
#[test]
fn fn_inside_cfg_test_module_is_test_fn() {
let ast = parse("#[cfg(test)]\nmod inner {\n fn helper() {\n foo();\n }\n}\n");
assert_eq!(classify_in_file(&ast, 4), ReferenceContext::TestFn);
}
#[test]
fn impl_block_method_is_impl_block() {
let ast = parse("struct W;\nimpl W {\n fn hi(&self) {\n foo();\n }\n}\n");
assert_eq!(classify_in_file(&ast, 4), ReferenceContext::ImplBlock);
}
#[test]
fn trait_impl_method_is_impl_block() {
let ast = parse(
"trait T { fn hi(&self); }\nstruct W;\nimpl T for W {\n fn hi(&self) {\n foo();\n }\n}\n",
);
assert_eq!(classify_in_file(&ast, 5), ReferenceContext::ImplBlock);
}
#[test]
fn test_fn_inside_impl_beats_impl_block() {
let ast = parse(
"struct W;\nimpl W {\n #[test]\n fn t_inner() {\n foo();\n }\n}\n",
);
assert_eq!(classify_in_file(&ast, 5), ReferenceContext::TestFn);
}
#[test]
fn nested_cfg_test_mod_propagates_into_inner_impl() {
let ast = parse(
"#[cfg(test)]\nmod tests {\n struct W;\n impl W {\n fn helper(&self) {\n foo();\n }\n }\n}\n",
);
assert_eq!(classify_in_file(&ast, 6), ReferenceContext::TestFn);
}
#[test]
fn line_outside_any_container_stays_caller() {
let ast = parse("fn a() {}\nfn b() {}\n");
assert_eq!(classify_in_file(&ast, 100), ReferenceContext::Caller);
}
#[test]
fn refined_severity_mapping() {
assert_eq!(
ReferenceContext::TestFn.refined_severity(),
SeverityClass::Low
);
assert_eq!(
ReferenceContext::ImplBlock.refined_severity(),
SeverityClass::High
);
assert_eq!(
ReferenceContext::Caller.refined_severity(),
SeverityClass::Medium
);
}
#[test]
fn classify_on_missing_file_returns_caller() {
let nonexistent = Path::new("/definitely/does/not/exist-xyz.rs");
assert_eq!(classify(nonexistent, 1), ReferenceContext::Caller);
}
}