cranelift-codegen 0.24.0

Low-level code generator library
Documentation
"""
Generate legalizer transformations.

The transformations defined in the `cranelift.legalize` module are all of the
macro-expansion form where the input pattern is a single instruction. We
generate a Rust function for each `XFormGroup` which takes a `Cursor` pointing
at the instruction to be legalized. The expanded destination pattern replaces
the input instruction.
"""
from __future__ import absolute_import
from srcgen import Formatter
from collections import defaultdict
from base import instructions
from cdsl.ast import Var
from cdsl.ti import ti_rtl, TypeEnv, get_type_env, TypesEqual,\
    InTypeset, WiderOrEq
from unique_table import UniqueTable
from gen_instr import gen_typesets_table
from cdsl.typevar import TypeVar

try:
    from typing import Sequence, List, Dict, Set, DefaultDict # noqa
    from cdsl.isa import TargetISA  # noqa
    from cdsl.ast import Def, VarAtomMap  # noqa
    from cdsl.xform import XForm, XFormGroup  # noqa
    from cdsl.typevar import TypeSet # noqa
    from cdsl.ti import TypeConstraint # noqa
except ImportError:
    pass


def get_runtime_typechecks(xform):
    # type: (XForm) -> List[TypeConstraint]
    """
    Given a XForm build a list of runtime type checks neccessary to determine
    if it applies. We have 2 types of runtime checks:
        1) typevar tv belongs to typeset T - needed for free tvs whose
               typeset is constrainted by their use in the dst pattern

        2) tv1 == tv2 where tv1 and tv2 are derived TVs - caused by unification
                of non-bijective functions
    """
    check_l = []  # type: List[TypeConstraint]

    # 1) Perform ti only on the source RTL. Accumulate any free tvs that have a
    #    different inferred type in src, compared to the type inferred for both
    #    src and dst.
    symtab = {}  # type: VarAtomMap
    src_copy = xform.src.copy(symtab)
    src_typenv = get_type_env(ti_rtl(src_copy, TypeEnv()))

    for v in xform.ti.vars:
        if not v.has_free_typevar():
            continue

        # In rust the local variable containing a free TV associated with var v
        # has name typeof_v. We rely on the python TVs having the same name.
        assert "typeof_{}".format(v) == xform.ti[v].name

        if v not in symtab:
            # We can have singleton vars defined only on dst. Ignore them
            assert v.get_typevar().singleton_type() is not None
            continue

        inner_v = symtab[v]
        assert isinstance(inner_v, Var)
        src_ts = src_typenv[inner_v].get_typeset()
        xform_ts = xform.ti[v].get_typeset()

        assert xform_ts.issubset(src_ts)
        if src_ts != xform_ts:
            check_l.append(InTypeset(xform.ti[v], xform_ts))

    # 2,3) Add any constraints that appear in xform.ti
    check_l.extend(xform.ti.constraints)

    return check_l


