Skip to main content

BodyBuilder

Struct BodyBuilder 

Source
pub struct BodyBuilder<'a> { /* private fields */ }
Expand description

Instruction emission API for building the body of a PTX kernel.

BodyBuilder is not constructed directly — it is provided as a mutable reference inside the closure passed to KernelBuilder::body.

Most methods follow a consistent pattern: allocate destination register(s), push the corresponding Instruction variant, and return the destination register so it can be used as an operand to subsequent instructions.

Implementations§

Source§

impl BodyBuilder<'_>

Source

pub fn atom_global_add_f32(&mut self, addr: Register, val: Register) -> Register

Atomic add on global memory (f32): returns old value at [addr], stores old + val.

Source

pub fn atom_global_add_u32(&mut self, addr: Register, val: Register) -> Register

Atomic add on global memory (u32): returns old value at [addr], stores old + val.

Source

pub fn atom_global_add_u64(&mut self, addr: Register, val: Register) -> Register

Atomic add on global memory (u64): returns old value at [addr], stores old + val.

Source

pub fn atom_global_add_f64(&mut self, addr: Register, val: Register) -> Register

Atomic add on global memory (f64): returns old value at [addr], stores old + val.

Source

pub fn atom_global_cas_u32( &mut self, addr: Register, compare: Register, value: Register, ) -> Register

Atomic compare-and-swap on global memory (u32).

If [addr] == compare, stores value. Returns the old value at [addr].

Source

pub fn atom_global_cas_u64( &mut self, addr: Register, compare: Register, value: Register, ) -> Register

Atomic compare-and-swap on global memory (u64).

If [addr] == compare, stores value. Returns the old value at [addr].

Source

pub fn atom_global_exch_u32( &mut self, addr: Register, val: Register, ) -> Register

Atomic exchange on global memory (u32): stores val, returns old value.

Source

pub fn atom_global_min_u32(&mut self, addr: Register, val: Register) -> Register

Atomic min on global memory (u32).

Source

pub fn atom_global_max_u32(&mut self, addr: Register, val: Register) -> Register

Atomic max on global memory (u32).

Source

pub fn atom_global_min_s32(&mut self, addr: Register, val: Register) -> Register

Atomic min on global memory (s32).

Source

pub fn atom_global_max_s32(&mut self, addr: Register, val: Register) -> Register

Atomic max on global memory (s32).

Source

pub fn atom_global_and_b32(&mut self, addr: Register, val: Register) -> Register

Atomic bitwise AND on global memory (b32).

Source

pub fn atom_global_or_b32(&mut self, addr: Register, val: Register) -> Register

Atomic bitwise OR on global memory (b32).

Source

pub fn atom_global_xor_b32(&mut self, addr: Register, val: Register) -> Register

Atomic bitwise XOR on global memory (b32).

Source

pub fn atom_shared_add_f32(&mut self, addr: Register, val: Register) -> Register

Atomic add on shared memory (f32): critical for block-level reductions.

Source

pub fn atom_shared_add_u32(&mut self, addr: Register, val: Register) -> Register

Atomic add on shared memory (u32).

Source

pub fn red_global_add_f32(&mut self, addr: Register, val: Register)

Atomic reduction (fire-and-forget) add on global memory (f32).

Unlike atom, this does not return the old value and may be faster.

Source

pub fn red_global_add_u32(&mut self, addr: Register, val: Register)

Atomic reduction (fire-and-forget) add on global memory (u32).

Source

pub fn tex_1d(&mut self, ty: PtxType, tex_ref: &str, coord: Operand) -> Register

Emits a 1D texture fetch instruction.

Fetches a texel from the named texture reference at the given integer coordinate. Returns the destination register.

Emits: tex.1d.v4.{ty}.s32 dst, [tex_ref, {coord}];

Source

pub fn tex_2d( &mut self, ty: PtxType, tex_ref: &str, coord_x: Operand, coord_y: Operand, ) -> Register

Emits a 2D texture fetch instruction.

Fetches a texel from the named texture reference at the given (x, y) integer coordinates. Returns the destination register.

Emits: tex.2d.v4.{ty}.s32 dst, [tex_ref, {coord_x, coord_y}];

Source

pub fn tex_3d( &mut self, ty: PtxType, tex_ref: &str, coord_x: Operand, coord_y: Operand, coord_z: Operand, ) -> Register

Emits a 3D texture fetch instruction.

Fetches a texel from the named texture reference at the given (x, y, z) integer coordinates. Returns the destination register.

Emits: tex.3d.v4.{ty}.s32 dst, [tex_ref, {coord_x, coord_y, coord_z}];

Source

pub fn surf_load( &mut self, ty: PtxType, surf_ref: &str, coord: Operand, ) -> Register

Emits a 1D surface load instruction.

Loads a value from the named surface reference at the given coordinate. Returns the destination register.

Emits: suld.b.1d.{ty} dst, [surf_ref, {coord}];

Source

pub fn surf_store( &mut self, ty: PtxType, surf_ref: &str, coord: Operand, src: Register, )

Emits a 1D surface store instruction.

Stores a value to the named surface reference at the given coordinate.

Emits: sust.b.1d.{ty} [surf_ref, {coord}], src;

Source

pub fn redux_add_u32(&mut self, src: &str) -> Result<String, PtxGenError>

Warp-level sum reduction on a u32 value (SM >= 80).

Source

pub fn redux_max_u32(&mut self, src: &str) -> Result<String, PtxGenError>

Warp-level max reduction on a u32 value (SM >= 80).

