Skip to main content

oxicuda_ptx/builder/body_builder/
mod.rs

1//! Instruction emission API for PTX kernel bodies.
2//!
3//! [`BodyBuilder`] provides the ergonomic instruction-level DSL used inside
4//! [`KernelBuilder::body`] closures. It wraps a [`RegisterAllocator`] and an
5//! instruction vector, exposing methods for parameter loading, thread IDs,
6//! arithmetic, memory ops, control flow, synchronization, type conversions,
7//! and Tensor Core MMA instructions.
8
9use crate::arch::SmVersion;
10use crate::error::PtxGenError;
11use crate::ir::{
12    CacheQualifier, CmpOp, FenceScope, ImmValue, Instruction, MemorySpace, MmaShape, MulMode,
13    Operand, PtxType, Register, RegisterAllocator, RoundingMode, SpecialReg, VectorWidth,
14};
15
16/// Instruction emission API for building the body of a PTX kernel.
17///
18/// `BodyBuilder` is not constructed directly — it is provided as a mutable
19/// reference inside the closure passed to [`KernelBuilder::body`].
20///
21/// Most methods follow a consistent pattern: allocate destination register(s),
22/// push the corresponding [`Instruction`] variant, and return the destination
23/// register so it can be used as an operand to subsequent instructions.
24///
25/// [`KernelBuilder::body`]: super::KernelBuilder::body
26pub struct BodyBuilder<'a> {
27    /// Register allocator shared with the kernel builder.
28    pub(super) regs: &'a mut RegisterAllocator,
29    /// Instruction vector that accumulates the kernel body.
30    pub(super) instructions: &'a mut Vec<Instruction>,
31    /// Monotonically increasing label counter for generating unique labels.
32    label_counter: u32,
33    /// Names of the kernel parameters (for `load_param_*` methods).
34    param_names: &'a [String],
35    /// Target SM version (for architecture-gated instructions).
36    pub(super) target: SmVersion,
37}
38
39impl<'a> BodyBuilder<'a> {
40    /// Creates a new body builder.
41    ///
42    /// This is called internally by [`KernelBuilder::build`] — users should
43    /// not need to construct this directly.
44    ///
45    /// [`KernelBuilder::build`]: super::KernelBuilder::build
46    pub(crate) const fn new(
47        regs: &'a mut RegisterAllocator,
48        instructions: &'a mut Vec<Instruction>,
49        param_names: &'a [String],
50        target: SmVersion,
51    ) -> Self {
52        Self {
53            regs,
54            instructions,
55            label_counter: 0,
56            param_names,
57            target,
58        }
59    }
60
61    // ════════════════════════════════════════════════════════════════════
62    //  Parameter Loading
63    // ════════════════════════════════════════════════════════════════════
64
65    /// Loads a `u32` kernel parameter by name.
66    ///
67    /// Emits a `ld.param.u32` instruction and returns the destination register.
68    pub fn load_param_u32(&mut self, name: &str) -> Register {
69        self.load_param(name, PtxType::U32)
70    }
71
72    /// Loads a `u64` kernel parameter by name (typically a device pointer).
73    ///
74    /// Emits a `ld.param.u64` instruction and returns the destination register.
75    pub fn load_param_u64(&mut self, name: &str) -> Register {
76        self.load_param(name, PtxType::U64)
77    }
78
79    /// Loads an `f32` kernel parameter by name.
80    ///
81    /// Emits a `ld.param.f32` instruction and returns the destination register.
82    pub fn load_param_f32(&mut self, name: &str) -> Register {
83        self.load_param(name, PtxType::F32)
84    }
85
86    /// Loads an `f64` kernel parameter by name.
87    ///
88    /// Emits a `ld.param.f64` instruction and returns the destination register.
89    pub fn load_param_f64(&mut self, name: &str) -> Register {
90        self.load_param(name, PtxType::F64)
91    }
92
93    /// Generic parameter load helper.
94    ///
95    /// Emits `ld.param{.ty} dst, [%param_{name}]` using the `LoadParam`
96    /// instruction variant.
97    fn load_param(&mut self, name: &str, ty: PtxType) -> Register {
98        let dst = self.regs.alloc(ty);
99        self.emit(Instruction::LoadParam {
100            ty,
101            dst: dst.clone(),
102            param_name: format!("%param_{name}"),
103        });
104        dst
105    }
106
107    // ════════════════════════════════════════════════════════════════════
108    //  Thread / Block ID Computation
109    // ════════════════════════════════════════════════════════════════════
110
111    /// Computes the global thread ID in the X dimension.
112    ///
113    /// Equivalent to `blockIdx.x * blockDim.x + threadIdx.x` in CUDA C.
114    /// Emits:
115    /// ```ptx
116    /// mov.u32 %r_tid,  %tid.x;
117    /// mov.u32 %r_ntid, %ntid.x;
118    /// mov.u32 %r_ctaid, %ctaid.x;
119    /// mad.lo.u32 %r_gid, %r_ctaid, %r_ntid, %r_tid;
120    /// ```
121    pub fn global_thread_id_x(&mut self) -> Register {
122        let tid = self.read_special_reg(SpecialReg::TidX);
123        let ntid = self.read_special_reg(SpecialReg::NtidX);
124        let ctaid = self.read_special_reg(SpecialReg::CtaidX);
125        let gid = self.regs.alloc(PtxType::U32);
126        self.emit(Instruction::Mad {
127            ty: PtxType::U32,
128            mode: MulMode::Lo,
129            dst: gid.clone(),
130            a: Operand::Register(ctaid),
131            b: Operand::Register(ntid),
132            c: Operand::Register(tid),
133        });
134        gid
135    }
136
137    /// Computes the global thread ID in the Y dimension.
138    ///
139    /// Equivalent to `blockIdx.y * blockDim.y + threadIdx.y` in CUDA C.
140    pub fn global_thread_id_y(&mut self) -> Register {
141        let tid = self.read_special_reg(SpecialReg::TidY);
142        let ntid = self.read_special_reg(SpecialReg::NtidY);
143        let ctaid = self.read_special_reg(SpecialReg::CtaidY);
144        let gid = self.regs.alloc(PtxType::U32);
145        self.emit(Instruction::Mad {
146            ty: PtxType::U32,
147            mode: MulMode::Lo,
148            dst: gid.clone(),
149            a: Operand::Register(ctaid),
150            b: Operand::Register(ntid),
151            c: Operand::Register(tid),
152        });
153        gid
154    }
155
156    /// Computes both X and Y global thread IDs for 2D kernels.
157    ///
158    /// Returns `(row, col)` where `row` is the Y global ID and `col` is
159    /// the X global ID (following matrix convention).
160    pub fn global_thread_id_2d(&mut self) -> (Register, Register) {
161        let col = self.global_thread_id_x();
162        let row = self.global_thread_id_y();
163        (row, col)
164    }
165
166    /// Reads `%tid.x` (thread index within the block, X dimension).
167    pub fn thread_id_x(&mut self) -> Register {
168        self.read_special_reg(SpecialReg::TidX)
169    }
170
171    /// Reads `%ctaid.x` (block index within the grid, X dimension).
172    pub fn block_id_x(&mut self) -> Register {
173        self.read_special_reg(SpecialReg::CtaidX)
174    }
175
176    /// Reads `%ntid.x` (number of threads per block, X dimension).
177    pub fn block_dim_x(&mut self) -> Register {
178        self.read_special_reg(SpecialReg::NtidX)
179    }
180
181    /// Reads a special register into a fresh `U32` register using `MovSpecial`.
182    fn read_special_reg(&mut self, sreg: SpecialReg) -> Register {
183        let dst = self.regs.alloc(PtxType::U32);
184        self.emit(Instruction::MovSpecial {
185            dst: dst.clone(),
186            special: sreg,
187        });
188        dst
189    }
190
191    // ════════════════════════════════════════════════════════════════════
192    //  Integer Arithmetic
193    // ════════════════════════════════════════════════════════════════════
194
195    /// Emits `add.u32 dst, a, b`.
196    pub fn add_u32(&mut self, a: Register, b: Register) -> Register {
197        self.add_typed(PtxType::U32, a, b)
198    }
199
200    /// Emits `add.u64 dst, a, b`.
201    pub fn add_u64(&mut self, a: Register, b: Register) -> Register {
202        self.add_typed(PtxType::U64, a, b)
203    }
204
205    /// Emits `add.f32 dst, a, b`.
206    pub fn add_f32(&mut self, a: Register, b: Register) -> Register {
207        self.add_typed(PtxType::F32, a, b)
208    }
209
210    /// Emits `add.f64 dst, a, b`.
211    pub fn add_f64(&mut self, a: Register, b: Register) -> Register {
212        self.add_typed(PtxType::F64, a, b)
213    }
214
215    /// Generic typed addition helper.
216    fn add_typed(&mut self, ty: PtxType, a: Register, b: Register) -> Register {
217        let dst = self.regs.alloc(ty);
218        self.emit(Instruction::Add {
219            ty,
220            dst: dst.clone(),
221            a: Operand::Register(a),
222            b: Operand::Register(b),
223        });
224        dst
225    }
226
227    /// Emits `sub.f32 dst, a, b`.
228    pub fn sub_f32(&mut self, a: Register, b: Register) -> Register {
229        self.sub_typed(PtxType::F32, a, b)
230    }
231
232    /// Emits `sub.f64 dst, a, b`.
233    pub fn sub_f64(&mut self, a: Register, b: Register) -> Register {
234        self.sub_typed(PtxType::F64, a, b)
235    }
236
237    /// Generic typed subtraction helper.
238    fn sub_typed(&mut self, ty: PtxType, a: Register, b: Register) -> Register {
239        let dst = self.regs.alloc(ty);
240        self.emit(Instruction::Sub {
241            ty,
242            dst: dst.clone(),
243            a: Operand::Register(a),
244            b: Operand::Register(b),
245        });
246        dst
247    }
248
249    /// Emits `mul.lo.u32 dst, a, b` — low 32 bits of a u32 multiplication.
250    pub fn mul_lo_u32(&mut self, a: Register, b: Register) -> Register {
251        let dst = self.regs.alloc(PtxType::U32);
252        self.emit(Instruction::Mul {
253            ty: PtxType::U32,
254            mode: MulMode::Lo,
255            dst: dst.clone(),
256            a: Operand::Register(a),
257            b: Operand::Register(b),
258        });
259        dst
260    }
261
262    /// Emits `mul.wide.u32 dst, a, b` — widens two u32 operands to produce
263    /// a u64 result.
264    pub fn mul_wide_u32_to_u64(&mut self, a: Register, b: Register) -> Register {
265        let dst = self.regs.alloc(PtxType::U64);
266        self.emit(Instruction::Mul {
267            ty: PtxType::U32,
268            mode: MulMode::Wide,
269            dst: dst.clone(),
270            a: Operand::Register(a),
271            b: Operand::Register(b),
272        });
273        dst
274    }
275
276    // ════════════════════════════════════════════════════════════════════
277    //  Integer Multiply-Add (mad.lo / mad.hi / mad.wide)
278    // ════════════════════════════════════════════════════════════════════
279
280    /// Emits `mad.lo.s32 dst, a, b, c` — low 32 bits of `a*b+c` (signed).
281    pub fn mad_lo_s32(&mut self, a: Register, b: Register, c: Register) -> Register {
282        self.mad_lo_typed(PtxType::S32, a, b, c)
283    }
284
285    /// Emits `mad.lo.u32 dst, a, b, c` — low 32 bits of `a*b+c` (unsigned).
286    pub fn mad_lo_u32(&mut self, a: Register, b: Register, c: Register) -> Register {
287        self.mad_lo_typed(PtxType::U32, a, b, c)
288    }
289
290    /// Emits `mad.lo.s64 dst, a, b, c` — low 64 bits of `a*b+c` (signed).
291    pub fn mad_lo_s64(&mut self, a: Register, b: Register, c: Register) -> Register {
292        self.mad_lo_typed(PtxType::S64, a, b, c)
293    }
294
295    /// Emits `mad.lo.u64 dst, a, b, c` — low 64 bits of `a*b+c` (unsigned).
296    pub fn mad_lo_u64(&mut self, a: Register, b: Register, c: Register) -> Register {
297        self.mad_lo_typed(PtxType::U64, a, b, c)
298    }
299
300    /// Generic typed `mad.lo` helper.
301    fn mad_lo_typed(&mut self, typ: PtxType, a: Register, b: Register, c: Register) -> Register {
302        let dst = self.regs.alloc(typ);
303        self.emit(Instruction::MadLo {
304            typ,
305            dst: dst.clone(),
306            a: Operand::Register(a),
307            b: Operand::Register(b),
308            c: Operand::Register(c),
309        });
310        dst
311    }
312
313    /// Emits `mad.hi.s32 dst, a, b, c` — high 32 bits of `a*b+c` (signed).
314    pub fn mad_hi_s32(&mut self, a: Register, b: Register, c: Register) -> Register {
315        self.mad_hi_typed(PtxType::S32, a, b, c)
316    }
317
318    /// Emits `mad.hi.u32 dst, a, b, c` — high 32 bits of `a*b+c` (unsigned).
319    pub fn mad_hi_u32(&mut self, a: Register, b: Register, c: Register) -> Register {
320        self.mad_hi_typed(PtxType::U32, a, b, c)
321    }
322
323    /// Emits `mad.hi.s64 dst, a, b, c` — high 64 bits of `a*b+c` (signed).
324    pub fn mad_hi_s64(&mut self, a: Register, b: Register, c: Register) -> Register {
325        self.mad_hi_typed(PtxType::S64, a, b, c)
326    }
327
328    /// Emits `mad.hi.u64 dst, a, b, c` — high 64 bits of `a*b+c` (unsigned).
329    pub fn mad_hi_u64(&mut self, a: Register, b: Register, c: Register) -> Register {
330        self.mad_hi_typed(PtxType::U64, a, b, c)
331    }
332
333    /// Generic typed `mad.hi` helper.
334    fn mad_hi_typed(&mut self, typ: PtxType, a: Register, b: Register, c: Register) -> Register {
335        let dst = self.regs.alloc(typ);
336        self.emit(Instruction::MadHi {
337            typ,
338            dst: dst.clone(),
339            a: Operand::Register(a),
340            b: Operand::Register(b),
341            c: Operand::Register(c),
342        });
343        dst
344    }
345
346    /// Emits `mad.wide.s16 dst, a, b, c` — widening multiply-add, s16 -> s32.
347    pub fn mad_wide_s16(&mut self, a: Register, b: Register, c: Register) -> Register {
348        let dst = self.regs.alloc(PtxType::S32);
349        self.emit(Instruction::MadWide {
350            src_typ: PtxType::S16,
351            dst: dst.clone(),
352            a: Operand::Register(a),
353            b: Operand::Register(b),
354            c: Operand::Register(c),
355        });
356        dst
357    }
358
359    /// Emits `mad.wide.u16 dst, a, b, c` — widening multiply-add, u16 -> u32.
360    pub fn mad_wide_u16(&mut self, a: Register, b: Register, c: Register) -> Register {
361        let dst = self.regs.alloc(PtxType::U32);
362        self.emit(Instruction::MadWide {
363            src_typ: PtxType::U16,
364            dst: dst.clone(),
365            a: Operand::Register(a),
366            b: Operand::Register(b),
367            c: Operand::Register(c),
368        });
369        dst
370    }
371
372    /// Emits `mad.wide.s32 dst, a, b, c` — widening multiply-add, s32 -> s64.
373    pub fn mad_wide_s32(&mut self, a: Register, b: Register, c: Register) -> Register {
374        let dst = self.regs.alloc(PtxType::S64);
375        self.emit(Instruction::MadWide {
376            src_typ: PtxType::S32,
377            dst: dst.clone(),
378            a: Operand::Register(a),
379            b: Operand::Register(b),
380            c: Operand::Register(c),
381        });
382        dst
383    }
384
385    /// Emits `mad.wide.u32 dst, a, b, c` — widening multiply-add, u32 -> u64.
386    pub fn mad_wide_u32(&mut self, a: Register, b: Register, c: Register) -> Register {
387        let dst = self.regs.alloc(PtxType::U64);
388        self.emit(Instruction::MadWide {
389            src_typ: PtxType::U32,
390            dst: dst.clone(),
391            a: Operand::Register(a),
392            b: Operand::Register(b),
393            c: Operand::Register(c),
394        });
395        dst
396    }
397
398    // ════════════════════════════════════════════════════════════════════
399    //  Floating-Point Arithmetic
400    // ════════════════════════════════════════════════════════════════════
401
402    /// Emits `fma.rn.f32 dst, a, b, c` — fused multiply-add, single precision.
403    pub fn fma_f32(&mut self, a: Register, b: Register, c: Register) -> Register {
404        self.fma_typed(PtxType::F32, a, b, c)
405    }
406
407    /// Emits `fma.rn.f64 dst, a, b, c` — fused multiply-add, double precision.
408    pub fn fma_f64(&mut self, a: Register, b: Register, c: Register) -> Register {
409        self.fma_typed(PtxType::F64, a, b, c)
410    }
411
412    /// Generic typed FMA helper with round-to-nearest-even.
413    fn fma_typed(&mut self, ty: PtxType, a: Register, b: Register, c: Register) -> Register {
414        let dst = self.regs.alloc(ty);
415        self.emit(Instruction::Fma {
416            rnd: RoundingMode::Rn,
417            ty,
418            dst: dst.clone(),
419            a: Operand::Register(a),
420            b: Operand::Register(b),
421            c: Operand::Register(c),
422        });
423        dst
424    }
425
426    /// Emits `neg.f32 dst, src`.
427    pub fn neg_f32(&mut self, src: Register) -> Register {
428        let dst = self.regs.alloc(PtxType::F32);
429        self.emit(Instruction::Neg {
430            ty: PtxType::F32,
431            dst: dst.clone(),
432            src: Operand::Register(src),
433        });
434        dst
435    }
436
437    /// Emits `abs.f32 dst, src`.
438    pub fn abs_f32(&mut self, src: Register) -> Register {
439        let dst = self.regs.alloc(PtxType::F32);
440        self.emit(Instruction::Abs {
441            ty: PtxType::F32,
442            dst: dst.clone(),
443            src: Operand::Register(src),
444        });
445        dst
446    }
447
448    /// Emits `min.f32 dst, a, b`.
449    pub fn min_f32(&mut self, a: Register, b: Register) -> Register {
450        self.min_typed(PtxType::F32, a, b)
451    }
452
453    /// Emits `max.f32 dst, a, b`.
454    pub fn max_f32(&mut self, a: Register, b: Register) -> Register {
455        self.max_typed(PtxType::F32, a, b)
456    }
457
458    /// Emits `min.u32 dst, a, b`.
459    pub fn min_u32(&mut self, a: Register, b: Register) -> Register {
460        self.min_typed(PtxType::U32, a, b)
461    }
462
463    /// Emits `max.u32 dst, a, b`.
464    pub fn max_u32(&mut self, a: Register, b: Register) -> Register {
465        self.max_typed(PtxType::U32, a, b)
466    }
467
468    /// Generic typed `min` helper.
469    fn min_typed(&mut self, ty: PtxType, a: Register, b: Register) -> Register {
470        let dst = self.regs.alloc(ty);
471        self.emit(Instruction::Min {
472            ty,
473            dst: dst.clone(),
474            a: Operand::Register(a),
475            b: Operand::Register(b),
476        });
477        dst
478    }
479
480    /// Generic typed `max` helper.
481    fn max_typed(&mut self, ty: PtxType, a: Register, b: Register) -> Register {
482        let dst = self.regs.alloc(ty);
483        self.emit(Instruction::Max {
484            ty,
485            dst: dst.clone(),
486            a: Operand::Register(a),
487            b: Operand::Register(b),
488        });
489        dst
490    }
491
492    // ════════════════════════════════════════════════════════════════════
493    //  Bit Manipulation
494    // ════════════════════════════════════════════════════════════════════
495
496    /// Emits `brev.b32 dst, src` — reverse the bits of a 32-bit value.
497    pub fn brev_b32(&mut self, src: Register) -> Register {
498        let dst = self.regs.alloc(PtxType::B32);
499        self.emit(Instruction::Brev {
500            ty: PtxType::B32,
501            dst: dst.clone(),
502            src: Operand::Register(src),
503        });
504        dst
505    }
506
507    /// Emits `brev.b64 dst, src` — reverse the bits of a 64-bit value.
508    pub fn brev_b64(&mut self, src: Register) -> Register {
509        let dst = self.regs.alloc(PtxType::B64);
510        self.emit(Instruction::Brev {
511            ty: PtxType::B64,
512            dst: dst.clone(),
513            src: Operand::Register(src),
514        });
515        dst
516    }
517
518    /// Emits `clz.b32 dst, src` — count leading zeros (result is U32).
519    pub fn clz_b32(&mut self, src: Register) -> Register {
520        let dst = self.regs.alloc(PtxType::U32);
521        self.emit(Instruction::Clz {
522            ty: PtxType::B32,
523            dst: dst.clone(),
524            src: Operand::Register(src),
525        });
526        dst
527    }
528
529    /// Emits `popc.b32 dst, src` — population count of 32-bit value (result is U32).
530    pub fn popc_b32(&mut self, src: Register) -> Register {
531        let dst = self.regs.alloc(PtxType::U32);
532        self.emit(Instruction::Popc {
533            ty: PtxType::B32,
534            dst: dst.clone(),
535            src: Operand::Register(src),
536        });
537        dst
538    }
539
540    /// Emits `popc.b64 dst, src` — population count of 64-bit value (result is U32).
541    pub fn popc_b64(&mut self, src: Register) -> Register {
542        let dst = self.regs.alloc(PtxType::U32);
543        self.emit(Instruction::Popc {
544            ty: PtxType::B64,
545            dst: dst.clone(),
546            src: Operand::Register(src),
547        });
548        dst
549    }
550
551    /// Emits `bfind.u32 dst, src` — find most significant bit (unsigned, result is U32).
552    pub fn bfind_u32(&mut self, src: Register) -> Register {
553        let dst = self.regs.alloc(PtxType::U32);
554        self.emit(Instruction::Bfind {
555            ty: PtxType::U32,
556            dst: dst.clone(),
557            src: Operand::Register(src),
558        });
559        dst
560    }
561
562    /// Emits `bfind.s32 dst, src` — find most significant non-sign bit (signed, result is U32).
563    pub fn bfind_s32(&mut self, src: Register) -> Register {
564        let dst = self.regs.alloc(PtxType::U32);
565        self.emit(Instruction::Bfind {
566            ty: PtxType::S32,
567            dst: dst.clone(),
568            src: Operand::Register(src),
569        });
570        dst
571    }
572
573    /// Emits `bfe.u32 dst, src, start, len` — extract a bit field (unsigned).
574    pub fn bfe_u32(&mut self, src: Register, start: Register, len: Register) -> Register {
575        let dst = self.regs.alloc(PtxType::U32);
576        self.emit(Instruction::Bfe {
577            ty: PtxType::U32,
578            dst: dst.clone(),
579            src: Operand::Register(src),
580            start: Operand::Register(start),
581            len: Operand::Register(len),
582        });
583        dst
584    }
585
586    /// Emits `bfe.s32 dst, src, start, len` — extract a bit field (signed).
587    pub fn bfe_s32(&mut self, src: Register, start: Register, len: Register) -> Register {
588        let dst = self.regs.alloc(PtxType::S32);
589        self.emit(Instruction::Bfe {
590            ty: PtxType::S32,
591            dst: dst.clone(),
592            src: Operand::Register(src),
593            start: Operand::Register(start),
594            len: Operand::Register(len),
595        });
596        dst
597    }
598
599    /// Emits `bfi.b32 dst, insert, base, start, len` — insert a bit field.
600    pub fn bfi_b32(
601        &mut self,
602        insert: Register,
603        base: Register,
604        start: Register,
605        len: Register,
606    ) -> Register {
607        let dst = self.regs.alloc(PtxType::B32);
608        self.emit(Instruction::Bfi {
609            ty: PtxType::B32,
610            dst: dst.clone(),
611            insert: Operand::Register(insert),
612            base: Operand::Register(base),
613            start: Operand::Register(start),
614            len: Operand::Register(len),
615        });
616        dst
617    }
618
619    // ════════════════════════════════════════════════════════════════════
620    //  Shift Operations
621    // ════════════════════════════════════════════════════════════════════
622
623    /// Emits `shl.b32 dst, src, amount` — left shift, 32-bit.
624    pub fn shl_b32(&mut self, src: Register, amount: Register) -> Register {
625        let dst = self.regs.alloc(PtxType::B32);
626        self.emit(Instruction::Shl {
627            ty: PtxType::B32,
628            dst: dst.clone(),
629            src: Operand::Register(src),
630            amount: Operand::Register(amount),
631        });
632        dst
633    }
634
635    /// Emits `shl.b64 dst, src, amount` — left shift, 64-bit.
636    pub fn shl_b64(&mut self, src: Register, amount: Register) -> Register {
637        let dst = self.regs.alloc(PtxType::B64);
638        self.emit(Instruction::Shl {
639            ty: PtxType::B64,
640            dst: dst.clone(),
641            src: Operand::Register(src),
642            amount: Operand::Register(amount),
643        });
644        dst
645    }
646
647    /// Emits `shr.b32 dst, src, amount` — logical right shift, 32-bit.
648    pub fn shr_b32(&mut self, src: Register, amount: Register) -> Register {
649        let dst = self.regs.alloc(PtxType::B32);
650        self.emit(Instruction::Shr {
651            ty: PtxType::B32,
652            dst: dst.clone(),
653            src: Operand::Register(src),
654            amount: Operand::Register(amount),
655        });
656        dst
657    }
658
659    /// Emits `shr.b64 dst, src, amount` — logical right shift, 64-bit.
660    pub fn shr_b64(&mut self, src: Register, amount: Register) -> Register {
661        let dst = self.regs.alloc(PtxType::B64);
662        self.emit(Instruction::Shr {
663            ty: PtxType::B64,
664            dst: dst.clone(),
665            src: Operand::Register(src),
666            amount: Operand::Register(amount),
667        });
668        dst
669    }
670
671    /// Emits `shr.u32 dst, src, amount` — logical right shift for unsigned 32-bit.
672    pub fn shr_u32(&mut self, src: Register, amount: Register) -> Register {
673        let dst = self.regs.alloc(PtxType::U32);
674        self.emit(Instruction::Shr {
675            ty: PtxType::U32,
676            dst: dst.clone(),
677            src: Operand::Register(src),
678            amount: Operand::Register(amount),
679        });
680        dst
681    }
682
683    /// Emits `shr.s32 dst, src, amount` — arithmetic right shift for signed 32-bit.
684    pub fn shr_s32(&mut self, src: Register, amount: Register) -> Register {
685        let dst = self.regs.alloc(PtxType::S32);
686        self.emit(Instruction::Shr {
687            ty: PtxType::S32,
688            dst: dst.clone(),
689            src: Operand::Register(src),
690            amount: Operand::Register(amount),
691        });
692        dst
693    }
694
695    // ════════════════════════════════════════════════════════════════════
696    //  Special Math Functions
697    // ════════════════════════════════════════════════════════════════════
698
699    /// Emits `rcp.rn.f32 dst, src` — reciprocal, single precision.
700    pub fn rcp_f32(&mut self, src: Register) -> Register {
701        let dst = self.regs.alloc(PtxType::F32);
702        self.emit(Instruction::Rcp {
703            rnd: Some(RoundingMode::Rn),
704            ty: PtxType::F32,
705            dst: dst.clone(),
706            src: Operand::Register(src),
707        });
708        dst
709    }
710
711    /// Emits `rcp.rn.f64 dst, src` — reciprocal, double precision.
712    pub fn rcp_f64(&mut self, src: Register) -> Register {
713        let dst = self.regs.alloc(PtxType::F64);
714        self.emit(Instruction::Rcp {
715            rnd: Some(RoundingMode::Rn),
716            ty: PtxType::F64,
717            dst: dst.clone(),
718            src: Operand::Register(src),
719        });
720        dst
721    }
722
723    /// Emits `rcp.approx.ftz.f32 dst, src` — fast approximate reciprocal.
724    ///
725    /// Uses `rnd=None` to signal approx mode (no IEEE rounding).
726    pub fn rcp_approx_f32(&mut self, src: Register) -> Register {
727        let dst = self.regs.alloc(PtxType::F32);
728        self.emit(Instruction::Rcp {
729            rnd: None,
730            ty: PtxType::F32,
731            dst: dst.clone(),
732            src: Operand::Register(src),
733        });
734        dst
735    }
736
737    /// Emits `rsqrt.approx.f32 dst, src` — approximate reciprocal square root.
738    pub fn rsqrt_approx_f32(&mut self, src: Register) -> Register {
739        let dst = self.regs.alloc(PtxType::F32);
740        self.emit(Instruction::Rsqrt {
741            approx: true,
742            ty: PtxType::F32,
743            dst: dst.clone(),
744            src: Operand::Register(src),
745        });
746        dst
747    }
748
749    /// Emits `rsqrt.approx.f64 dst, src` — approximate reciprocal square root, double precision.
750    pub fn rsqrt_approx_f64(&mut self, src: Register) -> Register {
751        let dst = self.regs.alloc(PtxType::F64);
752        self.emit(Instruction::Rsqrt {
753            approx: true,
754            ty: PtxType::F64,
755            dst: dst.clone(),
756            src: Operand::Register(src),
757        });
758        dst
759    }
760
761    /// Emits `sqrt.rn.f32 dst, src` — square root, single precision.
762    pub fn sqrt_rn_f32(&mut self, src: Register) -> Register {
763        let dst = self.regs.alloc(PtxType::F32);
764        self.emit(Instruction::Sqrt {
765            rnd: Some(RoundingMode::Rn),
766            ty: PtxType::F32,
767            dst: dst.clone(),
768            src: Operand::Register(src),
769        });
770        dst
771    }
772
773    /// Emits `sqrt.rn.f64 dst, src` — square root, double precision.
774    pub fn sqrt_rn_f64(&mut self, src: Register) -> Register {
775        let dst = self.regs.alloc(PtxType::F64);
776        self.emit(Instruction::Sqrt {
777            rnd: Some(RoundingMode::Rn),
778            ty: PtxType::F64,
779            dst: dst.clone(),
780            src: Operand::Register(src),
781        });
782        dst
783    }
784
785    /// Emits `ex2.approx.f32 dst, src` — base-2 exponential, approximate.
786    pub fn ex2_approx_f32(&mut self, src: Register) -> Register {
787        let dst = self.regs.alloc(PtxType::F32);
788        self.emit(Instruction::Ex2 {
789            approx: true,
790            ty: PtxType::F32,
791            dst: dst.clone(),
792            src: Operand::Register(src),
793        });
794        dst
795    }
796
797    /// Emits `lg2.approx.f32 dst, src` — base-2 logarithm, approximate.
798    pub fn lg2_approx_f32(&mut self, src: Register) -> Register {
799        let dst = self.regs.alloc(PtxType::F32);
800        self.emit(Instruction::Lg2 {
801            approx: true,
802            ty: PtxType::F32,
803            dst: dst.clone(),
804            src: Operand::Register(src),
805        });
806        dst
807    }
808
809    /// Emits `sin.approx.f32 dst, src` — sine, approximate.
810    pub fn sin_approx_f32(&mut self, src: Register) -> Register {
811        let dst = self.regs.alloc(PtxType::F32);
812        self.emit(Instruction::Sin {
813            approx: true,
814            ty: PtxType::F32,
815            dst: dst.clone(),
816            src: Operand::Register(src),
817        });
818        dst
819    }
820
821    /// Emits `cos.approx.f32 dst, src` — cosine, approximate.
822    pub fn cos_approx_f32(&mut self, src: Register) -> Register {
823        let dst = self.regs.alloc(PtxType::F32);
824        self.emit(Instruction::Cos {
825            approx: true,
826            ty: PtxType::F32,
827            dst: dst.clone(),
828            src: Operand::Register(src),
829        });
830        dst
831    }
832
833    // ════════════════════════════════════════════════════════════════════
834    //  Memory Operations — Global
835    // ════════════════════════════════════════════════════════════════════
836
837    /// Loads a single `f32` from global memory.
838    ///
839    /// `addr` should be a `U64` register containing the global device pointer.
840    /// Emits `ld.global.f32 dst, [addr]`.
841    pub fn load_global_f32(&mut self, addr: Register) -> Register {
842        self.load_global_scalar(PtxType::F32, addr)
843    }
844
845    /// Loads a single `f64` from global memory.
846    pub fn load_global_f64(&mut self, addr: Register) -> Register {
847        self.load_global_scalar(PtxType::F64, addr)
848    }
849
850    /// Loads a single signed 32-bit integer from global memory.
851    ///
852    /// Emits `ld.global.s32 dst, [addr]`.
853    pub fn load_global_i32(&mut self, addr: Register) -> Register {
854        self.load_global_scalar(PtxType::S32, addr)
855    }
856
857    /// Loads a single unsigned 32-bit integer from global memory.
858    ///
859    /// Emits `ld.global.u32 dst, [addr]`.
860    pub fn load_global_u32(&mut self, addr: Register) -> Register {
861        self.load_global_scalar(PtxType::U32, addr)
862    }
863
864    /// Loads a single scalar value from global memory.
865    fn load_global_scalar(&mut self, ty: PtxType, addr: Register) -> Register {
866        let dst = self.regs.alloc(ty);
867        self.emit(Instruction::Load {
868            space: MemorySpace::Global,
869            qualifier: CacheQualifier::None,
870            vec: VectorWidth::V1,
871            ty,
872            dst: dst.clone(),
873            addr: Operand::Address {
874                base: addr,
875                offset: None,
876            },
877        });
878        dst
879    }
880
881    /// Loads four `f32` values from global memory as a vectorized `.v4` load.
882    ///
883    /// Returns an array of 4 registers containing the loaded values.
884    /// `addr` must be 16-byte aligned for correctness.
885    ///
886    /// Since the IR `Load` instruction uses a single destination register,
887    /// this method emits raw PTX for the vectorized load and individual
888    /// `mov` instructions to extract each element.
889    pub fn load_global_f32x4(&mut self, addr: &Register) -> [Register; 4] {
890        let r0 = self.regs.alloc(PtxType::F32);
891        let r1 = self.regs.alloc(PtxType::F32);
892        let r2 = self.regs.alloc(PtxType::F32);
893        let r3 = self.regs.alloc(PtxType::F32);
894        self.emit(Instruction::Raw(format!(
895            "ld.global.v4.f32 {{{r0}, {r1}, {r2}, {r3}}}, [{addr}];"
896        )));
897        [r0, r1, r2, r3]
898    }
899
900    /// Stores a single `f32` to global memory.
901    ///
902    /// `addr` should be a `U64` register containing the global device pointer.
903    pub fn store_global_f32(&mut self, addr: Register, val: Register) {
904        self.store_global_scalar(PtxType::F32, addr, val);
905    }
906
907    /// Stores a single `f64` to global memory.
908    pub fn store_global_f64(&mut self, addr: Register, val: Register) {
909        self.store_global_scalar(PtxType::F64, addr, val);
910    }
911
912    /// Stores a single signed 32-bit integer to global memory.
913    ///
914    /// Emits `st.global.s32 [addr], val`.
915    pub fn store_global_i32(&mut self, addr: Register, val: Register) {
916        self.store_global_scalar(PtxType::S32, addr, val);
917    }
918
919    /// Stores a single unsigned 32-bit integer to global memory.
920    ///
921    /// Emits `st.global.u32 [addr], val`.
922    pub fn store_global_u32(&mut self, addr: Register, val: Register) {
923        self.store_global_scalar(PtxType::U32, addr, val);
924    }
925
926    /// Stores a single scalar to global memory.
927    fn store_global_scalar(&mut self, ty: PtxType, addr: Register, val: Register) {
928        self.emit(Instruction::Store {
929            space: MemorySpace::Global,
930            qualifier: CacheQualifier::None,
931            vec: VectorWidth::V1,
932            ty,
933            addr: Operand::Address {
934                base: addr,
935                offset: None,
936            },
937            src: val,
938        });
939    }
940
941    // ════════════════════════════════════════════════════════════════════
942    //  Memory Operations — Shared
943    // ════════════════════════════════════════════════════════════════════
944
945    /// Loads a single `f32` from shared memory.
946    ///
947    /// `addr` should be a register containing an address in shared memory space.
948    pub fn load_shared_f32(&mut self, addr: Register) -> Register {
949        let dst = self.regs.alloc(PtxType::F32);
950        self.emit(Instruction::Load {
951            space: MemorySpace::Shared,
952            qualifier: CacheQualifier::None,
953            vec: VectorWidth::V1,
954            ty: PtxType::F32,
955            dst: dst.clone(),
956            addr: Operand::Address {
957                base: addr,
958                offset: None,
959            },
960        });
961        dst
962    }
963
964    /// Stores a single `f32` to shared memory.
965    pub fn store_shared_f32(&mut self, addr: Register, val: Register) {
966        self.emit(Instruction::Store {
967            space: MemorySpace::Shared,
968            qualifier: CacheQualifier::None,
969            vec: VectorWidth::V1,
970            ty: PtxType::F32,
971            addr: Operand::Address {
972                base: addr,
973                offset: None,
974            },
975            src: val,
976        });
977    }
978
979    // ════════════════════════════════════════════════════════════════════
980    //  Asynchronous Copy (cp.async, Ampere+)
981    // ════════════════════════════════════════════════════════════════════
982
983    /// Emits a 32-bit (4-byte) asynchronous copy from global to shared memory.
984    ///
985    /// Emits: `cp.async.ca.shared.global [dst], [src], 4;`
986    /// Requires `sm_80`+.
987    pub fn cp_async_32bit(&mut self, dst_shared: Register, src_global: Register) {
988        self.emit(Instruction::CpAsync {
989            bytes: 4,
990            dst_shared: Operand::Register(dst_shared),
991            src_global: Operand::Register(src_global),
992        });
993    }
994
995    /// Emits a 64-bit (8-byte) asynchronous copy from global to shared memory.
996    ///
997    /// Emits: `cp.async.ca.shared.global [dst], [src], 8;`
998    /// Requires `sm_80`+.
999    pub fn cp_async_64bit(&mut self, dst_shared: Register, src_global: Register) {
1000        self.emit(Instruction::CpAsync {
1001            bytes: 8,
1002            dst_shared: Operand::Register(dst_shared),
1003            src_global: Operand::Register(src_global),
1004        });
1005    }
1006
1007    /// Emits a 128-bit (16-byte) asynchronous copy from global to shared memory.
1008    ///
1009    /// This is the most common `cp.async` variant, used for double-buffered
1010    /// data loading in high-performance kernels. Requires `sm_80`+.
1011    pub fn cp_async_128bit(&mut self, dst_shared: Register, src_global: Register) {
1012        self.emit(Instruction::CpAsync {
1013            bytes: 16,
1014            dst_shared: Operand::Register(dst_shared),
1015            src_global: Operand::Register(src_global),
1016        });
1017    }
1018
1019    /// Emits `cp.async.commit_group` to commit all pending async copies.
1020    pub fn cp_async_commit(&mut self) {
1021        self.emit(Instruction::CpAsyncCommit);
1022    }
1023
1024    /// Emits `cp.async.wait_group N` to wait until at most `n` copy groups
1025    /// are still pending.
1026    ///
1027    /// Pass `0` to wait for all pending copies to complete.
1028    pub fn cp_async_wait(&mut self, n: u32) {
1029        self.emit(Instruction::CpAsyncWait { n });
1030    }
1031
1032    /// Emits a `ldmatrix.sync.aligned.m8n8.x4.shared.b16` instruction (SM >= 75).
1033    ///
1034    /// Loads 4 warp-cooperative 8×8 B16 matrix fragments from shared memory.
1035    /// Each of the 32 threads contributes to loading 8 bytes (one row) of
1036    /// the tile. Returns the four destination registers.
1037    ///
1038    /// # Errors
1039    ///
1040    /// Returns [`PtxGenError`] if the target architecture does not support
1041    /// `ldmatrix` (requires SM >= 75).
1042    pub fn ldmatrix_x4(&mut self, src_addr: Register) -> Result<[Register; 4], PtxGenError> {
1043        use crate::ir::Instruction as I;
1044        if !self.target.capabilities().has_ldmatrix {
1045            return Err(PtxGenError::UnsupportedFeature {
1046                arch: self.target.as_ptx_str().to_string(),
1047                feature: "ldmatrix (SM >= 75)".to_string(),
1048            });
1049        }
1050        let r0 = self.regs.alloc(PtxType::B32);
1051        let r1 = self.regs.alloc(PtxType::B32);
1052        let r2 = self.regs.alloc(PtxType::B32);
1053        let r3 = self.regs.alloc(PtxType::B32);
1054        self.emit(I::Ldmatrix {
1055            num_fragments: 4,
1056            trans: false,
1057            dst_regs: vec![r0.clone(), r1.clone(), r2.clone(), r3.clone()],
1058            src_addr: Operand::Register(src_addr),
1059        });
1060        Ok([r0, r1, r2, r3])
1061    }
1062
1063    // ════════════════════════════════════════════════════════════════════
1064    //  Control Flow
1065    // ════════════════════════════════════════════════════════════════════
1066
1067    /// Emits a conditional block that executes `body` when `a < b` (unsigned 32-bit).
1068    ///
1069    /// Generates a `setp.lo.u32` comparison, a negated conditional branch
1070    /// over the body, and a skip label.
1071    ///
1072    /// # Example
1073    ///
1074    /// ```ignore
1075    /// b.if_lt_u32(tid, n, |b| {
1076    ///     // Only threads with tid < n execute this
1077    /// });
1078    /// ```
1079    pub fn if_lt_u32<F>(&mut self, a: Register, b: Register, body: F)
1080    where
1081        F: FnOnce(&mut BodyBuilder<'_>),
1082    {
1083        let pred = self.regs.alloc(PtxType::Pred);
1084        self.emit(Instruction::SetP {
1085            cmp: CmpOp::Lo,
1086            ty: PtxType::U32,
1087            dst: pred.clone(),
1088            a: Operand::Register(a),
1089            b: Operand::Register(b),
1090        });
1091        let skip_label = self.fresh_label("skip");
1092        // Branch to skip when predicate is false (negate = true).
1093        self.emit(Instruction::Branch {
1094            target: skip_label.clone(),
1095            predicate: Some((pred, true)),
1096        });
1097        body(self);
1098        self.emit(Instruction::Label(skip_label));
1099    }
1100
1101    /// Emits a conditional block that executes `body` when `a >= b` (unsigned 32-bit).
1102    pub fn if_ge_u32<F>(&mut self, a: Register, b: Register, body: F)
1103    where
1104        F: FnOnce(&mut BodyBuilder<'_>),
1105    {
1106        let pred = self.regs.alloc(PtxType::Pred);
1107        self.emit(Instruction::SetP {
1108            cmp: CmpOp::Hs,
1109            ty: PtxType::U32,
1110            dst: pred.clone(),
1111            a: Operand::Register(a),
1112            b: Operand::Register(b),
1113        });
1114        let skip_label = self.fresh_label("skip");
1115        self.emit(Instruction::Branch {
1116            target: skip_label.clone(),
1117            predicate: Some((pred, true)),
1118        });
1119        body(self);
1120        self.emit(Instruction::Label(skip_label));
1121    }
1122
1123    /// Compile-time loop unrolling.
1124    ///
1125    /// Calls `body(i)` for `i` in `0..count`, emitting all iterations
1126    /// inline. This is equivalent to `#pragma unroll` in CUDA C.
1127    ///
1128    /// Each iteration gets its own comment indicating the unroll index.
1129    pub fn unroll<F>(&mut self, count: u32, mut body: F)
1130    where
1131        F: FnMut(&mut BodyBuilder<'_>, u32),
1132    {
1133        for i in 0..count {
1134            self.comment(&format!("unroll iteration {i}/{count}"));
1135            body(self, i);
1136        }
1137    }
1138
1139    /// Emits a `.pragma "unroll N"` or `.pragma "nounroll"` directive hint.
1140    ///
1141    /// When `factor` is `Some(n)`, emits `.pragma "unroll N";` to hint the
1142    /// PTX assembler to unroll the following loop by factor `n`.
1143    /// When `factor` is `None`, emits `.pragma "nounroll";` to suppress
1144    /// unrolling.
1145    pub fn pragma_unroll(&mut self, factor: Option<u32>) {
1146        let text = factor.map_or_else(|| "nounroll".to_string(), |n| format!("unroll {n}"));
1147        self.emit(Instruction::Pragma(text));
1148    }
1149
1150    /// Emits a label pseudo-instruction.
1151    ///
1152    /// Labels are branch targets. They appear at the start of a line
1153    /// without indentation in the generated PTX.
1154    pub fn label(&mut self, name: &str) {
1155        self.emit(Instruction::Label(name.to_string()));
1156    }
1157
1158    /// Emits an unconditional branch to the given label.
1159    pub fn branch(&mut self, target: &str) {
1160        self.emit(Instruction::Branch {
1161            target: target.to_string(),
1162            predicate: None,
1163        });
1164    }
1165
1166    /// Emits a conditional branch: `@pred bra target`.
1167    pub fn branch_if(&mut self, pred: Register, target: &str) {
1168        self.emit(Instruction::Branch {
1169            target: target.to_string(),
1170            predicate: Some((pred, false)),
1171        });
1172    }
1173
1174    /// Emits a `ret` instruction to return from the kernel.
1175    pub fn ret(&mut self) {
1176        self.emit(Instruction::Return);
1177    }
1178
1179    // ════════════════════════════════════════════════════════════════════
1180    //  Synchronization
1181    // ════════════════════════════════════════════════════════════════════
1182
1183    /// Emits `bar.sync id` — block-level barrier synchronization.
1184    ///
1185    /// All threads in the block must reach this barrier before any can proceed.
1186    /// `id` is typically 0.
1187    pub fn bar_sync(&mut self, id: u32) {
1188        self.emit(Instruction::BarSync { id });
1189    }
1190
1191    /// Emits a memory fence with acquire-release semantics at the given scope.
1192    ///
1193    /// - [`FenceScope::Cta`]: visibility within the block
1194    /// - [`FenceScope::Gpu`]: visibility across the entire GPU
1195    /// - [`FenceScope::Sys`]: visibility across GPU and host
1196    pub fn fence_acq_rel(&mut self, scope: FenceScope) {
1197        self.emit(Instruction::FenceAcqRel { scope });
1198    }
1199
1200    // ════════════════════════════════════════════════════════════════════
1201    //  Type Conversion
1202    // ════════════════════════════════════════════════════════════════════
1203
1204    /// Converts a `u32` register to `u64` (zero-extension).
1205    ///
1206    /// Emits `cvt.u64.u32 dst, src`.
1207    pub fn cvt_u32_to_u64(&mut self, src: Register) -> Register {
1208        let dst = self.regs.alloc(PtxType::U64);
1209        self.emit(Instruction::Cvt {
1210            rnd: None,
1211            dst_ty: PtxType::U64,
1212            src_ty: PtxType::U32,
1213            dst: dst.clone(),
1214            src: Operand::Register(src),
1215        });
1216        dst
1217    }
1218
1219    /// Converts an `f32` register to `f64` (widening).
1220    ///
1221    /// Emits `cvt.f64.f32 dst, src`.
1222    pub fn cvt_f32_to_f64(&mut self, src: Register) -> Register {
1223        let dst = self.regs.alloc(PtxType::F64);
1224        self.emit(Instruction::Cvt {
1225            rnd: None,
1226            dst_ty: PtxType::F64,
1227            src_ty: PtxType::F32,
1228            dst: dst.clone(),
1229            src: Operand::Register(src),
1230        });
1231        dst
1232    }
1233
1234    /// Converts an `f64` register to `f32` (narrowing, round-to-nearest-even).
1235    ///
1236    /// Emits `cvt.rn.f32.f64 dst, src`.
1237    pub fn cvt_f64_to_f32(&mut self, src: Register) -> Register {
1238        let dst = self.regs.alloc(PtxType::F32);
1239        self.emit(Instruction::Cvt {
1240            rnd: Some(RoundingMode::Rn),
1241            dst_ty: PtxType::F32,
1242            src_ty: PtxType::F64,
1243            dst: dst.clone(),
1244            src: Operand::Register(src),
1245        });
1246        dst
1247    }
1248
1249    /// Converts an `f16` register to `f32` (widening).
1250    ///
1251    /// Emits `cvt.f32.f16 dst, src`.
1252    pub fn cvt_f16_to_f32(&mut self, src: Register) -> Register {
1253        let dst = self.regs.alloc(PtxType::F32);
1254        self.emit(Instruction::Cvt {
1255            rnd: None,
1256            dst_ty: PtxType::F32,
1257            src_ty: PtxType::F16,
1258            dst: dst.clone(),
1259            src: Operand::Register(src),
1260        });
1261        dst
1262    }
1263
1264    /// Converts an `f32` register to `f16` (narrowing, round-to-nearest-even).
1265    ///
1266    /// Emits `cvt.rn.f16.f32 dst, src`.
1267    pub fn cvt_f32_to_f16(&mut self, src: Register) -> Register {
1268        let dst = self.regs.alloc(PtxType::F16);
1269        self.emit(Instruction::Cvt {
1270            rnd: Some(RoundingMode::Rn),
1271            dst_ty: PtxType::F16,
1272            src_ty: PtxType::F32,
1273            dst: dst.clone(),
1274            src: Operand::Register(src),
1275        });
1276        dst
1277    }
1278
1279    /// Converts a `bf16` register to `f32` (widening).
1280    ///
1281    /// Emits `cvt.f32.bf16 dst, src`.
1282    pub fn cvt_bf16_to_f32(&mut self, src: Register) -> Register {
1283        let dst = self.regs.alloc(PtxType::F32);
1284        self.emit(Instruction::Cvt {
1285            rnd: None,
1286            dst_ty: PtxType::F32,
1287            src_ty: PtxType::BF16,
1288            dst: dst.clone(),
1289            src: Operand::Register(src),
1290        });
1291        dst
1292    }
1293
1294    /// Converts an `f32` register to `bf16` (narrowing, round-to-nearest-even).
1295    ///
1296    /// Emits `cvt.rn.bf16.f32 dst, src`.
1297    pub fn cvt_f32_to_bf16(&mut self, src: Register) -> Register {
1298        let dst = self.regs.alloc(PtxType::BF16);
1299        self.emit(Instruction::Cvt {
1300            rnd: Some(RoundingMode::Rn),
1301            dst_ty: PtxType::BF16,
1302            src_ty: PtxType::F32,
1303            dst: dst.clone(),
1304            src: Operand::Register(src),
1305        });
1306        dst
1307    }
1308
1309    /// Converts an `f32` register to FP8 `E4M3` format (`sm_89+`, Ada/Hopper).
1310    ///
1311    /// Emits: `cvt.rn.satfinite.e4m3x2.f32 dst, src_hi, src_lo`
1312    /// Note: PTX packs two FP8 values per register (`e4m3x2`).
1313    pub fn cvt_f32_to_e4m3(&mut self, src: Register) -> Register {
1314        let dst = self.regs.alloc(PtxType::E4M3);
1315        self.emit(Instruction::Cvt {
1316            rnd: Some(RoundingMode::Rn),
1317            dst_ty: PtxType::E4M3,
1318            src_ty: PtxType::F32,
1319            dst: dst.clone(),
1320            src: Operand::Register(src),
1321        });
1322        dst
1323    }
1324
1325    /// Converts an FP8 `E4M3` register to `f32` (`sm_89+`).
1326    ///
1327    /// Emits `cvt.f32.e4m3 dst, src`.
1328    pub fn cvt_e4m3_to_f32(&mut self, src: Register) -> Register {
1329        let dst = self.regs.alloc(PtxType::F32);
1330        self.emit(Instruction::Cvt {
1331            rnd: None,
1332            dst_ty: PtxType::F32,
1333            src_ty: PtxType::E4M3,
1334            dst: dst.clone(),
1335            src: Operand::Register(src),
1336        });
1337        dst
1338    }
1339
1340    /// Converts an `f32` register to FP8 `E5M2` format (`sm_89+`).
1341    ///
1342    /// Emits `cvt.rn.e5m2.f32 dst, src`.
1343    pub fn cvt_f32_to_e5m2(&mut self, src: Register) -> Register {
1344        let dst = self.regs.alloc(PtxType::E5M2);
1345        self.emit(Instruction::Cvt {
1346            rnd: Some(RoundingMode::Rn),
1347            dst_ty: PtxType::E5M2,
1348            src_ty: PtxType::F32,
1349            dst: dst.clone(),
1350            src: Operand::Register(src),
1351        });
1352        dst
1353    }
1354
1355    /// Converts an FP8 `E5M2` register to `f32` (`sm_89+`).
1356    ///
1357    /// Emits `cvt.f32.e5m2 dst, src`.
1358    pub fn cvt_e5m2_to_f32(&mut self, src: Register) -> Register {
1359        let dst = self.regs.alloc(PtxType::F32);
1360        self.emit(Instruction::Cvt {
1361            rnd: None,
1362            dst_ty: PtxType::F32,
1363            src_ty: PtxType::E5M2,
1364            dst: dst.clone(),
1365            src: Operand::Register(src),
1366        });
1367        dst
1368    }
1369
1370    // ════════════════════════════════════════════════════════════════════
1371    //  Tensor Core (Ampere+ MMA)
1372    // ════════════════════════════════════════════════════════════════════
1373
1374    /// Emits an `mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32` instruction.
1375    ///
1376    /// This is the standard Ampere tensor core MMA operation:
1377    /// - **Shape**: 16x8x16 tile
1378    /// - **A fragment**: registers holding f16 matrix A data
1379    /// - **B fragment**: registers holding f16 matrix B data
1380    /// - **C/D accumulator**: 4 f32 registers for input/output accumulator
1381    ///
1382    /// Returns the 4 destination accumulator registers.
1383    ///
1384    /// # Arguments
1385    ///
1386    /// * `a_regs` — Registers holding the A matrix fragment (f16)
1387    /// * `b_regs` — Registers holding the B matrix fragment (f16)
1388    /// * `c_regs` — Registers holding the C accumulator input (f32)
1389    pub fn mma_m16n8k16_f16_f32(
1390        &mut self,
1391        a_regs: &[Register],
1392        b_regs: &[Register],
1393        c_regs: &[Register],
1394    ) -> [Register; 4] {
1395        let dst = self.regs.alloc_group(PtxType::F32, 4);
1396        self.emit(Instruction::Mma {
1397            shape: MmaShape::M16N8K16,
1398            a_ty: PtxType::F16,
1399            b_ty: PtxType::F16,
1400            c_ty: PtxType::F32,
1401            d_ty: PtxType::F32,
1402            d_regs: dst.clone(),
1403            a_regs: a_regs.to_vec(),
1404            b_regs: b_regs.to_vec(),
1405            c_regs: c_regs.to_vec(),
1406        });
1407        [
1408            dst[0].clone(),
1409            dst[1].clone(),
1410            dst[2].clone(),
1411            dst[3].clone(),
1412        ]
1413    }
1414
1415    /// Emits `wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16`
1416    /// Warpgroup MMA async for Hopper (`sm_90+`) computing a 64×128 tile.
1417    ///
1418    /// This operates on warpgroup-level fragments:
1419    /// - `a_desc`: A operand descriptor (shared memory descriptor string)
1420    /// - `b_desc`: B operand descriptor
1421    /// - Accumulator: 64 f32 registers (managed by the warpgroup implicitly)
1422    ///
1423    /// Emits raw PTX via `raw_ptx` since wgmma is not yet in the structured IR.
1424    ///
1425    /// # Errors
1426    ///
1427    /// Returns `PtxGenError` when the target SM is below 90 (Hopper).
1428    pub fn wgmma_mma_async_m64n128k16_f16(
1429        &mut self,
1430        a_desc: &str,
1431        b_desc: &str,
1432    ) -> Result<(), PtxGenError> {
1433        if !self.target.capabilities().has_wgmma {
1434            return Err(PtxGenError::GenerationFailed(format!(
1435                "wgmma.mma_async requires SM >= 90 (Hopper), target is {}",
1436                self.target
1437            )));
1438        }
1439        self.raw_ptx(&format!(
1440            "wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16 {{...}}, {a_desc}, {b_desc}, 1, 1, 1, 0, 0;"
1441        ));
1442        Ok(())
1443    }
1444
1445    // ════════════════════════════════════════════════════════════════════
1446    //  Video Instructions (dp4a / dp2a)
1447    // ════════════════════════════════════════════════════════════════════
1448
1449    /// Emits `dp4a.u32.u32 dst, a, b, c` — unsigned 4-way byte dot product.
1450    pub fn dp4a_u32_u32(&mut self, a: Register, b: Register, c: Register) -> Register {
1451        self.dp4a_typed(a, b, c, false, false)
1452    }
1453
1454    /// Emits `dp4a.s32.s32 dst, a, b, c` — signed 4-way byte dot product.
1455    pub fn dp4a_s32_s32(&mut self, a: Register, b: Register, c: Register) -> Register {
1456        self.dp4a_typed(a, b, c, true, true)
1457    }
1458
1459    /// Emits `dp4a.s32.u32 dst, a, b, c` — mixed signed/unsigned 4-way byte dot product.
1460    pub fn dp4a_s32_u32(&mut self, a: Register, b: Register, c: Register) -> Register {
1461        self.dp4a_typed(a, b, c, true, false)
1462    }
1463
1464    /// Emits `dp4a.u32.s32 dst, a, b, c` — mixed unsigned/signed 4-way byte dot product.
1465    pub fn dp4a_u32_s32(&mut self, a: Register, b: Register, c: Register) -> Register {
1466        self.dp4a_typed(a, b, c, false, true)
1467    }
1468
1469    /// Generic dp4a helper.
1470    fn dp4a_typed(
1471        &mut self,
1472        a: Register,
1473        b: Register,
1474        c: Register,
1475        signed_a: bool,
1476        signed_b: bool,
1477    ) -> Register {
1478        let dst = self.regs.alloc(PtxType::S32);
1479        self.emit(Instruction::Dp4a {
1480            dst: dst.clone(),
1481            a: Operand::Register(a),
1482            b: Operand::Register(b),
1483            c: Operand::Register(c),
1484            signed_a,
1485            signed_b,
1486        });
1487        dst
1488    }
1489
1490    /// Emits `dp2a.lo.u32.u32 dst, a, b, c` — unsigned 2-way dot product, low half.
1491    pub fn dp2a_lo_u32_u32(&mut self, a: Register, b: Register, c: Register) -> Register {
1492        self.dp2a_typed(a, b, c, false, false, true)
1493    }
1494
1495    /// Emits `dp2a.hi.u32.u32 dst, a, b, c` — unsigned 2-way dot product, high half.
1496    pub fn dp2a_hi_u32_u32(&mut self, a: Register, b: Register, c: Register) -> Register {
1497        self.dp2a_typed(a, b, c, false, false, false)
1498    }
1499
1500    /// Emits `dp2a.lo.s32.s32 dst, a, b, c` — signed 2-way dot product, low half.
1501    pub fn dp2a_lo_s32_s32(&mut self, a: Register, b: Register, c: Register) -> Register {
1502        self.dp2a_typed(a, b, c, true, true, true)
1503    }
1504
1505    /// Emits `dp2a.hi.s32.s32 dst, a, b, c` — signed 2-way dot product, high half.
1506    pub fn dp2a_hi_s32_s32(&mut self, a: Register, b: Register, c: Register) -> Register {
1507        self.dp2a_typed(a, b, c, true, true, false)
1508    }
1509
1510    /// Generic dp2a helper.
1511    fn dp2a_typed(
1512        &mut self,
1513        a: Register,
1514        b: Register,
1515        c: Register,
1516        signed_a: bool,
1517        signed_b: bool,
1518        lo: bool,
1519    ) -> Register {
1520        let dst = self.regs.alloc(PtxType::S32);
1521        self.emit(Instruction::Dp2a {
1522            dst: dst.clone(),
1523            a: Operand::Register(a),
1524            b: Operand::Register(b),
1525            c: Operand::Register(c),
1526            signed_a,
1527            signed_b,
1528            lo,
1529        });
1530        dst
1531    }
1532
1533    // ════════════════════════════════════════════════════════════════════
1534    //  Immediate Value Helpers
1535    // ════════════════════════════════════════════════════════════════════
1536
1537    /// Creates an unsigned 32-bit immediate operand.
1538    #[must_use]
1539    pub const fn imm_u32(&self, val: u32) -> Operand {
1540        Operand::Immediate(ImmValue::U32(val))
1541    }
1542
1543    /// Loads an unsigned 32-bit immediate into a new register via `add.u32 dst, 0, val`.
1544    pub fn mov_imm_u32(&mut self, val: u32) -> Register {
1545        let dst = self.regs.alloc(PtxType::U32);
1546        self.emit(Instruction::Add {
1547            ty: PtxType::U32,
1548            dst: dst.clone(),
1549            a: Operand::Immediate(ImmValue::U32(0)),
1550            b: Operand::Immediate(ImmValue::U32(val)),
1551        });
1552        dst
1553    }
1554
1555    /// Creates an unsigned 64-bit immediate operand.
1556    #[must_use]
1557    pub const fn imm_u64(&self, val: u64) -> Operand {
1558        Operand::Immediate(ImmValue::U64(val))
1559    }
1560
1561    /// Creates a 32-bit floating-point immediate operand.
1562    #[must_use]
1563    pub const fn imm_f32(&self, val: f32) -> Operand {
1564        Operand::Immediate(ImmValue::F32(val))
1565    }
1566
1567    /// Creates a 64-bit floating-point immediate operand.
1568    #[must_use]
1569    pub const fn imm_f64(&self, val: f64) -> Operand {
1570        Operand::Immediate(ImmValue::F64(val))
1571    }
1572
1573    // ════════════════════════════════════════════════════════════════════
1574    //  Miscellaneous / Escape Hatches
1575    // ════════════════════════════════════════════════════════════════════
1576
1577    /// Emits a comment in the PTX output (for debugging / readability).
1578    pub fn comment(&mut self, text: &str) {
1579        self.emit(Instruction::Comment(text.to_string()));
1580    }
1581
1582    /// Emits raw PTX text verbatim. Use as an escape hatch for instructions
1583    /// not yet modeled in the IR.
1584    ///
1585    /// Named registers (e.g., `%f_x`, `%rd_off`, `%p_ge`) found in the text
1586    /// are automatically declared based on their prefix:
1587    /// - `%f_*`  → `.reg .f32`
1588    /// - `%rd_*` → `.reg .b64`
1589    /// - `%r_*`  → `.reg .b32`
1590    /// - `%p_*`  → `.reg .pred`
1591    pub fn raw_ptx(&mut self, text: &str) {
1592        // Auto-declare named registers found in the raw text.
1593        let mut i = 0;
1594        let bytes = text.as_bytes();
1595        while i < bytes.len() {
1596            if bytes[i] == b'%' {
1597                let start = i;
1598                i += 1;
1599                // Consume alphanumeric + underscore
1600                while i < bytes.len() && (bytes[i].is_ascii_alphanumeric() || bytes[i] == b'_') {
1601                    i += 1;
1602                }
1603                let name = &text[start..i];
1604                // Only declare names containing an underscore (custom names)
1605                if name.contains('_') {
1606                    let ty = if name.starts_with("%rd_") {
1607                        PtxType::B64
1608                    } else if name.starts_with("%f_") {
1609                        PtxType::F32
1610                    } else if name.starts_with("%p_") {
1611                        PtxType::Pred
1612                    } else if name.starts_with("%r_") {
1613                        PtxType::B32
1614                    } else {
1615                        continue;
1616                    };
1617                    self.regs.declare_named(name, ty);
1618                }
1619            } else {
1620                i += 1;
1621            }
1622        }
1623        self.emit(Instruction::Raw(text.to_string()));
1624    }
1625
1626    // ════════════════════════════════════════════════════════════════════
1627    //  Address Computation Helpers
1628    // ════════════════════════════════════════════════════════════════════
1629
1630    /// Computes a byte offset address: `base + index * stride`.
1631    ///
1632    /// Useful for computing element addresses in arrays. The index is
1633    /// zero-extended from `u32` to `u64` before the multiplication.
1634    ///
1635    /// Returns a `U64` register containing the computed address.
1636    pub fn byte_offset_addr(
1637        &mut self,
1638        base: Register,
1639        index: Register,
1640        stride_bytes: u32,
1641    ) -> Register {
1642        let idx64 = self.cvt_u32_to_u64(index);
1643        // Use mad.wide.u32 to compute idx * stride + base... but we need
1644        // mul then add since mad with mixed types isn't straightforward.
1645        let stride_reg = self.regs.alloc(PtxType::U64);
1646        self.emit(Instruction::Raw(format!(
1647            "mov.u64 {}, {};",
1648            stride_reg,
1649            u64::from(stride_bytes)
1650        )));
1651        let offset = self.regs.alloc(PtxType::U64);
1652        self.emit(Instruction::Mul {
1653            ty: PtxType::U64,
1654            mode: MulMode::Lo,
1655            dst: offset.clone(),
1656            a: Operand::Register(idx64),
1657            b: Operand::Register(stride_reg),
1658        });
1659        self.add_u64(base, offset)
1660    }
1661
1662    /// Computes an element address for an `f32` array: `base + index * 4`.
1663    pub fn f32_elem_addr(&mut self, base: Register, index: Register) -> Register {
1664        self.byte_offset_addr(base, index, 4)
1665    }
1666
1667    /// Computes an element address for an `f64` array: `base + index * 8`.
1668    pub fn f64_elem_addr(&mut self, base: Register, index: Register) -> Register {
1669        self.byte_offset_addr(base, index, 8)
1670    }
1671
1672    // ════════════════════════════════════════════════════════════════════
1673    //  Register Allocation (Direct Access)
1674    // ════════════════════════════════════════════════════════════════════
1675
1676    /// Allocates a fresh register of the given type.
1677    ///
1678    /// This is a lower-level API — most users should prefer the typed
1679    /// instruction methods which allocate destination registers automatically.
1680    pub fn alloc_reg(&mut self, ty: PtxType) -> Register {
1681        self.regs.alloc(ty)
1682    }
1683
1684    /// Declares a named register for use in [`raw_ptx`](Self::raw_ptx) blocks.
1685    ///
1686    /// Named registers (e.g., `%f_x`, `%rd_off`) are not created by the
1687    /// automatic allocator, so they must be declared explicitly before use.
1688    pub fn declare_named_reg(&mut self, name: &str, ty: PtxType) {
1689        self.regs.declare_named(name, ty);
1690    }
1691
1692    // ════════════════════════════════════════════════════════════════════
1693    //  Internal Helpers
1694    // ════════════════════════════════════════════════════════════════════
1695
1696    /// Appends an instruction to the body.
1697    fn emit(&mut self, inst: Instruction) {
1698        self.instructions.push(inst);
1699    }
1700
1701    /// Generates a unique label name with the given prefix.
1702    ///
1703    /// Labels are formatted as `L__{prefix}_{counter}` to avoid
1704    /// collisions with user-defined labels and other generated labels.
1705    pub fn fresh_label(&mut self, prefix: &str) -> String {
1706        let id = self.label_counter;
1707        self.label_counter += 1;
1708        format!("L__{prefix}_{id}")
1709    }
1710
1711    /// Returns the target SM version for this kernel.
1712    ///
1713    /// Useful for architecture-gated code paths within body closures.
1714    #[must_use]
1715    pub const fn target_sm(&self) -> SmVersion {
1716        self.target
1717    }
1718
1719    /// Returns `true` if the given parameter name was declared on the kernel.
1720    #[must_use]
1721    pub fn has_param(&self, name: &str) -> bool {
1722        self.param_names.iter().any(|p| p == name)
1723    }
1724}
1725
1726// Atomic/reduce, texture/surface, warp-level primitives, and barrier methods
1727// live in a sibling module to keep this file under the 2000-line policy.
1728pub(super) mod body_builder_ext;
1729
1730// Extended tensor core builder: WMMA, MMA (TF32/BF16/FP8/INT8), WGMMA.
1731pub(super) mod tensor_core_ops;
1732
1733#[cfg(test)]
1734#[path = "body_builder_tests.rs"]
1735mod tests;