use super::{CompileError, FnCompiler};
use crate::ast::{Expr, Literal, MatchArm, Pattern, Spanned};
use crate::ir::{
BoolCompareOp, BoolMatchShape, BoolSubjectPlan, CallLowerCtx, DispatchArmPlan,
DispatchBindingPlan, DispatchLiteral, DispatchTableShape, MatchDispatchPlan,
SemanticConstructor, SemanticDispatchPattern, WrapperKind, classify_bool_subject_plan,
classify_dispatch_pattern, classify_match_dispatch_plan,
};
use crate::nan_value::NanValue;
use crate::vm::opcode::*;
const QNAN: u64 = 0x7FFC_0000_0000_0000;
const TAG_SHIFT: u32 = 46;
const TAG_SOME: u64 = 4;
const TAG_OK: u64 = 6;
const TAG_ERR: u64 = 7;
const DISPATCH_KIND_EXACT: u8 = 0;
const DISPATCH_KIND_TAG: u8 = 1;
const DISPATCH_KIND_STRING: u8 = 2;
fn wrapper_tag_kind(kind: WrapperKind) -> u8 {
match kind {
WrapperKind::ResultOk => 0,
WrapperKind::ResultErr => 1,
WrapperKind::OptionSome => 2,
}
}
fn wrapper_tag_bits(kind: WrapperKind) -> u64 {
let tag = match kind {
WrapperKind::ResultOk => TAG_OK,
WrapperKind::ResultErr => TAG_ERR,
WrapperKind::OptionSome => TAG_SOME,
};
QNAN | (tag << TAG_SHIFT)
}
struct DispatchableArm {
kind: u8, expected: u64, arm_index: usize,
}
struct VmPatternCtx<'compiler, 'a> {
compiler: &'compiler FnCompiler<'a>,
}
impl CallLowerCtx for VmPatternCtx<'_, '_> {
fn is_local_value(&self, name: &str) -> bool {
self.compiler.local_slots.contains_key(name)
}
fn is_user_type(&self, name: &str) -> bool {
self.compiler.resolve_type_id(name).is_some()
}
fn resolve_module_call<'a>(&self, dotted: &'a str) -> Option<(&'a str, &'a str)> {
let mut best = None;
for (dot_idx, _) in dotted.match_indices('.') {
let prefix = &dotted[..dot_idx];
let suffix = &dotted[dot_idx + 1..];
if suffix.is_empty()
|| self
.compiler
.symbols
.resolve_namespace_path(prefix)
.is_none()
{
continue;
}
let is_module_ctor = suffix.rsplit_once('.').is_some_and(|(type_name, _)| {
self.compiler.resolve_type_id(type_name).is_some()
|| self
.compiler
.resolve_type_id(&format!("{prefix}.{type_name}"))
.is_some()
});
if is_module_ctor {
best = Some((prefix, suffix));
}
}
best
}
}
impl<'a> FnCompiler<'a> {
fn compile_unwrap_pattern(&mut self, kind: u8, binding: Option<&String>) -> Vec<usize> {
self.emit_op(MATCH_UNWRAP);
self.emit_u8(kind);
let fail_patch = self.code.len();
self.emit_i16(0);
if let Some(binding) = binding {
self.dup_and_bind_top_to_local(binding);
}
vec![fail_patch]
}
fn compile_extracted_subpattern<F>(
&mut self,
emit_subject: F,
pattern: &Pattern,
) -> Result<Vec<usize>, CompileError>
where
F: FnOnce(&mut Self),
{
emit_subject(self);
let inner_fail_patches = self.compile_pattern(pattern)?;
self.emit_op(POP);
if inner_fail_patches.is_empty() {
return Ok(Vec::new());
}
let success_skip_cleanup = self.emit_jump(JUMP);
let cleanup_target = self.offset();
for patch in inner_fail_patches {
self.patch_jump_to(patch, cleanup_target);
}
self.emit_op(POP);
let outer_fail = self.emit_jump(JUMP);
self.patch_jump(success_skip_cleanup);
Ok(vec![outer_fail])
}
fn compile_tuple_pattern(&mut self, patterns: &[Pattern]) -> Result<Vec<usize>, CompileError> {
self.emit_op(MATCH_TUPLE);
self.emit_u8(patterns.len() as u8);
let tuple_fail = self.code.len();
self.emit_i16(0);
let mut fail_patches = vec![tuple_fail];
for (i, pattern) in patterns.iter().enumerate() {
let mut nested = self.compile_extracted_subpattern(
|this| {
this.emit_op(EXTRACT_TUPLE_ITEM);
this.emit_u8(i as u8);
},
pattern,
)?;
fail_patches.append(&mut nested);
}
Ok(fail_patches)
}
fn classify_dispatchable(
&mut self,
pattern: &Pattern,
arm_index: usize,
) -> Option<DispatchableArm> {
let lower_ctx = VmPatternCtx { compiler: self };
match classify_dispatch_pattern(pattern, &lower_ctx)? {
SemanticDispatchPattern::Literal(lit) => {
let (kind, bits) = match lit {
DispatchLiteral::Int(i) => {
(DISPATCH_KIND_EXACT, NanValue::new_int(i, self.arena).bits())
}
DispatchLiteral::Float(f) => {
let value = f.parse::<f64>().ok()?;
(DISPATCH_KIND_EXACT, NanValue::new_float(value).bits())
}
DispatchLiteral::Bool(b) => (DISPATCH_KIND_EXACT, NanValue::new_bool(b).bits()),
DispatchLiteral::Str(s) => (
DISPATCH_KIND_STRING,
NanValue::new_string_value(&s, self.arena).bits(),
),
DispatchLiteral::Unit => (DISPATCH_KIND_EXACT, NanValue::UNIT.bits()),
};
Some(DispatchableArm {
kind,
expected: bits,
arm_index,
})
}
SemanticDispatchPattern::EmptyList => Some(DispatchableArm {
kind: DISPATCH_KIND_EXACT,
expected: NanValue::EMPTY_LIST.bits(),
arm_index,
}),
SemanticDispatchPattern::NoneValue => Some(DispatchableArm {
kind: DISPATCH_KIND_EXACT,
expected: NanValue::NONE.bits(),
arm_index,
}),
SemanticDispatchPattern::WrapperTag(kind) => Some(DispatchableArm {
kind: DISPATCH_KIND_TAG,
expected: wrapper_tag_bits(kind),
arm_index,
}),
}
}
fn emit_dispatch_arm_prologue(&mut self, entry: &DispatchArmPlan) {
if let (
SemanticDispatchPattern::WrapperTag(kind),
DispatchBindingPlan::WrapperPayload(name),
) = (&entry.pattern, &entry.binding)
{
self.emit_op(MATCH_UNWRAP);
self.emit_u8(wrapper_tag_kind(*kind));
self.emit_i16(0);
self.dup_and_bind_top_to_local(name);
}
}
fn try_const_expr(&mut self, expr: &Expr) -> Option<u64> {
match expr {
Expr::Literal(lit) => {
let nv = match lit {
Literal::Int(i) => NanValue::new_int(*i, self.arena),
Literal::Float(f) => NanValue::new_float(*f),
Literal::Bool(b) => NanValue::new_bool(*b),
Literal::Unit => NanValue::UNIT,
Literal::Str(s) => NanValue::new_string_value(s, self.arena),
};
Some(nv.bits())
}
_ => None,
}
}
fn emit_constructor_bindings_unconditional(
&mut self,
name: &str,
bindings: &[String],
) -> Result<(), CompileError> {
match self.classify_constructor_semantics(name) {
SemanticConstructor::Wrapper(kind) if !bindings.is_empty() => {
self.emit_op(MATCH_UNWRAP);
self.emit_u8(wrapper_tag_kind(kind));
self.emit_i16(0); self.dup_and_bind_top_to_local(&bindings[0]);
}
SemanticConstructor::Wrapper(_) => {}
SemanticConstructor::NoneValue => {} SemanticConstructor::TypeConstructor { .. } | SemanticConstructor::Unknown(_) => {
for (i, b) in bindings.iter().enumerate() {
self.emit_op(EXTRACT_FIELD);
self.emit_u8(i as u8);
self.bind_top_to_local(b);
}
}
}
Ok(())
}
pub(super) fn compile_match(
&mut self,
subject: &Spanned<Expr>,
arms: &[MatchArm],
) -> Result<(), CompileError> {
let lower_ctx = VmPatternCtx { compiler: self };
if let Some(plan) = classify_match_dispatch_plan(arms, &lower_ctx) {
match plan {
MatchDispatchPlan::Bool(shape) => {
self.compile_bool_match_with_shape(subject, arms, shape)?;
return Ok(());
}
MatchDispatchPlan::Table(shape) => {
if let Some(result) =
self.try_compile_match_dispatch_with_shape(subject, arms, &shape)?
{
return Ok(result);
}
}
MatchDispatchPlan::List(_) => {}
}
}
self.compile_expr(subject)?;
let mut end_jumps = Vec::new();
for (i, arm) in arms.iter().enumerate() {
let is_last = i == arms.len() - 1;
let fail_patches = if is_last {
if let Pattern::Ident(name) = &arm.pattern {
self.dup_and_bind_top_to_local(name);
} else if let Pattern::Constructor(name, bindings) = &arm.pattern {
self.emit_constructor_bindings_unconditional(name, bindings)?;
} else if let Pattern::Cons(head, tail) = &arm.pattern {
self.emit_op(DUP);
self.emit_op(LIST_HEAD_TAIL);
self.bind_top_to_local(head);
self.bind_top_to_local(tail);
} else if let Pattern::Tuple(patterns) = &arm.pattern {
for (idx, pat) in patterns.iter().enumerate() {
self.emit_op(EXTRACT_TUPLE_ITEM);
self.emit_u8(idx as u8);
if let Pattern::Ident(name) = pat {
self.bind_top_to_local(name);
} else {
self.emit_op(POP);
}
}
}
Vec::new()
} else {
match &arm.pattern {
Pattern::Wildcard => Vec::new(),
Pattern::Ident(name) => {
self.dup_and_bind_top_to_local(name);
Vec::new()
}
pat => self.compile_pattern(pat)?,
}
};
self.emit_op(POP);
self.compile_expr(&arm.body)?;
if !is_last {
end_jumps.push(self.emit_jump(JUMP));
if !fail_patches.is_empty() {
let fail_cleanup = self.offset();
for patch in fail_patches {
self.patch_jump_to(patch, fail_cleanup);
}
}
}
}
for patch in end_jumps {
self.patch_jump(patch);
}
Ok(())
}
fn compile_bool_match_with_shape(
&mut self,
subject: &Spanned<Expr>,
arms: &[MatchArm],
shape: BoolMatchShape,
) -> Result<(), CompileError> {
let true_body = &arms[shape.true_arm_index].body;
let false_body = &arms[shape.false_arm_index].body;
if let BoolSubjectPlan::Compare {
lhs,
rhs,
op,
invert,
} = classify_bool_subject_plan(&subject.node)
{
self.compile_expr(lhs)?;
self.compile_expr(rhs)?;
self.emit_op(match op {
BoolCompareOp::Eq => EQ,
BoolCompareOp::Lt => LT,
BoolCompareOp::Gt => GT,
});
if invert {
let true_jump = self.emit_jump(JUMP_IF_FALSE);
self.compile_expr(false_body)?;
let end_jump = self.emit_jump(JUMP);
self.patch_jump(true_jump);
self.compile_expr(true_body)?;
self.patch_jump(end_jump);
} else {
let false_jump = self.emit_jump(JUMP_IF_FALSE);
self.compile_expr(true_body)?;
let end_jump = self.emit_jump(JUMP);
self.patch_jump(false_jump);
self.compile_expr(false_body)?;
self.patch_jump(end_jump);
}
return Ok(());
}
self.compile_expr(subject)?;
let false_jump = self.emit_jump(JUMP_IF_FALSE);
self.compile_expr(true_body)?;
let end_jump = self.emit_jump(JUMP);
self.patch_jump(false_jump);
self.compile_expr(false_body)?;
self.patch_jump(end_jump);
Ok(())
}
fn try_compile_match_dispatch_with_shape(
&mut self,
subject: &Spanned<Expr>,
arms: &[MatchArm],
shape: &DispatchTableShape,
) -> Result<Option<()>, CompileError> {
if shape.entries.len() > 255 {
return Ok(None);
}
let mut entries = Vec::new();
for entry in &shape.entries {
if let Some(lowered) =
self.classify_dispatchable(&arms[entry.arm_index].pattern, entry.arm_index)
{
entries.push(lowered);
} else {
return Ok(None);
}
}
let has_default = shape.default_arm.is_some();
let all_const = entries.iter().all(|e| {
let arm = &arms[e.arm_index];
(e.kind == DISPATCH_KIND_EXACT || e.kind == DISPATCH_KIND_STRING)
&& self.try_const_expr(&arm.body.node).is_some()
});
if all_const {
return self.emit_match_dispatch_const(
&entries,
arms,
subject,
shape.default_arm.as_ref().map(|arm| arm.arm_index),
);
}
self.compile_expr(subject)?;
self.emit_op(MATCH_DISPATCH);
self.emit_u8(entries.len() as u8);
let default_offset_patch = self.code.len();
self.emit_i16(0);
let mut entry_offset_patches = Vec::new();
for entry in &entries {
self.emit_u8(entry.kind);
self.emit_u64(entry.expected);
entry_offset_patches.push(self.code.len());
self.emit_i16(0); }
let table_end = self.offset();
let mut end_jumps = Vec::new();
for (table_idx, (entry, plan_entry)) in entries.iter().zip(shape.entries.iter()).enumerate()
{
let arm = &arms[entry.arm_index];
let arm_start = self.offset();
let rel = (arm_start as isize - table_end as isize) as i16;
let bytes = (rel as u16).to_be_bytes();
self.code[entry_offset_patches[table_idx]] = bytes[0];
self.code[entry_offset_patches[table_idx] + 1] = bytes[1];
self.emit_dispatch_arm_prologue(plan_entry);
self.emit_op(POP);
self.compile_expr(&arm.body)?;
end_jumps.push(self.emit_jump(JUMP));
}
let default_start = self.offset();
let default_rel = (default_start as isize - table_end as isize) as i16;
let default_bytes = (default_rel as u16).to_be_bytes();
self.code[default_offset_patch] = default_bytes[0];
self.code[default_offset_patch + 1] = default_bytes[1];
if has_default {
let default_plan = shape.default_arm.as_ref().unwrap();
let default_arm = &arms[default_plan.arm_index];
if let Some(name) = &default_plan.binding_name {
self.dup_and_bind_top_to_local(name);
}
self.emit_op(POP);
self.compile_expr(&default_arm.body)?;
} else {
}
for patch in end_jumps {
self.patch_jump(patch);
}
Ok(Some(()))
}
fn emit_match_dispatch_const(
&mut self,
entries: &[DispatchableArm],
arms: &[MatchArm],
subject: &Spanned<Expr>,
default_arm_index: Option<usize>,
) -> Result<Option<()>, CompileError> {
self.compile_expr(subject)?;
let has_default = default_arm_index.is_some();
self.emit_op(MATCH_DISPATCH_CONST);
self.emit_u8(entries.len() as u8);
let default_offset_patch = self.code.len();
self.emit_i16(0);
for entry in entries {
let arm = &arms[entry.arm_index];
let result_bits = self.try_const_expr(&arm.body.node).unwrap();
self.emit_u8(entry.kind);
self.emit_u64(entry.expected);
self.emit_u64(result_bits);
}
let table_end = self.offset();
let hit_skip_jump = if has_default {
Some(self.emit_jump(JUMP))
} else {
None
};
let default_start = self.offset();
let default_rel = (default_start as isize - table_end as isize) as i16;
let default_bytes = (default_rel as u16).to_be_bytes();
self.code[default_offset_patch] = default_bytes[0];
self.code[default_offset_patch + 1] = default_bytes[1];
if has_default {
let default_arm = &arms[default_arm_index.unwrap()];
if let Pattern::Ident(name) = &default_arm.pattern {
self.dup_and_bind_top_to_local(name);
}
self.emit_op(POP);
self.compile_expr(&default_arm.body)?;
}
if let Some(patch) = hit_skip_jump {
self.patch_jump(patch);
}
Ok(Some(()))
}
fn compile_pattern(&mut self, pattern: &Pattern) -> Result<Vec<usize>, CompileError> {
match pattern {
Pattern::Wildcard => Ok(Vec::new()),
Pattern::Ident(name) => {
self.dup_and_bind_top_to_local(name);
Ok(Vec::new())
}
Pattern::Literal(lit) => {
self.emit_op(DUP);
self.compile_literal(lit)?;
self.emit_op(EQ);
let patch = self.emit_jump(JUMP_IF_FALSE);
Ok(vec![patch])
}
Pattern::EmptyList => {
self.emit_op(MATCH_NIL);
let patch = self.code.len();
self.emit_i16(0);
Ok(vec![patch])
}
Pattern::Cons(head, tail) => {
self.emit_op(MATCH_CONS);
let fail_patch = self.code.len();
self.emit_i16(0);
self.emit_op(DUP);
self.emit_op(LIST_HEAD_TAIL);
self.bind_top_to_local(head);
self.bind_top_to_local(tail);
Ok(vec![fail_patch])
}
Pattern::Constructor(name, bindings) => {
self.compile_constructor_pattern(name, bindings)
}
Pattern::Tuple(patterns) => self.compile_tuple_pattern(patterns),
}
}
fn compile_constructor_pattern(
&mut self,
name: &str,
bindings: &[String],
) -> Result<Vec<usize>, CompileError> {
match self.classify_constructor_semantics(name) {
SemanticConstructor::Wrapper(kind) => {
Ok(self.compile_unwrap_pattern(wrapper_tag_kind(kind), bindings.first()))
}
SemanticConstructor::NoneValue => {
self.emit_op(DUP);
let none_const = self.add_constant(NanValue::NONE);
self.emit_op(LOAD_CONST);
self.emit_u16(none_const);
self.emit_op(EQ);
let fail_patch = self.emit_jump(JUMP_IF_FALSE);
Ok(vec![fail_patch])
}
SemanticConstructor::TypeConstructor {
qualified_type_name,
variant_name,
} => {
if let Some(type_id) = self.resolve_type_id(&qualified_type_name)
&& let Some(variant_id) = self.arena.find_variant_id(type_id, &variant_name)
&& let Some(ctor_id) = self.arena.find_ctor_id(type_id, variant_id)
{
if ctor_id > u16::MAX as u32 {
return Err(CompileError {
msg: format!("constructor id too large for VM pattern match: {}", name),
});
}
let mut patches = Vec::new();
self.emit_op(MATCH_VARIANT);
self.emit_u16(ctor_id as u16);
let variant_fail = self.code.len();
self.emit_i16(0);
patches.push(variant_fail);
for (i, b) in bindings.iter().enumerate() {
self.emit_op(EXTRACT_FIELD);
self.emit_u8(i as u8);
self.bind_top_to_local(b);
}
return Ok(patches);
}
Err(CompileError {
msg: format!("unknown constructor pattern: {}", name),
})
}
SemanticConstructor::Unknown(_) => Err(CompileError {
msg: format!("unknown constructor pattern: {}", name),
}),
}
}
}