use std::sync::atomic::{AtomicU32, Ordering};
use dashmap::DashMap;
use log::warn;
use crate::{
emulation::tokens,
metadata::{tables::GenericParamAttributes, token::Token},
};
pub struct GenericRegistry {
type_instantiations: DashMap<(Token, Vec<Token>), Token>,
method_instantiations: DashMap<(Token, Vec<Token>), Token>,
reverse_lookup: DashMap<Token, (Token, Vec<Token>)>,
next_token: AtomicU32,
}
impl GenericRegistry {
#[must_use]
pub fn new() -> Self {
Self {
type_instantiations: DashMap::new(),
method_instantiations: DashMap::new(),
reverse_lookup: DashMap::new(),
next_token: AtomicU32::new(1),
}
}
pub fn get_or_create_type(&self, open_type: Token, args: Vec<Token>) -> Token {
let key = (open_type, args.clone());
if let Some(existing) = self.type_instantiations.get(&key) {
return *existing;
}
let id = self.next_token.fetch_add(1, Ordering::SeqCst);
let token = Token::new(tokens::ranges::GENERIC_INSTANTIATION_BASE | id);
self.type_instantiations.insert(key, token);
self.reverse_lookup.insert(token, (open_type, args));
token
}
pub fn get_or_create_method(&self, open_method: Token, args: Vec<Token>) -> Token {
let key = (open_method, args.clone());
if let Some(existing) = self.method_instantiations.get(&key) {
return *existing;
}
let id = self.next_token.fetch_add(1, Ordering::SeqCst);
let token = Token::new(tokens::ranges::GENERIC_INSTANTIATION_BASE | id);
self.method_instantiations.insert(key, token);
self.reverse_lookup.insert(token, (open_method, args));
token
}
#[must_use]
pub fn lookup(&self, instantiated: Token) -> Option<(Token, Vec<Token>)> {
self.reverse_lookup
.get(&instantiated)
.map(|entry| entry.value().clone())
}
#[must_use]
pub fn is_instantiation(&self, token: Token) -> bool {
token.value() & tokens::ranges::GENERIC_INSTANTIATION_MASK
== tokens::ranges::GENERIC_INSTANTIATION_BASE
}
}
pub fn validate_constraints<F>(
param_name: &str,
generic_params: &[(u32, GenericParamAttributes, String)],
type_args: &[Token],
is_value_type: F,
) where
F: Fn(Token) -> Option<bool>,
{
for (i, arg_token) in type_args.iter().enumerate() {
let Some((_, flags, name)) = generic_params.get(i) else {
continue;
};
if flags.contains(GenericParamAttributes::REFERENCE_TYPE_CONSTRAINT) {
if let Some(true) = is_value_type(*arg_token) {
warn!(
"Generic constraint violation in {param_name}: type arg {name} (!!{i}) \
requires reference type but got value type 0x{:08X}",
arg_token.value()
);
}
}
if flags.contains(GenericParamAttributes::NOT_NULLABLE_VALUE_TYPE_CONSTRAINT) {
if let Some(false) = is_value_type(*arg_token) {
warn!(
"Generic constraint violation in {param_name}: type arg {name} (!!{i}) \
requires value type but got reference type 0x{:08X}",
arg_token.value()
);
}
}
}
}
pub fn check_variance_compatibility<F>(
generic_params: &[(u32, GenericParamAttributes, String)],
source_args: &[Token],
target_args: &[Token],
is_assignable: F,
) -> bool
where
F: Fn(Token, Token) -> bool,
{
if source_args.len() != target_args.len() {
return false;
}
for (i, (src, tgt)) in source_args.iter().zip(target_args.iter()).enumerate() {
if src == tgt {
continue;
}
let Some((_, flags, _)) = generic_params.get(i) else {
continue;
};
let variance = flags.bits() & GenericParamAttributes::VARIANCE_MASK.bits();
if variance == GenericParamAttributes::COVARIANT.bits() {
if !is_assignable(*src, *tgt) {
warn!(
"Generic variance mismatch at position {i}: covariant parameter \
requires 0x{:08X} assignable to 0x{:08X}",
src.value(),
tgt.value()
);
return false;
}
} else if variance == GenericParamAttributes::CONTRAVARIANT.bits() {
if !is_assignable(*tgt, *src) {
warn!(
"Generic variance mismatch at position {i}: contravariant parameter \
requires 0x{:08X} assignable to 0x{:08X}",
tgt.value(),
src.value()
);
return false;
}
} else {
warn!(
"Generic variance mismatch at position {i}: invariant parameter \
requires exact match but got 0x{:08X} vs 0x{:08X}",
src.value(),
tgt.value()
);
return false;
}
}
true
}
impl Default for GenericRegistry {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use crate::{emulation::engine::generics::GenericRegistry, metadata::token::Token};
#[test]
fn test_type_identity() {
let registry = GenericRegistry::new();
let open = Token::new(0x0200_0001);
let arg_int = Token::new(0x0100_0010);
let arg_str = Token::new(0x0100_0011);
let list_int_1 = registry.get_or_create_type(open, vec![arg_int]);
let list_int_2 = registry.get_or_create_type(open, vec![arg_int]);
let list_str = registry.get_or_create_type(open, vec![arg_str]);
assert_eq!(list_int_1, list_int_2);
assert_ne!(list_int_1, list_str);
}
#[test]
fn test_reverse_lookup() {
let registry = GenericRegistry::new();
let open = Token::new(0x0200_0001);
let args = vec![Token::new(0x0100_0010)];
let instantiated = registry.get_or_create_type(open, args.clone());
let (found_open, found_args) = registry.lookup(instantiated).unwrap();
assert_eq!(found_open, open);
assert_eq!(found_args, args);
}
#[test]
fn test_method_instantiation() {
let registry = GenericRegistry::new();
let open = Token::new(0x0600_0001);
let args = vec![Token::new(0x0100_0010)];
let inst1 = registry.get_or_create_method(open, args.clone());
let inst2 = registry.get_or_create_method(open, args.clone());
assert_eq!(inst1, inst2);
let (found_open, found_args) = registry.lookup(inst1).unwrap();
assert_eq!(found_open, open);
assert_eq!(found_args, args);
}
#[test]
fn test_is_instantiation() {
let registry = GenericRegistry::new();
let open = Token::new(0x0200_0001);
let inst = registry.get_or_create_type(open, vec![Token::new(0x0100_0010)]);
assert!(registry.is_instantiation(inst));
assert!(!registry.is_instantiation(open));
}
#[test]
fn test_unknown_lookup_returns_none() {
let registry = GenericRegistry::new();
assert!(registry.lookup(Token::new(0xF100_9999)).is_none());
}
}