Source

pub fn redux_min_u32(&mut self, src: &str) -> Result<String, PtxGenError>

Warp-level min reduction on a u32 value (SM >= 80).

Source

pub fn stmatrix_m8n8x4( &mut self, addr: &str, src: &str, ) -> Result<(), PtxGenError>

Store matrix m8n8x4 to shared memory (SM >= 90).

Source

pub fn elect_sync(&mut self) -> Result<String, PtxGenError>

Elect a single warp leader (SM >= 90). Returns predicate register name.

Source

pub fn setmaxnreg_inc(&mut self, count: u32) -> Result<(), PtxGenError>

Increase the maximum register count (SM >= 90).

Source

pub fn setmaxnreg_dec(&mut self, count: u32) -> Result<(), PtxGenError>

Decrease the maximum register count (SM >= 90).

Source

pub fn griddepcontrol_launch_dependents(&mut self) -> Result<(), PtxGenError>

Signal that dependent grids may launch (SM >= 90).

Source

pub fn griddepcontrol_wait(&mut self) -> Result<(), PtxGenError>

Wait for grid dependencies to complete (SM >= 90).

Source

pub fn fence_proxy_async(&mut self, scope: &str) -> Result<(), PtxGenError>

Emit a proxy fence for async operations.

Source

pub fn mbarrier_init( &mut self, addr: &str, count: &str, ) -> Result<(), PtxGenError>

Initialize an mbarrier in shared memory (SM >= 90).

Source

pub fn mbarrier_arrive(&mut self, addr: &str) -> Result<(), PtxGenError>

Signal arrival at an mbarrier (SM >= 90).

Source

pub fn mbarrier_wait( &mut self, addr: &str, phase: &str, ) -> Result<(), PtxGenError>

Wait on an mbarrier phase (SM >= 90).

Source

pub fn cvt_f32_to_e2m1( &mut self, src: Register, ) -> Result<Register, PtxGenError>

Convert an f32 register to FP4 E2M1 format (SM >= 100, Blackwell).

Emits cvt.rn.e2m1.f32 dst, src; and returns the destination register. The FP4 result is stored in the low 4 bits of a B32 register container.

Source

pub fn cvt_e2m1_to_f32( &mut self, src: Register, ) -> Result<Register, PtxGenError>

Convert an FP4 E2M1 register to f32 (SM >= 100, Blackwell).

Emits cvt.f32.e2m1 dst, src; and returns the destination register.

Source

pub fn tcgen05_mma_m128n256k256_e2m1( &mut self, a_desc: Register, b_desc: Register, ) -> Result<(), PtxGenError>

Emit a tcgen05.mma.cta_group::1.kind::f32 instruction (SM >= 100).

This is the Blackwell 5th-generation Tensor Core MMA that operates on 128×256×256 E2M1 tiles referenced by descriptors stored in 64-bit registers.

Source

pub fn barrier_cluster(&mut self) -> Result<(), PtxGenError>

Emit barrier.cluster.arrive; — signal cluster barrier (SM >= 90).

All CTAs in the cluster must arrive before any may continue past the corresponding barrier.cluster.wait.

Source

pub fn fence_cluster(&mut self) -> Result<(), PtxGenError>

Emit fence.mbarrier_init.release.cluster; — cluster release fence (SM >= 90).

Ensures that all preceding memory operations (including mbarrier initializations) are visible cluster-wide before the barrier is observed.

Source

pub fn cp_async_bulk_tensor_1d( &mut self, dst_smem: Register, src_gmem: Register, desc: Register, ) -> Result<(), PtxGenError>

Emit a 1-D TMA descriptor-based bulk async copy (SM >= 90).

Emits:

cp.async.bulk.tensor.1d.shared::cluster.global.tile.bulk_group
    [dst_smem], [src_gmem, {desc}];

dst_smem is the destination shared-memory address register, src_gmem is the global-memory base address register, and desc is the coordinate / descriptor register.

Source§

impl BodyBuilder<'_>

Source

pub fn wmma_load_a_f16( &mut self, shape: WmmaShape, layout: WmmaLayout, addr: Operand, stride: Option<Operand>, ) -> Vec<Register>

Emit wmma.load.a.sync.aligned{shape}{layout}.f16 for any WMMA shape.

Returns the 8 allocated A-fragment registers.

§Parameters
  • shape — tile shape (all 3 shapes allocate 8 registers per thread).
  • layoutRowMajor or ColMajor A matrix layout.
  • addr — address operand pointing to the matrix tile in shared/global mem.
  • stride — optional stride operand for non-contiguous storage.
Source

pub fn wmma_load_b_f16( &mut self, shape: WmmaShape, layout: WmmaLayout, addr: Operand, stride: Option<Operand>, ) -> Vec<Register>

Emit wmma.load.b.sync.aligned{shape}{layout}.f16 for any WMMA shape.

Returns the 8 allocated B-fragment registers.

Source

pub fn wmma_store_d_f32( &mut self, shape: WmmaShape, layout: WmmaLayout, addr: Operand, regs: Vec<Register>, stride: Option<Operand>, )

Emit wmma.store.d.sync.aligned{shape}{layout}.f32 for any WMMA shape.

Writes the 8 F32 accumulator registers back to memory.

§Parameters
  • addr — destination address in shared/global memory.
  • regs — the accumulator fragment registers to store.
  • stride — optional store stride.
Source

pub fn wmma_store_d_f16( &mut self, shape: WmmaShape, layout: WmmaLayout, addr: Operand, regs: Vec<Register>, stride: Option<Operand>, )

