vyre-emit-ptx 0.6.1

PTX text emitter for vyre KernelDescriptor. Produces NVRTC-compatible CUDA assembly.
Documentation
use std::fmt::Write as _;

use vyre_lower::{KernelOp, LiteralValue};

use super::BodyCtx;
use crate::reg::{PtxType, Reg};
use crate::EmitError;

impl BodyCtx<'_> {
    pub(super) fn finish_with_return(&mut self) {
        if !self.pending_cp_async_tags.is_empty() {
            let _ = writeln!(
                self.text,
                "    // implicit cp.async drain for descriptors missing AsyncWait"
            );
            let _ = writeln!(self.text, "    cp.async.wait_group 0;");
            let _ = writeln!(self.text, "    membar.cta;");
            self.pending_cp_async_tags.clear();
        }
        self.text.push_str("$L_exit:\n");
        self.text.push_str("    ret;\n");
    }

    pub(super) fn alloc_literal(&mut self, lit: &LiteralValue) -> (Reg, String) {
        match lit {
            LiteralValue::U32(v) => {
                let reg = self.alloc(PtxType::U32);
                (reg, format!("{v}"))
            }
            LiteralValue::I32(v) => {
                let reg = self.alloc(PtxType::I32);
                (reg, format!("{v}"))
            }
            LiteralValue::F32(v) => {
                let reg = self.alloc(PtxType::F32);
                let bits = v.to_bits();
                (reg, format!("0f{bits:08X}"))
            }
            LiteralValue::Bool(v) => {
                let reg = self.alloc(PtxType::Bool);
                (reg, if *v { "1".into() } else { "0".into() })
            }
        }
    }

    pub(super) fn lookup_operand(&self, op_id: u32) -> Result<Reg, EmitError> {
        self.operand_to_reg.get(&op_id).copied().ok_or_else(|| {
            EmitError::InvalidDescriptor(format!("operand id {op_id} not yet emitted"))
        })
    }

    pub(super) fn bind_result(&mut self, op: &KernelOp, reg: Reg) -> Result<(), EmitError> {
        if let Some(result) = op.result {
            self.operand_to_reg.insert(result, reg);
        }
        Ok(())
    }

    pub(super) fn bind_consecutive_results(
        &mut self,
        op: &KernelOp,
        regs: &[Reg],
    ) -> Result<(), EmitError> {
        let Some(base) = op.result else {
            return Err(EmitError::InvalidDescriptor(
                "MatrixMma missing base result id. Fix: set result to the first accumulator id."
                    .into(),
            ));
        };
        for (offset, reg) in regs.iter().copied().enumerate() {
            let id = base.checked_add(offset as u32).ok_or_else(|| {
                EmitError::InvalidDescriptor(
                    "MatrixMma result id range overflows u32. Fix: allocate a lower base id."
                        .into(),
                )
            })?;
            self.operand_to_reg.insert(id, reg);
        }
        Ok(())
    }
}