use std::collections::{HashMap, HashSet};
use crate::{
metadata::{token::Token, typesystem::TypeRegistry},
CilObject,
};
#[derive(Debug)]
pub struct CallResolver {
virtual_dispatch_table: HashMap<Token, Vec<Token>>,
method_to_type: HashMap<Token, Token>,
type_subtypes: HashMap<Token, Vec<Token>>,
interfaces: HashSet<Token>,
sealed_types: HashSet<Token>,
}
impl CallResolver {
#[must_use]
pub fn empty() -> Self {
Self {
virtual_dispatch_table: HashMap::new(),
method_to_type: HashMap::new(),
type_subtypes: HashMap::new(),
interfaces: HashSet::new(),
sealed_types: HashSet::new(),
}
}
#[must_use]
pub fn new(assembly: &CilObject) -> Self {
let mut resolver = Self {
virtual_dispatch_table: HashMap::new(),
method_to_type: HashMap::new(),
type_subtypes: HashMap::new(),
interfaces: HashSet::new(),
sealed_types: HashSet::new(),
};
resolver.build_type_hierarchy(assembly);
resolver.build_virtual_dispatch_table(assembly);
resolver
}
fn build_type_hierarchy(&mut self, assembly: &CilObject) {
let types = assembly.types();
for type_info in types.all_types() {
let token = type_info.token;
if type_info.is_interface() {
self.interfaces.insert(token);
}
if type_info.is_sealed() {
self.sealed_types.insert(token);
}
for method in &type_info.query_methods() {
self.method_to_type.insert(method.token, token);
}
}
for type_info in types.all_types() {
if let Some(base) = type_info.base() {
let base_token = base.token;
self.type_subtypes
.entry(base_token)
.or_default()
.push(type_info.token);
}
}
}
fn build_virtual_dispatch_table(&mut self, assembly: &CilObject) {
let methods = assembly.methods();
let types = assembly.types();
for entry in methods {
let method = entry.value();
if !method.is_virtual() {
continue;
}
let method_token = method.token;
let method_name = &method.name;
let Some(&declaring_type) = self.method_to_type.get(&method_token) else {
continue;
};
let mut overriders = vec![method_token]; self.find_overriders(&types, declaring_type, method_name, &mut overriders);
if overriders.len() > 1 {
self.virtual_dispatch_table.insert(method_token, overriders);
}
}
}
fn find_overriders(
&self,
types: &TypeRegistry,
type_token: Token,
method_name: &str,
overriders: &mut Vec<Token>,
) {
let Some(subtypes) = self.type_subtypes.get(&type_token) else {
return;
};
for &subtype_token in subtypes {
if let Some(subtype) = types.get(&subtype_token) {
for (_, method_ref) in subtype.methods.iter() {
if let Some(method) = method_ref.upgrade() {
if method.name == method_name && method.is_virtual() {
if !overriders.contains(&method.token) {
overriders.push(method.token);
}
break;
}
}
}
}
self.find_overriders(types, subtype_token, method_name, overriders);
}
}
#[must_use]
pub fn resolve_virtual(&self, method_token: Token) -> Vec<Token> {
self.virtual_dispatch_table
.get(&method_token)
.cloned()
.unwrap_or_else(|| vec![method_token])
}
#[must_use]
pub fn is_polymorphic(&self, method_token: Token) -> bool {
self.virtual_dispatch_table
.get(&method_token)
.is_some_and(|targets| targets.len() > 1)
}
#[must_use]
pub fn declaring_type(&self, method_token: Token) -> Option<Token> {
self.method_to_type.get(&method_token).copied()
}
#[must_use]
pub fn is_interface(&self, type_token: Token) -> bool {
self.interfaces.contains(&type_token)
}
#[must_use]
pub fn is_sealed(&self, type_token: Token) -> bool {
self.sealed_types.contains(&type_token)
}
#[must_use]
pub fn subtypes(&self, type_token: Token) -> Vec<Token> {
self.type_subtypes
.get(&type_token)
.cloned()
.unwrap_or_default()
}
#[must_use]
pub fn all_subtypes(&self, type_token: Token) -> Vec<Token> {
let mut result = Vec::new();
let mut worklist = vec![type_token];
let mut visited = HashSet::new();
while let Some(current) = worklist.pop() {
if !visited.insert(current) {
continue;
}
if let Some(subtypes) = self.type_subtypes.get(¤t) {
for &subtype in subtypes {
result.push(subtype);
worklist.push(subtype);
}
}
}
result
}
#[must_use]
pub fn stats(&self) -> ResolverStats {
let polymorphic_methods = self
.virtual_dispatch_table
.values()
.filter(|targets| targets.len() > 1)
.count();
let max_targets = self
.virtual_dispatch_table
.values()
.map(Vec::len)
.max()
.unwrap_or(0);
ResolverStats {
total_methods: self.method_to_type.len(),
virtual_methods: self.virtual_dispatch_table.len(),
polymorphic_methods,
max_targets,
total_types: self.type_subtypes.len() + self.sealed_types.len(),
interface_types: self.interfaces.len(),
sealed_types: self.sealed_types.len(),
}
}
}
#[derive(Debug, Clone, Default)]
pub struct ResolverStats {
pub total_methods: usize,
pub virtual_methods: usize,
pub polymorphic_methods: usize,
pub max_targets: usize,
pub total_types: usize,
pub interface_types: usize,
pub sealed_types: usize,
}
#[cfg(test)]
mod tests {
use crate::analysis::callgraph::ResolverStats;
#[test]
fn test_resolver_stats_default() {
let stats = ResolverStats::default();
assert_eq!(stats.total_methods, 0);
assert_eq!(stats.virtual_methods, 0);
assert_eq!(stats.polymorphic_methods, 0);
}
}