1use crate::error::TruenoError;
10
11use super::jidoka::{JidokaError, JidokaGuard};
12
13pub fn gemm_reference(
27 m: usize,
28 n: usize,
29 k: usize,
30 a: &[f32],
31 b: &[f32],
32 c: &mut [f32],
33) -> Result<(), TruenoError> {
34 if a.len() != m * k {
36 return Err(TruenoError::InvalidInput(format!(
37 "A size mismatch: expected {}x{}={}, got {}",
38 m,
39 k,
40 m * k,
41 a.len()
42 )));
43 }
44 if b.len() != k * n {
45 return Err(TruenoError::InvalidInput(format!(
46 "B size mismatch: expected {}x{}={}, got {}",
47 k,
48 n,
49 k * n,
50 b.len()
51 )));
52 }
53 if c.len() != m * n {
54 return Err(TruenoError::InvalidInput(format!(
55 "C size mismatch: expected {}x{}={}, got {}",
56 m,
57 n,
58 m * n,
59 c.len()
60 )));
61 }
62
63 for i in 0..m {
65 for j in 0..n {
66 let mut sum = 0.0f32;
67 for p in 0..k {
68 sum += a[i * k + p] * b[p * n + j];
69 }
70 c[i * n + j] += sum;
71 }
72 }
73
74 Ok(())
75}
76
77#[inline(always)]
79pub(super) fn jidoka_check_output(
80 val: f32,
81 idx: usize,
82 sample_rate: usize,
83) -> Result<(), JidokaError> {
84 if idx % sample_rate == 0 {
85 if val.is_nan() {
86 return Err(JidokaError::NaNDetected { location: "output" });
87 }
88 if val.is_infinite() {
89 return Err(JidokaError::InfDetected { location: "output" });
90 }
91 }
92 Ok(())
93}
94
95pub(super) fn jidoka_check_inputs(
97 a: &[f32],
98 b: &[f32],
99 guard: &JidokaGuard,
100) -> Result<(), JidokaError> {
101 for (idx, &val) in a.iter().enumerate() {
102 if idx % guard.sample_rate == 0 {
103 guard.check_input(val, "matrix A")?;
104 }
105 }
106 for (idx, &val) in b.iter().enumerate() {
107 if idx % guard.sample_rate == 0 {
108 guard.check_input(val, "matrix B")?;
109 }
110 }
111 Ok(())
112}
113
114pub fn gemm_reference_with_jidoka(
118 m: usize,
119 n: usize,
120 k: usize,
121 a: &[f32],
122 b: &[f32],
123 c: &mut [f32],
124 guard: &JidokaGuard,
125) -> Result<(), JidokaError> {
126 jidoka_check_inputs(a, b, guard)?;
127
128 for i in 0..m {
129 for j in 0..n {
130 let mut sum = 0.0f32;
131 for p in 0..k {
132 sum += a[i * k + p] * b[p * n + j];
133 }
134 let output = c[i * n + j] + sum;
135 jidoka_check_output(output, i * n + j, guard.sample_rate)?;
136 c[i * n + j] = output;
137 }
138 }
139
140 Ok(())
141}