use std::collections::HashSet;
use crate::{
assembly::{decode_blocks, Operand},
cilassembly::{modifications::TableModifications, operation::Operation, CilAssembly},
metadata::{
method::MethodBody,
signatures::{
parse_field_signature, parse_local_var_signature, parse_method_signature,
parse_property_signature, parse_type_spec_signature, CustomModifier,
SignatureLocalVariable, SignatureParameter, TypeSignature,
},
streams::{Blob, TablesHeader},
tables::{
CustomAttributeRaw, FieldRaw, GenericParamConstraintRaw, InterfaceImplRaw,
MemberRefRaw, MethodDefRaw, MethodSpecRaw, PropertyRaw, StandAloneSigRaw,
TableDataOwned, TableId, TypeDefRaw, TypeRefRaw, TypeSpecRaw,
},
token::Token,
},
};
const PLACEHOLDER_RVA_THRESHOLD: u32 = 0xF000_0000;
fn is_row_deleted(assembly: &CilAssembly, table_id: TableId, rid: u32) -> bool {
if let Some(table_mods) = assembly.changes().table_changes.get(&table_id) {
match table_mods {
TableModifications::Sparse { operations, .. } => {
for op in operations.iter().rev() {
if op.get_rid() == rid {
return matches!(op.operation, Operation::Delete(_));
}
}
}
TableModifications::Replaced(rows) => {
return rid as usize > rows.len();
}
}
}
false
}
fn get_effective_method_rva(assembly: &CilAssembly, rid: u32, original_rva: u32) -> u32 {
if let Some(table_mods) = assembly.changes().table_changes.get(&TableId::MethodDef) {
match table_mods {
TableModifications::Sparse { operations, .. } => {
for op in operations.iter().rev() {
if op.get_rid() == rid {
match &op.operation {
Operation::Update(_, TableDataOwned::MethodDef(updated)) => {
return updated.rva;
}
Operation::Delete(_) => {
return 0;
}
_ => {}
}
}
}
}
TableModifications::Replaced(rows) => {
if let Some(TableDataOwned::MethodDef(row)) = rows.get((rid - 1) as usize) {
return row.rva;
}
}
}
}
original_rva
}
fn scan_method_body_bytes(data: &[u8], base_rva: usize, referenced: &mut HashSet<Token>) {
let Ok(body) = MethodBody::from(data) else {
return;
};
let code_start = body.size_header;
let code_end = code_start + body.size_code;
if code_end > data.len() {
return;
}
let code_rva = base_rva + body.size_header;
if let Ok(blocks) = decode_blocks(&data[code_start..code_end], 0, code_rva, None) {
for block in &blocks {
for instr in &block.instructions {
if let Operand::Token(token) = instr.operand {
referenced.insert(token);
}
}
}
}
}
fn collect_referenced_standalonesig_rids(assembly: &CilAssembly) -> HashSet<u32> {
let mut referenced = HashSet::new();
let view = assembly.view();
let Some(tables) = view.tables() else {
return referenced;
};
let Some(methoddef_table) = tables.table::<MethodDefRaw>() else {
return referenced;
};
let file = view.file();
let original_data = file.data();
for methoddef in methoddef_table {
let effective_rva = get_effective_method_rva(assembly, methoddef.rid, methoddef.rva);
if effective_rva == 0 {
continue;
}
if effective_rva >= PLACEHOLDER_RVA_THRESHOLD {
if let Some(body_bytes) = assembly.changes().get_method_body(effective_rva) {
if let Ok(body) = MethodBody::from(body_bytes.as_slice()) {
if body.local_var_sig_token != 0 {
let sig_token = Token::new(body.local_var_sig_token);
if sig_token.is_table(TableId::StandAloneSig) {
referenced.insert(sig_token.row());
}
}
}
}
} else {
let Ok(offset) = file.rva_to_offset(effective_rva as usize) else {
continue;
};
if offset < original_data.len() {
if let Ok(body) = MethodBody::from(&original_data[offset..]) {
if body.local_var_sig_token != 0 {
let sig_token = Token::new(body.local_var_sig_token);
if sig_token.is_table(TableId::StandAloneSig) {
referenced.insert(sig_token.row());
}
}
}
}
}
}
referenced
}
pub fn scan_method_body_tokens(assembly: &CilAssembly) -> HashSet<Token> {
let mut referenced = HashSet::new();
let view = assembly.view();
let Some(tables) = view.tables() else {
return referenced;
};
let Some(methoddef_table) = tables.table::<MethodDefRaw>() else {
return referenced;
};
let file = view.file();
let original_data = file.data();
for methoddef in methoddef_table {
let effective_rva = get_effective_method_rva(assembly, methoddef.rid, methoddef.rva);
if effective_rva == 0 {
continue;
}
if effective_rva >= PLACEHOLDER_RVA_THRESHOLD {
if let Some(body_bytes) = assembly.changes().get_method_body(effective_rva) {
scan_method_body_bytes(body_bytes, 0, &mut referenced);
}
} else {
let Ok(offset) = file.rva_to_offset(effective_rva as usize) else {
continue;
};
if offset >= original_data.len() {
continue;
}
scan_method_body_bytes(
&original_data[offset..],
effective_rva as usize,
&mut referenced,
);
}
}
referenced
}
pub fn scan_typeref_metadata_refs(assembly: &CilAssembly) -> HashSet<u32> {
let mut referenced_rids = HashSet::new();
let view = assembly.view();
let Some(tables) = view.tables() else {
return referenced_rids;
};
if let Some(typedef_table) = tables.table::<TypeDefRaw>() {
for typedef in typedef_table {
if typedef.extends.token.is_table(TableId::TypeRef) {
referenced_rids.insert(typedef.extends.token.row());
}
}
}
if let Some(interfaceimpl_table) = tables.table::<InterfaceImplRaw>() {
for impl_ in interfaceimpl_table {
if impl_.interface.token.is_table(TableId::TypeRef) {
referenced_rids.insert(impl_.interface.token.row());
}
}
}
if let Some(memberref_table) = tables.table::<MemberRefRaw>() {
for memberref in memberref_table {
if is_row_deleted(assembly, TableId::MemberRef, memberref.rid) {
continue;
}
if memberref.class.token.is_table(TableId::TypeRef) {
referenced_rids.insert(memberref.class.token.row());
}
}
}
if let Some(constraint_table) = tables.table::<GenericParamConstraintRaw>() {
for constraint in constraint_table {
if constraint.constraint.token.is_table(TableId::TypeRef) {
referenced_rids.insert(constraint.constraint.token.row());
}
}
}
if let Some(attr_table) = tables.table::<CustomAttributeRaw>() {
if let Some(memberref_table) = tables.table::<MemberRefRaw>() {
for attr in attr_table {
if attr.constructor.token.is_table(TableId::MemberRef) {
let memberref_rid = attr.constructor.token.row();
if is_row_deleted(assembly, TableId::MemberRef, memberref_rid) {
continue;
}
if let Some(memberref) = memberref_table.get(memberref_rid) {
if memberref.class.token.is_table(TableId::TypeRef) {
referenced_rids.insert(memberref.class.token.row());
}
}
}
}
}
}
referenced_rids
}
pub fn scan_memberref_metadata_refs(assembly: &CilAssembly) -> HashSet<u32> {
let mut referenced_rids = HashSet::new();
let view = assembly.view();
let Some(tables) = view.tables() else {
return referenced_rids;
};
if let Some(attr_table) = tables.table::<CustomAttributeRaw>() {
for attr in attr_table {
if attr.constructor.token.is_table(TableId::MemberRef) {
referenced_rids.insert(attr.constructor.token.row());
}
}
}
if let Some(methodspec_table) = tables.table::<MethodSpecRaw>() {
for spec in methodspec_table {
if spec.method.token.is_table(TableId::MemberRef) {
referenced_rids.insert(spec.method.token.row());
}
}
}
referenced_rids
}
pub fn scan_typespec_metadata_refs(assembly: &CilAssembly) -> HashSet<u32> {
let mut referenced_rids = HashSet::new();
let view = assembly.view();
let Some(tables) = view.tables() else {
return referenced_rids;
};
if let Some(memberref_table) = tables.table::<MemberRefRaw>() {
for memberref in memberref_table {
if memberref.class.token.is_table(TableId::TypeSpec) {
referenced_rids.insert(memberref.class.token.row());
}
}
}
if let Some(interfaceimpl_table) = tables.table::<InterfaceImplRaw>() {
for impl_ in interfaceimpl_table {
if impl_.interface.token.is_table(TableId::TypeSpec) {
referenced_rids.insert(impl_.interface.token.row());
}
}
}
if let Some(constraint_table) = tables.table::<GenericParamConstraintRaw>() {
for constraint in constraint_table {
if constraint.constraint.token.is_table(TableId::TypeSpec) {
referenced_rids.insert(constraint.constraint.token.row());
}
}
}
if let Some(typedef_table) = tables.table::<TypeDefRaw>() {
for typedef in typedef_table {
if typedef.extends.token.is_table(TableId::TypeSpec) {
referenced_rids.insert(typedef.extends.token.row());
}
}
}
referenced_rids
}
fn collect_typerefs_from_type(sig: &TypeSignature, referenced: &mut HashSet<u32>) {
match sig {
TypeSignature::Class(token) | TypeSignature::ValueType(token) => {
if token.is_table(TableId::TypeRef) {
referenced.insert(token.row());
}
}
TypeSignature::SzArray(arr) => {
for modifier in &arr.modifiers {
collect_typerefs_from_modifier(modifier, referenced);
}
collect_typerefs_from_type(&arr.base, referenced);
}
TypeSignature::Array(arr) => {
collect_typerefs_from_type(&arr.base, referenced);
}
TypeSignature::Ptr(ptr) => {
for modifier in &ptr.modifiers {
collect_typerefs_from_modifier(modifier, referenced);
}
collect_typerefs_from_type(&ptr.base, referenced);
}
TypeSignature::ByRef(inner) | TypeSignature::Pinned(inner) => {
collect_typerefs_from_type(inner, referenced);
}
TypeSignature::GenericInst(base, args) => {
collect_typerefs_from_type(base, referenced);
for arg in args {
collect_typerefs_from_type(arg, referenced);
}
}
TypeSignature::ModifiedRequired(modifiers) | TypeSignature::ModifiedOptional(modifiers) => {
for modifier in modifiers {
collect_typerefs_from_modifier(modifier, referenced);
}
}
TypeSignature::FnPtr(method_sig) => {
collect_typerefs_from_parameter(&method_sig.return_type, referenced);
for param in &method_sig.params {
collect_typerefs_from_parameter(param, referenced);
}
}
TypeSignature::Void
| TypeSignature::Boolean
| TypeSignature::Char
| TypeSignature::I1
| TypeSignature::U1
| TypeSignature::I2
| TypeSignature::U2
| TypeSignature::I4
| TypeSignature::U4
| TypeSignature::I8
| TypeSignature::U8
| TypeSignature::R4
| TypeSignature::R8
| TypeSignature::I
| TypeSignature::U
| TypeSignature::String
| TypeSignature::Object
| TypeSignature::TypedByRef
| TypeSignature::GenericParamType(_)
| TypeSignature::GenericParamMethod(_)
| TypeSignature::Sentinel
| TypeSignature::Internal
| TypeSignature::Unknown
| TypeSignature::Type
| TypeSignature::Boxed
| TypeSignature::Field
| TypeSignature::Modifier
| TypeSignature::Reserved => {}
}
}
fn collect_typerefs_from_modifier(modifier: &CustomModifier, referenced: &mut HashSet<u32>) {
if modifier.modifier_type.is_table(TableId::TypeRef) {
referenced.insert(modifier.modifier_type.row());
}
}
fn collect_typerefs_from_parameter(param: &SignatureParameter, referenced: &mut HashSet<u32>) {
for modifier in ¶m.modifiers {
collect_typerefs_from_modifier(modifier, referenced);
}
collect_typerefs_from_type(¶m.base, referenced);
}
fn collect_typerefs_from_local(local: &SignatureLocalVariable, referenced: &mut HashSet<u32>) {
for modifier in &local.modifiers {
collect_typerefs_from_modifier(modifier, referenced);
}
collect_typerefs_from_type(&local.base, referenced);
}
pub fn scan_signature_typeref_refs(assembly: &CilAssembly) -> HashSet<u32> {
let mut referenced_rids = HashSet::new();
let view = assembly.view();
let Some(tables) = view.tables() else {
return referenced_rids;
};
let Some(blob_heap) = view.blobs() else {
return referenced_rids;
};
if let Some(methoddef_table) = tables.table::<MethodDefRaw>() {
for methoddef in methoddef_table {
scan_method_signature_blob(blob_heap, methoddef.signature, &mut referenced_rids);
}
}
if let Some(field_table) = tables.table::<FieldRaw>() {
for field in field_table {
scan_field_signature_blob(blob_heap, field.signature, &mut referenced_rids);
}
}
if let Some(memberref_table) = tables.table::<MemberRefRaw>() {
for memberref in memberref_table {
if is_row_deleted(assembly, TableId::MemberRef, memberref.rid) {
continue;
}
if !scan_method_signature_blob(blob_heap, memberref.signature, &mut referenced_rids) {
scan_field_signature_blob(blob_heap, memberref.signature, &mut referenced_rids);
}
}
}
let referenced_sigs = collect_referenced_standalonesig_rids(assembly);
if let Some(standalonesig_table) = tables.table::<StandAloneSigRaw>() {
for sig in standalonesig_table {
if referenced_sigs.contains(&sig.rid) {
scan_local_var_signature_blob(blob_heap, sig.signature, &mut referenced_rids);
}
}
}
if let Some(typespec_table) = tables.table::<TypeSpecRaw>() {
for typespec in typespec_table {
scan_typespec_signature_blob(blob_heap, typespec.signature, &mut referenced_rids);
}
}
if let Some(property_table) = tables.table::<PropertyRaw>() {
for property in property_table {
scan_property_signature_blob(blob_heap, property.signature, &mut referenced_rids);
}
}
referenced_rids
}
fn scan_method_signature_blob(
blob_heap: &Blob<'_>,
blob_index: u32,
referenced: &mut HashSet<u32>,
) -> bool {
let Ok(blob_data) = blob_heap.get(blob_index as usize) else {
return false;
};
let Ok(sig) = parse_method_signature(blob_data) else {
return false;
};
collect_typerefs_from_parameter(&sig.return_type, referenced);
for param in &sig.params {
collect_typerefs_from_parameter(param, referenced);
}
true
}
fn scan_field_signature_blob(
blob_heap: &Blob<'_>,
blob_index: u32,
referenced: &mut HashSet<u32>,
) -> bool {
let Ok(blob_data) = blob_heap.get(blob_index as usize) else {
return false;
};
let Ok(sig) = parse_field_signature(blob_data) else {
return false;
};
for modifier in &sig.modifiers {
collect_typerefs_from_modifier(modifier, referenced);
}
collect_typerefs_from_type(&sig.base, referenced);
true
}
fn scan_local_var_signature_blob(
blob_heap: &Blob<'_>,
blob_index: u32,
referenced: &mut HashSet<u32>,
) {
let Ok(blob_data) = blob_heap.get(blob_index as usize) else {
return;
};
let Ok(sig) = parse_local_var_signature(blob_data) else {
return;
};
for local in &sig.locals {
collect_typerefs_from_local(local, referenced);
}
}
fn scan_typespec_signature_blob(
blob_heap: &Blob<'_>,
blob_index: u32,
referenced: &mut HashSet<u32>,
) {
let Ok(blob_data) = blob_heap.get(blob_index as usize) else {
return;
};
let Ok(sig) = parse_type_spec_signature(blob_data) else {
return;
};
collect_typerefs_from_type(&sig.base, referenced);
}
fn scan_property_signature_blob(
blob_heap: &Blob<'_>,
blob_index: u32,
referenced: &mut HashSet<u32>,
) {
let Ok(blob_data) = blob_heap.get(blob_index as usize) else {
return;
};
let Ok(sig) = parse_property_signature(blob_data) else {
return;
};
for modifier in &sig.modifiers {
collect_typerefs_from_modifier(modifier, referenced);
}
collect_typerefs_from_type(&sig.base, referenced);
for param in &sig.params {
collect_typerefs_from_parameter(param, referenced);
}
}
pub fn remove_unreferenced_typerefs(assembly: &mut CilAssembly) -> usize {
let mut referenced_rids = HashSet::new();
let body_tokens = scan_method_body_tokens(assembly);
for token in &body_tokens {
if token.is_table(TableId::TypeRef) {
referenced_rids.insert(token.row());
}
}
let metadata_refs = scan_typeref_metadata_refs(assembly);
referenced_rids.extend(metadata_refs);
let signature_refs = scan_signature_typeref_refs(assembly);
referenced_rids.extend(signature_refs);
let memberref_body_rids: HashSet<u32> = body_tokens
.iter()
.filter(|t| t.is_table(TableId::MemberRef))
.map(Token::row)
.collect();
{
let view = assembly.view();
if let Some(tables) = view.tables() {
if let Some(memberref_table) = tables.table::<MemberRefRaw>() {
for memberref in memberref_table {
if memberref_body_rids.contains(&memberref.rid)
&& memberref.class.token.is_table(TableId::TypeRef)
{
referenced_rids.insert(memberref.class.token.row());
}
}
}
}
}
let typeref_count = {
let view = assembly.view();
view.tables()
.and_then(TablesHeader::table::<TypeRefRaw>)
.map_or(0, |t| t.row_count)
};
let mut removed = 0;
for rid in (1..=typeref_count).rev() {
if !referenced_rids.contains(&rid)
&& assembly.table_row_remove(TableId::TypeRef, rid).is_ok()
{
removed += 1;
}
}
removed
}
pub fn remove_unreferenced_memberrefs(assembly: &mut CilAssembly) -> usize {
let mut referenced_rids = HashSet::new();
let body_tokens = scan_method_body_tokens(assembly);
for token in &body_tokens {
if token.is_table(TableId::MemberRef) {
referenced_rids.insert(token.row());
}
}
let metadata_refs = scan_memberref_metadata_refs(assembly);
referenced_rids.extend(metadata_refs);
let methodspec_body_rids: HashSet<u32> = body_tokens
.iter()
.filter(|t| t.is_table(TableId::MethodSpec))
.map(Token::row)
.collect();
{
let view = assembly.view();
if let Some(tables) = view.tables() {
if let Some(methodspec_table) = tables.table::<MethodSpecRaw>() {
for spec in methodspec_table {
if methodspec_body_rids.contains(&spec.rid)
&& spec.method.token.is_table(TableId::MemberRef)
{
referenced_rids.insert(spec.method.token.row());
}
}
}
}
}
let memberref_count = {
let view = assembly.view();
view.tables()
.and_then(TablesHeader::table::<MemberRefRaw>)
.map_or(0, |t| t.row_count)
};
let mut removed = 0;
for rid in (1..=memberref_count).rev() {
if !referenced_rids.contains(&rid)
&& assembly.table_row_remove(TableId::MemberRef, rid).is_ok()
{
removed += 1;
}
}
removed
}
pub fn remove_unreferenced_typespecs(assembly: &mut CilAssembly) -> usize {
let mut referenced_rids = HashSet::new();
let body_tokens = scan_method_body_tokens(assembly);
for token in &body_tokens {
if token.is_table(TableId::TypeSpec) {
referenced_rids.insert(token.row());
}
}
let metadata_refs = scan_typespec_metadata_refs(assembly);
referenced_rids.extend(metadata_refs);
let typespec_count = {
let view = assembly.view();
view.tables()
.and_then(TablesHeader::table::<TypeSpecRaw>)
.map_or(0, |t| t.row_count)
};
let mut removed = 0;
for rid in (1..=typespec_count).rev() {
if !referenced_rids.contains(&rid)
&& assembly.table_row_remove(TableId::TypeSpec, rid).is_ok()
{
removed += 1;
}
}
removed
}
#[cfg(test)]
mod tests {
use crate::metadata::{tables::TableId, token::Token};
#[test]
fn test_placeholder_rva_detection() {
const { assert!(0xF000_0000 >= super::PLACEHOLDER_RVA_THRESHOLD) };
const { assert!(0xF000_0001 >= super::PLACEHOLDER_RVA_THRESHOLD) };
const { assert!(0x2000 < super::PLACEHOLDER_RVA_THRESHOLD) };
const { assert!(0x0000_FFFF < super::PLACEHOLDER_RVA_THRESHOLD) };
}
#[test]
fn test_token_table_detection() {
let typeref_token = Token::new(0x01000005);
let memberref_token = Token::new(0x0A000010);
let typespec_token = Token::new(0x1B000003);
assert!(typeref_token.is_table(TableId::TypeRef));
assert!(memberref_token.is_table(TableId::MemberRef));
assert!(typespec_token.is_table(TableId::TypeSpec));
}
}