#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Mm001Verdict { Pass, Fail }
#[must_use]
pub const fn verdict_from_output_shape(
m: u64,
p: u64,
n: u64,
observed_rows: u64,
observed_cols: u64,
) -> Mm001Verdict {
if m == 0 || p == 0 || n == 0 { return Mm001Verdict::Fail; }
if observed_rows == m && observed_cols == n { Mm001Verdict::Pass } else { Mm001Verdict::Fail }
}
pub const AC_MM_002_TOLERANCE: f32 = 1.0e-4;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Mm002Verdict { Pass, Fail }
#[must_use]
pub fn matmul_reference(a: &[f32], b: &[f32], m: usize, p: usize, n: usize) -> Vec<f32> {
if a.len() != m * p || b.len() != p * n { return vec![]; }
let mut c = vec![0.0_f32; m * n];
for i in 0..m {
for j in 0..n {
let mut acc = 0.0_f32;
for k in 0..p {
acc += a[i * p + k] * b[k * n + j];
}
c[i * n + j] = acc;
}
}
c
}
#[must_use]
pub fn verdict_from_numerical_accuracy(observed: &[f32], reference: &[f32]) -> Mm002Verdict {
if observed.is_empty() || reference.is_empty() { return Mm002Verdict::Fail; }
if observed.len() != reference.len() { return Mm002Verdict::Fail; }
for (&a, &b) in observed.iter().zip(reference.iter()) {
if !a.is_finite() || !b.is_finite() { return Mm002Verdict::Fail; }
if (a - b).abs() > AC_MM_002_TOLERANCE { return Mm002Verdict::Fail; }
}
Mm002Verdict::Pass
}
pub const AC_MM_003_ULP_TOLERANCE: u32 = 4;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Mm003Verdict { Pass, Fail }
#[must_use]
pub fn ulp_distance(a: f32, b: f32) -> u32 {
if !a.is_finite() || !b.is_finite() { return u32::MAX; }
if a == b { return 0; }
let ai = a.to_bits() as i32;
let bi = b.to_bits() as i32;
let ord_a = if ai < 0 { i32::MIN.wrapping_sub(ai).wrapping_add(1) } else { ai };
let ord_b = if bi < 0 { i32::MIN.wrapping_sub(bi).wrapping_add(1) } else { bi };
ord_a.wrapping_sub(ord_b).unsigned_abs()
}
#[must_use]
pub fn verdict_from_simd_parity(scalar: &[f32], simd: &[f32]) -> Mm003Verdict {
if scalar.is_empty() || simd.is_empty() { return Mm003Verdict::Fail; }
if scalar.len() != simd.len() { return Mm003Verdict::Fail; }
for (&s, &v) in scalar.iter().zip(simd.iter()) {
if !s.is_finite() || !v.is_finite() { return Mm003Verdict::Fail; }
if ulp_distance(s, v) > AC_MM_003_ULP_TOLERANCE { return Mm003Verdict::Fail; }
}
Mm003Verdict::Pass
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Mm004Verdict { Pass, Fail }
#[must_use]
pub fn quantized_dot(a: &[i8], b: &[i8], scale_a: f32, scale_b: f32) -> f32 {
if a.is_empty() || b.is_empty() || a.len() != b.len() { return 0.0; }
if !scale_a.is_finite() || !scale_b.is_finite() { return f32::NAN; }
let mut acc: i64 = 0;
for (&ai, &bi) in a.iter().zip(b.iter()) {
acc += (ai as i64) * (bi as i64);
}
scale_a * scale_b * (acc as f32)
}
#[must_use]
pub fn float_dot(a: &[i8], b: &[i8], scale_a: f32, scale_b: f32) -> f32 {
if a.is_empty() || b.is_empty() || a.len() != b.len() { return 0.0; }
let mut acc = 0.0_f32;
for (&ai, &bi) in a.iter().zip(b.iter()) {
acc += (ai as f32 * scale_a) * (bi as f32 * scale_b);
}
acc
}
#[must_use]
pub fn verdict_from_quantized_accuracy(
a: &[i8],
b: &[i8],
scale_a: f32,
scale_b: f32,
bound: f32,
) -> Mm004Verdict {
if a.is_empty() || b.is_empty() || a.len() != b.len() { return Mm004Verdict::Fail; }
if !scale_a.is_finite() || scale_a <= 0.0 { return Mm004Verdict::Fail; }
if !scale_b.is_finite() || scale_b <= 0.0 { return Mm004Verdict::Fail; }
if !bound.is_finite() || bound < 0.0 { return Mm004Verdict::Fail; }
let q = quantized_dot(a, b, scale_a, scale_b);
let f = float_dot(a, b, scale_a, scale_b);
if !q.is_finite() || !f.is_finite() { return Mm004Verdict::Fail; }
if (q - f).abs() <= bound { Mm004Verdict::Pass } else { Mm004Verdict::Fail }
}
pub const AC_MM_005_TOLERANCE: f32 = 1.0e-6;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Mm005Verdict { Pass, Fail }
#[must_use]
pub fn make_identity(n: usize) -> Vec<f32> {
let mut i = vec![0.0_f32; n * n];
for k in 0..n {
i[k * n + k] = 1.0;
}
i
}
#[must_use]
pub fn verdict_from_identity_preservation(a: &[f32], m: usize, n: usize) -> Mm005Verdict {
if a.is_empty() || m == 0 || n == 0 { return Mm005Verdict::Fail; }
if a.len() != m * n { return Mm005Verdict::Fail; }
if !a.iter().all(|v| v.is_finite()) { return Mm005Verdict::Fail; }
let i_n = make_identity(n);
let ai = matmul_reference(a, &i_n, m, n, n);
if ai.len() != a.len() { return Mm005Verdict::Fail; }
for (&x, &y) in a.iter().zip(ai.iter()) {
if (x - y).abs() > AC_MM_005_TOLERANCE { return Mm005Verdict::Fail; }
}
let i_m = make_identity(m);
let ia = matmul_reference(&i_m, a, m, m, n);
if ia.len() != a.len() { return Mm005Verdict::Fail; }
for (&x, &y) in a.iter().zip(ia.iter()) {
if (x - y).abs() > AC_MM_005_TOLERANCE { return Mm005Verdict::Fail; }
}
Mm005Verdict::Pass
}
#[cfg(test)]
mod tests {
use super::*;
#[test] fn mm001_pass_canonical() {
assert_eq!(verdict_from_output_shape(3, 4, 5, 3, 5), Mm001Verdict::Pass);
}
#[test] fn mm001_fail_swapped() {
assert_eq!(verdict_from_output_shape(3, 4, 5, 5, 3), Mm001Verdict::Fail);
}
#[test] fn mm001_fail_zero() {
assert_eq!(verdict_from_output_shape(0, 4, 5, 0, 5), Mm001Verdict::Fail);
assert_eq!(verdict_from_output_shape(3, 0, 5, 3, 5), Mm001Verdict::Fail);
assert_eq!(verdict_from_output_shape(3, 4, 0, 3, 0), Mm001Verdict::Fail);
}
#[test] fn mm002_pass_canonical() {
let a = vec![1.0_f32, 2.0, 3.0, 4.0];
let b = vec![5.0_f32, 6.0, 7.0, 8.0];
let c_ref = matmul_reference(&a, &b, 2, 2, 2);
assert_eq!(c_ref, vec![19.0_f32, 22.0, 43.0, 50.0]);
let observed = vec![19.0_f32, 22.0, 43.0, 50.0];
assert_eq!(verdict_from_numerical_accuracy(&observed, &c_ref), Mm002Verdict::Pass);
}
#[test] fn mm002_pass_within_tolerance() {
let a = vec![1.0_f32];
let b = vec![1.0_f32 + 1e-6]; assert_eq!(verdict_from_numerical_accuracy(&a, &b), Mm002Verdict::Pass);
}
#[test] fn mm002_fail_above_tolerance() {
let a = vec![1.0_f32];
let b = vec![1.001_f32]; assert_eq!(verdict_from_numerical_accuracy(&a, &b), Mm002Verdict::Fail);
}
#[test] fn mm002_fail_length_mismatch() {
let a = vec![1.0_f32];
let b = vec![1.0_f32, 2.0];
assert_eq!(verdict_from_numerical_accuracy(&a, &b), Mm002Verdict::Fail);
}
#[test] fn mm003_pass_identical() {
let a = vec![1.0_f32, 2.0];
assert_eq!(verdict_from_simd_parity(&a, &a), Mm003Verdict::Pass);
}
#[test] fn mm003_pass_within_4_ulp() {
let a = vec![1.0_f32];
let b = vec![f32::from_bits(1.0_f32.to_bits() + 3)]; assert_eq!(verdict_from_simd_parity(&a, &b), Mm003Verdict::Pass);
}
#[test] fn mm003_fail_above_4_ulp() {
let a = vec![1.0_f32];
let b = vec![f32::from_bits(1.0_f32.to_bits() + 5)];
assert_eq!(verdict_from_simd_parity(&a, &b), Mm003Verdict::Fail);
}
#[test] fn mm003_fail_length() {
let a = vec![1.0_f32];
let b = vec![1.0_f32, 2.0];
assert_eq!(verdict_from_simd_parity(&a, &b), Mm003Verdict::Fail);
}
#[test] fn mm004_pass_canonical() {
let a = vec![1_i8, 2];
let b = vec![3_i8, 4];
assert_eq!(verdict_from_quantized_accuracy(&a, &b, 0.1, 0.05, 1e-3), Mm004Verdict::Pass);
}
#[test] fn mm004_fail_zero_scale() {
let a = vec![1_i8];
let b = vec![1_i8];
assert_eq!(verdict_from_quantized_accuracy(&a, &b, 0.0, 0.05, 1e-3), Mm004Verdict::Fail);
}
#[test] fn mm004_fail_negative_bound() {
let a = vec![1_i8];
let b = vec![1_i8];
assert_eq!(verdict_from_quantized_accuracy(&a, &b, 0.1, 0.05, -1e-3), Mm004Verdict::Fail);
}
#[test] fn mm004_fail_length_mismatch() {
let a = vec![1_i8, 2];
let b = vec![1_i8];
assert_eq!(verdict_from_quantized_accuracy(&a, &b, 0.1, 0.05, 1e-3), Mm004Verdict::Fail);
}
#[test] fn mm005_pass_2x3() {
let a = vec![1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0];
assert_eq!(verdict_from_identity_preservation(&a, 2, 3), Mm005Verdict::Pass);
}
#[test] fn mm005_pass_random() {
let a: Vec<f32> = (0..16).map(|i| (i as f32) * 0.1 - 0.5).collect();
assert_eq!(verdict_from_identity_preservation(&a, 4, 4), Mm005Verdict::Pass);
}
#[test] fn mm005_fail_dim_mismatch() {
let a = vec![1.0_f32; 5]; assert_eq!(verdict_from_identity_preservation(&a, 2, 3), Mm005Verdict::Fail);
}
#[test] fn mm005_fail_zero_dim() {
let a = vec![1.0_f32];
assert_eq!(verdict_from_identity_preservation(&a, 0, 1), Mm005Verdict::Fail);
}
#[test] fn mm005_fail_nan() {
let a = vec![1.0_f32, f32::NAN];
assert_eq!(verdict_from_identity_preservation(&a, 1, 2), Mm005Verdict::Fail);
}
#[test] fn matmul_2x2_canonical() {
let a = vec![1.0_f32, 2.0, 3.0, 4.0];
let b = vec![5.0_f32, 6.0, 7.0, 8.0];
assert_eq!(matmul_reference(&a, &b, 2, 2, 2), vec![19.0_f32, 22.0, 43.0, 50.0]);
}
#[test] fn identity_3() {
assert_eq!(make_identity(3), vec![1.0_f32, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0]);
}
#[test] fn provenance_constants() {
assert!((AC_MM_002_TOLERANCE - 1e-4).abs() < 1e-9);
assert_eq!(AC_MM_003_ULP_TOLERANCE, 4); assert!((AC_MM_005_TOLERANCE - 1e-6).abs() < 1e-12);
}
}