trueno/brick/profiler/divergence.rs
1//! Checksum recording and divergence detection for BrickProfiler.
2//!
3//! CORRECTNESS-011: Per-kernel checksum capture for CPU/GPU divergence detection.
4//! Extracted from mod.rs to keep file sizes manageable.
5
6use super::BrickProfiler;
7use crate::brick::profiler::checksum::{fnv1a_f32_checksum, DivergenceInfo, KernelChecksum};
8
9impl BrickProfiler {
10 // =======================================================================
11 // CORRECTNESS-011: Per-kernel checksum capture for divergence detection
12 // =======================================================================
13
14 /// Record a kernel trace with output checksum for divergence detection.
15 ///
16 /// This enables automated CPU/GPU divergence detection by capturing
17 /// output checksums alongside timing data. When GPU produces wrong output,
18 /// this identifies WHICH kernel diverged without hours of manual debugging.
19 ///
20 /// Five-Whys Root Cause: Hours of manual "let me check X in Y" debugging
21 /// -> No automated tool identified which kernel diverged
22 /// -> BrickProfiler only captured timing, not checksums
23 /// -> Missing feature: per-kernel checksum capture
24 ///
25 /// # Arguments
26 /// - `name`: Brick/kernel name
27 /// - `layer_idx`: Layer index (0-N for transformer layers)
28 /// - `position`: Position in sequence
29 /// - `output`: Output tensor data (first 64 floats checksummed)
30 ///
31 /// # Example
32 /// ```rust,ignore
33 /// // After RoPE kernel
34 /// profiler.record_checksum("RopeNeox", layer_idx, position, &q_rotated);
35 /// ```
36 pub fn record_checksum(&mut self, name: &str, layer_idx: usize, position: u32, output: &[f32]) {
37 if !self.enabled {
38 return;
39 }
40 let checksum = fnv1a_f32_checksum(output);
41 let trace = KernelChecksum { name: name.to_string(), layer_idx, position, checksum };
42 self.kernel_checksums.push(trace);
43 }
44
45 /// Get all kernel checksums for divergence comparison.
46 #[must_use]
47 pub fn get_checksums(&self) -> &[KernelChecksum] {
48 &self.kernel_checksums
49 }
50
51 /// Compare checksums with a reference profiler (e.g., CPU baseline).
52 ///
53 /// Returns None if all checksums match, or the first divergent kernel.
54 #[must_use]
55 pub fn find_divergence(&self, reference: &BrickProfiler) -> Option<DivergenceInfo> {
56 use std::collections::HashMap;
57
58 // Index reference checksums by (name, layer_idx, position)
59 let ref_index: HashMap<(&str, usize, u32), u64> = reference
60 .kernel_checksums
61 .iter()
62 .map(|t| ((t.name.as_str(), t.layer_idx, t.position), t.checksum))
63 .collect();
64
65 // Check each of our checksums against reference
66 for trace in &self.kernel_checksums {
67 let key = (trace.name.as_str(), trace.layer_idx, trace.position);
68 if let Some(&expected) = ref_index.get(&key) {
69 if trace.checksum != expected {
70 return Some(DivergenceInfo {
71 kernel_name: trace.name.clone(),
72 layer_idx: trace.layer_idx,
73 position: trace.position,
74 expected_checksum: expected,
75 actual_checksum: trace.checksum,
76 });
77 }
78 }
79 }
80 None
81 }
82
83 /// Reset checksum tracking (call before new forward pass).
84 pub fn reset_checksums(&mut self) {
85 self.kernel_checksums.clear();
86 }
87}