def emit_runtime_typecheck(check, fmt, type_sets):
    # type: (TypeConstraint, Formatter, UniqueTable) -> None
    """
    Emit rust code for the given check.

    The emitted code is a statement redefining the `predicate` variable like
    this:

        let predicate = predicate && ...
    """
    def build_derived_expr(tv):
        # type: (TypeVar) -> str
        """
        Build an expression of type Option<Type> corresponding to a concrete
        type transformed by the sequence of derivation functions in tv.

        We are using Option<Type>, as some constraints may cause an
        over/underflow on patterns that do not match them. We want to capture
        this without panicking at runtime.
        """
        if not tv.is_derived:
            assert tv.name.startswith('typeof_')
            return "Some({})".format(tv.name)

        base_exp = build_derived_expr(tv.base)
        if (tv.derived_func == TypeVar.LANEOF):
            return "{}.map(|t: ir::Type| t.lane_type())".format(base_exp)
        elif (tv.derived_func == TypeVar.ASBOOL):
            return "{}.map(|t: ir::Type| t.as_bool())".format(base_exp)
        elif (tv.derived_func == TypeVar.HALFWIDTH):
            return "{}.and_then(|t: ir::Type| t.half_width())".format(base_exp)
        elif (tv.derived_func == TypeVar.DOUBLEWIDTH):
            return "{}.and_then(|t: ir::Type| t.double_width())"\
                .format(base_exp)
        elif (tv.derived_func == TypeVar.HALFVECTOR):
            return "{}.and_then(|t: ir::Type| t.half_vector())"\
                .format(base_exp)
        elif (tv.derived_func == TypeVar.DOUBLEVECTOR):
            return "{}.and_then(|t: ir::Type| t.by(2))".format(base_exp)
        else:
            assert False, "Unknown derived function {}".format(tv.derived_func)

    if (isinstance(check, InTypeset)):
        assert not check.tv.is_derived
        tv = check.tv.name
        if check.ts not in type_sets.index:
            type_sets.add(check.ts)
        ts = type_sets.index[check.ts]
        fmt.comment("{} must belong to {}".format(tv, check.ts))
        fmt.format(
                'let predicate = predicate && TYPE_SETS[{}].contains({});',
                ts, tv)
    elif (isinstance(check, TypesEqual)):
        with fmt.indented(
            'let predicate = predicate && match ({}, {}) {{'
            .format(build_derived_expr(check.tv1),
                    build_derived_expr(check.tv2)), '};'):
            fmt.line('(Some(a), Some(b)) => a == b,')
            fmt.comment('On overflow, constraint doesn\'t appply')
            fmt.line('_ => false,')
    elif (isinstance(check, WiderOrEq)):
        with fmt.indented(
            'let predicate = predicate && match ({}, {}) {{'
            .format(build_derived_expr(check.tv1),
                    build_derived_expr(check.tv2)), '};'):
            fmt.line('(Some(a), Some(b)) => a.wider_or_equal(b),')
            fmt.comment('On overflow, constraint doesn\'t appply')
            fmt.line('_ => false,')
    else:
        assert False, "Unknown check {}".format(check)


def unwrap_inst(iref, node, fmt):
    # type: (str, Def, Formatter) -> bool
    """
    Given a `Def` node, emit code that extracts all the instruction fields from
    `pos.func.dfg[iref]`.

    Create local variables named after the `Var` instances in `node`.

    Also create a local variable named `predicate` with the value of the
    evaluated instruction predicate, or `true` if the node has no predicate.

    :param iref: Name of the `Inst` reference to unwrap.
    :param node: `Def` node providing variable names.
    :returns: True if the instruction arguments were not detached, expecting a
              replacement instruction to overwrite the original.
    """
    fmt.comment('Unwrap {}'.format(node))
    expr = node.expr
    iform = expr.inst.format
    nvops = iform.num_value_operands

    # The tuple of locals to extract is the `Var` instances in `expr.args`.
    arg_names = tuple(
            arg.name if isinstance(arg, Var) else '_' for arg in expr.args)
    with fmt.indented(
            'let ({}, predicate) = if let ir::InstructionData::{} {{'
            .format(', '.join(map(str, arg_names)), iform.name), '};'):
        # Fields are encoded directly.
        for f in iform.imm_fields:
            fmt.line('{},'.format(f.member))
        if nvops == 1:
            fmt.line('arg,')
        elif iform.has_value_list or nvops > 1:
            fmt.line('ref args,')
        fmt.line('..')
        fmt.outdented_line('} = pos.func.dfg[inst] {')
        fmt.line('let func = &pos.func;')
        if iform.has_value_list:
            fmt.line('let args = args.as_slice(&func.dfg.value_lists);')
        elif nvops == 1:
            fmt.line('let args = [arg];')
        # Generate the values for the tuple.
        with fmt.indented('(', ')'):
            for opnum, op in enumerate(expr.inst.ins):
                if op.is_immediate():
                    n = expr.inst.imm_opnums.index(opnum)
                    fmt.format('{},', iform.imm_fields[n].member)
                elif op.is_value():
                    n = expr.inst.value_opnums.index(opnum)
                    fmt.format('func.dfg.resolve_aliases(args[{}]),', n)
            # Evaluate the instruction predicate, if any.
            instp = expr.inst_predicate_with_ctrl_typevar()
            fmt.line(instp.rust_predicate(0) if instp else 'true')
        fmt.outdented_line('} else {')
        fmt.line('unreachable!("bad instruction format")')

    # Get the types of any variables where it is needed.
    for opnum in expr.inst.value_opnums:
        v = expr.args[opnum]
        if isinstance(v, Var) and v.has_free_typevar():
            fmt.format('let typeof_{0} = pos.func.dfg.value_type({0});', v)

    # If the node has results, detach the values.
    # Place the values in locals.
    replace_inst = False
    if len(node.defs) > 0:
        if node.defs == node.defs[0].dst_def.defs:
            # Special case: The instruction replacing node defines the exact
            # same values.
            fmt.comment(
                    'Results handled by {}.'
                    .format(node.defs[0].dst_def))
            replace_inst = True
        else:
            # Boring case: Detach the result values, capture them in locals.
            for d in node.defs:
                fmt.line('let {};'.format(d))
            with fmt.indented('{', '}'):
                fmt.line('let r = pos.func.dfg.inst_results(inst);')
                for i in range(len(node.defs)):
                    fmt.line('{} = r[{}];'.format(node.defs[i], i))
            for d in node.defs:
                if d.has_free_typevar():
                    fmt.line(
                            'let typeof_{0} = pos.func.dfg.value_type({0});'
                            .format(d))

    return replace_inst


