use std::path::PathBuf;
use syn::spanned::Spanned;
use syn::visit::Visit;
use syn::{BinOp, Expr, ExprBinary, ExprMethodCall, File};
use crate::models::{ComparisonOp, ComparisonSide};
#[derive(Debug, Clone)]
pub struct FoundComparison {
pub left: ComparisonSide,
pub right: ComparisonSide,
pub operator: ComparisonOp,
pub line: u32,
pub column: u32,
pub in_function: Option<String>,
pub in_test: bool,
}
pub struct ComparisonFinder {
comparisons: Vec<FoundComparison>,
current_function: Option<String>,
in_test: bool,
file_path: PathBuf,
}
impl ComparisonFinder {
pub fn new(file_path: PathBuf) -> Self {
Self {
comparisons: Vec::new(),
current_function: None,
in_test: false,
file_path,
}
}
pub fn find(file_path: PathBuf, file: &File) -> Vec<FoundComparison> {
let mut finder = Self::new(file_path);
finder.visit_file(file);
finder.comparisons
}
fn extract_side(expr: &Expr) -> ComparisonSide {
match expr {
Expr::Path(path) => {
let name = path
.path
.segments
.last()
.map(|s| s.ident.to_string())
.unwrap_or_default();
if name.chars().all(|c| c.is_uppercase() || c == '_') && !name.is_empty() {
ComparisonSide::ConstantRef {
name,
resolved_value: None,
}
} else {
ComparisonSide::Variable { name, source: None }
}
}
Expr::Lit(lit) => {
if let syn::Lit::Str(s) = &lit.lit {
ComparisonSide::StringLiteral(s.value())
} else {
ComparisonSide::Other(quote::quote!(#expr).to_string())
}
}
Expr::Field(field) => {
let base = quote::quote!(#field.base).to_string();
let field_name = match &field.member {
syn::Member::Named(ident) => ident.to_string(),
syn::Member::Unnamed(index) => index.index.to_string(),
};
ComparisonSide::FieldAccess {
base,
field: field_name,
}
}
Expr::MethodCall(call) => {
let receiver = quote::quote!(#call.receiver).to_string();
let method = call.method.to_string();
let args: Vec<String> = call
.args
.iter()
.map(|a| quote::quote!(#a).to_string())
.collect();
ComparisonSide::MethodCall {
receiver,
method,
args,
}
}
Expr::Call(call) => {
let path = quote::quote!(#call.func).to_string();
let args: Vec<String> = call
.args
.iter()
.map(|a| quote::quote!(#a).to_string())
.collect();
ComparisonSide::FunctionCall { path, args }
}
Expr::Reference(r) => Self::extract_side(&r.expr),
Expr::Paren(p) => Self::extract_side(&p.expr),
Expr::Group(g) => Self::extract_side(&g.expr),
_ => ComparisonSide::Other(quote::quote!(#expr).to_string()),
}
}
pub fn comparisons(&self) -> &[FoundComparison] {
&self.comparisons
}
pub fn into_comparisons(self) -> Vec<FoundComparison> {
self.comparisons
}
}
impl<'ast> Visit<'ast> for ComparisonFinder {
fn visit_expr_binary(&mut self, node: &'ast ExprBinary) {
let operator = match &node.op {
BinOp::Eq(_) => ComparisonOp::Eq,
BinOp::Ne(_) => ComparisonOp::Ne,
_ => {
syn::visit::visit_expr_binary(self, node);
return;
}
};
let left = Self::extract_side(&node.left);
let right = Self::extract_side(&node.right);
let span = match &node.op {
BinOp::Eq(token) => token.span(),
BinOp::Ne(token) => token.spans[0],
_ => node.span(),
};
self.comparisons.push(FoundComparison {
left,
right,
operator,
line: span.start().line as u32,
column: span.start().column as u32,
in_function: self.current_function.clone(),
in_test: self.in_test,
});
syn::visit::visit_expr_binary(self, node);
}
fn visit_expr_method_call(&mut self, node: &'ast ExprMethodCall) {
if node.method == "eq" && node.args.len() == 1 {
let left = Self::extract_side(&node.receiver);
let right = Self::extract_side(&node.args.first().unwrap());
self.comparisons.push(FoundComparison {
left,
right,
operator: ComparisonOp::Eq,
line: node.method.span().start().line as u32,
column: node.method.span().start().column as u32,
in_function: self.current_function.clone(),
in_test: self.in_test,
});
}
syn::visit::visit_expr_method_call(self, node);
}
fn visit_item_fn(&mut self, node: &'ast syn::ItemFn) {
let old_function = self.current_function.take();
let old_test = self.in_test;
self.current_function = Some(node.sig.ident.to_string());
let is_test = node.attrs.iter().any(|attr| {
attr.path().is_ident("test")
|| attr
.path()
.segments
.last()
.map(|s| s.ident == "test")
.unwrap_or(false)
});
if is_test {
self.in_test = true;
}
syn::visit::visit_item_fn(self, node);
self.current_function = old_function;
self.in_test = old_test;
}
fn visit_item_mod(&mut self, node: &'ast syn::ItemMod) {
let old_test = self.in_test;
let is_test_mod = node.ident == "tests"
|| node.ident == "test"
|| node.attrs.iter().any(|attr| {
if attr.path().is_ident("cfg") {
if let Ok(meta) = attr.meta.require_list() {
return meta.tokens.to_string().contains("test");
}
}
false
});
if is_test_mod {
self.in_test = true;
}
syn::visit::visit_item_mod(self, node);
self.in_test = old_test;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_find_eq_comparison() {
let code = r#"
fn check(input: &str) -> bool {
input == "secret"
}
"#;
let file: syn::File = syn::parse_str(code).unwrap();
let comparisons = ComparisonFinder::find(PathBuf::from("test.rs"), &file);
assert_eq!(comparisons.len(), 1);
assert!(matches!(comparisons[0].operator, ComparisonOp::Eq));
assert!(matches!(
&comparisons[0].left,
ComparisonSide::Variable { name, .. } if name == "input"
));
assert!(matches!(
&comparisons[0].right,
ComparisonSide::StringLiteral(s) if s == "secret"
));
}
#[test]
fn test_find_ne_comparison() {
let code = r#"
fn check(input: &str) -> bool {
input != "invalid"
}
"#;
let file: syn::File = syn::parse_str(code).unwrap();
let comparisons = ComparisonFinder::find(PathBuf::from("test.rs"), &file);
assert_eq!(comparisons.len(), 1);
assert!(matches!(comparisons[0].operator, ComparisonOp::Ne));
}
#[test]
fn test_find_constant_comparison() {
let code = r#"
const TOKEN: &str = "abc";
fn check(input: &str) -> bool {
input == TOKEN
}
"#;
let file: syn::File = syn::parse_str(code).unwrap();
let comparisons = ComparisonFinder::find(PathBuf::from("test.rs"), &file);
assert_eq!(comparisons.len(), 1);
assert!(matches!(
&comparisons[0].right,
ComparisonSide::ConstantRef { name, .. } if name == "TOKEN"
));
}
#[test]
fn test_find_eq_method() {
let code = r#"
fn check(input: &str) -> bool {
input.eq("secret")
}
"#;
let file: syn::File = syn::parse_str(code).unwrap();
let comparisons = ComparisonFinder::find(PathBuf::from("test.rs"), &file);
assert_eq!(comparisons.len(), 1);
assert!(matches!(comparisons[0].operator, ComparisonOp::Eq));
}
#[test]
fn test_test_context() {
let code = r#"
#[test]
fn test_something() {
assert_eq!(x, "value");
if x == "test" {}
}
"#;
let file: syn::File = syn::parse_str(code).unwrap();
let comparisons = ComparisonFinder::find(PathBuf::from("test.rs"), &file);
assert!(comparisons.iter().all(|c| c.in_test));
}
#[test]
fn test_function_context() {
let code = r#"
fn authenticate(token: &str) -> bool {
token == SECRET
}
"#;
let file: syn::File = syn::parse_str(code).unwrap();
let comparisons = ComparisonFinder::find(PathBuf::from("test.rs"), &file);
assert_eq!(comparisons.len(), 1);
assert_eq!(
comparisons[0].in_function,
Some("authenticate".to_string())
);
}
}