1use std::fmt;
4
5use super::instruction::PtxInstruction;
6use super::kernel::PtxKernel;
7use crate::instr::MemoryOp;
8
9#[derive(Debug, Clone)]
14pub struct PtxModule {
15 pub version: String,
17 pub target: String,
19 pub address_size: u32,
21 pub kernels: Vec<PtxKernel>,
23}
24
25impl PtxModule {
26 pub fn new(target: &str) -> Self {
30 Self {
31 version: "8.7".to_string(),
32 target: target.to_string(),
33 address_size: 64,
34 kernels: Vec::new(),
35 }
36 }
37
38 pub fn add_kernel(&mut self, kernel: PtxKernel) {
40 self.kernels.push(kernel);
41 }
42
43 fn parse_sm_target(&self) -> Option<u32> {
52 self.target.strip_prefix("sm_").and_then(|s| s.parse().ok())
53 }
54
55 pub fn validate(&self) -> Result<(), ValidationError> {
67 let Some(target_sm) = self.parse_sm_target() else {
68 return Ok(());
70 };
71
72 for kernel in &self.kernels {
73 for instr in &kernel.body {
74 if let Some((required, feature)) = instruction_sm_requirement(instr)
75 && target_sm < required
76 {
77 return Err(ValidationError::SmTooLow {
78 required,
79 actual: target_sm,
80 feature,
81 });
82 }
83 }
84 }
85 Ok(())
86 }
87}
88
89fn instruction_sm_requirement(instr: &PtxInstruction) -> Option<(u32, String)> {
92 match instr {
93 PtxInstruction::TensorCore(op) => Some((op.min_sm(), op.feature_label())),
94 PtxInstruction::Memory(
95 MemoryOp::CpAsyncCaSharedGlobal { .. }
96 | MemoryOp::CpAsyncCommitGroup
97 | MemoryOp::CpAsyncWaitGroup { .. },
98 ) => Some((80, "cp.async".to_string())),
99 _ => None,
100 }
101}
102
103#[derive(Debug, Clone, PartialEq, Eq)]
108pub enum ValidationError {
109 SmTooLow {
116 required: u32,
118 actual: u32,
120 feature: String,
122 },
123}
124
125impl fmt::Display for ValidationError {
126 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
127 match self {
128 Self::SmTooLow {
129 required,
130 actual,
131 feature,
132 } => {
133 write!(
134 f,
135 "{feature} requires sm_{required}+, target is sm_{actual}"
136 )
137 }
138 }
139 }
140}
141
142impl std::error::Error for ValidationError {}
143
144#[cfg(test)]
145mod tests {
146 use super::*;
147 use crate::fragment::{alloc_a, alloc_b, alloc_c};
148 use crate::instr::{MemoryOp, MmaShape, TensorCoreOp};
149 use crate::ir::{PtxInstruction, PtxKernel, Register, RegisterAllocator};
150 use crate::types::{PtxType, RegKind};
151
152 fn reg(kind: RegKind, index: u32, ptx_type: PtxType) -> Register {
153 Register {
154 kind,
155 index,
156 ptx_type,
157 }
158 }
159
160 fn tc_kernel() -> PtxKernel {
161 let mut alloc = RegisterAllocator::new();
162 let mut k = PtxKernel::new("has_mma");
163 k.push(PtxInstruction::TensorCore(TensorCoreOp::MmaSync {
164 d: alloc_c(&mut alloc),
165 a: alloc_a(&mut alloc),
166 b: alloc_b(&mut alloc),
167 c: alloc_c(&mut alloc),
168 shape: MmaShape::M16N8K16,
169 d_ty: PtxType::F32,
170 a_ty: PtxType::F16,
171 b_ty: PtxType::F16,
172 c_ty: PtxType::F32,
173 }));
174 k
175 }
176
177 #[test]
178 fn validate_rejects_mma_on_sm_70() {
179 let mut module = PtxModule::new("sm_70");
180 module.add_kernel(tc_kernel());
181 let err = module.validate().unwrap_err();
182 assert_eq!(
183 err,
184 ValidationError::SmTooLow {
185 required: 80,
186 actual: 70,
187 feature: "mma.sync.m16n8k16".to_string(),
188 }
189 );
190 assert_eq!(
191 err.to_string(),
192 "mma.sync.m16n8k16 requires sm_80+, target is sm_70"
193 );
194 }
195
196 #[test]
197 fn validate_accepts_mma_on_sm_80() {
198 let mut module = PtxModule::new("sm_80");
199 module.add_kernel(tc_kernel());
200 assert!(module.validate().is_ok());
201 }
202
203 #[test]
204 fn validate_accepts_mma_on_sm_89() {
205 let mut module = PtxModule::new("sm_89");
206 module.add_kernel(tc_kernel());
207 assert!(module.validate().is_ok());
208 }
209
210 fn tc_int8_kernel() -> PtxKernel {
211 use crate::fragment::{alloc_a_M16N8K32, alloc_b_M16N8K32, alloc_c_M16N8K32};
212 let mut alloc = RegisterAllocator::new();
213 let mut k = PtxKernel::new("has_mma_int8");
214 k.push(PtxInstruction::TensorCore(TensorCoreOp::MmaSyncInt8 {
215 d: alloc_c_M16N8K32(&mut alloc),
216 a: alloc_a_M16N8K32(&mut alloc),
217 b: alloc_b_M16N8K32(&mut alloc),
218 c: alloc_c_M16N8K32(&mut alloc),
219 }));
220 k
221 }
222
223 #[test]
224 fn validate_rejects_mma_int8_on_sm_70() {
225 let mut module = PtxModule::new("sm_70");
226 module.add_kernel(tc_int8_kernel());
227 let err = module.validate().unwrap_err();
228 assert_eq!(
229 err,
230 ValidationError::SmTooLow {
231 required: 80,
232 actual: 70,
233 feature: "mma.sync.m16n8k32.s8.s8.s32".to_string(),
234 }
235 );
236 assert_eq!(
237 err.to_string(),
238 "mma.sync.m16n8k32.s8.s8.s32 requires sm_80+, target is sm_70"
239 );
240 }
241
242 #[test]
243 fn validate_accepts_mma_int8_on_sm_80() {
244 let mut module = PtxModule::new("sm_80");
245 module.add_kernel(tc_int8_kernel());
246 assert!(module.validate().is_ok());
247 }
248
249 #[test]
250 fn validate_accepts_mma_int8_on_sm_89() {
251 let mut module = PtxModule::new("sm_89");
252 module.add_kernel(tc_int8_kernel());
253 assert!(module.validate().is_ok());
254 }
255
256 #[test]
257 fn validate_rejects_cp_async_on_sm_75() {
258 let mut module = PtxModule::new("sm_75");
259 let mut k = PtxKernel::new("has_cp_async");
260 k.push(PtxInstruction::Memory(MemoryOp::new_cp_async_ca(
261 reg(RegKind::R, 0, PtxType::U32),
262 reg(RegKind::Rd, 0, PtxType::U64),
263 16,
264 )));
265 module.add_kernel(k);
266 let err = module.validate().unwrap_err();
267 assert_eq!(
268 err,
269 ValidationError::SmTooLow {
270 required: 80,
271 actual: 75,
272 feature: "cp.async".to_string(),
273 }
274 );
275 }
276
277 #[test]
278 fn validate_accepts_scalar_kernel_on_sm_70() {
279 let mut module = PtxModule::new("sm_70");
282 let k = PtxKernel::new("scalar_only");
283 module.add_kernel(k);
284 assert!(module.validate().is_ok());
285 }
286
287 #[test]
288 fn validate_skips_unparseable_target() {
289 let mut module = PtxModule::new("compute_90a");
291 module.add_kernel(tc_kernel());
292 assert!(module.validate().is_ok());
293 }
294
295 #[test]
296 fn parse_sm_target() {
297 let m = PtxModule::new("sm_89");
298 assert_eq!(m.parse_sm_target(), Some(89));
299 let m2 = PtxModule::new("sm_80");
300 assert_eq!(m2.parse_sm_target(), Some(80));
301 let m3 = PtxModule::new("compute_90a");
302 assert_eq!(m3.parse_sm_target(), None);
303 }
304}