Skip to main content

kaio_core/ir/
module.rs

1//! PTX module — the top-level IR container.
2
3use std::fmt;
4
5use super::instruction::PtxInstruction;
6use super::kernel::PtxKernel;
7use crate::instr::MemoryOp;
8
9/// A complete PTX module containing version/target metadata and kernels.
10///
11/// Corresponds to a single `.ptx` file with a header and one or more
12/// `.entry` kernel definitions.
13#[derive(Debug, Clone)]
14pub struct PtxModule {
15    /// PTX ISA version (e.g. `"7.8"`).
16    pub version: String,
17    /// Target SM architecture (e.g. `"sm_89"`).
18    pub target: String,
19    /// Address size in bits (32 or 64).
20    pub address_size: u32,
21    /// Kernel definitions in this module.
22    pub kernels: Vec<PtxKernel>,
23}
24
25impl PtxModule {
26    /// Create a new module targeting the given SM architecture.
27    ///
28    /// Defaults: PTX version `8.7` (CUDA 12.8), address size `64`.
29    pub fn new(target: &str) -> Self {
30        Self {
31            version: "8.7".to_string(),
32            target: target.to_string(),
33            address_size: 64,
34            kernels: Vec::new(),
35        }
36    }
37
38    /// Add a kernel to this module.
39    pub fn add_kernel(&mut self, kernel: PtxKernel) {
40        self.kernels.push(kernel);
41    }
42
43    /// Parse the target string (e.g. `"sm_89"`) into a numeric SM
44    /// version (e.g. `89`).
45    ///
46    /// Returns `None` if the target string is not a recognized
47    /// `sm_NN` form (e.g. future targets, virtual architectures).
48    /// [`validate`](Self::validate) tolerates unparseable targets by
49    /// skipping the SM check — we'd rather let unusual targets through
50    /// than block a user experimenting with a custom target string.
51    fn parse_sm_target(&self) -> Option<u32> {
52        self.target.strip_prefix("sm_").and_then(|s| s.parse().ok())
53    }
54
55    /// Validate that this module's target SM is high enough for every
56    /// feature used by its kernels.
57    ///
58    /// Walks all kernel bodies looking for features that carry a minimum
59    /// SM requirement — currently tensor-core operations and `cp.async`
60    /// variants (both Ampere+ / SM 8.0). Returns [`ValidationError::SmTooLow`]
61    /// on the **first** such mismatch with a human-readable description.
62    ///
63    /// This is a narrow **target-capability** check, not a semantic or
64    /// dataflow pass. The goal is to surface clean errors at emit-time
65    /// instead of cryptic ptxas messages downstream.
66    pub fn validate(&self) -> Result<(), ValidationError> {
67        let Some(target_sm) = self.parse_sm_target() else {
68            // Unrecognized target (e.g. custom or virtual arch) — skip.
69            return Ok(());
70        };
71
72        for kernel in &self.kernels {
73            for instr in &kernel.body {
74                if let Some((required, feature)) = instruction_sm_requirement(instr)
75                    && target_sm < required
76                {
77                    return Err(ValidationError::SmTooLow {
78                        required,
79                        actual: target_sm,
80                        feature,
81                    });
82                }
83            }
84        }
85        Ok(())
86    }
87}
88
89/// Return `Some((min_sm, feature_label))` if this instruction carries an SM
90/// requirement, or `None` if it is SM-agnostic.
91fn instruction_sm_requirement(instr: &PtxInstruction) -> Option<(u32, String)> {
92    match instr {
93        PtxInstruction::TensorCore(op) => Some((op.min_sm(), op.feature_label())),
94        PtxInstruction::Memory(
95            MemoryOp::CpAsyncCaSharedGlobal { .. }
96            | MemoryOp::CpAsyncCommitGroup
97            | MemoryOp::CpAsyncWaitGroup { .. },
98        ) => Some((80, "cp.async".to_string())),
99        _ => None,
100    }
101}
102
103/// Errors returned by [`PtxModule::validate`].
104///
105/// Scope is intentionally narrow — target-capability checks only, no
106/// semantic analysis.
107#[derive(Debug, Clone, PartialEq, Eq)]
108pub enum ValidationError {
109    /// A feature used by the module requires a higher SM target than
110    /// the module declares.
111    ///
112    /// Example: a kernel containing `mma.sync.m16n8k16` in a module
113    /// with `.target sm_70` would yield
114    /// `required: 80, actual: 70, feature: "mma.sync.m16n8k16"`.
115    SmTooLow {
116        /// Minimum SM version required by the offending feature.
117        required: u32,
118        /// SM version parsed from the module's target string.
119        actual: u32,
120        /// Human-readable name of the offending feature.
121        feature: String,
122    },
123}
124
125impl fmt::Display for ValidationError {
126    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
127        match self {
128            Self::SmTooLow {
129                required,
130                actual,
131                feature,
132            } => {
133                write!(
134                    f,
135                    "{feature} requires sm_{required}+, target is sm_{actual}"
136                )
137            }
138        }
139    }
140}
141
142impl std::error::Error for ValidationError {}
143
144#[cfg(test)]
145mod tests {
146    use super::*;
147    use crate::fragment::{alloc_a, alloc_b, alloc_c};
148    use crate::instr::{MemoryOp, MmaShape, TensorCoreOp};
149    use crate::ir::{PtxInstruction, PtxKernel, Register, RegisterAllocator};
150    use crate::types::{PtxType, RegKind};
151
152    fn reg(kind: RegKind, index: u32, ptx_type: PtxType) -> Register {
153        Register {
154            kind,
155            index,
156            ptx_type,
157        }
158    }
159
160    fn tc_kernel() -> PtxKernel {
161        let mut alloc = RegisterAllocator::new();
162        let mut k = PtxKernel::new("has_mma");
163        k.push(PtxInstruction::TensorCore(TensorCoreOp::MmaSync {
164            d: alloc_c(&mut alloc),
165            a: alloc_a(&mut alloc),
166            b: alloc_b(&mut alloc),
167            c: alloc_c(&mut alloc),
168            shape: MmaShape::M16N8K16,
169            d_ty: PtxType::F32,
170            a_ty: PtxType::F16,
171            b_ty: PtxType::F16,
172            c_ty: PtxType::F32,
173        }));
174        k
175    }
176
177    #[test]
178    fn validate_rejects_mma_on_sm_70() {
179        let mut module = PtxModule::new("sm_70");
180        module.add_kernel(tc_kernel());
181        let err = module.validate().unwrap_err();
182        assert_eq!(
183            err,
184            ValidationError::SmTooLow {
185                required: 80,
186                actual: 70,
187                feature: "mma.sync.m16n8k16".to_string(),
188            }
189        );
190        assert_eq!(
191            err.to_string(),
192            "mma.sync.m16n8k16 requires sm_80+, target is sm_70"
193        );
194    }
195
196    #[test]
197    fn validate_accepts_mma_on_sm_80() {
198        let mut module = PtxModule::new("sm_80");
199        module.add_kernel(tc_kernel());
200        assert!(module.validate().is_ok());
201    }
202
203    #[test]
204    fn validate_accepts_mma_on_sm_89() {
205        let mut module = PtxModule::new("sm_89");
206        module.add_kernel(tc_kernel());
207        assert!(module.validate().is_ok());
208    }
209
210    fn tc_int8_kernel() -> PtxKernel {
211        use crate::fragment::{alloc_a_M16N8K32, alloc_b_M16N8K32, alloc_c_M16N8K32};
212        let mut alloc = RegisterAllocator::new();
213        let mut k = PtxKernel::new("has_mma_int8");
214        k.push(PtxInstruction::TensorCore(TensorCoreOp::MmaSyncInt8 {
215            d: alloc_c_M16N8K32(&mut alloc),
216            a: alloc_a_M16N8K32(&mut alloc),
217            b: alloc_b_M16N8K32(&mut alloc),
218            c: alloc_c_M16N8K32(&mut alloc),
219        }));
220        k
221    }
222
223    #[test]
224    fn validate_rejects_mma_int8_on_sm_70() {
225        let mut module = PtxModule::new("sm_70");
226        module.add_kernel(tc_int8_kernel());
227        let err = module.validate().unwrap_err();
228        assert_eq!(
229            err,
230            ValidationError::SmTooLow {
231                required: 80,
232                actual: 70,
233                feature: "mma.sync.m16n8k32.s8.s8.s32".to_string(),
234            }
235        );
236        assert_eq!(
237            err.to_string(),
238            "mma.sync.m16n8k32.s8.s8.s32 requires sm_80+, target is sm_70"
239        );
240    }
241
242    #[test]
243    fn validate_accepts_mma_int8_on_sm_80() {
244        let mut module = PtxModule::new("sm_80");
245        module.add_kernel(tc_int8_kernel());
246        assert!(module.validate().is_ok());
247    }
248
249    #[test]
250    fn validate_accepts_mma_int8_on_sm_89() {
251        let mut module = PtxModule::new("sm_89");
252        module.add_kernel(tc_int8_kernel());
253        assert!(module.validate().is_ok());
254    }
255
256    #[test]
257    fn validate_rejects_cp_async_on_sm_75() {
258        let mut module = PtxModule::new("sm_75");
259        let mut k = PtxKernel::new("has_cp_async");
260        k.push(PtxInstruction::Memory(MemoryOp::new_cp_async_ca(
261            reg(RegKind::R, 0, PtxType::U32),
262            reg(RegKind::Rd, 0, PtxType::U64),
263            16,
264        )));
265        module.add_kernel(k);
266        let err = module.validate().unwrap_err();
267        assert_eq!(
268            err,
269            ValidationError::SmTooLow {
270                required: 80,
271                actual: 75,
272                feature: "cp.async".to_string(),
273            }
274        );
275    }
276
277    #[test]
278    fn validate_accepts_scalar_kernel_on_sm_70() {
279        // A module with no tensor-core or cp.async features should pass
280        // validation even on sm_70.
281        let mut module = PtxModule::new("sm_70");
282        let k = PtxKernel::new("scalar_only");
283        module.add_kernel(k);
284        assert!(module.validate().is_ok());
285    }
286
287    #[test]
288    fn validate_skips_unparseable_target() {
289        // Don't block weird custom targets.
290        let mut module = PtxModule::new("compute_90a");
291        module.add_kernel(tc_kernel());
292        assert!(module.validate().is_ok());
293    }
294
295    #[test]
296    fn parse_sm_target() {
297        let m = PtxModule::new("sm_89");
298        assert_eq!(m.parse_sm_target(), Some(89));
299        let m2 = PtxModule::new("sm_80");
300        assert_eq!(m2.parse_sm_target(), Some(80));
301        let m3 = PtxModule::new("compute_90a");
302        assert_eq!(m3.parse_sm_target(), None);
303    }
304}