trueno-gpu 0.4.33

Pure Rust PTX generation for NVIDIA CUDA - no LLVM, no nvcc
Documentation
//! PTX Module, Kernel, and KernelParam types.
//!
//! Extracted from mod.rs for PMAT File Health compliance.

use std::fmt::Write;

use super::emit::write_instruction;
use crate::error::Result;
use crate::ptx::instructions::PtxInstruction;
use crate::ptx::registers::RegisterAllocator;
use crate::ptx::types::PtxType;
use crate::ptx::{validate_target, validate_version};

use super::KernelBuilder;

/// PTX Module builder
#[derive(Debug, Clone)]
pub struct PtxModule {
    /// PTX version (major, minor)
    pub(crate) version: (u32, u32),
    /// Target compute capability (e.g., `sm_70`)
    target: String,
    /// Address size (32 or 64)
    address_size: u32,
    /// Kernels in this module
    kernels: Vec<PtxKernel>,
}

impl PtxModule {
    /// Create a new PTX module with defaults
    #[must_use]
    pub fn new() -> Self {
        Self { version: (8, 0), target: "sm_70".to_string(), address_size: 64, kernels: Vec::new() }
    }

    /// Set PTX version
    #[must_use]
    pub fn version(mut self, major: u32, minor: u32) -> Self {
        self.version = (major, minor);
        self
    }

    /// Get PTX version
    #[must_use]
    pub const fn get_version(&self) -> (u32, u32) {
        self.version
    }

    /// Set target compute capability
    #[must_use]
    pub fn target(mut self, target: impl Into<String>) -> Self {
        self.target = target.into();
        self
    }

    /// Get target
    #[must_use]
    pub fn get_target(&self) -> &str {
        &self.target
    }

    /// Set address size
    #[must_use]
    pub const fn address_size(mut self, size: u32) -> Self {
        self.address_size = size;
        self
    }

    /// Get address size
    #[must_use]
    pub const fn get_address_size(&self) -> u32 {
        self.address_size
    }

    /// Add a kernel to the module
    #[must_use]
    pub fn add_kernel(mut self, kernel: PtxKernel) -> Self {
        self.kernels.push(kernel);
        self
    }

    /// Validate the module configuration
    ///
    /// # Errors
    ///
    /// Returns an error if:
    /// - The PTX version is below the minimum supported (7.0)
    /// - The target compute capability is invalid
    pub fn validate(&self) -> Result<()> {
        validate_version(self.version.0, self.version.1)?;
        validate_target(&self.target)?;
        Ok(())
    }

    /// Emit PTX source code
    #[must_use]
    pub fn emit(&self) -> String {
        let mut ptx = String::new();

        // Header comment
        ptx.push_str("// Generated by trueno-gpu\n");
        ptx.push_str("// Pure Rust PTX generation - no external dependencies\n\n");

        // Version directive
        let _ = writeln!(ptx, ".version {}.{}", self.version.0, self.version.1);

        // Target directive
        let _ = writeln!(ptx, ".target {}", self.target);

        // Address size directive
        let _ = writeln!(ptx, ".address_size {}\n", self.address_size);

        // Emit each kernel
        for kernel in &self.kernels {
            ptx.push_str(&kernel.emit());
            ptx.push('\n');
        }

        ptx
    }
}

impl Default for PtxModule {
    fn default() -> Self {
        Self::new()
    }
}

/// Kernel parameter
#[derive(Debug, Clone)]
pub struct KernelParam {
    /// Parameter type
    pub ty: PtxType,
    /// Parameter name
    pub name: String,
}

/// PTX Kernel builder
#[derive(Debug, Clone)]
pub struct PtxKernel {
    /// Kernel name
    name: String,
    /// Parameters
    pub(crate) params: Vec<KernelParam>,
    /// Shared memory size in bytes
    shared_memory: usize,
    /// Maximum registers per thread (`.maxnreg` directive)
    ///
    /// When set, ptxas allocates up to this many registers per thread,
    /// potentially limiting occupancy but maximizing ILP.
    /// llama.cpp uses `__launch_bounds__(128, 1)` which achieves the same
    /// effect — 1 block/SM means max registers available.
    max_regs: Option<u32>,
    /// Instructions
    instructions: Vec<PtxInstruction>,
    /// Register allocator
    registers: RegisterAllocator,
    /// Labels
    labels: Vec<String>,
}