def wrap_tup(seq):
    # type: (Sequence[object]) -> str
    tup = tuple(map(str, seq))
    if len(tup) == 1:
        return tup[0]
    else:
        return '({})'.format(', '.join(tup))


def is_value_split(node):
    # type: (Def) -> bool
    """
    Determine if `node` represents one of the value splitting instructions:
    `isplit` or `vsplit. These instructions are lowered specially by the
    `legalize::split` module.
    """
    if len(node.defs) != 2:
        return False
    return node.expr.inst in (instructions.isplit, instructions.vsplit)


def emit_dst_inst(node, fmt):
    # type: (Def, Formatter) -> None
    replaced_inst = None  # type: str

    if is_value_split(node):
        # Split instructions are not emitted with the builder, but by calling
        # special functions in the `legalizer::split` module. These functions
        # will eliminate concat-split patterns.
        fmt.line('let curpos = pos.position();')
        fmt.line('let srcloc = pos.srcloc();')
        fmt.format(
                'let {} = split::{}(pos.func, cfg, curpos, srcloc, {});',
                wrap_tup(node.defs),
                node.expr.inst.snake_name(),
                node.expr.args[0])
    else:
        if len(node.defs) == 0:
            # This node doesn't define any values, so just insert the new
            # instruction.
            builder = 'pos.ins()'
        else:
            src_def0 = node.defs[0].src_def
            if src_def0 and node.defs == src_def0.defs:
                # The replacement instruction defines the exact same values as
                # the source pattern. Unwrapping would have left the results
                # intact.
                # Replace the whole instruction.
                builder = 'let {} = pos.func.dfg.replace(inst)'.format(
                        wrap_tup(node.defs))
                replaced_inst = 'inst'
            else:
                # Insert a new instruction.
                builder = 'let {} = pos.ins()'.format(wrap_tup(node.defs))
                # We may want to reuse some of the detached output values.
                if len(node.defs) == 1 and node.defs[0].is_output():
                    # Reuse the single source result value.
                    builder += '.with_result({})'.format(node.defs[0])
                elif any(d.is_output() for d in node.defs):
                    # We have some output values to be reused.
                    array = ', '.join(
                            ('Some({})'.format(d) if d.is_output()
                                else 'None')
                            for d in node.defs)
                    builder += '.with_results([{}])'.format(array)

        fmt.line('{}.{};'.format(builder, node.expr.rust_builder(node.defs)))

    # If we just replaced an instruction, we need to bump the cursor so
    # following instructions are inserted *after* the replaced instruction.
    if replaced_inst:
        with fmt.indented(
                'if pos.current_inst() == Some({}) {{'
                .format(replaced_inst), '}'):
            fmt.line('pos.next_inst();')


def gen_xform(xform, fmt, type_sets):
    # type: (XForm, Formatter, UniqueTable) -> None
    """
    Emit code for `xform`, assuming that the opcode of xform's root instruction
    has already been matched.

    `inst: Inst` is the variable to be replaced. It is pointed to by `pos:
    Cursor`.
    `dfg: DataFlowGraph` is available and mutable.
    """
    # Unwrap the source instruction, create local variables for the input
    # variables.
    replace_inst = unwrap_inst('inst', xform.src.rtl[0], fmt)

    # Emit any runtime checks.
    # These will rebind `predicate` emitted by unwrap_inst().
    for check in get_runtime_typechecks(xform):
        emit_runtime_typecheck(check, fmt, type_sets)

    # Guard the actual expansion by `predicate`.
    with fmt.indented('if predicate {', '}'):
        # If we're going to delete `inst`, we need to detach its results first
        # so they can be reattached during pattern expansion.
        if not replace_inst:
            fmt.line('pos.func.dfg.clear_results(inst);')

        # Emit the destination pattern.
        for dst in xform.dst.rtl:
            emit_dst_inst(dst, fmt)

        # Delete the original instruction if we didn't have an opportunity to
        # replace it.
        if not replace_inst:
            fmt.line('let removed = pos.remove_inst();')
            fmt.line('debug_assert_eq!(removed, inst);')
        fmt.line('return true;')


