1use crate::error::Result;
7use serde::{Deserialize, Serialize};
8
9#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
11pub enum MudaType {
12 Transport,
14 Waiting,
16 Overprocessing,
18}
19
20#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
22pub struct MudaWarning {
23 pub muda_type: MudaType,
25 pub description: String,
27 pub impact: String,
29 pub line: Option<usize>,
31 pub suggestion: Option<String>,
33}
34
35#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
37pub struct RegisterUsage {
38 pub f32_regs: u32,
40 pub f64_regs: u32,
42 pub b32_regs: u32,
44 pub b64_regs: u32,
46 pub pred_regs: u32,
48}
49
50impl RegisterUsage {
51 #[must_use]
53 pub fn total(&self) -> u32 {
54 self.f32_regs + self.f64_regs + self.b32_regs + self.b64_regs + self.pred_regs
55 }
56
57 #[must_use]
60 pub fn estimated_occupancy(&self) -> f32 {
61 let total = self.total();
62 if total == 0 {
63 return 1.0;
64 }
65 let max_threads = (65536 / total.max(1)).min(2048);
68 max_threads as f32 / 2048.0
69 }
70}
71
72#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
74pub struct MemoryPattern {
75 pub global_loads: u32,
77 pub global_stores: u32,
79 pub shared_loads: u32,
81 pub shared_stores: u32,
83 pub coalesced_ratio: f32,
85}
86
87#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
89pub struct RooflineMetric {
90 pub arithmetic_intensity: f32,
92 pub theoretical_peak_gflops: f32,
94 pub memory_bound: bool,
96}
97
98#[derive(Debug, Clone, Default, Serialize, Deserialize)]
100pub struct AnalysisReport {
101 pub name: String,
103 pub target: String,
105 pub registers: RegisterUsage,
107 pub memory: MemoryPattern,
109 pub roofline: RooflineMetric,
111 pub warnings: Vec<MudaWarning>,
113 pub instruction_count: u32,
115 pub estimated_occupancy: f32,
117}
118
119pub trait Analyzer {
121 fn target_name(&self) -> &str;
123
124 fn analyze(&self, code: &str) -> Result<AnalysisReport>;
130
131 fn detect_muda(&self, code: &str) -> Vec<MudaWarning>;
133
134 fn estimate_roofline(&self, analysis: &AnalysisReport) -> RooflineMetric;
136}
137
138#[cfg(test)]
139mod tests {
140 use super::*;
141
142 #[test]
143 fn test_register_usage_total() {
144 let usage = RegisterUsage {
145 f32_regs: 10,
146 f64_regs: 5,
147 b32_regs: 8,
148 b64_regs: 4,
149 pred_regs: 2,
150 };
151 assert_eq!(usage.total(), 29);
152 }
153
154 #[test]
155 fn test_register_usage_total_empty() {
156 let usage = RegisterUsage::default();
157 assert_eq!(usage.total(), 0);
158 }
159
160 #[test]
161 fn test_occupancy_low_registers() {
162 let usage = RegisterUsage {
163 f32_regs: 16,
164 ..Default::default()
165 };
166 assert!((usage.estimated_occupancy() - 1.0).abs() < 0.01);
168 }
169
170 #[test]
171 fn test_occupancy_high_registers() {
172 let usage = RegisterUsage {
173 f32_regs: 128,
174 ..Default::default()
175 };
176 assert!((usage.estimated_occupancy() - 0.25).abs() < 0.01);
178 }
179
180 #[test]
181 fn test_occupancy_zero_registers() {
182 let usage = RegisterUsage::default();
183 assert!((usage.estimated_occupancy() - 1.0).abs() < 0.01);
184 }
185
186 #[test]
187 fn test_muda_warning_serialization() {
188 let warning = MudaWarning {
189 muda_type: MudaType::Transport,
190 description: "5 register spills detected".to_string(),
191 impact: "High latency local memory access".to_string(),
192 line: Some(42),
193 suggestion: Some("Reduce live variables".to_string()),
194 };
195
196 let json = serde_json::to_string(&warning).unwrap();
197 let parsed: MudaWarning = serde_json::from_str(&json).unwrap();
198 assert_eq!(warning, parsed);
199 }
200
201 #[test]
202 fn test_analysis_report_serialization() {
203 let report = AnalysisReport {
204 name: "test_kernel".to_string(),
205 target: "PTX".to_string(),
206 registers: RegisterUsage {
207 f32_regs: 24,
208 b32_regs: 18,
209 ..Default::default()
210 },
211 memory: MemoryPattern {
212 global_loads: 100,
213 coalesced_ratio: 0.95,
214 ..Default::default()
215 },
216 warnings: vec![],
217 instruction_count: 150,
218 estimated_occupancy: 0.875,
219 ..Default::default()
220 };
221
222 let json = serde_json::to_string_pretty(&report).unwrap();
223 assert!(json.contains("test_kernel"));
224 assert!(json.contains("PTX"));
225 }
226}