impl PtxKernel {
    /// Create a new kernel
    #[must_use]
    pub fn new(name: impl Into<String>) -> Self {
        Self {
            name: name.into(),
            params: Vec::new(),
            shared_memory: 0,
            max_regs: None,
            instructions: Vec::new(),
            registers: RegisterAllocator::new(),
            labels: Vec::new(),
        }
    }

    /// Add a parameter
    #[must_use]
    pub fn param(mut self, ty: PtxType, name: impl Into<String>) -> Self {
        self.params.push(KernelParam { ty, name: name.into() });
        self
    }

    /// Set shared memory size
    #[must_use]
    pub const fn shared_memory(mut self, bytes: usize) -> Self {
        self.shared_memory = bytes;
        self
    }

    /// Get shared memory size
    #[must_use]
    pub const fn shared_memory_bytes(&self) -> usize {
        self.shared_memory
    }

    /// Set maximum registers per thread (`.maxnreg` directive)
    ///
    /// Tells ptxas to use up to `n` registers per thread. Higher values
    /// reduce occupancy but increase ILP and register availability.
    /// Use 255 to match llama.cpp's `__launch_bounds__(N, 1)` strategy.
    #[must_use]
    pub const fn max_regs(mut self, n: u32) -> Self {
        self.max_regs = Some(n);
        self
    }

    /// Build kernel body with a closure
    #[must_use]
    pub fn build<F>(mut self, builder_fn: F) -> Self
    where
        F: FnOnce(&mut KernelBuilder<'_>),
    {
        let mut builder = KernelBuilder::new(&mut self.registers);
        builder_fn(&mut builder);
        self.instructions = builder.instructions;
        self.labels = builder.labels;
        self
    }

    /// Build kernel body with optimization passes (Issue #72, #73)
    ///
    /// Applies FMA fusion and tile validation passes to the instruction sequence.
    ///
    /// # Arguments
    ///
    /// * `builder_fn` - Closure that builds the kernel body
    ///
    /// # Returns
    ///
    /// Result containing the kernel or an error if tile validation fails
    ///
    /// # cuda-tile-behavior.md References
    ///
    /// - Section 3.5: FMA Fusion Detection
    /// - Section 3.4: Tile Dimension Constraints
    pub fn build_optimized<F>(mut self, builder_fn: F) -> crate::error::Result<Self>
    where
        F: FnOnce(&mut KernelBuilder<'_>),
    {
        let mut builder = KernelBuilder::new(&mut self.registers);
        builder_fn(&mut builder);

        // Apply optimization passes (FMA fusion + tile validation)
        self.instructions = crate::ptx::optimize::optimize(builder.instructions)?;
        self.labels = builder.labels;
        Ok(self)
    }

    /// Emit kernel PTX
    #[must_use]
    pub fn emit(&self) -> String {
        use std::fmt::Write;
        // Pre-allocate with estimated size: ~100 bytes per instruction + header overhead
        let estimated_size = 512 + self.instructions.len() * 100;
        let mut ptx = String::with_capacity(estimated_size);

        // Kernel entry point
        let _ = writeln!(ptx, ".visible .entry {}(", self.name);

        // Parameters
        for (i, param) in self.params.iter().enumerate() {
            let comma = if i < self.params.len() - 1 { "," } else { "" };
            let _ =
                writeln!(ptx, "    .param {} {}{}", param.ty.to_ptx_string(), param.name, comma);
        }

        // Performance directives go between closing paren and opening brace
        if let Some(max) = self.max_regs {
            let _ = writeln!(ptx, ") .maxnreg {} {{", max);
        } else {
            ptx.push_str(") {\n");
        }

        // Register declarations
        ptx.push_str(&self.registers.emit_declarations());

        // Shared memory declaration (if any)
        if self.shared_memory > 0 {
            let _ = writeln!(ptx, "    .shared .align 16 .b8 smem[{}];", self.shared_memory);
        }

        ptx.push('\n');

        // Instructions - write directly to ptx buffer to avoid allocations
        for instr in &self.instructions {
            write_instruction(instr, &mut ptx);
        }

        ptx.push_str("}\n");
        ptx
    }
}