use std::collections::HashMap;
use std::sync::{Mutex, OnceLock};
use super::super::sys::CUdeviceptr;
#[derive(Debug, Clone)]
pub struct BufferInfo {
pub name: String,
pub ptr: CUdeviceptr,
pub size: usize,
pub type_name: String,
pub element_size: usize,
}
impl BufferInfo {
pub fn contains(&self, addr: u64) -> bool {
addr >= self.ptr && addr < self.ptr + self.size as u64
}
pub fn offset_of(&self, addr: u64) -> Option<usize> {
if self.contains(addr) {
Some((addr - self.ptr) as usize)
} else {
None
}
}
pub fn element_index_of(&self, addr: u64) -> Option<usize> {
self.offset_of(addr).map(|off| off / self.element_size)
}
}
pub struct AddressRegistry {
pub(super) buffers: HashMap<CUdeviceptr, BufferInfo>,
}
impl AddressRegistry {
pub(super) fn new() -> Self {
Self { buffers: HashMap::new() }
}
pub fn global() -> &'static Mutex<AddressRegistry> {
static REGISTRY: OnceLock<Mutex<AddressRegistry>> = OnceLock::new();
REGISTRY.get_or_init(|| Mutex::new(AddressRegistry::new()))
}
pub fn register(
&mut self,
name: impl Into<String>,
ptr: CUdeviceptr,
size: usize,
type_name: impl Into<String>,
element_size: usize,
) {
let info =
BufferInfo { name: name.into(), ptr, size, type_name: type_name.into(), element_size };
self.buffers.insert(ptr, info);
}
pub fn unregister(&mut self, ptr: CUdeviceptr) {
self.buffers.remove(&ptr);
}
pub fn lookup(&self, addr: u64) -> Option<&BufferInfo> {
if let Some(info) = self.buffers.get(&addr) {
return Some(info);
}
for info in self.buffers.values() {
if info.contains(addr) {
return Some(info);
}
}
None
}
pub fn format_address(&self, addr: u64) -> String {
if let Some(info) = self.lookup(addr) {
if let Some(offset) = info.offset_of(addr) {
let elem_idx = offset / info.element_size;
let byte_in_elem = offset % info.element_size;
if byte_in_elem == 0 {
format!("{}[{}] (0x{:X} + {} bytes)", info.name, elem_idx, info.ptr, offset)
} else {
format!(
"{}[{}]+{} (0x{:X} + {} bytes)",
info.name, elem_idx, byte_in_elem, info.ptr, offset
)
}
} else {
format!("{} @ 0x{:X}", info.name, addr)
}
} else {
format!("0x{:X} (unknown buffer)", addr)
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum MemoryViolationType {
InvalidGlobalRead {
size: usize,
},
InvalidGlobalWrite {
size: usize,
},
InvalidSharedRead {
size: usize,
},
InvalidSharedWrite {
size: usize,
},
MisalignedAccess {
addr: u64,
},
RaceCondition,
Other(String),
}
#[derive(Debug, Clone)]
pub struct MemoryViolation {
pub violation_type: MemoryViolationType,
pub kernel_name: String,
pub sass_offset: u64,
pub thread: (u32, u32, u32),
pub block: (u32, u32, u32),
pub address: u64,
pub raw_message: String,
}
impl MemoryViolation {
pub fn format_with_registry(&self, registry: &AddressRegistry) -> String {
let addr_info = registry.format_address(self.address);
let violation_desc = match &self.violation_type {
MemoryViolationType::InvalidGlobalRead { size } => {
format!("Invalid global read of {} bytes", size)
}
MemoryViolationType::InvalidGlobalWrite { size } => {
format!("Invalid global write of {} bytes", size)
}
MemoryViolationType::InvalidSharedRead { size } => {
format!("Invalid shared read of {} bytes", size)
}
MemoryViolationType::InvalidSharedWrite { size } => {
format!("Invalid shared write of {} bytes", size)
}
MemoryViolationType::MisalignedAccess { addr } => {
format!("Misaligned access at 0x{:X}", addr)
}
MemoryViolationType::RaceCondition => "Race condition detected".to_string(),
MemoryViolationType::Other(msg) => msg.clone(),
};
format!(
"🛑 MEMORY VIOLATION\n\
├─ Kernel: {} @ SASS offset 0x{:X}\n\
├─ Thread: ({}, {}, {}) in Block ({}, {}, {})\n\
├─ Error: {}\n\
└─ Address: {}",
self.kernel_name,
self.sass_offset,
self.thread.0,
self.thread.1,
self.thread.2,
self.block.0,
self.block.1,
self.block.2,
violation_desc,
addr_info
)
}
}
#[derive(Debug, Clone)]
pub struct SourceLocation {
pub file: String,
pub line: u32,
pub column: Option<u32>,
pub function: Option<String>,
}