Emit wmma.store.d.sync.aligned{shape}{layout}.f16 for any WMMA shape.

Source

pub fn wmma_mma_sync_f16_f32( &mut self, shape: WmmaShape, layout: WmmaLayout, a_regs: Vec<Register>, b_regs: Vec<Register>, c_regs: Vec<Register>, ) -> Vec<Register>

Emit wmma.mma.sync.aligned{shape}{layout}.f32.f16.f16.f32 — F16 inputs with F32 accumulation (most common WMMA usage).

Returns the 8 F32 accumulator destination registers.

§Parameters
  • a_regs — 8 fragment registers loaded by wmma_load_a_f16.
  • b_regs — 8 fragment registers loaded by wmma_load_b_f16.
  • c_regs — 8 F32 accumulator input registers.
Source

pub fn wmma_mma_sync_f16_f16( &mut self, shape: WmmaShape, layout: WmmaLayout, a_regs: Vec<Register>, b_regs: Vec<Register>, c_regs: Vec<Register>, ) -> Vec<Register>

Emit wmma.mma.sync.aligned{shape}{layout}.f16.f16.f16.f16 — full F16 WMMA with F16 accumulation.

Source

pub fn mma_m16n8k8_tf32_f32( &mut self, a_regs: &[Register], b_regs: &[Register], c_regs: &[Register], ) -> Result<[Register; 4], PtxGenError>

Emit mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32.

TF32 inputs with F32 accumulation. Requires Ampere (sm_80+). Operand register counts per thread: A=2, B=1, C/D=4.

§Errors

Returns an error if the target SM is below Ampere.

Source

pub fn mma_m16n8k16_bf16_f32( &mut self, a_regs: &[Register], b_regs: &[Register], c_regs: &[Register], ) -> Result<[Register; 4], PtxGenError>

Emit mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32.

BF16 inputs with F32 accumulation. Requires Ampere (sm_80+). Register counts per thread: A=4, B=2, C/D=4.

§Errors

Returns an error if the target SM is below Ampere.

Source

pub fn mma_m16n8k32_e4m3_f32( &mut self, a_regs: &[Register], b_regs: &[Register], c_regs: &[Register], ) -> Result<[Register; 4], PtxGenError>

Emit mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32.

FP8 E4M3 inputs with F32 accumulation. Requires Hopper (sm_90+). Register counts per thread: A=8, B=4, C/D=4.

§Errors

Returns an error if the target SM is below Hopper.

Source

pub fn mma_m16n8k32_e5m2_f32( &mut self, a_regs: &[Register], b_regs: &[Register], c_regs: &[Register], ) -> Result<[Register; 4], PtxGenError>

Emit mma.sync.aligned.m16n8k32.row.col.f32.e5m2.e5m2.f32.

FP8 E5M2 inputs with F32 accumulation. Requires Hopper (sm_90+).

§Errors

Returns an error if the target SM is below Hopper.

Source

pub fn mma_m16n8k32_f16_f32( &mut self, a_regs: &[Register], b_regs: &[Register], c_regs: &[Register], ) -> Result<[Register; 4], PtxGenError>

Emit mma.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 — the K=32 FP16 variant for Hopper.

§Errors

Returns an error if the target SM is below Hopper.

Source

pub fn mma_m8n8k16_s8_s32( &mut self, a_regs: &[Register], b_regs: &[Register], c_regs: &[Register], ) -> Result<[Register; 2], PtxGenError>

Emit mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 — INT8 IMMA for Turing/Ampere.

Register counts per thread: A=1, B=1, C/D=2 (S32).

§Errors

Returns an error if the target SM is below Turing (sm_75).

Source

pub fn mma_m8n8k16_u8_s32( &mut self, a_regs: &[Register], b_regs: &[Register], c_regs: &[Register], ) -> Result<[Register; 2], PtxGenError>

Emit mma.sync.aligned.m8n8k16.row.col.s32.u8.u8.s32 — unsigned INT8 IMMA for Turing/Ampere.

§Errors

Returns an error if the target SM is below Turing.

Source

pub fn mma_m16n8k16_s8_s32( &mut self, a_regs: &[Register], b_regs: &[Register], c_regs: &[Register], ) -> Result<[Register; 4], PtxGenError>

Emit mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 — INT8 IMMA in the larger Ampere shape (16-wide M).

Register counts per thread: A=4, B=2, C/D=4 (S32).

§Errors

Returns an error if the target SM is below Ampere.

Source

pub fn wgmma_mma_async( &mut self, shape: WgmmaShape, a_ty: PtxType, b_ty: PtxType, desc_a: Register, desc_b: Register, scale_d: i32, imm_scale_a: i32, imm_scale_b: i32, trans_a: i32, trans_b: i32, ) -> Result<Vec<Register>, PtxGenError>

Emit wgmma.mma_async.sync.aligned{shape}.f32.{a_ty}.{b_ty} using the structured IR.

This is the full parameterised WGMMA builder. It validates the target architecture, allocates the correct number of accumulator registers, and emits the structured Instruction::Wgmma with all required fields.

§Parameters
  • shape — one of the six M64×N×K16 shapes.
  • a_ty — A/B element type: F16, BF16, E4M3, or E5M2.
  • b_ty — must equal a_ty per WGMMA PTX ISA.
  • desc_a — shared-memory descriptor register for A.
  • desc_b — shared-memory descriptor register for B.
  • scale_d — 1 to accumulate into D, 0 to zero-init D before writing.
  • imm_scale_a — immediate scale for A (typically 1).
  • imm_scale_b — immediate scale for B (typically 1).
  • trans_a — 0 = no transpose, 1 = transpose A from col to row layout.
  • trans_b — 0 = no transpose, 1 = transpose B from row to col layout.

