rusty-javac 0.2.2

A Java compiler written in Rust.
Documentation
use crate::bytecode::codegen::CodegenCtx;
use crate::bytecode::expr_gen::{expr_ty, gen_expr, push_default_value};
use crate::classfile::{Label, MethodWriter};
use crate::hir::{Body, Expr, ExprId, Stmt, StmtId, SwitchCase};
use crate::ty::Ty;
use rust_asm::opcodes;

pub(super) fn emit_switch_expr(
    mw: &mut MethodWriter,
    ctx: &mut CodegenCtx,
    body: &Body,
    selector: ExprId,
    cases: &[SwitchCase],
) {
    let result_ty = switch_result_ty(ctx, body, cases);
    emit_switch(
        mw,
        ctx,
        body,
        selector,
        cases,
        SwitchUse::Expression(result_ty),
    );
}

pub(crate) fn emit_switch_stmt(
    mw: &mut MethodWriter,
    ctx: &mut CodegenCtx,
    body: &Body,
    selector: ExprId,
    cases: &[SwitchCase],
) {
    emit_switch(mw, ctx, body, selector, cases, SwitchUse::Statement);
}

enum SwitchUse {
    Expression(Ty),
    Statement,
}

fn emit_switch(
    mw: &mut MethodWriter,
    ctx: &mut CodegenCtx,
    body: &Body,
    selector: ExprId,
    cases: &[SwitchCase],
    switch_use: SwitchUse,
) {
    let selector_ty = expr_ty(ctx, body, selector);
    if is_string_ty(&selector_ty) {
        emit_text_switch(
            mw,
            ctx,
            body,
            selector,
            cases,
            switch_use,
            TextSwitchKind::String,
        );
    } else if has_enum_case_labels(body, cases) {
        emit_text_switch(
            mw,
            ctx,
            body,
            selector,
            cases,
            switch_use,
            TextSwitchKind::EnumName,
        );
    } else {
        emit_int_switch(mw, ctx, body, selector, cases, switch_use);
    }
}

fn emit_int_switch(
    mw: &mut MethodWriter,
    ctx: &mut CodegenCtx,
    body: &Body,
    selector: ExprId,
    cases: &[SwitchCase],
    switch_use: SwitchUse,
) {
    let end_label = Label::new();
    let labels = case_labels(cases);
    let missing_default = missing_default_label(cases, &switch_use);
    let default_target = default_index(cases)
        .and_then(|index| labels[index])
        .or(missing_default)
        .unwrap_or(end_label);
    let mut lookup_pairs = int_lookup_pairs(body, cases, &labels);

    gen_expr(mw, ctx, body, selector);
    lookup_pairs.sort_by_key(|(key, _)| *key);
    mw.visit_lookup_switch(default_target, &lookup_pairs);
    emit_case_bodies(mw, ctx, body, cases, &labels, end_label, &switch_use);
    emit_missing_default(mw, missing_default, &switch_use);
    mw.visit_label(end_label);
}

#[derive(Clone, Copy)]
enum TextSwitchKind {
    String,
    EnumName,
}

fn emit_text_switch(
    mw: &mut MethodWriter,
    ctx: &mut CodegenCtx,
    body: &Body,
    selector: ExprId,
    cases: &[SwitchCase],
    switch_use: SwitchUse,
    kind: TextSwitchKind,
) {
    let end_label = Label::new();
    let labels = case_labels(cases);
    let missing_default = missing_default_label(cases, &switch_use);
    let default_target = default_index(cases)
        .and_then(|index| labels[index])
        .or(missing_default)
        .unwrap_or(end_label);
    let selector_ty = expr_ty(ctx, body, selector).erasure();
    let selector_slot = ctx.alloc_temp(&selector_ty);

    gen_expr(mw, ctx, body, selector);
    mw.visit_var_insn(
        crate::bytecode::local_var::store_opcode(&selector_ty),
        selector_slot,
    );
    for (index, case) in cases.iter().enumerate() {
        let Some(label) = labels[index] else {
            continue;
        };
        let Some(key) = text_case_key(body, case, kind) else {
            continue;
        };
        emit_text_case_test(mw, selector_slot, &key, kind, label);
    }
    mw.visit_jump_insn(opcodes::GOTO, default_target);

    emit_case_bodies(mw, ctx, body, cases, &labels, end_label, &switch_use);
    emit_missing_default(mw, missing_default, &switch_use);
    mw.visit_label(end_label);
}

