#![cfg(feature = "rust-lang")]
use pyrograph::analyze;
use pyrograph::parse::rust::parse_rust;
use std::path::{Path, PathBuf};
fn test_rust_file(code: &str, name: &str, expect_malware: bool) -> (bool, usize) {
let graph = match parse_rust(code, name) {
Ok(g) => g,
Err(_) => return (false, 0),
};
let findings = analyze(&graph).unwrap_or_default();
let correct = if expect_malware {
!findings.is_empty()
} else {
findings.is_empty()
};
(correct, findings.len())
}
fn walk_rust_files(
dir: &Path,
expect_malware: bool,
tp: &mut u32,
fn_count: &mut u32,
tn: &mut u32,
fp: &mut u32,
parse_fail: &mut u32,
) {
let entries = match std::fs::read_dir(dir) {
Ok(entries) => entries,
Err(_) => return,
};
for entry in entries.flatten() {
let path = entry.path();
if path.is_dir() {
walk_rust_files(&path, expect_malware, tp, fn_count, tn, fp, parse_fail);
continue;
}
if path.extension().map_or(true, |e| e != "rs") {
continue;
}
let code = match std::fs::read_to_string(&path) {
Ok(c) => c,
Err(_) => continue,
};
let (correct, count) = test_rust_file(&code, &path.to_string_lossy(), expect_malware);
if expect_malware {
if count == 0 && code.contains("parse") {
*parse_fail += 1;
} else if correct {
*tp += 1;
} else {
*fn_count += 1;
eprintln!("FN (missed malware): {}", path.display());
}
} else if correct {
*tn += 1;
} else {
*fp += 1;
eprintln!("FP (false alarm): {}", path.display());
}
}
}
fn corpus_root() -> PathBuf {
Path::new(env!("CARGO_MANIFEST_DIR")).join("../../../../corpus")
}
#[test]
fn validate_rust_corpus_samples() {
let mut tp = 0u32;
let mut fn_count = 0u32;
let mut tn = 0u32;
let mut fp = 0u32;
let mut parse_fail = 0u32;
let corpus_root = corpus_root();
walk_rust_files(
&corpus_root.join("malware"),
true,
&mut tp,
&mut fn_count,
&mut tn,
&mut fp,
&mut parse_fail,
);
walk_rust_files(
&corpus_root.join("benign"),
false,
&mut tp,
&mut fn_count,
&mut tn,
&mut fp,
&mut parse_fail,
);
let total = tp + fn_count + tn + fp;
eprintln!("\n=== Rust Corpus Validation ===");
eprintln!(
"TP: {tp}, FN: {fn_count}, TN: {tn}, FP: {fp}, ParseFail: {parse_fail}, Total: {total}"
);
if tp + fn_count > 0 {
eprintln!("Recall: {:.1}%", tp as f64 / (tp + fn_count) as f64 * 100.0);
}
if tp + fp > 0 {
eprintln!("Precision: {:.1}%", tp as f64 / (tp + fp) as f64 * 100.0);
}
assert!(
fn_count < tp,
"More false negatives ({fn_count}) than true positives ({tp}) — engine is broken"
);
}