cbtop/brick/profiler.rs
1//! CORRECTNESS-011: BrickProfiler for CPU/GPU Divergence Detection
2
3use super::types::{DivergenceReport, KernelTrace};
4
5/// BrickProfiler collects per-kernel traces for automated divergence detection.
6///
7/// Five-Whys Root Cause: Hours of manual "let me check X in Y" debugging
8/// → No automated tool identified which kernel diverged
9/// → BrickProfiler only captured timing, not checksums
10/// → Missing feature: per-kernel checksum capture
11/// → ROOT CAUSE: Brick Profiling lacked correctness instrumentation
12///
13/// # Usage
14///
15/// ```rust,ignore
16/// use cbtop::{BrickProfiler, KernelTrace};
17///
18/// // CPU execution
19/// let mut cpu_profiler = BrickProfiler::new("cpu_baseline");
20/// cpu_profiler.add_trace(KernelTrace::new("rope_neox", 0, pos, "CPU")
21/// .with_input_checksum(&input)
22/// .with_output_checksum(&output));
23///
24/// // GPU execution
25/// let mut gpu_profiler = BrickProfiler::new("cuda_test");
26/// gpu_profiler.add_trace(KernelTrace::new("rope_neox", 0, pos, "CUDA")
27/// .with_input_checksum(&input)
28/// .with_output_checksum(&output));
29///
30/// // Automated divergence detection
31/// let report = cpu_profiler.compare(&gpu_profiler);
32/// if !report.matched {
33/// eprintln!("FIVE-WHYS ALERT: {}", report.diagnosis);
34/// }
35/// ```
36#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
37pub struct BrickProfiler {
38 /// Run identifier (e.g., "cpu_baseline", "cuda_test")
39 pub run_id: String,
40 /// Collected kernel traces
41 pub traces: Vec<KernelTrace>,
42 /// Total execution time in microseconds
43 pub total_time_us: f64,
44 /// Whether any divergence was detected
45 pub diverged: bool,
46 /// Divergence diagnosis (if any)
47 pub divergence_diagnosis: String,
48}
49
50impl BrickProfiler {
51 /// Create a new profiler for a run
52 pub fn new(run_id: &str) -> Self {
53 Self {
54 run_id: run_id.to_string(),
55 traces: Vec::new(),
56 total_time_us: 0.0,
57 diverged: false,
58 divergence_diagnosis: String::new(),
59 }
60 }
61
62 /// Add a kernel trace
63 pub fn add_trace(&mut self, trace: KernelTrace) {
64 self.total_time_us += trace.time_us;
65 self.traces.push(trace);
66 }
67
68 /// Check if divergence was detected
69 pub fn is_diverged(&self) -> bool {
70 self.diverged
71 }
72
73 /// Compare this profiler's traces against a reference (e.g., CPU vs GPU)
74 ///
75 /// Returns a DivergenceReport identifying the first divergent kernel.
76 /// Matching is done by (kernel_name, layer_idx, position) triple.
77 pub fn compare(&self, reference: &BrickProfiler) -> DivergenceReport {
78 // Build index from reference traces
79 let ref_index: std::collections::HashMap<(&str, usize, u32), &KernelTrace> = reference
80 .traces
81 .iter()
82 .map(|t| ((t.kernel_name.as_str(), t.layer_idx, t.position), t))
83 .collect();
84
85 let mut kernels_compared = 0;
86
87 for actual_trace in &self.traces {
88 let key = (
89 actual_trace.kernel_name.as_str(),
90 actual_trace.layer_idx,
91 actual_trace.position,
92 );
93
94 if let Some(expected_trace) = ref_index.get(&key) {
95 kernels_compared += 1;
96
97 // Compare output checksums
98 if actual_trace.output_checksum != expected_trace.output_checksum {
99 return DivergenceReport::diverged(
100 (*expected_trace).clone(),
101 actual_trace.clone(),
102 kernels_compared,
103 );
104 }
105 }
106 }
107
108 DivergenceReport::matched(kernels_compared)
109 }
110
111 /// Compare and set internal divergence state
112 pub fn compare_and_mark(&mut self, reference: &BrickProfiler) -> DivergenceReport {
113 let report = self.compare(reference);
114 self.diverged = !report.matched;
115 self.divergence_diagnosis = report.diagnosis.clone();
116 report
117 }
118
119 /// Get traces for a specific kernel name
120 pub fn traces_for_kernel(&self, kernel_name: &str) -> Vec<&KernelTrace> {
121 self.traces
122 .iter()
123 .filter(|t| t.kernel_name == kernel_name)
124 .collect()
125 }
126
127 /// Get traces for a specific layer
128 pub fn traces_for_layer(&self, layer_idx: usize) -> Vec<&KernelTrace> {
129 self.traces
130 .iter()
131 .filter(|t| t.layer_idx == layer_idx)
132 .collect()
133 }
134
135 /// Clear all traces (for reuse)
136 pub fn clear(&mut self) {
137 self.traces.clear();
138 self.total_time_us = 0.0;
139 self.diverged = false;
140 self.divergence_diagnosis.clear();
141 }
142
143 /// Serialize to JSON for pmat brick-score consumption
144 pub fn to_json(&self) -> Result<String, serde_json::Error> {
145 serde_json::to_string_pretty(self)
146 }
147
148 /// Deserialize from JSON
149 pub fn from_json(json: &str) -> Result<Self, serde_json::Error> {
150 serde_json::from_str(json)
151 }
152}