use std::collections::{HashMap, HashSet};
use wasm_encoder::{Function, Instruction};
use super::super::WasmGcError;
use super::super::lists::{emit_record_eq_inline, emit_sum_eq_inline};
use super::super::types::TypeRegistry;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum EqKind {
Record,
Sum,
OptionEq,
ResultEq,
TupleEq,
}
#[derive(Default)]
pub(crate) struct EqHelperRegistry {
order: Vec<String>,
kinds: HashMap<String, EqKind>,
slots: HashMap<String, (u32, u32)>,
}
impl EqHelperRegistry {
pub(crate) fn new() -> Self {
Self::default()
}
pub(crate) fn register(&mut self, type_name: &str, kind: EqKind) {
if !self.kinds.contains_key(type_name) {
self.order.push(type_name.to_string());
self.kinds.insert(type_name.to_string(), kind);
}
}
pub(crate) fn register_transitive(
&mut self,
type_name: &str,
kind: EqKind,
registry: &TypeRegistry,
) {
if self.kinds.contains_key(type_name) {
return;
}
let mut seen = std::collections::HashSet::new();
let resolvable = match kind {
EqKind::Record => {
super::super::lists::record_fields_resolvable(type_name, registry, &mut seen)
}
EqKind::Sum => {
super::super::lists::sum_fields_resolvable(type_name, registry, &mut seen)
}
EqKind::OptionEq | EqKind::ResultEq | EqKind::TupleEq => true,
};
if !resolvable {
return;
}
self.register(type_name, kind);
match kind {
EqKind::Record => {
if let Some(fields) = registry.record_fields.get(type_name) {
for (_, field_ty) in fields {
self.register_field_type(field_ty.trim(), registry);
}
}
}
EqKind::Sum => {
let variants: Vec<_> = registry
.variants
.values()
.flat_map(|vs| vs.iter())
.filter(|v| v.parent == type_name)
.cloned()
.collect();
for v in &variants {
for field_ty in &v.fields {
self.register_field_type(field_ty.trim(), registry);
}
}
}
EqKind::OptionEq => {
if let Some(inner) = type_name
.strip_prefix("Option<")
.and_then(|s| s.strip_suffix('>'))
{
self.register_field_type(inner.trim(), registry);
}
}
EqKind::ResultEq => {
if let Some((ok, err)) = parse_result_kv(type_name) {
self.register_field_type(ok.trim(), registry);
self.register_field_type(err.trim(), registry);
}
}
EqKind::TupleEq => {
if let Some(elems) = parse_tuple_elems(type_name) {
for e in elems {
self.register_field_type(e.trim(), registry);
}
}
}
}
}
fn register_field_type(&mut self, field_ty: &str, registry: &TypeRegistry) {
if matches!(
field_ty,
"Int" | "Float" | "Bool" | "String" | "Unit" | "Byte" | "Char"
) {
return;
}
if registry.record_fields.contains_key(field_ty) {
self.register_transitive(field_ty, EqKind::Record, registry);
return;
}
if registry
.variants
.values()
.flat_map(|v| v.iter())
.any(|v| v.parent == field_ty)
{
self.register_transitive(field_ty, EqKind::Sum, registry);
return;
}
if let Some(inner) = field_ty
.strip_prefix("Option<")
.and_then(|s| s.strip_suffix('>'))
{
self.register_transitive(field_ty, EqKind::OptionEq, registry);
self.register_field_type(inner.trim(), registry);
} else if field_ty.starts_with("Result<") && field_ty.ends_with('>') {
self.register_transitive(field_ty, EqKind::ResultEq, registry);
if let Some((ok, err)) = parse_result_kv(field_ty) {
self.register_field_type(ok.trim(), registry);
self.register_field_type(err.trim(), registry);
}
} else if field_ty.starts_with("Tuple<") && field_ty.ends_with('>') {
self.register_transitive(field_ty, EqKind::TupleEq, registry);
if let Some(elems) = parse_tuple_elems(field_ty) {
for elem in elems {
self.register_field_type(elem.trim(), registry);
}
}
} else if let Some(inner) = field_ty
.strip_prefix("List<")
.and_then(|s| s.strip_suffix('>'))
{
self.register_field_type(inner.trim(), registry);
} else if let Some(inner) = field_ty
.strip_prefix("Vector<")
.and_then(|s| s.strip_suffix('>'))
{
self.register_field_type(inner.trim(), registry);
} else if let Some(inner) = field_ty
.strip_prefix("Map<")
.and_then(|s| s.strip_suffix('>'))
{
let bytes = inner.as_bytes();
let mut depth: i32 = 0;
for (idx, b) in bytes.iter().enumerate() {
match b {
b'<' | b'(' => depth += 1,
b'>' | b')' => depth -= 1,
b',' if depth == 0 => {
let k = inner[..idx].trim();
let v = inner[idx + 1..].trim();
self.register_field_type(k, registry);
self.register_field_type(v, registry);
return;
}
_ => {}
}
}
}
}
pub(crate) fn iter(&self) -> impl Iterator<Item = (&str, EqKind)> + '_ {
self.order.iter().map(|n| (n.as_str(), self.kinds[n]))
}
pub(crate) fn assign_slots(&mut self, next_fn_idx: &mut u32, next_type_idx: &mut u32) {
for name in &self.order {
self.slots
.insert(name.clone(), (*next_fn_idx, *next_type_idx));
*next_fn_idx += 1;
*next_type_idx += 1;
}
}
pub(crate) fn lookup_fn_idx(&self, type_name: &str) -> Option<u32> {
self.slots.get(type_name).map(|(f, _)| *f)
}
pub(crate) fn lookup_type_idx(&self, type_name: &str) -> Option<u32> {
self.slots.get(type_name).map(|(_, t)| *t)
}
pub(crate) fn emit_helper_types(&self, types: &mut wasm_encoder::TypeSection) {
let eq_ref = wasm_encoder::ValType::Ref(wasm_encoder::RefType {
nullable: true,
heap_type: wasm_encoder::HeapType::Abstract {
shared: false,
ty: wasm_encoder::AbstractHeapType::Eq,
},
});
for _ in &self.order {
types
.ty()
.function([eq_ref, eq_ref], [wasm_encoder::ValType::I32]);
}
}
pub(crate) fn emit_helper_bodies(
&self,
codes: &mut wasm_encoder::CodeSection,
registry: &TypeRegistry,
string_eq_fn_idx: Option<u32>,
compound_lookup: &HashMap<String, u32>,
) -> Result<(), WasmGcError> {
let mut helper_idx_map: HashMap<String, u32> = self
.slots
.iter()
.map(|(n, (fn_idx, _))| (n.clone(), *fn_idx))
.collect();
for (canonical, fn_idx) in compound_lookup {
helper_idx_map.insert(canonical.clone(), *fn_idx);
}
for name in &self.order {
let kind = self.kinds[name];
let self_fn_idx = self.slots.get(name).map(|(f, _)| *f);
match kind {
EqKind::Sum => {
let mut f = Function::new(Vec::new());
emit_sum_eq_inline(
&mut f,
name,
registry,
0,
1,
string_eq_fn_idx,
&helper_idx_map,
self_fn_idx,
)?;
f.instruction(&Instruction::End);
codes.function(&f);
}
EqKind::Record => {
let r_idx = registry
.record_type_idx(name)
.ok_or(WasmGcError::Validation(format!(
"eq helper for record `{name}`: record not registered"
)))?;
let r_ref = wasm_encoder::ValType::Ref(wasm_encoder::RefType {
nullable: true,
heap_type: wasm_encoder::HeapType::Concrete(r_idx),
});
let mut f = Function::new(vec![(2, r_ref)]);
let r_heap = wasm_encoder::HeapType::Concrete(r_idx);
f.instruction(&Instruction::LocalGet(0));
f.instruction(&Instruction::RefCastNonNull(r_heap));
f.instruction(&Instruction::LocalSet(2));
f.instruction(&Instruction::LocalGet(1));
f.instruction(&Instruction::RefCastNonNull(r_heap));
f.instruction(&Instruction::LocalSet(3));
emit_record_eq_inline(
&mut f,
name,
registry,
2,
3,
string_eq_fn_idx,
&helper_idx_map,
self_fn_idx,
)?;
f.instruction(&Instruction::End);
codes.function(&f);
}
EqKind::OptionEq => {
let f = emit_option_eq_body(name, registry, string_eq_fn_idx, &helper_idx_map)?;
codes.function(&f);
}
EqKind::ResultEq => {
let f = emit_result_eq_body(name, registry, string_eq_fn_idx, &helper_idx_map)?;
codes.function(&f);
}
EqKind::TupleEq => {
let f = emit_tuple_eq_body(name, registry, string_eq_fn_idx, &helper_idx_map)?;
codes.function(&f);
}
}
}
Ok(())
}
pub(crate) fn needs_string_eq(&self, registry: &TypeRegistry) -> bool {
let mut visiting: HashSet<String> = HashSet::new();
for name in &self.order {
if type_has_string_field(name, self.kinds[name], registry, &mut visiting) {
return true;
}
}
false
}
}
fn parse_result_kv(canonical: &str) -> Option<(&str, &str)> {
let inner = canonical
.trim()
.strip_prefix("Result<")?
.strip_suffix('>')?;
let bytes = inner.as_bytes();
let mut depth: i32 = 0;
for (idx, b) in bytes.iter().enumerate() {
match b {
b'<' | b'(' => depth += 1,
b'>' | b')' => depth -= 1,
b',' if depth == 0 => {
return Some((inner[..idx].trim(), inner[idx + 1..].trim()));
}
_ => {}
}
}
None
}
fn parse_tuple_elems(canonical: &str) -> Option<Vec<&str>> {
let inner = canonical.trim().strip_prefix("Tuple<")?.strip_suffix('>')?;
let bytes = inner.as_bytes();
let mut depth: i32 = 0;
let mut start = 0;
let mut out = Vec::new();
for (idx, b) in bytes.iter().enumerate() {
match b {
b'<' | b'(' => depth += 1,
b'>' | b')' => depth -= 1,
b',' if depth == 0 => {
out.push(inner[start..idx].trim());
start = idx + 1;
}
_ => {}
}
}
out.push(inner[start..].trim());
Some(out)
}
fn emit_option_eq_body(
canonical: &str,
registry: &TypeRegistry,
string_eq_fn_idx: Option<u32>,
helper_idx_map: &HashMap<String, u32>,
) -> Result<Function, WasmGcError> {
let opt_idx = registry
.option_type_idx(canonical)
.ok_or(WasmGcError::Validation(format!(
"eq helper for `{canonical}`: option not registered"
)))?;
let inner = TypeRegistry::option_element_type(canonical).ok_or(WasmGcError::Validation(
format!("eq helper for `{canonical}`: can't parse inner"),
))?;
let opt_ref = wasm_encoder::ValType::Ref(wasm_encoder::RefType {
nullable: true,
heap_type: wasm_encoder::HeapType::Concrete(opt_idx),
});
let mut f = Function::new(vec![(2, opt_ref)]);
let opt_heap = wasm_encoder::HeapType::Concrete(opt_idx);
f.instruction(&Instruction::LocalGet(0));
f.instruction(&Instruction::RefCastNonNull(opt_heap));
f.instruction(&Instruction::LocalSet(2));
f.instruction(&Instruction::LocalGet(1));
f.instruction(&Instruction::RefCastNonNull(opt_heap));
f.instruction(&Instruction::LocalSet(3));
f.instruction(&Instruction::LocalGet(2));
f.instruction(&Instruction::StructGet {
struct_type_index: opt_idx,
field_index: 0,
});
f.instruction(&Instruction::LocalGet(3));
f.instruction(&Instruction::StructGet {
struct_type_index: opt_idx,
field_index: 0,
});
f.instruction(&Instruction::I32Ne);
f.instruction(&Instruction::If(wasm_encoder::BlockType::Empty));
f.instruction(&Instruction::I32Const(0));
f.instruction(&Instruction::Return);
f.instruction(&Instruction::End);
f.instruction(&Instruction::LocalGet(2));
f.instruction(&Instruction::StructGet {
struct_type_index: opt_idx,
field_index: 0,
});
f.instruction(&Instruction::I32Eqz);
f.instruction(&Instruction::If(wasm_encoder::BlockType::Empty));
f.instruction(&Instruction::I32Const(1));
f.instruction(&Instruction::Return);
f.instruction(&Instruction::End);
f.instruction(&Instruction::LocalGet(2));
f.instruction(&Instruction::StructGet {
struct_type_index: opt_idx,
field_index: 1,
});
f.instruction(&Instruction::LocalGet(3));
f.instruction(&Instruction::StructGet {
struct_type_index: opt_idx,
field_index: 1,
});
emit_inner_eq_dispatch(
&mut f,
inner.trim(),
registry,
string_eq_fn_idx,
helper_idx_map,
)?;
f.instruction(&Instruction::End);
Ok(f)
}
fn emit_result_eq_body(
canonical: &str,
registry: &TypeRegistry,
string_eq_fn_idx: Option<u32>,
helper_idx_map: &HashMap<String, u32>,
) -> Result<Function, WasmGcError> {
let res_idx = registry
.result_type_idx(canonical)
.ok_or(WasmGcError::Validation(format!(
"eq helper for `{canonical}`: result not registered"
)))?;
let (ok_inner, err_inner) = parse_result_kv(canonical).ok_or(WasmGcError::Validation(
format!("eq helper for `{canonical}`: can't parse inner"),
))?;
let res_ref = wasm_encoder::ValType::Ref(wasm_encoder::RefType {
nullable: true,
heap_type: wasm_encoder::HeapType::Concrete(res_idx),
});
let mut f = Function::new(vec![(2, res_ref)]);
let res_heap = wasm_encoder::HeapType::Concrete(res_idx);
f.instruction(&Instruction::LocalGet(0));
f.instruction(&Instruction::RefCastNonNull(res_heap));
f.instruction(&Instruction::LocalSet(2));
f.instruction(&Instruction::LocalGet(1));
f.instruction(&Instruction::RefCastNonNull(res_heap));
f.instruction(&Instruction::LocalSet(3));
f.instruction(&Instruction::LocalGet(2));
f.instruction(&Instruction::StructGet {
struct_type_index: res_idx,
field_index: 0,
});
f.instruction(&Instruction::LocalGet(3));
f.instruction(&Instruction::StructGet {
struct_type_index: res_idx,
field_index: 0,
});
f.instruction(&Instruction::I32Ne);
f.instruction(&Instruction::If(wasm_encoder::BlockType::Empty));
f.instruction(&Instruction::I32Const(0));
f.instruction(&Instruction::Return);
f.instruction(&Instruction::End);
f.instruction(&Instruction::LocalGet(2));
f.instruction(&Instruction::StructGet {
struct_type_index: res_idx,
field_index: 0,
});
f.instruction(&Instruction::If(wasm_encoder::BlockType::Result(
wasm_encoder::ValType::I32,
)));
f.instruction(&Instruction::LocalGet(2));
f.instruction(&Instruction::StructGet {
struct_type_index: res_idx,
field_index: 1,
});
f.instruction(&Instruction::LocalGet(3));
f.instruction(&Instruction::StructGet {
struct_type_index: res_idx,
field_index: 1,
});
emit_inner_eq_dispatch(
&mut f,
ok_inner.trim(),
registry,
string_eq_fn_idx,
helper_idx_map,
)?;
f.instruction(&Instruction::Else);
f.instruction(&Instruction::LocalGet(2));
f.instruction(&Instruction::StructGet {
struct_type_index: res_idx,
field_index: 2,
});
f.instruction(&Instruction::LocalGet(3));
f.instruction(&Instruction::StructGet {
struct_type_index: res_idx,
field_index: 2,
});
emit_inner_eq_dispatch(
&mut f,
err_inner.trim(),
registry,
string_eq_fn_idx,
helper_idx_map,
)?;
f.instruction(&Instruction::End);
f.instruction(&Instruction::End);
Ok(f)
}
fn emit_tuple_eq_body(
canonical: &str,
registry: &TypeRegistry,
string_eq_fn_idx: Option<u32>,
helper_idx_map: &HashMap<String, u32>,
) -> Result<Function, WasmGcError> {
let tup_idx = registry
.tuple_type_idx(canonical)
.ok_or(WasmGcError::Validation(format!(
"eq helper for `{canonical}`: tuple not registered"
)))?;
let elems = parse_tuple_elems(canonical).ok_or(WasmGcError::Validation(format!(
"eq helper for `{canonical}`: can't parse elements"
)))?;
let tup_ref = wasm_encoder::ValType::Ref(wasm_encoder::RefType {
nullable: true,
heap_type: wasm_encoder::HeapType::Concrete(tup_idx),
});
let mut f = Function::new(vec![(2, tup_ref)]);
let tup_heap = wasm_encoder::HeapType::Concrete(tup_idx);
f.instruction(&Instruction::LocalGet(0));
f.instruction(&Instruction::RefCastNonNull(tup_heap));
f.instruction(&Instruction::LocalSet(2));
f.instruction(&Instruction::LocalGet(1));
f.instruction(&Instruction::RefCastNonNull(tup_heap));
f.instruction(&Instruction::LocalSet(3));
for (i, elem) in elems.iter().enumerate() {
f.instruction(&Instruction::LocalGet(2));
f.instruction(&Instruction::StructGet {
struct_type_index: tup_idx,
field_index: i as u32,
});
f.instruction(&Instruction::LocalGet(3));
f.instruction(&Instruction::StructGet {
struct_type_index: tup_idx,
field_index: i as u32,
});
emit_inner_eq_dispatch(
&mut f,
elem.trim(),
registry,
string_eq_fn_idx,
helper_idx_map,
)?;
if i > 0 {
f.instruction(&Instruction::I32And);
}
}
f.instruction(&Instruction::End);
Ok(f)
}
fn emit_inner_eq_dispatch(
f: &mut Function,
inner: &str,
registry: &TypeRegistry,
string_eq_fn_idx: Option<u32>,
helper_idx_map: &HashMap<String, u32>,
) -> Result<(), WasmGcError> {
let resolved: String = if let Some(under) = registry.newtype_underlying(inner) {
under.to_string()
} else {
inner.to_string()
};
match resolved.as_str() {
"Int" => {
f.instruction(&Instruction::I64Eq);
}
"Bool" => {
f.instruction(&Instruction::I32Eq);
}
"Float" => {
f.instruction(&Instruction::F64Eq);
}
"String" => {
let eq_fn = string_eq_fn_idx.ok_or(WasmGcError::Validation(
"carrier eq with String inner needs __wasmgc_string_eq".into(),
))?;
f.instruction(&Instruction::Call(eq_fn));
}
other if helper_idx_map.contains_key(other) => {
f.instruction(&Instruction::Call(helper_idx_map[other]));
}
other => {
return Err(WasmGcError::Validation(format!(
"carrier eq inner type `{other}` has no eq dispatch"
)));
}
}
Ok(())
}
fn type_has_string_field(
name: &str,
kind: EqKind,
registry: &TypeRegistry,
visiting: &mut HashSet<String>,
) -> bool {
if !visiting.insert(name.to_string()) {
return false;
}
match kind {
EqKind::Record => registry
.record_fields
.get(name)
.map(|fs| fs.iter().any(|(_, t)| t.trim() == "String"))
.unwrap_or(false),
EqKind::Sum => registry
.variants
.values()
.flat_map(|vs| vs.iter())
.filter(|v| v.parent == name)
.any(|v| v.fields.iter().any(|t| t.trim() == "String")),
EqKind::OptionEq | EqKind::ResultEq | EqKind::TupleEq => name.contains("String"),
}
}