use std::collections::HashMap;
use dashmap::DashMap;
use crate::{
emulation::engine::context::EmulationContext,
metadata::{
method::Method,
signatures::SignatureMethod,
token::Token,
typesystem::{CilType, CilTypeReference},
},
};
struct VTable {
slots: HashMap<Token, Token>,
}
pub struct DispatchResolver {
cache: DashMap<(Token, Token), Token>,
vtables: DashMap<Token, VTable>,
}
impl DispatchResolver {
#[must_use]
pub fn new() -> Self {
Self {
cache: DashMap::new(),
vtables: DashMap::new(),
}
}
#[must_use]
pub fn resolve(
&self,
declared_method: Token,
runtime_type: Token,
context: &EmulationContext,
) -> Token {
let key = (runtime_type, declared_method);
if let Some(cached) = self.cache.get(&key) {
return *cached;
}
if let Some(vtable) = self.vtables.get(&runtime_type) {
if let Some(&resolved) = vtable.slots.get(&declared_method) {
self.cache.insert(key, resolved);
return resolved;
}
}
let Ok(method) = context.get_method(declared_method) else {
return declared_method;
};
if !method.is_virtual() {
return declared_method;
}
let is_interface_call = method
.declaring_type_rc()
.is_some_and(|dt| dt.is_interface());
let resolved = if is_interface_call {
self.resolve_interface(runtime_type, declared_method, &method, context)
} else {
self.resolve_virtual(runtime_type, &method, context)
};
self.cache.insert(key, resolved);
resolved
}
fn resolve_interface(
&self,
runtime_type: Token,
interface_method: Token,
base_method: &Method,
context: &EmulationContext,
) -> Token {
let Some(rt) = context.assembly().types().resolve(&runtime_type) else {
return interface_method;
};
if let Some(found) = Self::find_interface_impl(&rt, interface_method, base_method) {
return found;
}
if base_method.has_body() && !base_method.is_abstract() {
return interface_method;
}
interface_method
}
fn find_interface_impl(
type_info: &CilType,
interface_method: Token,
base_method: &Method,
) -> Option<Token> {
for (_, method_ref) in type_info.methods.iter() {
let Some(method) = method_ref.upgrade() else {
continue;
};
for (_, override_ref) in method.overrides.iter() {
let override_token = match override_ref {
CilTypeReference::MethodDef(weak) => weak.upgrade().map(|m| m.token),
CilTypeReference::MemberRef(rc) => Some(rc.token),
_ => None,
};
if override_token == Some(interface_method) {
return Some(method.token);
}
}
}
let method_name = &base_method.name;
for (_, method_ref) in type_info.methods.iter() {
let Some(method) = method_ref.upgrade() else {
continue;
};
if method.name != *method_name
|| method.is_static()
|| !signatures_match(&method.signature, &base_method.signature)
{
continue;
}
let has_conflicting_override = method.overrides.iter().any(|(_, override_ref)| {
let override_token = match override_ref {
CilTypeReference::MethodDef(weak) => weak.upgrade().map(|m| m.token),
CilTypeReference::MemberRef(rc) => Some(rc.token),
_ => None,
};
override_token.is_some() && override_token != Some(interface_method)
});
if has_conflicting_override {
continue;
}
return Some(method.token);
}
if let Some(base) = type_info.base() {
return Self::find_interface_impl(&base, interface_method, base_method);
}
None
}
fn resolve_virtual(
&self,
runtime_type: Token,
base_method: &Method,
context: &EmulationContext,
) -> Token {
let Some(rt) = context.assembly().types().resolve(&runtime_type) else {
return base_method.token;
};
if let Some(found) = Self::find_virtual_override(&rt, &base_method.name, base_method) {
return found;
}
base_method.token
}
fn find_virtual_override(
type_info: &CilType,
method_name: &str,
base_method: &Method,
) -> Option<Token> {
for (_, method_ref) in type_info.methods.iter() {
let Some(method) = method_ref.upgrade() else {
continue;
};
if method.is_virtual()
&& method.name == method_name
&& signatures_match(&method.signature, &base_method.signature)
{
return Some(method.token);
}
}
if let Some(base) = type_info.base() {
return Self::find_virtual_override(&base, method_name, base_method);
}
None
}
pub fn precompute_vtable(&self, runtime_type: Token, context: &EmulationContext) {
if self.vtables.contains_key(&runtime_type) {
return;
}
let Some(rt) = context.assembly().types().resolve(&runtime_type) else {
return;
};
let mut slots = HashMap::new();
Self::collect_virtual_slots(&rt, &mut slots);
Self::collect_interface_slots(&rt, &mut slots);
self.vtables.insert(runtime_type, VTable { slots });
}
fn collect_virtual_slots(type_info: &CilType, slots: &mut HashMap<Token, Token>) {
if let Some(base) = type_info.base() {
Self::collect_virtual_slots(&base, slots);
}
for (_, method_ref) in type_info.methods.iter() {
let Some(method) = method_ref.upgrade() else {
continue;
};
if !method.is_virtual() {
continue;
}
slots.insert(method.token, method.token);
for (_, override_ref) in method.overrides.iter() {
let override_token = match override_ref {
CilTypeReference::MethodDef(weak) => weak.upgrade().map(|m| m.token),
CilTypeReference::MemberRef(rc) => Some(rc.token),
_ => None,
};
if let Some(ot) = override_token {
slots.insert(ot, method.token);
}
}
}
}
fn collect_interface_slots(type_info: &CilType, slots: &mut HashMap<Token, Token>) {
for (_, iface_entry) in type_info.interfaces.iter() {
let Some(iface_type) = iface_entry.interface.upgrade() else {
continue;
};
for (_, iface_method_ref) in iface_type.methods.iter() {
let Some(iface_method) = iface_method_ref.upgrade() else {
continue;
};
if slots.contains_key(&iface_method.token) {
continue;
}
if let Some(impl_token) =
Self::find_interface_impl(type_info, iface_method.token, &iface_method)
{
slots.insert(iface_method.token, impl_token);
} else if iface_method.has_body() && !iface_method.is_abstract() {
slots.insert(iface_method.token, iface_method.token);
}
}
}
if let Some(base) = type_info.base() {
Self::collect_interface_slots(&base, slots);
}
}
}
impl Default for DispatchResolver {
fn default() -> Self {
Self::new()
}
}
fn signatures_match(candidate: &SignatureMethod, base: &SignatureMethod) -> bool {
if candidate.param_count != base.param_count {
return false;
}
if candidate.param_count_generic != base.param_count_generic {
return false;
}
if candidate.params.len() != base.params.len() {
return false;
}
for (cp, bp) in candidate.params.iter().zip(base.params.iter()) {
if cp.base != bp.base {
return false;
}
}
true
}
#[cfg(test)]
mod tests {
use crate::emulation::engine::dispatch::DispatchResolver;
#[test]
fn test_resolver_creation() {
let resolver = DispatchResolver::new();
assert!(resolver.cache.is_empty());
}
}