1use std::fmt;
4
5use super::instruction::PtxInstruction;
6use super::kernel::PtxKernel;
7use crate::instr::{MemoryOp, TensorCoreOp};
8use crate::types::PtxType;
9
10#[derive(Debug, Clone)]
15pub struct PtxModule {
16 pub version: String,
18 pub target: String,
20 pub address_size: u32,
22 pub kernels: Vec<PtxKernel>,
24}
25
26impl PtxModule {
27 pub fn new(target: &str) -> Self {
31 Self {
32 version: "8.7".to_string(),
33 target: target.to_string(),
34 address_size: 64,
35 kernels: Vec::new(),
36 }
37 }
38
39 pub fn add_kernel(&mut self, kernel: PtxKernel) {
41 self.kernels.push(kernel);
42 }
43
44 fn parse_sm_target(&self) -> Option<u32> {
53 self.target.strip_prefix("sm_").and_then(|s| s.parse().ok())
54 }
55
56 pub fn validate(&self) -> Result<(), ValidationError> {
68 let target_sm = self.parse_sm_target();
69
70 for kernel in &self.kernels {
71 for instr in &kernel.body {
72 if let PtxInstruction::TensorCore(op) = instr {
74 validate_tensor_core_op(op)?;
75 }
76
77 if let Some(target_sm) = target_sm
79 && let Some((required, feature)) = instruction_sm_requirement(instr)
80 && target_sm < required
81 {
82 return Err(ValidationError::SmTooLow {
83 required,
84 actual: target_sm,
85 feature,
86 });
87 }
88 }
89 }
90 Ok(())
91 }
92}
93
94fn validate_tensor_core_op(op: &TensorCoreOp) -> Result<(), ValidationError> {
108 match op {
109 TensorCoreOp::MmaSync { a_ty, b_ty, .. } => {
110 if *a_ty == PtxType::BF16 {
111 return Err(ValidationError::MmaSyncBf16Rejected { operand: "a_ty" });
112 }
113 if *b_ty == PtxType::BF16 {
114 return Err(ValidationError::MmaSyncBf16Rejected { operand: "b_ty" });
115 }
116 }
117 TensorCoreOp::LdMatrix { dst, addr, .. } => {
118 for reg in dst.regs() {
119 if reg.ptx_type != PtxType::U32 {
120 return Err(ValidationError::LdMatrixBadRegType {
121 operand: "dst",
122 found: reg.ptx_type,
123 });
124 }
125 }
126 if addr.ptx_type != PtxType::U32 {
127 return Err(ValidationError::LdMatrixBadRegType {
128 operand: "addr",
129 found: addr.ptx_type,
130 });
131 }
132 }
133 _ => {}
134 }
135 Ok(())
136}
137
138fn instruction_sm_requirement(instr: &PtxInstruction) -> Option<(u32, String)> {
141 match instr {
142 PtxInstruction::TensorCore(op) => Some((op.min_sm(), op.feature_label())),
143 PtxInstruction::Memory(
144 MemoryOp::CpAsyncCaSharedGlobal { .. }
145 | MemoryOp::CpAsyncCommitGroup
146 | MemoryOp::CpAsyncWaitGroup { .. },
147 ) => Some((80, "cp.async".to_string())),
148 _ => None,
149 }
150}
151
152#[derive(Debug, Clone, PartialEq, Eq)]
157pub enum ValidationError {
158 SmTooLow {
165 required: u32,
167 actual: u32,
169 feature: String,
171 },
172 MmaSyncBf16Rejected {
181 operand: &'static str,
184 },
185 LdMatrixBadRegType {
195 operand: &'static str,
198 found: PtxType,
200 },
201}
202
203impl fmt::Display for ValidationError {
204 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
205 match self {
206 Self::SmTooLow {
207 required,
208 actual,
209 feature,
210 } => {
211 write!(
212 f,
213 "{feature} requires sm_{required}+, target is sm_{actual}"
214 )
215 }
216 Self::MmaSyncBf16Rejected { operand } => {
217 write!(
218 f,
219 "TensorCoreOp::MmaSync with PtxType::BF16 on {operand} is rejected; use TensorCoreOp::MmaSyncBf16 for bf16 emission"
220 )
221 }
222 Self::LdMatrixBadRegType { operand, found } => {
223 write!(
224 f,
225 "TensorCoreOp::LdMatrix {operand} register must be PtxType::U32 (.b32 packed-pair convention, see alloc_packed_half2), found {found:?}"
226 )
227 }
228 }
229 }
230}
231
232impl std::error::Error for ValidationError {}
233
234#[cfg(test)]
235mod tests {
236 use super::*;
237 use crate::fragment::{alloc_a_f16, alloc_b_f16, alloc_c};
238 use crate::instr::{MemoryOp, MmaShape, TensorCoreOp};
239 use crate::ir::{PtxInstruction, PtxKernel, Register, RegisterAllocator};
240 use crate::types::{PtxType, RegKind};
241
242 fn reg(kind: RegKind, index: u32, ptx_type: PtxType) -> Register {
243 Register {
244 kind,
245 index,
246 ptx_type,
247 }
248 }
249
250 fn tc_kernel() -> PtxKernel {
251 let mut alloc = RegisterAllocator::new();
252 let mut k = PtxKernel::new("has_mma");
253 k.push(PtxInstruction::TensorCore(TensorCoreOp::MmaSync {
254 d: alloc_c(&mut alloc),
255 a: alloc_a_f16(&mut alloc),
256 b: alloc_b_f16(&mut alloc),
257 c: alloc_c(&mut alloc),
258 shape: MmaShape::M16N8K16,
259 d_ty: PtxType::F32,
260 a_ty: PtxType::F16,
261 b_ty: PtxType::F16,
262 c_ty: PtxType::F32,
263 }));
264 k
265 }
266
267 #[test]
268 fn validate_rejects_mma_on_sm_70() {
269 let mut module = PtxModule::new("sm_70");
270 module.add_kernel(tc_kernel());
271 let err = module.validate().unwrap_err();
272 assert_eq!(
273 err,
274 ValidationError::SmTooLow {
275 required: 80,
276 actual: 70,
277 feature: "mma.sync.m16n8k16".to_string(),
278 }
279 );
280 assert_eq!(
281 err.to_string(),
282 "mma.sync.m16n8k16 requires sm_80+, target is sm_70"
283 );
284 }
285
286 #[test]
287 fn validate_accepts_mma_on_sm_80() {
288 let mut module = PtxModule::new("sm_80");
289 module.add_kernel(tc_kernel());
290 assert!(module.validate().is_ok());
291 }
292
293 #[test]
294 fn validate_accepts_mma_on_sm_89() {
295 let mut module = PtxModule::new("sm_89");
296 module.add_kernel(tc_kernel());
297 assert!(module.validate().is_ok());
298 }
299
300 fn tc_int8_kernel() -> PtxKernel {
301 use crate::fragment::{alloc_a_M16N8K32, alloc_b_M16N8K32, alloc_c_M16N8K32};
302 let mut alloc = RegisterAllocator::new();
303 let mut k = PtxKernel::new("has_mma_int8");
304 k.push(PtxInstruction::TensorCore(TensorCoreOp::MmaSyncInt8 {
305 d: alloc_c_M16N8K32(&mut alloc),
306 a: alloc_a_M16N8K32(&mut alloc),
307 b: alloc_b_M16N8K32(&mut alloc),
308 c: alloc_c_M16N8K32(&mut alloc),
309 }));
310 k
311 }
312
313 #[test]
314 fn validate_rejects_mma_int8_on_sm_70() {
315 let mut module = PtxModule::new("sm_70");
316 module.add_kernel(tc_int8_kernel());
317 let err = module.validate().unwrap_err();
318 assert_eq!(
319 err,
320 ValidationError::SmTooLow {
321 required: 80,
322 actual: 70,
323 feature: "mma.sync.m16n8k32.s8.s8.s32".to_string(),
324 }
325 );
326 assert_eq!(
327 err.to_string(),
328 "mma.sync.m16n8k32.s8.s8.s32 requires sm_80+, target is sm_70"
329 );
330 }
331
332 #[test]
333 fn validate_accepts_mma_int8_on_sm_80() {
334 let mut module = PtxModule::new("sm_80");
335 module.add_kernel(tc_int8_kernel());
336 assert!(module.validate().is_ok());
337 }
338
339 #[test]
340 fn validate_accepts_mma_int8_on_sm_89() {
341 let mut module = PtxModule::new("sm_89");
342 module.add_kernel(tc_int8_kernel());
343 assert!(module.validate().is_ok());
344 }
345
346 fn ldmatrix_kernel() -> PtxKernel {
347 use crate::instr::LdMatrixDst;
348 let mut alloc = RegisterAllocator::new();
349 let mut k = PtxKernel::new("has_ldmatrix");
350 k.push(PtxInstruction::TensorCore(TensorCoreOp::LdMatrix {
351 dst: LdMatrixDst::X4([
352 alloc.alloc_packed_half2(),
353 alloc.alloc_packed_half2(),
354 alloc.alloc_packed_half2(),
355 alloc.alloc_packed_half2(),
356 ]),
357 addr: alloc.alloc(PtxType::U32),
358 trans: false,
359 }));
360 k
361 }
362
363 #[test]
368 fn validate_accepts_ldmatrix_on_sm_75() {
369 let mut module = PtxModule::new("sm_75");
370 module.add_kernel(ldmatrix_kernel());
371 assert!(module.validate().is_ok());
372 }
373
374 #[test]
375 fn validate_rejects_ldmatrix_on_sm_70() {
376 let mut module = PtxModule::new("sm_70");
377 module.add_kernel(ldmatrix_kernel());
378 let err = module.validate().unwrap_err();
379 assert_eq!(
380 err,
381 ValidationError::SmTooLow {
382 required: 75,
383 actual: 70,
384 feature: "ldmatrix.m8n8.x4".to_string(),
385 }
386 );
387 assert_eq!(
388 err.to_string(),
389 "ldmatrix.m8n8.x4 requires sm_75+, target is sm_70"
390 );
391 }
392
393 #[test]
394 fn validate_rejects_mma_at_sm_75_even_with_ldmatrix_present() {
395 let mut module = PtxModule::new("sm_75");
396 module.add_kernel(ldmatrix_kernel());
397 module.add_kernel(tc_kernel());
398 let err = module.validate().unwrap_err();
399 assert_eq!(
400 err,
401 ValidationError::SmTooLow {
402 required: 80,
403 actual: 75,
404 feature: "mma.sync.m16n8k16".to_string(),
405 }
406 );
407 }
408
409 #[test]
410 fn validate_rejects_ldmatrix_bad_dst_reg_type() {
411 use crate::instr::LdMatrixDst;
412 let mut alloc = RegisterAllocator::new();
413 let mut k = PtxKernel::new("bad_ldmatrix_dst");
414 k.push(PtxInstruction::TensorCore(TensorCoreOp::LdMatrix {
416 dst: LdMatrixDst::X4([
417 alloc.alloc_packed_half2(),
418 alloc.alloc_packed_half2(),
419 alloc.alloc(PtxType::F32),
420 alloc.alloc_packed_half2(),
421 ]),
422 addr: alloc.alloc(PtxType::U32),
423 trans: false,
424 }));
425 let mut module = PtxModule::new("sm_80");
426 module.add_kernel(k);
427 let err = module.validate().unwrap_err();
428 assert_eq!(
429 err,
430 ValidationError::LdMatrixBadRegType {
431 operand: "dst",
432 found: PtxType::F32,
433 }
434 );
435 }
436
437 #[test]
438 fn validate_rejects_ldmatrix_bad_addr_reg_type() {
439 use crate::instr::LdMatrixDst;
440 let mut alloc = RegisterAllocator::new();
441 let mut k = PtxKernel::new("bad_ldmatrix_addr");
442 k.push(PtxInstruction::TensorCore(TensorCoreOp::LdMatrix {
445 dst: LdMatrixDst::X2([alloc.alloc_packed_half2(), alloc.alloc_packed_half2()]),
446 addr: alloc.alloc(PtxType::U64),
447 trans: true,
448 }));
449 let mut module = PtxModule::new("sm_80");
450 module.add_kernel(k);
451 let err = module.validate().unwrap_err();
452 assert_eq!(
453 err,
454 ValidationError::LdMatrixBadRegType {
455 operand: "addr",
456 found: PtxType::U64,
457 }
458 );
459 assert!(err.to_string().contains("alloc_packed_half2"));
460 }
461
462 #[test]
463 fn validate_rejects_cp_async_on_sm_75() {
464 let mut module = PtxModule::new("sm_75");
465 let mut k = PtxKernel::new("has_cp_async");
466 k.push(PtxInstruction::Memory(MemoryOp::new_cp_async_ca(
467 reg(RegKind::R, 0, PtxType::U32),
468 reg(RegKind::Rd, 0, PtxType::U64),
469 16,
470 )));
471 module.add_kernel(k);
472 let err = module.validate().unwrap_err();
473 assert_eq!(
474 err,
475 ValidationError::SmTooLow {
476 required: 80,
477 actual: 75,
478 feature: "cp.async".to_string(),
479 }
480 );
481 }
482
483 #[test]
484 fn validate_accepts_scalar_kernel_on_sm_70() {
485 let mut module = PtxModule::new("sm_70");
488 let k = PtxKernel::new("scalar_only");
489 module.add_kernel(k);
490 assert!(module.validate().is_ok());
491 }
492
493 #[test]
494 fn validate_skips_unparseable_target() {
495 let mut module = PtxModule::new("compute_90a");
497 module.add_kernel(tc_kernel());
498 assert!(module.validate().is_ok());
499 }
500
501 #[test]
502 fn parse_sm_target() {
503 let m = PtxModule::new("sm_89");
504 assert_eq!(m.parse_sm_target(), Some(89));
505 let m2 = PtxModule::new("sm_80");
506 assert_eq!(m2.parse_sm_target(), Some(80));
507 let m3 = PtxModule::new("compute_90a");
508 assert_eq!(m3.parse_sm_target(), None);
509 }
510
511 fn mma_sync_with_bf16_tags() -> PtxKernel {
512 let mut alloc = RegisterAllocator::new();
513 let mut k = PtxKernel::new("legacy_bf16_on_mma_sync");
514 k.push(PtxInstruction::TensorCore(TensorCoreOp::MmaSync {
515 d: alloc_c(&mut alloc),
516 a: alloc_a_f16(&mut alloc),
517 b: alloc_b_f16(&mut alloc),
518 c: alloc_c(&mut alloc),
519 shape: MmaShape::M16N8K16,
520 d_ty: PtxType::F32,
521 a_ty: PtxType::BF16,
522 b_ty: PtxType::BF16,
523 c_ty: PtxType::F32,
524 }));
525 k
526 }
527
528 #[test]
529 fn validate_rejects_mma_sync_bf16_a_ty() {
530 let mut module = PtxModule::new("sm_89");
531 module.add_kernel(mma_sync_with_bf16_tags());
532 let err = module.validate().unwrap_err();
533 assert_eq!(
534 err,
535 ValidationError::MmaSyncBf16Rejected { operand: "a_ty" }
536 );
537 assert_eq!(
538 err.to_string(),
539 "TensorCoreOp::MmaSync with PtxType::BF16 on a_ty is rejected; \
540 use TensorCoreOp::MmaSyncBf16 for bf16 emission"
541 );
542 }
543
544 #[test]
545 fn validate_rejects_mma_sync_bf16_b_ty_only() {
546 let mut alloc = RegisterAllocator::new();
548 let mut k = PtxKernel::new("mixed_bf16_b_only");
549 k.push(PtxInstruction::TensorCore(TensorCoreOp::MmaSync {
550 d: alloc_c(&mut alloc),
551 a: alloc_a_f16(&mut alloc),
552 b: alloc_b_f16(&mut alloc),
553 c: alloc_c(&mut alloc),
554 shape: MmaShape::M16N8K16,
555 d_ty: PtxType::F32,
556 a_ty: PtxType::F16,
557 b_ty: PtxType::BF16,
558 c_ty: PtxType::F32,
559 }));
560 let mut module = PtxModule::new("sm_89");
561 module.add_kernel(k);
562 let err = module.validate().unwrap_err();
563 assert_eq!(
564 err,
565 ValidationError::MmaSyncBf16Rejected { operand: "b_ty" }
566 );
567 }
568
569 #[test]
570 fn validate_rejects_mma_sync_bf16_even_on_unparseable_target() {
571 let mut module = PtxModule::new("compute_90a");
574 module.add_kernel(mma_sync_with_bf16_tags());
575 assert_eq!(
576 module.validate().unwrap_err(),
577 ValidationError::MmaSyncBf16Rejected { operand: "a_ty" }
578 );
579 }
580
581 fn mma_sync_bf16_kernel() -> PtxKernel {
582 use crate::fragment::{alloc_a_bf16, alloc_b_bf16};
583 let mut alloc = RegisterAllocator::new();
584 let mut k = PtxKernel::new("native_bf16");
585 k.push(PtxInstruction::TensorCore(TensorCoreOp::MmaSyncBf16 {
586 d: alloc_c(&mut alloc),
587 a: alloc_a_bf16(&mut alloc),
588 b: alloc_b_bf16(&mut alloc),
589 c: alloc_c(&mut alloc),
590 }));
591 k
592 }
593
594 #[test]
595 fn validate_accepts_mma_sync_bf16_dedicated_variant() {
596 let mut module = PtxModule::new("sm_89");
597 module.add_kernel(mma_sync_bf16_kernel());
598 assert!(module.validate().is_ok());
599 }
600}