Skip to main content

cbtop/brick/
profiler.rs

1//! CORRECTNESS-011: BrickProfiler for CPU/GPU Divergence Detection
2
3use super::types::{DivergenceReport, KernelTrace};
4
5/// BrickProfiler collects per-kernel traces for automated divergence detection.
6///
7/// Five-Whys Root Cause: Hours of manual "let me check X in Y" debugging
8/// → No automated tool identified which kernel diverged
9/// → BrickProfiler only captured timing, not checksums
10/// → Missing feature: per-kernel checksum capture
11/// → ROOT CAUSE: Brick Profiling lacked correctness instrumentation
12///
13/// # Usage
14///
15/// ```rust,ignore
16/// use cbtop::{BrickProfiler, KernelTrace};
17///
18/// // CPU execution
19/// let mut cpu_profiler = BrickProfiler::new("cpu_baseline");
20/// cpu_profiler.add_trace(KernelTrace::new("rope_neox", 0, pos, "CPU")
21///     .with_input_checksum(&input)
22///     .with_output_checksum(&output));
23///
24/// // GPU execution
25/// let mut gpu_profiler = BrickProfiler::new("cuda_test");
26/// gpu_profiler.add_trace(KernelTrace::new("rope_neox", 0, pos, "CUDA")
27///     .with_input_checksum(&input)
28///     .with_output_checksum(&output));
29///
30/// // Automated divergence detection
31/// let report = cpu_profiler.compare(&gpu_profiler);
32/// if !report.matched {
33///     eprintln!("FIVE-WHYS ALERT: {}", report.diagnosis);
34/// }
35/// ```
36#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
37pub struct BrickProfiler {
38    /// Run identifier (e.g., "cpu_baseline", "cuda_test")
39    pub run_id: String,
40    /// Collected kernel traces
41    pub traces: Vec<KernelTrace>,
42    /// Total execution time in microseconds
43    pub total_time_us: f64,
44    /// Whether any divergence was detected
45    pub diverged: bool,
46    /// Divergence diagnosis (if any)
47    pub divergence_diagnosis: String,
48}
49
50impl BrickProfiler {
51    /// Create a new profiler for a run
52    pub fn new(run_id: &str) -> Self {
53        Self {
54            run_id: run_id.to_string(),
55            traces: Vec::new(),
56            total_time_us: 0.0,
57            diverged: false,
58            divergence_diagnosis: String::new(),
59        }
60    }
61
62    /// Add a kernel trace
63    pub fn add_trace(&mut self, trace: KernelTrace) {
64        self.total_time_us += trace.time_us;
65        self.traces.push(trace);
66    }
67
68    /// Check if divergence was detected
69    pub fn is_diverged(&self) -> bool {
70        self.diverged
71    }
72
73    /// Compare this profiler's traces against a reference (e.g., CPU vs GPU)
74    ///
75    /// Returns a DivergenceReport identifying the first divergent kernel.
76    /// Matching is done by (kernel_name, layer_idx, position) triple.
77    pub fn compare(&self, reference: &BrickProfiler) -> DivergenceReport {
78        // Build index from reference traces
79        let ref_index: std::collections::HashMap<(&str, usize, u32), &KernelTrace> = reference
80            .traces
81            .iter()
82            .map(|t| ((t.kernel_name.as_str(), t.layer_idx, t.position), t))
83            .collect();
84
85        let mut kernels_compared = 0;
86
87        for actual_trace in &self.traces {
88            let key = (
89                actual_trace.kernel_name.as_str(),
90                actual_trace.layer_idx,
91                actual_trace.position,
92            );
93
94            if let Some(expected_trace) = ref_index.get(&key) {
95                kernels_compared += 1;
96
97                // Compare output checksums
98                if actual_trace.output_checksum != expected_trace.output_checksum {
99                    return DivergenceReport::diverged(
100                        (*expected_trace).clone(),
101                        actual_trace.clone(),
102                        kernels_compared,
103                    );
104                }
105            }
106        }
107
108        DivergenceReport::matched(kernels_compared)
109    }
110
111    /// Compare and set internal divergence state
112    pub fn compare_and_mark(&mut self, reference: &BrickProfiler) -> DivergenceReport {
113        let report = self.compare(reference);
114        self.diverged = !report.matched;
115        self.divergence_diagnosis = report.diagnosis.clone();
116        report
117    }
118
119    /// Get traces for a specific kernel name
120    pub fn traces_for_kernel(&self, kernel_name: &str) -> Vec<&KernelTrace> {
121        self.traces
122            .iter()
123            .filter(|t| t.kernel_name == kernel_name)
124            .collect()
125    }
126
127    /// Get traces for a specific layer
128    pub fn traces_for_layer(&self, layer_idx: usize) -> Vec<&KernelTrace> {
129        self.traces
130            .iter()
131            .filter(|t| t.layer_idx == layer_idx)
132            .collect()
133    }
134
135    /// Clear all traces (for reuse)
136    pub fn clear(&mut self) {
137        self.traces.clear();
138        self.total_time_us = 0.0;
139        self.diverged = false;
140        self.divergence_diagnosis.clear();
141    }
142
143    /// Serialize to JSON for pmat brick-score consumption
144    pub fn to_json(&self) -> Result<String, serde_json::Error> {
145        serde_json::to_string_pretty(self)
146    }
147
148    /// Deserialize from JSON
149    pub fn from_json(json: &str) -> Result<Self, serde_json::Error> {
150        serde_json::from_str(json)
151    }
152}