Returns the allocated accumulator registers (count = M×N/128).

§Errors

Returns an error if the target SM is below Hopper (sm_90).

Source

pub fn wgmma_mma_async_f16( &mut self, shape: WgmmaShape, desc_a: Register, desc_b: Register, ) -> Result<Vec<Register>, PtxGenError>

Convenience wrapper for wgmma.mma_async with F16 inputs (most common).

Equivalent to calling wgmma_mma_async with a_ty = b_ty = F16 and standard scale/transpose flags (scale_d=1, scales=1, no transpose).

§Errors

Returns an error if the target SM is below Hopper.

Source

pub fn wgmma_mma_async_bf16( &mut self, shape: WgmmaShape, desc_a: Register, desc_b: Register, ) -> Result<Vec<Register>, PtxGenError>

Convenience wrapper for WGMMA with BF16 inputs.

§Errors

Returns an error if the target SM is below Hopper.

Source

pub fn wgmma_mma_async_e4m3( &mut self, shape: WgmmaShape, desc_a: Register, desc_b: Register, ) -> Result<Vec<Register>, PtxGenError>

Convenience wrapper for WGMMA with FP8 E4M3 inputs.

§Errors

Returns an error if the target SM is below Hopper.

Source

pub fn wgmma_mma_async_e5m2( &mut self, shape: WgmmaShape, desc_a: Register, desc_b: Register, ) -> Result<Vec<Register>, PtxGenError>

Convenience wrapper for WGMMA with FP8 E5M2 inputs.

§Errors

Returns an error if the target SM is below Hopper.

Source§

impl<'a> BodyBuilder<'a>

Source

pub fn load_param_u32(&mut self, name: &str) -> Register

Loads a u32 kernel parameter by name.

Emits a ld.param.u32 instruction and returns the destination register.

Source

pub fn load_param_u64(&mut self, name: &str) -> Register

Loads a u64 kernel parameter by name (typically a device pointer).

Emits a ld.param.u64 instruction and returns the destination register.

Source

pub fn load_param_f32(&mut self, name: &str) -> Register

Loads an f32 kernel parameter by name.

Emits a ld.param.f32 instruction and returns the destination register.

Source

pub fn load_param_f64(&mut self, name: &str) -> Register

Loads an f64 kernel parameter by name.

Emits a ld.param.f64 instruction and returns the destination register.

Source

pub fn global_thread_id_x(&mut self) -> Register

Computes the global thread ID in the X dimension.

Equivalent to blockIdx.x * blockDim.x + threadIdx.x in CUDA C. Emits:

mov.u32 %r_tid,  %tid.x;
mov.u32 %r_ntid, %ntid.x;
mov.u32 %r_ctaid, %ctaid.x;
mad.lo.u32 %r_gid, %r_ctaid, %r_ntid, %r_tid;
Source

pub fn global_thread_id_y(&mut self) -> Register

Computes the global thread ID in the Y dimension.

Equivalent to blockIdx.y * blockDim.y + threadIdx.y in CUDA C.

Source

pub fn global_thread_id_2d(&mut self) -> (Register, Register)

Computes both X and Y global thread IDs for 2D kernels.

Returns (row, col) where row is the Y global ID and col is the X global ID (following matrix convention).

Source

pub fn thread_id_x(&mut self) -> Register

Reads %tid.x (thread index within the block, X dimension).

Source

pub fn block_id_x(&mut self) -> Register

Reads %ctaid.x (block index within the grid, X dimension).

Source

pub fn block_dim_x(&mut self) -> Register

Reads %ntid.x (number of threads per block, X dimension).

Source

pub fn add_u32(&mut self, a: Register, b: Register) -> Register

Emits add.u32 dst, a, b.

Source

pub fn add_u64(&mut self, a: Register, b: Register) -> Register

Emits add.u64 dst, a, b.

Source

pub fn add_f32(&mut self, a: Register, b: Register) -> Register

Emits add.f32 dst, a, b.

Source

pub fn add_f64(&mut self, a: Register, b: Register) -> Register

Emits add.f64 dst, a, b.

Source

pub fn sub_f32(&mut self, a: Register, b: Register) -> Register

Emits sub.f32 dst, a, b.

Source

pub fn sub_f64(&mut self, a: Register, b: Register) -> Register

Emits sub.f64 dst, a, b.

Source

pub fn mul_lo_u32(&mut self, a: Register, b: Register) -> Register

Emits mul.lo.u32 dst, a, b — low 32 bits of a u32 multiplication.

Source

pub fn mul_wide_u32_to_u64(&mut self, a: Register, b: Register) -> Register

Emits mul.wide.u32 dst, a, b — widens two u32 operands to produce a u64 result.

Source

pub fn mad_lo_s32(&mut self, a: Register, b: Register, c: Register) -> Register

Emits mad.lo.s32 dst, a, b, c — low 32 bits of a*b+c (signed).

Source

pub fn mad_lo_u32(&mut self, a: Register, b: Register, c: Register) -> Register

Emits mad.lo.u32 dst, a, b, c — low 32 bits of a*b+c (unsigned).

Source

pub fn mad_lo_s64(&mut self, a: Register, b: Register, c: Register) -> Register

