Skip to main content

kaio_core/ir/
module.rs

1//! PTX module — the top-level IR container.
2
3use std::fmt;
4
5use super::instruction::PtxInstruction;
6use super::kernel::PtxKernel;
7use crate::instr::{MemoryOp, TensorCoreOp};
8use crate::types::PtxType;
9
10/// A complete PTX module containing version/target metadata and kernels.
11///
12/// Corresponds to a single `.ptx` file with a header and one or more
13/// `.entry` kernel definitions.
14#[derive(Debug, Clone)]
15pub struct PtxModule {
16    /// PTX ISA version (e.g. `"7.8"`).
17    pub version: String,
18    /// Target SM architecture (e.g. `"sm_89"`).
19    pub target: String,
20    /// Address size in bits (32 or 64).
21    pub address_size: u32,
22    /// Kernel definitions in this module.
23    pub kernels: Vec<PtxKernel>,
24}
25
26impl PtxModule {
27    /// Create a new module targeting the given SM architecture.
28    ///
29    /// Defaults: PTX version `8.7` (CUDA 12.8), address size `64`.
30    pub fn new(target: &str) -> Self {
31        Self {
32            version: "8.7".to_string(),
33            target: target.to_string(),
34            address_size: 64,
35            kernels: Vec::new(),
36        }
37    }
38
39    /// Add a kernel to this module.
40    pub fn add_kernel(&mut self, kernel: PtxKernel) {
41        self.kernels.push(kernel);
42    }
43
44    /// Parse the target string (e.g. `"sm_89"`) into a numeric SM
45    /// version (e.g. `89`).
46    ///
47    /// Returns `None` if the target string is not a recognized
48    /// `sm_NN` form (e.g. future targets, virtual architectures).
49    /// [`validate`](Self::validate) tolerates unparseable targets by
50    /// skipping the SM check — we'd rather let unusual targets through
51    /// than block a user experimenting with a custom target string.
52    fn parse_sm_target(&self) -> Option<u32> {
53        self.target.strip_prefix("sm_").and_then(|s| s.parse().ok())
54    }
55
56    /// Validate that this module's target SM is high enough for every
57    /// feature used by its kernels.
58    ///
59    /// Walks all kernel bodies looking for features that carry a minimum
60    /// SM requirement — currently tensor-core operations and `cp.async`
61    /// variants (both Ampere+ / SM 8.0). Returns [`ValidationError::SmTooLow`]
62    /// on the **first** such mismatch with a human-readable description.
63    ///
64    /// This is a narrow **target-capability** check, not a semantic or
65    /// dataflow pass. The goal is to surface clean errors at emit-time
66    /// instead of cryptic ptxas messages downstream.
67    pub fn validate(&self) -> Result<(), ValidationError> {
68        let target_sm = self.parse_sm_target();
69
70        for kernel in &self.kernels {
71            for instr in &kernel.body {
72                // Target-agnostic shape/dtype routing checks.
73                if let PtxInstruction::TensorCore(op) = instr {
74                    validate_tensor_core_op(op)?;
75                }
76
77                // Target-capability check (skipped on unparseable targets).
78                if let Some(target_sm) = target_sm
79                    && let Some((required, feature)) = instruction_sm_requirement(instr)
80                    && target_sm < required
81                {
82                    return Err(ValidationError::SmTooLow {
83                        required,
84                        actual: target_sm,
85                        feature,
86                    });
87                }
88            }
89        }
90        Ok(())
91    }
92}
93
94/// Per-instruction target-agnostic IR validation for tensor-core ops.
95///
96/// Rejects bf16 dtype tags on the generic [`TensorCoreOp::MmaSync`]
97/// variant — bf16 emission must go through [`TensorCoreOp::MmaSyncBf16`] so
98/// the fragment types and instruction dtype stay aligned at the IR boundary.
99///
100/// Also rejects mis-typed registers on [`TensorCoreOp::LdMatrix`]: unlike
101/// the mma variants, whose operands are typed fragment wrappers (allocated
102/// with the correct register class by construction), `LdMatrix` carries raw
103/// [`Register`](crate::ir::Register)s — so the `.b32`-class requirement
104/// (`PtxType::U32`, the `alloc_packed_half2` packed-pair convention) is
105/// enforced here, surfacing a named error at module load instead of a
106/// cryptic ptxas failure at JIT time.
107fn validate_tensor_core_op(op: &TensorCoreOp) -> Result<(), ValidationError> {
108    match op {
109        TensorCoreOp::MmaSync { a_ty, b_ty, .. } => {
110            if *a_ty == PtxType::BF16 {
111                return Err(ValidationError::MmaSyncBf16Rejected { operand: "a_ty" });
112            }
113            if *b_ty == PtxType::BF16 {
114                return Err(ValidationError::MmaSyncBf16Rejected { operand: "b_ty" });
115            }
116        }
117        TensorCoreOp::LdMatrix { dst, addr, .. } => {
118            for reg in dst.regs() {
119                if reg.ptx_type != PtxType::U32 {
120                    return Err(ValidationError::LdMatrixBadRegType {
121                        operand: "dst",
122                        found: reg.ptx_type,
123                    });
124                }
125            }
126            if addr.ptx_type != PtxType::U32 {
127                return Err(ValidationError::LdMatrixBadRegType {
128                    operand: "addr",
129                    found: addr.ptx_type,
130                });
131            }
132        }
133        _ => {}
134    }
135    Ok(())
136}
137
138/// Return `Some((min_sm, feature_label))` if this instruction carries an SM
139/// requirement, or `None` if it is SM-agnostic.
140fn instruction_sm_requirement(instr: &PtxInstruction) -> Option<(u32, String)> {
141    match instr {
142        PtxInstruction::TensorCore(op) => Some((op.min_sm(), op.feature_label())),
143        PtxInstruction::Memory(
144            MemoryOp::CpAsyncCaSharedGlobal { .. }
145            | MemoryOp::CpAsyncCommitGroup
146            | MemoryOp::CpAsyncWaitGroup { .. },
147        ) => Some((80, "cp.async".to_string())),
148        _ => None,
149    }
150}
151
152/// Errors returned by [`PtxModule::validate`].
153///
154/// Scope is intentionally narrow — target-capability checks only, no
155/// semantic analysis.
156#[derive(Debug, Clone, PartialEq, Eq)]
157pub enum ValidationError {
158    /// A feature used by the module requires a higher SM target than
159    /// the module declares.
160    ///
161    /// Example: a kernel containing `mma.sync.m16n8k16` in a module
162    /// with `.target sm_70` would yield
163    /// `required: 80, actual: 70, feature: "mma.sync.m16n8k16"`.
164    SmTooLow {
165        /// Minimum SM version required by the offending feature.
166        required: u32,
167        /// SM version parsed from the module's target string.
168        actual: u32,
169        /// Human-readable name of the offending feature.
170        feature: String,
171    },
172    /// A [`TensorCoreOp::MmaSync`] instruction was constructed with
173    /// `PtxType::BF16` on `a_ty` or `b_ty`. Bf16 emission must use the
174    /// dedicated [`TensorCoreOp::MmaSyncBf16`] variant so fragment types
175    /// and instruction dtype stay aligned at the IR boundary.
176    ///
177    /// Introduced in Sprint 9.1 cleanup; closes the legacy hole where the
178    /// generic `MmaSync` path silently emitted a bf16 instruction from
179    /// `FragmentA_F16` / `FragmentB_F16` operands.
180    MmaSyncBf16Rejected {
181        /// Which operand carried the rejected dtype tag (`"a_ty"` or
182        /// `"b_ty"`).
183        operand: &'static str,
184    },
185    /// A [`TensorCoreOp::LdMatrix`] instruction carries a register whose
186    /// declared type is not `PtxType::U32` (`.b32` class). The mma
187    /// variants get this for free from their typed fragment wrappers;
188    /// `LdMatrix` takes raw registers, so the check lives here.
189    ///
190    /// Destination registers hold packed 16-bit pairs
191    /// ([`alloc_packed_half2`](crate::ir::RegisterAllocator::alloc_packed_half2)
192    /// convention) and the address register is a shared-space `.u32`
193    /// byte address. Introduced in Sprint 9.3.
194    LdMatrixBadRegType {
195        /// Which operand carried the rejected register (`"dst"` or
196        /// `"addr"`).
197        operand: &'static str,
198        /// The register's declared PTX type.
199        found: PtxType,
200    },
201}
202
203impl fmt::Display for ValidationError {
204    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
205        match self {
206            Self::SmTooLow {
207                required,
208                actual,
209                feature,
210            } => {
211                write!(
212                    f,
213                    "{feature} requires sm_{required}+, target is sm_{actual}"
214                )
215            }
216            Self::MmaSyncBf16Rejected { operand } => {
217                write!(
218                    f,
219                    "TensorCoreOp::MmaSync with PtxType::BF16 on {operand} is rejected; use TensorCoreOp::MmaSyncBf16 for bf16 emission"
220                )
221            }
222            Self::LdMatrixBadRegType { operand, found } => {
223                write!(
224                    f,
225                    "TensorCoreOp::LdMatrix {operand} register must be PtxType::U32 (.b32 packed-pair convention, see alloc_packed_half2), found {found:?}"
226                )
227            }
228        }
229    }
230}
231
232impl std::error::Error for ValidationError {}
233
234#[cfg(test)]
235mod tests {
236    use super::*;
237    use crate::fragment::{alloc_a_f16, alloc_b_f16, alloc_c};
238    use crate::instr::{MemoryOp, MmaShape, TensorCoreOp};
239    use crate::ir::{PtxInstruction, PtxKernel, Register, RegisterAllocator};
240    use crate::types::{PtxType, RegKind};
241
242    fn reg(kind: RegKind, index: u32, ptx_type: PtxType) -> Register {
243        Register {
244            kind,
245            index,
246            ptx_type,
247        }
248    }
249
250    fn tc_kernel() -> PtxKernel {
251        let mut alloc = RegisterAllocator::new();
252        let mut k = PtxKernel::new("has_mma");
253        k.push(PtxInstruction::TensorCore(TensorCoreOp::MmaSync {
254            d: alloc_c(&mut alloc),
255            a: alloc_a_f16(&mut alloc),
256            b: alloc_b_f16(&mut alloc),
257            c: alloc_c(&mut alloc),
258            shape: MmaShape::M16N8K16,
259            d_ty: PtxType::F32,
260            a_ty: PtxType::F16,
261            b_ty: PtxType::F16,
262            c_ty: PtxType::F32,
263        }));
264        k
265    }
266
267    #[test]
268    fn validate_rejects_mma_on_sm_70() {
269        let mut module = PtxModule::new("sm_70");
270        module.add_kernel(tc_kernel());
271        let err = module.validate().unwrap_err();
272        assert_eq!(
273            err,
274            ValidationError::SmTooLow {
275                required: 80,
276                actual: 70,
277                feature: "mma.sync.m16n8k16".to_string(),
278            }
279        );
280        assert_eq!(
281            err.to_string(),
282            "mma.sync.m16n8k16 requires sm_80+, target is sm_70"
283        );
284    }
285
286    #[test]
287    fn validate_accepts_mma_on_sm_80() {
288        let mut module = PtxModule::new("sm_80");
289        module.add_kernel(tc_kernel());
290        assert!(module.validate().is_ok());
291    }
292
293    #[test]
294    fn validate_accepts_mma_on_sm_89() {
295        let mut module = PtxModule::new("sm_89");
296        module.add_kernel(tc_kernel());
297        assert!(module.validate().is_ok());
298    }
299
300    fn tc_int8_kernel() -> PtxKernel {
301        use crate::fragment::{alloc_a_M16N8K32, alloc_b_M16N8K32, alloc_c_M16N8K32};
302        let mut alloc = RegisterAllocator::new();
303        let mut k = PtxKernel::new("has_mma_int8");
304        k.push(PtxInstruction::TensorCore(TensorCoreOp::MmaSyncInt8 {
305            d: alloc_c_M16N8K32(&mut alloc),
306            a: alloc_a_M16N8K32(&mut alloc),
307            b: alloc_b_M16N8K32(&mut alloc),
308            c: alloc_c_M16N8K32(&mut alloc),
309        }));
310        k
311    }
312
313    #[test]
314    fn validate_rejects_mma_int8_on_sm_70() {
315        let mut module = PtxModule::new("sm_70");
316        module.add_kernel(tc_int8_kernel());
317        let err = module.validate().unwrap_err();
318        assert_eq!(
319            err,
320            ValidationError::SmTooLow {
321                required: 80,
322                actual: 70,
323                feature: "mma.sync.m16n8k32.s8.s8.s32".to_string(),
324            }
325        );
326        assert_eq!(
327            err.to_string(),
328            "mma.sync.m16n8k32.s8.s8.s32 requires sm_80+, target is sm_70"
329        );
330    }
331
332    #[test]
333    fn validate_accepts_mma_int8_on_sm_80() {
334        let mut module = PtxModule::new("sm_80");
335        module.add_kernel(tc_int8_kernel());
336        assert!(module.validate().is_ok());
337    }
338
339    #[test]
340    fn validate_accepts_mma_int8_on_sm_89() {
341        let mut module = PtxModule::new("sm_89");
342        module.add_kernel(tc_int8_kernel());
343        assert!(module.validate().is_ok());
344    }
345
346    fn ldmatrix_kernel() -> PtxKernel {
347        use crate::instr::LdMatrixDst;
348        let mut alloc = RegisterAllocator::new();
349        let mut k = PtxKernel::new("has_ldmatrix");
350        k.push(PtxInstruction::TensorCore(TensorCoreOp::LdMatrix {
351            dst: LdMatrixDst::X4([
352                alloc.alloc_packed_half2(),
353                alloc.alloc_packed_half2(),
354                alloc.alloc_packed_half2(),
355                alloc.alloc_packed_half2(),
356            ]),
357            addr: alloc.alloc(PtxType::U32),
358            trans: false,
359        }));
360        k
361    }
362
363    // ldmatrix is the first sub-80 TensorCore instruction; these tests
364    // protect the shared validation path for the new 75 tier
365    // (Sprint 9.3): sm_75 accepts ldmatrix, sm_70 still rejects it, and
366    // mma stays gated at 80 even in a module whose ldmatrix is fine.
367    #[test]
368    fn validate_accepts_ldmatrix_on_sm_75() {
369        let mut module = PtxModule::new("sm_75");
370        module.add_kernel(ldmatrix_kernel());
371        assert!(module.validate().is_ok());
372    }
373
374    #[test]
375    fn validate_rejects_ldmatrix_on_sm_70() {
376        let mut module = PtxModule::new("sm_70");
377        module.add_kernel(ldmatrix_kernel());
378        let err = module.validate().unwrap_err();
379        assert_eq!(
380            err,
381            ValidationError::SmTooLow {
382                required: 75,
383                actual: 70,
384                feature: "ldmatrix.m8n8.x4".to_string(),
385            }
386        );
387        assert_eq!(
388            err.to_string(),
389            "ldmatrix.m8n8.x4 requires sm_75+, target is sm_70"
390        );
391    }
392
393    #[test]
394    fn validate_rejects_mma_at_sm_75_even_with_ldmatrix_present() {
395        let mut module = PtxModule::new("sm_75");
396        module.add_kernel(ldmatrix_kernel());
397        module.add_kernel(tc_kernel());
398        let err = module.validate().unwrap_err();
399        assert_eq!(
400            err,
401            ValidationError::SmTooLow {
402                required: 80,
403                actual: 75,
404                feature: "mma.sync.m16n8k16".to_string(),
405            }
406        );
407    }
408
409    #[test]
410    fn validate_rejects_ldmatrix_bad_dst_reg_type() {
411        use crate::instr::LdMatrixDst;
412        let mut alloc = RegisterAllocator::new();
413        let mut k = PtxKernel::new("bad_ldmatrix_dst");
414        // Third dst register is an .f32 — not the .b32 packed-pair class.
415        k.push(PtxInstruction::TensorCore(TensorCoreOp::LdMatrix {
416            dst: LdMatrixDst::X4([
417                alloc.alloc_packed_half2(),
418                alloc.alloc_packed_half2(),
419                alloc.alloc(PtxType::F32),
420                alloc.alloc_packed_half2(),
421            ]),
422            addr: alloc.alloc(PtxType::U32),
423            trans: false,
424        }));
425        let mut module = PtxModule::new("sm_80");
426        module.add_kernel(k);
427        let err = module.validate().unwrap_err();
428        assert_eq!(
429            err,
430            ValidationError::LdMatrixBadRegType {
431                operand: "dst",
432                found: PtxType::F32,
433            }
434        );
435    }
436
437    #[test]
438    fn validate_rejects_ldmatrix_bad_addr_reg_type() {
439        use crate::instr::LdMatrixDst;
440        let mut alloc = RegisterAllocator::new();
441        let mut k = PtxKernel::new("bad_ldmatrix_addr");
442        // Shared addresses are 32-bit byte offsets in this IR — a .u64
443        // address register is a wiring bug.
444        k.push(PtxInstruction::TensorCore(TensorCoreOp::LdMatrix {
445            dst: LdMatrixDst::X2([alloc.alloc_packed_half2(), alloc.alloc_packed_half2()]),
446            addr: alloc.alloc(PtxType::U64),
447            trans: true,
448        }));
449        let mut module = PtxModule::new("sm_80");
450        module.add_kernel(k);
451        let err = module.validate().unwrap_err();
452        assert_eq!(
453            err,
454            ValidationError::LdMatrixBadRegType {
455                operand: "addr",
456                found: PtxType::U64,
457            }
458        );
459        assert!(err.to_string().contains("alloc_packed_half2"));
460    }
461
462    #[test]
463    fn validate_rejects_cp_async_on_sm_75() {
464        let mut module = PtxModule::new("sm_75");
465        let mut k = PtxKernel::new("has_cp_async");
466        k.push(PtxInstruction::Memory(MemoryOp::new_cp_async_ca(
467            reg(RegKind::R, 0, PtxType::U32),
468            reg(RegKind::Rd, 0, PtxType::U64),
469            16,
470        )));
471        module.add_kernel(k);
472        let err = module.validate().unwrap_err();
473        assert_eq!(
474            err,
475            ValidationError::SmTooLow {
476                required: 80,
477                actual: 75,
478                feature: "cp.async".to_string(),
479            }
480        );
481    }
482
483    #[test]
484    fn validate_accepts_scalar_kernel_on_sm_70() {
485        // A module with no tensor-core or cp.async features should pass
486        // validation even on sm_70.
487        let mut module = PtxModule::new("sm_70");
488        let k = PtxKernel::new("scalar_only");
489        module.add_kernel(k);
490        assert!(module.validate().is_ok());
491    }
492
493    #[test]
494    fn validate_skips_unparseable_target() {
495        // Don't block weird custom targets.
496        let mut module = PtxModule::new("compute_90a");
497        module.add_kernel(tc_kernel());
498        assert!(module.validate().is_ok());
499    }
500
501    #[test]
502    fn parse_sm_target() {
503        let m = PtxModule::new("sm_89");
504        assert_eq!(m.parse_sm_target(), Some(89));
505        let m2 = PtxModule::new("sm_80");
506        assert_eq!(m2.parse_sm_target(), Some(80));
507        let m3 = PtxModule::new("compute_90a");
508        assert_eq!(m3.parse_sm_target(), None);
509    }
510
511    fn mma_sync_with_bf16_tags() -> PtxKernel {
512        let mut alloc = RegisterAllocator::new();
513        let mut k = PtxKernel::new("legacy_bf16_on_mma_sync");
514        k.push(PtxInstruction::TensorCore(TensorCoreOp::MmaSync {
515            d: alloc_c(&mut alloc),
516            a: alloc_a_f16(&mut alloc),
517            b: alloc_b_f16(&mut alloc),
518            c: alloc_c(&mut alloc),
519            shape: MmaShape::M16N8K16,
520            d_ty: PtxType::F32,
521            a_ty: PtxType::BF16,
522            b_ty: PtxType::BF16,
523            c_ty: PtxType::F32,
524        }));
525        k
526    }
527
528    #[test]
529    fn validate_rejects_mma_sync_bf16_a_ty() {
530        let mut module = PtxModule::new("sm_89");
531        module.add_kernel(mma_sync_with_bf16_tags());
532        let err = module.validate().unwrap_err();
533        assert_eq!(
534            err,
535            ValidationError::MmaSyncBf16Rejected { operand: "a_ty" }
536        );
537        assert_eq!(
538            err.to_string(),
539            "TensorCoreOp::MmaSync with PtxType::BF16 on a_ty is rejected; \
540             use TensorCoreOp::MmaSyncBf16 for bf16 emission"
541        );
542    }
543
544    #[test]
545    fn validate_rejects_mma_sync_bf16_b_ty_only() {
546        // Mixed: a_ty F16 + b_ty BF16 still rejected.
547        let mut alloc = RegisterAllocator::new();
548        let mut k = PtxKernel::new("mixed_bf16_b_only");
549        k.push(PtxInstruction::TensorCore(TensorCoreOp::MmaSync {
550            d: alloc_c(&mut alloc),
551            a: alloc_a_f16(&mut alloc),
552            b: alloc_b_f16(&mut alloc),
553            c: alloc_c(&mut alloc),
554            shape: MmaShape::M16N8K16,
555            d_ty: PtxType::F32,
556            a_ty: PtxType::F16,
557            b_ty: PtxType::BF16,
558            c_ty: PtxType::F32,
559        }));
560        let mut module = PtxModule::new("sm_89");
561        module.add_kernel(k);
562        let err = module.validate().unwrap_err();
563        assert_eq!(
564            err,
565            ValidationError::MmaSyncBf16Rejected { operand: "b_ty" }
566        );
567    }
568
569    #[test]
570    fn validate_rejects_mma_sync_bf16_even_on_unparseable_target() {
571        // The dtype-routing check is target-agnostic — it fires regardless
572        // of whether the target string is `sm_NN`.
573        let mut module = PtxModule::new("compute_90a");
574        module.add_kernel(mma_sync_with_bf16_tags());
575        assert_eq!(
576            module.validate().unwrap_err(),
577            ValidationError::MmaSyncBf16Rejected { operand: "a_ty" }
578        );
579    }
580
581    fn mma_sync_bf16_kernel() -> PtxKernel {
582        use crate::fragment::{alloc_a_bf16, alloc_b_bf16};
583        let mut alloc = RegisterAllocator::new();
584        let mut k = PtxKernel::new("native_bf16");
585        k.push(PtxInstruction::TensorCore(TensorCoreOp::MmaSyncBf16 {
586            d: alloc_c(&mut alloc),
587            a: alloc_a_bf16(&mut alloc),
588            b: alloc_b_bf16(&mut alloc),
589            c: alloc_c(&mut alloc),
590        }));
591        k
592    }
593
594    #[test]
595    fn validate_accepts_mma_sync_bf16_dedicated_variant() {
596        let mut module = PtxModule::new("sm_89");
597        module.add_kernel(mma_sync_bf16_kernel());
598        assert!(module.validate().is_ok());
599    }
600}