use crate::types::{AssertionStrength, AuditResult, TestId};
use serde::{Deserialize, Serialize};
use std::path::{Path, PathBuf};
use syn::{Expr, ExprCall, ExprMacro, Item, ItemFn};
use walkdir::WalkDir;
#[derive(Debug)]
pub struct AssertionAnalyzer {
test_dirs: Vec<PathBuf>,
}
impl AssertionAnalyzer {
pub fn new(test_dirs: Vec<PathBuf>) -> Self {
Self { test_dirs }
}
pub fn analyze_test_file(&self, file_path: &Path) -> AuditResult<Vec<TestAssertion>> {
let content = std::fs::read_to_string(file_path)?;
let syntax_tree = syn::parse_file(&content).map_err(|e| {
crate::types::AuditError::AssertionParseError(format!(
"Failed to parse {}: {}",
file_path.display(),
e
))
})?;
let mut assertions = Vec::new();
for item in &syntax_tree.items {
if let Item::Fn(func) = item {
if self.is_test_function(func) {
let test_id = TestId::new(func.sig.ident.to_string())?;
let strength = self.analyze_function_assertions(func);
assertions.push(TestAssertion {
test_id,
file_path: file_path.to_path_buf(),
assertion_strength: strength,
assertion_count: self.count_assertions(func),
});
}
}
}
Ok(assertions)
}
fn is_test_function(&self, func: &ItemFn) -> bool {
func.attrs.iter().any(|attr| attr.path().is_ident("test"))
}
fn analyze_function_assertions(&self, func: &ItemFn) -> AssertionStrength {
let mut strongest = AssertionStrength::Weak;
for stmt in &func.block.stmts {
let expr = match stmt {
syn::Stmt::Expr(e, _) => Some(e),
_ => None,
};
if let Some(expr) = expr {
let strength = self.classify_expression(expr);
if strength > strongest {
strongest = strength;
}
}
}
strongest
}
fn classify_expression(&self, expr: &Expr) -> AssertionStrength {
match expr {
Expr::Macro(ExprMacro { mac, .. }) => self.classify_macro_assertion(mac),
Expr::Call(ExprCall { func, .. }) => self.classify_method_call_assertion(func),
_ => AssertionStrength::Weak,
}
}
fn classify_macro_assertion(&self, mac: &syn::Macro) -> AssertionStrength {
let path_str = mac
.path
.segments
.last()
.map(|seg| seg.ident.to_string())
.unwrap_or_default();
self.classify_assertion(&path_str)
}
fn classify_method_call_assertion(&self, _func: &Expr) -> AssertionStrength {
AssertionStrength::Weak
}
pub fn classify_assertion(&self, assertion_name: &str) -> AssertionStrength {
match assertion_name {
"assert_eq" | "assert_ne" => AssertionStrength::Strong,
"assert" | "is_err" => AssertionStrength::Medium,
"is_ok" | "is_some" | "is_none" => AssertionStrength::Weak,
_ => AssertionStrength::Weak,
}
}
#[must_use]
pub fn score_assertion(&self, strength: AssertionStrength) -> f64 {
strength.to_score()
}
fn count_assertions(&self, func: &ItemFn) -> usize {
let mut count = 0;
for stmt in &func.block.stmts {
let expr = match stmt {
syn::Stmt::Expr(e, _) => Some(e),
_ => None,
};
if let Some(expr) = expr {
if self.is_assertion_expr(expr) {
count += 1;
}
}
}
count
}
fn is_assertion_expr(&self, expr: &Expr) -> bool {
match expr {
Expr::Macro(ExprMacro { mac, .. }) => {
let path_str = mac
.path
.segments
.last()
.map(|seg| seg.ident.to_string())
.unwrap_or_default();
path_str.starts_with("assert") || path_str.starts_with("is_")
}
_ => false,
}
}
pub fn analyze_all_tests(&self) -> AuditResult<Vec<TestAssertion>> {
let mut all_assertions = Vec::new();
for test_dir in &self.test_dirs {
for entry in WalkDir::new(test_dir)
.into_iter()
.filter_map(Result::ok)
.filter(|e| e.path().extension().and_then(|s| s.to_str()) == Some("rs"))
{
let assertions = self.analyze_test_file(entry.path())?;
all_assertions.extend(assertions);
}
}
Ok(all_assertions)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TestAssertion {
pub test_id: TestId,
pub file_path: PathBuf,
pub assertion_strength: AssertionStrength,
pub assertion_count: usize,
}