use strum::IntoEnumIterator;
use crate::{
metadata::{tables::TableId, token::Token, validation::scanner::ReferenceScanner},
Error, Result,
};
use rustc_hash::{FxHashMap, FxHashSet};
use std::collections::HashSet;
pub struct ReferenceValidator<'a> {
scanner: &'a ReferenceScanner,
}
impl<'a> ReferenceValidator<'a> {
#[must_use]
pub fn new(scanner: &'a ReferenceScanner) -> Self {
Self { scanner }
}
pub fn validate_token_references<I>(&self, tokens: I) -> Result<()>
where
I: IntoIterator<Item = Token>,
{
for token in tokens {
if !self.scanner.token_exists(token) {
return Err(Error::InvalidRid {
table: TableId::from_token_type(token.table()).unwrap_or(TableId::Module),
rid: token.row(),
});
}
}
Ok(())
}
pub fn validate_token_integrity(&self, token: Token) -> Result<()> {
if !self.scanner.token_exists(token) {
return Err(Error::InvalidRid {
table: TableId::from_token_type(token.table()).unwrap_or(TableId::Module),
rid: token.row(),
});
}
if let Some(outgoing_refs) = self.scanner.references_from(token) {
for &referenced_token in outgoing_refs {
if !self.scanner.token_exists(referenced_token) {
return Err(Error::InvalidRid {
table: TableId::from_token_type(referenced_token.table())
.unwrap_or(TableId::Module),
rid: referenced_token.row(),
});
}
}
}
if let Some(incoming_refs) = self.scanner.references_to(token) {
for &referencing_token in incoming_refs {
if !self.scanner.token_exists(referencing_token) {
return Err(Error::InvalidRid {
table: TableId::from_token_type(referencing_token.table())
.unwrap_or(TableId::Module),
rid: referencing_token.row(),
});
}
}
}
Ok(())
}
#[must_use]
pub fn has_circular_references(&self, start_token: Token) -> bool {
let mut visited = FxHashSet::default();
let mut recursion_stack = FxHashSet::default();
self.detect_cycle_dfs(start_token, &mut visited, &mut recursion_stack)
}
fn detect_cycle_dfs(
&self,
token: Token,
visited: &mut FxHashSet<Token>,
recursion_stack: &mut FxHashSet<Token>,
) -> bool {
if recursion_stack.contains(&token) {
return true; }
if visited.contains(&token) {
return false; }
visited.insert(token);
recursion_stack.insert(token);
if let Some(references) = self.scanner.references_from(token) {
for &referenced_token in references {
if self.detect_cycle_dfs(referenced_token, visited, recursion_stack) {
return true;
}
}
}
recursion_stack.remove(&token);
false
}
#[must_use]
pub fn find_references_to_row(&self, table_id: TableId, rid: u32) -> HashSet<(TableId, u32)> {
let target_token_value = (u32::from(table_id.token_type()) << 24) | (rid & 0x00FF_FFFF);
let target_token = Token::new(target_token_value);
self.scanner
.references_to(target_token)
.map(|refs| {
refs.iter()
.filter_map(|token| {
TableId::from_token_type(token.table()).map(|table| (table, token.row()))
})
.collect()
})
.unwrap_or_default()
}
pub fn validate_deletion_safety(&self, token: Token) -> Result<()> {
if !self.scanner.can_delete_token(token) {
let ref_count = self
.scanner
.references_to(token)
.map_or(0, std::collections::HashSet::len);
let token_value = token.value();
return Err(Error::CrossReferenceError(format!(
"Cannot delete token {token_value:#x}: {ref_count} references would be broken"
)));
}
Ok(())
}
#[must_use]
pub fn analyze_reference_patterns(&self) -> ReferenceAnalysis {
let mut analysis = ReferenceAnalysis::default();
for table_id in TableId::iter() {
let row_count = self.scanner.table_row_count(table_id);
for rid in 1..=row_count {
let token = Self::create_token(table_id, rid);
self.analyze_token_references(token, &mut analysis);
}
}
analysis
}
fn analyze_token_references(&self, token: Token, analysis: &mut ReferenceAnalysis) {
let incoming_count = self
.scanner
.references_to(token)
.map_or(0, std::collections::HashSet::len);
let outgoing_count = self
.scanner
.references_from(token)
.map_or(0, std::collections::HashSet::len);
analysis.total_tokens += 1;
analysis.total_references += incoming_count + outgoing_count;
if incoming_count == 0 {
analysis.orphaned_tokens.insert(token);
}
if incoming_count > 10 {
analysis
.highly_referenced_tokens
.insert(token, incoming_count);
}
if self.has_circular_references(token) {
analysis.circular_reference_chains.push(token);
}
}
fn create_token(table_id: TableId, rid: u32) -> Token {
let table_token_base = u32::from(table_id.token_type()) << 24;
Token::new(table_token_base | rid)
}
pub fn validate_forward_references(&self, token: Token) -> Result<()> {
if let Some(references) = self.scanner.references_from(token) {
for &referenced_token in references {
if !self.scanner.token_exists(referenced_token) {
let from_token = token.value();
let to_token = referenced_token.value();
return Err(Error::CrossReferenceError(format!(
"Forward reference from {from_token:#x} to non-existent token {to_token:#x}"
)));
}
}
}
Ok(())
}
pub fn validate_parent_child_relationship(
&self,
parent_token: Token,
child_token: Token,
) -> Result<()> {
if !self.scanner.token_exists(parent_token) {
return Err(Error::InvalidRid {
table: TableId::from_token_type(parent_token.table()).unwrap_or(TableId::Module),
rid: parent_token.row(),
});
}
if !self.scanner.token_exists(child_token) {
return Err(Error::InvalidRid {
table: TableId::from_token_type(child_token.table()).unwrap_or(TableId::Module),
rid: child_token.row(),
});
}
if parent_token == child_token {
let token_value = parent_token.value();
return Err(Error::CrossReferenceError(format!(
"Self-referential parent-child relationship detected for token {token_value:#x}"
)));
}
if let Some(parent_references) = self.scanner.references_from(child_token) {
if parent_references.contains(&parent_token) {
let parent_value = parent_token.value();
let child_value = child_token.value();
return Err(Error::CrossReferenceError(format!(
"Circular parent-child relationship detected between {parent_value:#x} and {child_value:#x}"
)));
}
}
Ok(())
}
pub fn validate_nested_class_relationship(
&self,
enclosing_token: Token,
nested_token: Token,
) -> Result<()> {
if !self.scanner.token_exists(enclosing_token) {
return Err(Error::InvalidRid {
table: TableId::from_token_type(enclosing_token.table()).unwrap_or(TableId::Module),
rid: enclosing_token.row(),
});
}
if !self.scanner.token_exists(nested_token) {
return Err(Error::InvalidRid {
table: TableId::from_token_type(nested_token.table()).unwrap_or(TableId::Module),
rid: nested_token.row(),
});
}
if enclosing_token == nested_token {
let token_value = enclosing_token.value();
return Err(Error::CrossReferenceError(format!(
"Self-referential nested class relationship detected for token {token_value:#x}"
)));
}
if self.scanner.is_nested_within(nested_token, enclosing_token) {
let enclosing_value = enclosing_token.value();
let nested_value = nested_token.value();
return Err(Error::CrossReferenceError(format!(
"Circular nested class relationship detected: {enclosing_value:#x} cannot be the \
enclosing class of {nested_value:#x} because {enclosing_value:#x} is already \
nested within {nested_value:#x}"
)));
}
Ok(())
}
#[must_use]
pub fn get_reference_statistics(&self) -> ReferenceStatistics {
let analysis = self.analyze_reference_patterns();
ReferenceStatistics {
total_tokens: analysis.total_tokens,
total_references: analysis.total_references,
orphaned_count: analysis.orphaned_tokens.len(),
circular_chains: analysis.circular_reference_chains.len(),
highly_referenced_count: analysis.highly_referenced_tokens.len(),
max_incoming_references: analysis
.highly_referenced_tokens
.values()
.max()
.copied()
.unwrap_or(0),
}
}
}
#[derive(Debug, Default)]
pub struct ReferenceAnalysis {
pub total_tokens: usize,
pub total_references: usize,
pub orphaned_tokens: FxHashSet<Token>,
pub highly_referenced_tokens: FxHashMap<Token, usize>,
pub circular_reference_chains: Vec<Token>,
}
#[derive(Debug, Clone)]
pub struct ReferenceStatistics {
pub total_tokens: usize,
pub total_references: usize,
pub orphaned_count: usize,
pub circular_chains: usize,
pub highly_referenced_count: usize,
pub max_incoming_references: usize,
}
impl std::fmt::Display for ReferenceStatistics {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"Reference Statistics: {} tokens, {} references, {} orphaned, {} circular chains, {} highly referenced (max: {})",
self.total_tokens,
self.total_references,
self.orphaned_count,
self.circular_chains,
self.highly_referenced_count,
self.max_incoming_references
)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::metadata::cilassemblyview::CilAssemblyView;
use std::path::PathBuf;
#[test]
fn test_reference_validator_creation() {
let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/samples/WindowsBase.dll");
if let Ok(view) = CilAssemblyView::from_path(&path) {
if let Ok(scanner) = ReferenceScanner::from_view(&view) {
let validator = ReferenceValidator::new(&scanner);
let stats = validator.get_reference_statistics();
assert!(stats.total_tokens > 0);
}
}
}
#[test]
fn test_token_reference_validation() {
let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/samples/WindowsBase.dll");
if let Ok(view) = CilAssemblyView::from_path(&path) {
if let Ok(scanner) = ReferenceScanner::from_view(&view) {
let validator = ReferenceValidator::new(&scanner);
if scanner.table_row_count(TableId::TypeDef) > 0 {
let valid_token = Token::new(0x02000001); let tokens = vec![valid_token];
assert!(validator.validate_token_references(tokens).is_ok());
}
let invalid_token = Token::new(0x02000000); let invalid_tokens = vec![invalid_token];
assert!(validator.validate_token_references(invalid_tokens).is_err());
}
}
}
#[test]
fn test_deletion_safety_validation() {
let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/samples/WindowsBase.dll");
if let Ok(view) = CilAssemblyView::from_path(&path) {
if let Ok(scanner) = ReferenceScanner::from_view(&view) {
let validator = ReferenceValidator::new(&scanner);
if scanner.table_row_count(TableId::TypeDef) > 0 {
let token = Token::new(0x02000001);
let result = validator.validate_deletion_safety(token);
let _ = result;
}
}
}
}
#[test]
fn test_circular_reference_detection() {
let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/samples/WindowsBase.dll");
if let Ok(view) = CilAssemblyView::from_path(&path) {
if let Ok(scanner) = ReferenceScanner::from_view(&view) {
let validator = ReferenceValidator::new(&scanner);
if scanner.table_row_count(TableId::TypeDef) > 0 {
let token = Token::new(0x02000001);
let has_circular = validator.has_circular_references(token);
let _ = has_circular;
}
}
}
}
#[test]
fn test_parent_child_relationship_validation() {
let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/samples/WindowsBase.dll");
if let Ok(view) = CilAssemblyView::from_path(&path) {
if let Ok(scanner) = ReferenceScanner::from_view(&view) {
let validator = ReferenceValidator::new(&scanner);
if scanner.table_row_count(TableId::TypeDef) >= 2 {
let parent_token = Token::new(0x02000001); let child_token = Token::new(0x02000002);
let result =
validator.validate_parent_child_relationship(parent_token, child_token);
assert!(result.is_ok());
let self_ref_result =
validator.validate_parent_child_relationship(parent_token, parent_token);
assert!(self_ref_result.is_err());
}
}
}
}
#[test]
fn test_reference_analysis() {
let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/samples/WindowsBase.dll");
if let Ok(view) = CilAssemblyView::from_path(&path) {
if let Ok(scanner) = ReferenceScanner::from_view(&view) {
let validator = ReferenceValidator::new(&scanner);
let analysis = validator.analyze_reference_patterns();
assert!(analysis.total_tokens > 0);
let stats = validator.get_reference_statistics();
let stats_string = stats.to_string();
assert!(stats_string.contains("tokens"));
assert!(stats_string.contains("references"));
}
}
}
#[test]
fn test_forward_reference_validation() {
let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/samples/WindowsBase.dll");
if let Ok(view) = CilAssemblyView::from_path(&path) {
if let Ok(scanner) = ReferenceScanner::from_view(&view) {
let validator = ReferenceValidator::new(&scanner);
if scanner.table_row_count(TableId::TypeDef) > 0 {
let token = Token::new(0x02000001);
let result = validator.validate_forward_references(token);
assert!(result.is_ok());
}
}
}
}
}