use log::debug;
use crate::{
cilassembly::AssemblyChanges,
metadata::{
cilassemblyview::CilAssemblyView,
cilobject::CilObject,
validation::{
config::ValidationConfig,
context::factory as context_factory,
result::{TwoStageValidationResult, ValidationResult},
scanner::{ReferenceScanner, ScannerStatistics},
traits::{OwnedValidator, RawValidator},
validators::{
OwnedAccessibilityValidator, OwnedAssemblyValidator, OwnedAttributeValidator,
OwnedCircularityValidator, OwnedDependencyValidator, OwnedFieldValidator,
OwnedInheritanceValidator, OwnedMethodValidator, OwnedOwnershipValidator,
OwnedSecurityValidator, OwnedSignatureValidator, OwnedTypeCircularityValidator,
OwnedTypeConstraintValidator, OwnedTypeDefinitionValidator,
OwnedTypeDependencyValidator, OwnedTypeOwnershipValidator,
RawChangeIntegrityValidator, RawGenericConstraintValidator, RawHeapValidator,
RawLayoutConstraintValidator, RawOperationValidator, RawSignatureValidator,
RawTableValidator, RawTokenValidator,
},
},
},
Error, Result,
};
use rayon::{prelude::*, ThreadPool, ThreadPoolBuilder};
use std::{sync::OnceLock, time::Instant};
static RAW_VALIDATORS: OnceLock<Vec<Box<dyn RawValidator>>> = OnceLock::new();
static OWNED_VALIDATORS: OnceLock<Vec<Box<dyn OwnedValidator>>> = OnceLock::new();
fn init_raw_validators() -> Vec<Box<dyn RawValidator>> {
vec![
Box::new(RawTokenValidator::new()), Box::new(RawTableValidator::new()), Box::new(RawHeapValidator::new()), Box::new(RawSignatureValidator::new()), Box::new(RawGenericConstraintValidator::new()), Box::new(RawLayoutConstraintValidator::new()), Box::new(RawOperationValidator::new()), Box::new(RawChangeIntegrityValidator::new()), ]
}
fn init_owned_validators() -> Vec<Box<dyn OwnedValidator>> {
vec![
Box::new(OwnedTypeDefinitionValidator::new()), Box::new(OwnedTypeConstraintValidator::new()), Box::new(OwnedInheritanceValidator::new()), Box::new(OwnedTypeCircularityValidator::new()), Box::new(OwnedTypeDependencyValidator::new()), Box::new(OwnedTypeOwnershipValidator::new()), Box::new(OwnedMethodValidator::new()), Box::new(OwnedFieldValidator::new()), Box::new(OwnedAccessibilityValidator::new()), Box::new(OwnedSignatureValidator::new()), Box::new(OwnedAttributeValidator::new()), Box::new(OwnedCircularityValidator::new()), Box::new(OwnedDependencyValidator::new()), Box::new(OwnedOwnershipValidator::new()), Box::new(OwnedSecurityValidator::new()), Box::new(OwnedAssemblyValidator::new()), ]
}
pub struct ValidationEngine {
config: ValidationConfig,
scanner: ReferenceScanner,
thread_pool: ThreadPool,
}
impl ValidationEngine {
pub fn new(view: &CilAssemblyView, config: ValidationConfig) -> Result<Self> {
let scanner =
ReferenceScanner::from_view(view).map_err(|e| Error::ValidationEngineInitFailed {
message: format!("Failed to initialize reference scanner: {e}"),
})?;
let thread_count = std::thread::available_parallelism()
.map(std::num::NonZero::get)
.unwrap_or(4);
let thread_pool = ThreadPoolBuilder::new()
.thread_name(|i| format!("validation-engine-{i}"))
.num_threads(thread_count)
.build()
.map_err(|e| Error::ValidationEngineInitFailed {
message: format!("Failed to create thread pool: {e}"),
})?;
debug!(
"Validation starting: {} raw + {} owned validators, lenient={}",
init_raw_validators().len(),
init_owned_validators().len(),
config.lenient
);
Ok(Self {
config,
scanner,
thread_pool,
})
}
pub fn execute_two_stage_validation(
&self,
view: &CilAssemblyView,
changes: Option<&AssemblyChanges>,
object: Option<&CilObject>,
) -> Result<TwoStageValidationResult> {
let mut result = TwoStageValidationResult::new();
if self.config.should_validate_raw() {
let stage1_result = self.execute_stage1_validation(view, changes)?;
let stage1_success = stage1_result.is_success();
result.set_stage1_result(stage1_result);
if !stage1_success {
return Ok(result); }
}
if let Some(obj) = object {
if self.config.should_validate_owned() {
let stage2_result = self.execute_stage2_validation(obj)?;
result.set_stage2_result(stage2_result);
}
}
Ok(result)
}
pub fn execute_stage1_validation(
&self,
view: &CilAssemblyView,
changes: Option<&AssemblyChanges>,
) -> Result<ValidationResult> {
let validators = Self::get_raw_validators();
self.validate_raw_stage(view, changes, validators)
}
pub fn execute_stage2_validation(&self, object: &CilObject) -> Result<ValidationResult> {
let validators = Self::get_owned_validators();
self.validate_owned_stage(object, validators)
}
pub fn validate_raw_stage(
&self,
view: &CilAssemblyView,
changes: Option<&AssemblyChanges>,
validators: &Vec<Box<dyn RawValidator>>,
) -> Result<ValidationResult> {
let start_time = Instant::now();
let context = if let Some(changes) = changes {
context_factory::raw_modification_context(
view,
changes,
&self.scanner,
&self.config,
&self.thread_pool,
)
} else {
context_factory::raw_loading_context(
view,
&self.scanner,
&self.config,
&self.thread_pool,
)
};
let active_validators: Vec<_> = validators
.iter()
.filter(|v| v.should_run(&context))
.collect();
if active_validators.is_empty() {
return Ok(ValidationResult::success());
}
let results: Vec<(&str, Result<()>)> = self.thread_pool.install(|| {
active_validators
.par_iter()
.map(|validator| {
let validator_result =
validator
.validate_raw(&context)
.map_err(|e| Error::ValidationRawFailed {
validator: validator.name().to_string(),
message: e.to_string(),
});
(validator.name(), validator_result)
})
.collect()
});
let duration = start_time.elapsed();
let named_results: Vec<(&str, Result<()>)> = results.into_iter().collect();
let error_count = named_results.iter().filter(|(_, r)| r.is_err()).count();
debug!(
"Raw validation completed in {}ms: {} errors",
duration.as_millis(),
error_count
);
let diagnostics = if self.config.lenient {
Some(view.diagnostics())
} else {
None
};
Ok(ValidationResult::from_named_results(
named_results,
duration,
diagnostics,
))
}
pub fn validate_owned_stage(
&self,
object: &CilObject,
validators: &Vec<Box<dyn OwnedValidator>>,
) -> Result<ValidationResult> {
let start_time = Instant::now();
let context =
context_factory::owned_context(object, &self.scanner, &self.config, &self.thread_pool);
let active_validators: Vec<_> = validators
.iter()
.filter(|v| v.should_run(&context))
.collect();
if active_validators.is_empty() {
return Ok(ValidationResult::success());
}
let results: Vec<(&str, Result<()>)> = self.thread_pool.install(|| {
active_validators
.par_iter()
.map(|validator| {
let validator_result = validator.validate_owned(&context).map_err(|e| {
Error::ValidationOwnedFailed {
validator: validator.name().to_string(),
message: e.to_string(),
}
});
(validator.name(), validator_result)
})
.collect()
});
let duration = start_time.elapsed();
let named_results: Vec<(&str, Result<()>)> = results.into_iter().collect();
let error_count = named_results.iter().filter(|(_, r)| r.is_err()).count();
debug!(
"Owned validation completed in {}ms: {} errors",
duration.as_millis(),
error_count
);
let diagnostics = if self.config.lenient {
Some(object.diagnostics())
} else {
None
};
Ok(ValidationResult::from_named_results(
named_results,
duration,
diagnostics,
))
}
fn get_raw_validators() -> &'static Vec<Box<dyn RawValidator>> {
RAW_VALIDATORS.get_or_init(init_raw_validators)
}
fn get_owned_validators() -> &'static Vec<Box<dyn OwnedValidator>> {
OWNED_VALIDATORS.get_or_init(init_owned_validators)
}
#[must_use]
pub fn config(&self) -> &ValidationConfig {
&self.config
}
#[must_use]
pub fn scanner(&self) -> &ReferenceScanner {
&self.scanner
}
#[must_use]
pub fn statistics(&self) -> EngineStatistics {
EngineStatistics {
scanner_stats: self.scanner.statistics(),
raw_validator_count: Self::get_raw_validators().len(),
owned_validator_count: Self::get_owned_validators().len(),
}
}
}
#[derive(Debug, Clone)]
pub struct EngineStatistics {
pub scanner_stats: ScannerStatistics,
pub raw_validator_count: usize,
pub owned_validator_count: usize,
}
impl std::fmt::Display for EngineStatistics {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"Engine Statistics: {} raw validators, {} owned validators, {}",
self.raw_validator_count, self.owned_validator_count, self.scanner_stats
)
}
}
pub mod factory {
use super::{CilAssemblyView, Result, ValidationConfig, ValidationEngine};
pub fn minimal_engine(view: &CilAssemblyView) -> Result<ValidationEngine> {
ValidationEngine::new(view, ValidationConfig::minimal())
}
pub fn production_engine(view: &CilAssemblyView) -> Result<ValidationEngine> {
ValidationEngine::new(view, ValidationConfig::production())
}
pub fn comprehensive_engine(view: &CilAssemblyView) -> Result<ValidationEngine> {
ValidationEngine::new(view, ValidationConfig::comprehensive())
}
pub fn strict_engine(view: &CilAssemblyView) -> Result<ValidationEngine> {
ValidationEngine::new(view, ValidationConfig::strict())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
cilassembly::AssemblyChanges,
metadata::{
cilassemblyview::CilAssemblyView,
cilobject::CilObject,
validation::{
config::ValidationConfig, context::RawValidationContext, traits::RawValidator,
},
},
};
use std::path::PathBuf;
struct TestRawValidator {
should_fail: bool,
}
impl RawValidator for TestRawValidator {
fn validate_raw(&self, _context: &RawValidationContext) -> Result<()> {
if self.should_fail {
Err(Error::NotSupported)
} else {
Ok(())
}
}
fn name(&self) -> &'static str {
"TestRawValidator"
}
}
#[test]
fn test_validation_engine_creation() {
let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/samples/WindowsBase.dll");
if let Ok(view) = CilAssemblyView::from_path(&path) {
let config = ValidationConfig::minimal();
let engine = ValidationEngine::new(&view, config);
assert!(engine.is_ok(), "Engine creation should succeed");
let engine = engine.unwrap();
let stats = engine.statistics();
assert!(stats.scanner_stats.total_tokens > 0);
}
}
#[test]
fn test_two_stage_validation_early_termination() {
let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/samples/WindowsBase.dll");
if let Ok(view) = CilAssemblyView::from_path(&path) {
let mut config = ValidationConfig::comprehensive();
config.enable_raw_validation = true;
config.enable_owned_validation = true;
if let Ok(engine) = ValidationEngine::new(&view, config) {
let result = engine.execute_two_stage_validation(&view, None, None);
assert!(result.is_ok());
let result = result.unwrap();
assert!(result.stage1_result().is_some());
assert!(result.stage2_result().is_none());
}
}
}
#[test]
fn test_raw_validation_with_changes() {
let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/samples/WindowsBase.dll");
if let Ok(view) = CilAssemblyView::from_path(&path) {
let config = ValidationConfig::minimal();
if let Ok(engine) = ValidationEngine::new(&view, config) {
let changes = AssemblyChanges::empty();
let result = engine.execute_stage1_validation(&view, Some(&changes));
assert!(result.is_ok());
let result = engine.execute_stage1_validation(&view, None);
assert!(result.is_ok());
}
}
}
#[test]
fn test_factory_functions() {
let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/samples/WindowsBase.dll");
if let Ok(view) = CilAssemblyView::from_path(&path) {
assert!(factory::minimal_engine(&view).is_ok());
assert!(factory::production_engine(&view).is_ok());
assert!(factory::comprehensive_engine(&view).is_ok());
assert!(factory::strict_engine(&view).is_ok());
}
}
#[test]
fn test_engine_statistics() {
let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/samples/WindowsBase.dll");
if let Ok(view) = CilAssemblyView::from_path(&path) {
if let Ok(engine) = ValidationEngine::new(&view, ValidationConfig::minimal()) {
let stats = engine.statistics();
let stats_string = stats.to_string();
assert!(stats_string.contains("validators"));
assert!(stats_string.contains("tokens"));
}
}
}
#[test]
fn test_all_validators_registered() {
let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/samples/WindowsBase.dll");
let view = CilAssemblyView::from_path(&path).expect("Failed to load test assembly");
let config = ValidationConfig::comprehensive();
let engine =
ValidationEngine::new(&view, config).expect("Failed to create validation engine");
let stats = engine.statistics();
assert!(
stats.raw_validator_count >= 7,
"Expected at least 7 raw validators, got {}",
stats.raw_validator_count
);
assert!(
stats.owned_validator_count >= 15,
"Expected at least 15 owned validators, got {}",
stats.owned_validator_count
);
}
#[test]
fn test_raw_validators_creation_and_uniqueness() {
let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/samples/WindowsBase.dll");
let view = CilAssemblyView::from_path(&path).expect("Failed to load test assembly");
let config = ValidationConfig::comprehensive();
let engine =
ValidationEngine::new(&view, config).expect("Failed to create validation engine");
let result = engine.execute_stage1_validation(&view, None);
assert!(
result.is_ok() || result.is_err(),
"Validation should complete without panicking"
);
}
#[test]
fn test_owned_validators_creation_and_uniqueness() {
let manifest_dir = std::env::var("CARGO_MANIFEST_DIR").unwrap_or_else(|_| ".".to_string());
let path = PathBuf::from(&manifest_dir).join("tests/samples/mono_4.8/mscorlib.dll");
let object = CilObject::from_path(&path).expect("Failed to load mscorlib.dll");
let view = CilAssemblyView::from_path(&path).expect("Failed to load test assembly");
let config = ValidationConfig::comprehensive();
let engine =
ValidationEngine::new(&view, config).expect("Failed to create validation engine");
let validation_result = engine.execute_stage2_validation(&object);
assert!(
validation_result.is_ok() || validation_result.is_err(),
"Validation should complete without panicking"
);
}
#[test]
fn test_complete_two_stage_validation() {
let manifest_dir = std::env::var("CARGO_MANIFEST_DIR").unwrap_or_else(|_| ".".to_string());
let path = PathBuf::from(&manifest_dir).join("tests/samples/mono_4.8/mscorlib.dll");
let view = CilAssemblyView::from_path(&path).expect("Failed to load test assembly");
let object = CilObject::from_path(&path).expect("Failed to load mscorlib.dll");
let configs = vec![
("minimal", ValidationConfig::minimal()),
("production", ValidationConfig::production()),
("comprehensive", ValidationConfig::comprehensive()),
];
configs.into_par_iter().for_each(|(name, config)| {
let engine =
ValidationEngine::new(&view, config).expect("Failed to create validation engine");
let result = engine.execute_two_stage_validation(&view, None, Some(&object));
assert!(
result.is_ok(),
"Two-stage validation should complete for {name} config"
);
if let Ok(two_stage_result) = result {
assert!(
two_stage_result.stage1_result().is_some()
|| two_stage_result.stage2_result().is_some(),
"At least one validation stage should have run for {name} config"
);
}
});
}
#[test]
fn test_validation_engine_factories() {
let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/samples/WindowsBase.dll");
let view = CilAssemblyView::from_path(&path).expect("Failed to load test assembly");
let engines = vec![
("minimal", factory::minimal_engine(&view)),
("production", factory::production_engine(&view)),
("comprehensive", factory::comprehensive_engine(&view)),
("strict", factory::strict_engine(&view)),
];
for (name, engine_result) in engines {
assert!(
engine_result.is_ok(),
"Failed to create {name} engine: {:?}",
engine_result.err()
);
if let Ok(engine) = engine_result {
let stats = engine.statistics();
assert!(
stats.raw_validator_count > 0,
"{name} engine should have raw validators"
);
assert!(
stats.owned_validator_count > 0,
"{name} engine should have owned validators"
);
}
}
}
#[test]
fn test_validator_name_uniqueness() {
let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/samples/WindowsBase.dll");
let view = CilAssemblyView::from_path(&path).expect("Failed to load test assembly");
let config = ValidationConfig::comprehensive();
let engine =
ValidationEngine::new(&view, config).expect("Failed to create validation engine");
let stats = engine.statistics();
let total_validators = stats.raw_validator_count + stats.owned_validator_count;
assert!(
total_validators >= 22,
"Expected at least 22 total validators, got {total_validators}"
);
assert_eq!(
total_validators,
stats.raw_validator_count + stats.owned_validator_count,
"Total validator count should equal sum of raw and owned validators"
);
}
}