use crate::error::TruenoError;
use super::jidoka::{JidokaError, JidokaGuard};
pub fn gemm_reference(
m: usize,
n: usize,
k: usize,
a: &[f32],
b: &[f32],
c: &mut [f32],
) -> Result<(), TruenoError> {
if a.len() != m * k {
return Err(TruenoError::InvalidInput(format!(
"A size mismatch: expected {}x{}={}, got {}",
m,
k,
m * k,
a.len()
)));
}
if b.len() != k * n {
return Err(TruenoError::InvalidInput(format!(
"B size mismatch: expected {}x{}={}, got {}",
k,
n,
k * n,
b.len()
)));
}
if c.len() != m * n {
return Err(TruenoError::InvalidInput(format!(
"C size mismatch: expected {}x{}={}, got {}",
m,
n,
m * n,
c.len()
)));
}
for i in 0..m {
for j in 0..n {
let mut sum = 0.0f32;
for p in 0..k {
sum += a[i * k + p] * b[p * n + j];
}
c[i * n + j] += sum;
}
}
Ok(())
}
#[inline(always)]
pub(super) fn jidoka_check_output(
val: f32,
idx: usize,
sample_rate: usize,
) -> Result<(), JidokaError> {
if idx % sample_rate == 0 {
if val.is_nan() {
return Err(JidokaError::NaNDetected { location: "output" });
}
if val.is_infinite() {
return Err(JidokaError::InfDetected { location: "output" });
}
}
Ok(())
}
pub(super) fn jidoka_check_inputs(
a: &[f32],
b: &[f32],
guard: &JidokaGuard,
) -> Result<(), JidokaError> {
for (idx, &val) in a.iter().enumerate() {
if idx % guard.sample_rate == 0 {
guard.check_input(val, "matrix A")?;
}
}
for (idx, &val) in b.iter().enumerate() {
if idx % guard.sample_rate == 0 {
guard.check_input(val, "matrix B")?;
}
}
Ok(())
}
pub fn gemm_reference_with_jidoka(
m: usize,
n: usize,
k: usize,
a: &[f32],
b: &[f32],
c: &mut [f32],
guard: &JidokaGuard,
) -> Result<(), JidokaError> {
jidoka_check_inputs(a, b, guard)?;
for i in 0..m {
for j in 0..n {
let mut sum = 0.0f32;
for p in 0..k {
sum += a[i * k + p] * b[p * n + j];
}
let output = c[i * n + j] + sum;
jidoka_check_output(output, i * n + j, guard.sample_rate)?;
c[i * n + j] = output;
}
}
Ok(())
}