kaio_core/instr/control.rs
1//! Control flow and synchronization PTX operations.
2//!
3//! Contains comparison-to-predicate ([`SetP`](ControlOp::SetP)),
4//! branching ([`BraPred`](ControlOp::BraPred), [`Bra`](ControlOp::Bra)),
5//! [`Ret`](ControlOp::Ret), barrier synchronization
6//! ([`BarSync`](ControlOp::BarSync)), and warp shuffle operations
7//! ([`ShflSyncDown`](ControlOp::ShflSyncDown),
8//! [`ShflSyncUp`](ControlOp::ShflSyncUp),
9//! [`ShflSyncBfly`](ControlOp::ShflSyncBfly)).
10
11use std::fmt;
12
13use crate::emit::{Emit, PtxWriter};
14use crate::ir::{Operand, Register};
15use crate::types::PtxType;
16
17/// Comparison operator for `setp` instructions.
18#[derive(Debug, Clone, Copy, PartialEq, Eq)]
19pub enum CmpOp {
20 /// Equal (`==`)
21 Eq,
22 /// Not equal (`!=`)
23 Ne,
24 /// Less than (`<`)
25 Lt,
26 /// Less than or equal (`<=`)
27 Le,
28 /// Greater than (`>`)
29 Gt,
30 /// Greater than or equal (`>=`)
31 Ge,
32}
33
34impl CmpOp {
35 /// PTX comparison operator string (e.g. `"ge"`, `"lt"`).
36 pub fn ptx_str(&self) -> &'static str {
37 match self {
38 Self::Eq => "eq",
39 Self::Ne => "ne",
40 Self::Lt => "lt",
41 Self::Le => "le",
42 Self::Gt => "gt",
43 Self::Ge => "ge",
44 }
45 }
46}
47
48/// Control flow PTX instruction variants.
49#[derive(Debug, Clone)]
50pub enum ControlOp {
51 /// Set predicate from comparison: `setp.{cmp_op}{ty} pred, lhs, rhs;`
52 ///
53 /// Compares `lhs` and `rhs` and writes the result to a predicate register.
54 /// Example: `setp.ge.u32 %p1, %r1, %r2;`
55 SetP {
56 /// Destination predicate register.
57 dst: Register,
58 /// Comparison operation.
59 cmp_op: CmpOp,
60 /// Left-hand operand (register or immediate).
61 lhs: Operand,
62 /// Right-hand operand (register or immediate).
63 rhs: Operand,
64 /// PTX type for the comparison.
65 ty: PtxType,
66 },
67 /// Set predicate from comparison ANDed with a source predicate:
68 /// `setp.{cmp_op}.and{ty} pred, lhs, rhs, src_pred;`
69 ///
70 /// Computes `pred = (lhs CmpOp rhs) AND src_pred` in one instruction.
71 /// Used for compact edge-tile bounds checking — combines a row check
72 /// with an existing col-check predicate without a separate `and.pred`.
73 /// Sprint 6.7 (multi-warp matmul_tc edge tiles) is the first user.
74 /// Example: `setp.lt.and.u32 %p3, %r5, %r10, %p2;`
75 SetPAnd {
76 /// Destination predicate register.
77 dst: Register,
78 /// Comparison operation applied to `lhs`/`rhs`.
79 cmp_op: CmpOp,
80 /// Left-hand operand of the comparison.
81 lhs: Operand,
82 /// Right-hand operand of the comparison.
83 rhs: Operand,
84 /// PTX type for the comparison.
85 ty: PtxType,
86 /// Source predicate AND'd with the comparison result.
87 src_pred: Register,
88 },
89 /// Predicated branch: `@{pred} bra {target};` or `@!{pred} bra {target};`
90 ///
91 /// Branches to `target` label if `pred` is true (or false when negated).
92 /// Uses `PtxWriter::line()` instead of `instruction()` because the
93 /// `@pred mnemonic target;` format doesn't fit the comma-separated
94 /// operand pattern.
95 ///
96 /// Examples:
97 /// - `@%p1 bra $L__BB0_2;` — branch if pred is true
98 /// - `@!%p1 bra IF_END_0;` — branch if pred is false (Phase 2 if/else)
99 BraPred {
100 /// Predicate register to test.
101 pred: Register,
102 /// Label name to branch to.
103 target: String,
104 /// When `true`, negate the predicate (`@!pred`). Deferred from
105 /// Sprint 1.4, needed for Phase 2 if/else lowering where `setp`
106 /// matches the source comparison and `@!pred bra` skips the
107 /// then-block when the condition is false.
108 negate: bool,
109 },
110 /// Unconditional branch: `bra {target};`
111 ///
112 /// Not used in `vector_add` but included for Phase 3 loop support.
113 Bra {
114 /// Label name to branch to.
115 target: String,
116 },
117 /// Return from kernel: `ret;`
118 Ret,
119 /// Block-level barrier synchronization: `bar.sync {barrier_id};`
120 ///
121 /// All threads in the block must reach this instruction before any
122 /// can proceed. Barrier 0 is the conventional default.
123 /// Example: `bar.sync 0;`
124 BarSync {
125 /// Barrier identifier (0 is conventional for single-barrier use).
126 barrier_id: u32,
127 },
128 /// Warp shuffle down: `shfl.sync.down.b32 dst, src, delta, c, membermask;`
129 ///
130 /// Each thread reads from the thread `delta` lanes below it within
131 /// the warp. The `c` operand packs clamp width (see PTX ISA 8.7 S9.7.8).
132 /// Example: `shfl.sync.down.b32 %r2, %r1, 1, 31, 0xFFFFFFFF;`
133 ShflSyncDown {
134 /// Destination register.
135 dst: Register,
136 /// Source register (value to share).
137 src: Register,
138 /// Delta (offset) — how many lanes down.
139 delta: Operand,
140 /// Pre-packed clamp/width value (encoding is caller's responsibility).
141 c: u32,
142 /// Member mask (0xFFFFFFFF = full warp).
143 mask: u32,
144 },
145 /// Warp shuffle up: `shfl.sync.up.b32 dst, src, delta, c, membermask;`
146 ///
147 /// Each thread reads from the thread `delta` lanes above it.
148 ShflSyncUp {
149 /// Destination register.
150 dst: Register,
151 /// Source register.
152 src: Register,
153 /// Delta (offset) — how many lanes up.
154 delta: Operand,
155 /// Pre-packed clamp/width value.
156 c: u32,
157 /// Member mask.
158 mask: u32,
159 },
160 /// Warp shuffle butterfly (XOR): `shfl.sync.bfly.b32 dst, src, lane_mask, c, membermask;`
161 ///
162 /// Each thread reads from the thread at `lane XOR lane_mask`.
163 ShflSyncBfly {
164 /// Destination register.
165 dst: Register,
166 /// Source register.
167 src: Register,
168 /// Lane mask for XOR operation.
169 lane_mask: Operand,
170 /// Pre-packed clamp/width value.
171 c: u32,
172 /// Member mask.
173 mask: u32,
174 },
175}
176
177impl Emit for ControlOp {
178 fn emit(&self, w: &mut PtxWriter) -> fmt::Result {
179 match self {
180 ControlOp::SetP {
181 dst,
182 cmp_op,
183 lhs,
184 rhs,
185 ty,
186 } => {
187 let mnemonic = format!("setp.{}{}", cmp_op.ptx_str(), ty.ptx_suffix());
188 w.instruction(&mnemonic, &[dst as &dyn fmt::Display, lhs, rhs])
189 }
190 ControlOp::SetPAnd {
191 dst,
192 cmp_op,
193 lhs,
194 rhs,
195 ty,
196 src_pred,
197 } => {
198 let mnemonic = format!("setp.{}.and{}", cmp_op.ptx_str(), ty.ptx_suffix());
199 w.instruction(&mnemonic, &[dst as &dyn fmt::Display, lhs, rhs, src_pred])
200 }
201 ControlOp::BraPred {
202 pred,
203 target,
204 negate,
205 } => {
206 let neg = if *negate { "!" } else { "" };
207 w.line(&format!("@{neg}{pred} bra {target};"))
208 }
209 ControlOp::Bra { target } => w.instruction("bra", &[&target as &dyn fmt::Display]),
210 ControlOp::Ret => w.instruction("ret", &[]),
211 ControlOp::BarSync { barrier_id } => w.line(&format!("bar.sync {barrier_id};")),
212 ControlOp::ShflSyncDown {
213 dst,
214 src,
215 delta,
216 c,
217 mask,
218 } => w.line(&format!(
219 "shfl.sync.down.b32 {dst}, {src}, {delta}, {c}, 0x{mask:08X};"
220 )),
221 ControlOp::ShflSyncUp {
222 dst,
223 src,
224 delta,
225 c,
226 mask,
227 } => w.line(&format!(
228 "shfl.sync.up.b32 {dst}, {src}, {delta}, {c}, 0x{mask:08X};"
229 )),
230 ControlOp::ShflSyncBfly {
231 dst,
232 src,
233 lane_mask,
234 c,
235 mask,
236 } => w.line(&format!(
237 "shfl.sync.bfly.b32 {dst}, {src}, {lane_mask}, {c}, 0x{mask:08X};"
238 )),
239 }
240 }
241}
242
243#[cfg(test)]
244mod tests {
245 use super::*;
246 use crate::types::RegKind;
247
248 fn reg(kind: RegKind, index: u32, ptx_type: PtxType) -> Register {
249 Register {
250 kind,
251 index,
252 ptx_type,
253 }
254 }
255
256 // --- nvcc golden comparisons ---
257
258 #[test]
259 fn emit_setp_and_lt_u32() {
260 // Sprint 6.7 edge-tile: setp.lt.and.u32 %p3, %r5, %r10, %p2
261 let mut w = PtxWriter::new();
262 w.indent();
263 let op = ControlOp::SetPAnd {
264 dst: reg(RegKind::P, 3, PtxType::Pred),
265 cmp_op: CmpOp::Lt,
266 lhs: Operand::Reg(reg(RegKind::R, 5, PtxType::U32)),
267 rhs: Operand::Reg(reg(RegKind::R, 10, PtxType::U32)),
268 ty: PtxType::U32,
269 src_pred: reg(RegKind::P, 2, PtxType::Pred),
270 };
271 op.emit(&mut w).unwrap();
272 assert_eq!(w.finish(), " setp.lt.and.u32 %p3, %r5, %r10, %p2;\n");
273 }
274
275 #[test]
276 fn emit_setp_ge_u32() {
277 // nvcc line 36: setp.ge.u32 %p1, %r1, %r2
278 let mut w = PtxWriter::new();
279 w.indent();
280 let op = ControlOp::SetP {
281 dst: reg(RegKind::P, 1, PtxType::Pred),
282 cmp_op: CmpOp::Ge,
283 lhs: Operand::Reg(reg(RegKind::R, 1, PtxType::U32)),
284 rhs: Operand::Reg(reg(RegKind::R, 2, PtxType::U32)),
285 ty: PtxType::U32,
286 };
287 op.emit(&mut w).unwrap();
288 assert_eq!(w.finish(), " setp.ge.u32 %p1, %r1, %r2;\n");
289 }
290
291 #[test]
292 fn emit_bra_pred() {
293 // nvcc line 37: @%p1 bra $L__BB0_2
294 // nvcc uses tab whitespace; we use space — both valid PTX
295 let mut w = PtxWriter::new();
296 w.indent();
297 let op = ControlOp::BraPred {
298 pred: reg(RegKind::P, 1, PtxType::Pred),
299 target: "$L__BB0_2".to_string(),
300 negate: false,
301 };
302 op.emit(&mut w).unwrap();
303 assert_eq!(w.finish(), " @%p1 bra $L__BB0_2;\n");
304 }
305
306 #[test]
307 fn emit_bra_pred_negated() {
308 // Phase 2 if/else: @!%p1 bra IF_END_0 — skip then-block when false
309 let mut w = PtxWriter::new();
310 w.indent();
311 let op = ControlOp::BraPred {
312 pred: reg(RegKind::P, 1, PtxType::Pred),
313 target: "IF_END_0".to_string(),
314 negate: true,
315 };
316 op.emit(&mut w).unwrap();
317 assert_eq!(w.finish(), " @!%p1 bra IF_END_0;\n");
318 }
319
320 #[test]
321 fn emit_ret() {
322 // nvcc line 52: ret
323 let mut w = PtxWriter::new();
324 w.indent();
325 ControlOp::Ret.emit(&mut w).unwrap();
326 assert_eq!(w.finish(), " ret;\n");
327 }
328
329 #[test]
330 fn emit_bra_unconditional() {
331 let mut w = PtxWriter::new();
332 w.indent();
333 let op = ControlOp::Bra {
334 target: "LOOP".to_string(),
335 };
336 op.emit(&mut w).unwrap();
337 assert_eq!(w.finish(), " bra LOOP;\n");
338 }
339
340 // --- Dispatch and unit tests ---
341
342 #[test]
343 fn control_via_ptx_instruction() {
344 use crate::ir::PtxInstruction;
345
346 let mut w = PtxWriter::new();
347 w.indent();
348 let instr = PtxInstruction::Control(ControlOp::Ret);
349 instr.emit(&mut w).unwrap();
350 assert_eq!(w.finish(), " ret;\n");
351 }
352
353 // --- Phase 3: Barrier + Shuffle ---
354
355 #[test]
356 fn emit_bar_sync() {
357 let mut w = PtxWriter::new();
358 w.indent();
359 ControlOp::BarSync { barrier_id: 0 }.emit(&mut w).unwrap();
360 assert_eq!(w.finish(), " bar.sync 0;\n");
361 }
362
363 #[test]
364 fn emit_shfl_sync_down() {
365 let mut w = PtxWriter::new();
366 w.indent();
367 let op = ControlOp::ShflSyncDown {
368 dst: reg(RegKind::R, 2, PtxType::U32),
369 src: reg(RegKind::R, 1, PtxType::U32),
370 delta: Operand::ImmU32(1),
371 c: 31,
372 mask: 0xFFFFFFFF,
373 };
374 op.emit(&mut w).unwrap();
375 assert_eq!(
376 w.finish(),
377 " shfl.sync.down.b32 %r2, %r1, 1, 31, 0xFFFFFFFF;\n"
378 );
379 }
380
381 #[test]
382 fn emit_shfl_sync_up() {
383 let mut w = PtxWriter::new();
384 w.indent();
385 let op = ControlOp::ShflSyncUp {
386 dst: reg(RegKind::R, 2, PtxType::U32),
387 src: reg(RegKind::R, 1, PtxType::U32),
388 delta: Operand::ImmU32(1),
389 c: 0,
390 mask: 0xFFFFFFFF,
391 };
392 op.emit(&mut w).unwrap();
393 assert_eq!(
394 w.finish(),
395 " shfl.sync.up.b32 %r2, %r1, 1, 0, 0xFFFFFFFF;\n"
396 );
397 }
398
399 #[test]
400 fn emit_shfl_sync_bfly() {
401 let mut w = PtxWriter::new();
402 w.indent();
403 let op = ControlOp::ShflSyncBfly {
404 dst: reg(RegKind::R, 2, PtxType::U32),
405 src: reg(RegKind::R, 1, PtxType::U32),
406 lane_mask: Operand::ImmU32(1),
407 c: 31,
408 mask: 0xFFFFFFFF,
409 };
410 op.emit(&mut w).unwrap();
411 assert_eq!(
412 w.finish(),
413 " shfl.sync.bfly.b32 %r2, %r1, 1, 31, 0xFFFFFFFF;\n"
414 );
415 }
416
417 #[test]
418 fn shfl_sync_down_with_register_delta() {
419 let mut w = PtxWriter::new();
420 w.indent();
421 let op = ControlOp::ShflSyncDown {
422 dst: reg(RegKind::R, 3, PtxType::U32),
423 src: reg(RegKind::R, 0, PtxType::U32),
424 delta: Operand::Reg(reg(RegKind::R, 4, PtxType::U32)),
425 c: 31,
426 mask: 0xFFFFFFFF,
427 };
428 op.emit(&mut w).unwrap();
429 assert_eq!(
430 w.finish(),
431 " shfl.sync.down.b32 %r3, %r0, %r4, 31, 0xFFFFFFFF;\n"
432 );
433 }
434
435 #[test]
436 fn cmp_op_all_variants() {
437 assert_eq!(CmpOp::Eq.ptx_str(), "eq");
438 assert_eq!(CmpOp::Ne.ptx_str(), "ne");
439 assert_eq!(CmpOp::Lt.ptx_str(), "lt");
440 assert_eq!(CmpOp::Le.ptx_str(), "le");
441 assert_eq!(CmpOp::Gt.ptx_str(), "gt");
442 assert_eq!(CmpOp::Ge.ptx_str(), "ge");
443 }
444}