1use std::collections::HashSet;
22
23use crate::arch::SmVersion;
24use crate::ir::{Instruction, MemorySpace, Operand, WmmaOp};
25
26#[derive(Debug, Clone)]
31pub struct ValidationResult {
32 pub errors: Vec<ValidationError>,
34 pub warnings: Vec<String>,
36}
37
38impl ValidationResult {
39 #[must_use]
41 pub fn is_ok(&self) -> bool {
42 self.errors.is_empty()
43 }
44
45 #[must_use]
47 pub fn has_errors(&self) -> bool {
48 !self.errors.is_empty()
49 }
50}
51
52#[derive(Debug, Clone)]
56pub enum ValidationError {
57 MissingVersionDirective,
59 MissingTargetDirective,
61 UndefinedRegister(String),
63 TypeMismatch {
65 expected: String,
67 found: String,
69 },
70 InvalidSharedMemSize {
72 declared: usize,
74 max_allowed: usize,
76 },
77 InvalidAddressSize(String),
79 SmIncompatibleInstruction {
81 instruction: String,
83 required_sm: String,
85 found_sm: String,
87 },
88 RegisterPressureExceeded {
90 count: usize,
92 max_allowed: usize,
94 },
95 Other(String),
97}
98
99impl std::fmt::Display for ValidationError {
100 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
101 match self {
102 Self::MissingVersionDirective => write!(f, "missing .version directive"),
103 Self::MissingTargetDirective => write!(f, "missing .target directive"),
104 Self::UndefinedRegister(name) => write!(f, "undefined register: {name}"),
105 Self::TypeMismatch { expected, found } => {
106 write!(f, "type mismatch: expected {expected}, found {found}")
107 }
108 Self::InvalidSharedMemSize {
109 declared,
110 max_allowed,
111 } => {
112 write!(
113 f,
114 "shared memory {declared} bytes exceeds limit of {max_allowed} bytes"
115 )
116 }
117 Self::InvalidAddressSize(msg) => write!(f, "address size issue: {msg}"),
118 Self::SmIncompatibleInstruction {
119 instruction,
120 required_sm,
121 found_sm,
122 } => write!(
123 f,
124 "instruction '{instruction}' requires {required_sm} but target is {found_sm}"
125 ),
126 Self::RegisterPressureExceeded { count, max_allowed } => write!(
127 f,
128 "register count {count} exceeds per-thread limit of {max_allowed}"
129 ),
130 Self::Other(msg) => write!(f, "{msg}"),
131 }
132 }
133}
134
135#[must_use]
151pub fn validate_ptx(ptx: &str) -> ValidationResult {
152 let mut errors = Vec::new();
153 let mut warnings = Vec::new();
154
155 if !ptx.contains(".version") {
157 errors.push(ValidationError::MissingVersionDirective);
158 }
159
160 if !ptx.contains(".target") {
162 errors.push(ValidationError::MissingTargetDirective);
163 }
164
165 let target_sm = extract_target_sm(ptx);
167
168 check_shared_memory(ptx, target_sm, &mut errors, &mut warnings);
170
171 check_register_declarations(ptx, &mut warnings);
173
174 check_register_pressure(ptx, &mut errors, &mut warnings);
176
177 if let Some(sm) = target_sm {
179 check_sm_compatibility(ptx, sm, &mut errors, &mut warnings);
180 }
181
182 check_structure(ptx, &mut warnings);
184
185 ValidationResult { errors, warnings }
186}
187
188#[must_use]
195pub fn validate_ptx_for_target(ptx: &str, target: SmVersion) -> ValidationResult {
196 let mut errors = Vec::new();
197 let mut warnings = Vec::new();
198
199 if !ptx.contains(".version") {
201 errors.push(ValidationError::MissingVersionDirective);
202 }
203
204 if !ptx.contains(".target") {
206 errors.push(ValidationError::MissingTargetDirective);
207 }
208
209 check_shared_memory(ptx, Some(target), &mut errors, &mut warnings);
211
212 check_register_declarations(ptx, &mut warnings);
214
215 check_register_pressure(ptx, &mut errors, &mut warnings);
217
218 check_sm_compatibility(ptx, target, &mut errors, &mut warnings);
220
221 check_structure(ptx, &mut warnings);
223
224 ValidationResult { errors, warnings }
225}
226
227fn extract_target_sm(ptx: &str) -> Option<SmVersion> {
229 for line in ptx.lines() {
230 let trimmed = line.trim();
231 if trimmed.starts_with(".target") {
232 let parts: Vec<&str> = trimmed.split_whitespace().collect();
233 if parts.len() >= 2 {
234 return parse_sm_version(parts[1].trim_end_matches(';'));
235 }
236 }
237 }
238 None
239}
240
241fn parse_sm_version(s: &str) -> Option<SmVersion> {
243 match s {
244 "sm_75" => Some(SmVersion::Sm75),
245 "sm_80" => Some(SmVersion::Sm80),
246 "sm_86" => Some(SmVersion::Sm86),
247 "sm_89" => Some(SmVersion::Sm89),
248 "sm_90" => Some(SmVersion::Sm90),
249 "sm_90a" => Some(SmVersion::Sm90a),
250 "sm_100" => Some(SmVersion::Sm100),
251 "sm_120" => Some(SmVersion::Sm120),
252 _ => None,
253 }
254}
255
256fn check_shared_memory(
258 ptx: &str,
259 target: Option<SmVersion>,
260 errors: &mut Vec<ValidationError>,
261 warnings: &mut Vec<String>,
262) {
263 let max_smem = target.map_or(usize::MAX, |sm| sm.max_shared_mem_per_block() as usize);
264
265 let mut total_smem: usize = 0;
266
267 for line in ptx.lines() {
268 let trimmed = line.trim();
269 if let Some(size) = extract_shared_mem_size(trimmed) {
270 total_smem = total_smem.saturating_add(size);
271 }
272 }
273
274 if total_smem > max_smem {
275 errors.push(ValidationError::InvalidSharedMemSize {
276 declared: total_smem,
277 max_allowed: max_smem,
278 });
279 } else if total_smem > 48 * 1024 && target.is_some() {
280 warnings.push(format!(
281 "shared memory usage ({total_smem} bytes) exceeds default limit (49152); \
282 may require opt-in via cuFuncSetAttribute"
283 ));
284 }
285}
286
287fn extract_shared_mem_size(line: &str) -> Option<usize> {
291 if !line.contains(".shared") {
292 return None;
293 }
294
295 let bracket_start = line.find('[')?;
297 let bracket_end = line.find(']')?;
298 if bracket_end <= bracket_start {
299 return None;
300 }
301
302 let size_str = &line[bracket_start + 1..bracket_end];
303 size_str.trim().parse::<usize>().ok()
304}
305
306fn check_register_declarations(ptx: &str, warnings: &mut Vec<String>) {
308 let decl_count = ptx
310 .lines()
311 .filter(|line| line.trim().starts_with(".reg"))
312 .count();
313
314 let entry_count = ptx.lines().filter(|line| line.contains(".entry")).count();
316
317 if entry_count > 0 && decl_count == 0 {
318 warnings.push(
319 "kernel has no .reg declarations; all registers may be declared via raw PTX"
320 .to_string(),
321 );
322 }
323}
324
325fn check_structure(ptx: &str, warnings: &mut Vec<String>) {
327 let open_braces = ptx.chars().filter(|c| *c == '{').count();
328 let close_braces = ptx.chars().filter(|c| *c == '}').count();
329
330 if open_braces != close_braces {
331 warnings.push(format!(
332 "mismatched braces: {open_braces} opening vs {close_braces} closing"
333 ));
334 }
335}
336
337struct SmRequirement {
343 pattern: &'static str,
345 min_sm: SmVersion,
347 name: &'static str,
349}
350
351const SM_REQUIREMENTS: &[SmRequirement] = &[
353 SmRequirement {
354 pattern: "cp.async",
355 min_sm: SmVersion::Sm80,
356 name: "cp.async",
357 },
358 SmRequirement {
359 pattern: "wgmma",
360 min_sm: SmVersion::Sm90,
361 name: "wgmma",
362 },
363 SmRequirement {
364 pattern: "mma.sync",
365 min_sm: SmVersion::Sm75,
366 name: "mma.sync (tensor core)",
367 },
368 SmRequirement {
369 pattern: "ldmatrix",
370 min_sm: SmVersion::Sm75,
371 name: "ldmatrix",
372 },
373 SmRequirement {
374 pattern: ".e4m3",
375 min_sm: SmVersion::Sm89,
376 name: "fp8 e4m3 type",
377 },
378 SmRequirement {
379 pattern: ".e5m2",
380 min_sm: SmVersion::Sm89,
381 name: "fp8 e5m2 type",
382 },
383 SmRequirement {
384 pattern: "tcgen05",
385 min_sm: SmVersion::Sm100,
386 name: "tcgen05",
387 },
388];
389
390fn check_sm_compatibility(
395 ptx: &str,
396 sm: SmVersion,
397 errors: &mut Vec<ValidationError>,
398 _warnings: &mut Vec<String>,
399) {
400 let found_sm_str = sm.as_ptx_str();
401 for req in SM_REQUIREMENTS {
402 if ptx.contains(req.pattern) && sm < req.min_sm {
403 errors.push(ValidationError::SmIncompatibleInstruction {
404 instruction: req.name.to_string(),
405 required_sm: req.min_sm.as_ptx_str().to_string(),
406 found_sm: found_sm_str.to_string(),
407 });
408 }
409 }
410}
411
412const MAX_REGISTERS_PER_THREAD: usize = 255;
418
419const REGISTER_PRESSURE_WARNING_THRESHOLD: usize = 200;
421
422fn check_register_pressure(
434 ptx: &str,
435 errors: &mut Vec<ValidationError>,
436 warnings: &mut Vec<String>,
437) {
438 use std::collections::HashSet;
439
440 let mut seen: HashSet<&str> = HashSet::new();
441
442 let bytes = ptx.as_bytes();
446 let len = bytes.len();
447 let mut i = 0;
448 while i < len {
449 if bytes[i] == b'%' {
450 let start = i;
451 i += 1;
452 while i < len && bytes[i].is_ascii_alphabetic() {
454 i += 1;
455 }
456 if i < len && bytes[i].is_ascii_digit() {
458 while i < len && bytes[i].is_ascii_digit() {
459 i += 1;
460 }
461 let token = &ptx[start..i];
463 let name_part = &token[1..]; let is_special = name_part.starts_with("tid")
468 || name_part.starts_with("ntid")
469 || name_part.starts_with("ctaid")
470 || name_part.starts_with("nctaid")
471 || name_part.starts_with("laneid")
472 || name_part.starts_with("warpid")
473 || name_part.starts_with("smid")
474 || name_part.starts_with("pm")
475 || name_part.starts_with("envreg")
476 || name_part.starts_with("globaltimer")
477 || name_part.starts_with("param_");
478 if !is_special {
479 seen.insert(token);
480 }
481 }
482 } else {
483 i += 1;
484 }
485 }
486
487 let count = seen.len();
488 if count > MAX_REGISTERS_PER_THREAD {
489 errors.push(ValidationError::RegisterPressureExceeded {
490 count,
491 max_allowed: MAX_REGISTERS_PER_THREAD,
492 });
493 } else if count > REGISTER_PRESSURE_WARNING_THRESHOLD {
494 warnings.push(format!(
495 "register count ({count}) is approaching the per-thread limit of \
496 {MAX_REGISTERS_PER_THREAD}; consider reducing register pressure"
497 ));
498 }
499}
500
501#[derive(Debug, Clone)]
507pub struct IrValidationResult {
508 pub errors: Vec<IrValidationError>,
510 pub warnings: Vec<IrValidationWarning>,
512}
513
514impl IrValidationResult {
515 #[must_use]
517 pub fn is_ok(&self) -> bool {
518 self.errors.is_empty()
519 }
520
521 #[must_use]
523 pub fn has_errors(&self) -> bool {
524 !self.errors.is_empty()
525 }
526
527 fn merge(&mut self, other: &Self) {
529 self.errors.extend(other.errors.iter().cloned());
530 self.warnings.extend(other.warnings.iter().cloned());
531 }
532}
533
534impl std::fmt::Display for IrValidationResult {
535 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
536 if self.errors.is_empty() && self.warnings.is_empty() {
537 return write!(f, "IR validation passed: no errors, no warnings");
538 }
539 if !self.errors.is_empty() {
540 writeln!(f, "Errors ({}):", self.errors.len())?;
541 for err in &self.errors {
542 writeln!(
543 f,
544 " [{:>3}] {}: {}",
545 err.instruction_index, err.kind, err.message
546 )?;
547 }
548 }
549 if !self.warnings.is_empty() {
550 writeln!(f, "Warnings ({}):", self.warnings.len())?;
551 for warn in &self.warnings {
552 writeln!(f, " [{:>3}] {}", warn.instruction_index, warn.message)?;
553 }
554 }
555 Ok(())
556 }
557}
558
559#[derive(Debug, Clone)]
561pub struct IrValidationError {
562 pub instruction_index: usize,
564 pub kind: IrErrorKind,
566 pub message: String,
568}
569
570#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
572pub enum IrErrorKind {
573 TypeMismatch,
575 UseBeforeDef,
577 InvalidMemorySpace,
579 InvalidOperand,
581 BarrierInDivergent,
583 RegisterLifetime,
585}
586
587impl std::fmt::Display for IrErrorKind {
588 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
589 match self {
590 Self::TypeMismatch => write!(f, "TypeMismatch"),
591 Self::UseBeforeDef => write!(f, "UseBeforeDef"),
592 Self::InvalidMemorySpace => write!(f, "InvalidMemorySpace"),
593 Self::InvalidOperand => write!(f, "InvalidOperand"),
594 Self::BarrierInDivergent => write!(f, "BarrierInDivergent"),
595 Self::RegisterLifetime => write!(f, "RegisterLifetime"),
596 }
597 }
598}
599
600#[derive(Debug, Clone)]
602pub struct IrValidationWarning {
603 pub instruction_index: usize,
605 pub message: String,
607}
608
609fn push_operand_names(op: &Operand, names: &mut Vec<String>) {
615 if let Operand::Register(r) = op {
616 names.push(r.name.clone());
617 }
618 if let Operand::Address { base, .. } = op {
619 names.push(base.name.clone());
620 }
621}
622
623#[allow(clippy::too_many_lines)]
625fn collect_src_register_names(inst: &Instruction) -> Vec<String> {
626 let mut names = Vec::new();
627
628 match inst {
629 Instruction::Add { a, b, .. }
630 | Instruction::Sub { a, b, .. }
631 | Instruction::Mul { a, b, .. }
632 | Instruction::Min { a, b, .. }
633 | Instruction::Max { a, b, .. }
634 | Instruction::Div { a, b, .. }
635 | Instruction::Rem { a, b, .. }
636 | Instruction::And { a, b, .. }
637 | Instruction::Or { a, b, .. }
638 | Instruction::Xor { a, b, .. }
639 | Instruction::SetP { a, b, .. } => {
640 push_operand_names(a, &mut names);
641 push_operand_names(b, &mut names);
642 }
643 Instruction::Mad { a, b, c, .. }
644 | Instruction::MadLo { a, b, c, .. }
645 | Instruction::MadHi { a, b, c, .. }
646 | Instruction::MadWide { a, b, c, .. }
647 | Instruction::Fma { a, b, c, .. }
648 | Instruction::Dp4a { a, b, c, .. }
649 | Instruction::Dp2a { a, b, c, .. } => {
650 push_operand_names(a, &mut names);
651 push_operand_names(b, &mut names);
652 push_operand_names(c, &mut names);
653 }
654 Instruction::Neg { src, .. }
655 | Instruction::Abs { src, .. }
656 | Instruction::Brev { src, .. }
657 | Instruction::Clz { src, .. }
658 | Instruction::Popc { src, .. }
659 | Instruction::Bfind { src, .. }
660 | Instruction::Rcp { src, .. }
661 | Instruction::Rsqrt { src, .. }
662 | Instruction::Sqrt { src, .. }
663 | Instruction::Ex2 { src, .. }
664 | Instruction::Lg2 { src, .. }
665 | Instruction::Sin { src, .. }
666 | Instruction::Cos { src, .. }
667 | Instruction::Cvt { src, .. }
668 | Instruction::Redux { src, .. } => {
669 push_operand_names(src, &mut names);
670 }
671 Instruction::Bfe {
672 src, start, len, ..
673 } => {
674 push_operand_names(src, &mut names);
675 push_operand_names(start, &mut names);
676 push_operand_names(len, &mut names);
677 }
678 Instruction::Bfi {
679 insert,
680 base,
681 start,
682 len,
683 ..
684 } => {
685 push_operand_names(insert, &mut names);
686 push_operand_names(base, &mut names);
687 push_operand_names(start, &mut names);
688 push_operand_names(len, &mut names);
689 }
690 Instruction::Shl { src, amount, .. } | Instruction::Shr { src, amount, .. } => {
691 push_operand_names(src, &mut names);
692 push_operand_names(amount, &mut names);
693 }
694 Instruction::Load { addr, .. } | Instruction::MbarrierArrive { addr } => {
695 push_operand_names(addr, &mut names);
696 }
697 Instruction::Store { addr, src, .. } => {
698 push_operand_names(addr, &mut names);
699 names.push(src.name.clone());
700 }
701 Instruction::CpAsync {
702 dst_shared,
703 src_global,
704 ..
705 } => {
706 push_operand_names(dst_shared, &mut names);
707 push_operand_names(src_global, &mut names);
708 }
709 Instruction::Branch { predicate, .. } => {
710 if let Some((r, _)) = predicate {
711 names.push(r.name.clone());
712 }
713 }
714 Instruction::Atom { addr, src, .. } | Instruction::Red { addr, src, .. } => {
715 push_operand_names(addr, &mut names);
716 push_operand_names(src, &mut names);
717 }
718 Instruction::AtomCas {
719 addr,
720 compare,
721 value,
722 ..
723 } => {
724 push_operand_names(addr, &mut names);
725 push_operand_names(compare, &mut names);
726 push_operand_names(value, &mut names);
727 }
728 Instruction::Tex1d { coord, .. } | Instruction::SurfLoad { coord, .. } => {
729 push_operand_names(coord, &mut names);
730 }
731 Instruction::Tex2d {
732 coord_x, coord_y, ..
733 } => {
734 push_operand_names(coord_x, &mut names);
735 push_operand_names(coord_y, &mut names);
736 }
737 Instruction::Tex3d {
738 coord_x,
739 coord_y,
740 coord_z,
741 ..
742 } => {
743 push_operand_names(coord_x, &mut names);
744 push_operand_names(coord_y, &mut names);
745 push_operand_names(coord_z, &mut names);
746 }
747 Instruction::SurfStore { coord, src, .. } => {
748 push_operand_names(coord, &mut names);
749 names.push(src.name.clone());
750 }
751 Instruction::Wmma {
752 fragments,
753 addr,
754 stride,
755 ..
756 } => {
757 for frag in fragments {
758 names.push(frag.name.clone());
759 }
760 if let Some(a) = addr {
761 push_operand_names(a, &mut names);
762 }
763 if let Some(s) = stride {
764 push_operand_names(s, &mut names);
765 }
766 }
767 Instruction::Mma {
768 a_regs,
769 b_regs,
770 c_regs,
771 ..
772 } => {
773 for r in a_regs.iter().chain(b_regs).chain(c_regs) {
774 names.push(r.name.clone());
775 }
776 }
777 Instruction::Wgmma { desc_a, desc_b, .. } => {
778 names.push(desc_a.name.clone());
779 names.push(desc_b.name.clone());
780 }
781 Instruction::TmaLoad {
782 desc,
783 coords,
784 barrier,
785 dst_shared,
786 ..
787 } => {
788 names.push(desc.name.clone());
789 for c in coords {
790 names.push(c.name.clone());
791 }
792 names.push(barrier.name.clone());
793 push_operand_names(dst_shared, &mut names);
794 }
795 Instruction::Stmatrix { dst_addr, src, .. } => {
796 push_operand_names(dst_addr, &mut names);
797 names.push(src.name.clone());
798 }
799 Instruction::MbarrierInit { addr, count } => {
800 push_operand_names(addr, &mut names);
801 push_operand_names(count, &mut names);
802 }
803 Instruction::MbarrierWait { addr, phase } => {
804 push_operand_names(addr, &mut names);
805 push_operand_names(phase, &mut names);
806 }
807 Instruction::MovSpecial { .. }
808 | Instruction::LoadParam { .. }
809 | Instruction::Label(_)
810 | Instruction::Return
811 | Instruction::Comment(_)
812 | Instruction::Raw(_)
813 | Instruction::Pragma(_)
814 | Instruction::BarSync { .. }
815 | Instruction::BarArrive { .. }
816 | Instruction::FenceAcqRel { .. }
817 | Instruction::FenceProxy { .. }
818 | Instruction::CpAsyncCommit
819 | Instruction::CpAsyncWait { .. }
820 | Instruction::ElectSync { .. }
821 | Instruction::Setmaxnreg { .. }
822 | Instruction::Griddepcontrol { .. }
823 | Instruction::BarrierCluster
824 | Instruction::FenceCluster => {}
825
826 Instruction::Tcgen05Mma { a_desc, b_desc } => {
827 names.push(a_desc.name.clone());
828 names.push(b_desc.name.clone());
829 }
830 Instruction::CpAsyncBulk {
831 dst_smem,
832 src_gmem,
833 desc,
834 } => {
835 names.push(dst_smem.name.clone());
836 names.push(src_gmem.name.clone());
837 names.push(desc.name.clone());
838 }
839 Instruction::Ldmatrix { src_addr, .. } => {
840 push_operand_names(src_addr, &mut names);
841 }
842 }
843 names
844}
845
846fn dst_register_name(inst: &Instruction) -> Option<String> {
848 match inst {
849 Instruction::Add { dst, .. }
850 | Instruction::Sub { dst, .. }
851 | Instruction::Mul { dst, .. }
852 | Instruction::Min { dst, .. }
853 | Instruction::Max { dst, .. }
854 | Instruction::Div { dst, .. }
855 | Instruction::Rem { dst, .. }
856 | Instruction::And { dst, .. }
857 | Instruction::Or { dst, .. }
858 | Instruction::Xor { dst, .. }
859 | Instruction::SetP { dst, .. }
860 | Instruction::Mad { dst, .. }
861 | Instruction::MadLo { dst, .. }
862 | Instruction::MadHi { dst, .. }
863 | Instruction::MadWide { dst, .. }
864 | Instruction::Fma { dst, .. }
865 | Instruction::Neg { dst, .. }
866 | Instruction::Abs { dst, .. }
867 | Instruction::Brev { dst, .. }
868 | Instruction::Clz { dst, .. }
869 | Instruction::Popc { dst, .. }
870 | Instruction::Bfind { dst, .. }
871 | Instruction::Bfe { dst, .. }
872 | Instruction::Bfi { dst, .. }
873 | Instruction::Rcp { dst, .. }
874 | Instruction::Rsqrt { dst, .. }
875 | Instruction::Sqrt { dst, .. }
876 | Instruction::Ex2 { dst, .. }
877 | Instruction::Lg2 { dst, .. }
878 | Instruction::Sin { dst, .. }
879 | Instruction::Cos { dst, .. }
880 | Instruction::Shl { dst, .. }
881 | Instruction::Shr { dst, .. }
882 | Instruction::Load { dst, .. }
883 | Instruction::Cvt { dst, .. }
884 | Instruction::Atom { dst, .. }
885 | Instruction::AtomCas { dst, .. }
886 | Instruction::MovSpecial { dst, .. }
887 | Instruction::LoadParam { dst, .. }
888 | Instruction::Dp4a { dst, .. }
889 | Instruction::Dp2a { dst, .. }
890 | Instruction::Tex1d { dst, .. }
891 | Instruction::Tex2d { dst, .. }
892 | Instruction::Tex3d { dst, .. }
893 | Instruction::SurfLoad { dst, .. }
894 | Instruction::Redux { dst, .. }
895 | Instruction::ElectSync { dst, .. } => Some(dst.name.clone()),
896 Instruction::Mma { d_regs, .. } => d_regs.first().map(|r| r.name.clone()),
897 Instruction::Wgmma { d_regs, .. } => d_regs.first().map(|r| r.name.clone()),
898 _ => None,
899 }
900}
901
902fn operand_type_compatible(op: &Operand, expected_ty: crate::ir::PtxType) -> bool {
909 match op {
910 Operand::Register(r) => r.ty == expected_ty,
911 Operand::Immediate(_) | Operand::Symbol(_) | Operand::Address { .. } => true,
913 }
914}
915
916#[must_use]
928pub fn validate_ir_instructions(instructions: &[Instruction]) -> IrValidationResult {
929 let mut result = IrValidationResult {
930 errors: Vec::new(),
931 warnings: Vec::new(),
932 };
933
934 let lifetime_result = validate_register_lifetimes(instructions);
936 result.merge(&lifetime_result);
937
938 let consistency_result = validate_memory_consistency(instructions);
939 result.merge(&consistency_result);
940
941 for (idx, inst) in instructions.iter().enumerate() {
943 validate_type_safety(inst, idx, &mut result);
944 validate_memory_spaces(inst, idx, &mut result);
945 validate_tensor_core_operands(inst, idx, &mut result);
946 }
947
948 result
949}
950
951#[must_use]
957pub fn validate_register_lifetimes(instructions: &[Instruction]) -> IrValidationResult {
958 let mut result = IrValidationResult {
959 errors: Vec::new(),
960 warnings: Vec::new(),
961 };
962
963 let mut defined: HashSet<String> = HashSet::new();
964
965 for (idx, inst) in instructions.iter().enumerate() {
966 let src_names = collect_src_register_names(inst);
968 for name in &src_names {
969 if !defined.contains(name) {
970 result.errors.push(IrValidationError {
971 instruction_index: idx,
972 kind: IrErrorKind::UseBeforeDef,
973 message: format!("register {name} used before definition"),
974 });
975 }
976 }
977
978 if let Some(dst_name) = dst_register_name(inst) {
980 defined.insert(dst_name);
981 }
982
983 match inst {
985 Instruction::Mma { d_regs, .. } | Instruction::Wgmma { d_regs, .. } => {
986 for r in d_regs {
987 defined.insert(r.name.clone());
988 }
989 }
990 Instruction::Wmma { op, fragments, .. } => {
991 if matches!(op, WmmaOp::LoadA | WmmaOp::LoadB) {
993 for frag in fragments {
994 defined.insert(frag.name.clone());
995 }
996 }
997 }
998 _ => {}
999 }
1000 }
1001
1002 result
1003}
1004
1005#[must_use]
1011pub fn validate_memory_consistency(instructions: &[Instruction]) -> IrValidationResult {
1012 let mut result = IrValidationResult {
1013 errors: Vec::new(),
1014 warnings: Vec::new(),
1015 };
1016
1017 check_barrier_divergence(instructions, &mut result);
1019
1020 check_shared_memory_races(instructions, &mut result);
1022
1023 result
1024}
1025
1026fn validate_type_safety(inst: &Instruction, idx: usize, result: &mut IrValidationResult) {
1032 match inst {
1033 Instruction::Add { ty, dst, a, b }
1034 | Instruction::Sub { ty, dst, a, b }
1035 | Instruction::Min { ty, dst, a, b }
1036 | Instruction::Max { ty, dst, a, b } => {
1037 if dst.ty != *ty {
1038 result.errors.push(IrValidationError {
1039 instruction_index: idx,
1040 kind: IrErrorKind::TypeMismatch,
1041 message: format!(
1042 "dst register {} has type {:?} but instruction type is {:?}",
1043 dst.name, dst.ty, ty
1044 ),
1045 });
1046 }
1047 if !operand_type_compatible(a, *ty) {
1048 result.errors.push(IrValidationError {
1049 instruction_index: idx,
1050 kind: IrErrorKind::TypeMismatch,
1051 message: format!("operand a type mismatch with instruction type {ty:?}"),
1052 });
1053 }
1054 if !operand_type_compatible(b, *ty) {
1055 result.errors.push(IrValidationError {
1056 instruction_index: idx,
1057 kind: IrErrorKind::TypeMismatch,
1058 message: format!("operand b type mismatch with instruction type {ty:?}"),
1059 });
1060 }
1061 }
1062 Instruction::Mul { ty, dst, a, b, .. } => {
1063 if !operand_type_compatible(a, *ty) {
1066 result.errors.push(IrValidationError {
1067 instruction_index: idx,
1068 kind: IrErrorKind::TypeMismatch,
1069 message: format!("mul operand a type mismatch with instruction type {ty:?}"),
1070 });
1071 }
1072 if !operand_type_compatible(b, *ty) {
1073 result.errors.push(IrValidationError {
1074 instruction_index: idx,
1075 kind: IrErrorKind::TypeMismatch,
1076 message: format!("mul operand b type mismatch with instruction type {ty:?}"),
1077 });
1078 }
1079 if dst.ty != *ty {
1081 result.warnings.push(IrValidationWarning {
1082 instruction_index: idx,
1083 message: format!(
1084 "mul dst register {} type {:?} differs from instruction type {:?}",
1085 dst.name, dst.ty, ty
1086 ),
1087 });
1088 }
1089 }
1090 _ => {}
1091 }
1092}
1093
1094fn validate_memory_spaces(inst: &Instruction, idx: usize, result: &mut IrValidationResult) {
1096 if let Instruction::CpAsync {
1097 dst_shared: Operand::Register(r),
1098 ..
1099 } = inst
1100 {
1101 result.warnings.push(IrValidationWarning {
1103 instruction_index: idx,
1104 message: format!(
1105 "cp.async dst_shared uses register {} directly; expected a shared memory address",
1106 r.name
1107 ),
1108 });
1109 }
1110
1111 match inst {
1113 Instruction::Load {
1114 space,
1115 addr: Operand::Immediate(_),
1116 ..
1117 } if *space == MemorySpace::Shared => {
1118 result.errors.push(IrValidationError {
1119 instruction_index: idx,
1120 kind: IrErrorKind::InvalidMemorySpace,
1121 message: "shared memory load with immediate address is invalid".to_string(),
1122 });
1123 }
1124 Instruction::Store {
1125 space,
1126 addr: Operand::Immediate(_),
1127 ..
1128 } if *space == MemorySpace::Shared => {
1129 result.errors.push(IrValidationError {
1130 instruction_index: idx,
1131 kind: IrErrorKind::InvalidMemorySpace,
1132 message: "shared memory store with immediate address is invalid".to_string(),
1133 });
1134 }
1135 _ => {}
1136 }
1137}
1138
1139fn validate_tensor_core_operands(inst: &Instruction, idx: usize, result: &mut IrValidationResult) {
1141 match inst {
1142 Instruction::Wmma { addr, stride, .. } => {
1143 if let Some(Operand::Immediate(_)) = addr.as_ref() {
1145 result.errors.push(IrValidationError {
1146 instruction_index: idx,
1147 kind: IrErrorKind::InvalidOperand,
1148 message: "wmma address operand must not be an immediate value".to_string(),
1149 });
1150 }
1151 if let Some(Operand::Immediate(_)) = stride.as_ref() {
1152 result.errors.push(IrValidationError {
1153 instruction_index: idx,
1154 kind: IrErrorKind::InvalidOperand,
1155 message: "wmma stride operand must not be an immediate value".to_string(),
1156 });
1157 }
1158 }
1159 Instruction::Mma {
1160 a_regs,
1161 b_regs,
1162 c_regs,
1163 d_regs,
1164 ..
1165 }
1166 if (a_regs.is_empty() || b_regs.is_empty() || c_regs.is_empty() || d_regs.is_empty()) => {
1168 result.errors.push(IrValidationError {
1169 instruction_index: idx,
1170 kind: IrErrorKind::InvalidOperand,
1171 message: "mma instruction requires non-empty register fragments".to_string(),
1172 });
1173 }
1174 Instruction::Wgmma { d_regs, .. }
1175 if d_regs.is_empty() => {
1176 result.errors.push(IrValidationError {
1177 instruction_index: idx,
1178 kind: IrErrorKind::InvalidOperand,
1179 message: "wgmma instruction requires non-empty destination registers".to_string(),
1180 });
1181 }
1182 _ => {}
1183 }
1184}
1185
1186fn check_barrier_divergence(instructions: &[Instruction], result: &mut IrValidationResult) {
1188 let all_labels: HashSet<&str> = instructions
1190 .iter()
1191 .filter_map(|inst| {
1192 if let Instruction::Label(name) = inst {
1193 Some(name.as_str())
1194 } else {
1195 None
1196 }
1197 })
1198 .collect();
1199
1200 let mut in_conditional_region = false;
1201 let mut conditional_branch_idx = 0;
1202
1203 for (idx, inst) in instructions.iter().enumerate() {
1204 match inst {
1205 Instruction::Branch {
1206 predicate: Some(_),
1207 target,
1208 ..
1209 }
1210 if all_labels.contains(target.as_str()) => {
1213 in_conditional_region = true;
1214 conditional_branch_idx = idx;
1215 }
1216 Instruction::Label(_) => {
1217 in_conditional_region = false;
1219 }
1220 Instruction::BarSync { .. }
1221 if in_conditional_region => {
1222 result.warnings.push(IrValidationWarning {
1223 instruction_index: idx,
1224 message: format!(
1225 "bar.sync inside potentially divergent control flow \
1226 (conditional branch at instruction {conditional_branch_idx}); \
1227 this may cause deadlock if not all threads reach the barrier"
1228 ),
1229 });
1230 }
1231 _ => {}
1232 }
1233 }
1234}
1235
1236fn check_shared_memory_races(instructions: &[Instruction], result: &mut IrValidationResult) {
1241 let mut pending_shared_store: Option<usize> = None;
1242
1243 for (idx, inst) in instructions.iter().enumerate() {
1244 match inst {
1245 Instruction::Store {
1246 space: MemorySpace::Shared,
1247 ..
1248 } => {
1249 pending_shared_store = Some(idx);
1250 }
1251 Instruction::BarSync { .. } => {
1252 pending_shared_store = None;
1254 }
1255 Instruction::Load {
1256 space: MemorySpace::Shared,
1257 ..
1258 } => {
1259 if let Some(store_idx) = pending_shared_store {
1260 result.warnings.push(IrValidationWarning {
1261 instruction_index: idx,
1262 message: format!(
1263 "shared memory load without bar.sync after shared memory \
1264 store at instruction {store_idx}; potential race condition"
1265 ),
1266 });
1267 }
1268 }
1269 _ => {}
1270 }
1271 }
1272}
1273
1274#[cfg(test)]
1275mod tests {
1276 use super::*;
1277 use crate::ir::{
1278 CacheQualifier, ImmValue, Instruction, MemorySpace, Operand, PtxType, Register, SpecialReg,
1279 VectorWidth, WmmaLayout, WmmaOp, WmmaShape,
1280 };
1281
1282 #[test]
1283 fn valid_minimal_ptx() {
1284 let ptx = ".version 8.5\n.target sm_90a\n.address_size 64\n";
1285 let result = validate_ptx(ptx);
1286 assert!(result.is_ok());
1287 assert!(result.errors.is_empty());
1288 }
1289
1290 #[test]
1291 fn missing_version() {
1292 let ptx = ".target sm_80\n.address_size 64\n";
1293 let result = validate_ptx(ptx);
1294 assert!(result.has_errors());
1295 assert!(
1296 result
1297 .errors
1298 .iter()
1299 .any(|e| matches!(e, ValidationError::MissingVersionDirective))
1300 );
1301 }
1302
1303 #[test]
1304 fn missing_target() {
1305 let ptx = ".version 8.5\n.address_size 64\n";
1306 let result = validate_ptx(ptx);
1307 assert!(result.has_errors());
1308 assert!(
1309 result
1310 .errors
1311 .iter()
1312 .any(|e| matches!(e, ValidationError::MissingTargetDirective))
1313 );
1314 }
1315
1316 #[test]
1317 fn shared_memory_within_limits() {
1318 let ptx = ".version 8.5\n.target sm_80\n.address_size 64\n\
1319 .shared .align 4 .b8 smem[4096];\n";
1320 let result = validate_ptx(ptx);
1321 assert!(result.is_ok());
1322 }
1323
1324 #[test]
1325 fn shared_memory_exceeds_limits() {
1326 let ptx = ".version 6.4\n.target sm_75\n.address_size 64\n\
1328 .shared .align 4 .b8 smem[100000];\n";
1329 let result = validate_ptx(ptx);
1330 assert!(result.has_errors());
1331 assert!(
1332 result
1333 .errors
1334 .iter()
1335 .any(|e| matches!(e, ValidationError::InvalidSharedMemSize { .. }))
1336 );
1337 }
1338
1339 #[test]
1340 fn validate_for_specific_target() {
1341 let ptx = ".version 8.5\n.target sm_80\n.address_size 64\n\
1342 .shared .align 4 .b8 smem[200000];\n";
1343 let result = validate_ptx_for_target(ptx, SmVersion::Sm80);
1344 assert!(result.has_errors());
1346 }
1347
1348 #[test]
1349 fn extract_shared_mem_size_fn() {
1350 assert_eq!(
1351 extract_shared_mem_size(" .shared .align 4 .b8 smem[4096];"),
1352 Some(4096)
1353 );
1354 assert_eq!(
1355 extract_shared_mem_size(" .shared .align 16 .b8 tile[65536];"),
1356 Some(65536)
1357 );
1358 assert_eq!(extract_shared_mem_size(" mov.u32 %r0, 0;"), None);
1359 }
1360
1361 #[test]
1362 fn parse_sm_version_fn() {
1363 assert_eq!(parse_sm_version("sm_80"), Some(SmVersion::Sm80));
1364 assert_eq!(parse_sm_version("sm_90a"), Some(SmVersion::Sm90a));
1365 assert_eq!(parse_sm_version("sm_100"), Some(SmVersion::Sm100));
1366 assert_eq!(parse_sm_version("sm_999"), None);
1367 }
1368
1369 #[test]
1370 fn mismatched_braces_warning() {
1371 let ptx = ".version 8.5\n.target sm_80\n.address_size 64\n{\n";
1372 let result = validate_ptx(ptx);
1373 assert!(!result.warnings.is_empty());
1374 }
1375
1376 #[test]
1377 fn validation_error_display() {
1378 let err = ValidationError::MissingVersionDirective;
1379 assert_eq!(format!("{err}"), "missing .version directive");
1380
1381 let err = ValidationError::InvalidSharedMemSize {
1382 declared: 100_000,
1383 max_allowed: 65536,
1384 };
1385 assert!(format!("{err}").contains("100000"));
1386 }
1387
1388 fn reg(name: &str, ty: PtxType) -> Register {
1393 Register {
1394 name: name.to_string(),
1395 ty,
1396 }
1397 }
1398
1399 fn reg_op(name: &str, ty: PtxType) -> Operand {
1400 Operand::Register(reg(name, ty))
1401 }
1402
1403 #[test]
1404 fn ir_type_compatible_arithmetic_passes() {
1405 let instructions = vec![
1406 Instruction::LoadParam {
1407 ty: PtxType::F32,
1408 dst: reg("%f0", PtxType::F32),
1409 param_name: "a".to_string(),
1410 },
1411 Instruction::LoadParam {
1412 ty: PtxType::F32,
1413 dst: reg("%f1", PtxType::F32),
1414 param_name: "b".to_string(),
1415 },
1416 Instruction::Add {
1417 ty: PtxType::F32,
1418 dst: reg("%f2", PtxType::F32),
1419 a: reg_op("%f0", PtxType::F32),
1420 b: reg_op("%f1", PtxType::F32),
1421 },
1422 ];
1423 let result = validate_ir_instructions(&instructions);
1424 assert!(
1425 result.errors.is_empty(),
1426 "expected no errors, got: {:?}",
1427 result.errors
1428 );
1429 }
1430
1431 #[test]
1432 fn ir_type_mismatched_arithmetic_fails() {
1433 let instructions = vec![
1434 Instruction::LoadParam {
1435 ty: PtxType::F32,
1436 dst: reg("%f0", PtxType::F32),
1437 param_name: "a".to_string(),
1438 },
1439 Instruction::LoadParam {
1440 ty: PtxType::U32,
1441 dst: reg("%r0", PtxType::U32),
1442 param_name: "b".to_string(),
1443 },
1444 Instruction::Add {
1445 ty: PtxType::F32,
1446 dst: reg("%f1", PtxType::F32),
1447 a: reg_op("%f0", PtxType::F32),
1448 b: reg_op("%r0", PtxType::U32), },
1450 ];
1451 let result = validate_ir_instructions(&instructions);
1452 assert!(result.has_errors());
1453 assert!(
1454 result
1455 .errors
1456 .iter()
1457 .any(|e| e.kind == IrErrorKind::TypeMismatch)
1458 );
1459 }
1460
1461 #[test]
1462 fn ir_use_before_def_detection() {
1463 let instructions = vec![Instruction::Add {
1464 ty: PtxType::F32,
1465 dst: reg("%f2", PtxType::F32),
1466 a: reg_op("%f0", PtxType::F32), b: reg_op("%f1", PtxType::F32), }];
1469 let result = validate_ir_instructions(&instructions);
1470 assert!(result.has_errors());
1471 let ubd_count = result
1472 .errors
1473 .iter()
1474 .filter(|e| e.kind == IrErrorKind::UseBeforeDef)
1475 .count();
1476 assert!(ubd_count >= 2, "expected at least 2 use-before-def errors");
1477 }
1478
1479 #[test]
1480 fn ir_load_param_counted_as_definition() {
1481 let instructions = vec![
1482 Instruction::LoadParam {
1483 ty: PtxType::U64,
1484 dst: reg("%rd0", PtxType::U64),
1485 param_name: "ptr".to_string(),
1486 },
1487 Instruction::Load {
1488 space: MemorySpace::Global,
1489 qualifier: CacheQualifier::None,
1490 vec: VectorWidth::V1,
1491 ty: PtxType::F32,
1492 dst: reg("%f0", PtxType::F32),
1493 addr: Operand::Address {
1494 base: reg("%rd0", PtxType::U64),
1495 offset: None,
1496 },
1497 },
1498 ];
1499 let result = validate_register_lifetimes(&instructions);
1500 assert!(
1501 result.errors.is_empty(),
1502 "LoadParam should count as definition: {:?}",
1503 result.errors
1504 );
1505 }
1506
1507 #[test]
1508 fn ir_mov_special_counted_as_definition() {
1509 let instructions = vec![
1510 Instruction::MovSpecial {
1511 dst: reg("%r0", PtxType::U32),
1512 special: SpecialReg::TidX,
1513 },
1514 Instruction::Add {
1515 ty: PtxType::U32,
1516 dst: reg("%r1", PtxType::U32),
1517 a: reg_op("%r0", PtxType::U32),
1518 b: Operand::Immediate(ImmValue::U32(1)),
1519 },
1520 ];
1521 let result = validate_register_lifetimes(&instructions);
1522 assert!(
1523 result.errors.is_empty(),
1524 "MovSpecial should count as definition: {:?}",
1525 result.errors
1526 );
1527 }
1528
1529 #[test]
1530 fn ir_shared_store_without_barrier_warns() {
1531 let addr_reg = reg("%rd0", PtxType::U64);
1532 let instructions = vec![
1533 Instruction::LoadParam {
1534 ty: PtxType::U64,
1535 dst: addr_reg.clone(),
1536 param_name: "addr".to_string(),
1537 },
1538 Instruction::LoadParam {
1539 ty: PtxType::F32,
1540 dst: reg("%f0", PtxType::F32),
1541 param_name: "val".to_string(),
1542 },
1543 Instruction::Store {
1544 space: MemorySpace::Shared,
1545 qualifier: CacheQualifier::None,
1546 vec: VectorWidth::V1,
1547 ty: PtxType::F32,
1548 addr: Operand::Address {
1549 base: addr_reg.clone(),
1550 offset: None,
1551 },
1552 src: reg("%f0", PtxType::F32),
1553 },
1554 Instruction::Load {
1556 space: MemorySpace::Shared,
1557 qualifier: CacheQualifier::None,
1558 vec: VectorWidth::V1,
1559 ty: PtxType::F32,
1560 dst: reg("%f1", PtxType::F32),
1561 addr: Operand::Address {
1562 base: addr_reg,
1563 offset: Some(4),
1564 },
1565 },
1566 ];
1567 let result = validate_memory_consistency(&instructions);
1568 assert!(
1569 !result.warnings.is_empty(),
1570 "expected race condition warning"
1571 );
1572 assert!(
1573 result.warnings[0].message.contains("race condition"),
1574 "warning should mention race condition"
1575 );
1576 }
1577
1578 #[test]
1579 fn ir_barrier_after_shared_store_no_warning() {
1580 let addr_reg = reg("%rd0", PtxType::U64);
1581 let instructions = vec![
1582 Instruction::Store {
1583 space: MemorySpace::Shared,
1584 qualifier: CacheQualifier::None,
1585 vec: VectorWidth::V1,
1586 ty: PtxType::F32,
1587 addr: Operand::Address {
1588 base: addr_reg.clone(),
1589 offset: None,
1590 },
1591 src: reg("%f0", PtxType::F32),
1592 },
1593 Instruction::BarSync { id: 0 },
1594 Instruction::Load {
1595 space: MemorySpace::Shared,
1596 qualifier: CacheQualifier::None,
1597 vec: VectorWidth::V1,
1598 ty: PtxType::F32,
1599 dst: reg("%f1", PtxType::F32),
1600 addr: Operand::Address {
1601 base: addr_reg,
1602 offset: Some(4),
1603 },
1604 },
1605 ];
1606 let result = validate_memory_consistency(&instructions);
1607 assert!(
1608 result.warnings.is_empty(),
1609 "expected no warnings when barrier separates store/load"
1610 );
1611 }
1612
1613 #[test]
1614 fn ir_empty_instruction_list_no_errors() {
1615 let result = validate_ir_instructions(&[]);
1616 assert!(result.is_ok());
1617 assert!(result.warnings.is_empty());
1618 }
1619
1620 #[test]
1621 fn ir_complex_sequence_multiple_issues() {
1622 let instructions = vec![
1623 Instruction::Add {
1625 ty: PtxType::F32,
1626 dst: reg("%f1", PtxType::F32),
1627 a: reg_op("%f0", PtxType::F32),
1628 b: Operand::Immediate(ImmValue::F32(1.0)),
1629 },
1630 Instruction::Sub {
1632 ty: PtxType::F32,
1633 dst: reg("%r0", PtxType::U32),
1634 a: reg_op("%f1", PtxType::F32),
1635 b: Operand::Immediate(ImmValue::F32(2.0)),
1636 },
1637 ];
1638 let result = validate_ir_instructions(&instructions);
1639 assert!(result.has_errors());
1640
1641 let has_ubd = result
1642 .errors
1643 .iter()
1644 .any(|e| e.kind == IrErrorKind::UseBeforeDef);
1645 let has_type_mismatch = result
1646 .errors
1647 .iter()
1648 .any(|e| e.kind == IrErrorKind::TypeMismatch);
1649 assert!(has_ubd, "expected use-before-def error");
1650 assert!(has_type_mismatch, "expected type mismatch error");
1651 }
1652
1653 #[test]
1654 fn ir_validate_register_lifetimes_standalone() {
1655 let instructions = vec![
1656 Instruction::LoadParam {
1657 ty: PtxType::F32,
1658 dst: reg("%f0", PtxType::F32),
1659 param_name: "x".to_string(),
1660 },
1661 Instruction::Neg {
1662 ty: PtxType::F32,
1663 dst: reg("%f1", PtxType::F32),
1664 src: reg_op("%f0", PtxType::F32),
1665 },
1666 Instruction::Add {
1668 ty: PtxType::F32,
1669 dst: reg("%f2", PtxType::F32),
1670 a: reg_op("%f1", PtxType::F32),
1671 b: reg_op("%f99", PtxType::F32),
1672 },
1673 ];
1674 let result = validate_register_lifetimes(&instructions);
1675 assert!(result.has_errors());
1676 assert_eq!(result.errors.len(), 1);
1677 assert!(result.errors[0].message.contains("%f99"));
1678 }
1679
1680 #[test]
1681 fn ir_validate_memory_consistency_standalone() {
1682 let instructions = vec![
1684 Instruction::LoadParam {
1685 ty: PtxType::U32,
1686 dst: reg("%p0", PtxType::Pred),
1687 param_name: "pred".to_string(),
1688 },
1689 Instruction::Branch {
1690 target: "skip".to_string(),
1691 predicate: Some((reg("%p0", PtxType::Pred), false)),
1692 },
1693 Instruction::BarSync { id: 0 },
1694 Instruction::Label("skip".to_string()),
1695 ];
1696 let result = validate_memory_consistency(&instructions);
1697 assert!(!result.warnings.is_empty(), "expected divergence warning");
1698 assert!(result.warnings[0].message.contains("divergent"));
1699 }
1700
1701 #[test]
1702 fn ir_validation_result_display() {
1703 let result = IrValidationResult {
1704 errors: vec![IrValidationError {
1705 instruction_index: 3,
1706 kind: IrErrorKind::TypeMismatch,
1707 message: "dst type does not match".to_string(),
1708 }],
1709 warnings: vec![IrValidationWarning {
1710 instruction_index: 7,
1711 message: "possible race".to_string(),
1712 }],
1713 };
1714 let display = format!("{result}");
1715 assert!(display.contains("Errors (1)"));
1716 assert!(display.contains("TypeMismatch"));
1717 assert!(display.contains("Warnings (1)"));
1718 assert!(display.contains("possible race"));
1719
1720 let ok_result = IrValidationResult {
1722 errors: Vec::new(),
1723 warnings: Vec::new(),
1724 };
1725 let ok_display = format!("{ok_result}");
1726 assert!(ok_display.contains("passed"));
1727 }
1728
1729 #[test]
1730 fn ir_wmma_with_immediate_operand_flagged() {
1731 let instructions = vec![Instruction::Wmma {
1732 op: WmmaOp::LoadA,
1733 shape: WmmaShape::M16N16K16,
1734 layout: WmmaLayout::RowMajor,
1735 ty: PtxType::F16,
1736 fragments: vec![reg("%f0", PtxType::F16)],
1737 addr: Some(Operand::Immediate(ImmValue::U32(0))), stride: Some(Operand::Immediate(ImmValue::U32(16))), }];
1740 let result = validate_ir_instructions(&instructions);
1741 let invalid_operand_errors: Vec<_> = result
1742 .errors
1743 .iter()
1744 .filter(|e| e.kind == IrErrorKind::InvalidOperand)
1745 .collect();
1746 assert!(
1747 invalid_operand_errors.len() >= 2,
1748 "expected at least 2 InvalidOperand errors for wmma immediates, got {}",
1749 invalid_operand_errors.len()
1750 );
1751 }
1752
1753 #[test]
1754 fn ir_mixed_valid_and_invalid_instructions() {
1755 let instructions = vec![
1756 Instruction::LoadParam {
1758 ty: PtxType::F32,
1759 dst: reg("%f0", PtxType::F32),
1760 param_name: "x".to_string(),
1761 },
1762 Instruction::MovSpecial {
1764 dst: reg("%r0", PtxType::U32),
1765 special: SpecialReg::TidX,
1766 },
1767 Instruction::Add {
1769 ty: PtxType::F32,
1770 dst: reg("%f1", PtxType::F32),
1771 a: reg_op("%f0", PtxType::F32),
1772 b: Operand::Immediate(ImmValue::F32(1.0)),
1773 },
1774 Instruction::Sub {
1776 ty: PtxType::F32,
1777 dst: reg("%bad", PtxType::U32), a: reg_op("%f1", PtxType::F32),
1779 b: Operand::Immediate(ImmValue::F32(0.5)),
1780 },
1781 Instruction::Comment("test".to_string()),
1783 Instruction::Return,
1785 ];
1786 let result = validate_ir_instructions(&instructions);
1787 let type_errors: Vec<_> = result
1789 .errors
1790 .iter()
1791 .filter(|e| e.kind == IrErrorKind::TypeMismatch)
1792 .collect();
1793 assert_eq!(
1794 type_errors.len(),
1795 1,
1796 "expected exactly 1 type mismatch, got {}: {:?}",
1797 type_errors.len(),
1798 type_errors
1799 );
1800 let ubd_errors: Vec<_> = result
1802 .errors
1803 .iter()
1804 .filter(|e| e.kind == IrErrorKind::UseBeforeDef)
1805 .collect();
1806 assert!(
1807 ubd_errors.is_empty(),
1808 "expected no use-before-def errors: {ubd_errors:?}",
1809 );
1810 }
1811}
1812
1813#[cfg(test)]
1816#[path = "validator_tests.rs"]
1817mod sm_tests;