kaio_core/instr/memory.rs
1//! Memory PTX operations.
2//!
3//! Contains load/store instructions for global and shared memory:
4//! [`LdParam`](MemoryOp::LdParam), [`LdGlobal`](MemoryOp::LdGlobal),
5//! [`StGlobal`](MemoryOp::StGlobal), [`LdShared`](MemoryOp::LdShared),
6//! [`StShared`](MemoryOp::StShared), and
7//! [`CvtaToGlobal`](MemoryOp::CvtaToGlobal).
8
9use std::fmt;
10
11use crate::emit::{Emit, PtxWriter};
12use crate::ir::Register;
13use crate::types::PtxType;
14
15/// Memory PTX instruction variants.
16///
17/// Operand conventions:
18/// - All addresses and values are [`Register`]s (not [`Operand`](crate::ir::Operand)).
19/// You can't `ld.global` from an immediate address or `st.global` an immediate
20/// value in PTX — those go through `mov` first.
21/// - [`LdParam`](Self::LdParam) is the exception: it references a kernel parameter
22/// by name (a `String`), not by register.
23#[derive(Debug, Clone)]
24pub enum MemoryOp {
25 /// Load kernel parameter: `ld.param{ty} dst, [param_name];`
26 ///
27 /// References the parameter by name from the kernel signature.
28 /// Example: `ld.param.u64 %rd1, [vector_add_param_0];`
29 LdParam {
30 /// Destination register.
31 dst: Register,
32 /// Parameter name from the kernel signature.
33 param_name: String,
34 /// PTX type of the parameter value.
35 ty: PtxType,
36 },
37 /// Load from global memory: `ld.global{ty} dst, [addr];`
38 ///
39 /// The `addr` register holds the computed memory address.
40 /// Example: `ld.global.f32 %f1, [%rd8];`
41 LdGlobal {
42 /// Destination register.
43 dst: Register,
44 /// Register holding the memory address.
45 addr: Register,
46 /// PTX type of the loaded value.
47 ty: PtxType,
48 },
49 /// Predicated load from global memory: `@[!]{pred} ld.global{ty} dst, [addr];`
50 ///
51 /// Skips the load when the predicate evaluates false (or true when
52 /// `negate` is set). Used for edge-tile bounds checking — the OOB
53 /// thread's `dst` register is left unchanged, so callers typically
54 /// pre-initialize `dst` to zero with `mov.b32 dst, 0` and then
55 /// conditionally overwrite with a predicated load.
56 ///
57 /// Sprint 6.7 (multi-warp matmul_tc edge tiles) is the first user.
58 /// Example: `@%p1 ld.global.u32 %r5, [%rd9];`
59 LdGlobalPred {
60 /// Destination register (unchanged when predicate is false).
61 dst: Register,
62 /// Register holding the memory address.
63 addr: Register,
64 /// PTX type of the loaded value.
65 ty: PtxType,
66 /// Predicate register controlling the load.
67 pred: Register,
68 /// When `true`, negate the predicate (`@!pred`).
69 negate: bool,
70 },
71 /// 128-bit vectorized load from global memory:
72 /// `ld.global.v4.b32 {%r_i, %r_j, %r_k, %r_l}, [addr];`
73 ///
74 /// Single-instruction 128-bit transfer into 4 independent b32
75 /// destination registers. Halves (or more) the global-load
76 /// instruction count vs scalar b32 loads for bandwidth-bound
77 /// kernels. Requires the `addr` register to hold a **16-byte
78 /// aligned global-space address** — unaligned access will fault
79 /// at runtime; PTX does not catch this statically.
80 ///
81 /// Destinations are NOT required to be consecutive registers in
82 /// the allocator — PTX `ld.global.v4.b32` accepts any 4 b32 regs
83 /// in the vector brace list. In practice, allocating 4 regs in
84 /// sequence produces consecutive indices, which is what callers
85 /// typically do.
86 ///
87 /// No predicate variant in Sprint 6.7b — edge tiles stay on the
88 /// existing [`LdGlobalPred`](Self::LdGlobalPred) scalar path. A
89 /// future `LdGlobalB128Pred` would be additive.
90 ///
91 /// Sprint 6.7b (multi-warp matmul_tc Tile B fast path) is the
92 /// first user. Construct via
93 /// [`MemoryOp::new_ld_global_b128`](Self::new_ld_global_b128),
94 /// which validates that all 4 destinations are b32-class registers.
95 ///
96 /// Example: `ld.global.v4.b32 {%r0, %r1, %r2, %r3}, [%rd8];`
97 LdGlobalB128 {
98 /// Four b32 destination registers — receive bytes 0-3, 4-7,
99 /// 8-11, 12-15 of the loaded 128-bit value respectively.
100 dsts: [Register; 4],
101 /// Register holding a 16-B aligned global-space address.
102 addr: Register,
103 },
104 /// Store to global memory: `st.global{ty} [addr], src;`
105 ///
106 /// **Operand order is reversed in PTX** — address comes first,
107 /// value second. This matches PTX convention but is opposite to
108 /// loads and arithmetic where `dst` is first.
109 ///
110 /// Example: `st.global.f32 [%rd10], %f3;`
111 StGlobal {
112 /// Register holding the memory address.
113 addr: Register,
114 /// Source register (value to store).
115 src: Register,
116 /// PTX type of the stored value.
117 ty: PtxType,
118 },
119 /// Predicated store to global memory: `@[!]{pred} st.global{ty} [addr], src;`
120 ///
121 /// Skips the store when the predicate evaluates false (or true when
122 /// `negate` is set). Used for edge-tile bounds checking on output
123 /// writes — out-of-bounds threads simply don't store, leaving the
124 /// destination memory untouched.
125 ///
126 /// Sprint 6.7 (multi-warp matmul_tc edge tiles) is the first user.
127 /// Example: `@%p1 st.global.f32 [%rd11], %f4;`
128 StGlobalPred {
129 /// Register holding the memory address.
130 addr: Register,
131 /// Source register (value to store).
132 src: Register,
133 /// PTX type of the stored value.
134 ty: PtxType,
135 /// Predicate register controlling the store.
136 pred: Register,
137 /// When `true`, negate the predicate (`@!pred`).
138 negate: bool,
139 },
140 /// Load from shared memory: `ld.shared{ty} dst, [addr];`
141 ///
142 /// Shared memory is block-scoped SRAM. The `addr` register holds the
143 /// offset into the declared shared allocation.
144 /// Example: `ld.shared.f32 %f0, [%r0];`
145 LdShared {
146 /// Destination register.
147 dst: Register,
148 /// Register holding the shared memory offset.
149 addr: Register,
150 /// PTX type of the loaded value.
151 ty: PtxType,
152 },
153 /// Store to shared memory: `st.shared{ty} [addr], src;`
154 ///
155 /// **Operand order is reversed in PTX** — address first, value second
156 /// (same convention as [`StGlobal`](Self::StGlobal)).
157 /// Example: `st.shared.f32 [%r0], %f1;`
158 StShared {
159 /// Register holding the shared memory offset.
160 addr: Register,
161 /// Source register (value to store).
162 src: Register,
163 /// PTX type of the stored value.
164 ty: PtxType,
165 },
166 /// Convert generic address to global: `cvta.to.global.u64 dst, src;`
167 ///
168 /// Always `.u64` (64-bit address space, matching `.address_size 64`).
169 /// Required because `ld.param` returns generic-space pointers —
170 /// `ld.global` needs global-space addresses.
171 CvtaToGlobal {
172 /// Destination register (global-space address).
173 dst: Register,
174 /// Source register (generic-space address from `ld.param`).
175 src: Register,
176 },
177 /// Asynchronous global→shared copy, cache-at-all-levels variant:
178 /// `cp.async.ca.shared.global [dst_shared], [src_global], size_bytes;`
179 ///
180 /// Issues a non-blocking transfer from global memory into shared
181 /// memory without tying up registers. The copy is in-flight after
182 /// this instruction; use [`CpAsyncCommitGroup`](Self::CpAsyncCommitGroup)
183 /// to delimit a batch and [`CpAsyncWaitGroup`](Self::CpAsyncWaitGroup)
184 /// to synchronize. Requires **SM 8.0+ (Ampere)**.
185 ///
186 /// `size_bytes` must be one of 4, 8, or 16 (validated at construction
187 /// via [`MemoryOp::new_cp_async_ca`](Self::new_cp_async_ca)).
188 ///
189 /// Example: `cp.async.ca.shared.global [%r0], [%rd3], 16;`
190 ///
191 /// *Placement note:* cp.async lives in `MemoryOp` for Sprint 6.2
192 /// because semantically it is a memory op. The commit/wait variants
193 /// are pipeline-state operations and may relocate to a dedicated
194 /// `PipelineOp` category in Sprint 6.4 once double-buffering patterns
195 /// exercise the state machine.
196 CpAsyncCaSharedGlobal {
197 /// Register holding the shared-memory destination offset.
198 dst_shared: Register,
199 /// Register holding the global-memory source address (`.to.global`).
200 src_global: Register,
201 /// Copy size in bytes: must be 4, 8, or 16.
202 size_bytes: u8,
203 },
204 /// Commit all pending `cp.async` operations into a new async group:
205 /// `cp.async.commit_group;`
206 ///
207 /// Groups are numbered implicitly from 0 (most-recently committed)
208 /// upward. Used in conjunction with
209 /// [`CpAsyncWaitGroup`](Self::CpAsyncWaitGroup) to block until a
210 /// specific group completes. Requires **SM 8.0+**.
211 CpAsyncCommitGroup,
212 /// Wait until at most `n` async copy groups remain in-flight:
213 /// `cp.async.wait_group n;`
214 ///
215 /// `wait_group 0` waits for all outstanding groups to complete
216 /// (the common one-stage-pipeline case). For double-buffered
217 /// kernels, `wait_group 1` is used to block on the N-1'th group
218 /// while issuing the N'th. Requires **SM 8.0+**.
219 CpAsyncWaitGroup {
220 /// Number of outstanding groups still permitted after this wait.
221 n: u8,
222 },
223}
224
225impl MemoryOp {
226 /// Construct a [`CpAsyncCaSharedGlobal`](Self::CpAsyncCaSharedGlobal),
227 /// validating the size byte count.
228 ///
229 /// # Panics
230 ///
231 /// Panics if `size_bytes` is not one of `4`, `8`, or `16` — the
232 /// only sizes PTX accepts for `cp.async.ca`. PTX won't catch this
233 /// until ptxas runs, and the error there is cryptic, so we fail
234 /// loudly at construction time.
235 pub fn new_cp_async_ca(dst_shared: Register, src_global: Register, size_bytes: u8) -> Self {
236 assert!(
237 matches!(size_bytes, 4 | 8 | 16),
238 "cp.async.ca size must be 4, 8, or 16 bytes (got {size_bytes})"
239 );
240 Self::CpAsyncCaSharedGlobal {
241 dst_shared,
242 src_global,
243 size_bytes,
244 }
245 }
246
247 /// Construct an [`LdGlobalB128`](Self::LdGlobalB128), validating
248 /// that all 4 destinations are b32-class registers.
249 ///
250 /// # Panics
251 ///
252 /// Panics if any destination register is not [`crate::types::RegKind::R`] (b32).
253 /// `ld.global.v4.b32` requires 4× 32-bit-wide integer-class
254 /// destinations; `.f` / `.rd` / `.h` / `.hb` / `.p` registers are
255 /// invalid and ptxas's error message is cryptic. Fail loudly at
256 /// construction.
257 pub fn new_ld_global_b128(dsts: [Register; 4], addr: Register) -> Self {
258 use crate::types::RegKind;
259 for (i, d) in dsts.iter().enumerate() {
260 assert!(
261 d.kind == RegKind::R,
262 "ld.global.v4.b32 destination {i} must be a b32 register (RegKind::R); got {:?}",
263 d.kind
264 );
265 }
266 Self::LdGlobalB128 { dsts, addr }
267 }
268}
269
270impl Emit for MemoryOp {
271 fn emit(&self, w: &mut PtxWriter) -> fmt::Result {
272 match self {
273 MemoryOp::LdParam {
274 dst,
275 param_name,
276 ty,
277 } => {
278 let mnemonic = format!("ld.param{}", ty.ptx_memory_suffix());
279 let addr = format!("[{param_name}]");
280 w.instruction(&mnemonic, &[dst as &dyn fmt::Display, &addr])
281 }
282 MemoryOp::LdGlobal { dst, addr, ty } => {
283 let mnemonic = format!("ld.global{}", ty.ptx_memory_suffix());
284 let addr_str = format!("[{addr}]");
285 w.instruction(&mnemonic, &[dst as &dyn fmt::Display, &addr_str])
286 }
287 MemoryOp::LdGlobalPred {
288 dst,
289 addr,
290 ty,
291 pred,
292 negate,
293 } => {
294 let neg = if *negate { "!" } else { "" };
295 w.line(&format!(
296 "@{neg}{pred} ld.global{} {dst}, [{addr}];",
297 ty.ptx_memory_suffix()
298 ))
299 }
300 MemoryOp::LdGlobalB128 { dsts, addr } => {
301 // ld.global.v4.b32 {d0, d1, d2, d3}, [addr];
302 w.line(&format!(
303 "ld.global.v4.b32 {{{}, {}, {}, {}}}, [{addr}];",
304 dsts[0], dsts[1], dsts[2], dsts[3]
305 ))
306 }
307 MemoryOp::StGlobal { addr, src, ty } => {
308 let mnemonic = format!("st.global{}", ty.ptx_memory_suffix());
309 let addr_str = format!("[{addr}]");
310 // PTX store order: [address], source (reversed from load)
311 w.instruction(&mnemonic, &[&addr_str as &dyn fmt::Display, src])
312 }
313 MemoryOp::StGlobalPred {
314 addr,
315 src,
316 ty,
317 pred,
318 negate,
319 } => {
320 let neg = if *negate { "!" } else { "" };
321 w.line(&format!(
322 "@{neg}{pred} st.global{} [{addr}], {src};",
323 ty.ptx_memory_suffix()
324 ))
325 }
326 MemoryOp::LdShared { dst, addr, ty } => {
327 let mnemonic = format!("ld.shared{}", ty.ptx_memory_suffix());
328 let addr_str = format!("[{addr}]");
329 w.instruction(&mnemonic, &[dst as &dyn fmt::Display, &addr_str])
330 }
331 MemoryOp::StShared { addr, src, ty } => {
332 let mnemonic = format!("st.shared{}", ty.ptx_memory_suffix());
333 let addr_str = format!("[{addr}]");
334 w.instruction(&mnemonic, &[&addr_str as &dyn fmt::Display, src])
335 }
336 MemoryOp::CvtaToGlobal { dst, src } => {
337 w.instruction("cvta.to.global.u64", &[dst as &dyn fmt::Display, src])
338 }
339 MemoryOp::CpAsyncCaSharedGlobal {
340 dst_shared,
341 src_global,
342 size_bytes,
343 } => {
344 // cp.async.ca.shared.global [dst_shared], [src_global], size;
345 let dst_str = format!("[{dst_shared}]");
346 let src_str = format!("[{src_global}]");
347 let sz = *size_bytes as u32;
348 w.instruction(
349 "cp.async.ca.shared.global",
350 &[&dst_str as &dyn fmt::Display, &src_str, &sz],
351 )
352 }
353 MemoryOp::CpAsyncCommitGroup => w.instruction("cp.async.commit_group", &[]),
354 MemoryOp::CpAsyncWaitGroup { n } => {
355 let n = *n as u32;
356 w.instruction("cp.async.wait_group", &[&n as &dyn fmt::Display])
357 }
358 }
359 }
360}
361
362#[cfg(test)]
363mod tests {
364 use super::*;
365 use crate::types::RegKind;
366
367 /// Helper to make a register without going through the allocator.
368 fn reg(kind: RegKind, index: u32, ptx_type: PtxType) -> Register {
369 Register {
370 kind,
371 index,
372 ptx_type,
373 }
374 }
375
376 // --- nvcc golden comparisons (byte-for-byte match against nvcc --ptx -arch=sm_89) ---
377
378 #[test]
379 fn emit_ld_param_u64() {
380 // nvcc line 28: ld.param.u64 %rd1, [vector_add_param_0]
381 let mut w = PtxWriter::new();
382 w.indent();
383 let op = MemoryOp::LdParam {
384 dst: reg(RegKind::Rd, 1, PtxType::U64),
385 param_name: "vector_add_param_0".to_string(),
386 ty: PtxType::U64,
387 };
388 op.emit(&mut w).unwrap();
389 assert_eq!(w.finish(), " ld.param.u64 %rd1, [vector_add_param_0];\n");
390 }
391
392 #[test]
393 fn emit_ld_param_u32() {
394 // nvcc line 31: ld.param.u32 %r2, [vector_add_param_3]
395 let mut w = PtxWriter::new();
396 w.indent();
397 let op = MemoryOp::LdParam {
398 dst: reg(RegKind::R, 2, PtxType::U32),
399 param_name: "vector_add_param_3".to_string(),
400 ty: PtxType::U32,
401 };
402 op.emit(&mut w).unwrap();
403 assert_eq!(w.finish(), " ld.param.u32 %r2, [vector_add_param_3];\n");
404 }
405
406 #[test]
407 fn emit_cvta_to_global() {
408 // nvcc line 39: cvta.to.global.u64 %rd4, %rd1
409 let mut w = PtxWriter::new();
410 w.indent();
411 let op = MemoryOp::CvtaToGlobal {
412 dst: reg(RegKind::Rd, 4, PtxType::U64),
413 src: reg(RegKind::Rd, 1, PtxType::U64),
414 };
415 op.emit(&mut w).unwrap();
416 assert_eq!(w.finish(), " cvta.to.global.u64 %rd4, %rd1;\n");
417 }
418
419 #[test]
420 fn emit_ld_global_f32() {
421 // nvcc line 44: ld.global.f32 %f1, [%rd8]
422 let mut w = PtxWriter::new();
423 w.indent();
424 let op = MemoryOp::LdGlobal {
425 dst: reg(RegKind::F, 1, PtxType::F32),
426 addr: reg(RegKind::Rd, 8, PtxType::U64),
427 ty: PtxType::F32,
428 };
429 op.emit(&mut w).unwrap();
430 assert_eq!(w.finish(), " ld.global.f32 %f1, [%rd8];\n");
431 }
432
433 #[test]
434 fn emit_ld_global_pred_b32() {
435 // Sprint 6.7 edge-tile predicated load.
436 let mut w = PtxWriter::new();
437 w.indent();
438 let op = MemoryOp::LdGlobalPred {
439 dst: reg(RegKind::R, 5, PtxType::U32),
440 addr: reg(RegKind::Rd, 9, PtxType::U64),
441 ty: PtxType::U32,
442 pred: reg(RegKind::P, 1, PtxType::Pred),
443 negate: false,
444 };
445 op.emit(&mut w).unwrap();
446 assert_eq!(w.finish(), " @%p1 ld.global.u32 %r5, [%rd9];\n");
447 }
448
449 #[test]
450 fn emit_ld_global_pred_negated_b32() {
451 let mut w = PtxWriter::new();
452 w.indent();
453 let op = MemoryOp::LdGlobalPred {
454 dst: reg(RegKind::R, 5, PtxType::U32),
455 addr: reg(RegKind::Rd, 9, PtxType::U64),
456 ty: PtxType::U32,
457 pred: reg(RegKind::P, 2, PtxType::Pred),
458 negate: true,
459 };
460 op.emit(&mut w).unwrap();
461 assert_eq!(w.finish(), " @!%p2 ld.global.u32 %r5, [%rd9];\n");
462 }
463
464 #[test]
465 fn emit_ld_global_b128() {
466 // Sprint 6.7b vectorized load — 128-bit global load into 4 b32 regs.
467 let mut w = PtxWriter::new();
468 w.indent();
469 let op = MemoryOp::new_ld_global_b128(
470 [
471 reg(RegKind::R, 0, PtxType::U32),
472 reg(RegKind::R, 1, PtxType::U32),
473 reg(RegKind::R, 2, PtxType::U32),
474 reg(RegKind::R, 3, PtxType::U32),
475 ],
476 reg(RegKind::Rd, 8, PtxType::U64),
477 );
478 op.emit(&mut w).unwrap();
479 assert_eq!(
480 w.finish(),
481 " ld.global.v4.b32 {%r0, %r1, %r2, %r3}, [%rd8];\n"
482 );
483 }
484
485 #[test]
486 fn emit_ld_global_b128_non_consecutive_regs() {
487 // PTX accepts any 4 b32 regs in the brace list — not required
488 // to be consecutive. Validates that emit just writes what it's given.
489 let mut w = PtxWriter::new();
490 w.indent();
491 let op = MemoryOp::new_ld_global_b128(
492 [
493 reg(RegKind::R, 5, PtxType::U32),
494 reg(RegKind::R, 9, PtxType::U32),
495 reg(RegKind::R, 2, PtxType::U32),
496 reg(RegKind::R, 14, PtxType::U32),
497 ],
498 reg(RegKind::Rd, 3, PtxType::U64),
499 );
500 op.emit(&mut w).unwrap();
501 assert_eq!(
502 w.finish(),
503 " ld.global.v4.b32 {%r5, %r9, %r2, %r14}, [%rd3];\n"
504 );
505 }
506
507 #[test]
508 fn ld_global_b128_emits_single_instruction() {
509 // D1 promise: LDG.128 emits ONE PTX instruction, not four.
510 let mut w = PtxWriter::new();
511 w.indent();
512 let op = MemoryOp::new_ld_global_b128(
513 [
514 reg(RegKind::R, 0, PtxType::U32),
515 reg(RegKind::R, 1, PtxType::U32),
516 reg(RegKind::R, 2, PtxType::U32),
517 reg(RegKind::R, 3, PtxType::U32),
518 ],
519 reg(RegKind::Rd, 0, PtxType::U64),
520 );
521 op.emit(&mut w).unwrap();
522 let out = w.finish();
523 // Exactly one `ld.global.v4.b32` + exactly one newline = one line.
524 assert_eq!(out.matches("ld.global").count(), 1);
525 assert_eq!(out.matches('\n').count(), 1);
526 }
527
528 #[test]
529 #[should_panic(expected = "ld.global.v4.b32 destination 0 must be a b32 register")]
530 fn ld_global_b128_rejects_f32_destination() {
531 MemoryOp::new_ld_global_b128(
532 [
533 reg(RegKind::F, 0, PtxType::F32), // wrong kind
534 reg(RegKind::R, 1, PtxType::U32),
535 reg(RegKind::R, 2, PtxType::U32),
536 reg(RegKind::R, 3, PtxType::U32),
537 ],
538 reg(RegKind::Rd, 0, PtxType::U64),
539 );
540 }
541
542 #[test]
543 #[should_panic(expected = "ld.global.v4.b32 destination 2 must be a b32 register")]
544 fn ld_global_b128_rejects_h_destination() {
545 MemoryOp::new_ld_global_b128(
546 [
547 reg(RegKind::R, 0, PtxType::U32),
548 reg(RegKind::R, 1, PtxType::U32),
549 reg(RegKind::H, 0, PtxType::F16), // wrong kind (fp16)
550 reg(RegKind::R, 3, PtxType::U32),
551 ],
552 reg(RegKind::Rd, 0, PtxType::U64),
553 );
554 }
555
556 #[test]
557 fn ld_global_b128_via_ptx_instruction() {
558 use crate::ir::PtxInstruction;
559 let mut w = PtxWriter::new();
560 w.indent();
561 let instr = PtxInstruction::Memory(MemoryOp::new_ld_global_b128(
562 [
563 reg(RegKind::R, 0, PtxType::U32),
564 reg(RegKind::R, 1, PtxType::U32),
565 reg(RegKind::R, 2, PtxType::U32),
566 reg(RegKind::R, 3, PtxType::U32),
567 ],
568 reg(RegKind::Rd, 5, PtxType::U64),
569 ));
570 instr.emit(&mut w).unwrap();
571 assert_eq!(
572 w.finish(),
573 " ld.global.v4.b32 {%r0, %r1, %r2, %r3}, [%rd5];\n"
574 );
575 }
576
577 #[test]
578 fn emit_st_global_pred_f32() {
579 // Sprint 6.7 edge-tile predicated store.
580 let mut w = PtxWriter::new();
581 w.indent();
582 let op = MemoryOp::StGlobalPred {
583 addr: reg(RegKind::Rd, 11, PtxType::U64),
584 src: reg(RegKind::F, 4, PtxType::F32),
585 ty: PtxType::F32,
586 pred: reg(RegKind::P, 3, PtxType::Pred),
587 negate: false,
588 };
589 op.emit(&mut w).unwrap();
590 assert_eq!(w.finish(), " @%p3 st.global.f32 [%rd11], %f4;\n");
591 }
592
593 #[test]
594 fn emit_st_global_f32() {
595 // nvcc line 49: st.global.f32 [%rd10], %f3
596 let mut w = PtxWriter::new();
597 w.indent();
598 let op = MemoryOp::StGlobal {
599 addr: reg(RegKind::Rd, 10, PtxType::U64),
600 src: reg(RegKind::F, 3, PtxType::F32),
601 ty: PtxType::F32,
602 };
603 op.emit(&mut w).unwrap();
604 assert_eq!(w.finish(), " st.global.f32 [%rd10], %f3;\n");
605 }
606
607 // --- Dispatch and ordering validation ---
608
609 #[test]
610 fn memory_via_ptx_instruction() {
611 use crate::ir::PtxInstruction;
612
613 let mut w = PtxWriter::new();
614 w.indent();
615 let instr = PtxInstruction::Memory(MemoryOp::LdGlobal {
616 dst: reg(RegKind::F, 0, PtxType::F32),
617 addr: reg(RegKind::Rd, 0, PtxType::U64),
618 ty: PtxType::F32,
619 });
620 instr.emit(&mut w).unwrap();
621 assert_eq!(w.finish(), " ld.global.f32 %f0, [%rd0];\n");
622 }
623
624 // --- Shared memory ops ---
625
626 #[test]
627 fn emit_ld_shared_f32() {
628 let mut w = PtxWriter::new();
629 w.indent();
630 let op = MemoryOp::LdShared {
631 dst: reg(RegKind::F, 0, PtxType::F32),
632 addr: reg(RegKind::R, 0, PtxType::U32),
633 ty: PtxType::F32,
634 };
635 op.emit(&mut w).unwrap();
636 assert_eq!(w.finish(), " ld.shared.f32 %f0, [%r0];\n");
637 }
638
639 #[test]
640 fn emit_st_shared_f32() {
641 let mut w = PtxWriter::new();
642 w.indent();
643 let op = MemoryOp::StShared {
644 addr: reg(RegKind::R, 0, PtxType::U32),
645 src: reg(RegKind::F, 1, PtxType::F32),
646 ty: PtxType::F32,
647 };
648 op.emit(&mut w).unwrap();
649 assert_eq!(w.finish(), " st.shared.f32 [%r0], %f1;\n");
650 }
651
652 // --- Half-precision load/store (Sprint 6.1) ---
653
654 #[test]
655 fn emit_ld_global_f16() {
656 let mut w = PtxWriter::new();
657 w.indent();
658 let op = MemoryOp::LdGlobal {
659 dst: reg(RegKind::H, 0, PtxType::F16),
660 addr: reg(RegKind::Rd, 0, PtxType::U64),
661 ty: PtxType::F16,
662 };
663 op.emit(&mut w).unwrap();
664 // PTX ISA §8.7.9: `ld`'s valid type set excludes f16/bf16 — must
665 // use `.b16` for 16-bit loads into `.f16` registers.
666 assert_eq!(w.finish(), " ld.global.b16 %h0, [%rd0];\n");
667 }
668
669 #[test]
670 fn emit_st_global_f16() {
671 let mut w = PtxWriter::new();
672 w.indent();
673 let op = MemoryOp::StGlobal {
674 addr: reg(RegKind::Rd, 0, PtxType::U64),
675 src: reg(RegKind::H, 0, PtxType::F16),
676 ty: PtxType::F16,
677 };
678 op.emit(&mut w).unwrap();
679 // See ld.global counterpart — `.b16` is the memory-op form.
680 assert_eq!(w.finish(), " st.global.b16 [%rd0], %h0;\n");
681 }
682
683 #[test]
684 fn emit_ld_shared_bf16() {
685 let mut w = PtxWriter::new();
686 w.indent();
687 let op = MemoryOp::LdShared {
688 dst: reg(RegKind::Hb, 0, PtxType::BF16),
689 addr: reg(RegKind::R, 0, PtxType::U32),
690 ty: PtxType::BF16,
691 };
692 op.emit(&mut w).unwrap();
693 // Both f16 and bf16 use `.b16` in memory ops per PTX ISA.
694 assert_eq!(w.finish(), " ld.shared.b16 %hb0, [%r0];\n");
695 }
696
697 // --- cp.async (Sprint 6.2) ---
698
699 #[test]
700 fn emit_cp_async_ca_shared_global_16b() {
701 let mut w = PtxWriter::new();
702 w.indent();
703 let op = MemoryOp::new_cp_async_ca(
704 reg(RegKind::R, 0, PtxType::U32), // shared offset
705 reg(RegKind::Rd, 3, PtxType::U64), // global addr
706 16,
707 );
708 op.emit(&mut w).unwrap();
709 assert_eq!(
710 w.finish(),
711 " cp.async.ca.shared.global [%r0], [%rd3], 16;\n"
712 );
713 }
714
715 #[test]
716 fn emit_cp_async_ca_size_4() {
717 let mut w = PtxWriter::new();
718 w.indent();
719 let op = MemoryOp::new_cp_async_ca(
720 reg(RegKind::R, 1, PtxType::U32),
721 reg(RegKind::Rd, 4, PtxType::U64),
722 4,
723 );
724 op.emit(&mut w).unwrap();
725 assert_eq!(
726 w.finish(),
727 " cp.async.ca.shared.global [%r1], [%rd4], 4;\n"
728 );
729 }
730
731 #[test]
732 fn emit_cp_async_ca_size_8() {
733 let mut w = PtxWriter::new();
734 w.indent();
735 let op = MemoryOp::new_cp_async_ca(
736 reg(RegKind::R, 2, PtxType::U32),
737 reg(RegKind::Rd, 5, PtxType::U64),
738 8,
739 );
740 op.emit(&mut w).unwrap();
741 assert!(w.finish().ends_with("8;\n"));
742 }
743
744 #[test]
745 #[should_panic(expected = "cp.async.ca size must be 4, 8, or 16 bytes")]
746 fn cp_async_ca_rejects_bad_size() {
747 // 12 is not a valid size — construction should panic.
748 MemoryOp::new_cp_async_ca(
749 reg(RegKind::R, 0, PtxType::U32),
750 reg(RegKind::Rd, 0, PtxType::U64),
751 12,
752 );
753 }
754
755 #[test]
756 fn emit_cp_async_commit_group() {
757 let mut w = PtxWriter::new();
758 w.indent();
759 MemoryOp::CpAsyncCommitGroup.emit(&mut w).unwrap();
760 assert_eq!(w.finish(), " cp.async.commit_group;\n");
761 }
762
763 #[test]
764 fn emit_cp_async_wait_group_zero() {
765 let mut w = PtxWriter::new();
766 w.indent();
767 MemoryOp::CpAsyncWaitGroup { n: 0 }.emit(&mut w).unwrap();
768 assert_eq!(w.finish(), " cp.async.wait_group 0;\n");
769 }
770
771 #[test]
772 fn emit_cp_async_wait_group_n() {
773 let mut w = PtxWriter::new();
774 w.indent();
775 MemoryOp::CpAsyncWaitGroup { n: 3 }.emit(&mut w).unwrap();
776 assert_eq!(w.finish(), " cp.async.wait_group 3;\n");
777 }
778
779 #[test]
780 fn cp_async_via_ptx_instruction() {
781 use crate::ir::PtxInstruction;
782 let mut w = PtxWriter::new();
783 w.indent();
784 let instr = PtxInstruction::Memory(MemoryOp::CpAsyncCommitGroup);
785 instr.emit(&mut w).unwrap();
786 assert_eq!(w.finish(), " cp.async.commit_group;\n");
787 }
788
789 #[test]
790 fn st_global_operand_order() {
791 // Verify store has [addr], src order — NOT src, [addr]
792 let mut w = PtxWriter::new();
793 w.indent();
794 let op = MemoryOp::StGlobal {
795 addr: reg(RegKind::Rd, 0, PtxType::U64),
796 src: reg(RegKind::F, 0, PtxType::F32),
797 ty: PtxType::F32,
798 };
799 op.emit(&mut w).unwrap();
800 let output = w.finish();
801 // [%rd0] must appear BEFORE %f0
802 let addr_pos = output.find("[%rd0]").expect("address not found");
803 let src_pos = output.find("%f0").expect("source not found");
804 assert!(
805 addr_pos < src_pos,
806 "store operand order wrong: address must come before source in PTX"
807 );
808 }
809}