Skip to main content

kaio_core/instr/
memory.rs

1//! Memory PTX operations.
2//!
3//! Contains load/store instructions for global and shared memory:
4//! [`LdParam`](MemoryOp::LdParam), [`LdGlobal`](MemoryOp::LdGlobal),
5//! [`StGlobal`](MemoryOp::StGlobal), [`LdShared`](MemoryOp::LdShared),
6//! [`StShared`](MemoryOp::StShared), and
7//! [`CvtaToGlobal`](MemoryOp::CvtaToGlobal).
8
9use std::fmt;
10
11use crate::emit::{Emit, PtxWriter};
12use crate::ir::Register;
13use crate::types::PtxType;
14
15/// Memory PTX instruction variants.
16///
17/// Operand conventions:
18/// - All addresses and values are [`Register`]s (not [`Operand`](crate::ir::Operand)).
19///   You can't `ld.global` from an immediate address or `st.global` an immediate
20///   value in PTX — those go through `mov` first.
21/// - [`LdParam`](Self::LdParam) is the exception: it references a kernel parameter
22///   by name (a `String`), not by register.
23#[derive(Debug, Clone)]
24pub enum MemoryOp {
25    /// Load kernel parameter: `ld.param{ty} dst, [param_name];`
26    ///
27    /// References the parameter by name from the kernel signature.
28    /// Example: `ld.param.u64 %rd1, [vector_add_param_0];`
29    LdParam {
30        /// Destination register.
31        dst: Register,
32        /// Parameter name from the kernel signature.
33        param_name: String,
34        /// PTX type of the parameter value.
35        ty: PtxType,
36    },
37    /// Load from global memory: `ld.global{ty} dst, [addr];`
38    ///
39    /// The `addr` register holds the computed memory address.
40    /// Example: `ld.global.f32 %f1, [%rd8];`
41    LdGlobal {
42        /// Destination register.
43        dst: Register,
44        /// Register holding the memory address.
45        addr: Register,
46        /// PTX type of the loaded value.
47        ty: PtxType,
48    },
49    /// Predicated load from global memory: `@[!]{pred} ld.global{ty} dst, [addr];`
50    ///
51    /// Skips the load when the predicate evaluates false (or true when
52    /// `negate` is set). Used for edge-tile bounds checking — the OOB
53    /// thread's `dst` register is left unchanged, so callers typically
54    /// pre-initialize `dst` to zero with `mov.b32 dst, 0` and then
55    /// conditionally overwrite with a predicated load.
56    ///
57    /// Sprint 6.7 (multi-warp matmul_tc edge tiles) is the first user.
58    /// Example: `@%p1 ld.global.u32 %r5, [%rd9];`
59    LdGlobalPred {
60        /// Destination register (unchanged when predicate is false).
61        dst: Register,
62        /// Register holding the memory address.
63        addr: Register,
64        /// PTX type of the loaded value.
65        ty: PtxType,
66        /// Predicate register controlling the load.
67        pred: Register,
68        /// When `true`, negate the predicate (`@!pred`).
69        negate: bool,
70    },
71    /// 128-bit vectorized load from global memory:
72    /// `ld.global.v4.b32 {%r_i, %r_j, %r_k, %r_l}, [addr];`
73    ///
74    /// Single-instruction 128-bit transfer into 4 independent b32
75    /// destination registers. Halves (or more) the global-load
76    /// instruction count vs scalar b32 loads for bandwidth-bound
77    /// kernels. Requires the `addr` register to hold a **16-byte
78    /// aligned global-space address** — unaligned access will fault
79    /// at runtime; PTX does not catch this statically.
80    ///
81    /// Destinations are NOT required to be consecutive registers in
82    /// the allocator — PTX `ld.global.v4.b32` accepts any 4 b32 regs
83    /// in the vector brace list. In practice, allocating 4 regs in
84    /// sequence produces consecutive indices, which is what callers
85    /// typically do.
86    ///
87    /// No predicate variant in Sprint 6.7b — edge tiles stay on the
88    /// existing [`LdGlobalPred`](Self::LdGlobalPred) scalar path. A
89    /// future `LdGlobalB128Pred` would be additive.
90    ///
91    /// Sprint 6.7b (multi-warp matmul_tc Tile B fast path) is the
92    /// first user. Construct via
93    /// [`MemoryOp::new_ld_global_b128`](Self::new_ld_global_b128),
94    /// which validates that all 4 destinations are b32-class registers.
95    ///
96    /// Example: `ld.global.v4.b32 {%r0, %r1, %r2, %r3}, [%rd8];`
97    LdGlobalB128 {
98        /// Four b32 destination registers — receive bytes 0-3, 4-7,
99        /// 8-11, 12-15 of the loaded 128-bit value respectively.
100        dsts: [Register; 4],
101        /// Register holding a 16-B aligned global-space address.
102        addr: Register,
103    },
104    /// Store to global memory: `st.global{ty} [addr], src;`
105    ///
106    /// **Operand order is reversed in PTX** — address comes first,
107    /// value second. This matches PTX convention but is opposite to
108    /// loads and arithmetic where `dst` is first.
109    ///
110    /// Example: `st.global.f32 [%rd10], %f3;`
111    StGlobal {
112        /// Register holding the memory address.
113        addr: Register,
114        /// Source register (value to store).
115        src: Register,
116        /// PTX type of the stored value.
117        ty: PtxType,
118    },
119    /// Predicated store to global memory: `@[!]{pred} st.global{ty} [addr], src;`
120    ///
121    /// Skips the store when the predicate evaluates false (or true when
122    /// `negate` is set). Used for edge-tile bounds checking on output
123    /// writes — out-of-bounds threads simply don't store, leaving the
124    /// destination memory untouched.
125    ///
126    /// Sprint 6.7 (multi-warp matmul_tc edge tiles) is the first user.
127    /// Example: `@%p1 st.global.f32 [%rd11], %f4;`
128    StGlobalPred {
129        /// Register holding the memory address.
130        addr: Register,
131        /// Source register (value to store).
132        src: Register,
133        /// PTX type of the stored value.
134        ty: PtxType,
135        /// Predicate register controlling the store.
136        pred: Register,
137        /// When `true`, negate the predicate (`@!pred`).
138        negate: bool,
139    },
140    /// Load from shared memory: `ld.shared{ty} dst, [addr];`
141    ///
142    /// Shared memory is block-scoped SRAM. The `addr` register holds the
143    /// offset into the declared shared allocation.
144    /// Example: `ld.shared.f32 %f0, [%r0];`
145    LdShared {
146        /// Destination register.
147        dst: Register,
148        /// Register holding the shared memory offset.
149        addr: Register,
150        /// PTX type of the loaded value.
151        ty: PtxType,
152    },
153    /// Store to shared memory: `st.shared{ty} [addr], src;`
154    ///
155    /// **Operand order is reversed in PTX** — address first, value second
156    /// (same convention as [`StGlobal`](Self::StGlobal)).
157    /// Example: `st.shared.f32 [%r0], %f1;`
158    StShared {
159        /// Register holding the shared memory offset.
160        addr: Register,
161        /// Source register (value to store).
162        src: Register,
163        /// PTX type of the stored value.
164        ty: PtxType,
165    },
166    /// Convert generic address to global: `cvta.to.global.u64 dst, src;`
167    ///
168    /// Always `.u64` (64-bit address space, matching `.address_size 64`).
169    /// Required because `ld.param` returns generic-space pointers —
170    /// `ld.global` needs global-space addresses.
171    CvtaToGlobal {
172        /// Destination register (global-space address).
173        dst: Register,
174        /// Source register (generic-space address from `ld.param`).
175        src: Register,
176    },
177    /// Asynchronous global→shared copy, cache-at-all-levels variant:
178    /// `cp.async.ca.shared.global [dst_shared], [src_global], size_bytes;`
179    ///
180    /// Issues a non-blocking transfer from global memory into shared
181    /// memory without tying up registers. The copy is in-flight after
182    /// this instruction; use [`CpAsyncCommitGroup`](Self::CpAsyncCommitGroup)
183    /// to delimit a batch and [`CpAsyncWaitGroup`](Self::CpAsyncWaitGroup)
184    /// to synchronize. Requires **SM 8.0+ (Ampere)**.
185    ///
186    /// `size_bytes` must be one of 4, 8, or 16 (validated at construction
187    /// via [`MemoryOp::new_cp_async_ca`](Self::new_cp_async_ca)).
188    ///
189    /// Example: `cp.async.ca.shared.global [%r0], [%rd3], 16;`
190    ///
191    /// *Placement note:* cp.async lives in `MemoryOp` for Sprint 6.2
192    /// because semantically it is a memory op. The commit/wait variants
193    /// are pipeline-state operations and may relocate to a dedicated
194    /// `PipelineOp` category in Sprint 6.4 once double-buffering patterns
195    /// exercise the state machine.
196    CpAsyncCaSharedGlobal {
197        /// Register holding the shared-memory destination offset.
198        dst_shared: Register,
199        /// Register holding the global-memory source address (`.to.global`).
200        src_global: Register,
201        /// Copy size in bytes: must be 4, 8, or 16.
202        size_bytes: u8,
203    },
204    /// Commit all pending `cp.async` operations into a new async group:
205    /// `cp.async.commit_group;`
206    ///
207    /// Groups are numbered implicitly from 0 (most-recently committed)
208    /// upward. Used in conjunction with
209    /// [`CpAsyncWaitGroup`](Self::CpAsyncWaitGroup) to block until a
210    /// specific group completes. Requires **SM 8.0+**.
211    CpAsyncCommitGroup,
212    /// Wait until at most `n` async copy groups remain in-flight:
213    /// `cp.async.wait_group n;`
214    ///
215    /// `wait_group 0` waits for all outstanding groups to complete
216    /// (the common one-stage-pipeline case). For double-buffered
217    /// kernels, `wait_group 1` is used to block on the N-1'th group
218    /// while issuing the N'th. Requires **SM 8.0+**.
219    CpAsyncWaitGroup {
220        /// Number of outstanding groups still permitted after this wait.
221        n: u8,
222    },
223}
224
225impl MemoryOp {
226    /// Construct a [`CpAsyncCaSharedGlobal`](Self::CpAsyncCaSharedGlobal),
227    /// validating the size byte count.
228    ///
229    /// # Panics
230    ///
231    /// Panics if `size_bytes` is not one of `4`, `8`, or `16` — the
232    /// only sizes PTX accepts for `cp.async.ca`. PTX won't catch this
233    /// until ptxas runs, and the error there is cryptic, so we fail
234    /// loudly at construction time.
235    pub fn new_cp_async_ca(dst_shared: Register, src_global: Register, size_bytes: u8) -> Self {
236        assert!(
237            matches!(size_bytes, 4 | 8 | 16),
238            "cp.async.ca size must be 4, 8, or 16 bytes (got {size_bytes})"
239        );
240        Self::CpAsyncCaSharedGlobal {
241            dst_shared,
242            src_global,
243            size_bytes,
244        }
245    }
246
247    /// Construct an [`LdGlobalB128`](Self::LdGlobalB128), validating
248    /// that all 4 destinations are b32-class registers.
249    ///
250    /// # Panics
251    ///
252    /// Panics if any destination register is not [`crate::types::RegKind::R`] (b32).
253    /// `ld.global.v4.b32` requires 4× 32-bit-wide integer-class
254    /// destinations; `.f` / `.rd` / `.h` / `.hb` / `.p` registers are
255    /// invalid and ptxas's error message is cryptic. Fail loudly at
256    /// construction.
257    pub fn new_ld_global_b128(dsts: [Register; 4], addr: Register) -> Self {
258        use crate::types::RegKind;
259        for (i, d) in dsts.iter().enumerate() {
260            assert!(
261                d.kind == RegKind::R,
262                "ld.global.v4.b32 destination {i} must be a b32 register (RegKind::R); got {:?}",
263                d.kind
264            );
265        }
266        Self::LdGlobalB128 { dsts, addr }
267    }
268}
269
270impl Emit for MemoryOp {
271    fn emit(&self, w: &mut PtxWriter) -> fmt::Result {
272        match self {
273            MemoryOp::LdParam {
274                dst,
275                param_name,
276                ty,
277            } => {
278                let mnemonic = format!("ld.param{}", ty.ptx_memory_suffix());
279                let addr = format!("[{param_name}]");
280                w.instruction(&mnemonic, &[dst as &dyn fmt::Display, &addr])
281            }
282            MemoryOp::LdGlobal { dst, addr, ty } => {
283                let mnemonic = format!("ld.global{}", ty.ptx_memory_suffix());
284                let addr_str = format!("[{addr}]");
285                w.instruction(&mnemonic, &[dst as &dyn fmt::Display, &addr_str])
286            }
287            MemoryOp::LdGlobalPred {
288                dst,
289                addr,
290                ty,
291                pred,
292                negate,
293            } => {
294                let neg = if *negate { "!" } else { "" };
295                w.line(&format!(
296                    "@{neg}{pred} ld.global{} {dst}, [{addr}];",
297                    ty.ptx_memory_suffix()
298                ))
299            }
300            MemoryOp::LdGlobalB128 { dsts, addr } => {
301                // ld.global.v4.b32 {d0, d1, d2, d3}, [addr];
302                w.line(&format!(
303                    "ld.global.v4.b32 {{{}, {}, {}, {}}}, [{addr}];",
304                    dsts[0], dsts[1], dsts[2], dsts[3]
305                ))
306            }
307            MemoryOp::StGlobal { addr, src, ty } => {
308                let mnemonic = format!("st.global{}", ty.ptx_memory_suffix());
309                let addr_str = format!("[{addr}]");
310                // PTX store order: [address], source (reversed from load)
311                w.instruction(&mnemonic, &[&addr_str as &dyn fmt::Display, src])
312            }
313            MemoryOp::StGlobalPred {
314                addr,
315                src,
316                ty,
317                pred,
318                negate,
319            } => {
320                let neg = if *negate { "!" } else { "" };
321                w.line(&format!(
322                    "@{neg}{pred} st.global{} [{addr}], {src};",
323                    ty.ptx_memory_suffix()
324                ))
325            }
326            MemoryOp::LdShared { dst, addr, ty } => {
327                let mnemonic = format!("ld.shared{}", ty.ptx_memory_suffix());
328                let addr_str = format!("[{addr}]");
329                w.instruction(&mnemonic, &[dst as &dyn fmt::Display, &addr_str])
330            }
331            MemoryOp::StShared { addr, src, ty } => {
332                let mnemonic = format!("st.shared{}", ty.ptx_memory_suffix());
333                let addr_str = format!("[{addr}]");
334                w.instruction(&mnemonic, &[&addr_str as &dyn fmt::Display, src])
335            }
336            MemoryOp::CvtaToGlobal { dst, src } => {
337                w.instruction("cvta.to.global.u64", &[dst as &dyn fmt::Display, src])
338            }
339            MemoryOp::CpAsyncCaSharedGlobal {
340                dst_shared,
341                src_global,
342                size_bytes,
343            } => {
344                // cp.async.ca.shared.global [dst_shared], [src_global], size;
345                let dst_str = format!("[{dst_shared}]");
346                let src_str = format!("[{src_global}]");
347                let sz = *size_bytes as u32;
348                w.instruction(
349                    "cp.async.ca.shared.global",
350                    &[&dst_str as &dyn fmt::Display, &src_str, &sz],
351                )
352            }
353            MemoryOp::CpAsyncCommitGroup => w.instruction("cp.async.commit_group", &[]),
354            MemoryOp::CpAsyncWaitGroup { n } => {
355                let n = *n as u32;
356                w.instruction("cp.async.wait_group", &[&n as &dyn fmt::Display])
357            }
358        }
359    }
360}
361
362#[cfg(test)]
363mod tests {
364    use super::*;
365    use crate::types::RegKind;
366
367    /// Helper to make a register without going through the allocator.
368    fn reg(kind: RegKind, index: u32, ptx_type: PtxType) -> Register {
369        Register {
370            kind,
371            index,
372            ptx_type,
373        }
374    }
375
376    // --- nvcc golden comparisons (byte-for-byte match against nvcc --ptx -arch=sm_89) ---
377
378    #[test]
379    fn emit_ld_param_u64() {
380        // nvcc line 28: ld.param.u64 %rd1, [vector_add_param_0]
381        let mut w = PtxWriter::new();
382        w.indent();
383        let op = MemoryOp::LdParam {
384            dst: reg(RegKind::Rd, 1, PtxType::U64),
385            param_name: "vector_add_param_0".to_string(),
386            ty: PtxType::U64,
387        };
388        op.emit(&mut w).unwrap();
389        assert_eq!(w.finish(), "    ld.param.u64 %rd1, [vector_add_param_0];\n");
390    }
391
392    #[test]
393    fn emit_ld_param_u32() {
394        // nvcc line 31: ld.param.u32 %r2, [vector_add_param_3]
395        let mut w = PtxWriter::new();
396        w.indent();
397        let op = MemoryOp::LdParam {
398            dst: reg(RegKind::R, 2, PtxType::U32),
399            param_name: "vector_add_param_3".to_string(),
400            ty: PtxType::U32,
401        };
402        op.emit(&mut w).unwrap();
403        assert_eq!(w.finish(), "    ld.param.u32 %r2, [vector_add_param_3];\n");
404    }
405
406    #[test]
407    fn emit_cvta_to_global() {
408        // nvcc line 39: cvta.to.global.u64 %rd4, %rd1
409        let mut w = PtxWriter::new();
410        w.indent();
411        let op = MemoryOp::CvtaToGlobal {
412            dst: reg(RegKind::Rd, 4, PtxType::U64),
413            src: reg(RegKind::Rd, 1, PtxType::U64),
414        };
415        op.emit(&mut w).unwrap();
416        assert_eq!(w.finish(), "    cvta.to.global.u64 %rd4, %rd1;\n");
417    }
418
419    #[test]
420    fn emit_ld_global_f32() {
421        // nvcc line 44: ld.global.f32 %f1, [%rd8]
422        let mut w = PtxWriter::new();
423        w.indent();
424        let op = MemoryOp::LdGlobal {
425            dst: reg(RegKind::F, 1, PtxType::F32),
426            addr: reg(RegKind::Rd, 8, PtxType::U64),
427            ty: PtxType::F32,
428        };
429        op.emit(&mut w).unwrap();
430        assert_eq!(w.finish(), "    ld.global.f32 %f1, [%rd8];\n");
431    }
432
433    #[test]
434    fn emit_ld_global_pred_b32() {
435        // Sprint 6.7 edge-tile predicated load.
436        let mut w = PtxWriter::new();
437        w.indent();
438        let op = MemoryOp::LdGlobalPred {
439            dst: reg(RegKind::R, 5, PtxType::U32),
440            addr: reg(RegKind::Rd, 9, PtxType::U64),
441            ty: PtxType::U32,
442            pred: reg(RegKind::P, 1, PtxType::Pred),
443            negate: false,
444        };
445        op.emit(&mut w).unwrap();
446        assert_eq!(w.finish(), "    @%p1 ld.global.u32 %r5, [%rd9];\n");
447    }
448
449    #[test]
450    fn emit_ld_global_pred_negated_b32() {
451        let mut w = PtxWriter::new();
452        w.indent();
453        let op = MemoryOp::LdGlobalPred {
454            dst: reg(RegKind::R, 5, PtxType::U32),
455            addr: reg(RegKind::Rd, 9, PtxType::U64),
456            ty: PtxType::U32,
457            pred: reg(RegKind::P, 2, PtxType::Pred),
458            negate: true,
459        };
460        op.emit(&mut w).unwrap();
461        assert_eq!(w.finish(), "    @!%p2 ld.global.u32 %r5, [%rd9];\n");
462    }
463
464    #[test]
465    fn emit_ld_global_b128() {
466        // Sprint 6.7b vectorized load — 128-bit global load into 4 b32 regs.
467        let mut w = PtxWriter::new();
468        w.indent();
469        let op = MemoryOp::new_ld_global_b128(
470            [
471                reg(RegKind::R, 0, PtxType::U32),
472                reg(RegKind::R, 1, PtxType::U32),
473                reg(RegKind::R, 2, PtxType::U32),
474                reg(RegKind::R, 3, PtxType::U32),
475            ],
476            reg(RegKind::Rd, 8, PtxType::U64),
477        );
478        op.emit(&mut w).unwrap();
479        assert_eq!(
480            w.finish(),
481            "    ld.global.v4.b32 {%r0, %r1, %r2, %r3}, [%rd8];\n"
482        );
483    }
484
485    #[test]
486    fn emit_ld_global_b128_non_consecutive_regs() {
487        // PTX accepts any 4 b32 regs in the brace list — not required
488        // to be consecutive. Validates that emit just writes what it's given.
489        let mut w = PtxWriter::new();
490        w.indent();
491        let op = MemoryOp::new_ld_global_b128(
492            [
493                reg(RegKind::R, 5, PtxType::U32),
494                reg(RegKind::R, 9, PtxType::U32),
495                reg(RegKind::R, 2, PtxType::U32),
496                reg(RegKind::R, 14, PtxType::U32),
497            ],
498            reg(RegKind::Rd, 3, PtxType::U64),
499        );
500        op.emit(&mut w).unwrap();
501        assert_eq!(
502            w.finish(),
503            "    ld.global.v4.b32 {%r5, %r9, %r2, %r14}, [%rd3];\n"
504        );
505    }
506
507    #[test]
508    fn ld_global_b128_emits_single_instruction() {
509        // D1 promise: LDG.128 emits ONE PTX instruction, not four.
510        let mut w = PtxWriter::new();
511        w.indent();
512        let op = MemoryOp::new_ld_global_b128(
513            [
514                reg(RegKind::R, 0, PtxType::U32),
515                reg(RegKind::R, 1, PtxType::U32),
516                reg(RegKind::R, 2, PtxType::U32),
517                reg(RegKind::R, 3, PtxType::U32),
518            ],
519            reg(RegKind::Rd, 0, PtxType::U64),
520        );
521        op.emit(&mut w).unwrap();
522        let out = w.finish();
523        // Exactly one `ld.global.v4.b32` + exactly one newline = one line.
524        assert_eq!(out.matches("ld.global").count(), 1);
525        assert_eq!(out.matches('\n').count(), 1);
526    }
527
528    #[test]
529    #[should_panic(expected = "ld.global.v4.b32 destination 0 must be a b32 register")]
530    fn ld_global_b128_rejects_f32_destination() {
531        MemoryOp::new_ld_global_b128(
532            [
533                reg(RegKind::F, 0, PtxType::F32), // wrong kind
534                reg(RegKind::R, 1, PtxType::U32),
535                reg(RegKind::R, 2, PtxType::U32),
536                reg(RegKind::R, 3, PtxType::U32),
537            ],
538            reg(RegKind::Rd, 0, PtxType::U64),
539        );
540    }
541
542    #[test]
543    #[should_panic(expected = "ld.global.v4.b32 destination 2 must be a b32 register")]
544    fn ld_global_b128_rejects_h_destination() {
545        MemoryOp::new_ld_global_b128(
546            [
547                reg(RegKind::R, 0, PtxType::U32),
548                reg(RegKind::R, 1, PtxType::U32),
549                reg(RegKind::H, 0, PtxType::F16), // wrong kind (fp16)
550                reg(RegKind::R, 3, PtxType::U32),
551            ],
552            reg(RegKind::Rd, 0, PtxType::U64),
553        );
554    }
555
556    #[test]
557    fn ld_global_b128_via_ptx_instruction() {
558        use crate::ir::PtxInstruction;
559        let mut w = PtxWriter::new();
560        w.indent();
561        let instr = PtxInstruction::Memory(MemoryOp::new_ld_global_b128(
562            [
563                reg(RegKind::R, 0, PtxType::U32),
564                reg(RegKind::R, 1, PtxType::U32),
565                reg(RegKind::R, 2, PtxType::U32),
566                reg(RegKind::R, 3, PtxType::U32),
567            ],
568            reg(RegKind::Rd, 5, PtxType::U64),
569        ));
570        instr.emit(&mut w).unwrap();
571        assert_eq!(
572            w.finish(),
573            "    ld.global.v4.b32 {%r0, %r1, %r2, %r3}, [%rd5];\n"
574        );
575    }
576
577    #[test]
578    fn emit_st_global_pred_f32() {
579        // Sprint 6.7 edge-tile predicated store.
580        let mut w = PtxWriter::new();
581        w.indent();
582        let op = MemoryOp::StGlobalPred {
583            addr: reg(RegKind::Rd, 11, PtxType::U64),
584            src: reg(RegKind::F, 4, PtxType::F32),
585            ty: PtxType::F32,
586            pred: reg(RegKind::P, 3, PtxType::Pred),
587            negate: false,
588        };
589        op.emit(&mut w).unwrap();
590        assert_eq!(w.finish(), "    @%p3 st.global.f32 [%rd11], %f4;\n");
591    }
592
593    #[test]
594    fn emit_st_global_f32() {
595        // nvcc line 49: st.global.f32 [%rd10], %f3
596        let mut w = PtxWriter::new();
597        w.indent();
598        let op = MemoryOp::StGlobal {
599            addr: reg(RegKind::Rd, 10, PtxType::U64),
600            src: reg(RegKind::F, 3, PtxType::F32),
601            ty: PtxType::F32,
602        };
603        op.emit(&mut w).unwrap();
604        assert_eq!(w.finish(), "    st.global.f32 [%rd10], %f3;\n");
605    }
606
607    // --- Dispatch and ordering validation ---
608
609    #[test]
610    fn memory_via_ptx_instruction() {
611        use crate::ir::PtxInstruction;
612
613        let mut w = PtxWriter::new();
614        w.indent();
615        let instr = PtxInstruction::Memory(MemoryOp::LdGlobal {
616            dst: reg(RegKind::F, 0, PtxType::F32),
617            addr: reg(RegKind::Rd, 0, PtxType::U64),
618            ty: PtxType::F32,
619        });
620        instr.emit(&mut w).unwrap();
621        assert_eq!(w.finish(), "    ld.global.f32 %f0, [%rd0];\n");
622    }
623
624    // --- Shared memory ops ---
625
626    #[test]
627    fn emit_ld_shared_f32() {
628        let mut w = PtxWriter::new();
629        w.indent();
630        let op = MemoryOp::LdShared {
631            dst: reg(RegKind::F, 0, PtxType::F32),
632            addr: reg(RegKind::R, 0, PtxType::U32),
633            ty: PtxType::F32,
634        };
635        op.emit(&mut w).unwrap();
636        assert_eq!(w.finish(), "    ld.shared.f32 %f0, [%r0];\n");
637    }
638
639    #[test]
640    fn emit_st_shared_f32() {
641        let mut w = PtxWriter::new();
642        w.indent();
643        let op = MemoryOp::StShared {
644            addr: reg(RegKind::R, 0, PtxType::U32),
645            src: reg(RegKind::F, 1, PtxType::F32),
646            ty: PtxType::F32,
647        };
648        op.emit(&mut w).unwrap();
649        assert_eq!(w.finish(), "    st.shared.f32 [%r0], %f1;\n");
650    }
651
652    // --- Half-precision load/store (Sprint 6.1) ---
653
654    #[test]
655    fn emit_ld_global_f16() {
656        let mut w = PtxWriter::new();
657        w.indent();
658        let op = MemoryOp::LdGlobal {
659            dst: reg(RegKind::H, 0, PtxType::F16),
660            addr: reg(RegKind::Rd, 0, PtxType::U64),
661            ty: PtxType::F16,
662        };
663        op.emit(&mut w).unwrap();
664        // PTX ISA §8.7.9: `ld`'s valid type set excludes f16/bf16 — must
665        // use `.b16` for 16-bit loads into `.f16` registers.
666        assert_eq!(w.finish(), "    ld.global.b16 %h0, [%rd0];\n");
667    }
668
669    #[test]
670    fn emit_st_global_f16() {
671        let mut w = PtxWriter::new();
672        w.indent();
673        let op = MemoryOp::StGlobal {
674            addr: reg(RegKind::Rd, 0, PtxType::U64),
675            src: reg(RegKind::H, 0, PtxType::F16),
676            ty: PtxType::F16,
677        };
678        op.emit(&mut w).unwrap();
679        // See ld.global counterpart — `.b16` is the memory-op form.
680        assert_eq!(w.finish(), "    st.global.b16 [%rd0], %h0;\n");
681    }
682
683    #[test]
684    fn emit_ld_shared_bf16() {
685        let mut w = PtxWriter::new();
686        w.indent();
687        let op = MemoryOp::LdShared {
688            dst: reg(RegKind::Hb, 0, PtxType::BF16),
689            addr: reg(RegKind::R, 0, PtxType::U32),
690            ty: PtxType::BF16,
691        };
692        op.emit(&mut w).unwrap();
693        // Both f16 and bf16 use `.b16` in memory ops per PTX ISA.
694        assert_eq!(w.finish(), "    ld.shared.b16 %hb0, [%r0];\n");
695    }
696
697    // --- cp.async (Sprint 6.2) ---
698
699    #[test]
700    fn emit_cp_async_ca_shared_global_16b() {
701        let mut w = PtxWriter::new();
702        w.indent();
703        let op = MemoryOp::new_cp_async_ca(
704            reg(RegKind::R, 0, PtxType::U32),  // shared offset
705            reg(RegKind::Rd, 3, PtxType::U64), // global addr
706            16,
707        );
708        op.emit(&mut w).unwrap();
709        assert_eq!(
710            w.finish(),
711            "    cp.async.ca.shared.global [%r0], [%rd3], 16;\n"
712        );
713    }
714
715    #[test]
716    fn emit_cp_async_ca_size_4() {
717        let mut w = PtxWriter::new();
718        w.indent();
719        let op = MemoryOp::new_cp_async_ca(
720            reg(RegKind::R, 1, PtxType::U32),
721            reg(RegKind::Rd, 4, PtxType::U64),
722            4,
723        );
724        op.emit(&mut w).unwrap();
725        assert_eq!(
726            w.finish(),
727            "    cp.async.ca.shared.global [%r1], [%rd4], 4;\n"
728        );
729    }
730
731    #[test]
732    fn emit_cp_async_ca_size_8() {
733        let mut w = PtxWriter::new();
734        w.indent();
735        let op = MemoryOp::new_cp_async_ca(
736            reg(RegKind::R, 2, PtxType::U32),
737            reg(RegKind::Rd, 5, PtxType::U64),
738            8,
739        );
740        op.emit(&mut w).unwrap();
741        assert!(w.finish().ends_with("8;\n"));
742    }
743
744    #[test]
745    #[should_panic(expected = "cp.async.ca size must be 4, 8, or 16 bytes")]
746    fn cp_async_ca_rejects_bad_size() {
747        // 12 is not a valid size — construction should panic.
748        MemoryOp::new_cp_async_ca(
749            reg(RegKind::R, 0, PtxType::U32),
750            reg(RegKind::Rd, 0, PtxType::U64),
751            12,
752        );
753    }
754
755    #[test]
756    fn emit_cp_async_commit_group() {
757        let mut w = PtxWriter::new();
758        w.indent();
759        MemoryOp::CpAsyncCommitGroup.emit(&mut w).unwrap();
760        assert_eq!(w.finish(), "    cp.async.commit_group;\n");
761    }
762
763    #[test]
764    fn emit_cp_async_wait_group_zero() {
765        let mut w = PtxWriter::new();
766        w.indent();
767        MemoryOp::CpAsyncWaitGroup { n: 0 }.emit(&mut w).unwrap();
768        assert_eq!(w.finish(), "    cp.async.wait_group 0;\n");
769    }
770
771    #[test]
772    fn emit_cp_async_wait_group_n() {
773        let mut w = PtxWriter::new();
774        w.indent();
775        MemoryOp::CpAsyncWaitGroup { n: 3 }.emit(&mut w).unwrap();
776        assert_eq!(w.finish(), "    cp.async.wait_group 3;\n");
777    }
778
779    #[test]
780    fn cp_async_via_ptx_instruction() {
781        use crate::ir::PtxInstruction;
782        let mut w = PtxWriter::new();
783        w.indent();
784        let instr = PtxInstruction::Memory(MemoryOp::CpAsyncCommitGroup);
785        instr.emit(&mut w).unwrap();
786        assert_eq!(w.finish(), "    cp.async.commit_group;\n");
787    }
788
789    #[test]
790    fn st_global_operand_order() {
791        // Verify store has [addr], src order — NOT src, [addr]
792        let mut w = PtxWriter::new();
793        w.indent();
794        let op = MemoryOp::StGlobal {
795            addr: reg(RegKind::Rd, 0, PtxType::U64),
796            src: reg(RegKind::F, 0, PtxType::F32),
797            ty: PtxType::F32,
798        };
799        op.emit(&mut w).unwrap();
800        let output = w.finish();
801        // [%rd0] must appear BEFORE %f0
802        let addr_pos = output.find("[%rd0]").expect("address not found");
803        let src_pos = output.find("%f0").expect("source not found");
804        assert!(
805            addr_pos < src_pos,
806            "store operand order wrong: address must come before source in PTX"
807        );
808    }
809}