use crate::error::Result;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum MudaType {
Transport,
Waiting,
Overprocessing,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct MudaWarning {
pub muda_type: MudaType,
pub description: String,
pub impact: String,
pub line: Option<usize>,
pub suggestion: Option<String>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
pub struct RegisterUsage {
pub f32_regs: u32,
pub f64_regs: u32,
pub b32_regs: u32,
pub b64_regs: u32,
pub pred_regs: u32,
}
impl RegisterUsage {
#[must_use]
pub fn total(&self) -> u32 {
self.f32_regs + self.f64_regs + self.b32_regs + self.b64_regs + self.pred_regs
}
#[must_use]
pub fn estimated_occupancy(&self) -> f32 {
let total = self.total();
if total == 0 {
return 1.0;
}
let max_threads = (65536 / total.max(1)).min(2048);
max_threads as f32 / 2048.0
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
pub struct MemoryPattern {
pub global_loads: u32,
pub global_stores: u32,
pub shared_loads: u32,
pub shared_stores: u32,
pub coalesced_ratio: f32,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
pub struct RooflineMetric {
pub arithmetic_intensity: f32,
pub theoretical_peak_gflops: f32,
pub memory_bound: bool,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct AnalysisReport {
pub name: String,
pub target: String,
pub registers: RegisterUsage,
pub memory: MemoryPattern,
pub roofline: RooflineMetric,
pub warnings: Vec<MudaWarning>,
pub instruction_count: u32,
pub estimated_occupancy: f32,
}
pub trait Analyzer {
fn target_name(&self) -> &str;
fn analyze(&self, code: &str) -> Result<AnalysisReport>;
fn detect_muda(&self, code: &str) -> Vec<MudaWarning>;
fn estimate_roofline(&self, analysis: &AnalysisReport) -> RooflineMetric;
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_register_usage_total() {
let usage = RegisterUsage {
f32_regs: 10,
f64_regs: 5,
b32_regs: 8,
b64_regs: 4,
pred_regs: 2,
};
assert_eq!(usage.total(), 29);
}
#[test]
fn test_register_usage_total_empty() {
let usage = RegisterUsage::default();
assert_eq!(usage.total(), 0);
}
#[test]
fn test_occupancy_low_registers() {
let usage = RegisterUsage {
f32_regs: 16,
..Default::default()
};
assert!((usage.estimated_occupancy() - 1.0).abs() < 0.01);
}
#[test]
fn test_occupancy_high_registers() {
let usage = RegisterUsage {
f32_regs: 128,
..Default::default()
};
assert!((usage.estimated_occupancy() - 0.25).abs() < 0.01);
}
#[test]
fn test_occupancy_zero_registers() {
let usage = RegisterUsage::default();
assert!((usage.estimated_occupancy() - 1.0).abs() < 0.01);
}
#[test]
fn test_muda_warning_serialization() {
let warning = MudaWarning {
muda_type: MudaType::Transport,
description: "5 register spills detected".to_string(),
impact: "High latency local memory access".to_string(),
line: Some(42),
suggestion: Some("Reduce live variables".to_string()),
};
let json = serde_json::to_string(&warning).unwrap();
let parsed: MudaWarning = serde_json::from_str(&json).unwrap();
assert_eq!(warning, parsed);
}
#[test]
fn test_analysis_report_serialization() {
let report = AnalysisReport {
name: "test_kernel".to_string(),
target: "PTX".to_string(),
registers: RegisterUsage {
f32_regs: 24,
b32_regs: 18,
..Default::default()
},
memory: MemoryPattern {
global_loads: 100,
coalesced_ratio: 0.95,
..Default::default()
},
warnings: vec![],
instruction_count: 150,
estimated_occupancy: 0.875,
..Default::default()
};
let json = serde_json::to_string_pretty(&report).unwrap();
assert!(json.contains("test_kernel"));
assert!(json.contains("PTX"));
}
}