vyre-emit-ptx 0.6.1

PTX text emitter for vyre KernelDescriptor. Produces NVRTC-compatible CUDA assembly.
Documentation
use std::fmt::Write as _;

use vyre_lower::{KernelOp, MatrixMmaElement, MatrixMmaLayout, MatrixMmaShape};

use super::BodyCtx;
use crate::reg::{PtxType, Reg};
use crate::EmitError;

impl BodyCtx<'_> {
    pub(super) fn emit_matrix_mma(
        &mut self,
        op: &KernelOp,
        shape: MatrixMmaShape,
        a_layout: MatrixMmaLayout,
        b_layout: MatrixMmaLayout,
        a_type: MatrixMmaElement,
        b_type: MatrixMmaElement,
        accum_type: MatrixMmaElement,
    ) -> Result<[Reg; 4], EmitError> {
        if op.operands.len() < 10 {
            return Err(EmitError::InvalidDescriptor(format!(
                "MatrixMma requires 10 operands but got {}. Fix: pass four A fragment regs, two B fragment regs, and four accumulator regs.",
                op.operands.len()
            )));
        }
        if shape != MatrixMmaShape::M16N8K16
            || a_layout != MatrixMmaLayout::RowMajor
            || b_layout != MatrixMmaLayout::ColMajor
            || a_type != MatrixMmaElement::F16
            || b_type != MatrixMmaElement::F16
            || accum_type != MatrixMmaElement::F32
        {
            return Err(EmitError::UnsupportedOp(KernelOp {
                kind: op.kind.clone(),
                operands: op.operands.clone(),
                result: op.result,
            }));
        }
        if !self.options.target.supports_wmma_f16() {
            return Err(EmitError::UnsupportedOp(KernelOp {
                kind: op.kind.clone(),
                operands: op.operands.clone(),
                result: op.result,
            }));
        }

        let a = [
            self.lookup_operand(op.operands[0])?,
            self.lookup_operand(op.operands[1])?,
            self.lookup_operand(op.operands[2])?,
            self.lookup_operand(op.operands[3])?,
        ];
        let b = [
            self.lookup_operand(op.operands[4])?,
            self.lookup_operand(op.operands[5])?,
        ];
        let c = [
            self.lookup_operand(op.operands[6])?,
            self.lookup_operand(op.operands[7])?,
            self.lookup_operand(op.operands[8])?,
            self.lookup_operand(op.operands[9])?,
        ];
        for reg in a.iter().chain(b.iter()) {
            if reg.0 != PtxType::U32 {
                return Err(EmitError::InvalidDescriptor(format!(
                    "MatrixMma f16 fragments must be packed u32 registers; got {reg}. Fix: pack two f16 lanes per u32 fragment operand."
                )));
            }
        }
        for reg in &c {
            if reg.0 != PtxType::F32 {
                return Err(EmitError::InvalidDescriptor(format!(
                    "MatrixMma f32 accumulators must be f32 registers; got {reg}. Fix: pass f32 accumulator operands."
                )));
            }
        }

        let d = [
            self.alloc(PtxType::F32),
            self.alloc(PtxType::F32),
            self.alloc(PtxType::F32),
            self.alloc(PtxType::F32),
        ];
        let _ = writeln!(
            self.text,
            "    mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32    {{{}, {}, {}, {}}}, {{{}, {}, {}, {}}}, {{{}, {}}}, {{{}, {}, {}, {}}};",
            d[0],
            d[1],
            d[2],
            d[3],
            a[0],
            a[1],
            a[2],
            a[3],
            b[0],
            b[1],
            c[0],
            c[1],
            c[2],
            c[3],
        );
        Ok(d)
    }
}