fn emit_text_case_test(
    mw: &mut MethodWriter,
    selector_slot: u16,
    key: &str,
    kind: TextSwitchKind,
    label: Label,
) {
    mw.visit_var_insn(opcodes::ALOAD, selector_slot);
    if matches!(kind, TextSwitchKind::EnumName) {
        mw.visit_method_insn(
            opcodes::INVOKEVIRTUAL,
            "java/lang/Enum",
            "name",
            "()Ljava/lang/String;",
            false,
        );
    }
    mw.visit_ldc_insn_string(key);
    mw.visit_method_insn(
        opcodes::INVOKEVIRTUAL,
        "java/lang/String",
        "equals",
        "(Ljava/lang/Object;)Z",
        false,
    );
    mw.visit_jump_insn(opcodes::IFNE, label);
}

fn emit_case_bodies(
    mw: &mut MethodWriter,
    ctx: &mut CodegenCtx,
    body: &Body,
    cases: &[SwitchCase],
    labels: &[Option<Label>],
    end_label: Label,
    switch_use: &SwitchUse,
) {
    let break_target = ctx.control_target(end_label);
    ctx.break_labels.push(break_target);
    for (index, case) in cases.iter().enumerate() {
        if let Some(label) = labels[index] {
            mw.visit_label(label);
        }
        emit_case_body(mw, ctx, body, case, end_label, switch_use);
    }
    ctx.break_labels.pop();
}

fn emit_case_body(
    mw: &mut MethodWriter,
    ctx: &mut CodegenCtx,
    body: &Body,
    case: &SwitchCase,
    end_label: Label,
    switch_use: &SwitchUse,
) {
    match switch_use {
        SwitchUse::Expression(result_ty) => {
            emit_case_value(mw, ctx, body, case, result_ty);
            mw.visit_jump_insn(opcodes::GOTO, end_label);
        }
        SwitchUse::Statement => {
            for stmt in case_stmts(case) {
                crate::bytecode::stmt_gen::gen_stmt(mw, ctx, body, *stmt);
            }
            if case_is_arrow(case) && !case_definitely_exits(body, case) {
                mw.visit_jump_insn(opcodes::GOTO, end_label);
            }
        }
    }
}

fn missing_default_label(cases: &[SwitchCase], switch_use: &SwitchUse) -> Option<Label> {
    if matches!(switch_use, SwitchUse::Expression(_)) && default_index(cases).is_none() {
        Some(Label::new())
    } else {
        None
    }
}

fn emit_missing_default(
    mw: &mut MethodWriter,
    missing_default: Option<Label>,
    switch_use: &SwitchUse,
) {
    if let (Some(label), SwitchUse::Expression(result_ty)) = (missing_default, switch_use) {
        mw.visit_label(label);
        push_default_value(mw, result_ty);
    }
}

fn case_labels(cases: &[SwitchCase]) -> Vec<Option<Label>> {
    cases.iter().map(|_| Some(Label::new())).collect()
}

fn int_lookup_pairs(
    body: &Body,
    cases: &[SwitchCase],
    labels: &[Option<Label>],
) -> Vec<(i32, Label)> {
    cases
        .iter()
        .enumerate()
        .filter_map(|(index, case)| match case {
            SwitchCase::Case { pattern, .. } => {
                let label = labels[index]?;
                int_case_key(body, *pattern).map(|key| (key, label))
            }
            SwitchCase::Default { .. } => None,
        })
        .collect()
}