def gen_xform_group(xgrp, fmt, type_sets):
    # type: (XFormGroup, Formatter, UniqueTable) -> None
    fmt.doc_comment("Legalize `inst`.")
    fmt.line('#[allow(unused_variables,unused_assignments,non_snake_case)]')
    with fmt.indented('pub fn {}('.format(xgrp.name)):
        fmt.line('inst: ir::Inst,')
        fmt.line('func: &mut ir::Function,')
        fmt.line('cfg: &mut ::flowgraph::ControlFlowGraph,')
        fmt.line('isa: &::isa::TargetIsa,')
    with fmt.indented(') -> bool {', '}'):
        fmt.line('use ir::InstBuilder;')
        fmt.line('use cursor::{Cursor, FuncCursor};')
        fmt.line('let mut pos = FuncCursor::new(func).at_inst(inst);')
        fmt.line('pos.use_srcloc(inst);')

        # Group the xforms by opcode so we can generate a big switch.
        # Preserve ordering.
        xforms = defaultdict(list)  # type: DefaultDict[str, List[XForm]]
        for xform in xgrp.xforms:
            inst = xform.src.rtl[0].expr.inst
            xforms[inst.camel_name].append(xform)

        with fmt.indented('{', '}'):
            with fmt.indented('match pos.func.dfg[inst].opcode() {', '}'):
                for camel_name in sorted(xforms.keys()):
                    with fmt.indented(
                            'ir::Opcode::{} => {{'.format(camel_name), '}'):
                        for xform in xforms[camel_name]:
                            gen_xform(xform, fmt, type_sets)

                # Emit the custom transforms. The Rust compiler will complain
                # about any overlap with the normal xforms.
                for inst, funcname in xgrp.custom.items():
                    with fmt.indented(
                            'ir::Opcode::{} => {{'
                            .format(inst.camel_name), '}'):
                        fmt.format('{}(inst, pos.func, cfg, isa);', funcname)
                        fmt.line('return true;')

                # We'll assume there are uncovered opcodes.
                fmt.line('_ => {},')

        # If we fall through, nothing was expanded. Call the chain if any.
        if xgrp.chain:
            fmt.format('{}(inst, pos.func, cfg, isa)', xgrp.chain.rust_name())
        else:
            fmt.line('false')


def gen_isa(isa, fmt, shared_groups):
    # type: (TargetISA, Formatter, Set[XFormGroup]) -> None
    """
    Generate legalization functions for `isa` and add any shared `XFormGroup`s
    encountered to `shared_groups`.

    Generate `TYPE_SETS` and `LEGALIZE_ACTION` tables.
    """
    type_sets = UniqueTable()
    for xgrp in isa.legalize_codes.keys():
        if xgrp.isa is None:
            shared_groups.add(xgrp)
        else:
            assert xgrp.isa == isa
            gen_xform_group(xgrp, fmt, type_sets)

    gen_typesets_table(fmt, type_sets)

    with fmt.indented(
            'pub static LEGALIZE_ACTIONS: [isa::Legalize; {}] = ['
            .format(len(isa.legalize_codes)), '];'):
        for xgrp in isa.legalize_codes.keys():
            fmt.format('{},', xgrp.rust_name())


def generate(isas, out_dir):
    # type: (Sequence[TargetISA], str) -> None
    shared_groups = set()  # type: Set[XFormGroup]

    for isa in isas:
        fmt = Formatter()
        gen_isa(isa, fmt, shared_groups)
        fmt.update_file('legalize-{}.rs'.format(isa.name), out_dir)

    # Shared xform groups.
    fmt = Formatter()
    type_sets = UniqueTable()
    for xgrp in sorted(shared_groups, key=lambda g: g.name):
        gen_xform_group(xgrp, fmt, type_sets)
    gen_typesets_table(fmt, type_sets)
    fmt.update_file('legalizer.rs', out_dir)