Skip to main content

trueno/blis/
reference.rs

1//! Scalar reference GEMM implementations for Jidoka validation.
2//!
3//! These are the "gold standard" implementations used to validate optimized versions.
4//!
5//! # References
6//!
7//! - Golub & Van Loan (2013), Matrix Computations, 4th ed., Algorithm 1.1.1.
8
9use crate::error::TruenoError;
10
11use super::jidoka::{JidokaError, JidokaGuard};
12
13/// Scalar reference GEMM for Jidoka validation
14///
15/// Computes C += A * B where:
16/// - A is M x K (row-major)
17/// - B is K x N (row-major)
18/// - C is M x N (row-major)
19///
20/// This is the "gold standard" implementation used to validate optimized versions.
21///
22/// # References
23///
24/// This implements the naive O(MNK) algorithm as described in
25/// Golub & Van Loan (2013), Matrix Computations, 4th ed., Algorithm 1.1.1.
26pub fn gemm_reference(
27    m: usize,
28    n: usize,
29    k: usize,
30    a: &[f32],
31    b: &[f32],
32    c: &mut [f32],
33) -> Result<(), TruenoError> {
34    // Poka-yoke: dimension validation
35    if a.len() != m * k {
36        return Err(TruenoError::InvalidInput(format!(
37            "A size mismatch: expected {}x{}={}, got {}",
38            m,
39            k,
40            m * k,
41            a.len()
42        )));
43    }
44    if b.len() != k * n {
45        return Err(TruenoError::InvalidInput(format!(
46            "B size mismatch: expected {}x{}={}, got {}",
47            k,
48            n,
49            k * n,
50            b.len()
51        )));
52    }
53    if c.len() != m * n {
54        return Err(TruenoError::InvalidInput(format!(
55            "C size mismatch: expected {}x{}={}, got {}",
56            m,
57            n,
58            m * n,
59            c.len()
60        )));
61    }
62
63    // Scalar triple-nested loop
64    for i in 0..m {
65        for j in 0..n {
66            let mut sum = 0.0f32;
67            for p in 0..k {
68                sum += a[i * k + p] * b[p * n + j];
69            }
70            c[i * n + j] += sum;
71        }
72    }
73
74    Ok(())
75}
76
77/// Validate a sampled output value for NaN/Inf (Jidoka andon cord).
78#[inline(always)]
79pub(super) fn jidoka_check_output(
80    val: f32,
81    idx: usize,
82    sample_rate: usize,
83) -> Result<(), JidokaError> {
84    if idx % sample_rate == 0 {
85        if val.is_nan() {
86            return Err(JidokaError::NaNDetected { location: "output" });
87        }
88        if val.is_infinite() {
89            return Err(JidokaError::InfDetected { location: "output" });
90        }
91    }
92    Ok(())
93}
94
95/// Validate sampled input values for NaN/Inf.
96pub(super) fn jidoka_check_inputs(
97    a: &[f32],
98    b: &[f32],
99    guard: &JidokaGuard,
100) -> Result<(), JidokaError> {
101    for (idx, &val) in a.iter().enumerate() {
102        if idx % guard.sample_rate == 0 {
103            guard.check_input(val, "matrix A")?;
104        }
105    }
106    for (idx, &val) in b.iter().enumerate() {
107        if idx % guard.sample_rate == 0 {
108            guard.check_input(val, "matrix B")?;
109        }
110    }
111    Ok(())
112}
113
114/// Scalar reference GEMM with Jidoka validation
115///
116/// Same as `gemm_reference` but validates outputs against known-good computation.
117pub fn gemm_reference_with_jidoka(
118    m: usize,
119    n: usize,
120    k: usize,
121    a: &[f32],
122    b: &[f32],
123    c: &mut [f32],
124    guard: &JidokaGuard,
125) -> Result<(), JidokaError> {
126    jidoka_check_inputs(a, b, guard)?;
127
128    for i in 0..m {
129        for j in 0..n {
130            let mut sum = 0.0f32;
131            for p in 0..k {
132                sum += a[i * k + p] * b[p * n + j];
133            }
134            let output = c[i * n + j] + sum;
135            jidoka_check_output(output, i * n + j, guard.sample_rate)?;
136            c[i * n + j] = output;
137        }
138    }
139
140    Ok(())
141}