Skip to main content

trueno_explain/
compare.rs

1//! Backend comparison module
2//!
3//! Compares kernel performance characteristics across different backends or configurations.
4
5use crate::analyzer::AnalysisReport;
6use serde::{Deserialize, Serialize};
7
8/// Comparison result for a single metric
9#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct MetricComparison {
11    /// Metric name
12    pub name: String,
13    /// Value in first report
14    pub value_a: f32,
15    /// Value in second report
16    pub value_b: f32,
17    /// Winner ("A", "B", or "Tie")
18    pub winner: String,
19    /// Notes about the comparison
20    pub notes: String,
21}
22
23/// Full comparison report
24#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct ComparisonReport {
26    /// First report name
27    pub report_a_name: String,
28    /// Second report name
29    pub report_b_name: String,
30    /// Individual metric comparisons
31    pub metrics: Vec<MetricComparison>,
32    /// Overall recommendation
33    pub recommendation: String,
34}
35
36/// Compare two analysis reports
37#[must_use]
38pub fn compare_analyses(report_a: &AnalysisReport, report_b: &AnalysisReport) -> ComparisonReport {
39    let mut metrics = Vec::new();
40
41    // Compare register usage (lower is better)
42    let regs_a = report_a.registers.total() as f32;
43    let regs_b = report_b.registers.total() as f32;
44    metrics.push(MetricComparison {
45        name: "Register Count".to_string(),
46        value_a: regs_a,
47        value_b: regs_b,
48        winner: if regs_a < regs_b {
49            "A".to_string()
50        } else if regs_b < regs_a {
51            "B".to_string()
52        } else {
53            "Tie".to_string()
54        },
55        notes: "Lower is better (higher occupancy)".to_string(),
56    });
57
58    // Compare instruction count (lower is better for same work)
59    let inst_a = report_a.instruction_count as f32;
60    let inst_b = report_b.instruction_count as f32;
61    metrics.push(MetricComparison {
62        name: "Instruction Count".to_string(),
63        value_a: inst_a,
64        value_b: inst_b,
65        winner: if inst_a < inst_b {
66            "A".to_string()
67        } else if inst_b < inst_a {
68            "B".to_string()
69        } else {
70            "Tie".to_string()
71        },
72        notes: "Lower is better (less work)".to_string(),
73    });
74
75    // Compare occupancy (higher is better)
76    let occ_a = report_a.estimated_occupancy;
77    let occ_b = report_b.estimated_occupancy;
78    metrics.push(MetricComparison {
79        name: "Estimated Occupancy".to_string(),
80        value_a: occ_a * 100.0,
81        value_b: occ_b * 100.0,
82        winner: if occ_a > occ_b {
83            "A".to_string()
84        } else if occ_b > occ_a {
85            "B".to_string()
86        } else {
87            "Tie".to_string()
88        },
89        notes: "Higher is better (GPU utilization)".to_string(),
90    });
91
92    // Compare warning counts (lower is better)
93    let warns_a = report_a.warnings.len() as f32;
94    let warns_b = report_b.warnings.len() as f32;
95    metrics.push(MetricComparison {
96        name: "Muda Warnings".to_string(),
97        value_a: warns_a,
98        value_b: warns_b,
99        winner: if warns_a < warns_b {
100            "A".to_string()
101        } else if warns_b < warns_a {
102            "B".to_string()
103        } else {
104            "Tie".to_string()
105        },
106        notes: "Lower is better (less waste)".to_string(),
107    });
108
109    // Compare memory coalescing (higher is better)
110    let coal_a = report_a.memory.coalesced_ratio;
111    let coal_b = report_b.memory.coalesced_ratio;
112    metrics.push(MetricComparison {
113        name: "Memory Coalescing".to_string(),
114        value_a: coal_a * 100.0,
115        value_b: coal_b * 100.0,
116        winner: if coal_a > coal_b {
117            "A".to_string()
118        } else if coal_b > coal_a {
119            "B".to_string()
120        } else {
121            "Tie".to_string()
122        },
123        notes: "Higher is better (bandwidth efficiency)".to_string(),
124    });
125
126    // Count wins
127    let a_wins = metrics.iter().filter(|m| m.winner == "A").count();
128    let b_wins = metrics.iter().filter(|m| m.winner == "B").count();
129
130    let recommendation = match a_wins.cmp(&b_wins) {
131        std::cmp::Ordering::Greater => {
132            format!("{} wins {} to {} metrics", report_a.name, a_wins, b_wins)
133        }
134        std::cmp::Ordering::Less => {
135            format!("{} wins {} to {} metrics", report_b.name, b_wins, a_wins)
136        }
137        std::cmp::Ordering::Equal => "Both configurations are comparable".to_string(),
138    };
139
140    ComparisonReport {
141        report_a_name: report_a.name.clone(),
142        report_b_name: report_b.name.clone(),
143        metrics,
144        recommendation,
145    }
146}
147
148/// Format comparison report as text
149#[must_use]
150pub fn format_comparison_text(report: &ComparisonReport) -> String {
151    let mut output = String::new();
152
153    output.push_str(&format!(
154        "╔══ Comparison: {} vs {} ══╗\n\n",
155        report.report_a_name, report.report_b_name
156    ));
157
158    output.push_str(&format!(
159        "{:<25} {:>12} {:>12} {:>8}\n",
160        "Metric", &report.report_a_name, &report.report_b_name, "Winner"
161    ));
162    output.push_str(&format!("{}\n", "─".repeat(60)));
163
164    for metric in &report.metrics {
165        let winner_icon = match metric.winner.as_str() {
166            "A" => "◀",
167            "B" => "▶",
168            _ => "═",
169        };
170        output.push_str(&format!(
171            "{:<25} {:>12.1} {:>12.1} {:>6} {}\n",
172            metric.name, metric.value_a, metric.value_b, winner_icon, metric.winner
173        ));
174    }
175
176    output.push_str(&format!("\n{}\n", report.recommendation));
177
178    output
179}
180
181/// Format comparison report as JSON
182#[must_use]
183pub fn format_comparison_json(report: &ComparisonReport) -> String {
184    serde_json::to_string_pretty(report).unwrap_or_else(|_| "{}".to_string())
185}
186
187#[cfg(test)]
188mod tests {
189    use super::*;
190    use crate::analyzer::{MemoryPattern, MudaWarning, RegisterUsage, RooflineMetric};
191
192    fn make_report(
193        name: &str,
194        regs: u32,
195        inst: u32,
196        occ: f32,
197        warns: usize,
198        coal: f32,
199    ) -> AnalysisReport {
200        AnalysisReport {
201            name: name.to_string(),
202            target: "PTX".to_string(),
203            registers: RegisterUsage {
204                f32_regs: regs,
205                ..Default::default()
206            },
207            memory: MemoryPattern {
208                coalesced_ratio: coal,
209                ..Default::default()
210            },
211            roofline: RooflineMetric::default(),
212            warnings: (0..warns)
213                .map(|_| MudaWarning {
214                    muda_type: crate::analyzer::MudaType::Transport,
215                    description: "test".to_string(),
216                    impact: "test".to_string(),
217                    line: None,
218                    suggestion: None,
219                })
220                .collect(),
221            instruction_count: inst,
222            estimated_occupancy: occ,
223        }
224    }
225
226    #[test]
227    fn test_compare_identical() {
228        let report_a = make_report("A", 32, 100, 0.75, 0, 0.95);
229        let report_b = make_report("B", 32, 100, 0.75, 0, 0.95);
230
231        let comparison = compare_analyses(&report_a, &report_b);
232
233        // All ties
234        assert!(comparison.metrics.iter().all(|m| m.winner == "Tie"));
235    }
236
237    #[test]
238    fn test_compare_clear_winner() {
239        let report_a = make_report("Optimized", 16, 50, 0.90, 0, 0.98);
240        let report_b = make_report("Baseline", 64, 200, 0.50, 3, 0.70);
241
242        let comparison = compare_analyses(&report_a, &report_b);
243
244        // A should win on all metrics
245        let a_wins = comparison
246            .metrics
247            .iter()
248            .filter(|m| m.winner == "A")
249            .count();
250        assert!(a_wins >= 4, "Optimized should win most metrics");
251        assert!(comparison.recommendation.contains("Optimized"));
252    }
253
254    #[test]
255    fn test_compare_mixed() {
256        // A has fewer registers but more warnings
257        let report_a = make_report("LowReg", 16, 100, 0.90, 5, 0.80);
258        let report_b = make_report("HighReg", 64, 100, 0.50, 0, 0.95);
259
260        let comparison = compare_analyses(&report_a, &report_b);
261
262        // Should have mixed results
263        let a_wins = comparison
264            .metrics
265            .iter()
266            .filter(|m| m.winner == "A")
267            .count();
268        let b_wins = comparison
269            .metrics
270            .iter()
271            .filter(|m| m.winner == "B")
272            .count();
273        assert!(a_wins > 0 && b_wins > 0, "Should have mixed winners");
274    }
275
276    #[test]
277    fn test_format_text() {
278        let report_a = make_report("A", 32, 100, 0.75, 1, 0.90);
279        let report_b = make_report("B", 48, 150, 0.60, 2, 0.85);
280
281        let comparison = compare_analyses(&report_a, &report_b);
282        let text = format_comparison_text(&comparison);
283
284        assert!(text.contains("Comparison"));
285        assert!(text.contains("Register Count"));
286        assert!(text.contains("Instruction Count"));
287    }
288
289    #[test]
290    fn test_format_json() {
291        let report_a = make_report("A", 32, 100, 0.75, 0, 0.90);
292        let report_b = make_report("B", 32, 100, 0.75, 0, 0.90);
293
294        let comparison = compare_analyses(&report_a, &report_b);
295        let json = format_comparison_json(&comparison);
296
297        assert!(json.contains("\"report_a_name\": \"A\""));
298        assert!(json.contains("\"report_b_name\": \"B\""));
299    }
300}