use formalang::ir::{DispatchKind, ImplTarget, IrExpr, ResolvedType};
use wasm_encoder::{InstructionSink, MemArg};
use super::{LowerContext, LowerError, lower_expr};
use crate::layout::VTABLE_SLOT_SIZE;
use crate::module::MEMORY_INDEX;
use crate::module_lowering::impl_target_key;
use crate::types::{CLOSURE_ENV_OFFSET, CLOSURE_FUNCREF_OFFSET};
pub fn lower_function_call(
expr: &IrExpr,
sink: &mut InstructionSink<'_>,
ctx: &LowerContext<'_>,
) -> Result<(), LowerError> {
let IrExpr::FunctionCall {
path,
function_id,
args,
..
} = expr
else {
return Err(LowerError::NotYetImplemented {
what: "lower_function_call called with non-FunctionCall expression".to_owned(),
});
};
let module = ctx.module().ok();
let id = if let Some(m) = module {
let path_last = path.last().map(String::as_str);
let arg_labels: Vec<&str> = args
.iter()
.filter_map(|(label, _)| label.as_deref())
.collect();
let matches_signature = |f: &formalang::ir::IrFunction| -> bool {
Some(f.name.as_str()) == path_last
&& f.params.len() == args.len()
&& arg_labels
.iter()
.all(|label| f.params.iter().any(|p| p.name == *label))
};
let id_valid = function_id.and_then(|id| {
let f = m.functions.get(id.0 as usize)?;
if matches_signature(f) { Some(id) } else { None }
});
let resolved_id = id_valid.or_else(|| {
m.functions.iter().enumerate().find_map(|(i, f)| {
if matches_signature(f) {
u32::try_from(i).ok().map(formalang::ir::FunctionId)
} else {
None
}
})
});
resolved_id
.or(*function_id)
.ok_or_else(|| LowerError::UnresolvedFunctionCall { path: path.clone() })?
} else {
function_id.ok_or_else(|| LowerError::UnresolvedFunctionCall { path: path.clone() })?
};
let wasm_idx = ctx
.functions
.get(id)
.ok_or(LowerError::UnknownFunction(id))?;
let callee = module.and_then(|m| m.functions.get(id.0 as usize));
for (param_name, arg) in args {
let target = callee.and_then(|f| {
param_name.as_ref().and_then(|n| {
f.params
.iter()
.find(|p| p.name == *n)
.and_then(|p| p.ty.as_ref())
})
});
if let Some(t) = target {
super::optional::lower_coerced(arg, t, sink, ctx)?;
} else {
lower_expr(arg, sink, ctx)?;
}
}
sink.call(wasm_idx);
Ok(())
}
pub fn lower_call_closure(
expr: &IrExpr,
sink: &mut InstructionSink<'_>,
ctx: &LowerContext<'_>,
) -> Result<(), LowerError> {
use formalang::ir::ResolvedType;
use wasm_encoder::{MemArg, ValType};
let IrExpr::CallClosure { closure, args, .. } = expr else {
return Err(LowerError::NotYetImplemented {
what: "lower_call_closure called with non-CallClosure expression".to_owned(),
});
};
let closure_ty = closure.ty().clone();
let ResolvedType::Closure { .. } = &closure_ty else {
return Err(LowerError::NotYetImplemented {
what: format!("CallClosure on non-closure-typed value {closure_ty:?}"),
});
};
let table_idx = ctx.closure_table_index()?;
let type_idx = ctx.closure_type_index(&closure_ty)?;
let base_local = ctx.next_scratch_local(ValType::I32)?;
lower_expr(closure, sink, ctx)?;
sink.local_set(base_local);
sink.local_get(base_local);
sink.i32_load(MemArg {
offset: u64::from(CLOSURE_ENV_OFFSET),
align: 2, memory_index: MEMORY_INDEX,
});
for (_, arg) in args {
lower_expr(arg, sink, ctx)?;
}
sink.local_get(base_local);
sink.i32_load(MemArg {
offset: u64::from(CLOSURE_FUNCREF_OFFSET),
align: 2,
memory_index: MEMORY_INDEX,
});
sink.call_indirect(table_idx, type_idx);
Ok(())
}
pub fn lower_method_call(
expr: &IrExpr,
sink: &mut InstructionSink<'_>,
ctx: &LowerContext<'_>,
) -> Result<(), LowerError> {
let IrExpr::MethodCall {
receiver,
method_idx,
args,
dispatch,
..
} = expr
else {
return Err(LowerError::NotYetImplemented {
what: "lower_method_call called with non-MethodCall expression".to_owned(),
});
};
match dispatch {
DispatchKind::Static { impl_id } => {
lower_static_method_call(*impl_id, *method_idx, receiver, args, sink, ctx)
}
DispatchKind::Virtual {
trait_id,
method_name: _,
} => lower_virtual_method_call(*trait_id, *method_idx, receiver, args, sink, ctx),
}
}
fn lower_static_method_call(
impl_id: formalang::ir::ImplId,
method_idx: formalang::ir::MethodIdx,
receiver: &IrExpr,
args: &[(Option<String>, IrExpr)],
sink: &mut InstructionSink<'_>,
ctx: &LowerContext<'_>,
) -> Result<(), LowerError> {
let methods = ctx
.methods
.ok_or(LowerError::MissingContext { what: "methods" })?;
let wasm_idx = methods
.get((impl_id, method_idx))
.ok_or(LowerError::UnknownMethod {
impl_id,
method_idx,
})?;
lower_expr(receiver, sink, ctx)?;
let method_sig = ctx
.module()
.ok()
.and_then(|m| m.impls.get(impl_id.0 as usize))
.and_then(|i| i.functions.get(method_idx.0 as usize));
for (param_name, arg) in args {
let target = method_sig.and_then(|sig| {
param_name.as_ref().and_then(|n| {
sig.params
.iter()
.find(|p| p.name == *n)
.and_then(|p| p.ty.as_ref())
})
});
if let Some(t) = target {
super::optional::lower_coerced(arg, t, sink, ctx)?;
} else {
lower_expr(arg, sink, ctx)?;
}
}
sink.call(wasm_idx);
Ok(())
}
fn lower_trait_typed_dispatch(
method_idx: formalang::ir::MethodIdx,
receiver: &IrExpr,
args: &[(Option<String>, IrExpr)],
sink: &mut InstructionSink<'_>,
ctx: &LowerContext<'_>,
trait_id: formalang::ir::TraitId,
) -> Result<(), LowerError> {
use crate::layout::{POINTER_SIZE, VTABLE_SLOT_SIZE};
use crate::module::MEMORY_INDEX;
use wasm_encoder::{MemArg, ValType};
let table_idx = ctx.method_table_index()?;
let type_idx = ctx.virtual_call_type_index(trait_id, method_idx)?;
let cell_scratch = ctx.next_scratch_local(ValType::I32)?;
lower_expr(receiver, sink, ctx)?;
sink.local_set(cell_scratch);
let mem_arg = |off: u64| MemArg {
offset: off,
align: 2,
memory_index: MEMORY_INDEX,
};
sink.local_get(cell_scratch);
sink.i32_load(mem_arg(u64::from(POINTER_SIZE)));
let trait_method_sig = ctx
.module()
.ok()
.and_then(|m| m.traits.get(trait_id.0 as usize))
.and_then(|t| t.methods.get(method_idx.0 as usize));
for (param_name, arg) in args {
let target_ty = trait_method_sig.and_then(|sig| {
param_name.as_ref().and_then(|n| {
sig.params
.iter()
.find(|p| p.name == *n)
.and_then(|p| p.ty.as_ref())
})
});
if let Some(t) = target_ty {
super::optional::lower_coerced(arg, t, sink, ctx)?;
} else {
lower_expr(arg, sink, ctx)?;
}
}
sink.local_get(cell_scratch);
sink.i32_load(mem_arg(0)); let method_byte_off = u64::from(method_idx.0)
.checked_mul(u64::from(VTABLE_SLOT_SIZE))
.ok_or_else(|| LowerError::NotYetImplemented {
what: "vtable slot byte offset overflow".to_owned(),
})?;
let method_byte_off_signed = i32::try_from(method_byte_off).unwrap_or(i32::MAX);
sink.i32_const(method_byte_off_signed);
sink.i32_add();
sink.i32_load(MemArg {
offset: 0,
align: 2,
memory_index: MEMORY_INDEX,
});
sink.call_indirect(table_idx, type_idx);
Ok(())
}
fn lower_virtual_method_call(
trait_id: formalang::ir::TraitId,
method_idx: formalang::ir::MethodIdx,
receiver: &IrExpr,
args: &[(Option<String>, IrExpr)],
sink: &mut InstructionSink<'_>,
ctx: &LowerContext<'_>,
) -> Result<(), LowerError> {
if let ResolvedType::Trait(_) = receiver.ty() {
return lower_trait_typed_dispatch(method_idx, receiver, args, sink, ctx, trait_id);
}
let target = if let Some(sid) = crate::compound::struct_id_of(receiver.ty()) {
ImplTarget::Struct(sid)
} else if let Some(eid) = crate::compound::enum_id_of(receiver.ty()) {
ImplTarget::Enum(eid)
} else {
return Err(LowerError::UnsupportedVirtualReceiver {
ty: receiver.ty().clone(),
});
};
let table_idx = ctx.method_table_index()?;
let type_idx = ctx.virtual_call_type_index(trait_id, method_idx)?;
let vtable_base = ctx.vtable_offset(trait_id, impl_target_key(target))?;
let slot_offset = u64::from(vtable_base)
.checked_add(
u64::from(method_idx.0)
.checked_mul(u64::from(VTABLE_SLOT_SIZE))
.ok_or_else(|| LowerError::NotYetImplemented {
what: "vtable slot offset overflow".to_owned(),
})?,
)
.ok_or_else(|| LowerError::NotYetImplemented {
what: "vtable slot offset overflow".to_owned(),
})?;
lower_expr(receiver, sink, ctx)?;
let trait_method_sig = ctx
.module()
.ok()
.and_then(|m| m.traits.get(trait_id.0 as usize))
.and_then(|t| t.methods.get(method_idx.0 as usize));
for (param_name, arg) in args {
let target_ty = trait_method_sig.and_then(|sig| {
param_name.as_ref().and_then(|n| {
sig.params
.iter()
.find(|p| p.name == *n)
.and_then(|p| p.ty.as_ref())
})
});
if let Some(t) = target_ty {
super::optional::lower_coerced(arg, t, sink, ctx)?;
} else {
lower_expr(arg, sink, ctx)?;
}
}
sink.i32_const(0)
.i32_load(MemArg {
offset: slot_offset,
align: 2, memory_index: MEMORY_INDEX,
})
.call_indirect(table_idx, type_idx);
Ok(())
}