use formalang::ast::PrimitiveType;
use formalang::ir::{BindingId, IrBlockStatement, IrExpr, IrModule, ResolvedType, StructId};
use wasm_encoder::{Function, InstructionSink, ValType};
use super::{
BindingMap, ClosureCallContext, FunctionMap, LowerContext, LowerError, MethodMap,
ScratchAllocator, lower_expr,
};
use crate::types::body_value_type;
pub fn lower_block(
expr: &IrExpr,
sink: &mut InstructionSink<'_>,
ctx: &LowerContext<'_>,
) -> Result<(), LowerError> {
let IrExpr::Block {
statements, result, ..
} = expr
else {
return Err(LowerError::NotYetImplemented {
what: "lower_block called with non-Block expression".to_owned(),
});
};
for stmt in statements {
lower_block_statement(stmt, sink, ctx)?;
}
lower_expr(result, sink, ctx)
}
pub(super) fn lower_block_statement(
stmt: &IrBlockStatement,
sink: &mut InstructionSink<'_>,
ctx: &LowerContext<'_>,
) -> Result<(), LowerError> {
match stmt {
IrBlockStatement::Let {
binding_id,
ty,
value,
..
} => {
let idx = ctx
.bindings
.get(*binding_id)
.ok_or(LowerError::UnknownBinding(*binding_id))?;
if let Some(target) = ty.as_ref() {
super::optional::lower_coerced(value, target, sink, ctx)?;
} else {
lower_expr(value, sink, ctx)?;
}
sink.local_set(idx);
Ok(())
}
IrBlockStatement::Expr(e) => {
let needs_drop = !matches!(e.ty(), ResolvedType::Primitive(PrimitiveType::Never));
lower_expr(e, sink, ctx)?;
if needs_drop {
sink.drop();
}
Ok(())
}
IrBlockStatement::Assign { target, value, .. } => lower_assign(target, value, sink, ctx),
}
}
fn lower_assign(
target: &IrExpr,
value: &IrExpr,
sink: &mut InstructionSink<'_>,
ctx: &LowerContext<'_>,
) -> Result<(), LowerError> {
use super::aggregate::{
layout_for_aggregate, lookup_field_by_name, lookup_field_by_name_with_meta, primitive_of,
store_primitive,
};
use crate::layout::plan_struct;
match target {
IrExpr::SelfFieldRef { field, .. } => {
let struct_id = ctx.self_struct_id.ok_or(LowerError::MissingSelfStruct)?;
let module = ctx.module()?;
let s = module
.structs
.get(struct_id.0 as usize)
.ok_or(LowerError::UnknownStruct(struct_id))?;
let layout = plan_struct(s, module)?;
let (field_layout, field_def) = lookup_field_by_name(s, &layout.fields, field)?;
let primitive = primitive_of(&field_def.ty)?;
sink.local_get(0);
lower_expr(value, sink, ctx)?;
store_primitive(primitive, *field_layout, sink);
Ok(())
}
IrExpr::FieldAccess { object, field, .. } => {
let module = ctx.module()?;
let (layout, fields_meta) = layout_for_aggregate(object.ty(), module)?;
let (field_layout, field_def) =
lookup_field_by_name_with_meta(&fields_meta, &layout.fields, field, "<aggregate>")?;
let primitive = primitive_of(&field_def.ty)?;
lower_expr(object, sink, ctx)?;
lower_expr(value, sink, ctx)?;
store_primitive(primitive, *field_layout, sink);
Ok(())
}
IrExpr::LetRef { binding_id, .. } => {
let local_idx = ctx
.bindings
.get(*binding_id)
.ok_or(LowerError::UnknownBinding(*binding_id))?;
lower_expr(value, sink, ctx)?;
sink.local_set(local_idx);
Ok(())
}
IrExpr::Reference {
target: ref_target, ..
} => {
let binding_id = match ref_target {
formalang::ir::ReferenceTarget::Param(b)
| formalang::ir::ReferenceTarget::Local(b) => Some(*b),
formalang::ir::ReferenceTarget::Function(_)
| formalang::ir::ReferenceTarget::Struct(_)
| formalang::ir::ReferenceTarget::Enum(_)
| formalang::ir::ReferenceTarget::Trait(_)
| formalang::ir::ReferenceTarget::ModuleLet(_)
| formalang::ir::ReferenceTarget::External { .. }
| formalang::ir::ReferenceTarget::Unresolved => None,
};
if let Some(binding_id) = binding_id {
let local_idx = ctx
.bindings
.get(binding_id)
.ok_or(LowerError::UnknownBinding(binding_id))?;
lower_expr(value, sink, ctx)?;
sink.local_set(local_idx);
return Ok(());
}
Err(LowerError::NotYetImplemented {
what: format!("IrBlockStatement::Assign to Reference target {ref_target:?}"),
})
}
IrExpr::Literal { .. }
| IrExpr::StructInst { .. }
| IrExpr::EnumInst { .. }
| IrExpr::Array { .. }
| IrExpr::Tuple { .. }
| IrExpr::BinaryOp { .. }
| IrExpr::UnaryOp { .. }
| IrExpr::If { .. }
| IrExpr::For { .. }
| IrExpr::Match { .. }
| IrExpr::FunctionCall { .. }
| IrExpr::CallClosure { .. }
| IrExpr::MethodCall { .. }
| IrExpr::Closure { .. }
| IrExpr::ClosureRef { .. }
| IrExpr::DictLiteral { .. }
| IrExpr::DictAccess { .. }
| IrExpr::Block { .. } => Err(LowerError::NotYetImplemented {
what: "IrBlockStatement::Assign target shape (only field writes supported in mc9)"
.to_owned(),
}),
}
}
fn collect_local_bindings(expr: &IrExpr) -> Result<Vec<(BindingId, ValType)>, LowerError> {
let mut out = Vec::new();
walk_for_locals(expr, &mut out)?;
Ok(out)
}
fn walk_block_statements(
statements: &[IrBlockStatement],
out: &mut Vec<(BindingId, ValType)>,
) -> Result<(), LowerError> {
for stmt in statements {
match stmt {
IrBlockStatement::Let {
binding_id,
name,
value,
ty,
..
} => {
let resolved = ty.as_ref().unwrap_or_else(|| value.ty());
let vt =
body_value_type(resolved)?.ok_or_else(|| LowerError::ZeroSizedLetBinding {
name: name.clone(),
ty: resolved.clone(),
})?;
out.push((*binding_id, vt));
walk_for_locals(value, out)?;
}
IrBlockStatement::Assign { target, value, .. } => {
walk_for_locals(target, out)?;
walk_for_locals(value, out)?;
}
IrBlockStatement::Expr(e) => walk_for_locals(e, out)?,
}
}
Ok(())
}
#[expect(
clippy::too_many_lines,
reason = "exhaustive walk over every IrExpr variant; splitting hides which variants introduce new bindings"
)]
fn walk_for_locals(expr: &IrExpr, out: &mut Vec<(BindingId, ValType)>) -> Result<(), LowerError> {
match expr {
IrExpr::Block {
statements, result, ..
} => {
walk_block_statements(statements, out)?;
walk_for_locals(result, out)
}
IrExpr::BinaryOp { left, right, .. } => {
walk_for_locals(left, out)?;
walk_for_locals(right, out)
}
IrExpr::UnaryOp { operand, .. } => walk_for_locals(operand, out),
IrExpr::If {
condition,
then_branch,
else_branch,
..
} => {
walk_for_locals(condition, out)?;
walk_for_locals(then_branch, out)?;
if let Some(else_branch) = else_branch {
walk_for_locals(else_branch, out)?;
}
Ok(())
}
IrExpr::FunctionCall { args, .. } => {
for (_, arg) in args {
walk_for_locals(arg, out)?;
}
Ok(())
}
IrExpr::CallClosure { closure, args, .. } => {
walk_for_locals(closure, out)?;
for (_, arg) in args {
walk_for_locals(arg, out)?;
}
Ok(())
}
IrExpr::MethodCall { receiver, args, .. } => {
walk_for_locals(receiver, out)?;
for (_, arg) in args {
walk_for_locals(arg, out)?;
}
Ok(())
}
IrExpr::FieldAccess { object, .. } => walk_for_locals(object, out),
IrExpr::DictAccess { dict, key, .. } => {
walk_for_locals(dict, out)?;
walk_for_locals(key, out)
}
IrExpr::ClosureRef { env_struct, .. } => walk_for_locals(env_struct, out),
IrExpr::StructInst { fields, .. } | IrExpr::EnumInst { fields, .. } => {
for (_, _, value) in fields {
walk_for_locals(value, out)?;
}
Ok(())
}
IrExpr::Tuple { fields, .. } => {
for (_, value) in fields {
walk_for_locals(value, out)?;
}
Ok(())
}
IrExpr::Match {
scrutinee, arms, ..
} => {
walk_for_locals(scrutinee, out)?;
for arm in arms {
for (name, binding_id, ty) in &arm.bindings {
let vt =
body_value_type(ty)?.ok_or_else(|| LowerError::ZeroSizedLetBinding {
name: name.clone(),
ty: ty.clone(),
})?;
out.push((*binding_id, vt));
}
walk_for_locals(&arm.body, out)?;
}
Ok(())
}
IrExpr::For {
var,
var_ty,
var_binding_id,
collection,
body,
..
} => {
let vt = body_value_type(var_ty)?.ok_or_else(|| LowerError::ZeroSizedLetBinding {
name: var.clone(),
ty: var_ty.clone(),
})?;
out.push((*var_binding_id, vt));
walk_for_locals(collection, out)?;
walk_for_locals(body, out)
}
IrExpr::Literal { .. }
| IrExpr::Reference { .. }
| IrExpr::LetRef { .. }
| IrExpr::SelfFieldRef { .. }
| IrExpr::Array { .. }
| IrExpr::Closure { .. }
| IrExpr::DictLiteral { .. } => Ok(()),
}
}
pub fn lower_function_body(
body: &IrExpr,
param_bindings: &[(BindingId, ValType)],
functions: &FunctionMap,
) -> Result<Function, LowerError> {
let plan = plan_function_locals(body, param_bindings)?;
let ctx = LowerContext::new(&plan.bindings, functions);
finish_function_body(body, None, plan.locals, &ctx)
}
#[expect(
clippy::too_many_arguments,
reason = "module-aware body lowering needs every map and table-context input the called expression lowerings can possibly read; bundling into a struct hides the contract"
)]
#[expect(
clippy::implicit_hasher,
reason = "string_pool comes from the module-lowering pass and always uses the default hasher"
)]
pub fn lower_function_body_in_module(
body: &IrExpr,
return_ty: Option<&ResolvedType>,
param_bindings: &[(BindingId, ValType)],
functions: &FunctionMap,
methods: &MethodMap,
module: &IrModule,
bump_allocator: u32,
self_struct_id: Option<StructId>,
closure_ctx: Option<&ClosureCallContext<'_>>,
vtable_ctx: Option<&super::VTableContext<'_>>,
string_pool: &std::collections::HashMap<String, u32>,
str_eq: u32,
str_concat: u32,
) -> Result<Function, LowerError> {
let plan = plan_function_locals(body, param_bindings)?;
let counts = count_scratch_locals(body, return_ty, Some(module))?;
let scratch_offset = scratch_locals_offset(param_bindings.len(), plan.locals.len())?;
let mut locals = plan.locals;
let mut running = scratch_offset;
let mut next_region = |count: u32, ty: ValType| -> Result<u32, LowerError> {
let base = running;
if count > 0 {
locals.push((count, ty));
running = running
.checked_add(count)
.ok_or_else(|| LowerError::NotYetImplemented {
what: "scratch-local layout overflows u32".to_owned(),
})?;
}
Ok(base)
};
let i32_base = next_region(counts.i32, ValType::I32)?;
let i64_base = next_region(counts.i64, ValType::I64)?;
let f32_base = next_region(counts.f32, ValType::F32)?;
let f64_base = next_region(counts.f64, ValType::F64)?;
let allocator = ScratchAllocator::new(super::ScratchRegions {
i32: (i32_base, counts.i32),
i64: (i64_base, counts.i64),
f32: (f32_base, counts.f32),
f64: (f64_base, counts.f64),
});
let mut ctx = LowerContext::new(&plan.bindings, functions)
.with_methods(methods)
.with_module(module)
.with_bump_allocator(bump_allocator)
.with_scratch_locals(&allocator)
.with_string_pool(string_pool)
.with_str_eq(str_eq)
.with_str_concat(str_concat);
if let Some(id) = self_struct_id {
ctx = ctx.with_self_struct_id(id);
}
if let Some(closure) = closure_ctx {
ctx = ctx
.with_closure_table(closure.table_idx)
.with_closure_funcref_indices(closure.funcref_indices)
.with_closure_type_indices(closure.type_indices);
}
if let Some(vt) = vtable_ctx {
ctx = ctx
.with_method_table(vt.table_idx)
.with_vtable_offsets(vt.vtable_offsets)
.with_virtual_call_type_indices(vt.call_type_indices);
}
finish_function_body(body, return_ty, locals, &ctx)
}
fn scratch_locals_offset(params: usize, lets: usize) -> Result<u32, LowerError> {
let p = u32::try_from(params).map_err(|_| LowerError::NotYetImplemented {
what: "more than u32::MAX parameters in a single function".to_owned(),
})?;
let l = u32::try_from(lets).map_err(|_| LowerError::NotYetImplemented {
what: "more than u32::MAX `let` bindings in a single function".to_owned(),
})?;
p.checked_add(l)
.ok_or_else(|| LowerError::NotYetImplemented {
what: "params + lets overflow u32 in a single function".to_owned(),
})
}
#[derive(Debug, Default, Clone, Copy)]
pub(super) struct ScratchCounts {
pub i32: u32,
pub i64: u32,
pub f32: u32,
pub f64: u32,
}
fn count_scratch_locals(
expr: &IrExpr,
return_ty: Option<&ResolvedType>,
module: Option<&IrModule>,
) -> Result<ScratchCounts, LowerError> {
let mut counts = ScratchCounts::default();
if let Some(target) = return_ty {
super::optional::coercion_scratch_counts(target, expr, &mut counts, module)?;
}
walk_count(expr, module, &mut counts)?;
Ok(counts)
}
pub(super) fn bump_count(field: &mut u32) -> Result<(), LowerError> {
*field = field
.checked_add(1)
.ok_or_else(|| LowerError::NotYetImplemented {
what: "more than u32::MAX scratch slots of one type in a single function".to_owned(),
})?;
Ok(())
}
fn walk_count_block_statement(
stmt: &IrBlockStatement,
module: Option<&IrModule>,
out: &mut ScratchCounts,
) -> Result<(), LowerError> {
match stmt {
IrBlockStatement::Let { ty, value, .. } => {
if let Some(target) = ty.as_ref() {
super::optional::coercion_scratch_counts(target, value, out, module)?;
}
walk_count(value, module, out)
}
IrBlockStatement::Assign { target, value, .. } => {
walk_count(target, module, out)?;
walk_count(value, module, out)
}
IrBlockStatement::Expr(e) => walk_count(e, module, out),
}
}
#[expect(
clippy::too_many_lines,
reason = "exhaustive walk over every IrExpr variant; splitting hides which variants reserve which scratch slots"
)]
fn walk_count(
expr: &IrExpr,
module: Option<&IrModule>,
out: &mut ScratchCounts,
) -> Result<(), LowerError> {
match expr {
IrExpr::StructInst {
struct_id, fields, ..
} => {
bump_count(&mut out.i32)?;
if let Some(id) = struct_id
&& let Some(m) = module
&& let Some(s) = m.structs.get(id.0 as usize)
{
for (name, _idx, e) in fields {
if let Some(decl) = s.fields.iter().find(|f| f.name == *name) {
super::optional::coercion_scratch_counts(&decl.ty, e, out, module)?;
}
walk_count(e, module, out)?;
}
} else {
for (_, _, e) in fields {
walk_count(e, module, out)?;
}
}
}
IrExpr::EnumInst {
enum_id,
variant_idx,
fields,
..
} => {
bump_count(&mut out.i32)?;
if let Some(id) = enum_id
&& let Some(m) = module
&& let Some(e) = m.enums.get(id.0 as usize)
&& let Some(v) = e.variants.get(variant_idx.0 as usize)
{
for (name, _idx, value) in fields {
if let Some(decl) = v.fields.iter().find(|f| f.name == *name) {
super::optional::coercion_scratch_counts(&decl.ty, value, out, module)?;
}
walk_count(value, module, out)?;
}
} else {
for (_, _, value) in fields {
walk_count(value, module, out)?;
}
}
}
IrExpr::Tuple { fields, ty, .. } => {
bump_count(&mut out.i32)?;
let target_fields: Option<&Vec<(String, ResolvedType)>> =
if let ResolvedType::Tuple(ts) = ty {
Some(ts)
} else {
None
};
for (name, e) in fields {
if let Some(targets) = target_fields
&& let Some((_, t)) = targets.iter().find(|(n, _)| n == name)
{
super::optional::coercion_scratch_counts(t, e, out, module)?;
}
walk_count(e, module, out)?;
}
}
IrExpr::Block {
statements, result, ..
} => {
for stmt in statements {
walk_count_block_statement(stmt, module, out)?;
}
walk_count(result, module, out)?;
}
IrExpr::BinaryOp {
left, right, op, ..
} => {
if matches!(op, formalang::ast::BinaryOperator::Range) {
bump_count(&mut out.i32)?;
}
walk_count(left, module, out)?;
walk_count(right, module, out)?;
}
IrExpr::UnaryOp { operand, .. } => walk_count(operand, module, out)?,
IrExpr::If {
condition,
then_branch,
else_branch,
ty,
..
} => {
walk_count(condition, module, out)?;
super::optional::coercion_scratch_counts(ty, then_branch, out, module)?;
walk_count(then_branch, module, out)?;
if let Some(else_branch) = else_branch {
super::optional::coercion_scratch_counts(ty, else_branch, out, module)?;
walk_count(else_branch, module, out)?;
}
}
IrExpr::FunctionCall {
function_id, args, ..
} => {
if let Some(id) = function_id
&& let Some(m) = module
&& let Some(f) = m.functions.get(id.0 as usize)
{
for (param_name, arg) in args {
let target = param_name.as_ref().and_then(|n| {
f.params
.iter()
.find(|p| p.name == *n)
.and_then(|p| p.ty.as_ref())
});
if let Some(t) = target {
super::optional::coercion_scratch_counts(t, arg, out, module)?;
}
walk_count(arg, module, out)?;
}
} else {
for (_, arg) in args {
walk_count(arg, module, out)?;
}
}
}
IrExpr::CallClosure { closure, args, .. } => {
bump_count(&mut out.i32)?;
walk_count(closure, module, out)?;
for (_, arg) in args {
walk_count(arg, module, out)?;
}
}
IrExpr::MethodCall {
receiver,
method_idx,
args,
dispatch,
..
} => {
if matches!(dispatch, formalang::ir::DispatchKind::Virtual { .. })
&& matches!(receiver.ty(), formalang::ir::ResolvedType::Trait(_))
{
bump_count(&mut out.i32)?;
}
walk_count(receiver, module, out)?;
let method_params: Option<&[formalang::ir::IrFunctionParam]> = match dispatch {
formalang::ir::DispatchKind::Static { impl_id } => module
.and_then(|m| m.impls.get(impl_id.0 as usize))
.and_then(|i| i.functions.get(method_idx.0 as usize))
.map(|f| f.params.as_slice()),
formalang::ir::DispatchKind::Virtual { trait_id, .. } => module
.and_then(|m| m.traits.get(trait_id.0 as usize))
.and_then(|t| t.methods.get(method_idx.0 as usize))
.map(|sig| sig.params.as_slice()),
};
for (param_name, arg) in args {
let target = method_params.and_then(|params| {
param_name.as_ref().and_then(|n| {
params
.iter()
.find(|p| p.name == *n)
.and_then(|p| p.ty.as_ref())
})
});
if let Some(t) = target {
super::optional::coercion_scratch_counts(t, arg, out, module)?;
}
walk_count(arg, module, out)?;
}
}
IrExpr::FieldAccess { object, .. } => walk_count(object, module, out)?,
IrExpr::DictAccess { dict, key, .. } => {
if module.is_some_and(|m| {
matches!(
crate::compound::Compound::of(dict.ty(), m),
crate::compound::Compound::Dictionary { .. }
)
}) {
bump_count(&mut out.i32)?;
bump_count(&mut out.i32)?;
bump_count(&mut out.i32)?;
bump_count(&mut out.i32)?;
bump_count(&mut out.i32)?;
bump_count(&mut out.i32)?;
}
walk_count(dict, module, out)?;
walk_count(key, module, out)?;
}
IrExpr::Match {
scrutinee,
arms,
ty,
..
} => {
bump_count(&mut out.i32)?;
walk_count(scrutinee, module, out)?;
for arm in arms {
super::optional::coercion_scratch_counts(ty, &arm.body, out, module)?;
walk_count(&arm.body, module, out)?;
}
}
IrExpr::ClosureRef { env_struct, .. } => {
bump_count(&mut out.i32)?;
walk_count(env_struct, module, out)?;
}
IrExpr::Array { elements, ty, .. } => {
bump_count(&mut out.i32)?;
bump_count(&mut out.i32)?;
let elem_ty: Option<&ResolvedType> =
module.and_then(|m| crate::compound::array_elem(ty, m));
for e in elements {
if let Some(t) = elem_ty {
super::optional::coercion_scratch_counts(t, e, out, module)?;
}
walk_count(e, module, out)?;
}
}
IrExpr::For {
collection, body, ..
} => {
super::control::for_scratch_counts(collection.ty(), out, module)?;
walk_count(collection, module, out)?;
walk_count(body, module, out)?;
}
IrExpr::Literal { value, .. } => {
if matches!(value, formalang::ast::Literal::Nil) {
bump_count(&mut out.i32)?;
}
}
IrExpr::DictLiteral { entries, .. } => {
bump_count(&mut out.i32)?;
bump_count(&mut out.i32)?;
for (k, v) in entries {
bump_count(&mut out.i32)?;
walk_count(k, module, out)?;
walk_count(v, module, out)?;
}
}
IrExpr::Reference { .. }
| IrExpr::LetRef { .. }
| IrExpr::SelfFieldRef { .. }
| IrExpr::Closure { .. } => {}
}
Ok(())
}
struct FunctionPlan {
bindings: BindingMap,
locals: Vec<(u32, ValType)>,
}
fn plan_function_locals(
body: &IrExpr,
param_bindings: &[(BindingId, ValType)],
) -> Result<FunctionPlan, LowerError> {
let local_bindings = collect_local_bindings(body)?;
let mut binding_map = BindingMap::new();
for (i, (id, _)) in param_bindings.iter().enumerate() {
binding_map.insert(*id, index_of(i)?);
}
let local_offset = u32::try_from(param_bindings.len()).unwrap_or(u32::MAX);
for (i, (id, _)) in local_bindings.iter().enumerate() {
let idx = local_offset.saturating_add(index_of(i)?);
binding_map.insert(*id, idx);
}
let locals: Vec<(u32, ValType)> = local_bindings.iter().map(|(_, vt)| (1, *vt)).collect();
Ok(FunctionPlan {
bindings: binding_map,
locals,
})
}
fn finish_function_body(
body: &IrExpr,
return_ty: Option<&ResolvedType>,
locals: Vec<(u32, ValType)>,
ctx: &LowerContext<'_>,
) -> Result<Function, LowerError> {
let mut func = Function::new(locals);
{
let sink = &mut func.instructions();
if let Some(target) = return_ty {
super::optional::lower_coerced(body, target, sink, ctx)?;
} else {
lower_expr(body, sink, ctx)?;
}
sink.end();
}
Ok(func)
}
fn index_of(i: usize) -> Result<u32, LowerError> {
u32::try_from(i).map_err(|_| LowerError::NotYetImplemented {
what: "more than u32::MAX locals in a single function".to_owned(),
})
}