Emits mad.lo.s64 dst, a, b, c — low 64 bits of a*b+c (signed).

Source

pub fn mad_lo_u64(&mut self, a: Register, b: Register, c: Register) -> Register

Emits mad.lo.u64 dst, a, b, c — low 64 bits of a*b+c (unsigned).

Source

pub fn mad_hi_s32(&mut self, a: Register, b: Register, c: Register) -> Register

Emits mad.hi.s32 dst, a, b, c — high 32 bits of a*b+c (signed).

Source

pub fn mad_hi_u32(&mut self, a: Register, b: Register, c: Register) -> Register

Emits mad.hi.u32 dst, a, b, c — high 32 bits of a*b+c (unsigned).

Source

pub fn mad_hi_s64(&mut self, a: Register, b: Register, c: Register) -> Register

Emits mad.hi.s64 dst, a, b, c — high 64 bits of a*b+c (signed).

Source

pub fn mad_hi_u64(&mut self, a: Register, b: Register, c: Register) -> Register

Emits mad.hi.u64 dst, a, b, c — high 64 bits of a*b+c (unsigned).

Source

pub fn mad_wide_s16( &mut self, a: Register, b: Register, c: Register, ) -> Register

Emits mad.wide.s16 dst, a, b, c — widening multiply-add, s16 -> s32.

Source

pub fn mad_wide_u16( &mut self, a: Register, b: Register, c: Register, ) -> Register

Emits mad.wide.u16 dst, a, b, c — widening multiply-add, u16 -> u32.

Source

pub fn mad_wide_s32( &mut self, a: Register, b: Register, c: Register, ) -> Register

Emits mad.wide.s32 dst, a, b, c — widening multiply-add, s32 -> s64.

Source

pub fn mad_wide_u32( &mut self, a: Register, b: Register, c: Register, ) -> Register

Emits mad.wide.u32 dst, a, b, c — widening multiply-add, u32 -> u64.

Source

pub fn fma_f32(&mut self, a: Register, b: Register, c: Register) -> Register

Emits fma.rn.f32 dst, a, b, c — fused multiply-add, single precision.

Source

pub fn fma_f64(&mut self, a: Register, b: Register, c: Register) -> Register

Emits fma.rn.f64 dst, a, b, c — fused multiply-add, double precision.

Source

pub fn neg_f32(&mut self, src: Register) -> Register

Emits neg.f32 dst, src.

Source

pub fn abs_f32(&mut self, src: Register) -> Register

Emits abs.f32 dst, src.

Source

pub fn min_f32(&mut self, a: Register, b: Register) -> Register

Emits min.f32 dst, a, b.

Source

pub fn max_f32(&mut self, a: Register, b: Register) -> Register

Emits max.f32 dst, a, b.

Source

pub fn min_u32(&mut self, a: Register, b: Register) -> Register

Emits min.u32 dst, a, b.

Source

pub fn max_u32(&mut self, a: Register, b: Register) -> Register

Emits max.u32 dst, a, b.

Source

pub fn brev_b32(&mut self, src: Register) -> Register

Emits brev.b32 dst, src — reverse the bits of a 32-bit value.

Source

pub fn brev_b64(&mut self, src: Register) -> Register

Emits brev.b64 dst, src — reverse the bits of a 64-bit value.

Source

pub fn clz_b32(&mut self, src: Register) -> Register

Emits clz.b32 dst, src — count leading zeros (result is U32).

Source

pub fn popc_b32(&mut self, src: Register) -> Register

Emits popc.b32 dst, src — population count of 32-bit value (result is U32).

Source

pub fn popc_b64(&mut self, src: Register) -> Register

Emits popc.b64 dst, src — population count of 64-bit value (result is U32).

Source

pub fn bfind_u32(&mut self, src: Register) -> Register

Emits bfind.u32 dst, src — find most significant bit (unsigned, result is U32).

Source

pub fn bfind_s32(&mut self, src: Register) -> Register

Emits bfind.s32 dst, src — find most significant non-sign bit (signed, result is U32).

Source

pub fn bfe_u32( &mut self, src: Register, start: Register, len: Register, ) -> Register

Emits bfe.u32 dst, src, start, len — extract a bit field (unsigned).

Source

pub fn bfe_s32( &mut self, src: Register, start: Register, len: Register, ) -> Register

Emits bfe.s32 dst, src, start, len — extract a bit field (signed).

Source

pub fn bfi_b32( &mut self, insert: Register, base: Register, start: Register, len: Register, ) -> Register

Emits bfi.b32 dst, insert, base, start, len — insert a bit field.

Source

pub fn shl_b32(&mut self, src: Register, amount: Register) -> Register

Emits shl.b32 dst, src, amount — left shift, 32-bit.

Source

pub fn shl_b64(&mut self, src: Register, amount: Register) -> Register

Emits shl.b64 dst, src, amount — left shift, 64-bit.

Source

pub fn shr_b32(&mut self, src: Register, amount: Register) -> Register

Emits shr.b32 dst, src, amount — logical right shift, 32-bit.

Source

pub fn shr_b64(&mut self, src: Register, amount: Register) -> Register

Emits shr.b64 dst, src, amount — logical right shift, 64-bit.

Source

pub fn shr_u32(&mut self, src: Register, amount: Register) -> Register

Emits shr.u32 dst, src, amount — logical right shift for unsigned 32-bit.

Source

pub fn shr_s32(&mut self, src: Register, amount: Register) -> Register

Emits shr.s32 dst, src, amount — arithmetic right shift for signed 32-bit.

Source

