use super::Severity;
use crate::parser::SourceLocation;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum BugClass {
GenericAddressCorruption,
SharedMemU64Addressing,
MissingDirectShared,
MissingBarrierSync,
RegisterTypeInvariant,
UnalignedMemoryAccess,
DataDependentStore,
ComputedAddrFromLoaded,
SequentialCodeSensitivity,
LoopCvtaShared,
IncompatibleAddressSpace,
}
impl std::fmt::Display for BugClass {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
BugClass::GenericAddressCorruption => write!(f, "GenericAddressCorruption"),
BugClass::SharedMemU64Addressing => write!(f, "SharedMemU64Addressing"),
BugClass::MissingDirectShared => write!(f, "MissingDirectShared"),
BugClass::MissingBarrierSync => write!(f, "MissingBarrierSync"),
BugClass::RegisterTypeInvariant => write!(f, "RegisterTypeInvariant"),
BugClass::UnalignedMemoryAccess => write!(f, "UnalignedMemoryAccess"),
BugClass::DataDependentStore => write!(f, "DataDependentStore"),
BugClass::ComputedAddrFromLoaded => write!(f, "ComputedAddrFromLoaded"),
BugClass::SequentialCodeSensitivity => write!(f, "SequentialCodeSensitivity"),
BugClass::LoopCvtaShared => write!(f, "LoopCvtaShared"),
BugClass::IncompatibleAddressSpace => write!(f, "IncompatibleAddressSpace"),
}
}
}
impl BugClass {
pub fn severity(&self) -> Severity {
match self {
BugClass::GenericAddressCorruption => Severity::Critical,
BugClass::SharedMemU64Addressing => Severity::High,
BugClass::MissingDirectShared => Severity::High,
BugClass::MissingBarrierSync => Severity::High,
BugClass::RegisterTypeInvariant => Severity::Medium,
BugClass::UnalignedMemoryAccess => Severity::High,
BugClass::DataDependentStore => Severity::Critical,
BugClass::ComputedAddrFromLoaded => Severity::Critical,
BugClass::SequentialCodeSensitivity => Severity::High,
BugClass::LoopCvtaShared => Severity::High,
BugClass::IncompatibleAddressSpace => Severity::Medium,
}
}
pub fn description(&self) -> &'static str {
match self {
BugClass::GenericAddressCorruption => {
"cvta.shared creates 64-bit generic address that SASS clobbers"
}
BugClass::SharedMemU64Addressing => {
"Using u64 for shared memory addresses (should use 32-bit offset)"
}
BugClass::MissingDirectShared => "Using generic ld/st instead of ld.shared/st.shared",
BugClass::MissingBarrierSync => {
"Missing bar.sync between shared memory writes and reads"
}
BugClass::RegisterTypeInvariant => {
"Wrong register type for operation (e.g., f32 vs u32)"
}
BugClass::UnalignedMemoryAccess => "Non-aligned global/shared memory access",
BugClass::DataDependentStore => "Store using value derived from ld.shared crashes",
BugClass::ComputedAddrFromLoaded => {
"Address computed from ld.shared value causes store crash"
}
BugClass::SequentialCodeSensitivity => {
"Adding one instruction causes crash (ptxas JIT bug)"
}
BugClass::LoopCvtaShared => "cvta.shared inside loop causes register pressure issues",
BugClass::IncompatibleAddressSpace => "Mismatched address space qualifiers",
}
}
pub fn mitigation(&self) -> &'static str {
match self {
BugClass::GenericAddressCorruption => {
"Use direct shared memory addressing with 32-bit offsets"
}
BugClass::SharedMemU64Addressing => "Use 32-bit offset for shared memory addressing",
BugClass::MissingDirectShared => "Use ld.shared/st.shared with 32-bit offset instead",
BugClass::MissingBarrierSync => "Add bar.sync between shared memory write and read",
BugClass::RegisterTypeInvariant => {
"Ensure register type matches instruction type modifier"
}
BugClass::UnalignedMemoryAccess => "Ensure memory addresses are aligned to access size",
BugClass::DataDependentStore => "Use constant value or pre-computed address",
BugClass::ComputedAddrFromLoaded => {
"Use constant-only address computation, try membar.cta (partial), or Kernel Fission"
}
BugClass::SequentialCodeSensitivity => {
"Split kernel into simpler kernels (Kernel Fission)"
}
BugClass::LoopCvtaShared => "Move cvta.shared outside loop",
BugClass::IncompatibleAddressSpace => {
"Use explicit address space qualifiers consistently"
}
}
}
}
#[derive(Debug, Clone)]
pub enum BugPattern {
GenericSharedAccess,
MissingBarrier {
write_loc: SourceLocation,
read_loc: SourceLocation,
},
LoadedValueStore {
load_loc: SourceLocation,
store_loc: SourceLocation,
},
ComputedAddrFromLoaded {
load_loc: SourceLocation,
compute_loc: SourceLocation,
},
Other(String),
}
#[derive(Debug, Clone)]
pub struct Bug {
pub class: BugClass,
pub location: SourceLocation,
pub pattern: BugPattern,
pub message: String,
}
#[derive(Debug, Default)]
pub struct BugRegistry {
bugs: Vec<Bug>,
}
impl BugRegistry {
pub fn new() -> Self {
Self { bugs: Vec::new() }
}
pub fn add(&mut self, bug: Bug) {
self.bugs.push(bug);
}
pub fn bugs(&self) -> &[Bug] {
&self.bugs
}
pub fn critical_bugs(&self) -> Vec<&Bug> {
self.bugs
.iter()
.filter(|b| b.class.severity() == Severity::Critical)
.collect()
}
pub fn has_critical_bugs(&self) -> bool {
self.bugs
.iter()
.any(|b| b.class.severity() == Severity::Critical)
}
pub fn clear(&mut self) {
self.bugs.clear();
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_bug_class_severity() {
assert_eq!(
BugClass::GenericAddressCorruption.severity(),
Severity::Critical
);
assert_eq!(BugClass::DataDependentStore.severity(), Severity::Critical);
assert_eq!(
BugClass::ComputedAddrFromLoaded.severity(),
Severity::Critical
);
assert_eq!(BugClass::MissingBarrierSync.severity(), Severity::High);
}
#[test]
fn test_bug_registry() {
let mut registry = BugRegistry::new();
assert!(registry.bugs().is_empty());
registry.add(Bug {
class: BugClass::DataDependentStore,
location: SourceLocation::default(),
pattern: BugPattern::LoadedValueStore {
load_loc: SourceLocation::default(),
store_loc: SourceLocation::default(),
},
message: "Test bug".into(),
});
assert_eq!(registry.bugs().len(), 1);
assert!(registry.has_critical_bugs());
}
}