use std::collections::HashMap;
use super::types::{MemoryViolation, MemoryViolationType, SourceLocation};
pub struct SanitizerParser;
impl SanitizerParser {
pub fn parse(output: &str) -> Vec<MemoryViolation> {
let mut violations = Vec::new();
let lines: Vec<&str> = output.lines().collect();
let mut i = 0;
while i < lines.len() {
let line = lines[i];
if line.contains("Invalid __") {
if let Some(violation) = Self::parse_violation(&lines[i..]) {
violations.push(violation);
}
}
i += 1;
}
violations
}
fn parse_violation(lines: &[&str]) -> Option<MemoryViolation> {
let first_line = lines.first()?;
let violation_type = Self::classify_violation(first_line);
let mut kernel_name = String::from("unknown");
let mut sass_offset = 0u64;
let mut thread = (0u32, 0u32, 0u32);
let mut block = (0u32, 0u32, 0u32);
let mut address = 0u64;
for line in lines.iter().skip(1).take(10) {
if let Some((name, offset)) = Self::parse_kernel_location(line) {
kernel_name = name;
sass_offset = offset;
}
if let Some(parsed) = Self::parse_thread_block(line) {
thread = parsed.0;
block = parsed.1;
}
if let Some(addr) = Self::parse_address(line) {
address = addr;
}
}
Some(MemoryViolation {
violation_type,
kernel_name,
sass_offset,
thread,
block,
address,
raw_message: lines.iter().take(5).copied().collect::<Vec<_>>().join("\n"),
})
}
fn classify_violation(line: &str) -> MemoryViolationType {
if line.contains("__shared__ read") {
MemoryViolationType::InvalidSharedRead { size: Self::extract_size(line).unwrap_or(4) }
} else if line.contains("__shared__ write") {
MemoryViolationType::InvalidSharedWrite { size: Self::extract_size(line).unwrap_or(4) }
} else if line.contains("__global__ read") {
MemoryViolationType::InvalidGlobalRead { size: Self::extract_size(line).unwrap_or(4) }
} else if line.contains("__global__ write") {
MemoryViolationType::InvalidGlobalWrite { size: Self::extract_size(line).unwrap_or(4) }
} else if line.contains("misaligned") {
MemoryViolationType::MisalignedAccess { addr: 0 }
} else {
MemoryViolationType::Other(line.to_string())
}
}
fn parse_kernel_location(line: &str) -> Option<(String, u64)> {
let at_pos = line.find(" at ")?;
let rest = &line[at_pos + 4..];
let plus_pos = rest.find("+0x")?;
let kernel_name = rest[..plus_pos].trim().to_string();
let hex_str = &rest[plus_pos + 3..];
let hex_end = hex_str.find(|c: char| !c.is_ascii_hexdigit()).unwrap_or(hex_str.len());
let offset = u64::from_str_radix(&hex_str[..hex_end], 16).unwrap_or(0);
Some((kernel_name, offset))
}
fn parse_address(line: &str) -> Option<u64> {
let addr_pos = line.find("Address 0x")?;
let rest = &line[addr_pos + 10..];
let hex_end = rest.find(|c: char| !c.is_ascii_hexdigit()).unwrap_or(rest.len());
Some(u64::from_str_radix(&rest[..hex_end], 16).unwrap_or(0))
}
fn extract_size(line: &str) -> Option<usize> {
if let Some(pos) = line.find("of size ") {
let rest = &line[pos + 8..];
let num_end = rest.find(' ').unwrap_or(rest.len());
rest[..num_end].parse().ok()
} else {
None
}
}
fn parse_thread_block(line: &str) -> Option<((u32, u32, u32), (u32, u32, u32))> {
let parse_triple = |s: &str| -> Option<(u32, u32, u32)> {
let s = s.trim_start_matches('(').trim_end_matches(')');
let parts: Vec<&str> = s.split(',').collect();
if parts.len() == 3 {
Some((parts[0].parse().ok()?, parts[1].parse().ok()?, parts[2].parse().ok()?))
} else {
None
}
};
let thread_pos = line.find("thread (")?;
let thread_end = line[thread_pos..].find(')')?;
let thread_str = &line[thread_pos + 7..thread_pos + thread_end + 1];
let block_pos = line.find("block (")?;
let block_end = line[block_pos..].find(')')?;
let block_str = &line[block_pos + 6..block_pos + block_end + 1];
Some((parse_triple(thread_str)?, parse_triple(block_str)?))
}
}
pub struct PtxSourceMap {
ptx_source: String,
line_map: HashMap<u32, SourceLocation>,
pub(super) label_lines: HashMap<String, u32>,
}
impl PtxSourceMap {
pub fn new(ptx_source: &str) -> Self {
let mut map = Self {
ptx_source: ptx_source.to_string(),
line_map: HashMap::new(),
label_lines: HashMap::new(),
};
map.parse_ptx();
map
}
fn parse_ptx(&mut self) {
for (line_num, line) in self.ptx_source.lines().enumerate() {
let line_num = (line_num + 1) as u32;
let trimmed = line.trim();
if trimmed.ends_with(':') && !trimmed.starts_with("//") {
let label = trimmed.trim_end_matches(':');
self.label_lines.insert(label.to_string(), line_num);
}
if trimmed.starts_with(".loc ") {
let parts: Vec<&str> = trimmed.split_whitespace().collect();
if parts.len() >= 3 {
if let Ok(src_line) = parts[2].parse::<u32>() {
self.line_map.insert(
line_num,
SourceLocation {
file: "kernel.rs".to_string(),
line: src_line,
column: parts.get(3).and_then(|s| s.parse().ok()),
function: None,
},
);
}
}
}
}
}
pub fn label_at_line(&self, target_line: u32) -> Option<&str> {
let mut best_label = None;
let mut best_line = 0;
for (label, &line) in &self.label_lines {
if line <= target_line && line > best_line {
best_line = line;
best_label = Some(label.as_str());
}
}
best_label
}
pub fn context_around_label(&self, label: &str, context_lines: usize) -> Option<String> {
let &label_line = self.label_lines.get(label)?;
let lines: Vec<&str> = self.ptx_source.lines().collect();
let start = (label_line as usize).saturating_sub(context_lines);
let end = ((label_line as usize) + context_lines).min(lines.len());
let mut result = String::new();
for (i, line) in lines[start..end].iter().enumerate() {
let actual_line = start + i + 1;
let marker = if actual_line == label_line as usize { ">>>" } else { " " };
result.push_str(&format!("{} {:4}: {}\n", marker, actual_line, line));
}
Some(result)
}
}