pub fn rcp_f32(&mut self, src: Register) -> Register

Emits rcp.rn.f32 dst, src — reciprocal, single precision.

Source

pub fn rcp_f64(&mut self, src: Register) -> Register

Emits rcp.rn.f64 dst, src — reciprocal, double precision.

Source

pub fn rcp_approx_f32(&mut self, src: Register) -> Register

Emits rcp.approx.ftz.f32 dst, src — fast approximate reciprocal.

Uses rnd=None to signal approx mode (no IEEE rounding).

Source

pub fn rsqrt_approx_f32(&mut self, src: Register) -> Register

Emits rsqrt.approx.f32 dst, src — approximate reciprocal square root.

Source

pub fn rsqrt_approx_f64(&mut self, src: Register) -> Register

Emits rsqrt.approx.f64 dst, src — approximate reciprocal square root, double precision.

Source

pub fn sqrt_rn_f32(&mut self, src: Register) -> Register

Emits sqrt.rn.f32 dst, src — square root, single precision.

Source

pub fn sqrt_rn_f64(&mut self, src: Register) -> Register

Emits sqrt.rn.f64 dst, src — square root, double precision.

Source

pub fn ex2_approx_f32(&mut self, src: Register) -> Register

Emits ex2.approx.f32 dst, src — base-2 exponential, approximate.

Source

pub fn lg2_approx_f32(&mut self, src: Register) -> Register

Emits lg2.approx.f32 dst, src — base-2 logarithm, approximate.

Source

pub fn sin_approx_f32(&mut self, src: Register) -> Register

Emits sin.approx.f32 dst, src — sine, approximate.

Source

pub fn cos_approx_f32(&mut self, src: Register) -> Register

Emits cos.approx.f32 dst, src — cosine, approximate.

Source

pub fn load_global_f32(&mut self, addr: Register) -> Register

Loads a single f32 from global memory.

addr should be a U64 register containing the global device pointer. Emits ld.global.f32 dst, [addr].

Source

pub fn load_global_f64(&mut self, addr: Register) -> Register

Loads a single f64 from global memory.

Source

pub fn load_global_i32(&mut self, addr: Register) -> Register

Loads a single signed 32-bit integer from global memory.

Emits ld.global.s32 dst, [addr].

Source

pub fn load_global_u32(&mut self, addr: Register) -> Register

Loads a single unsigned 32-bit integer from global memory.

Emits ld.global.u32 dst, [addr].

Source

pub fn load_global_f32x4(&mut self, addr: &Register) -> [Register; 4]

Loads four f32 values from global memory as a vectorized .v4 load.

Returns an array of 4 registers containing the loaded values. addr must be 16-byte aligned for correctness.

Since the IR Load instruction uses a single destination register, this method emits raw PTX for the vectorized load and individual mov instructions to extract each element.

Source

pub fn store_global_f32(&mut self, addr: Register, val: Register)

Stores a single f32 to global memory.

addr should be a U64 register containing the global device pointer.

Source

pub fn store_global_f64(&mut self, addr: Register, val: Register)

Stores a single f64 to global memory.

Source

pub fn store_global_i32(&mut self, addr: Register, val: Register)

Stores a single signed 32-bit integer to global memory.

Emits st.global.s32 [addr], val.

Source

pub fn store_global_u32(&mut self, addr: Register, val: Register)

Stores a single unsigned 32-bit integer to global memory.

Emits st.global.u32 [addr], val.

Source

pub fn load_shared_f32(&mut self, addr: Register) -> Register

Loads a single f32 from shared memory.

addr should be a register containing an address in shared memory space.

Source

pub fn store_shared_f32(&mut self, addr: Register, val: Register)

Stores a single f32 to shared memory.

Source

pub fn cp_async_32bit(&mut self, dst_shared: Register, src_global: Register)

Emits a 32-bit (4-byte) asynchronous copy from global to shared memory.

Emits: cp.async.ca.shared.global [dst], [src], 4; Requires sm_80+.

Source

pub fn cp_async_64bit(&mut self, dst_shared: Register, src_global: Register)

Emits a 64-bit (8-byte) asynchronous copy from global to shared memory.

Emits: cp.async.ca.shared.global [dst], [src], 8; Requires sm_80+.

Source

pub fn cp_async_128bit(&mut self, dst_shared: Register, src_global: Register)

Emits a 128-bit (16-byte) asynchronous copy from global to shared memory.

This is the most common cp.async variant, used for double-buffered data loading in high-performance kernels. Requires sm_80+.

Source

pub fn cp_async_commit(&mut self)

Emits cp.async.commit_group to commit all pending async copies.

Source

pub fn cp_async_wait(&mut self, n: u32)

Emits cp.async.wait_group N to wait until at most n copy groups are still pending.

Pass 0 to wait for all pending copies to complete.

Source

pub fn ldmatrix_x4( &mut self, src_addr: Register, ) -> Result<[Register; 4], PtxGenError>

Emits a ldmatrix.sync.aligned.m8n8.x4.shared.b16 instruction (SM >= 75).

Loads 4 warp-cooperative 8×8 B16 matrix fragments from shared memory. Each of the 32 threads contributes to loading 8 bytes (one row) of the tile. Returns the four destination registers.

§Errors

Returns PtxGenError if the target architecture does not support ldmatrix (requires SM >= 75).

Source

