use crate::{
metadata::{
tables::TableId,
validation::{
scanner::{HeapSizes, ReferenceScanner},
ScannerStatistics,
},
},
Error, Result,
};
pub struct SchemaValidator<'a> {
scanner: &'a ReferenceScanner,
}
impl<'a> SchemaValidator<'a> {
#[must_use]
pub fn new(scanner: &'a ReferenceScanner) -> Self {
Self { scanner }
}
pub fn validate_basic_structure(&self, tables: &crate::TablesHeader) -> Result<()> {
if self.scanner.table_row_count(TableId::Module) == 0 {
return Err(Error::ValidationOwnedFailed {
validator: "SchemaValidator".to_string(),
message: "Module table is required but empty".to_string(),
});
}
if self.scanner.table_row_count(TableId::Module) > 1 {
return Err(Error::ValidationOwnedFailed {
validator: "SchemaValidator".to_string(),
message: "Module table must contain exactly one row".to_string(),
});
}
self.validate_table_consistency(tables)?;
Ok(())
}
fn validate_table_consistency(&self, _tables: &crate::TablesHeader) -> Result<()> {
let typedef_count = self.scanner.table_row_count(TableId::TypeDef);
if typedef_count > 0 {
let assembly_count = self.scanner.table_row_count(TableId::Assembly);
let assemblyref_count = self.scanner.table_row_count(TableId::AssemblyRef);
if assembly_count == 0 && assemblyref_count == 0 {
return Err(Error::ValidationOwnedFailed {
validator: "SchemaValidator".to_string(),
message: "TypeDef tables require Assembly or AssemblyRef table".to_string(),
});
}
}
self.validate_field_map_consistency()?;
self.validate_method_map_consistency()?;
Ok(())
}
fn validate_field_map_consistency(&self) -> Result<()> {
let typedef_count = self.scanner.table_row_count(TableId::TypeDef);
let field_count = self.scanner.table_row_count(TableId::Field);
if field_count > 0 && typedef_count == 0 {
return Err(Error::ValidationOwnedFailed {
validator: "SchemaValidator".to_string(),
message: "Field table requires TypeDef table".to_string(),
});
}
Ok(())
}
fn validate_method_map_consistency(&self) -> Result<()> {
let typedef_count = self.scanner.table_row_count(TableId::TypeDef);
let methoddef_count = self.scanner.table_row_count(TableId::MethodDef);
if methoddef_count > 0 && typedef_count == 0 {
return Err(Error::ValidationOwnedFailed {
validator: "SchemaValidator".to_string(),
message: "MethodDef table requires TypeDef table".to_string(),
});
}
Ok(())
}
pub fn validate_heap_reference(&self, heap_type: &str, index: u32) -> Result<()> {
self.scanner.validate_heap_index(heap_type, index)
}
pub fn validate_heap_references<I>(&self, heap_type: &str, indices: I) -> Result<()>
where
I: IntoIterator<Item = u32>,
{
for index in indices {
self.validate_heap_reference(heap_type, index)?;
}
Ok(())
}
pub fn validate_rid(&self, table_id: TableId, rid: u32) -> Result<()> {
if rid == 0 {
return Err(Error::InvalidRid {
table: table_id,
rid,
});
}
let max_rid = self.scanner.table_row_count(table_id);
if rid > max_rid {
return Err(Error::InvalidRid {
table: table_id,
rid,
});
}
Ok(())
}
pub fn validate_coded_index(&self, coded_index: u32, allowed_tables: &[TableId]) -> Result<()> {
if coded_index == 0 {
return Ok(());
}
let table_bits = allowed_tables.len().next_power_of_two().trailing_zeros();
let table_index = coded_index & ((1 << table_bits) - 1);
let rid = coded_index >> table_bits;
if (table_index as usize) >= allowed_tables.len() {
return Err(Error::InvalidToken {
token: crate::metadata::token::Token::new(coded_index),
message: format!("Table index {table_index} not in allowed range"),
});
}
let table_id = allowed_tables[table_index as usize];
self.validate_rid(table_id, rid)
}
pub fn validate_string_index(&self, index: u32) -> Result<()> {
self.validate_heap_reference("strings", index)
}
pub fn validate_blob_index(&self, index: u32) -> Result<()> {
self.validate_heap_reference("blobs", index)
}
pub fn validate_guid_index(&self, index: u32) -> Result<()> {
if index == 0 {
return Ok(());
}
let guid_heap_size = self.scanner.heap_sizes().guids;
let max_index = guid_heap_size / 16;
if index > max_index {
return Err(Error::HeapBoundsError {
heap: "guids".to_string(),
index,
});
}
Ok(())
}
pub fn validate_user_string_index(&self, index: u32) -> Result<()> {
self.validate_heap_reference("userstrings", index)
}
#[must_use]
pub fn get_validation_statistics(&self) -> SchemaValidationStatistics {
SchemaValidationStatistics {
total_tables: self.scanner.count_non_empty_tables(),
total_rows: self.scanner.count_total_rows(),
heap_sizes: self.scanner.heap_sizes().clone(),
scanner_stats: self.scanner.statistics(),
}
}
}
#[derive(Debug, Clone)]
pub struct SchemaValidationStatistics {
pub total_tables: usize,
pub total_rows: u32,
pub heap_sizes: HeapSizes,
pub scanner_stats: ScannerStatistics,
}
impl std::fmt::Display for SchemaValidationStatistics {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"Schema Statistics: {} tables, {} total rows, Heaps(strings: {}, blobs: {}, guids: {}, userstrings: {})",
self.total_tables,
self.total_rows,
self.heap_sizes.strings,
self.heap_sizes.blobs,
self.heap_sizes.guids,
self.heap_sizes.userstrings
)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::metadata::cilassemblyview::CilAssemblyView;
use std::path::PathBuf;
#[test]
fn test_schema_validator_creation() {
let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/samples/WindowsBase.dll");
if let Ok(view) = CilAssemblyView::from_path(&path) {
if let Ok(scanner) = ReferenceScanner::from_view(&view) {
let validator = SchemaValidator::new(&scanner);
let stats = validator.get_validation_statistics();
assert!(stats.total_tables > 0);
assert!(stats.total_rows > 0);
}
}
}
#[test]
fn test_basic_structure_validation() {
let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/samples/WindowsBase.dll");
if let Ok(view) = crate::metadata::cilassemblyview::CilAssemblyView::from_path(&path) {
if let Ok(scanner) = ReferenceScanner::from_view(&view) {
let validator = SchemaValidator::new(&scanner);
if let Some(tables) = view.tables() {
assert!(validator.validate_basic_structure(tables).is_ok());
}
}
}
}
#[test]
fn test_rid_validation() {
let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/samples/WindowsBase.dll");
if let Ok(view) = CilAssemblyView::from_path(&path) {
if let Ok(scanner) = ReferenceScanner::from_view(&view) {
let validator = SchemaValidator::new(&scanner);
assert!(validator.validate_rid(TableId::TypeDef, 0).is_err());
if scanner.table_row_count(TableId::TypeDef) > 0 {
assert!(validator.validate_rid(TableId::TypeDef, 1).is_ok());
}
let max_rid = scanner.table_row_count(TableId::TypeDef);
if max_rid > 0 {
assert!(validator
.validate_rid(TableId::TypeDef, max_rid + 1)
.is_err());
}
}
}
}
#[test]
fn test_heap_reference_validation() {
let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/samples/WindowsBase.dll");
if let Ok(view) = CilAssemblyView::from_path(&path) {
if let Ok(scanner) = ReferenceScanner::from_view(&view) {
let validator = SchemaValidator::new(&scanner);
assert!(validator.validate_string_index(0).is_ok());
assert!(validator.validate_blob_index(0).is_ok());
assert!(validator.validate_guid_index(0).is_ok());
assert!(validator.validate_user_string_index(0).is_ok()); }
}
}
#[test]
fn test_validation_statistics() {
let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/samples/WindowsBase.dll");
if let Ok(view) = CilAssemblyView::from_path(&path) {
if let Ok(scanner) = ReferenceScanner::from_view(&view) {
let validator = SchemaValidator::new(&scanner);
let stats = validator.get_validation_statistics();
let stats_string = stats.to_string();
assert!(stats_string.contains("tables"));
assert!(stats_string.contains("rows"));
assert!(stats_string.contains("Heaps"));
}
}
}
#[test]
fn test_coded_index_validation() {
let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/samples/WindowsBase.dll");
if let Ok(view) = CilAssemblyView::from_path(&path) {
if let Ok(scanner) = ReferenceScanner::from_view(&view) {
let validator = SchemaValidator::new(&scanner);
let allowed_tables = &[TableId::TypeDef, TableId::TypeRef, TableId::TypeSpec];
assert!(validator.validate_coded_index(0, allowed_tables).is_ok());
}
}
}
}