Skip to main content

oxicuda_ptx/emit/
validator.rs

1//! Basic PTX text validation.
2//!
3//! Provides lightweight validation checks on generated PTX text to catch
4//! common errors before submitting to the driver's `cuModuleLoadDataEx`.
5//! This is not a full PTX parser — it performs heuristic checks for:
6//!
7//! - Presence of required directives (`.version`, `.target`)
8//! - Register declaration vs. usage consistency
9//! - Shared memory size limits for the target architecture
10//!
11//! # Example
12//!
13//! ```
14//! use oxicuda_ptx::emit::validator::{validate_ptx, ValidationError};
15//!
16//! let ptx = ".version 8.5\n.target sm_90a\n.address_size 64\n";
17//! let result = validate_ptx(ptx);
18//! assert!(result.errors.is_empty());
19//! ```
20
21use std::collections::HashSet;
22
23use crate::arch::SmVersion;
24use crate::ir::{Instruction, MemorySpace, Operand, WmmaOp};
25
26/// Result of PTX validation containing errors and warnings.
27///
28/// An empty `errors` vector indicates the PTX passed all checks. Warnings
29/// are informational and do not indicate invalid PTX.
30#[derive(Debug, Clone)]
31pub struct ValidationResult {
32    /// Fatal validation errors that likely indicate broken PTX.
33    pub errors: Vec<ValidationError>,
34    /// Non-fatal warnings (informational).
35    pub warnings: Vec<String>,
36}
37
38impl ValidationResult {
39    /// Returns `true` if no errors were found.
40    #[must_use]
41    pub fn is_ok(&self) -> bool {
42        self.errors.is_empty()
43    }
44
45    /// Returns `true` if any errors were found.
46    #[must_use]
47    pub fn has_errors(&self) -> bool {
48        !self.errors.is_empty()
49    }
50}
51
52/// A PTX validation error.
53///
54/// Each variant describes a specific issue found during validation.
55#[derive(Debug, Clone)]
56pub enum ValidationError {
57    /// The `.version` directive is missing.
58    MissingVersionDirective,
59    /// The `.target` directive is missing.
60    MissingTargetDirective,
61    /// A register was used but not declared.
62    UndefinedRegister(String),
63    /// A type mismatch was detected (heuristic-based).
64    TypeMismatch {
65        /// The expected type.
66        expected: String,
67        /// The type that was found.
68        found: String,
69    },
70    /// Shared memory size exceeds architecture limits.
71    InvalidSharedMemSize {
72        /// The declared shared memory size in bytes.
73        declared: usize,
74        /// The maximum allowed for the target architecture.
75        max_allowed: usize,
76    },
77    /// The `.address_size` directive is missing or not 64.
78    InvalidAddressSize(String),
79    /// An instruction requires a newer SM version than the target.
80    SmIncompatibleInstruction {
81        /// The instruction or feature that is not available.
82        instruction: String,
83        /// The minimum SM version required (e.g. `"sm_80"`).
84        required_sm: String,
85        /// The SM version specified in the PTX (e.g. `"sm_75"`).
86        found_sm: String,
87    },
88    /// Register count exceeds the architecture's per-thread limit (255).
89    RegisterPressureExceeded {
90        /// Number of unique registers detected.
91        count: usize,
92        /// Maximum allowed registers per thread.
93        max_allowed: usize,
94    },
95    /// A generic validation error with descriptive message.
96    Other(String),
97}
98
99impl std::fmt::Display for ValidationError {
100    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
101        match self {
102            Self::MissingVersionDirective => write!(f, "missing .version directive"),
103            Self::MissingTargetDirective => write!(f, "missing .target directive"),
104            Self::UndefinedRegister(name) => write!(f, "undefined register: {name}"),
105            Self::TypeMismatch { expected, found } => {
106                write!(f, "type mismatch: expected {expected}, found {found}")
107            }
108            Self::InvalidSharedMemSize {
109                declared,
110                max_allowed,
111            } => {
112                write!(
113                    f,
114                    "shared memory {declared} bytes exceeds limit of {max_allowed} bytes"
115                )
116            }
117            Self::InvalidAddressSize(msg) => write!(f, "address size issue: {msg}"),
118            Self::SmIncompatibleInstruction {
119                instruction,
120                required_sm,
121                found_sm,
122            } => write!(
123                f,
124                "instruction '{instruction}' requires {required_sm} but target is {found_sm}"
125            ),
126            Self::RegisterPressureExceeded { count, max_allowed } => write!(
127                f,
128                "register count {count} exceeds per-thread limit of {max_allowed}"
129            ),
130            Self::Other(msg) => write!(f, "{msg}"),
131        }
132    }
133}
134
135/// Validates PTX text for common errors.
136///
137/// Performs the following checks:
138/// 1. `.version` directive is present
139/// 2. `.target` directive is present
140/// 3. Register declarations match usage (heuristic)
141/// 4. Shared memory does not exceed architecture limits
142///
143/// # Arguments
144///
145/// * `ptx` - The PTX text to validate
146///
147/// # Returns
148///
149/// A [`ValidationResult`] containing any errors and warnings found.
150#[must_use]
151pub fn validate_ptx(ptx: &str) -> ValidationResult {
152    let mut errors = Vec::new();
153    let mut warnings = Vec::new();
154
155    // Check for .version directive
156    if !ptx.contains(".version") {
157        errors.push(ValidationError::MissingVersionDirective);
158    }
159
160    // Check for .target directive
161    if !ptx.contains(".target") {
162        errors.push(ValidationError::MissingTargetDirective);
163    }
164
165    // Try to determine target architecture for limit checking
166    let target_sm = extract_target_sm(ptx);
167
168    // Check shared memory limits
169    check_shared_memory(ptx, target_sm, &mut errors, &mut warnings);
170
171    // Check register declarations vs usage (real tracking)
172    check_register_declarations(ptx, &mut warnings);
173
174    // Check register pressure against hardware limits (255 per thread)
175    check_register_pressure(ptx, &mut errors, &mut warnings);
176
177    // Check SM-version-specific instruction availability
178    if let Some(sm) = target_sm {
179        check_sm_compatibility(ptx, sm, &mut errors, &mut warnings);
180    }
181
182    // Check for basic structural issues
183    check_structure(ptx, &mut warnings);
184
185    ValidationResult { errors, warnings }
186}
187
188/// Validates PTX text against a specific target architecture.
189///
190/// This variant allows specifying the target explicitly rather than
191/// extracting it from the PTX text. It runs all checks from [`validate_ptx`]
192/// plus explicit architecture-specific checks (SM compatibility, register
193/// pressure, shared memory limits) against `target`.
194#[must_use]
195pub fn validate_ptx_for_target(ptx: &str, target: SmVersion) -> ValidationResult {
196    let mut errors = Vec::new();
197    let mut warnings = Vec::new();
198
199    // Check for .version directive
200    if !ptx.contains(".version") {
201        errors.push(ValidationError::MissingVersionDirective);
202    }
203
204    // Check for .target directive
205    if !ptx.contains(".target") {
206        errors.push(ValidationError::MissingTargetDirective);
207    }
208
209    // Shared memory against the explicitly-supplied target
210    check_shared_memory(ptx, Some(target), &mut errors, &mut warnings);
211
212    // Register declarations heuristic
213    check_register_declarations(ptx, &mut warnings);
214
215    // Register pressure against hardware limit
216    check_register_pressure(ptx, &mut errors, &mut warnings);
217
218    // SM-version-specific instruction compatibility
219    check_sm_compatibility(ptx, target, &mut errors, &mut warnings);
220
221    // Basic structural checks
222    check_structure(ptx, &mut warnings);
223
224    ValidationResult { errors, warnings }
225}
226
227/// Extracts the target SM version from PTX text, if present.
228fn extract_target_sm(ptx: &str) -> Option<SmVersion> {
229    for line in ptx.lines() {
230        let trimmed = line.trim();
231        if trimmed.starts_with(".target") {
232            let parts: Vec<&str> = trimmed.split_whitespace().collect();
233            if parts.len() >= 2 {
234                return parse_sm_version(parts[1].trim_end_matches(';'));
235            }
236        }
237    }
238    None
239}
240
241/// Parses a target string like `sm_80` into an `SmVersion`.
242fn parse_sm_version(s: &str) -> Option<SmVersion> {
243    match s {
244        "sm_75" => Some(SmVersion::Sm75),
245        "sm_80" => Some(SmVersion::Sm80),
246        "sm_86" => Some(SmVersion::Sm86),
247        "sm_89" => Some(SmVersion::Sm89),
248        "sm_90" => Some(SmVersion::Sm90),
249        "sm_90a" => Some(SmVersion::Sm90a),
250        "sm_100" => Some(SmVersion::Sm100),
251        "sm_120" => Some(SmVersion::Sm120),
252        _ => None,
253    }
254}
255
256/// Checks shared memory declarations against architecture limits.
257fn check_shared_memory(
258    ptx: &str,
259    target: Option<SmVersion>,
260    errors: &mut Vec<ValidationError>,
261    warnings: &mut Vec<String>,
262) {
263    let max_smem = target.map_or(usize::MAX, |sm| sm.max_shared_mem_per_block() as usize);
264
265    let mut total_smem: usize = 0;
266
267    for line in ptx.lines() {
268        let trimmed = line.trim();
269        if let Some(size) = extract_shared_mem_size(trimmed) {
270            total_smem = total_smem.saturating_add(size);
271        }
272    }
273
274    if total_smem > max_smem {
275        errors.push(ValidationError::InvalidSharedMemSize {
276            declared: total_smem,
277            max_allowed: max_smem,
278        });
279    } else if total_smem > 48 * 1024 && target.is_some() {
280        warnings.push(format!(
281            "shared memory usage ({total_smem} bytes) exceeds default limit (49152); \
282             may require opt-in via cuFuncSetAttribute"
283        ));
284    }
285}
286
287/// Extracts the byte size from a `.shared` declaration line.
288///
289/// Handles patterns like `.shared .align 4 .b8 smem[1024];`
290fn extract_shared_mem_size(line: &str) -> Option<usize> {
291    if !line.contains(".shared") {
292        return None;
293    }
294
295    // Look for [size] pattern
296    let bracket_start = line.find('[')?;
297    let bracket_end = line.find(']')?;
298    if bracket_end <= bracket_start {
299        return None;
300    }
301
302    let size_str = &line[bracket_start + 1..bracket_end];
303    size_str.trim().parse::<usize>().ok()
304}
305
306/// Heuristic check for register declarations vs usage.
307fn check_register_declarations(ptx: &str, warnings: &mut Vec<String>) {
308    // Count register declaration groups
309    let decl_count = ptx
310        .lines()
311        .filter(|line| line.trim().starts_with(".reg"))
312        .count();
313
314    // Count entry points
315    let entry_count = ptx.lines().filter(|line| line.contains(".entry")).count();
316
317    if entry_count > 0 && decl_count == 0 {
318        warnings.push(
319            "kernel has no .reg declarations; all registers may be declared via raw PTX"
320                .to_string(),
321        );
322    }
323}
324
325/// Checks basic structural properties of the PTX.
326fn check_structure(ptx: &str, warnings: &mut Vec<String>) {
327    let open_braces = ptx.chars().filter(|c| *c == '{').count();
328    let close_braces = ptx.chars().filter(|c| *c == '}').count();
329
330    if open_braces != close_braces {
331        warnings.push(format!(
332            "mismatched braces: {open_braces} opening vs {close_braces} closing"
333        ));
334    }
335}
336
337// ===========================================================================
338// SM version compatibility checks
339// ===========================================================================
340
341/// Describes an instruction/feature that requires a minimum SM version.
342struct SmRequirement {
343    /// Substring to search for in the PTX text.
344    pattern: &'static str,
345    /// The minimum SM version required.
346    min_sm: SmVersion,
347    /// Human-readable name for error messages.
348    name: &'static str,
349}
350
351/// Table of instructions and the minimum SM version that supports them.
352const SM_REQUIREMENTS: &[SmRequirement] = &[
353    SmRequirement {
354        pattern: "cp.async",
355        min_sm: SmVersion::Sm80,
356        name: "cp.async",
357    },
358    SmRequirement {
359        pattern: "wgmma",
360        min_sm: SmVersion::Sm90,
361        name: "wgmma",
362    },
363    SmRequirement {
364        pattern: "mma.sync",
365        min_sm: SmVersion::Sm75,
366        name: "mma.sync (tensor core)",
367    },
368    SmRequirement {
369        pattern: "ldmatrix",
370        min_sm: SmVersion::Sm75,
371        name: "ldmatrix",
372    },
373    SmRequirement {
374        pattern: ".e4m3",
375        min_sm: SmVersion::Sm89,
376        name: "fp8 e4m3 type",
377    },
378    SmRequirement {
379        pattern: ".e5m2",
380        min_sm: SmVersion::Sm89,
381        name: "fp8 e5m2 type",
382    },
383    SmRequirement {
384        pattern: "tcgen05",
385        min_sm: SmVersion::Sm100,
386        name: "tcgen05",
387    },
388];
389
390/// Checks that instructions present in the PTX are supported by the target SM.
391///
392/// Scans for known instruction patterns and emits an error for each that
393/// requires a newer SM than the target.
394fn check_sm_compatibility(
395    ptx: &str,
396    sm: SmVersion,
397    errors: &mut Vec<ValidationError>,
398    _warnings: &mut Vec<String>,
399) {
400    let found_sm_str = sm.as_ptx_str();
401    for req in SM_REQUIREMENTS {
402        if ptx.contains(req.pattern) && sm < req.min_sm {
403            errors.push(ValidationError::SmIncompatibleInstruction {
404                instruction: req.name.to_string(),
405                required_sm: req.min_sm.as_ptx_str().to_string(),
406                found_sm: found_sm_str.to_string(),
407            });
408        }
409    }
410}
411
412// ===========================================================================
413// Register pressure check
414// ===========================================================================
415
416/// Maximum number of registers per thread allowed by all current NVIDIA GPUs.
417const MAX_REGISTERS_PER_THREAD: usize = 255;
418
419/// Register count at which a warning is emitted (approaching the limit).
420const REGISTER_PRESSURE_WARNING_THRESHOLD: usize = 200;
421
422/// Checks whether the number of distinct register names used in the PTX
423/// text exceeds or approaches the per-thread hardware limit.
424///
425/// Scans for PTX register naming conventions:
426/// - `%r\d+`  — u32/s32 registers
427/// - `%f\d+`  — f32 registers
428/// - `%rd\d+` — u64/s64 registers
429/// - `%fd\d+` — f64 registers
430/// - `%p\d+`  — predicate registers
431/// - `%b\d+`  — b32/b64 registers
432/// - `%h\d+`  — f16 registers
433fn check_register_pressure(
434    ptx: &str,
435    errors: &mut Vec<ValidationError>,
436    warnings: &mut Vec<String>,
437) {
438    use std::collections::HashSet;
439
440    let mut seen: HashSet<&str> = HashSet::new();
441
442    // We walk the PTX character by character looking for '%' followed by a
443    // register name (letters then digits).  This avoids bringing in a regex
444    // dependency while still being precise enough for generated PTX.
445    let bytes = ptx.as_bytes();
446    let len = bytes.len();
447    let mut i = 0;
448    while i < len {
449        if bytes[i] == b'%' {
450            let start = i;
451            i += 1;
452            // Consume the letter prefix (e.g. "rd", "fd", "r", "f", "p", ...)
453            while i < len && bytes[i].is_ascii_alphabetic() {
454                i += 1;
455            }
456            // Must be followed by at least one digit to be a register reference
457            if i < len && bytes[i].is_ascii_digit() {
458                while i < len && bytes[i].is_ascii_digit() {
459                    i += 1;
460                }
461                // Only count concrete register references (not %tid.x, %ctaid.x, etc.)
462                let token = &ptx[start..i];
463                // Special registers contain '.' after the name — skip them.
464                // We already stopped at non-digit so the token is clean.
465                // Exclude PTX special registers that start with known prefixes.
466                let name_part = &token[1..]; // strip leading '%'
467                let is_special = name_part.starts_with("tid")
468                    || name_part.starts_with("ntid")
469                    || name_part.starts_with("ctaid")
470                    || name_part.starts_with("nctaid")
471                    || name_part.starts_with("laneid")
472                    || name_part.starts_with("warpid")
473                    || name_part.starts_with("smid")
474                    || name_part.starts_with("pm")
475                    || name_part.starts_with("envreg")
476                    || name_part.starts_with("globaltimer")
477                    || name_part.starts_with("param_");
478                if !is_special {
479                    seen.insert(token);
480                }
481            }
482        } else {
483            i += 1;
484        }
485    }
486
487    let count = seen.len();
488    if count > MAX_REGISTERS_PER_THREAD {
489        errors.push(ValidationError::RegisterPressureExceeded {
490            count,
491            max_allowed: MAX_REGISTERS_PER_THREAD,
492        });
493    } else if count > REGISTER_PRESSURE_WARNING_THRESHOLD {
494        warnings.push(format!(
495            "register count ({count}) is approaching the per-thread limit of \
496             {MAX_REGISTERS_PER_THREAD}; consider reducing register pressure"
497        ));
498    }
499}
500
501// ===========================================================================
502// IR-level validation
503// ===========================================================================
504
505/// Result of IR-level instruction validation.
506#[derive(Debug, Clone)]
507pub struct IrValidationResult {
508    /// Fatal validation errors.
509    pub errors: Vec<IrValidationError>,
510    /// Non-fatal warnings.
511    pub warnings: Vec<IrValidationWarning>,
512}
513
514impl IrValidationResult {
515    /// Returns `true` if no errors were found.
516    #[must_use]
517    pub fn is_ok(&self) -> bool {
518        self.errors.is_empty()
519    }
520
521    /// Returns `true` if any errors were found.
522    #[must_use]
523    pub fn has_errors(&self) -> bool {
524        !self.errors.is_empty()
525    }
526
527    /// Merge another result into this one.
528    fn merge(&mut self, other: &Self) {
529        self.errors.extend(other.errors.iter().cloned());
530        self.warnings.extend(other.warnings.iter().cloned());
531    }
532}
533
534impl std::fmt::Display for IrValidationResult {
535    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
536        if self.errors.is_empty() && self.warnings.is_empty() {
537            return write!(f, "IR validation passed: no errors, no warnings");
538        }
539        if !self.errors.is_empty() {
540            writeln!(f, "Errors ({}):", self.errors.len())?;
541            for err in &self.errors {
542                writeln!(
543                    f,
544                    "  [{:>3}] {}: {}",
545                    err.instruction_index, err.kind, err.message
546                )?;
547            }
548        }
549        if !self.warnings.is_empty() {
550            writeln!(f, "Warnings ({}):", self.warnings.len())?;
551            for warn in &self.warnings {
552                writeln!(f, "  [{:>3}] {}", warn.instruction_index, warn.message)?;
553            }
554        }
555        Ok(())
556    }
557}
558
559/// An IR-level validation error tied to a specific instruction.
560#[derive(Debug, Clone)]
561pub struct IrValidationError {
562    /// Index of the offending instruction in the sequence.
563    pub instruction_index: usize,
564    /// The kind of error detected.
565    pub kind: IrErrorKind,
566    /// Human-readable description.
567    pub message: String,
568}
569
570/// Categories of IR validation errors.
571#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
572pub enum IrErrorKind {
573    /// Type mismatch between operands.
574    TypeMismatch,
575    /// Register used before definition.
576    UseBeforeDef,
577    /// Invalid memory space for instruction.
578    InvalidMemorySpace,
579    /// Invalid operand type for instruction.
580    InvalidOperand,
581    /// Barrier inside divergent control flow.
582    BarrierInDivergent,
583    /// Register lifetime issue.
584    RegisterLifetime,
585}
586
587impl std::fmt::Display for IrErrorKind {
588    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
589        match self {
590            Self::TypeMismatch => write!(f, "TypeMismatch"),
591            Self::UseBeforeDef => write!(f, "UseBeforeDef"),
592            Self::InvalidMemorySpace => write!(f, "InvalidMemorySpace"),
593            Self::InvalidOperand => write!(f, "InvalidOperand"),
594            Self::BarrierInDivergent => write!(f, "BarrierInDivergent"),
595            Self::RegisterLifetime => write!(f, "RegisterLifetime"),
596        }
597    }
598}
599
600/// An IR-level validation warning tied to a specific instruction.
601#[derive(Debug, Clone)]
602pub struct IrValidationWarning {
603    /// Index of the relevant instruction.
604    pub instruction_index: usize,
605    /// Human-readable warning message.
606    pub message: String,
607}
608
609// ---------------------------------------------------------------------------
610// Helper: extract register names from operands used as sources
611// ---------------------------------------------------------------------------
612
613/// Extracts register name(s) from an operand used as a source.
614fn push_operand_names(op: &Operand, names: &mut Vec<String>) {
615    if let Operand::Register(r) = op {
616        names.push(r.name.clone());
617    }
618    if let Operand::Address { base, .. } = op {
619        names.push(base.name.clone());
620    }
621}
622
623/// Collects the register names used as source operands in an instruction.
624#[allow(clippy::too_many_lines)]
625fn collect_src_register_names(inst: &Instruction) -> Vec<String> {
626    let mut names = Vec::new();
627
628    match inst {
629        Instruction::Add { a, b, .. }
630        | Instruction::Sub { a, b, .. }
631        | Instruction::Mul { a, b, .. }
632        | Instruction::Min { a, b, .. }
633        | Instruction::Max { a, b, .. }
634        | Instruction::Div { a, b, .. }
635        | Instruction::Rem { a, b, .. }
636        | Instruction::And { a, b, .. }
637        | Instruction::Or { a, b, .. }
638        | Instruction::Xor { a, b, .. }
639        | Instruction::SetP { a, b, .. } => {
640            push_operand_names(a, &mut names);
641            push_operand_names(b, &mut names);
642        }
643        Instruction::Mad { a, b, c, .. }
644        | Instruction::MadLo { a, b, c, .. }
645        | Instruction::MadHi { a, b, c, .. }
646        | Instruction::MadWide { a, b, c, .. }
647        | Instruction::Fma { a, b, c, .. }
648        | Instruction::Dp4a { a, b, c, .. }
649        | Instruction::Dp2a { a, b, c, .. } => {
650            push_operand_names(a, &mut names);
651            push_operand_names(b, &mut names);
652            push_operand_names(c, &mut names);
653        }
654        Instruction::Neg { src, .. }
655        | Instruction::Abs { src, .. }
656        | Instruction::Brev { src, .. }
657        | Instruction::Clz { src, .. }
658        | Instruction::Popc { src, .. }
659        | Instruction::Bfind { src, .. }
660        | Instruction::Rcp { src, .. }
661        | Instruction::Rsqrt { src, .. }
662        | Instruction::Sqrt { src, .. }
663        | Instruction::Ex2 { src, .. }
664        | Instruction::Lg2 { src, .. }
665        | Instruction::Sin { src, .. }
666        | Instruction::Cos { src, .. }
667        | Instruction::Cvt { src, .. }
668        | Instruction::Redux { src, .. } => {
669            push_operand_names(src, &mut names);
670        }
671        Instruction::Bfe {
672            src, start, len, ..
673        } => {
674            push_operand_names(src, &mut names);
675            push_operand_names(start, &mut names);
676            push_operand_names(len, &mut names);
677        }
678        Instruction::Bfi {
679            insert,
680            base,
681            start,
682            len,
683            ..
684        } => {
685            push_operand_names(insert, &mut names);
686            push_operand_names(base, &mut names);
687            push_operand_names(start, &mut names);
688            push_operand_names(len, &mut names);
689        }
690        Instruction::Shl { src, amount, .. } | Instruction::Shr { src, amount, .. } => {
691            push_operand_names(src, &mut names);
692            push_operand_names(amount, &mut names);
693        }
694        Instruction::Load { addr, .. } | Instruction::MbarrierArrive { addr } => {
695            push_operand_names(addr, &mut names);
696        }
697        Instruction::Store { addr, src, .. } => {
698            push_operand_names(addr, &mut names);
699            names.push(src.name.clone());
700        }
701        Instruction::CpAsync {
702            dst_shared,
703            src_global,
704            ..
705        } => {
706            push_operand_names(dst_shared, &mut names);
707            push_operand_names(src_global, &mut names);
708        }
709        Instruction::Branch { predicate, .. } => {
710            if let Some((r, _)) = predicate {
711                names.push(r.name.clone());
712            }
713        }
714        Instruction::Atom { addr, src, .. } | Instruction::Red { addr, src, .. } => {
715            push_operand_names(addr, &mut names);
716            push_operand_names(src, &mut names);
717        }
718        Instruction::AtomCas {
719            addr,
720            compare,
721            value,
722            ..
723        } => {
724            push_operand_names(addr, &mut names);
725            push_operand_names(compare, &mut names);
726            push_operand_names(value, &mut names);
727        }
728        Instruction::Tex1d { coord, .. } | Instruction::SurfLoad { coord, .. } => {
729            push_operand_names(coord, &mut names);
730        }
731        Instruction::Tex2d {
732            coord_x, coord_y, ..
733        } => {
734            push_operand_names(coord_x, &mut names);
735            push_operand_names(coord_y, &mut names);
736        }
737        Instruction::Tex3d {
738            coord_x,
739            coord_y,
740            coord_z,
741            ..
742        } => {
743            push_operand_names(coord_x, &mut names);
744            push_operand_names(coord_y, &mut names);
745            push_operand_names(coord_z, &mut names);
746        }
747        Instruction::SurfStore { coord, src, .. } => {
748            push_operand_names(coord, &mut names);
749            names.push(src.name.clone());
750        }
751        Instruction::Wmma {
752            fragments,
753            addr,
754            stride,
755            ..
756        } => {
757            for frag in fragments {
758                names.push(frag.name.clone());
759            }
760            if let Some(a) = addr {
761                push_operand_names(a, &mut names);
762            }
763            if let Some(s) = stride {
764                push_operand_names(s, &mut names);
765            }
766        }
767        Instruction::Mma {
768            a_regs,
769            b_regs,
770            c_regs,
771            ..
772        } => {
773            for r in a_regs.iter().chain(b_regs).chain(c_regs) {
774                names.push(r.name.clone());
775            }
776        }
777        Instruction::Wgmma { desc_a, desc_b, .. } => {
778            names.push(desc_a.name.clone());
779            names.push(desc_b.name.clone());
780        }
781        Instruction::TmaLoad {
782            desc,
783            coords,
784            barrier,
785            dst_shared,
786            ..
787        } => {
788            names.push(desc.name.clone());
789            for c in coords {
790                names.push(c.name.clone());
791            }
792            names.push(barrier.name.clone());
793            push_operand_names(dst_shared, &mut names);
794        }
795        Instruction::Stmatrix { dst_addr, src, .. } => {
796            push_operand_names(dst_addr, &mut names);
797            names.push(src.name.clone());
798        }
799        Instruction::MbarrierInit { addr, count } => {
800            push_operand_names(addr, &mut names);
801            push_operand_names(count, &mut names);
802        }
803        Instruction::MbarrierWait { addr, phase } => {
804            push_operand_names(addr, &mut names);
805            push_operand_names(phase, &mut names);
806        }
807        Instruction::MovSpecial { .. }
808        | Instruction::LoadParam { .. }
809        | Instruction::Label(_)
810        | Instruction::Return
811        | Instruction::Comment(_)
812        | Instruction::Raw(_)
813        | Instruction::Pragma(_)
814        | Instruction::BarSync { .. }
815        | Instruction::BarArrive { .. }
816        | Instruction::FenceAcqRel { .. }
817        | Instruction::FenceProxy { .. }
818        | Instruction::CpAsyncCommit
819        | Instruction::CpAsyncWait { .. }
820        | Instruction::ElectSync { .. }
821        | Instruction::Setmaxnreg { .. }
822        | Instruction::Griddepcontrol { .. }
823        | Instruction::BarrierCluster
824        | Instruction::FenceCluster => {}
825
826        Instruction::Tcgen05Mma { a_desc, b_desc } => {
827            names.push(a_desc.name.clone());
828            names.push(b_desc.name.clone());
829        }
830        Instruction::CpAsyncBulk {
831            dst_smem,
832            src_gmem,
833            desc,
834        } => {
835            names.push(dst_smem.name.clone());
836            names.push(src_gmem.name.clone());
837            names.push(desc.name.clone());
838        }
839        Instruction::Ldmatrix { src_addr, .. } => {
840            push_operand_names(src_addr, &mut names);
841        }
842    }
843    names
844}
845
846/// Returns the destination register name defined by an instruction, if any.
847fn dst_register_name(inst: &Instruction) -> Option<String> {
848    match inst {
849        Instruction::Add { dst, .. }
850        | Instruction::Sub { dst, .. }
851        | Instruction::Mul { dst, .. }
852        | Instruction::Min { dst, .. }
853        | Instruction::Max { dst, .. }
854        | Instruction::Div { dst, .. }
855        | Instruction::Rem { dst, .. }
856        | Instruction::And { dst, .. }
857        | Instruction::Or { dst, .. }
858        | Instruction::Xor { dst, .. }
859        | Instruction::SetP { dst, .. }
860        | Instruction::Mad { dst, .. }
861        | Instruction::MadLo { dst, .. }
862        | Instruction::MadHi { dst, .. }
863        | Instruction::MadWide { dst, .. }
864        | Instruction::Fma { dst, .. }
865        | Instruction::Neg { dst, .. }
866        | Instruction::Abs { dst, .. }
867        | Instruction::Brev { dst, .. }
868        | Instruction::Clz { dst, .. }
869        | Instruction::Popc { dst, .. }
870        | Instruction::Bfind { dst, .. }
871        | Instruction::Bfe { dst, .. }
872        | Instruction::Bfi { dst, .. }
873        | Instruction::Rcp { dst, .. }
874        | Instruction::Rsqrt { dst, .. }
875        | Instruction::Sqrt { dst, .. }
876        | Instruction::Ex2 { dst, .. }
877        | Instruction::Lg2 { dst, .. }
878        | Instruction::Sin { dst, .. }
879        | Instruction::Cos { dst, .. }
880        | Instruction::Shl { dst, .. }
881        | Instruction::Shr { dst, .. }
882        | Instruction::Load { dst, .. }
883        | Instruction::Cvt { dst, .. }
884        | Instruction::Atom { dst, .. }
885        | Instruction::AtomCas { dst, .. }
886        | Instruction::MovSpecial { dst, .. }
887        | Instruction::LoadParam { dst, .. }
888        | Instruction::Dp4a { dst, .. }
889        | Instruction::Dp2a { dst, .. }
890        | Instruction::Tex1d { dst, .. }
891        | Instruction::Tex2d { dst, .. }
892        | Instruction::Tex3d { dst, .. }
893        | Instruction::SurfLoad { dst, .. }
894        | Instruction::Redux { dst, .. }
895        | Instruction::ElectSync { dst, .. } => Some(dst.name.clone()),
896        Instruction::Mma { d_regs, .. } => d_regs.first().map(|r| r.name.clone()),
897        Instruction::Wgmma { d_regs, .. } => d_regs.first().map(|r| r.name.clone()),
898        _ => None,
899    }
900}
901
902// ---------------------------------------------------------------------------
903// Type compatibility check
904// ---------------------------------------------------------------------------
905
906/// Returns `true` if an operand's register type is compatible with the
907/// instruction type for simple arithmetic.
908fn operand_type_compatible(op: &Operand, expected_ty: crate::ir::PtxType) -> bool {
909    match op {
910        Operand::Register(r) => r.ty == expected_ty,
911        // Immediates, symbols, and address operands are always considered compatible at IR level
912        Operand::Immediate(_) | Operand::Symbol(_) | Operand::Address { .. } => true,
913    }
914}
915
916// ---------------------------------------------------------------------------
917// Public IR validation functions
918// ---------------------------------------------------------------------------
919
920/// Validate an IR instruction sequence for type safety and correctness.
921///
922/// Performs:
923/// 1. Type checking on arithmetic instructions
924/// 2. Use-before-def analysis on registers
925/// 3. Memory space validation for load/store/cp.async
926/// 4. Operand validation for tensor core instructions
927#[must_use]
928pub fn validate_ir_instructions(instructions: &[Instruction]) -> IrValidationResult {
929    let mut result = IrValidationResult {
930        errors: Vec::new(),
931        warnings: Vec::new(),
932    };
933
934    // Run sub-validations and merge
935    let lifetime_result = validate_register_lifetimes(instructions);
936    result.merge(&lifetime_result);
937
938    let consistency_result = validate_memory_consistency(instructions);
939    result.merge(&consistency_result);
940
941    // Type checking for arithmetic instructions
942    for (idx, inst) in instructions.iter().enumerate() {
943        validate_type_safety(inst, idx, &mut result);
944        validate_memory_spaces(inst, idx, &mut result);
945        validate_tensor_core_operands(inst, idx, &mut result);
946    }
947
948    result
949}
950
951/// Validate register lifetimes: no use-before-def.
952///
953/// Tracks which registers have been written to (appear as `dst`) and flags
954/// any register used as a source before it has been defined. Registers
955/// defined by `LoadParam` and `MovSpecial` are counted as definitions.
956#[must_use]
957pub fn validate_register_lifetimes(instructions: &[Instruction]) -> IrValidationResult {
958    let mut result = IrValidationResult {
959        errors: Vec::new(),
960        warnings: Vec::new(),
961    };
962
963    let mut defined: HashSet<String> = HashSet::new();
964
965    for (idx, inst) in instructions.iter().enumerate() {
966        // First, check sources for use-before-def
967        let src_names = collect_src_register_names(inst);
968        for name in &src_names {
969            if !defined.contains(name) {
970                result.errors.push(IrValidationError {
971                    instruction_index: idx,
972                    kind: IrErrorKind::UseBeforeDef,
973                    message: format!("register {name} used before definition"),
974                });
975            }
976        }
977
978        // Then, record definitions
979        if let Some(dst_name) = dst_register_name(inst) {
980            defined.insert(dst_name);
981        }
982
983        // Multi-register definitions for tensor core
984        match inst {
985            Instruction::Mma { d_regs, .. } | Instruction::Wgmma { d_regs, .. } => {
986                for r in d_regs {
987                    defined.insert(r.name.clone());
988                }
989            }
990            Instruction::Wmma { op, fragments, .. } => {
991                // For WMMA Mma/StoreD, fragments are sources; for Load, they are defs
992                if matches!(op, WmmaOp::LoadA | WmmaOp::LoadB) {
993                    for frag in fragments {
994                        defined.insert(frag.name.clone());
995                    }
996                }
997            }
998            _ => {}
999        }
1000    }
1001
1002    result
1003}
1004
1005/// Validate fence/barrier placement and memory consistency.
1006///
1007/// Checks:
1008/// 1. Barriers potentially inside divergent control flow
1009/// 2. Shared memory stores without a subsequent barrier before shared loads
1010#[must_use]
1011pub fn validate_memory_consistency(instructions: &[Instruction]) -> IrValidationResult {
1012    let mut result = IrValidationResult {
1013        errors: Vec::new(),
1014        warnings: Vec::new(),
1015    };
1016
1017    // Check for barriers after conditional branches (potential divergence)
1018    check_barrier_divergence(instructions, &mut result);
1019
1020    // Check for shared memory race conditions
1021    check_shared_memory_races(instructions, &mut result);
1022
1023    result
1024}
1025
1026// ---------------------------------------------------------------------------
1027// Internal validation helpers
1028// ---------------------------------------------------------------------------
1029
1030/// Type safety checks for arithmetic instructions.
1031fn validate_type_safety(inst: &Instruction, idx: usize, result: &mut IrValidationResult) {
1032    match inst {
1033        Instruction::Add { ty, dst, a, b }
1034        | Instruction::Sub { ty, dst, a, b }
1035        | Instruction::Min { ty, dst, a, b }
1036        | Instruction::Max { ty, dst, a, b } => {
1037            if dst.ty != *ty {
1038                result.errors.push(IrValidationError {
1039                    instruction_index: idx,
1040                    kind: IrErrorKind::TypeMismatch,
1041                    message: format!(
1042                        "dst register {} has type {:?} but instruction type is {:?}",
1043                        dst.name, dst.ty, ty
1044                    ),
1045                });
1046            }
1047            if !operand_type_compatible(a, *ty) {
1048                result.errors.push(IrValidationError {
1049                    instruction_index: idx,
1050                    kind: IrErrorKind::TypeMismatch,
1051                    message: format!("operand a type mismatch with instruction type {ty:?}"),
1052                });
1053            }
1054            if !operand_type_compatible(b, *ty) {
1055                result.errors.push(IrValidationError {
1056                    instruction_index: idx,
1057                    kind: IrErrorKind::TypeMismatch,
1058                    message: format!("operand b type mismatch with instruction type {ty:?}"),
1059                });
1060            }
1061        }
1062        Instruction::Mul { ty, dst, a, b, .. } => {
1063            // For mul, dst type depends on mode (wide produces double width)
1064            // but the source operands should match the instruction type
1065            if !operand_type_compatible(a, *ty) {
1066                result.errors.push(IrValidationError {
1067                    instruction_index: idx,
1068                    kind: IrErrorKind::TypeMismatch,
1069                    message: format!("mul operand a type mismatch with instruction type {ty:?}"),
1070                });
1071            }
1072            if !operand_type_compatible(b, *ty) {
1073                result.errors.push(IrValidationError {
1074                    instruction_index: idx,
1075                    kind: IrErrorKind::TypeMismatch,
1076                    message: format!("mul operand b type mismatch with instruction type {ty:?}"),
1077                });
1078            }
1079            // For non-wide modes, dst should match
1080            if dst.ty != *ty {
1081                result.warnings.push(IrValidationWarning {
1082                    instruction_index: idx,
1083                    message: format!(
1084                        "mul dst register {} type {:?} differs from instruction type {:?}",
1085                        dst.name, dst.ty, ty
1086                    ),
1087                });
1088            }
1089        }
1090        _ => {}
1091    }
1092}
1093
1094/// Validate memory spaces for load/store/cp.async instructions.
1095fn validate_memory_spaces(inst: &Instruction, idx: usize, result: &mut IrValidationResult) {
1096    if let Instruction::CpAsync {
1097        dst_shared: Operand::Register(r),
1098        ..
1099    } = inst
1100    {
1101        // If someone passes a raw register instead of an address, warn
1102        result.warnings.push(IrValidationWarning {
1103            instruction_index: idx,
1104            message: format!(
1105                "cp.async dst_shared uses register {} directly; expected a shared memory address",
1106                r.name
1107            ),
1108        });
1109    }
1110
1111    // Validate that Load/Store with shared space uses address operands
1112    match inst {
1113        Instruction::Load {
1114            space,
1115            addr: Operand::Immediate(_),
1116            ..
1117        } if *space == MemorySpace::Shared => {
1118            result.errors.push(IrValidationError {
1119                instruction_index: idx,
1120                kind: IrErrorKind::InvalidMemorySpace,
1121                message: "shared memory load with immediate address is invalid".to_string(),
1122            });
1123        }
1124        Instruction::Store {
1125            space,
1126            addr: Operand::Immediate(_),
1127            ..
1128        } if *space == MemorySpace::Shared => {
1129            result.errors.push(IrValidationError {
1130                instruction_index: idx,
1131                kind: IrErrorKind::InvalidMemorySpace,
1132                message: "shared memory store with immediate address is invalid".to_string(),
1133            });
1134        }
1135        _ => {}
1136    }
1137}
1138
1139/// Validate tensor core instruction operands.
1140fn validate_tensor_core_operands(inst: &Instruction, idx: usize, result: &mut IrValidationResult) {
1141    match inst {
1142        Instruction::Wmma { addr, stride, .. } => {
1143            // Check that addr/stride are not immediates (should be registers/addresses)
1144            if let Some(Operand::Immediate(_)) = addr.as_ref() {
1145                result.errors.push(IrValidationError {
1146                    instruction_index: idx,
1147                    kind: IrErrorKind::InvalidOperand,
1148                    message: "wmma address operand must not be an immediate value".to_string(),
1149                });
1150            }
1151            if let Some(Operand::Immediate(_)) = stride.as_ref() {
1152                result.errors.push(IrValidationError {
1153                    instruction_index: idx,
1154                    kind: IrErrorKind::InvalidOperand,
1155                    message: "wmma stride operand must not be an immediate value".to_string(),
1156                });
1157            }
1158        }
1159        Instruction::Mma {
1160            a_regs,
1161            b_regs,
1162            c_regs,
1163            d_regs,
1164            ..
1165        }
1166            // All registers should be non-empty
1167            if (a_regs.is_empty() || b_regs.is_empty() || c_regs.is_empty() || d_regs.is_empty()) => {
1168                result.errors.push(IrValidationError {
1169                    instruction_index: idx,
1170                    kind: IrErrorKind::InvalidOperand,
1171                    message: "mma instruction requires non-empty register fragments".to_string(),
1172                });
1173            }
1174        Instruction::Wgmma { d_regs, .. }
1175            if d_regs.is_empty() => {
1176                result.errors.push(IrValidationError {
1177                    instruction_index: idx,
1178                    kind: IrErrorKind::InvalidOperand,
1179                    message: "wgmma instruction requires non-empty destination registers".to_string(),
1180                });
1181            }
1182        _ => {}
1183    }
1184}
1185
1186/// Check if barrier instructions might be inside divergent control flow.
1187fn check_barrier_divergence(instructions: &[Instruction], result: &mut IrValidationResult) {
1188    // Collect all labels in the program
1189    let all_labels: HashSet<&str> = instructions
1190        .iter()
1191        .filter_map(|inst| {
1192            if let Instruction::Label(name) = inst {
1193                Some(name.as_str())
1194            } else {
1195                None
1196            }
1197        })
1198        .collect();
1199
1200    let mut in_conditional_region = false;
1201    let mut conditional_branch_idx = 0;
1202
1203    for (idx, inst) in instructions.iter().enumerate() {
1204        match inst {
1205            Instruction::Branch {
1206                predicate: Some(_),
1207                target,
1208                ..
1209            }
1210                // A conditional branch that targets a label creates a divergent
1211                // region until we reach that label
1212                if all_labels.contains(target.as_str()) => {
1213                    in_conditional_region = true;
1214                    conditional_branch_idx = idx;
1215                }
1216            Instruction::Label(_) => {
1217                // Reaching a label ends the conditional region
1218                in_conditional_region = false;
1219            }
1220            Instruction::BarSync { .. }
1221                if in_conditional_region => {
1222                    result.warnings.push(IrValidationWarning {
1223                        instruction_index: idx,
1224                        message: format!(
1225                            "bar.sync inside potentially divergent control flow \
1226                             (conditional branch at instruction {conditional_branch_idx}); \
1227                             this may cause deadlock if not all threads reach the barrier"
1228                        ),
1229                    });
1230                }
1231            _ => {}
1232        }
1233    }
1234}
1235
1236/// Check for potential shared memory race conditions.
1237///
1238/// Warns if there are shared memory stores without a subsequent `bar.sync`
1239/// before the next shared memory load.
1240fn check_shared_memory_races(instructions: &[Instruction], result: &mut IrValidationResult) {
1241    let mut pending_shared_store: Option<usize> = None;
1242
1243    for (idx, inst) in instructions.iter().enumerate() {
1244        match inst {
1245            Instruction::Store {
1246                space: MemorySpace::Shared,
1247                ..
1248            } => {
1249                pending_shared_store = Some(idx);
1250            }
1251            Instruction::BarSync { .. } => {
1252                // Barrier clears the pending shared store
1253                pending_shared_store = None;
1254            }
1255            Instruction::Load {
1256                space: MemorySpace::Shared,
1257                ..
1258            } => {
1259                if let Some(store_idx) = pending_shared_store {
1260                    result.warnings.push(IrValidationWarning {
1261                        instruction_index: idx,
1262                        message: format!(
1263                            "shared memory load without bar.sync after shared memory \
1264                             store at instruction {store_idx}; potential race condition"
1265                        ),
1266                    });
1267                }
1268            }
1269            _ => {}
1270        }
1271    }
1272}
1273
1274#[cfg(test)]
1275mod tests {
1276    use super::*;
1277    use crate::ir::{
1278        CacheQualifier, ImmValue, Instruction, MemorySpace, Operand, PtxType, Register, SpecialReg,
1279        VectorWidth, WmmaLayout, WmmaOp, WmmaShape,
1280    };
1281
1282    #[test]
1283    fn valid_minimal_ptx() {
1284        let ptx = ".version 8.5\n.target sm_90a\n.address_size 64\n";
1285        let result = validate_ptx(ptx);
1286        assert!(result.is_ok());
1287        assert!(result.errors.is_empty());
1288    }
1289
1290    #[test]
1291    fn missing_version() {
1292        let ptx = ".target sm_80\n.address_size 64\n";
1293        let result = validate_ptx(ptx);
1294        assert!(result.has_errors());
1295        assert!(
1296            result
1297                .errors
1298                .iter()
1299                .any(|e| matches!(e, ValidationError::MissingVersionDirective))
1300        );
1301    }
1302
1303    #[test]
1304    fn missing_target() {
1305        let ptx = ".version 8.5\n.address_size 64\n";
1306        let result = validate_ptx(ptx);
1307        assert!(result.has_errors());
1308        assert!(
1309            result
1310                .errors
1311                .iter()
1312                .any(|e| matches!(e, ValidationError::MissingTargetDirective))
1313        );
1314    }
1315
1316    #[test]
1317    fn shared_memory_within_limits() {
1318        let ptx = ".version 8.5\n.target sm_80\n.address_size 64\n\
1319                    .shared .align 4 .b8 smem[4096];\n";
1320        let result = validate_ptx(ptx);
1321        assert!(result.is_ok());
1322    }
1323
1324    #[test]
1325    fn shared_memory_exceeds_limits() {
1326        // sm_75 max is 65536
1327        let ptx = ".version 6.4\n.target sm_75\n.address_size 64\n\
1328                    .shared .align 4 .b8 smem[100000];\n";
1329        let result = validate_ptx(ptx);
1330        assert!(result.has_errors());
1331        assert!(
1332            result
1333                .errors
1334                .iter()
1335                .any(|e| matches!(e, ValidationError::InvalidSharedMemSize { .. }))
1336        );
1337    }
1338
1339    #[test]
1340    fn validate_for_specific_target() {
1341        let ptx = ".version 8.5\n.target sm_80\n.address_size 64\n\
1342                    .shared .align 4 .b8 smem[200000];\n";
1343        let result = validate_ptx_for_target(ptx, SmVersion::Sm80);
1344        // 200000 > 163840 (sm_80 limit)
1345        assert!(result.has_errors());
1346    }
1347
1348    #[test]
1349    fn extract_shared_mem_size_fn() {
1350        assert_eq!(
1351            extract_shared_mem_size("    .shared .align 4 .b8 smem[4096];"),
1352            Some(4096)
1353        );
1354        assert_eq!(
1355            extract_shared_mem_size("    .shared .align 16 .b8 tile[65536];"),
1356            Some(65536)
1357        );
1358        assert_eq!(extract_shared_mem_size("    mov.u32 %r0, 0;"), None);
1359    }
1360
1361    #[test]
1362    fn parse_sm_version_fn() {
1363        assert_eq!(parse_sm_version("sm_80"), Some(SmVersion::Sm80));
1364        assert_eq!(parse_sm_version("sm_90a"), Some(SmVersion::Sm90a));
1365        assert_eq!(parse_sm_version("sm_100"), Some(SmVersion::Sm100));
1366        assert_eq!(parse_sm_version("sm_999"), None);
1367    }
1368
1369    #[test]
1370    fn mismatched_braces_warning() {
1371        let ptx = ".version 8.5\n.target sm_80\n.address_size 64\n{\n";
1372        let result = validate_ptx(ptx);
1373        assert!(!result.warnings.is_empty());
1374    }
1375
1376    #[test]
1377    fn validation_error_display() {
1378        let err = ValidationError::MissingVersionDirective;
1379        assert_eq!(format!("{err}"), "missing .version directive");
1380
1381        let err = ValidationError::InvalidSharedMemSize {
1382            declared: 100_000,
1383            max_allowed: 65536,
1384        };
1385        assert!(format!("{err}").contains("100000"));
1386    }
1387
1388    // -----------------------------------------------------------------------
1389    // IR-level validation tests
1390    // -----------------------------------------------------------------------
1391
1392    fn reg(name: &str, ty: PtxType) -> Register {
1393        Register {
1394            name: name.to_string(),
1395            ty,
1396        }
1397    }
1398
1399    fn reg_op(name: &str, ty: PtxType) -> Operand {
1400        Operand::Register(reg(name, ty))
1401    }
1402
1403    #[test]
1404    fn ir_type_compatible_arithmetic_passes() {
1405        let instructions = vec![
1406            Instruction::LoadParam {
1407                ty: PtxType::F32,
1408                dst: reg("%f0", PtxType::F32),
1409                param_name: "a".to_string(),
1410            },
1411            Instruction::LoadParam {
1412                ty: PtxType::F32,
1413                dst: reg("%f1", PtxType::F32),
1414                param_name: "b".to_string(),
1415            },
1416            Instruction::Add {
1417                ty: PtxType::F32,
1418                dst: reg("%f2", PtxType::F32),
1419                a: reg_op("%f0", PtxType::F32),
1420                b: reg_op("%f1", PtxType::F32),
1421            },
1422        ];
1423        let result = validate_ir_instructions(&instructions);
1424        assert!(
1425            result.errors.is_empty(),
1426            "expected no errors, got: {:?}",
1427            result.errors
1428        );
1429    }
1430
1431    #[test]
1432    fn ir_type_mismatched_arithmetic_fails() {
1433        let instructions = vec![
1434            Instruction::LoadParam {
1435                ty: PtxType::F32,
1436                dst: reg("%f0", PtxType::F32),
1437                param_name: "a".to_string(),
1438            },
1439            Instruction::LoadParam {
1440                ty: PtxType::U32,
1441                dst: reg("%r0", PtxType::U32),
1442                param_name: "b".to_string(),
1443            },
1444            Instruction::Add {
1445                ty: PtxType::F32,
1446                dst: reg("%f1", PtxType::F32),
1447                a: reg_op("%f0", PtxType::F32),
1448                b: reg_op("%r0", PtxType::U32), // mismatch
1449            },
1450        ];
1451        let result = validate_ir_instructions(&instructions);
1452        assert!(result.has_errors());
1453        assert!(
1454            result
1455                .errors
1456                .iter()
1457                .any(|e| e.kind == IrErrorKind::TypeMismatch)
1458        );
1459    }
1460
1461    #[test]
1462    fn ir_use_before_def_detection() {
1463        let instructions = vec![Instruction::Add {
1464            ty: PtxType::F32,
1465            dst: reg("%f2", PtxType::F32),
1466            a: reg_op("%f0", PtxType::F32), // never defined
1467            b: reg_op("%f1", PtxType::F32), // never defined
1468        }];
1469        let result = validate_ir_instructions(&instructions);
1470        assert!(result.has_errors());
1471        let ubd_count = result
1472            .errors
1473            .iter()
1474            .filter(|e| e.kind == IrErrorKind::UseBeforeDef)
1475            .count();
1476        assert!(ubd_count >= 2, "expected at least 2 use-before-def errors");
1477    }
1478
1479    #[test]
1480    fn ir_load_param_counted_as_definition() {
1481        let instructions = vec![
1482            Instruction::LoadParam {
1483                ty: PtxType::U64,
1484                dst: reg("%rd0", PtxType::U64),
1485                param_name: "ptr".to_string(),
1486            },
1487            Instruction::Load {
1488                space: MemorySpace::Global,
1489                qualifier: CacheQualifier::None,
1490                vec: VectorWidth::V1,
1491                ty: PtxType::F32,
1492                dst: reg("%f0", PtxType::F32),
1493                addr: Operand::Address {
1494                    base: reg("%rd0", PtxType::U64),
1495                    offset: None,
1496                },
1497            },
1498        ];
1499        let result = validate_register_lifetimes(&instructions);
1500        assert!(
1501            result.errors.is_empty(),
1502            "LoadParam should count as definition: {:?}",
1503            result.errors
1504        );
1505    }
1506
1507    #[test]
1508    fn ir_mov_special_counted_as_definition() {
1509        let instructions = vec![
1510            Instruction::MovSpecial {
1511                dst: reg("%r0", PtxType::U32),
1512                special: SpecialReg::TidX,
1513            },
1514            Instruction::Add {
1515                ty: PtxType::U32,
1516                dst: reg("%r1", PtxType::U32),
1517                a: reg_op("%r0", PtxType::U32),
1518                b: Operand::Immediate(ImmValue::U32(1)),
1519            },
1520        ];
1521        let result = validate_register_lifetimes(&instructions);
1522        assert!(
1523            result.errors.is_empty(),
1524            "MovSpecial should count as definition: {:?}",
1525            result.errors
1526        );
1527    }
1528
1529    #[test]
1530    fn ir_shared_store_without_barrier_warns() {
1531        let addr_reg = reg("%rd0", PtxType::U64);
1532        let instructions = vec![
1533            Instruction::LoadParam {
1534                ty: PtxType::U64,
1535                dst: addr_reg.clone(),
1536                param_name: "addr".to_string(),
1537            },
1538            Instruction::LoadParam {
1539                ty: PtxType::F32,
1540                dst: reg("%f0", PtxType::F32),
1541                param_name: "val".to_string(),
1542            },
1543            Instruction::Store {
1544                space: MemorySpace::Shared,
1545                qualifier: CacheQualifier::None,
1546                vec: VectorWidth::V1,
1547                ty: PtxType::F32,
1548                addr: Operand::Address {
1549                    base: addr_reg.clone(),
1550                    offset: None,
1551                },
1552                src: reg("%f0", PtxType::F32),
1553            },
1554            // No bar.sync here!
1555            Instruction::Load {
1556                space: MemorySpace::Shared,
1557                qualifier: CacheQualifier::None,
1558                vec: VectorWidth::V1,
1559                ty: PtxType::F32,
1560                dst: reg("%f1", PtxType::F32),
1561                addr: Operand::Address {
1562                    base: addr_reg,
1563                    offset: Some(4),
1564                },
1565            },
1566        ];
1567        let result = validate_memory_consistency(&instructions);
1568        assert!(
1569            !result.warnings.is_empty(),
1570            "expected race condition warning"
1571        );
1572        assert!(
1573            result.warnings[0].message.contains("race condition"),
1574            "warning should mention race condition"
1575        );
1576    }
1577
1578    #[test]
1579    fn ir_barrier_after_shared_store_no_warning() {
1580        let addr_reg = reg("%rd0", PtxType::U64);
1581        let instructions = vec![
1582            Instruction::Store {
1583                space: MemorySpace::Shared,
1584                qualifier: CacheQualifier::None,
1585                vec: VectorWidth::V1,
1586                ty: PtxType::F32,
1587                addr: Operand::Address {
1588                    base: addr_reg.clone(),
1589                    offset: None,
1590                },
1591                src: reg("%f0", PtxType::F32),
1592            },
1593            Instruction::BarSync { id: 0 },
1594            Instruction::Load {
1595                space: MemorySpace::Shared,
1596                qualifier: CacheQualifier::None,
1597                vec: VectorWidth::V1,
1598                ty: PtxType::F32,
1599                dst: reg("%f1", PtxType::F32),
1600                addr: Operand::Address {
1601                    base: addr_reg,
1602                    offset: Some(4),
1603                },
1604            },
1605        ];
1606        let result = validate_memory_consistency(&instructions);
1607        assert!(
1608            result.warnings.is_empty(),
1609            "expected no warnings when barrier separates store/load"
1610        );
1611    }
1612
1613    #[test]
1614    fn ir_empty_instruction_list_no_errors() {
1615        let result = validate_ir_instructions(&[]);
1616        assert!(result.is_ok());
1617        assert!(result.warnings.is_empty());
1618    }
1619
1620    #[test]
1621    fn ir_complex_sequence_multiple_issues() {
1622        let instructions = vec![
1623            // Use-before-def: %f0 never defined
1624            Instruction::Add {
1625                ty: PtxType::F32,
1626                dst: reg("%f1", PtxType::F32),
1627                a: reg_op("%f0", PtxType::F32),
1628                b: Operand::Immediate(ImmValue::F32(1.0)),
1629            },
1630            // Type mismatch: dst is U32 but instruction type is F32
1631            Instruction::Sub {
1632                ty: PtxType::F32,
1633                dst: reg("%r0", PtxType::U32),
1634                a: reg_op("%f1", PtxType::F32),
1635                b: Operand::Immediate(ImmValue::F32(2.0)),
1636            },
1637        ];
1638        let result = validate_ir_instructions(&instructions);
1639        assert!(result.has_errors());
1640
1641        let has_ubd = result
1642            .errors
1643            .iter()
1644            .any(|e| e.kind == IrErrorKind::UseBeforeDef);
1645        let has_type_mismatch = result
1646            .errors
1647            .iter()
1648            .any(|e| e.kind == IrErrorKind::TypeMismatch);
1649        assert!(has_ubd, "expected use-before-def error");
1650        assert!(has_type_mismatch, "expected type mismatch error");
1651    }
1652
1653    #[test]
1654    fn ir_validate_register_lifetimes_standalone() {
1655        let instructions = vec![
1656            Instruction::LoadParam {
1657                ty: PtxType::F32,
1658                dst: reg("%f0", PtxType::F32),
1659                param_name: "x".to_string(),
1660            },
1661            Instruction::Neg {
1662                ty: PtxType::F32,
1663                dst: reg("%f1", PtxType::F32),
1664                src: reg_op("%f0", PtxType::F32),
1665            },
1666            // %f99 never defined
1667            Instruction::Add {
1668                ty: PtxType::F32,
1669                dst: reg("%f2", PtxType::F32),
1670                a: reg_op("%f1", PtxType::F32),
1671                b: reg_op("%f99", PtxType::F32),
1672            },
1673        ];
1674        let result = validate_register_lifetimes(&instructions);
1675        assert!(result.has_errors());
1676        assert_eq!(result.errors.len(), 1);
1677        assert!(result.errors[0].message.contains("%f99"));
1678    }
1679
1680    #[test]
1681    fn ir_validate_memory_consistency_standalone() {
1682        // Conditional branch followed by bar.sync => divergence warning
1683        let instructions = vec![
1684            Instruction::LoadParam {
1685                ty: PtxType::U32,
1686                dst: reg("%p0", PtxType::Pred),
1687                param_name: "pred".to_string(),
1688            },
1689            Instruction::Branch {
1690                target: "skip".to_string(),
1691                predicate: Some((reg("%p0", PtxType::Pred), false)),
1692            },
1693            Instruction::BarSync { id: 0 },
1694            Instruction::Label("skip".to_string()),
1695        ];
1696        let result = validate_memory_consistency(&instructions);
1697        assert!(!result.warnings.is_empty(), "expected divergence warning");
1698        assert!(result.warnings[0].message.contains("divergent"));
1699    }
1700
1701    #[test]
1702    fn ir_validation_result_display() {
1703        let result = IrValidationResult {
1704            errors: vec![IrValidationError {
1705                instruction_index: 3,
1706                kind: IrErrorKind::TypeMismatch,
1707                message: "dst type does not match".to_string(),
1708            }],
1709            warnings: vec![IrValidationWarning {
1710                instruction_index: 7,
1711                message: "possible race".to_string(),
1712            }],
1713        };
1714        let display = format!("{result}");
1715        assert!(display.contains("Errors (1)"));
1716        assert!(display.contains("TypeMismatch"));
1717        assert!(display.contains("Warnings (1)"));
1718        assert!(display.contains("possible race"));
1719
1720        // Also test the all-clear case
1721        let ok_result = IrValidationResult {
1722            errors: Vec::new(),
1723            warnings: Vec::new(),
1724        };
1725        let ok_display = format!("{ok_result}");
1726        assert!(ok_display.contains("passed"));
1727    }
1728
1729    #[test]
1730    fn ir_wmma_with_immediate_operand_flagged() {
1731        let instructions = vec![Instruction::Wmma {
1732            op: WmmaOp::LoadA,
1733            shape: WmmaShape::M16N16K16,
1734            layout: WmmaLayout::RowMajor,
1735            ty: PtxType::F16,
1736            fragments: vec![reg("%f0", PtxType::F16)],
1737            addr: Some(Operand::Immediate(ImmValue::U32(0))), // invalid!
1738            stride: Some(Operand::Immediate(ImmValue::U32(16))), // invalid!
1739        }];
1740        let result = validate_ir_instructions(&instructions);
1741        let invalid_operand_errors: Vec<_> = result
1742            .errors
1743            .iter()
1744            .filter(|e| e.kind == IrErrorKind::InvalidOperand)
1745            .collect();
1746        assert!(
1747            invalid_operand_errors.len() >= 2,
1748            "expected at least 2 InvalidOperand errors for wmma immediates, got {}",
1749            invalid_operand_errors.len()
1750        );
1751    }
1752
1753    #[test]
1754    fn ir_mixed_valid_and_invalid_instructions() {
1755        let instructions = vec![
1756            // Valid: load param
1757            Instruction::LoadParam {
1758                ty: PtxType::F32,
1759                dst: reg("%f0", PtxType::F32),
1760                param_name: "x".to_string(),
1761            },
1762            // Valid: mov special
1763            Instruction::MovSpecial {
1764                dst: reg("%r0", PtxType::U32),
1765                special: SpecialReg::TidX,
1766            },
1767            // Valid: add with matching types
1768            Instruction::Add {
1769                ty: PtxType::F32,
1770                dst: reg("%f1", PtxType::F32),
1771                a: reg_op("%f0", PtxType::F32),
1772                b: Operand::Immediate(ImmValue::F32(1.0)),
1773            },
1774            // Invalid: sub with mismatched dst type
1775            Instruction::Sub {
1776                ty: PtxType::F32,
1777                dst: reg("%bad", PtxType::U32), // type mismatch
1778                a: reg_op("%f1", PtxType::F32),
1779                b: Operand::Immediate(ImmValue::F32(0.5)),
1780            },
1781            // Valid: comment (no validation issues)
1782            Instruction::Comment("test".to_string()),
1783            // Valid: return
1784            Instruction::Return,
1785        ];
1786        let result = validate_ir_instructions(&instructions);
1787        // Should have exactly 1 type mismatch error (dst type mismatch on Sub)
1788        let type_errors: Vec<_> = result
1789            .errors
1790            .iter()
1791            .filter(|e| e.kind == IrErrorKind::TypeMismatch)
1792            .collect();
1793        assert_eq!(
1794            type_errors.len(),
1795            1,
1796            "expected exactly 1 type mismatch, got {}: {:?}",
1797            type_errors.len(),
1798            type_errors
1799        );
1800        // No use-before-def errors since all regs are defined
1801        let ubd_errors: Vec<_> = result
1802            .errors
1803            .iter()
1804            .filter(|e| e.kind == IrErrorKind::UseBeforeDef)
1805            .collect();
1806        assert!(
1807            ubd_errors.is_empty(),
1808            "expected no use-before-def errors: {ubd_errors:?}",
1809        );
1810    }
1811}
1812
1813/// SM compatibility, register pressure, and related tests — kept in a separate
1814/// file to stay under the 2 000-line limit.
1815#[cfg(test)]
1816#[path = "validator_tests.rs"]
1817mod sm_tests;