pub fn if_lt_u32<F>(&mut self, a: Register, b: Register, body: F)
where F: FnOnce(&mut BodyBuilder<'_>),

Emits a conditional block that executes body when a < b (unsigned 32-bit).

Generates a setp.lo.u32 comparison, a negated conditional branch over the body, and a skip label.

§Example
b.if_lt_u32(tid, n, |b| {
    // Only threads with tid < n execute this
});
Source

pub fn if_ge_u32<F>(&mut self, a: Register, b: Register, body: F)
where F: FnOnce(&mut BodyBuilder<'_>),

Emits a conditional block that executes body when a >= b (unsigned 32-bit).

Source

pub fn unroll<F>(&mut self, count: u32, body: F)
where F: FnMut(&mut BodyBuilder<'_>, u32),

Compile-time loop unrolling.

Calls body(i) for i in 0..count, emitting all iterations inline. This is equivalent to #pragma unroll in CUDA C.

Each iteration gets its own comment indicating the unroll index.

Source

pub fn pragma_unroll(&mut self, factor: Option<u32>)

Emits a .pragma "unroll N" or .pragma "nounroll" directive hint.

When factor is Some(n), emits .pragma "unroll N"; to hint the PTX assembler to unroll the following loop by factor n. When factor is None, emits .pragma "nounroll"; to suppress unrolling.

Source

pub fn label(&mut self, name: &str)

Emits a label pseudo-instruction.

Labels are branch targets. They appear at the start of a line without indentation in the generated PTX.

Source

pub fn branch(&mut self, target: &str)

Emits an unconditional branch to the given label.

Source

pub fn branch_if(&mut self, pred: Register, target: &str)

Emits a conditional branch: @pred bra target.

Source

pub fn ret(&mut self)

Emits a ret instruction to return from the kernel.

Source

pub fn bar_sync(&mut self, id: u32)

Emits bar.sync id — block-level barrier synchronization.

All threads in the block must reach this barrier before any can proceed. id is typically 0.

Source

pub fn fence_acq_rel(&mut self, scope: FenceScope)

Emits a memory fence with acquire-release semantics at the given scope.

Source

pub fn cvt_u32_to_u64(&mut self, src: Register) -> Register

Converts a u32 register to u64 (zero-extension).

Emits cvt.u64.u32 dst, src.

Source

pub fn cvt_f32_to_f64(&mut self, src: Register) -> Register

Converts an f32 register to f64 (widening).

Emits cvt.f64.f32 dst, src.

Source

pub fn cvt_f64_to_f32(&mut self, src: Register) -> Register

Converts an f64 register to f32 (narrowing, round-to-nearest-even).

Emits cvt.rn.f32.f64 dst, src.

Source

pub fn cvt_f16_to_f32(&mut self, src: Register) -> Register

Converts an f16 register to f32 (widening).

Emits cvt.f32.f16 dst, src.

Source

pub fn cvt_f32_to_f16(&mut self, src: Register) -> Register

Converts an f32 register to f16 (narrowing, round-to-nearest-even).

Emits cvt.rn.f16.f32 dst, src.

Source

pub fn cvt_bf16_to_f32(&mut self, src: Register) -> Register

Converts a bf16 register to f32 (widening).

Emits cvt.f32.bf16 dst, src.

Source

pub fn cvt_f32_to_bf16(&mut self, src: Register) -> Register

Converts an f32 register to bf16 (narrowing, round-to-nearest-even).

Emits cvt.rn.bf16.f32 dst, src.

Source

pub fn cvt_f32_to_e4m3(&mut self, src: Register) -> Register

Converts an f32 register to FP8 E4M3 format (sm_89+, Ada/Hopper).

Emits: cvt.rn.satfinite.e4m3x2.f32 dst, src_hi, src_lo Note: PTX packs two FP8 values per register (e4m3x2).

Source

pub fn cvt_e4m3_to_f32(&mut self, src: Register) -> Register

Converts an FP8 E4M3 register to f32 (sm_89+).

Emits cvt.f32.e4m3 dst, src.

Source

pub fn cvt_f32_to_e5m2(&mut self, src: Register) -> Register

Converts an f32 register to FP8 E5M2 format (sm_89+).

Emits cvt.rn.e5m2.f32 dst, src.

Source

pub fn cvt_e5m2_to_f32(&mut self, src: Register) -> Register

Converts an FP8 E5M2 register to f32 (sm_89+).

Emits cvt.f32.e5m2 dst, src.

Source

pub fn mma_m16n8k16_f16_f32( &mut self, a_regs: &[Register], b_regs: &[Register], c_regs: &[Register], ) -> [Register; 4]

Emits an mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 instruction.

This is the standard Ampere tensor core MMA operation:

  • Shape: 16x8x16 tile
  • A fragment: registers holding f16 matrix A data
  • B fragment: registers holding f16 matrix B data
  • C/D accumulator: 4 f32 registers for input/output accumulator

Returns the 4 destination accumulator registers.

§Arguments
  • a_regs — Registers holding the A matrix fragment (f16)
  • b_regs — Registers holding the B matrix fragment (f16)
  • c_regs — Registers holding the C accumulator input (f32)
Source

pub fn wgmma_mma_async_m64n128k16_f16( &mut self, a_desc: &str, b_desc: &str, ) -> Result<(), PtxGenError>

Emits wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16 Warpgroup MMA async for Hopper (sm_90+) computing a 64×128 tile.

This operates on warpgroup-level fragments:

  • a_desc: A operand descriptor (shared memory descriptor string)
  • b_desc: B operand descriptor
  • Accumulator: 64 f32 registers (managed by the warpgroup implicitly)

Emits raw PTX via raw_ptx since wgmma is not yet in the structured IR.

§Errors

Returns PtxGenError when the target SM is below 90 (Hopper).

Source

pub fn dp4a_u32_u32( &mut self, a: Register, b: Register, c: Register, ) -> Register

Emits dp4a.u32.u32 dst, a, b, c — unsigned 4-way byte dot product.

Source

pub fn dp4a_s32_s32( &mut self, a: Register, b: Register, c: Register, ) -> Register

Emits dp4a.s32.s32 dst, a, b, c — signed 4-way byte dot product.

Source

pub fn dp4a_s32_u32( &mut self, a: Register, b: Register, c: Register, ) -> Register

Emits dp4a.s32.u32 dst, a, b, c — mixed signed/unsigned 4-way byte dot product.

Source

pub fn dp4a_u32_s32( &mut self, a: Register, b: Register, c: Register, ) -> Register

Emits dp4a.u32.s32 dst, a, b, c — mixed unsigned/signed 4-way byte dot product.

Source

pub fn dp2a_lo_u32_u32( &mut self, a: Register, b: Register, c: Register, ) -> Register

Emits dp2a.lo.u32.u32 dst, a, b, c — unsigned 2-way dot product, low half.

Source

pub fn dp2a_hi_u32_u32( &mut self, a: Register, b: Register, c: Register, ) -> Register

Emits dp2a.hi.u32.u32 dst, a, b, c — unsigned 2-way dot product, high half.

Source

pub fn dp2a_lo_s32_s32( &mut self, a: Register, b: Register, c: Register, ) -> Register

Emits dp2a.lo.s32.s32 dst, a, b, c — signed 2-way dot product, low half.

Source

pub fn dp2a_hi_s32_s32( &mut self, a: Register, b: Register, c: Register, ) -> Register

Emits dp2a.hi.s32.s32 dst, a, b, c — signed 2-way dot product, high half.

Source

pub const fn imm_u32(&self, val: u32) -> Operand

Creates an unsigned 32-bit immediate operand.

Source

pub fn mov_imm_u32(&mut self, val: u32) -> Register

Loads an unsigned 32-bit immediate into a new register via add.u32 dst, 0, val.

Source

pub const fn imm_u64(&self, val: u64) -> Operand

Creates an unsigned 64-bit immediate operand.

Source

pub const fn imm_f32(&self, val: f32) -> Operand

Creates a 32-bit floating-point immediate operand.

Source

pub const fn imm_f64(&self, val: f64) -> Operand

Creates a 64-bit floating-point immediate operand.

Source

pub fn comment(&mut self, text: &str)

Emits a comment in the PTX output (for debugging / readability).

Source

pub fn raw_ptx(&mut self, text: &str)

Emits raw PTX text verbatim. Use as an escape hatch for instructions not yet modeled in the IR.

Named registers (e.g., %f_x, %rd_off, %p_ge) found in the text are automatically declared based on their prefix:

  • %f_*.reg .f32
  • %rd_*.reg .b64
  • %r_*.reg .b32
  • %p_*.reg .pred
Source

pub fn byte_offset_addr( &mut self, base: Register, index: Register, stride_bytes: u32, ) -> Register

Computes a byte offset address: base + index * stride.

Useful for computing element addresses in arrays. The index is zero-extended from u32 to u64 before the multiplication.

Returns a U64 register containing the computed address.

Source

pub fn f32_elem_addr(&mut self, base: Register, index: Register) -> Register

Computes an element address for an f32 array: base + index * 4.

Source

pub fn f64_elem_addr(&mut self, base: Register, index: Register) -> Register

Computes an element address for an f64 array: base + index * 8.

Source

pub fn alloc_reg(&mut self, ty: PtxType) -> Register

Allocates a fresh register of the given type.

This is a lower-level API — most users should prefer the typed instruction methods which allocate destination registers automatically.

Source

pub fn declare_named_reg(&mut self, name: &str, ty: PtxType)

Declares a named register for use in raw_ptx blocks.

Named registers (e.g., %f_x, %rd_off) are not created by the automatic allocator, so they must be declared explicitly before use.

Source

pub fn fresh_label(&mut self, prefix: &str) -> String

Generates a unique label name with the given prefix.

Labels are formatted as L__{prefix}_{counter} to avoid collisions with user-defined labels and other generated labels.

Source

pub const fn target_sm(&self) -> SmVersion

Returns the target SM version for this kernel.

Useful for architecture-gated code paths within body closures.

Source

pub fn has_param(&self, name: &str) -> bool

Returns true if the given parameter name was declared on the kernel.

Auto Trait Implementations§

§

impl<'a> Freeze for BodyBuilder<'a>

§

impl<'a> RefUnwindSafe for BodyBuilder<'a>

§

impl<'a> Send for BodyBuilder<'a>

§

impl<'a> Sync for BodyBuilder<'a>

§

impl<'a> Unpin for BodyBuilder<'a>

§

impl<'a> UnsafeUnpin for BodyBuilder<'a>

§

impl<'a> !UnwindSafe for BodyBuilder<'a>

Blanket Implementations§

Source§

impl<T> Any for T
where T: 'static + ?Sized,

Source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Source§

impl<T> Borrow<T> for T
where T: ?Sized,

Source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

Source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
Source§

impl<T> From<T> for T

Source§

fn from(t: T) -> T

Returns the argument unchanged.

Source§

impl<T, U> Into<U> for T
where U: From<T>,

Source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

Source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Source§

type Error = Infallible

The type returned in the event of a conversion error.
Source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.