Skip to main content

cbtop/brick/
types.rs

1//! Core Brick types: assertions, kernel traces, divergence reports,
2//! performance budgets, and verification results.
3
4use std::any::Any;
5use std::time::{Duration, Instant};
6
7/// Falsifiable assertion types
8#[derive(Debug, Clone)]
9pub enum BrickAssertion {
10    /// Minimum width requirement
11    MinWidth(u16),
12    /// Minimum height requirement
13    MinHeight(u16),
14    /// Maximum width requirement
15    MaxWidth(u16),
16    /// Maximum height requirement
17    MaxHeight(u16),
18    /// Maximum render time in milliseconds
19    MaxRenderTimeMs(u32),
20    /// Maximum latency in milliseconds
21    MaxLatencyMs(u32),
22    /// Value must be in range [min, max]
23    ValueInRange { min: f64, max: f64 },
24    /// Data must not be empty
25    DataNonEmpty,
26    /// Custom assertion with name and validator
27    Custom {
28        name: &'static str,
29        description: &'static str,
30    },
31    /// CORRECTNESS-011: Checksum must match between backends (CPU vs GPU)
32    /// Five-Whys: Hours of manual debugging → No automated divergence detection
33    ChecksumMatch {
34        /// Expected checksum from reference backend (e.g., CPU Scalar)
35        expected: u64,
36        /// Actual checksum from test backend (e.g., CUDA)
37        actual: u64,
38        /// Kernel name where divergence occurred
39        kernel_name: String,
40        /// Position/layer where divergence occurred
41        position: u32,
42    },
43}
44
45impl BrickAssertion {
46    /// Get assertion name for reporting
47    pub fn name(&self) -> &str {
48        match self {
49            Self::MinWidth(_) => "min_width",
50            Self::MinHeight(_) => "min_height",
51            Self::MaxWidth(_) => "max_width",
52            Self::MaxHeight(_) => "max_height",
53            Self::MaxRenderTimeMs(_) => "max_render_time_ms",
54            Self::MaxLatencyMs(_) => "max_latency_ms",
55            Self::ValueInRange { .. } => "value_in_range",
56            Self::DataNonEmpty => "data_non_empty",
57            Self::Custom { name, .. } => name,
58            Self::ChecksumMatch { .. } => "checksum_match",
59        }
60    }
61
62    /// Create custom assertion with name and validator function
63    /// Note: validator is called but result not stored (for API compatibility)
64    pub fn custom<F>(_name: &'static str, _validator: F) -> Self
65    where
66        F: Fn(&dyn Any) -> bool,
67    {
68        Self::Custom {
69            name: _name,
70            description: "",
71        }
72    }
73
74    /// Create max latency assertion (milliseconds)
75    pub const fn max_latency_ms(ms: u32) -> Self {
76        Self::MaxLatencyMs(ms)
77    }
78
79    /// CORRECTNESS-011: Create checksum match assertion
80    pub fn checksum_match(expected: u64, actual: u64, kernel_name: &str, position: u32) -> Self {
81        Self::ChecksumMatch {
82            expected,
83            actual,
84            kernel_name: kernel_name.to_string(),
85            position,
86        }
87    }
88}
89
90/// CORRECTNESS-011: Per-kernel trace for divergence detection
91/// Captures input/output checksums for every kernel launch
92#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
93pub struct KernelTrace {
94    /// Kernel name (e.g., "rope_neox_indirect_12_128")
95    pub kernel_name: String,
96    /// Layer index (0-27 for transformer layers)
97    pub layer_idx: usize,
98    /// Position in sequence (for RoPE, attention)
99    pub position: u32,
100    /// Input checksum (FNV-1a of first 64 floats)
101    pub input_checksum: u64,
102    /// Output checksum (FNV-1a of first 64 floats)
103    pub output_checksum: u64,
104    /// Kernel parameters (JSON serialized for flexibility)
105    pub params: String,
106    /// Execution time in microseconds
107    pub time_us: f64,
108    /// Backend that executed this kernel (e.g., "CPU", "CUDA", "Vulkan")
109    pub backend: String,
110}
111
112impl KernelTrace {
113    /// Create a new kernel trace
114    pub fn new(kernel_name: &str, layer_idx: usize, position: u32, backend: &str) -> Self {
115        Self {
116            kernel_name: kernel_name.to_string(),
117            layer_idx,
118            position,
119            input_checksum: 0,
120            output_checksum: 0,
121            params: String::new(),
122            time_us: 0.0,
123            backend: backend.to_string(),
124        }
125    }
126
127    /// Set input checksum from float slice (FNV-1a hash of first 64 elements)
128    pub fn with_input_checksum(mut self, data: &[f32]) -> Self {
129        self.input_checksum = fnv1a_f32(data);
130        self
131    }
132
133    /// Set output checksum from float slice
134    pub fn with_output_checksum(mut self, data: &[f32]) -> Self {
135        self.output_checksum = fnv1a_f32(data);
136        self
137    }
138
139    /// Set kernel parameters as JSON
140    pub fn with_params(mut self, params: &str) -> Self {
141        self.params = params.to_string();
142        self
143    }
144
145    /// Set execution time
146    pub fn with_time_us(mut self, time_us: f64) -> Self {
147        self.time_us = time_us;
148        self
149    }
150}
151
152/// CORRECTNESS-011: Divergence report identifying first mismatch
153#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
154pub struct DivergenceReport {
155    /// Whether CPU and GPU outputs matched
156    pub matched: bool,
157    /// First kernel that diverged (None if all matched)
158    pub first_divergent_kernel: Option<KernelTrace>,
159    /// Expected (CPU) trace at divergence point
160    pub expected_trace: Option<KernelTrace>,
161    /// Actual (GPU) trace at divergence point
162    pub actual_trace: Option<KernelTrace>,
163    /// Number of kernels compared before finding divergence
164    pub kernels_compared: usize,
165    /// Human-readable diagnosis
166    pub diagnosis: String,
167}
168
169impl DivergenceReport {
170    /// Create a report indicating no divergence
171    pub fn matched(kernels_compared: usize) -> Self {
172        Self {
173            matched: true,
174            first_divergent_kernel: None,
175            expected_trace: None,
176            actual_trace: None,
177            kernels_compared,
178            diagnosis: format!(
179                "All {} kernels matched between CPU and GPU",
180                kernels_compared
181            ),
182        }
183    }
184
185    /// Create a report indicating divergence at specific kernel
186    pub fn diverged(expected: KernelTrace, actual: KernelTrace, kernels_compared: usize) -> Self {
187        let diagnosis = format!(
188            "DIVERGENCE at kernel '{}' (layer {}, position {}): \
189             CPU checksum 0x{:016X} != GPU checksum 0x{:016X}. \
190             Params: {}",
191            actual.kernel_name,
192            actual.layer_idx,
193            actual.position,
194            expected.output_checksum,
195            actual.output_checksum,
196            actual.params,
197        );
198        Self {
199            matched: false,
200            first_divergent_kernel: Some(actual.clone()),
201            expected_trace: Some(expected),
202            actual_trace: Some(actual),
203            kernels_compared,
204            diagnosis,
205        }
206    }
207}
208
209/// FNV-1a hash of f32 slice (first 64 elements for efficiency)
210/// Public for use in divergence detection across crates
211pub fn fnv1a_f32(data: &[f32]) -> u64 {
212    const FNV_OFFSET: u64 = 0xcbf29ce484222325;
213    const FNV_PRIME: u64 = 0x100000001b3;
214
215    let mut hash = FNV_OFFSET;
216    let len = data.len().min(64);
217    for &val in &data[..len] {
218        let bytes = val.to_le_bytes();
219        for byte in bytes {
220            hash ^= u64::from(byte);
221            hash = hash.wrapping_mul(FNV_PRIME);
222        }
223    }
224    hash
225}
226
227/// Performance budget per phase (Muda elimination)
228///
229/// Reference: Ohno, T. (1988). "Toyota Production System"
230#[derive(Debug, Clone, Copy, Default)]
231pub struct BrickBudget {
232    /// Collection phase budget (ms)
233    pub collect_ms: u32,
234    /// Layout calculation budget (ms)
235    pub layout_ms: u32,
236    /// Rendering phase budget (ms)
237    pub render_ms: u32,
238}
239
240impl BrickBudget {
241    /// Create uniform budget (same for all phases)
242    pub const fn uniform(ms: u32) -> Self {
243        Self {
244            collect_ms: ms,
245            layout_ms: ms,
246            render_ms: ms,
247        }
248    }
249
250    /// 60fps budget: 16ms total
251    pub const FRAME_60FPS: Self = Self {
252        collect_ms: 5,
253        layout_ms: 3,
254        render_ms: 8,
255    };
256
257    /// 30fps budget: 33ms total
258    pub const FRAME_30FPS: Self = Self {
259        collect_ms: 10,
260        layout_ms: 6,
261        render_ms: 17,
262    };
263
264    /// Total budget in milliseconds
265    pub const fn total_ms(&self) -> u32 {
266        self.collect_ms + self.layout_ms + self.render_ms
267    }
268}
269
270/// Verification result with pass/fail tracking
271#[derive(Debug, Clone)]
272pub struct BrickVerification {
273    /// Passed assertions
274    pub passed: Vec<BrickAssertion>,
275    /// Failed assertions with reason
276    pub failed: Vec<(BrickAssertion, String)>,
277    /// Time taken to verify
278    pub verification_time: Duration,
279    /// Timestamp
280    pub timestamp: Instant,
281}
282
283impl BrickVerification {
284    /// Create new verification result
285    pub fn new() -> Self {
286        Self {
287            passed: Vec::new(),
288            failed: Vec::new(),
289            verification_time: Duration::ZERO,
290            timestamp: Instant::now(),
291        }
292    }
293
294    /// Create a passing verification
295    pub fn pass() -> Self {
296        Self::new()
297    }
298
299    /// Add a passed assertion
300    pub fn add_pass(&mut self, assertion: BrickAssertion) {
301        self.passed.push(assertion);
302    }
303
304    /// Add a failed assertion with reason
305    pub fn add_fail(&mut self, assertion: BrickAssertion, reason: impl Into<String>) {
306        self.failed.push((assertion, reason.into()));
307    }
308
309    /// Check an assertion and add to passed list (simplified version)
310    pub fn check(&mut self, assertion: &BrickAssertion) {
311        // For now, assume assertions pass (real implementation would validate)
312        self.passed.push(assertion.clone());
313    }
314
315    /// Is verification successful? (Jidoka gate)
316    pub fn is_valid(&self) -> bool {
317        self.failed.is_empty()
318    }
319
320    /// Falsification score: passed / total
321    pub fn score(&self) -> f64 {
322        let total = self.passed.len() + self.failed.len();
323        if total == 0 {
324            1.0
325        } else {
326            self.passed.len() as f64 / total as f64
327        }
328    }
329
330    /// Get failure count
331    pub fn failure_count(&self) -> usize {
332        self.failed.len()
333    }
334}
335
336impl Default for BrickVerification {
337    fn default() -> Self {
338        Self::new()
339    }
340}