Skip to main content

trueno/brick/profiler/
divergence.rs

1//! Checksum recording and divergence detection for BrickProfiler.
2//!
3//! CORRECTNESS-011: Per-kernel checksum capture for CPU/GPU divergence detection.
4//! Extracted from mod.rs to keep file sizes manageable.
5
6use super::BrickProfiler;
7use crate::brick::profiler::checksum::{fnv1a_f32_checksum, DivergenceInfo, KernelChecksum};
8
9impl BrickProfiler {
10    // =======================================================================
11    // CORRECTNESS-011: Per-kernel checksum capture for divergence detection
12    // =======================================================================
13
14    /// Record a kernel trace with output checksum for divergence detection.
15    ///
16    /// This enables automated CPU/GPU divergence detection by capturing
17    /// output checksums alongside timing data. When GPU produces wrong output,
18    /// this identifies WHICH kernel diverged without hours of manual debugging.
19    ///
20    /// Five-Whys Root Cause: Hours of manual "let me check X in Y" debugging
21    /// -> No automated tool identified which kernel diverged
22    /// -> BrickProfiler only captured timing, not checksums
23    /// -> Missing feature: per-kernel checksum capture
24    ///
25    /// # Arguments
26    /// - `name`: Brick/kernel name
27    /// - `layer_idx`: Layer index (0-N for transformer layers)
28    /// - `position`: Position in sequence
29    /// - `output`: Output tensor data (first 64 floats checksummed)
30    ///
31    /// # Example
32    /// ```rust,ignore
33    /// // After RoPE kernel
34    /// profiler.record_checksum("RopeNeox", layer_idx, position, &q_rotated);
35    /// ```
36    pub fn record_checksum(&mut self, name: &str, layer_idx: usize, position: u32, output: &[f32]) {
37        if !self.enabled {
38            return;
39        }
40        let checksum = fnv1a_f32_checksum(output);
41        let trace = KernelChecksum { name: name.to_string(), layer_idx, position, checksum };
42        self.kernel_checksums.push(trace);
43    }
44
45    /// Get all kernel checksums for divergence comparison.
46    #[must_use]
47    pub fn get_checksums(&self) -> &[KernelChecksum] {
48        &self.kernel_checksums
49    }
50
51    /// Compare checksums with a reference profiler (e.g., CPU baseline).
52    ///
53    /// Returns None if all checksums match, or the first divergent kernel.
54    #[must_use]
55    pub fn find_divergence(&self, reference: &BrickProfiler) -> Option<DivergenceInfo> {
56        use std::collections::HashMap;
57
58        // Index reference checksums by (name, layer_idx, position)
59        let ref_index: HashMap<(&str, usize, u32), u64> = reference
60            .kernel_checksums
61            .iter()
62            .map(|t| ((t.name.as_str(), t.layer_idx, t.position), t.checksum))
63            .collect();
64
65        // Check each of our checksums against reference
66        for trace in &self.kernel_checksums {
67            let key = (trace.name.as_str(), trace.layer_idx, trace.position);
68            if let Some(&expected) = ref_index.get(&key) {
69                if trace.checksum != expected {
70                    return Some(DivergenceInfo {
71                        kernel_name: trace.name.clone(),
72                        layer_idx: trace.layer_idx,
73                        position: trace.position,
74                        expected_checksum: expected,
75                        actual_checksum: trace.checksum,
76                    });
77                }
78            }
79        }
80        None
81    }
82
83    /// Reset checksum tracking (call before new forward pass).
84    pub fn reset_checksums(&mut self) {
85        self.kernel_checksums.clear();
86    }
87}