vsec 0.0.1

Detect secrets and in Rust codebases
Documentation
// src/analysis/comparison_finder.rs

use std::path::PathBuf;

use syn::spanned::Spanned;
use syn::visit::Visit;
use syn::{BinOp, Expr, ExprBinary, ExprMethodCall, File};

use crate::models::{ComparisonOp, ComparisonSide};

/// A comparison found during analysis
#[derive(Debug, Clone)]
pub struct FoundComparison {
    /// Left side of the comparison
    pub left: ComparisonSide,

    /// Right side of the comparison
    pub right: ComparisonSide,

    /// The comparison operator
    pub operator: ComparisonOp,

    /// Line number
    pub line: u32,

    /// Column number
    pub column: u32,

    /// The function this comparison is in (if any)
    pub in_function: Option<String>,

    /// Whether this is in a test context
    pub in_test: bool,
}

/// Finds equality comparisons in a Rust file
pub struct ComparisonFinder {
    /// Collected comparisons
    comparisons: Vec<FoundComparison>,

    /// Current function name
    current_function: Option<String>,

    /// Whether we're in a test context
    in_test: bool,

    /// File being analyzed
    file_path: PathBuf,
}

impl ComparisonFinder {
    pub fn new(file_path: PathBuf) -> Self {
        Self {
            comparisons: Vec::new(),
            current_function: None,
            in_test: false,
            file_path,
        }
    }

    /// Find all comparisons in a file
    pub fn find(file_path: PathBuf, file: &File) -> Vec<FoundComparison> {
        let mut finder = Self::new(file_path);
        finder.visit_file(file);
        finder.comparisons
    }

