Skip to main content

kaio_core/instr/
control.rs

1//! Control flow and synchronization PTX operations.
2//!
3//! Contains comparison-to-predicate ([`SetP`](ControlOp::SetP)),
4//! branching ([`BraPred`](ControlOp::BraPred), [`Bra`](ControlOp::Bra)),
5//! [`Ret`](ControlOp::Ret), barrier synchronization
6//! ([`BarSync`](ControlOp::BarSync)), and warp shuffle operations
7//! ([`ShflSyncDown`](ControlOp::ShflSyncDown),
8//! [`ShflSyncUp`](ControlOp::ShflSyncUp),
9//! [`ShflSyncBfly`](ControlOp::ShflSyncBfly)).
10
11use std::fmt;
12
13use crate::emit::{Emit, PtxWriter};
14use crate::ir::{Operand, Register};
15use crate::types::PtxType;
16
17/// Comparison operator for `setp` instructions.
18#[derive(Debug, Clone, Copy, PartialEq, Eq)]
19pub enum CmpOp {
20    /// Equal (`==`)
21    Eq,
22    /// Not equal (`!=`)
23    Ne,
24    /// Less than (`<`)
25    Lt,
26    /// Less than or equal (`<=`)
27    Le,
28    /// Greater than (`>`)
29    Gt,
30    /// Greater than or equal (`>=`)
31    Ge,
32}
33
34impl CmpOp {
35    /// PTX comparison operator string (e.g. `"ge"`, `"lt"`).
36    pub fn ptx_str(&self) -> &'static str {
37        match self {
38            Self::Eq => "eq",
39            Self::Ne => "ne",
40            Self::Lt => "lt",
41            Self::Le => "le",
42            Self::Gt => "gt",
43            Self::Ge => "ge",
44        }
45    }
46}
47
48/// Control flow PTX instruction variants.
49#[derive(Debug, Clone)]
50pub enum ControlOp {
51    /// Set predicate from comparison: `setp.{cmp_op}{ty} pred, lhs, rhs;`
52    ///
53    /// Compares `lhs` and `rhs` and writes the result to a predicate register.
54    /// Example: `setp.ge.u32 %p1, %r1, %r2;`
55    SetP {
56        /// Destination predicate register.
57        dst: Register,
58        /// Comparison operation.
59        cmp_op: CmpOp,
60        /// Left-hand operand (register or immediate).
61        lhs: Operand,
62        /// Right-hand operand (register or immediate).
63        rhs: Operand,
64        /// PTX type for the comparison.
65        ty: PtxType,
66    },
67    /// Set predicate from comparison ANDed with a source predicate:
68    /// `setp.{cmp_op}.and{ty} pred, lhs, rhs, src_pred;`
69    ///
70    /// Computes `pred = (lhs CmpOp rhs) AND src_pred` in one instruction.
71    /// Used for compact edge-tile bounds checking — combines a row check
72    /// with an existing col-check predicate without a separate `and.pred`.
73    /// Sprint 6.7 (multi-warp matmul_tc edge tiles) is the first user.
74    /// Example: `setp.lt.and.u32 %p3, %r5, %r10, %p2;`
75    SetPAnd {
76        /// Destination predicate register.
77        dst: Register,
78        /// Comparison operation applied to `lhs`/`rhs`.
79        cmp_op: CmpOp,
80        /// Left-hand operand of the comparison.
81        lhs: Operand,
82        /// Right-hand operand of the comparison.
83        rhs: Operand,
84        /// PTX type for the comparison.
85        ty: PtxType,
86        /// Source predicate AND'd with the comparison result.
87        src_pred: Register,
88    },
89    /// Predicated branch: `@{pred} bra {target};` or `@!{pred} bra {target};`
90    ///
91    /// Branches to `target` label if `pred` is true (or false when negated).
92    /// Uses `PtxWriter::line()` instead of `instruction()` because the
93    /// `@pred mnemonic target;` format doesn't fit the comma-separated
94    /// operand pattern.
95    ///
96    /// Examples:
97    /// - `@%p1 bra $L__BB0_2;` — branch if pred is true
98    /// - `@!%p1 bra IF_END_0;` — branch if pred is false (Phase 2 if/else)
99    BraPred {
100        /// Predicate register to test.
101        pred: Register,
102        /// Label name to branch to.
103        target: String,
104        /// When `true`, negate the predicate (`@!pred`). Deferred from
105        /// Sprint 1.4, needed for Phase 2 if/else lowering where `setp`
106        /// matches the source comparison and `@!pred bra` skips the
107        /// then-block when the condition is false.
108        negate: bool,
109    },
110    /// Unconditional branch: `bra {target};`
111    ///
112    /// Not used in `vector_add` but included for Phase 3 loop support.
113    Bra {
114        /// Label name to branch to.
115        target: String,
116    },
117    /// Return from kernel: `ret;`
118    Ret,
119    /// Block-level barrier synchronization: `bar.sync {barrier_id};`
120    ///
121    /// All threads in the block must reach this instruction before any
122    /// can proceed. Barrier 0 is the conventional default.
123    /// Example: `bar.sync 0;`
124    BarSync {
125        /// Barrier identifier (0 is conventional for single-barrier use).
126        barrier_id: u32,
127    },
128    /// Warp shuffle down: `shfl.sync.down.b32 dst, src, delta, c, membermask;`
129    ///
130    /// Each thread reads from the thread `delta` lanes below it within
131    /// the warp. The `c` operand packs clamp width (see PTX ISA 8.7 S9.7.8).
132    /// Example: `shfl.sync.down.b32 %r2, %r1, 1, 31, 0xFFFFFFFF;`
133    ShflSyncDown {
134        /// Destination register.
135        dst: Register,
136        /// Source register (value to share).
137        src: Register,
138        /// Delta (offset) — how many lanes down.
139        delta: Operand,
140        /// Pre-packed clamp/width value (encoding is caller's responsibility).
141        c: u32,
142        /// Member mask (0xFFFFFFFF = full warp).
143        mask: u32,
144    },
145    /// Warp shuffle up: `shfl.sync.up.b32 dst, src, delta, c, membermask;`
146    ///
147    /// Each thread reads from the thread `delta` lanes above it.
148    ShflSyncUp {
149        /// Destination register.
150        dst: Register,
151        /// Source register.
152        src: Register,
153        /// Delta (offset) — how many lanes up.
154        delta: Operand,
155        /// Pre-packed clamp/width value.
156        c: u32,
157        /// Member mask.
158        mask: u32,
159    },
160    /// Warp shuffle butterfly (XOR): `shfl.sync.bfly.b32 dst, src, lane_mask, c, membermask;`
161    ///
162    /// Each thread reads from the thread at `lane XOR lane_mask`.
163    ShflSyncBfly {
164        /// Destination register.
165        dst: Register,
166        /// Source register.
167        src: Register,
168        /// Lane mask for XOR operation.
169        lane_mask: Operand,
170        /// Pre-packed clamp/width value.
171        c: u32,
172        /// Member mask.
173        mask: u32,
174    },
175}
176
177impl Emit for ControlOp {
178    fn emit(&self, w: &mut PtxWriter) -> fmt::Result {
179        match self {
180            ControlOp::SetP {
181                dst,
182                cmp_op,
183                lhs,
184                rhs,
185                ty,
186            } => {
187                let mnemonic = format!("setp.{}{}", cmp_op.ptx_str(), ty.ptx_suffix());
188                w.instruction(&mnemonic, &[dst as &dyn fmt::Display, lhs, rhs])
189            }
190            ControlOp::SetPAnd {
191                dst,
192                cmp_op,
193                lhs,
194                rhs,
195                ty,
196                src_pred,
197            } => {
198                let mnemonic = format!("setp.{}.and{}", cmp_op.ptx_str(), ty.ptx_suffix());
199                w.instruction(&mnemonic, &[dst as &dyn fmt::Display, lhs, rhs, src_pred])
200            }
201            ControlOp::BraPred {
202                pred,
203                target,
204                negate,
205            } => {
206                let neg = if *negate { "!" } else { "" };
207                w.line(&format!("@{neg}{pred} bra {target};"))
208            }
209            ControlOp::Bra { target } => w.instruction("bra", &[&target as &dyn fmt::Display]),
210            ControlOp::Ret => w.instruction("ret", &[]),
211            ControlOp::BarSync { barrier_id } => w.line(&format!("bar.sync {barrier_id};")),
212            ControlOp::ShflSyncDown {
213                dst,
214                src,
215                delta,
216                c,
217                mask,
218            } => w.line(&format!(
219                "shfl.sync.down.b32 {dst}, {src}, {delta}, {c}, 0x{mask:08X};"
220            )),
221            ControlOp::ShflSyncUp {
222                dst,
223                src,
224                delta,
225                c,
226                mask,
227            } => w.line(&format!(
228                "shfl.sync.up.b32 {dst}, {src}, {delta}, {c}, 0x{mask:08X};"
229            )),
230            ControlOp::ShflSyncBfly {
231                dst,
232                src,
233                lane_mask,
234                c,
235                mask,
236            } => w.line(&format!(
237                "shfl.sync.bfly.b32 {dst}, {src}, {lane_mask}, {c}, 0x{mask:08X};"
238            )),
239        }
240    }
241}
242
243#[cfg(test)]
244mod tests {
245    use super::*;
246    use crate::types::RegKind;
247
248    fn reg(kind: RegKind, index: u32, ptx_type: PtxType) -> Register {
249        Register {
250            kind,
251            index,
252            ptx_type,
253        }
254    }
255
256    // --- nvcc golden comparisons ---
257
258    #[test]
259    fn emit_setp_and_lt_u32() {
260        // Sprint 6.7 edge-tile: setp.lt.and.u32 %p3, %r5, %r10, %p2
261        let mut w = PtxWriter::new();
262        w.indent();
263        let op = ControlOp::SetPAnd {
264            dst: reg(RegKind::P, 3, PtxType::Pred),
265            cmp_op: CmpOp::Lt,
266            lhs: Operand::Reg(reg(RegKind::R, 5, PtxType::U32)),
267            rhs: Operand::Reg(reg(RegKind::R, 10, PtxType::U32)),
268            ty: PtxType::U32,
269            src_pred: reg(RegKind::P, 2, PtxType::Pred),
270        };
271        op.emit(&mut w).unwrap();
272        assert_eq!(w.finish(), "    setp.lt.and.u32 %p3, %r5, %r10, %p2;\n");
273    }
274
275    #[test]
276    fn emit_setp_ge_u32() {
277        // nvcc line 36: setp.ge.u32 %p1, %r1, %r2
278        let mut w = PtxWriter::new();
279        w.indent();
280        let op = ControlOp::SetP {
281            dst: reg(RegKind::P, 1, PtxType::Pred),
282            cmp_op: CmpOp::Ge,
283            lhs: Operand::Reg(reg(RegKind::R, 1, PtxType::U32)),
284            rhs: Operand::Reg(reg(RegKind::R, 2, PtxType::U32)),
285            ty: PtxType::U32,
286        };
287        op.emit(&mut w).unwrap();
288        assert_eq!(w.finish(), "    setp.ge.u32 %p1, %r1, %r2;\n");
289    }
290
291    #[test]
292    fn emit_bra_pred() {
293        // nvcc line 37: @%p1 bra $L__BB0_2
294        // nvcc uses tab whitespace; we use space — both valid PTX
295        let mut w = PtxWriter::new();
296        w.indent();
297        let op = ControlOp::BraPred {
298            pred: reg(RegKind::P, 1, PtxType::Pred),
299            target: "$L__BB0_2".to_string(),
300            negate: false,
301        };
302        op.emit(&mut w).unwrap();
303        assert_eq!(w.finish(), "    @%p1 bra $L__BB0_2;\n");
304    }
305
306    #[test]
307    fn emit_bra_pred_negated() {
308        // Phase 2 if/else: @!%p1 bra IF_END_0 — skip then-block when false
309        let mut w = PtxWriter::new();
310        w.indent();
311        let op = ControlOp::BraPred {
312            pred: reg(RegKind::P, 1, PtxType::Pred),
313            target: "IF_END_0".to_string(),
314            negate: true,
315        };
316        op.emit(&mut w).unwrap();
317        assert_eq!(w.finish(), "    @!%p1 bra IF_END_0;\n");
318    }
319
320    #[test]
321    fn emit_ret() {
322        // nvcc line 52: ret
323        let mut w = PtxWriter::new();
324        w.indent();
325        ControlOp::Ret.emit(&mut w).unwrap();
326        assert_eq!(w.finish(), "    ret;\n");
327    }
328
329    #[test]
330    fn emit_bra_unconditional() {
331        let mut w = PtxWriter::new();
332        w.indent();
333        let op = ControlOp::Bra {
334            target: "LOOP".to_string(),
335        };
336        op.emit(&mut w).unwrap();
337        assert_eq!(w.finish(), "    bra LOOP;\n");
338    }
339
340    // --- Dispatch and unit tests ---
341
342    #[test]
343    fn control_via_ptx_instruction() {
344        use crate::ir::PtxInstruction;
345
346        let mut w = PtxWriter::new();
347        w.indent();
348        let instr = PtxInstruction::Control(ControlOp::Ret);
349        instr.emit(&mut w).unwrap();
350        assert_eq!(w.finish(), "    ret;\n");
351    }
352
353    // --- Phase 3: Barrier + Shuffle ---
354
355    #[test]
356    fn emit_bar_sync() {
357        let mut w = PtxWriter::new();
358        w.indent();
359        ControlOp::BarSync { barrier_id: 0 }.emit(&mut w).unwrap();
360        assert_eq!(w.finish(), "    bar.sync 0;\n");
361    }
362
363    #[test]
364    fn emit_shfl_sync_down() {
365        let mut w = PtxWriter::new();
366        w.indent();
367        let op = ControlOp::ShflSyncDown {
368            dst: reg(RegKind::R, 2, PtxType::U32),
369            src: reg(RegKind::R, 1, PtxType::U32),
370            delta: Operand::ImmU32(1),
371            c: 31,
372            mask: 0xFFFFFFFF,
373        };
374        op.emit(&mut w).unwrap();
375        assert_eq!(
376            w.finish(),
377            "    shfl.sync.down.b32 %r2, %r1, 1, 31, 0xFFFFFFFF;\n"
378        );
379    }
380
381    #[test]
382    fn emit_shfl_sync_up() {
383        let mut w = PtxWriter::new();
384        w.indent();
385        let op = ControlOp::ShflSyncUp {
386            dst: reg(RegKind::R, 2, PtxType::U32),
387            src: reg(RegKind::R, 1, PtxType::U32),
388            delta: Operand::ImmU32(1),
389            c: 0,
390            mask: 0xFFFFFFFF,
391        };
392        op.emit(&mut w).unwrap();
393        assert_eq!(
394            w.finish(),
395            "    shfl.sync.up.b32 %r2, %r1, 1, 0, 0xFFFFFFFF;\n"
396        );
397    }
398
399    #[test]
400    fn emit_shfl_sync_bfly() {
401        let mut w = PtxWriter::new();
402        w.indent();
403        let op = ControlOp::ShflSyncBfly {
404            dst: reg(RegKind::R, 2, PtxType::U32),
405            src: reg(RegKind::R, 1, PtxType::U32),
406            lane_mask: Operand::ImmU32(1),
407            c: 31,
408            mask: 0xFFFFFFFF,
409        };
410        op.emit(&mut w).unwrap();
411        assert_eq!(
412            w.finish(),
413            "    shfl.sync.bfly.b32 %r2, %r1, 1, 31, 0xFFFFFFFF;\n"
414        );
415    }
416
417    #[test]
418    fn shfl_sync_down_with_register_delta() {
419        let mut w = PtxWriter::new();
420        w.indent();
421        let op = ControlOp::ShflSyncDown {
422            dst: reg(RegKind::R, 3, PtxType::U32),
423            src: reg(RegKind::R, 0, PtxType::U32),
424            delta: Operand::Reg(reg(RegKind::R, 4, PtxType::U32)),
425            c: 31,
426            mask: 0xFFFFFFFF,
427        };
428        op.emit(&mut w).unwrap();
429        assert_eq!(
430            w.finish(),
431            "    shfl.sync.down.b32 %r3, %r0, %r4, 31, 0xFFFFFFFF;\n"
432        );
433    }
434
435    #[test]
436    fn cmp_op_all_variants() {
437        assert_eq!(CmpOp::Eq.ptx_str(), "eq");
438        assert_eq!(CmpOp::Ne.ptx_str(), "ne");
439        assert_eq!(CmpOp::Lt.ptx_str(), "lt");
440        assert_eq!(CmpOp::Le.ptx_str(), "le");
441        assert_eq!(CmpOp::Gt.ptx_str(), "gt");
442        assert_eq!(CmpOp::Ge.ptx_str(), "ge");
443    }
444}