fn int_case_key(body: &Body, pattern: ExprId) -> Option<i32> {
    match body.exprs[pattern] {
        Expr::IntLiteral(value) => i32::try_from(value).ok(),
        Expr::CharLiteral(value) => Some(value as i32),
        _ => None,
    }
}

fn text_case_key(body: &Body, case: &SwitchCase, kind: TextSwitchKind) -> Option<String> {
    let SwitchCase::Case { pattern, .. } = case else {
        return None;
    };
    match kind {
        TextSwitchKind::String => match &body.exprs[*pattern] {
            Expr::StringLiteral(value) => Some(value.to_string()),
            _ => None,
        },
        TextSwitchKind::EnumName => enum_case_key(body, *pattern),
    }
}

fn enum_case_key(body: &Body, pattern: ExprId) -> Option<String> {
    match &body.exprs[pattern] {
        Expr::Ident(name) => Some(name.to_string()),
        Expr::FieldAccess { field, .. } => Some(field.to_string()),
        _ => None,
    }
}

fn has_enum_case_labels(body: &Body, cases: &[SwitchCase]) -> bool {
    cases.iter().any(|case| {
        matches!(
            case,
            SwitchCase::Case { pattern, .. } if enum_case_key(body, *pattern).is_some()
        )
    })
}

fn default_index(cases: &[SwitchCase]) -> Option<usize> {
    cases
        .iter()
        .position(|case| matches!(case, SwitchCase::Default { .. }))
}

fn emit_case_value(
    mw: &mut MethodWriter,
    ctx: &mut CodegenCtx,
    body: &Body,
    case: &SwitchCase,
    switch_ty: &Ty,
) {
    if let Some(expr) = case_value(case, body) {
        gen_expr(mw, ctx, body, expr);
        let value_ty = expr_ty(ctx, body, expr);
        crate::bytecode::expr_gen::coerce(mw, &value_ty, switch_ty);
    } else {
        push_default_value(mw, switch_ty);
    }
}

fn switch_result_ty(ctx: &CodegenCtx, body: &Body, cases: &[SwitchCase]) -> Ty {
    cases
        .iter()
        .find_map(|case| case_value(case, body))
        .map(|expr| expr_ty(ctx, body, expr))
        .unwrap_or_else(Ty::object)
}

fn case_value(case: &SwitchCase, body: &Body) -> Option<ExprId> {
    case_stmts(case)
        .iter()
        .find_map(|stmt| match &body.stmts[*stmt] {
            Stmt::Yield(expr) | Stmt::Return(Some(expr)) | Stmt::Expr(expr) => Some(*expr),
            Stmt::Block(block) => block
                .stmts
                .iter()
                .find_map(|stmt| match &body.stmts[*stmt] {
                    Stmt::Yield(expr) | Stmt::Return(Some(expr)) | Stmt::Expr(expr) => Some(*expr),
                    _ => None,
                }),
            _ => None,
        })
}

fn case_stmts(case: &SwitchCase) -> &[StmtId] {
    match case {
        SwitchCase::Case { body, .. } | SwitchCase::Default { body, .. } => body,
    }
}

fn case_is_arrow(case: &SwitchCase) -> bool {
    match case {
        SwitchCase::Case { is_arrow, .. } | SwitchCase::Default { is_arrow, .. } => *is_arrow,
    }
}

fn case_definitely_exits(body: &Body, case: &SwitchCase) -> bool {
    case_stmts(case)
        .last()
        .map(|stmt| {
            matches!(
                body.stmts[*stmt],
                Stmt::Return(_) | Stmt::Throw(_) | Stmt::Break(_)
            )
        })
        .unwrap_or(false)
}

fn is_string_ty(ty: &Ty) -> bool {
    ty.is_string()
}