vyre-emit-ptx 0.6.1

PTX text emitter for vyre KernelDescriptor. Produces NVRTC-compatible CUDA assembly.
Documentation
use std::fmt::Write as _;

use smallvec::SmallVec;
use vyre_foundation::ir::BinOp;
use vyre_lower::{KernelBody, KernelOp, KernelOpKind};

use super::facts::EmitFacts;
use super::schedule::{
    is_latency_load, is_schedulable_pure_op, is_scheduling_fence, op_reads_operand,
    operand_is_immediate,
};
use super::BodyCtx;
use crate::reg::PtxType;
use crate::EmitError;

const MAX_LOAD_GAP_FILLERS: usize = 3;

impl BodyCtx<'_> {
    pub(super) fn emit_body(&mut self, body: &KernelBody) -> Result<(), EmitError> {
        let facts = EmitFacts::new(body);
        let mut skip = vec![false; body.ops.len()];
        let mut idx = 0;
        while idx < body.ops.len() {
            if skip[idx] {
                idx += 1;
                continue;
            }
            if let Some(chain) = self.collect_vec_load_chain(body, &facts, idx)? {
                self.emit_vec_load_chain(body, &chain)?;
                for &op_idx in chain.iter().skip(1) {
                    skip[op_idx] = true;
                }
                self.mark_dead_vec_index_ops(body, &facts, &chain, &mut skip)?;
                let blocked_results = chain
                    .iter()
                    .filter_map(|op_idx| body.ops.get(*op_idx).and_then(|op| op.result))
                    .collect::<SmallVec<[u32; 4]>>();
                for _ in 0..MAX_LOAD_GAP_FILLERS {
                    let Some(filler_idx) = self.find_latency_filler_avoiding_results(
                        body,
                        idx,
                        &blocked_results,
                        &skip,
                    ) else {
                        break;
                    };
                    let _ = writeln!(
                        self.text,
                        "    // schedule: hoist independent op#{filler_idx} into vector-load gap after op#{idx}"
                    );
                    self.emit_op(body, &body.ops[filler_idx])?;
                    skip[filler_idx] = true;
                }
                idx += 1;
                continue;
            }
            if let Some(chain) = self.collect_vec_store_chain(body, &facts, idx)? {
                self.emit_vec_store_chain(body, &chain)?;
                for &op_idx in chain.iter().skip(1) {
                    skip[op_idx] = true;
                }
                self.mark_dead_vec_index_ops(body, &facts, &chain, &mut skip)?;
                idx += 1;
                continue;
            }
            if self.should_defer_integer_mul_for_mad(body, &facts, idx) {
                skip[idx] = true;
                idx += 1;
                continue;
            }
            if self.emit_integer_mad_from_add(body, &facts, &body.ops[idx])? {
                idx += 1;
                continue;
            }
            self.emit_op(body, &body.ops[idx])?;
            if is_latency_load(&body.ops[idx]) {
                let blocked_results = body.ops[idx]
                    .result
                    .into_iter()
                    .collect::<SmallVec<[u32; 1]>>();
                for _ in 0..MAX_LOAD_GAP_FILLERS {
                    let Some(filler_idx) = self.find_latency_filler_avoiding_results(
                        body,
                        idx,
                        &blocked_results,
                        &skip,
                    ) else {
                        break;
                    };
                    let _ = writeln!(
                        self.text,
                        "    // schedule: hoist independent op#{filler_idx} into load-use gap after op#{idx}"
                    );
                    self.emit_op(body, &body.ops[filler_idx])?;
                    skip[filler_idx] = true;
                }
            }
            idx += 1;
        }
        Ok(())
    }

    fn find_latency_filler_avoiding_results(
        &self,
        body: &KernelBody,
        anchor_idx: usize,
        blocked_results: &[u32],
        skip: &[bool],
    ) -> Option<usize> {
        if blocked_results.is_empty() {
            return None;
        }
        let upper = body.ops.len().min(anchor_idx.saturating_add(10));
        for candidate_idx in anchor_idx + 1..upper {
            if skip.get(candidate_idx).copied().unwrap_or(false) {
                continue;
            }
            let candidate = &body.ops[candidate_idx];
            if is_scheduling_fence(candidate) {
                break;
            }
            if blocked_results
                .iter()
                .any(|result| op_reads_operand(candidate, *result))
            {
                continue;
            }
            if self.is_ready_pure_op(candidate) {
                return Some(candidate_idx);
            }
        }
        None
    }

    fn is_ready_pure_op(&self, op: &KernelOp) -> bool {
        if !is_schedulable_pure_op(op) {
            return false;
        }
        if let Some(result) = op.result {
            if self.operand_to_reg.contains_key(&result) {
                return false;
            }
        }
        op.operands.iter().all(|operand| {
            operand_is_immediate(op, *operand) || self.operand_to_reg.contains_key(operand)
        })
    }

    fn should_defer_integer_mul_for_mad(
        &self,
        body: &KernelBody,
        facts: &EmitFacts,
        op_idx: usize,
    ) -> bool {
        let _ = (body, facts, op_idx);
        false
    }

    fn emit_integer_mad_from_add(
        &mut self,
        body: &KernelBody,
        facts: &EmitFacts,
        op: &KernelOp,
    ) -> Result<bool, EmitError> {
        if !matches!(
            op.kind,
            KernelOpKind::BinOpKind(BinOp::Add | BinOp::WrappingAdd)
        ) || op.operands.len() != 2
        {
            return Ok(false);
        }

        let lhs = op.operands[0];
        let rhs = op.operands[1];
        let mad = self
            .integer_mad_parts(body, facts, lhs, rhs)
            .or_else(|| self.integer_mad_parts(body, facts, rhs, lhs));
        let Some((a, b, c, ptx_suffix, out_ty)) = mad else {
            return Ok(false);
        };

        let out = self.alloc(out_ty);
        let _ = writeln!(
            self.text,
            "    mad.lo.{ptx_suffix}    {out}, {a}, {b}, {c};"
        );
        self.bind_result(op, out)?;
        Ok(true)
    }

    fn integer_mad_parts(
        &self,
        body: &KernelBody,
        facts: &EmitFacts,
        mul_result_id: u32,
        addend_id: u32,
    ) -> Option<(
        crate::reg::Reg,
        crate::reg::Reg,
        crate::reg::Reg,
        &'static str,
        PtxType,
    )> {
        if facts.result_use_count(mul_result_id) != 1 {
            return None;
        }
        let producer_idx = facts.producer_idx(mul_result_id)?;
        let producer = body.ops.get(producer_idx)?;
        if !matches!(producer.kind, KernelOpKind::BinOpKind(BinOp::Mul))
            || producer.operands.len() != 2
            || producer.result != Some(mul_result_id)
        {
            return None;
        }
        let a = self.operand_to_reg.get(&producer.operands[0]).copied()?;
        let b = self.operand_to_reg.get(&producer.operands[1]).copied()?;
        let c = self.operand_to_reg.get(&addend_id).copied()?;
        if a.0 != b.0 || a.0 != c.0 {
            return None;
        }
        match a.0 {
            PtxType::U32 => Some((a, b, c, "u32", PtxType::U32)),
            PtxType::I32 => Some((a, b, c, "s32", PtxType::I32)),
            _ => None,
        }
    }
}