use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet};
use crate::{
cilassembly::cleanup::CleanupRequest,
metadata::{
tables::{CustomAttributeRaw, TableId},
token::Token,
},
CilObject,
};
pub fn expand_type_tokens(request: &CleanupRequest, assembly: &CilObject) -> HashSet<Token> {
let mut tokens = request.all_tokens();
let registry = assembly.types();
let mut types_to_expand: Vec<Token> = request.types().copied().collect();
for type_token in request.types() {
if let Some(cil_type) = registry.get(type_token) {
for (_, nested_ref) in cil_type.nested_types.iter() {
if let Some(t) = nested_ref.token() {
types_to_expand.push(t);
tokens.insert(t);
}
}
}
}
for type_token in &types_to_expand {
if let Some(cil_type) = registry.get(type_token) {
for (_, field) in cil_type.fields.iter() {
tokens.insert(field.token);
}
for (_, method_ref) in cil_type.methods.iter() {
if let Some(method) = method_ref.upgrade() {
tokens.insert(method.token);
}
}
}
}
tokens
}
pub fn find_unreferenced_types(
assembly: &CilObject,
method_call_graph: &BTreeMap<Token, BTreeSet<Token>>,
request: &CleanupRequest,
) -> Vec<Token> {
let method_to_type = build_method_type_map(assembly);
let mut entry_point_types: HashSet<Token> = HashSet::new();
let entry_token_val = assembly.cor20header().entry_point_token;
if entry_token_val != 0 {
let entry_token = Token::new(entry_token_val);
if let Some(&type_token) = method_to_type.get(&entry_token) {
entry_point_types.insert(type_token);
}
}
let registry = assembly.types();
let mut candidates: HashSet<Token> = HashSet::new();
for type_entry in registry.iter() {
let type_token = *type_entry.key();
if type_token.table() != 0x02 {
continue;
}
let cil_type = type_entry.value();
if cil_type.is_module_type() || cil_type.is_public() {
continue;
}
if entry_point_types.contains(&type_token) {
continue;
}
if request.types().any(|t| *t == type_token) {
continue;
}
let type_methods: Vec<Token> = cil_type.methods().map(|m| m.token).collect();
if type_methods.is_empty() {
continue;
}
let has_non_cctor = type_methods.iter().any(|m| {
!cil_type
.methods()
.any(|method| method.token == *m && method.is_cctor())
});
if !has_non_cctor {
continue;
}
candidates.insert(type_token);
}
let deleted_types: HashSet<Token> = request.types().copied().collect();
let mut deleted_methods: HashSet<Token> = request.methods().copied().collect();
for type_token in &deleted_types {
if let Some(cil_type) = registry.get(type_token) {
for method in cil_type.methods() {
deleted_methods.insert(method.token);
}
}
}
let mut has_external_caller: HashSet<Token> = HashSet::new();
for (caller_token, callees) in method_call_graph {
if deleted_methods.contains(caller_token) {
continue;
}
let Some(caller_type) = method_to_type.get(caller_token).copied() else {
continue;
};
if deleted_types.contains(&caller_type) {
continue;
}
let caller_is_candidate = candidates.contains(&caller_type);
for callee_token in callees {
let Some(callee_type) = method_to_type.get(callee_token).copied() else {
continue;
};
if caller_type == callee_type {
continue;
}
if !caller_is_candidate && candidates.contains(&callee_type) {
has_external_caller.insert(callee_type);
}
}
}
if let Some(tables) = assembly.tables() {
if let Some(attr_table) = tables.table::<CustomAttributeRaw>() {
for row in attr_table {
if row.constructor.token.is_table(TableId::MethodDef) {
if let Some(&ctor_type) = method_to_type.get(&row.constructor.token) {
if candidates.contains(&ctor_type) {
has_external_caller.insert(ctor_type);
}
}
}
}
}
}
candidates
.into_iter()
.filter(|t| !has_external_caller.contains(t))
.collect()
}
pub fn compute_entry_points(assembly: &CilObject, aggressive: bool) -> HashSet<Token> {
let mut entry_points = HashSet::new();
let entry_token_val = assembly.cor20header().entry_point_token;
if entry_token_val != 0 {
entry_points.insert(Token::new(entry_token_val));
}
for method_entry in assembly.methods() {
let method = method_entry.value();
if method.is_cctor() {
entry_points.insert(method.token);
continue;
}
if aggressive {
continue;
}
if method.is_public() {
let in_module = assembly.types().module_type().is_some_and(|module_type| {
module_type
.methods
.iter()
.any(|(_, r)| r.upgrade().is_some_and(|m| m.token == method.token))
});
if !in_module {
entry_points.insert(method.token);
}
}
}
entry_points
}
pub(super) fn build_method_type_map(assembly: &CilObject) -> HashMap<Token, Token> {
let registry = assembly.types();
let mut map = HashMap::new();
for type_entry in registry.iter() {
let type_token = *type_entry.key();
if type_token.table() != 0x02 {
continue;
}
let cil_type = type_entry.value();
for method in cil_type.methods() {
map.insert(method.token, type_token);
}
}
map
}
#[cfg(test)]
mod tests {
use std::collections::{BTreeMap, BTreeSet};
use crate::{cilassembly::cleanup::CleanupRequest, metadata::token::Token};
#[test]
fn test_expand_type_tokens_empty_request() {
let request = CleanupRequest::new();
let tokens = request.all_tokens();
assert!(tokens.is_empty());
}
#[test]
fn test_find_unreferenced_types_empty_graph() {
let graph: BTreeMap<Token, BTreeSet<Token>> = BTreeMap::new();
assert!(graph.is_empty());
}
}