use formalang::ir::{IrExpr, ResolvedType};
use wasm_encoder::{InstructionSink, MemArg, ValType};
use super::aggregate::{allocate_aggregate, primitive_of, store_primitive};
use super::block::{ScratchCounts, bump_count};
use super::{LowerContext, LowerError, lower_expr};
use crate::layout::{FieldLayout, OPTIONAL_TAG_ALIGN, OPTIONAL_TAG_SOME, plan_optional};
use crate::module::MEMORY_INDEX;
#[must_use]
pub(super) fn some_wrap_payload<'a>(
target_ty: &'a ResolvedType,
value_ty: &ResolvedType,
module: &formalang::ir::IrModule,
) -> Option<&'a ResolvedType> {
let inner = crate::compound::optional_inner(target_ty, module)?;
if value_ty == inner { Some(inner) } else { None }
}
pub(super) fn lower_some_wrap(
value_expr: &IrExpr,
payload_ty: &ResolvedType,
sink: &mut InstructionSink<'_>,
ctx: &LowerContext<'_>,
) -> Result<(), LowerError> {
let prim = primitive_of(payload_ty).map_err(|_| LowerError::NotYetImplemented {
what: format!("Some-wrap of non-primitive payload type {payload_ty:?}"),
})?;
let module = ctx.module()?;
let layout = plan_optional(payload_ty, module)?;
let value_vt = primitive_to_valtype(payload_ty)?;
let value_scratch = ctx.next_scratch_local(value_vt)?;
lower_expr(value_expr, sink, ctx)?;
sink.local_set(value_scratch);
let base_local = allocate_aggregate(layout.size, sink, ctx)?;
sink.local_get(base_local);
sink.i32_const(i32::try_from(OPTIONAL_TAG_SOME).unwrap_or(i32::MAX));
sink.i32_store(MemArg {
offset: u64::from(layout.tag_offset),
align: OPTIONAL_TAG_ALIGN.trailing_zeros(),
memory_index: MEMORY_INDEX,
});
sink.local_get(base_local);
sink.local_get(value_scratch);
let payload_field = FieldLayout {
offset: layout.payload_offset,
size: layout.payload_size,
align: layout.payload_align,
};
store_primitive(prim, payload_field, sink);
sink.local_get(base_local);
Ok(())
}
fn primitive_to_valtype(ty: &ResolvedType) -> Result<ValType, LowerError> {
let prim = primitive_of(ty).map_err(|_| LowerError::NotYetImplemented {
what: format!("Some-wrap scratch slot for non-primitive payload {ty:?}"),
})?;
match prim {
formalang::ast::PrimitiveType::I32 | formalang::ast::PrimitiveType::Boolean => {
Ok(ValType::I32)
}
formalang::ast::PrimitiveType::I64 => Ok(ValType::I64),
formalang::ast::PrimitiveType::F32 => Ok(ValType::F32),
formalang::ast::PrimitiveType::F64 => Ok(ValType::F64),
formalang::ast::PrimitiveType::String
| formalang::ast::PrimitiveType::Path
| formalang::ast::PrimitiveType::Regex => Ok(ValType::I32),
formalang::ast::PrimitiveType::Never | _ => Err(LowerError::NotYetImplemented {
what: format!("Some-wrap scratch slot for {prim:?} payload"),
}),
}
}
pub(super) fn some_wrap_scratch_valtype(payload_ty: &ResolvedType) -> Result<ValType, LowerError> {
primitive_to_valtype(payload_ty)
}
pub(super) fn lower_coerced(
value_expr: &IrExpr,
target_ty: &ResolvedType,
sink: &mut InstructionSink<'_>,
ctx: &LowerContext<'_>,
) -> Result<(), LowerError> {
if let Some(module) = ctx.module_opt()
&& let Some(payload_ty) = some_wrap_payload(target_ty, value_expr.ty(), module)
{
return lower_some_wrap(value_expr, payload_ty, sink, ctx);
}
if let ResolvedType::Trait(trait_id) = target_ty {
return lower_trait_coercion(*trait_id, value_expr, sink, ctx);
}
lower_expr(value_expr, sink, ctx)
}
const fn trait_dispatch_target(ty: &ResolvedType) -> Option<formalang::ir::ImplTarget> {
use formalang::ir::ImplTarget;
if let Some(sid) = crate::compound::struct_id_of(ty) {
return Some(ImplTarget::Struct(sid));
}
if let Some(eid) = crate::compound::enum_id_of(ty) {
return Some(ImplTarget::Enum(eid));
}
None
}
fn lower_trait_coercion(
trait_id: formalang::ir::TraitId,
value_expr: &IrExpr,
sink: &mut InstructionSink<'_>,
ctx: &LowerContext<'_>,
) -> Result<(), LowerError> {
if matches!(value_expr.ty(), ResolvedType::Trait(_)) {
return lower_expr(value_expr, sink, ctx);
}
match value_expr {
IrExpr::If {
condition,
then_branch,
else_branch,
..
} => {
lower_expr(condition, sink, ctx)?;
sink.if_(wasm_encoder::BlockType::Result(ValType::I32));
lower_trait_coercion(trait_id, then_branch, sink, ctx)?;
if let Some(else_b) = else_branch {
sink.else_();
lower_trait_coercion(trait_id, else_b, sink, ctx)?;
}
sink.end();
Ok(())
}
IrExpr::Match { .. } => {
Err(LowerError::NotYetImplemented {
what: "trait coercion through `match` arms (push lower_coerced into each arm)"
.to_owned(),
})
}
IrExpr::Block {
statements, result, ..
} => {
for stmt in statements {
super::block::lower_block_statement(stmt, sink, ctx)?;
}
lower_trait_coercion(trait_id, result, sink, ctx)
}
IrExpr::Literal { .. }
| IrExpr::StructInst { .. }
| IrExpr::EnumInst { .. }
| IrExpr::Array { .. }
| IrExpr::Tuple { .. }
| IrExpr::Reference { .. }
| IrExpr::SelfFieldRef { .. }
| IrExpr::FieldAccess { .. }
| IrExpr::LetRef { .. }
| IrExpr::BinaryOp { .. }
| IrExpr::UnaryOp { .. }
| IrExpr::For { .. }
| IrExpr::FunctionCall { .. }
| IrExpr::CallClosure { .. }
| IrExpr::MethodCall { .. }
| IrExpr::Closure { .. }
| IrExpr::ClosureRef { .. }
| IrExpr::DictLiteral { .. }
| IrExpr::DictAccess { .. } => {
let target = trait_dispatch_target(value_expr.ty()).ok_or_else(|| {
LowerError::NotYetImplemented {
what: format!(
"trait coercion of value with non-aggregate type {:?}",
value_expr.ty()
),
}
})?;
materialize_trait_fat_pointer(trait_id, target, value_expr, sink, ctx)
}
}
}
fn materialize_trait_fat_pointer(
trait_id: formalang::ir::TraitId,
target: formalang::ir::ImplTarget,
value_expr: &IrExpr,
sink: &mut InstructionSink<'_>,
ctx: &LowerContext<'_>,
) -> Result<(), LowerError> {
use crate::layout::POINTER_SIZE;
use crate::module::MEMORY_INDEX;
use wasm_encoder::MemArg;
let vtable_offset =
ctx.vtable_offset(trait_id, crate::module_lowering::impl_target_key(target))?;
let data_scratch = ctx.next_scratch_local(ValType::I32)?;
super::lower_expr(value_expr, sink, ctx)?;
sink.local_set(data_scratch);
let cell_size = i32::try_from(POINTER_SIZE.saturating_mul(2)).unwrap_or(8);
let bump_idx = ctx.bump_allocator()?;
let cell_scratch = ctx.next_scratch_local(ValType::I32)?;
sink.i32_const(cell_size)
.call(bump_idx)
.local_set(cell_scratch);
let mem_arg = |off: u64| MemArg {
offset: off,
align: 2,
memory_index: MEMORY_INDEX,
};
let vtable_off_signed = i32::try_from(vtable_offset).unwrap_or(i32::MAX);
sink.local_get(cell_scratch);
sink.i32_const(vtable_off_signed);
sink.i32_store(mem_arg(0));
sink.local_get(cell_scratch);
sink.local_get(data_scratch);
sink.i32_store(mem_arg(u64::from(POINTER_SIZE)));
sink.local_get(cell_scratch);
Ok(())
}
pub(super) fn coercion_scratch_counts(
target_ty: &ResolvedType,
value_expr: &IrExpr,
out: &mut ScratchCounts,
module: Option<&formalang::ir::IrModule>,
) -> Result<(), LowerError> {
let value_ty = value_expr.ty();
let Some(module) = module else {
return Ok(());
};
if matches!(target_ty, ResolvedType::Trait(_)) {
return count_trait_coercion_leaves(value_expr, out);
}
let Some(payload_ty) = some_wrap_payload(target_ty, value_ty, module) else {
return Ok(());
};
bump_count(&mut out.i32)?;
let vt = some_wrap_scratch_valtype(payload_ty)?;
match vt {
ValType::I32 => bump_count(&mut out.i32)?,
ValType::I64 => bump_count(&mut out.i64)?,
ValType::F32 => bump_count(&mut out.f32)?,
ValType::F64 => bump_count(&mut out.f64)?,
ValType::V128 | ValType::Ref(_) => {
return Err(LowerError::NotYetImplemented {
what: format!("Some-wrap scratch slot of value type {vt:?}"),
});
}
}
Ok(())
}
fn count_trait_coercion_leaves(
value_expr: &IrExpr,
out: &mut ScratchCounts,
) -> Result<(), LowerError> {
if matches!(value_expr.ty(), ResolvedType::Trait(_)) {
return Ok(());
}
match value_expr {
IrExpr::If {
then_branch,
else_branch,
..
} => {
count_trait_coercion_leaves(then_branch, out)?;
if let Some(else_b) = else_branch {
count_trait_coercion_leaves(else_b, out)?;
}
Ok(())
}
IrExpr::Block { result, .. } => count_trait_coercion_leaves(result, out),
IrExpr::Literal { .. }
| IrExpr::StructInst { .. }
| IrExpr::EnumInst { .. }
| IrExpr::Array { .. }
| IrExpr::Tuple { .. }
| IrExpr::Reference { .. }
| IrExpr::SelfFieldRef { .. }
| IrExpr::FieldAccess { .. }
| IrExpr::LetRef { .. }
| IrExpr::BinaryOp { .. }
| IrExpr::UnaryOp { .. }
| IrExpr::For { .. }
| IrExpr::Match { .. }
| IrExpr::FunctionCall { .. }
| IrExpr::CallClosure { .. }
| IrExpr::MethodCall { .. }
| IrExpr::Closure { .. }
| IrExpr::ClosureRef { .. }
| IrExpr::DictLiteral { .. }
| IrExpr::DictAccess { .. } => {
bump_count(&mut out.i32)?;
bump_count(&mut out.i32)?;
Ok(())
}
}
}