    /// Extract comparison side from an expression
    fn extract_side(expr: &Expr) -> ComparisonSide {
        match expr {
            Expr::Path(path) => {
                let name = path
                    .path
                    .segments
                    .last()
                    .map(|s| s.ident.to_string())
                    .unwrap_or_default();

                // Check if it looks like a constant (all uppercase)
                if name.chars().all(|c| c.is_uppercase() || c == '_') && !name.is_empty() {
                    ComparisonSide::ConstantRef {
                        name,
                        resolved_value: None,
                    }
                } else {
                    ComparisonSide::Variable { name, source: None }
                }
            }
            Expr::Lit(lit) => {
                if let syn::Lit::Str(s) = &lit.lit {
                    ComparisonSide::StringLiteral(s.value())
                } else {
                    ComparisonSide::Other(quote::quote!(#expr).to_string())
                }
            }
            Expr::Field(field) => {
                let base = quote::quote!(#field.base).to_string();
                let field_name = match &field.member {
                    syn::Member::Named(ident) => ident.to_string(),
                    syn::Member::Unnamed(index) => index.index.to_string(),
                };
                ComparisonSide::FieldAccess {
                    base,
                    field: field_name,
                }
            }
            Expr::MethodCall(call) => {
                let receiver = quote::quote!(#call.receiver).to_string();
                let method = call.method.to_string();
                let args: Vec<String> = call
                    .args
                    .iter()
                    .map(|a| quote::quote!(#a).to_string())
                    .collect();
                ComparisonSide::MethodCall {
                    receiver,
                    method,
                    args,
                }
            }
            Expr::Call(call) => {
                let path = quote::quote!(#call.func).to_string();
                let args: Vec<String> = call
                    .args
                    .iter()
                    .map(|a| quote::quote!(#a).to_string())
                    .collect();
                ComparisonSide::FunctionCall { path, args }
            }
            Expr::Reference(r) => Self::extract_side(&r.expr),
            Expr::Paren(p) => Self::extract_side(&p.expr),
            Expr::Group(g) => Self::extract_side(&g.expr),
            _ => ComparisonSide::Other(quote::quote!(#expr).to_string()),
        }
    }

    /// Get the collected comparisons
    pub fn comparisons(&self) -> &[FoundComparison] {
        &self.comparisons
    }

    /// Take ownership of results
    pub fn into_comparisons(self) -> Vec<FoundComparison> {
        self.comparisons
    }
}

impl<'ast> Visit<'ast> for ComparisonFinder {
    fn visit_expr_binary(&mut self, node: &'ast ExprBinary) {
        // Only care about equality comparisons
        let operator = match &node.op {
            BinOp::Eq(_) => ComparisonOp::Eq,
            BinOp::Ne(_) => ComparisonOp::Ne,
            _ => {
                syn::visit::visit_expr_binary(self, node);
                return;
            }
        };

        let left = Self::extract_side(&node.left);
        let right = Self::extract_side(&node.right);

        // Get span from the operator
        let span = match &node.op {
            BinOp::Eq(token) => token.span(),
            BinOp::Ne(token) => token.spans[0],
            _ => node.span(),
        };

        self.comparisons.push(FoundComparison {
            left,
            right,
            operator,
            line: span.start().line as u32,
            column: span.start().column as u32,
            in_function: self.current_function.clone(),
            in_test: self.in_test,
        });

        syn::visit::visit_expr_binary(self, node);
    }

    fn visit_expr_method_call(&mut self, node: &'ast ExprMethodCall) {
        // Handle .eq() method calls
        if node.method == "eq" && node.args.len() == 1 {
            let left = Self::extract_side(&node.receiver);
            let right = Self::extract_side(&node.args.first().unwrap());

            self.comparisons.push(FoundComparison {
                left,
                right,
                operator: ComparisonOp::Eq,
                line: node.method.span().start().line as u32,
                column: node.method.span().start().column as u32,
                in_function: self.current_function.clone(),
                in_test: self.in_test,
            });
        }

        syn::visit::visit_expr_method_call(self, node);
    }

    fn visit_item_fn(&mut self, node: &'ast syn::ItemFn) {
        let old_function = self.current_function.take();
        let old_test = self.in_test;

        self.current_function = Some(node.sig.ident.to_string());

        // Check if this is a test function
        let is_test = node.attrs.iter().any(|attr| {
            attr.path().is_ident("test")
                || attr
                    .path()
                    .segments
                    .last()
                    .map(|s| s.ident == "test")
                    .unwrap_or(false)
        });

        if is_test {
            self.in_test = true;
        }

        syn::visit::visit_item_fn(self, node);

        self.current_function = old_function;
        self.in_test = old_test;
    }

    fn visit_item_mod(&mut self, node: &'ast syn::ItemMod) {
        let old_test = self.in_test;

        // Check if this is a test module
        let is_test_mod = node.ident == "tests"
            || node.ident == "test"
            || node.attrs.iter().any(|attr| {
                if attr.path().is_ident("cfg") {
                    if let Ok(meta) = attr.meta.require_list() {
                        return meta.tokens.to_string().contains("test");
                    }
                }
                false
            });

        if is_test_mod {
            self.in_test = true;
        }

        syn::visit::visit_item_mod(self, node);

        self.in_test = old_test;
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_find_eq_comparison() {
        let code = r#"
            fn check(input: &str) -> bool {
                input == "secret"
            }
        "#;

        let file: syn::File = syn::parse_str(code).unwrap();
        let comparisons = ComparisonFinder::find(PathBuf::from("test.rs"), &file);

        assert_eq!(comparisons.len(), 1);
        assert!(matches!(comparisons[0].operator, ComparisonOp::Eq));
        assert!(matches!(
            &comparisons[0].left,
            ComparisonSide::Variable { name, .. } if name == "input"
        ));
        assert!(matches!(
            &comparisons[0].right,
            ComparisonSide::StringLiteral(s) if s == "secret"
        ));
    }

    #[test]
    fn test_find_ne_comparison() {
        let code = r#"
            fn check(input: &str) -> bool {
                input != "invalid"
            }
        "#;

        let file: syn::File = syn::parse_str(code).unwrap();
        let comparisons = ComparisonFinder::find(PathBuf::from("test.rs"), &file);

        assert_eq!(comparisons.len(), 1);
        assert!(matches!(comparisons[0].operator, ComparisonOp::Ne));
    }

    #[test]
    fn test_find_constant_comparison() {
        let code = r#"
            const TOKEN: &str = "abc";
            fn check(input: &str) -> bool {
                input == TOKEN
            }
        "#;

        let file: syn::File = syn::parse_str(code).unwrap();
        let comparisons = ComparisonFinder::find(PathBuf::from("test.rs"), &file);

        assert_eq!(comparisons.len(), 1);
        assert!(matches!(
            &comparisons[0].right,
            ComparisonSide::ConstantRef { name, .. } if name == "TOKEN"
        ));
    }

    #[test]
    fn test_find_eq_method() {
        let code = r#"
            fn check(input: &str) -> bool {
                input.eq("secret")
            }
        "#;

        let file: syn::File = syn::parse_str(code).unwrap();
        let comparisons = ComparisonFinder::find(PathBuf::from("test.rs"), &file);

        assert_eq!(comparisons.len(), 1);
        assert!(matches!(comparisons[0].operator, ComparisonOp::Eq));
    }

    #[test]
    fn test_test_context() {
        let code = r#"
            #[test]
            fn test_something() {
                assert_eq!(x, "value");
                if x == "test" {}
            }
        "#;

        let file: syn::File = syn::parse_str(code).unwrap();
        let comparisons = ComparisonFinder::find(PathBuf::from("test.rs"), &file);

        assert!(comparisons.iter().all(|c| c.in_test));
    }

    #[test]
    fn test_function_context() {
        let code = r#"
            fn authenticate(token: &str) -> bool {
                token == SECRET
            }
        "#;

        let file: syn::File = syn::parse_str(code).unwrap();
        let comparisons = ComparisonFinder::find(PathBuf::from("test.rs"), &file);

        assert_eq!(comparisons.len(), 1);
        assert_eq!(
            comparisons[0].in_function,
            Some("authenticate".to_string())
        );
    }
}