use std::sync::{Arc, OnceLock};
use rustc_hash::FxHashMap;
use crate::metadata::token::Token;
use crate::{
cilassembly::AssemblyChanges,
metadata::{
cilassemblyview::CilAssemblyView,
cilobject::CilObject,
method::Method,
typesystem::{CilTypeRc, TypeSource},
validation::{config::ValidationConfig, scanner::ReferenceScanner},
},
};
use rayon::ThreadPool;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ValidationStage {
Raw,
Owned,
}
pub trait ValidationContext {
fn validation_stage(&self) -> ValidationStage;
fn reference_scanner(&self) -> &ReferenceScanner;
fn config(&self) -> &ValidationConfig;
}
pub struct RawValidationContext<'a> {
view: &'a CilAssemblyView,
changes: Option<&'a AssemblyChanges>,
scanner: &'a ReferenceScanner,
config: &'a ValidationConfig,
thread_pool: &'a ThreadPool,
}
impl<'a> RawValidationContext<'a> {
#[must_use]
pub fn new_for_loading(
view: &'a CilAssemblyView,
scanner: &'a ReferenceScanner,
config: &'a ValidationConfig,
thread_pool: &'a ThreadPool,
) -> Self {
Self {
view,
changes: None,
scanner,
config,
thread_pool,
}
}
pub fn new_for_modification(
view: &'a CilAssemblyView,
changes: &'a AssemblyChanges,
scanner: &'a ReferenceScanner,
config: &'a ValidationConfig,
thread_pool: &'a ThreadPool,
) -> Self {
Self {
view,
changes: Some(changes),
scanner,
config,
thread_pool,
}
}
#[must_use]
pub fn changes(&self) -> Option<&AssemblyChanges> {
self.changes
}
#[must_use]
pub fn is_modification_validation(&self) -> bool {
self.changes.is_some()
}
#[must_use]
pub fn is_loading_validation(&self) -> bool {
self.changes.is_none()
}
#[must_use]
pub fn assembly_view(&self) -> &CilAssemblyView {
self.view
}
#[must_use]
pub fn thread_pool(&self) -> &ThreadPool {
self.thread_pool
}
}
impl ValidationContext for RawValidationContext<'_> {
fn validation_stage(&self) -> ValidationStage {
ValidationStage::Raw
}
fn reference_scanner(&self) -> &ReferenceScanner {
self.scanner
}
fn config(&self) -> &ValidationConfig {
self.config
}
}
pub struct MethodTypeMapping {
method_to_type: FxHashMap<usize, usize>,
type_to_methods: FxHashMap<usize, Vec<usize>>,
address_to_method: FxHashMap<usize, Arc<Method>>,
}
impl MethodTypeMapping {
#[must_use]
pub fn new(all_types: Vec<CilTypeRc>) -> Self {
let mut method_to_type = FxHashMap::default();
let mut type_to_methods: FxHashMap<usize, Vec<usize>> = FxHashMap::default();
let mut address_to_method = FxHashMap::default();
for type_entry in all_types {
let type_address = Arc::as_ptr(&type_entry) as usize;
let mut type_methods = Vec::new();
for (_, method_ref) in type_entry.methods.iter() {
if let Some(method_rc) = method_ref.upgrade() {
let method_address = Arc::as_ptr(&method_rc) as usize;
method_to_type.insert(method_address, type_address);
address_to_method.insert(method_address, Arc::clone(&method_rc));
type_methods.push(method_address);
}
}
if !type_methods.is_empty() {
type_to_methods.insert(type_address, type_methods);
}
}
Self {
method_to_type,
type_to_methods,
address_to_method,
}
}
#[must_use]
pub fn method_belongs_to_type(&self, method_address: usize, type_address: usize) -> bool {
self.method_to_type.get(&method_address) == Some(&type_address)
}
#[must_use]
pub fn get_type_methods(&self, type_address: usize) -> &[usize] {
self.type_to_methods
.get(&type_address)
.map_or(&[], Vec::as_slice)
}
#[must_use]
pub fn get_method(&self, method_address: usize) -> Option<&Arc<Method>> {
self.address_to_method.get(&method_address)
}
#[must_use]
pub fn method_count(&self) -> usize {
self.address_to_method.len()
}
#[must_use]
pub fn type_count(&self) -> usize {
self.type_to_methods.len()
}
}
#[derive(Default)]
pub struct ValidationCache {
target_types: OnceLock<Vec<CilTypeRc>>,
all_types: OnceLock<Vec<CilTypeRc>>,
method_type_mapping: OnceLock<MethodTypeMapping>,
interface_relationships: OnceLock<FxHashMap<usize, Vec<usize>>>,
nested_relationships: OnceLock<FxHashMap<Token, Vec<Token>>>,
}
impl ValidationCache {
#[must_use]
pub fn new() -> Self {
Self {
target_types: OnceLock::new(),
all_types: OnceLock::new(),
method_type_mapping: OnceLock::new(),
interface_relationships: OnceLock::new(),
nested_relationships: OnceLock::new(),
}
}
}
pub struct OwnedValidationContext<'a> {
object: &'a CilObject,
scanner: &'a ReferenceScanner,
config: &'a ValidationConfig,
cache: ValidationCache,
thread_pool: &'a ThreadPool,
}
impl<'a> OwnedValidationContext<'a> {
pub fn new(
object: &'a CilObject,
scanner: &'a ReferenceScanner,
config: &'a ValidationConfig,
thread_pool: &'a ThreadPool,
) -> Self {
Self {
object,
scanner,
config,
cache: ValidationCache::new(),
thread_pool,
}
}
#[must_use]
pub fn object(&self) -> &CilObject {
self.object
}
#[must_use]
pub fn thread_pool(&self) -> &ThreadPool {
self.thread_pool
}
}
impl OwnedValidationContext<'_> {
pub fn target_assembly_types(&self) -> &Vec<CilTypeRc> {
self.cache.target_types.get_or_init(|| {
if let Some(assembly_identity) = self.object.identity() {
self.object
.types()
.types_from_source(&TypeSource::Assembly(assembly_identity))
} else {
Vec::new()
}
})
}
pub fn all_types(&self) -> &Vec<CilTypeRc> {
self.cache
.all_types
.get_or_init(|| self.object.types().all_types())
}
pub fn method_type_mapping(&self) -> &MethodTypeMapping {
self.cache
.method_type_mapping
.get_or_init(|| MethodTypeMapping::new(self.all_types().clone()))
}
pub fn interface_relationships(&self) -> &FxHashMap<usize, Vec<usize>> {
self.cache.interface_relationships.get_or_init(|| {
let mut relationships = FxHashMap::default();
for type_entry in self.all_types() {
let type_ptr = Arc::as_ptr(type_entry) as usize;
let mut implemented_interfaces = Vec::new();
for (_, interface_ref) in type_entry.interfaces.iter() {
if let Some(interface_type) = interface_ref.upgrade() {
let interface_ptr = Arc::as_ptr(&interface_type) as usize;
implemented_interfaces.push(interface_ptr);
}
}
if !implemented_interfaces.is_empty() {
relationships.insert(type_ptr, implemented_interfaces);
}
}
relationships
})
}
pub fn nested_relationships(&self) -> &FxHashMap<Token, Vec<Token>> {
self.cache.nested_relationships.get_or_init(|| {
let mut relationships = FxHashMap::default();
for type_entry in self.all_types() {
let token = type_entry.token;
let mut nested_tokens = Vec::new();
for (_, nested_ref) in type_entry.nested_types.iter() {
if let Some(nested_type) = nested_ref.upgrade() {
nested_tokens.push(nested_type.token);
}
}
if !nested_tokens.is_empty() {
relationships.insert(token, nested_tokens);
}
}
relationships
})
}
}
impl ValidationContext for OwnedValidationContext<'_> {
fn validation_stage(&self) -> ValidationStage {
ValidationStage::Owned
}
fn reference_scanner(&self) -> &ReferenceScanner {
self.scanner
}
fn config(&self) -> &ValidationConfig {
self.config
}
}
pub mod factory {
use super::{
AssemblyChanges, CilAssemblyView, CilObject, OwnedValidationContext, RawValidationContext,
ReferenceScanner, ValidationConfig,
};
use rayon::ThreadPool;
pub fn raw_loading_context<'a>(
view: &'a CilAssemblyView,
scanner: &'a ReferenceScanner,
config: &'a ValidationConfig,
thread_pool: &'a ThreadPool,
) -> RawValidationContext<'a> {
RawValidationContext::new_for_loading(view, scanner, config, thread_pool)
}
pub fn raw_modification_context<'a>(
view: &'a CilAssemblyView,
changes: &'a AssemblyChanges,
scanner: &'a ReferenceScanner,
config: &'a ValidationConfig,
thread_pool: &'a ThreadPool,
) -> RawValidationContext<'a> {
RawValidationContext::new_for_modification(view, changes, scanner, config, thread_pool)
}
pub fn owned_context<'a>(
object: &'a CilObject,
scanner: &'a ReferenceScanner,
config: &'a ValidationConfig,
thread_pool: &'a ThreadPool,
) -> OwnedValidationContext<'a> {
OwnedValidationContext::new(object, scanner, config, thread_pool)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::metadata::validation::config::ValidationConfig;
use rayon::ThreadPoolBuilder;
use std::path::PathBuf;
#[test]
fn test_raw_loading_context() {
let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/samples/WindowsBase.dll");
if let Ok(view) = CilAssemblyView::from_path(&path) {
let scanner = ReferenceScanner::from_view(&view).unwrap();
let config = ValidationConfig::minimal();
let thread_pool = ThreadPoolBuilder::new().num_threads(4).build().unwrap();
let context =
RawValidationContext::new_for_loading(&view, &scanner, &config, &thread_pool);
assert_eq!(context.validation_stage(), ValidationStage::Raw);
assert!(context.is_loading_validation());
assert!(!context.is_modification_validation());
assert!(context.changes().is_none());
}
}
#[test]
fn test_raw_modification_context() {
let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/samples/WindowsBase.dll");
if let Ok(view) = CilAssemblyView::from_path(&path) {
let scanner = ReferenceScanner::from_view(&view).unwrap();
let config = ValidationConfig::minimal();
let changes = AssemblyChanges::new();
let thread_pool = ThreadPoolBuilder::new().num_threads(4).build().unwrap();
let context = RawValidationContext::new_for_modification(
&view,
&changes,
&scanner,
&config,
&thread_pool,
);
assert_eq!(context.validation_stage(), ValidationStage::Raw);
assert!(!context.is_loading_validation());
assert!(context.is_modification_validation());
assert!(context.changes().is_some());
}
}
#[test]
fn test_factory_functions() {
let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/samples/WindowsBase.dll");
if let Ok(view) = CilAssemblyView::from_path(&path) {
let scanner = ReferenceScanner::from_view(&view).unwrap();
let config = ValidationConfig::minimal();
let changes = AssemblyChanges::new();
let thread_pool = ThreadPoolBuilder::new().num_threads(4).build().unwrap();
let loading_context =
factory::raw_loading_context(&view, &scanner, &config, &thread_pool);
assert_eq!(loading_context.validation_stage(), ValidationStage::Raw);
let modification_context =
factory::raw_modification_context(&view, &changes, &scanner, &config, &thread_pool);
assert_eq!(
modification_context.validation_stage(),
ValidationStage::Raw
);
}
}
}