use crate::{
metadata::{tables::TableId, token::Token, validation::scanner::ReferenceScanner},
Error, Result,
};
pub struct TokenValidator<'a> {
scanner: &'a ReferenceScanner,
}
impl<'a> TokenValidator<'a> {
#[must_use]
pub fn new(scanner: &'a ReferenceScanner) -> Self {
Self { scanner }
}
pub fn validate_token_bounds(&self, token: Token) -> Result<()> {
self.scanner.validate_token_bounds(token)
}
#[must_use]
pub fn token_exists(&self, token: Token) -> bool {
self.scanner.token_exists(token)
}
pub fn validate_token_collection<I>(&self, tokens: I) -> Result<()>
where
I: IntoIterator<Item = Token>,
{
for token in tokens {
self.validate_token_bounds(token)?;
}
Ok(())
}
pub fn validate_token_table_type(&self, token: Token, expected_table: TableId) -> Result<()> {
let token_table_value = token.table();
let expected_table_value = expected_table.token_type();
if token_table_value != expected_table_value {
return Err(Error::InvalidToken {
token,
message: format!(
"Token belongs to table {token_table_value:#x}, expected table {expected_table_value:#x}"
),
});
}
self.validate_token_bounds(token)
}
#[must_use]
pub fn can_delete_token(&self, token: Token) -> bool {
self.scanner.can_delete_token(token)
}
#[must_use]
pub fn references_to(&self, token: Token) -> Option<&rustc_hash::FxHashSet<Token>> {
self.scanner.references_to(token)
}
#[must_use]
pub fn references_from(&self, token: Token) -> Option<&rustc_hash::FxHashSet<Token>> {
self.scanner.references_from(token)
}
#[must_use]
pub fn has_references_to(&self, token: Token) -> bool {
self.scanner.has_references_to(token)
}
#[must_use]
pub fn has_references_from(&self, token: Token) -> bool {
self.scanner.has_references_from(token)
}
pub fn validate_null_token(&self, token: Token, nullable: bool) -> Result<()> {
if token.value() == 0 {
if nullable {
Ok(())
} else {
Err(Error::InvalidRid {
table: TableId::Module, rid: 0,
})
}
} else {
self.validate_token_bounds(token)
}
}
pub fn validate_typed_token(&self, token: Token, allowed_tables: &[TableId]) -> Result<()> {
let token_table_value = token.table();
for &allowed_table in allowed_tables {
if token_table_value == allowed_table.token_type() {
return self.validate_token_bounds(token);
}
}
Err(Error::InvalidToken {
token,
message: format!("Table type {token_table_value:#x} not in allowed tables"),
})
}
#[must_use]
pub fn table_row_count(&self, table_id: TableId) -> u32 {
self.scanner.table_row_count(table_id)
}
pub fn validate_token_value(&self, token_value: u32) -> Result<()> {
let token = Token::new(token_value);
self.validate_token_bounds(token)
}
pub fn validate_table_row(&self, table_id: TableId, rid: u32) -> Result<()> {
if rid == 0 {
return Err(Error::InvalidRid {
table: table_id,
rid,
});
}
let max_rid = self.scanner.table_row_count(table_id);
if rid > max_rid {
return Err(Error::InvalidRid {
table: table_id,
rid,
});
}
Ok(())
}
pub fn validate_token_values<I>(&self, token_values: I) -> Result<()>
where
I: IntoIterator<Item = u32>,
{
for token_value in token_values {
self.validate_token_value(token_value)?;
}
Ok(())
}
}
#[derive(Debug, Default)]
pub struct TokenValidationResult {
errors: Vec<Error>,
}
impl TokenValidationResult {
pub fn new() -> Self {
Self::default()
}
pub fn add_result(&mut self, result: Result<()>) {
if let Err(error) = result {
self.errors.push(error);
}
}
pub fn has_errors(&self) -> bool {
!self.errors.is_empty()
}
pub fn error_count(&self) -> usize {
self.errors.len()
}
pub fn into_result(self) -> Result<()> {
if let Some(first_error) = self.errors.into_iter().next() {
Err(first_error)
} else {
Ok(())
}
}
pub fn errors(&self) -> &[Error] {
&self.errors
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::metadata::cilassemblyview::CilAssemblyView;
use std::path::PathBuf;
#[test]
fn test_token_validator_creation() {
let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/samples/WindowsBase.dll");
if let Ok(view) = CilAssemblyView::from_path(&path) {
if let Ok(scanner) = ReferenceScanner::from_view(&view) {
let validator = TokenValidator::new(&scanner);
let _count = validator.table_row_count(TableId::TypeDef);
}
}
}
#[test]
fn test_token_bounds_validation() {
let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/samples/WindowsBase.dll");
if let Ok(view) = CilAssemblyView::from_path(&path) {
if let Ok(scanner) = ReferenceScanner::from_view(&view) {
let validator = TokenValidator::new(&scanner);
let invalid_token = Token::new(0x02000000); assert!(validator.validate_token_bounds(invalid_token).is_err());
if validator.table_row_count(TableId::TypeDef) > 0 {
let valid_token = Token::new(0x02000001); assert!(validator.validate_token_bounds(valid_token).is_ok());
}
}
}
}
#[test]
fn test_token_table_type_validation() {
let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/samples/WindowsBase.dll");
if let Ok(view) = CilAssemblyView::from_path(&path) {
if let Ok(scanner) = ReferenceScanner::from_view(&view) {
let validator = TokenValidator::new(&scanner);
if validator.table_row_count(TableId::TypeDef) > 0 {
let typedef_token = Token::new(0x02000001);
assert!(validator
.validate_token_table_type(typedef_token, TableId::TypeDef)
.is_ok());
assert!(validator
.validate_token_table_type(typedef_token, TableId::MethodDef)
.is_err());
}
}
}
}
#[test]
fn test_token_validation_result() {
let result = TokenValidationResult::new();
assert!(!result.has_errors());
assert_eq!(result.error_count(), 0);
assert!(result.into_result().is_ok());
let mut result = TokenValidationResult::new();
result.add_result(Ok(()));
result.add_result(Err(Error::InvalidRid {
table: TableId::TypeDef,
rid: 0,
}));
assert!(result.has_errors());
assert_eq!(result.error_count(), 1);
assert!(result.into_result().is_err());
}
#[test]
fn test_null_token_validation() {
let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/samples/WindowsBase.dll");
if let Ok(view) = CilAssemblyView::from_path(&path) {
if let Ok(scanner) = ReferenceScanner::from_view(&view) {
let validator = TokenValidator::new(&scanner);
let null_token = Token::new(0);
assert!(validator.validate_null_token(null_token, true).is_ok());
assert!(validator.validate_null_token(null_token, false).is_err());
}
}
}
#[test]
fn test_typed_token_validation() {
let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/samples/WindowsBase.dll");
if let Ok(view) = CilAssemblyView::from_path(&path) {
if let Ok(scanner) = ReferenceScanner::from_view(&view) {
let validator = TokenValidator::new(&scanner);
if validator.table_row_count(TableId::TypeDef) > 0 {
let typedef_token = Token::new(0x02000001); let allowed_tables = &[TableId::TypeDef, TableId::TypeRef, TableId::TypeSpec];
assert!(validator
.validate_typed_token(typedef_token, allowed_tables)
.is_ok());
let not_allowed = &[TableId::MethodDef, TableId::Field];
assert!(validator
.validate_typed_token(typedef_token, not_allowed)
.is_err());
}
}
}
}
}