use crate::{
metadata::{
cilassemblyview::CilAssemblyView,
validation::{
OwnedValidationContext, RawValidationContext, ReferenceScanner, ValidationConfig,
},
},
Error, Result,
};
use rayon::ThreadPoolBuilder;
use std::path::{Path, PathBuf};
use tempfile::NamedTempFile;
#[derive(Debug)]
pub enum TestAssemblySource {
File {
path: PathBuf,
_temp_file: Option<NamedTempFile>,
},
Memory(Vec<u8>),
}
#[derive(Debug)]
pub struct TestAssembly {
pub source: TestAssemblySource,
pub should_pass: bool,
pub expected_error_pattern: Option<String>,
}
impl TestAssembly {
pub fn new<P: Into<PathBuf>>(path: P, should_pass: bool) -> Self {
Self {
source: TestAssemblySource::File {
path: path.into(),
_temp_file: None,
},
should_pass,
expected_error_pattern: None,
}
}
pub fn failing_with_error<P: Into<PathBuf>>(path: P, error_pattern: &str) -> Self {
Self {
source: TestAssemblySource::File {
path: path.into(),
_temp_file: None,
},
should_pass: false,
expected_error_pattern: Some(error_pattern.to_string()),
}
}
pub fn from_temp_file(temp_file: NamedTempFile, should_pass: bool) -> Self {
let path = temp_file.path().to_path_buf();
Self {
source: TestAssemblySource::File {
path,
_temp_file: Some(temp_file),
},
should_pass,
expected_error_pattern: None,
}
}
pub fn from_temp_file_with_error(temp_file: NamedTempFile, error_pattern: &str) -> Self {
let path = temp_file.path().to_path_buf();
Self {
source: TestAssemblySource::File {
path,
_temp_file: Some(temp_file),
},
should_pass: false,
expected_error_pattern: Some(error_pattern.to_string()),
}
}
pub fn from_bytes(data: Vec<u8>, should_pass: bool) -> Self {
Self {
source: TestAssemblySource::Memory(data),
should_pass,
expected_error_pattern: None,
}
}
pub fn from_bytes_with_error(data: Vec<u8>, error_pattern: &str) -> Self {
Self {
source: TestAssemblySource::Memory(data),
should_pass: false,
expected_error_pattern: Some(error_pattern.to_string()),
}
}
pub fn path(&self) -> Option<&Path> {
match &self.source {
TestAssemblySource::File { path, .. } => Some(path),
TestAssemblySource::Memory(_) => None,
}
}
pub fn bytes(&self) -> Option<&[u8]> {
match &self.source {
TestAssemblySource::File { .. } => None,
TestAssemblySource::Memory(data) => Some(data),
}
}
pub fn is_memory_backed(&self) -> bool {
matches!(self.source, TestAssemblySource::Memory(_))
}
pub fn description(&self) -> String {
match &self.source {
TestAssemblySource::File { path, .. } => path.display().to_string(),
TestAssemblySource::Memory(data) => format!("<in-memory {} bytes>", data.len()),
}
}
pub fn load(&self) -> Result<CilAssemblyView> {
match &self.source {
TestAssemblySource::File { path, .. } => {
CilAssemblyView::from_path_with_validation(path, ValidationConfig::disabled())
}
TestAssemblySource::Memory(data) => CilAssemblyView::from_mem_with_validation(
data.clone(),
ValidationConfig::disabled(),
),
}
}
}
#[derive(Debug)]
pub struct ValidationTestResult {
pub assembly: TestAssembly,
pub validation_succeeded: bool,
pub error_message: Option<String>,
pub test_passed: bool,
}
pub type FileFactory = fn() -> Result<Vec<TestAssembly>>;
fn file_verify(
results: &[ValidationTestResult],
validator_name: &str,
expected_error_type: &str,
) -> Result<()> {
if results.is_empty() {
return Err(Error::Other(
"No test assemblies were processed".to_string(),
));
}
let mut positive_tests = 0;
let mut negative_tests = 0;
for result in results {
if result.assembly.should_pass {
positive_tests += 1;
if !result.test_passed {
return Err(Error::Other(format!(
"Positive test failed for {}: validation should have passed but got error: {:?}",
result.assembly.description(),
result.error_message
)));
}
if !result.validation_succeeded {
return Err(Error::Other(format!(
"Clean assembly {} failed {} validation unexpectedly",
result.assembly.description(),
validator_name
)));
}
} else {
negative_tests += 1;
if !result.test_passed {
return Err(Error::Other(format!(
"Negative test failed for {}: expected validation failure with pattern '{:?}' but got: validation_succeeded={}, error={:?}",
result.assembly.description(),
result.assembly.expected_error_pattern,
result.validation_succeeded,
result.error_message
)));
}
if result.validation_succeeded {
return Err(Error::Other(format!(
"Modified assembly {} passed validation but should have failed",
result.assembly.description()
)));
}
if let Some(expected_pattern) = &result.assembly.expected_error_pattern {
if let Some(error_msg) = &result.error_message {
if !error_msg.contains(expected_pattern) {
return Err(Error::Other(format!(
"Error message '{error_msg}' does not contain expected pattern '{expected_pattern}'"
)));
}
if !expected_error_type.is_empty() && !error_msg.contains(expected_error_type) {
return Err(Error::Other(format!(
"Expected {expected_error_type} but got: {error_msg}"
)));
}
}
}
}
}
if positive_tests < 1 {
return Err(Error::Other("No positive test cases found".to_string()));
}
if results.len() > 1 && negative_tests < 1 {
return Err(Error::Other(format!(
"Expected negative tests for validation rules, got {negative_tests}"
)));
}
Ok(())
}
pub fn validator_test<F>(
file_factory: FileFactory,
validator_name: &str,
expected_error_type: &str,
validation_config: ValidationConfig,
run_validator: F,
) -> Result<()>
where
F: Fn(&RawValidationContext) -> Result<()>,
{
let test_assemblies = file_factory()?;
if test_assemblies.is_empty() {
return Err(Error::Other("No test-assembly found!".to_string()));
}
let mut test_results = Vec::new();
for assembly in test_assemblies {
let validation_result = run_validation_test(&assembly, &validation_config, &run_validator);
let test_result = match validation_result {
Ok(()) => ValidationTestResult {
test_passed: assembly.should_pass,
validation_succeeded: true,
error_message: None,
assembly,
},
Err(error) => {
let error_msg = format!("{error:?}");
let test_passed = if assembly.should_pass {
false
} else if let Some(expected_pattern) = &assembly.expected_error_pattern {
error_msg.contains(expected_pattern)
} else {
true
};
ValidationTestResult {
test_passed,
validation_succeeded: false,
error_message: Some(error_msg),
assembly,
}
}
};
test_results.push(test_result);
}
file_verify(&test_results, validator_name, expected_error_type)
}
pub fn owned_validator_test<F>(
file_factory: FileFactory,
validator_name: &str,
expected_error_type: &str,
validation_config: ValidationConfig,
run_validator: F,
) -> Result<()>
where
F: Fn(&OwnedValidationContext) -> Result<()>,
{
let test_assemblies = file_factory()?;
if test_assemblies.is_empty() {
return Err(Error::Other("No test-assembly found!".to_string()));
}
let mut test_results = Vec::new();
for assembly in test_assemblies {
let validation_result =
run_owned_validation_test(&assembly, &validation_config, &run_validator);
let test_result = match validation_result {
Ok(()) => ValidationTestResult {
test_passed: assembly.should_pass,
validation_succeeded: true,
error_message: None,
assembly,
},
Err(error) => {
let error_msg = format!("{error:?}");
let test_passed = if assembly.should_pass {
false
} else if let Some(expected_pattern) = &assembly.expected_error_pattern {
error_msg.contains(expected_pattern)
} else {
true
};
ValidationTestResult {
test_passed,
validation_succeeded: false,
error_message: Some(error_msg),
assembly,
}
}
};
test_results.push(test_result);
}
file_verify(&test_results, validator_name, expected_error_type)
}
fn run_validation_test<F>(
assembly: &TestAssembly,
config: &ValidationConfig,
run_validator: &F,
) -> Result<()>
where
F: Fn(&RawValidationContext) -> Result<()>,
{
let assembly_view = assembly.load()?;
let scanner = ReferenceScanner::from_view(&assembly_view)?;
let thread_count = std::thread::available_parallelism()
.map(|n| n.get())
.unwrap_or(4);
let thread_pool = ThreadPoolBuilder::new()
.num_threads(thread_count)
.build()
.unwrap();
let context =
RawValidationContext::new_for_loading(&assembly_view, &scanner, config, &thread_pool);
run_validator(&context)
}
fn run_owned_validation_test<F>(
assembly: &TestAssembly,
config: &ValidationConfig,
run_validator: &F,
) -> Result<()>
where
F: Fn(&OwnedValidationContext) -> Result<()>,
{
use std::io::Write;
let manifest_dir = std::env::var("CARGO_MANIFEST_DIR").unwrap_or_else(|_| ".".to_string());
let mono_deps_path = Path::new(&manifest_dir).join("tests/samples/mono_4.8");
let assembly_view = assembly.load()?;
let _temp_file: Option<NamedTempFile>;
let primary_path: PathBuf = match &assembly.source {
TestAssemblySource::File { path, .. } => path.clone(),
TestAssemblySource::Memory(data) => {
let mut temp = NamedTempFile::new().map_err(|e| {
Error::Other(format!(
"Failed to create temp file for owned validation: {e}"
))
})?;
temp.write_all(data).map_err(|e| {
Error::Other(format!(
"Failed to write temp file for owned validation: {e}"
))
})?;
let path = temp.path().to_path_buf();
_temp_file = Some(temp);
path
}
};
let project_result = crate::project::ProjectLoader::new()
.primary_file(&primary_path)?
.with_search_path(&mono_deps_path)?
.auto_discover(true)
.strict_mode(true)
.with_validation(ValidationConfig::disabled())
.build()?;
let object = project_result
.project
.get_primary()
.ok_or_else(|| Error::Other("Failed to get primary assembly from project".to_string()))?;
let scanner = ReferenceScanner::from_view(&assembly_view)?;
let thread_count = std::thread::available_parallelism()
.map(|n| n.get())
.unwrap_or(4);
let thread_pool = ThreadPoolBuilder::new()
.num_threads(thread_count)
.build()
.unwrap();
let context = OwnedValidationContext::new(object.as_ref(), &scanner, config, &thread_pool);
run_validator(&context)
}
pub fn get_testfile_wb() -> Option<PathBuf> {
let windowsbase_path =
Path::new(env!("CARGO_MANIFEST_DIR")).join("tests/samples/WindowsBase.dll");
if windowsbase_path.exists() {
Some(windowsbase_path)
} else {
None
}
}
pub fn get_testfile_crafted2() -> Option<PathBuf> {
let crafted_path = Path::new(env!("CARGO_MANIFEST_DIR")).join("tests/samples/crafted_2.exe");
if crafted_path.exists() {
Some(crafted_path)
} else {
None
}
}
pub fn get_testfile_mscorlib() -> Option<PathBuf> {
let mscorlib_path =
Path::new(env!("CARGO_MANIFEST_DIR")).join("tests/samples/mono_4.8/mscorlib.dll");
if mscorlib_path.exists() {
Some(mscorlib_path)
} else {
None
}
}