vyre-emit-ptx 0.6.1

PTX text emitter for vyre KernelDescriptor. Produces NVRTC-compatible CUDA assembly.
Documentation
use std::fmt::Write as _;

use vyre_lower::{KernelBody, KernelOp, KernelOpKind};

use super::schedule::is_schedulable_pure_op;
use super::BodyCtx;
use crate::reg::PtxType;
use crate::EmitError;

const MAX_PREDICATED_BODY_OPS: usize = 4;

impl BodyCtx<'_> {
    pub(super) fn emit_region(
        &mut self,
        body: &KernelBody,
        op: &KernelOp,
        generator: &str,
    ) -> Result<(), EmitError> {
        let _ = writeln!(self.text, "    // region: {generator}");
        if let Some(child_id) = op.operands.first() {
            if let Some(child) = body.child_bodies.get(*child_id as usize) {
                self.emit_body(child)?;
            }
        }
        Ok(())
    }

    pub(super) fn emit_structured_block(
        &mut self,
        body: &KernelBody,
        op: &KernelOp,
    ) -> Result<(), EmitError> {
        if let Some(child_id) = op.operands.first() {
            if let Some(child) = body.child_bodies.get(*child_id as usize) {
                self.emit_body(child)?;
            }
        }
        Ok(())
    }

    pub(super) fn emit_structured_if_then(
        &mut self,
        body: &KernelBody,
        op: &KernelOp,
    ) -> Result<(), EmitError> {
        let cond_id = *op
            .operands
            .first()
            .ok_or_else(|| EmitError::InvalidDescriptor("StructuredIfThen missing cond".into()))?;
        let body_id = *op.operands.get(1).ok_or_else(|| {
            EmitError::InvalidDescriptor("StructuredIfThen missing body index".into())
        })?;
        let cond_reg = self.lookup_operand(cond_id)?;
        if let Some(child) = body.child_bodies.get(body_id as usize) {
            if child.ops.len() <= MAX_PREDICATED_BODY_OPS {
                let pred = self.pred_from_boolish(cond_reg);
                if self.emit_predicated_store_body(child, pred, false)? {
                    return Ok(());
                }
            }
        }
        let branch_pred = self.pred_from_boolish(cond_reg);
        let end_label = self.alloc_label("if_end");
        let _ = writeln!(self.text, "    @!{branch_pred} bra {end_label};");
        if let Some(child) = body.child_bodies.get(body_id as usize) {
            self.emit_body(child)?;
        }
        let _ = writeln!(self.text, "{end_label}:");
        Ok(())
    }

    pub(super) fn emit_structured_if_then_else(
        &mut self,
        body: &KernelBody,
        op: &KernelOp,
    ) -> Result<(), EmitError> {
        let cond_id = *op.operands.first().ok_or_else(|| {
            EmitError::InvalidDescriptor("StructuredIfThenElse missing cond".into())
        })?;
        let then_id = *op.operands.get(1).ok_or_else(|| {
            EmitError::InvalidDescriptor("StructuredIfThenElse missing then index".into())
        })?;
        let else_id = *op.operands.get(2).ok_or_else(|| {
            EmitError::InvalidDescriptor("StructuredIfThenElse missing else index".into())
        })?;
        let cond_reg = self.lookup_operand(cond_id)?;
        if let (Some(then_body), Some(else_body)) = (
            body.child_bodies.get(then_id as usize),
            body.child_bodies.get(else_id as usize),
        ) {
            if then_body.ops.len() <= MAX_PREDICATED_BODY_OPS
                && else_body.ops.len() <= MAX_PREDICATED_BODY_OPS
                && predicated_store_body_supported(then_body)
                && predicated_store_body_supported(else_body)
            {
                let pred = self.pred_from_boolish(cond_reg);
                let then_emitted = self.emit_predicated_store_body(then_body, pred, false)?;
                let else_emitted = self.emit_predicated_store_body(else_body, pred, true)?;
                if then_emitted && else_emitted {
                    return Ok(());
                }
            }
        }
        let branch_pred = self.pred_from_boolish(cond_reg);
        let else_label = self.alloc_label("if_else");
        let end_label = self.alloc_label("if_end");
        let _ = writeln!(self.text, "    @!{branch_pred} bra {else_label};");
        if let Some(child) = body.child_bodies.get(then_id as usize) {
            self.emit_body(child)?;
        }
        let _ = writeln!(self.text, "    bra {end_label};");
        let _ = writeln!(self.text, "{else_label}:");
        if let Some(child) = body.child_bodies.get(else_id as usize) {
            self.emit_body(child)?;
        }
        let _ = writeln!(self.text, "{end_label}:");
        Ok(())
    }

    fn emit_predicated_store_body(
        &mut self,
        child: &KernelBody,
        pred: crate::reg::Reg,
        negate: bool,
    ) -> Result<bool, EmitError> {
        if !predicated_store_body_supported(child) {
            return Ok(false);
        }
        let mut emitted_store = false;
        for op in &child.ops {
            if matches!(
                op.kind,
                KernelOpKind::StoreGlobal | KernelOpKind::StoreShared
            ) {
                emitted_store |= self.emit_predicated_store(op, pred, negate)?;
            } else {
                self.emit_op(child, op)?;
            }
        }
        Ok(emitted_store)
    }

    pub(super) fn emit_structured_for_loop(
        &mut self,
        body: &KernelBody,
        op: &KernelOp,
        loop_var: &str,
    ) -> Result<(), EmitError> {
        let lo_id = *op
            .operands
            .first()
            .ok_or_else(|| EmitError::InvalidDescriptor("StructuredForLoop missing lo".into()))?;
        let hi_id = *op
            .operands
            .get(1)
            .ok_or_else(|| EmitError::InvalidDescriptor("StructuredForLoop missing hi".into()))?;
        let body_id = *op.operands.get(2).ok_or_else(|| {
            EmitError::InvalidDescriptor("StructuredForLoop missing body index".into())
        })?;
        let lo_reg = self.lookup_operand(lo_id)?;
        let hi_reg = self.lookup_operand(hi_id)?;
        let var_reg = self.alloc(PtxType::U32);
        let cond_reg = self.alloc(PtxType::Bool);
        let one_reg = self.alloc(PtxType::U32);
        let head = self.alloc_label("for_head");
        let exit = self.alloc_label("for_exit");
        let _ = writeln!(self.text, "    // for {loop_var} in [{lo_reg}, {hi_reg})");
        let _ = writeln!(self.text, "    mov.u32    {var_reg}, {lo_reg};");
        let _ = writeln!(self.text, "    mov.u32    {one_reg}, 1;");
        let _ = writeln!(self.text, "{head}:");
        let _ = writeln!(
            self.text,
            "    setp.ge.u32 {cond_reg}, {var_reg}, {hi_reg};"
        );
        let _ = writeln!(self.text, "    @{cond_reg} bra {exit};");
        self.loop_indices.insert(loop_var.into(), var_reg);
        if let Some(child) = body.child_bodies.get(body_id as usize) {
            self.emit_body(child)?;
        }
        self.loop_indices.remove(loop_var);
        let _ = writeln!(self.text, "    add.u32    {var_reg}, {var_reg}, {one_reg};");
        let _ = writeln!(self.text, "    bra {head};");
        let _ = writeln!(self.text, "{exit}:");
        Ok(())
    }

    pub(super) fn emit_loop_index(
        &mut self,
        op: &KernelOp,
        loop_var: &str,
    ) -> Result<(), EmitError> {
        let reg = *self.loop_indices.get(loop_var).ok_or_else(|| {
            EmitError::InvalidDescriptor(format!(
                "LoopIndex `{loop_var}` appeared outside its StructuredForLoop"
            ))
        })?;
        self.bind_result(op, reg)
    }
}

fn predicated_store_body_supported(body: &KernelBody) -> bool {
    let mut has_store = false;
    for op in &body.ops {
        if matches!(
            op.kind,
            KernelOpKind::StoreGlobal | KernelOpKind::StoreShared
        ) {
            has_store = true;
            continue;
        }
        if !is_schedulable_pure_op(op) {
            return false;
        }
    }
    has_store
}