use std::collections::{BTreeSet, HashSet};
use crate::{
cilassembly::{
cleanup::{
compaction::mark_unreferenced_heap_entries,
orphans::{self, DeletionContext},
references::{
collect_pre_deletion_references, collect_typedefs_from_field_signatures,
scan_method_body_tokens,
},
utils::{is_cctor_method, list_range, try_remove},
CleanupRequest, CleanupStats,
},
CilAssembly,
},
metadata::{
tables::{
CustomAttributeRaw, FieldRaw, InterfaceImplRaw, MethodDefRaw, MethodImplRaw,
MethodSemanticsRaw, MethodSpecRaw, TableId, TypeDefRaw,
},
token::Token,
},
Result,
};
pub fn execute_cleanup(
assembly: &mut CilAssembly,
request: &CleanupRequest,
) -> Result<CleanupStats> {
let mut stats = CleanupStats::new();
if request.is_empty() {
return Ok(stats);
}
let (type_methods, type_fields) = expand_type_members(assembly, request);
let mut all_methods: BTreeSet<Token> = request.methods().copied().collect();
all_methods.extend(type_methods);
let mut all_fields: BTreeSet<Token> = request.fields().copied().collect();
all_fields.extend(type_fields);
let all_types: BTreeSet<Token> = request.types().copied().collect();
let mut pre_refs =
collect_pre_deletion_references(assembly, &all_methods, &all_fields, &all_types);
pre_refs
.il_tokens
.extend(request.rewrite_orphaned_tokens().iter().copied());
{
let view = assembly.view();
if let Some(tables) = view.tables() {
if let Some(attr_table) = tables.table::<CustomAttributeRaw>() {
for attr_token in request.attributes() {
if let Some(attr) = attr_table.get(attr_token.row()) {
pre_refs.il_tokens.insert(attr.constructor.token);
}
}
}
}
}
let mut removed_types: HashSet<Token> = HashSet::new();
let mut removed_methods: HashSet<Token> = HashSet::new();
let mut removed_fields: HashSet<Token> = HashSet::new();
for method_token in all_methods.iter().rev() {
if request.is_protected(*method_token) {
continue;
}
if try_remove(assembly, TableId::MethodDef, method_token.row()) {
removed_methods.insert(*method_token);
stats.add(TableId::MethodDef, 1);
}
}
for spec_token in request.methodspecs() {
if request.is_protected(*spec_token) {
continue;
}
if try_remove(assembly, TableId::MethodSpec, spec_token.row()) {
stats.add(TableId::MethodSpec, 1);
}
}
for field_token in all_fields.iter().rev() {
if request.is_protected(*field_token) {
continue;
}
if try_remove(assembly, TableId::Field, field_token.row()) {
removed_fields.insert(*field_token);
stats.add(TableId::Field, 1);
}
}
let body_tokens = scan_method_body_tokens(assembly);
let sig_referenced_typedefs: HashSet<u32> = body_tokens
.iter()
.filter(|t| t.is_table(TableId::TypeDef))
.map(|t| t.row())
.collect();
for type_token in request.types() {
if request.is_protected(*type_token) {
continue;
}
if sig_referenced_typedefs.contains(&type_token.row()) {
continue;
}
if try_remove(assembly, TableId::TypeDef, type_token.row()) {
removed_types.insert(*type_token);
stats.add(TableId::TypeDef, 1);
}
}
let orphaned_nested = orphans::collect_orphaned_nested_types(
assembly,
&DeletionContext::new(&removed_types, &removed_methods, &removed_fields),
);
if !orphaned_nested.is_empty() {
let (nested_methods, nested_fields) = {
let mut nested_request = CleanupRequest::new();
for &t in &orphaned_nested {
nested_request.add_type(t);
}
expand_type_members(assembly, &nested_request)
};
let mut nested_methods_sorted: Vec<_> = nested_methods.into_iter().collect();
nested_methods_sorted.sort_by(|a, b| b.cmp(a));
for method_token in &nested_methods_sorted {
if try_remove(assembly, TableId::MethodDef, method_token.row()) {
removed_methods.insert(*method_token);
stats.add(TableId::MethodDef, 1);
}
}
let mut nested_fields_sorted: Vec<_> = nested_fields.into_iter().collect();
nested_fields_sorted.sort_by(|a, b| b.cmp(a));
for field_token in &nested_fields_sorted {
if try_remove(assembly, TableId::Field, field_token.row()) {
removed_fields.insert(*field_token);
stats.add(TableId::Field, 1);
}
}
let mut sorted_nested: Vec<_> = orphaned_nested.clone();
sorted_nested.sort_by_key(|t| std::cmp::Reverse(t.row()));
for type_token in &sorted_nested {
if try_remove(assembly, TableId::TypeDef, type_token.row()) {
removed_types.insert(*type_token);
stats.add(TableId::TypeDef, 1);
}
}
}
for attr_token in request.attributes() {
if try_remove(assembly, TableId::CustomAttribute, attr_token.row()) {
stats.add(TableId::CustomAttribute, 1);
}
}
for asmref_token in request.assemblyrefs() {
if try_remove(assembly, TableId::AssemblyRef, asmref_token.row()) {
stats.add(TableId::AssemblyRef, 1);
}
}
for modref_token in request.modulerefs() {
if try_remove(assembly, TableId::ModuleRef, modref_token.row()) {
stats.add(TableId::ModuleRef, 1);
}
}
for res_token in request.manifest_resources() {
if try_remove(assembly, TableId::ManifestResource, res_token.row()) {
stats.add(TableId::ManifestResource, 1);
}
}
if request.remove_orphans() {
const MAX_CASCADE_ROUNDS: usize = 10;
for _round in 0..MAX_CASCADE_ROUNDS {
let alive_methods = collect_alive_method_tokens(assembly);
let alive_fields = collect_alive_field_tokens(assembly);
let dead_methods: Vec<Token> = pre_refs
.il_tokens
.iter()
.filter(|t| t.is_table(TableId::MethodDef))
.filter(|t| !alive_methods.contains(t))
.filter(|t| !removed_methods.contains(t))
.filter(|t| {
!assembly
.changes()
.is_row_deleted(TableId::MethodDef, t.row())
})
.filter(|t| !is_cctor_method(assembly, t.row()))
.filter(|t| !request.is_protected(**t))
.copied()
.collect();
let dead_fields: Vec<Token> = pre_refs
.il_tokens
.iter()
.filter(|t| t.is_table(TableId::Field))
.filter(|t| !alive_fields.contains(t))
.filter(|t| !removed_fields.contains(t))
.filter(|t| !assembly.changes().is_row_deleted(TableId::Field, t.row()))
.filter(|t| !request.is_protected(**t))
.copied()
.collect();
if dead_methods.is_empty() && dead_fields.is_empty() {
break;
}
let dead_methods_set: BTreeSet<Token> = dead_methods.iter().copied().collect();
let dead_fields_set: BTreeSet<Token> = dead_fields.iter().copied().collect();
let empty_types = BTreeSet::new();
let new_refs = collect_pre_deletion_references(
assembly,
&dead_methods_set,
&dead_fields_set,
&empty_types,
);
pre_refs.il_tokens.extend(new_refs.il_tokens);
pre_refs.typeref_rids.extend(new_refs.typeref_rids);
pre_refs
.standalonesig_rids
.extend(new_refs.standalonesig_rids);
let mut method_count = 0usize;
for token in dead_methods.iter().rev() {
if try_remove(assembly, TableId::MethodDef, token.row()) {
removed_methods.insert(*token);
method_count += 1;
}
}
stats.add(TableId::MethodDef, method_count);
let mut field_count = 0usize;
for token in dead_fields.iter().rev() {
if try_remove(assembly, TableId::Field, token.row()) {
removed_fields.insert(*token);
field_count += 1;
}
}
stats.add(TableId::Field, field_count);
}
}
let body_tokens = scan_method_body_tokens(assembly);
if request.remove_empty_types() {
let (empty_removed, empty_type_tokens) =
remove_empty_types(assembly, &body_tokens, request);
stats.add(TableId::TypeDef, empty_removed);
if !empty_type_tokens.is_empty() {
removed_types.extend(empty_type_tokens.iter().copied());
let empty_methods = HashSet::new();
let empty_fields = HashSet::new();
let empty_ctx = DeletionContext::new(&empty_type_tokens, &empty_methods, &empty_fields);
let type_dep_stats = orphans::remove_type_dependents(assembly, &empty_ctx);
stats.merge(&type_dep_stats);
}
}
if request.remove_orphans() {
let ctx = DeletionContext::new(&removed_types, &removed_methods, &removed_fields);
let orphan_stats = orphans::remove_parent_child_dependents(assembly, &ctx, &pre_refs);
stats.merge(&orphan_stats);
let cascade_stats = orphans::cascade_reference_cleanup(assembly, &pre_refs, &body_tokens);
stats.merge(&cascade_stats);
}
let compaction_stats = mark_unreferenced_heap_entries(assembly)?;
stats.blobs_compacted = compaction_stats.blobs;
stats.guids_compacted = compaction_stats.guids;
stats.strings_compacted = compaction_stats.strings;
stats.sections_excluded = request.excluded_sections().len();
Ok(stats)
}
fn expand_type_members(
assembly: &CilAssembly,
request: &CleanupRequest,
) -> (HashSet<Token>, HashSet<Token>) {
let mut methods = HashSet::new();
let mut fields = HashSet::new();
let view = assembly.view();
let Some(tables) = view.tables() else {
return (methods, fields);
};
let Some(typedef_table) = tables.table::<TypeDefRaw>() else {
return (methods, fields);
};
let methoddef_count = tables.table::<MethodDefRaw>().map_or(0, |t| t.row_count);
let field_count = tables.table::<FieldRaw>().map_or(0, |t| t.row_count);
let type_count = typedef_table.row_count;
for type_token in request.types() {
let type_rid = type_token.row();
let Some(typedef) = typedef_table.get(type_rid) else {
continue;
};
let method_range = list_range(type_rid, type_count, methoddef_count, |rid| {
typedef_table.get(rid).map(|t| t.method_list)
});
for method_rid in typedef.method_list..method_range.end {
methods.insert(Token::from_parts(TableId::MethodDef, method_rid));
}
let field_range = list_range(type_rid, type_count, field_count, |rid| {
typedef_table.get(rid).map(|t| t.field_list)
});
for field_rid in typedef.field_list..field_range.end {
fields.insert(Token::from_parts(TableId::Field, field_rid));
}
}
(methods, fields)
}
fn remove_empty_types(
assembly: &mut CilAssembly,
body_tokens: &HashSet<Token>,
request: &CleanupRequest,
) -> (usize, HashSet<Token>) {
let mut referenced_typedefs: HashSet<u32> = body_tokens
.iter()
.filter(|t| t.is_table(TableId::TypeDef))
.map(|t| t.row())
.collect();
let field_sig_refs = collect_typedefs_from_field_signatures(assembly);
referenced_typedefs.extend(
field_sig_refs
.iter()
.filter(|t| t.is_table(TableId::TypeDef))
.map(|t| t.row()),
);
let empty_types: Vec<u32> = {
let view = assembly.view();
let Some(tables) = view.tables() else {
return (0, HashSet::new());
};
let Some(typedef_table) = tables.table::<TypeDefRaw>() else {
return (0, HashSet::new());
};
let methoddef_count = tables.table::<MethodDefRaw>().map_or(0, |t| t.row_count);
let field_count = tables.table::<FieldRaw>().map_or(0, |t| t.row_count);
let type_count = typedef_table.row_count;
let mut empty = Vec::new();
for type_rid in 1..=type_count {
let Some(typedef) = typedef_table.get(type_rid) else {
continue;
};
if type_rid == 1 {
continue;
}
if request.is_protected(Token::from_parts(TableId::TypeDef, type_rid)) {
continue;
}
if referenced_typedefs.contains(&type_rid) {
continue;
}
let method_range = list_range(type_rid, type_count, methoddef_count, |rid| {
typedef_table.get(rid).map(|t| t.method_list)
});
let live_method_count = (typedef.method_list..method_range.end)
.filter(|&rid| !assembly.changes().is_row_deleted(TableId::MethodDef, rid))
.count();
let field_range = list_range(type_rid, type_count, field_count, |rid| {
typedef_table.get(rid).map(|t| t.field_list)
});
let live_field_count = (typedef.field_list..field_range.end)
.filter(|&rid| !assembly.changes().is_row_deleted(TableId::Field, rid))
.count();
if live_method_count == 0 && live_field_count == 0 {
if typedef.flags & 0x20 != 0 {
continue;
}
let is_base_class = typedef_table.iter().any(|other| {
other.rid != type_rid
&& !empty.contains(&other.rid)
&& other.extends.tag == TableId::TypeDef
&& other.extends.row == type_rid
});
if is_base_class {
continue;
}
if let Some(iface_impl) = tables.table::<InterfaceImplRaw>() {
let is_implemented = iface_impl.iter().any(|row| {
row.interface.tag == TableId::TypeDef && row.interface.row == type_rid
});
if is_implemented {
continue;
}
}
empty.push(type_rid);
}
}
empty
};
let mut removed = 0;
let mut removed_tokens = HashSet::new();
for rid in empty_types.into_iter().rev() {
if try_remove(assembly, TableId::TypeDef, rid) {
removed += 1;
removed_tokens.insert(Token::from_parts(TableId::TypeDef, rid));
}
}
(removed, removed_tokens)
}
fn collect_alive_method_tokens(assembly: &CilAssembly) -> HashSet<Token> {
let body_tokens = scan_method_body_tokens(assembly);
let mut alive: HashSet<Token> = body_tokens
.iter()
.filter(|t| t.is_table(TableId::MethodDef))
.copied()
.collect();
let view = assembly.view();
let Some(tables) = view.tables() else {
return alive;
};
if let Some(methodspec_table) = tables.table::<MethodSpecRaw>() {
for row in methodspec_table {
if assembly
.changes()
.is_row_deleted(TableId::MethodSpec, row.rid)
{
continue;
}
let spec_token = Token::from_parts(TableId::MethodSpec, row.rid);
if body_tokens.contains(&spec_token) && row.method.token.is_table(TableId::MethodDef) {
alive.insert(row.method.token);
}
}
}
if let Some(attr_table) = tables.table::<CustomAttributeRaw>() {
for row in attr_table {
if assembly
.changes()
.is_row_deleted(TableId::CustomAttribute, row.rid)
{
continue;
}
if row.constructor.token.is_table(TableId::MethodDef) {
alive.insert(row.constructor.token);
}
}
}
if let Some(sem_table) = tables.table::<MethodSemanticsRaw>() {
for row in sem_table {
if assembly
.changes()
.is_row_deleted(TableId::MethodSemantics, row.rid)
{
continue;
}
let method_token = Token::from_parts(TableId::MethodDef, row.method);
alive.insert(method_token);
}
}
if let Some(impl_table) = tables.table::<MethodImplRaw>() {
for row in impl_table {
if assembly
.changes()
.is_row_deleted(TableId::MethodImpl, row.rid)
{
continue;
}
if row.method_body.token.is_table(TableId::MethodDef) {
alive.insert(row.method_body.token);
}
if row.method_declaration.token.is_table(TableId::MethodDef) {
alive.insert(row.method_declaration.token);
}
}
}
alive
}
fn collect_alive_field_tokens(assembly: &CilAssembly) -> HashSet<Token> {
let body_tokens = scan_method_body_tokens(assembly);
body_tokens
.into_iter()
.filter(|t| t.is_table(TableId::Field))
.collect()
}
#[cfg(test)]
mod tests {
use crate::{
cilassembly::cleanup::CleanupRequest,
metadata::{tables::TableId, token::Token},
};
#[test]
fn test_execute_cleanup_empty_request() {
let request = CleanupRequest::new();
assert!(request.is_empty());
}
#[test]
fn test_cleanup_request_with_types() {
let mut request = CleanupRequest::new();
request.add_type(Token::from_parts(TableId::TypeDef, 5));
request.add_method(Token::from_parts(TableId::MethodDef, 10));
assert!(!request.is_empty());
assert_eq!(request.types_len(), 1);
assert_eq!(request.methods_len(), 1);
}
}