vyre-emit-ptx 0.6.1

PTX text emitter for vyre KernelDescriptor. Produces NVRTC-compatible CUDA assembly.
Documentation
use std::fmt::Write as _;

use rustc_hash::FxHashSet;
use vyre_lower::{KernelDescriptor, MemoryClass, TRAP_SIDECAR_NAME};

use super::names::sanitize_param_name;
use super::sizing::{body_op_count_recursive, estimate_body_text_capacity};
use super::BodyCtx;
use crate::{EmitError, PtxEmitOptions};

pub(super) struct ModuleBuilder {
    options: PtxEmitOptions,
    text: String,
}

impl ModuleBuilder {
    pub(super) fn new(options: PtxEmitOptions, text_capacity: usize) -> Self {
        Self {
            options,
            text: String::with_capacity(text_capacity),
        }
    }

    pub(super) fn write_preamble(&mut self) {
        let target = self.options.target;
        // PTX `.version` must rise alongside `.target` because each
        // new sm_NN ISA bumps the minimum PTX ISA version that can
        // name it. Driver/JIT rejects mismatched pairs with
        // CUDA_ERROR_INVALID_PTX. Mapping per the NVIDIA PTX ISA
        // tables (see CUDA 12.x Programming Guide, "PTX ISA
        // Versions"):
        //   sm_120 (Blackwell-2)       → PTX 8.7+
        //   sm_100 / sm_101 (Blackwell)→ PTX 8.6+
        //   sm_90  (Hopper)            → PTX 8.0+
        //   sm_70..sm_89 (Volta–Ada)   → PTX 6.0+ (8.5 covers all)
        let ptx_version = match target.major * 10 + target.minor {
            120.. => "8.7",
            100..=119 => "8.6",
            90..=99 => "8.0",
            _ => "8.5",
        };
        let _ = writeln!(
            self.text,
            "//\n// Generated by vyre-emit-ptx (target sm_{}{})\n//\n.version {ptx_version}\n.target sm_{}{}\n.address_size 64\n",
            target.major, target.minor, target.major, target.minor
        );
    }

    pub(super) fn write_entry_point(&mut self, desc: &KernelDescriptor) -> Result<(), EmitError> {
        self.text.push_str(".visible .entry main(\n");
        let mut first = true;
        for binding in &desc.bindings.slots {
            if matches!(binding.memory_class, MemoryClass::Shared)
                || binding.name == TRAP_SIDECAR_NAME
            {
                continue;
            }
            if !first {
                self.text.push_str(",\n");
            }
            first = false;
            let _ = write!(
                self.text,
                "    .param .u64 _arg_{}",
                sanitize_param_name(&binding.name, binding.slot)
            );
        }
        if !first {
            self.text.push_str(",\n");
        }
        self.text.push_str("    .param .u64 params_buf");
        self.text.push_str("\n) {\n");

        let read_only_cache_slots = vyre_lower::analyses::analyze_texture_promote(desc)
            .candidates
            .into_iter()
            .map(|candidate| candidate.binding_slot)
            .collect::<FxHashSet<_>>();
        let mut body_ctx = BodyCtx::new(
            &desc.bindings,
            self.options,
            read_only_cache_slots,
            estimate_body_text_capacity(&desc.body, &desc.bindings),
            body_op_count_recursive(&desc.body),
        );
        body_ctx.preload_bindings(desc)?;
        body_ctx.emit_body(&desc.body)?;
        body_ctx.finish_with_return();

        let mut decl_block = String::with_capacity(192);
        decl_block.push_str("    // register declarations\n");
        if body_ctx.next_pred > 0 {
            let _ = writeln!(decl_block, "    .reg .pred  %p<{}>;", body_ctx.next_pred);
        }
        if body_ctx.next_b16 > 0 {
            let _ = writeln!(decl_block, "    .reg .b16   %h<{}>;", body_ctx.next_b16);
        }
        if body_ctx.next_u32 > 0 {
            let _ = writeln!(decl_block, "    .reg .u32   %r<{}>;", body_ctx.next_u32);
        }
        if body_ctx.next_i32 > 0 {
            let _ = writeln!(decl_block, "    .reg .s32   %s<{}>;", body_ctx.next_i32);
        }
        if body_ctx.next_f32 > 0 {
            let _ = writeln!(decl_block, "    .reg .f32   %f<{}>;", body_ctx.next_f32);
        }
        if body_ctx.next_u64 > 0 {
            let _ = writeln!(decl_block, "    .reg .u64   %rd<{}>;", body_ctx.next_u64);
        }
        decl_block.push('\n');

        self.text.push_str(&decl_block);
        self.text.push_str(&body_ctx.text);
        self.text.push_str("}\n");
        Ok(())
    }

    pub(super) fn finish(self) -> String {
        self.text
    }
}