use anyhow::{anyhow, bail, Result};
use serde::Deserialize;
use std::path::Path;
#[derive(Debug, Clone, Deserialize)]
pub struct GoldenOutput {
pub name: String,
pub data: Vec<f32>,
pub shape: Vec<usize>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct GoldenFixture {
pub model_path: String,
pub inputs: Vec<GoldenInput>,
pub outputs: Vec<GoldenOutput>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct GoldenInput {
pub name: String,
pub data: Vec<f32>,
pub shape: Vec<usize>,
}
impl GoldenFixture {
pub fn load<P: AsRef<Path>>(path: P) -> Result<Self> {
let text = std::fs::read_to_string(path)?;
Ok(serde_json::from_str(&text)?)
}
}
pub fn assert_allclose(actual: &[f32], expected: &[f32], atol: f32, rtol: f32) -> Result<f32> {
if actual.len() != expected.len() {
bail!(
"length mismatch: actual={} expected={}",
actual.len(),
expected.len()
);
}
let mut max_abs = 0f32;
for (i, (a, b)) in actual.iter().zip(expected.iter()).enumerate() {
let abs = (a - b).abs();
let tol = atol + rtol * b.abs();
max_abs = max_abs.max(abs);
if abs > tol || a.is_nan() != b.is_nan() {
return Err(anyhow!(
"parity failed at index {i}: actual={a} expected={b} \
|Δ|={abs} > tol={tol}"
));
}
}
Ok(max_abs)
}
pub fn assert_shape(actual: &[usize], expected: &[usize]) -> Result<()> {
if actual != expected {
bail!("shape mismatch: actual={actual:?} expected={expected:?}");
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn allclose_passes_on_identical() {
let a = vec![1.0f32, 2.0, 3.0];
let b = vec![1.0f32, 2.0, 3.0];
let err = assert_allclose(&a, &b, 1e-6, 1e-6).unwrap();
assert!(err < 1e-6);
}
#[test]
fn allclose_passes_within_tol() {
let a = vec![1.0f32, 2.0, 3.0];
let b = vec![1.0 + 1e-7, 2.0 - 1e-7, 3.0 + 1e-7];
assert!(assert_allclose(&a, &b, 1e-5, 0.0).is_ok());
}
#[test]
fn allclose_fails_above_tol() {
let a = vec![1.0f32, 2.0];
let b = vec![1.0, 3.0];
assert!(assert_allclose(&a, &b, 1e-3, 1e-3).is_err());
}
#[test]
fn shape_check() {
assert!(assert_shape(&[2, 3], &[2, 3]).is_ok());
assert!(assert_shape(&[2, 3], &[3, 2]).is_err());
}
}