use serde::Deserialize;
use std::fs;
use std::path::Path;
#[allow(dead_code)]
#[derive(Debug, Deserialize)]
pub struct GoldenData {
pub test_name: String,
pub params: serde_json::Value,
pub input: GoldenInput,
pub expected: GoldenExpected,
#[serde(default = "default_tolerance")]
pub tolerance: f64,
}
fn default_tolerance() -> f64 {
1e-10
}
#[allow(dead_code)]
#[derive(Debug, Deserialize)]
pub struct GoldenInput {
#[serde(rename = "X")]
pub x: Option<Vec<Vec<f64>>>,
pub y: Option<Vec<String>>,
#[serde(rename = "X_test")]
pub x_test: Option<Vec<Vec<f64>>>,
pub a: Option<Vec<f64>>,
pub b: Option<Vec<f64>>,
}
#[allow(dead_code)]
#[derive(Debug, Deserialize)]
pub struct GoldenExpected {
pub output: Option<Vec<Vec<f64>>>,
pub scalar: Option<f64>,
pub output_3d: Option<Vec<Vec<Vec<f64>>>>,
pub output_4d: Option<Vec<Vec<Vec<Vec<f64>>>>>,
pub symbolic: Option<Vec<Vec<u8>>>,
pub strings: Option<Vec<String>>,
pub predictions: Option<Vec<String>>,
pub score: Option<f64>,
}
pub fn load_golden_data(path: &str) -> Vec<GoldenData> {
let full_path = Path::new(env!("CARGO_MANIFEST_DIR"))
.join("tests")
.join("golden_data")
.join(path);
let content = fs::read_to_string(&full_path).unwrap_or_else(|e| {
panic!(
"Failed to read golden data at {}: {}",
full_path.display(),
e
)
});
serde_json::from_str(&content).unwrap_or_else(|e| {
panic!(
"Failed to parse golden data at {}: {}",
full_path.display(),
e
)
})
}
#[allow(dead_code)]
pub fn assert_slice_close(name: &str, actual: &[f64], expected: &[f64], epsilon: f64) {
assert_eq!(
actual.len(),
expected.len(),
"{name}: length mismatch: actual={}, expected={}",
actual.len(),
expected.len()
);
for (i, (a, e)) in actual.iter().zip(expected.iter()).enumerate() {
assert!(
(a - e).abs() < epsilon,
"{name}[{i}]: actual={a} != expected={e} (eps={epsilon})"
);
}
}
#[allow(dead_code)]
pub fn assert_batch_close(name: &str, actual: &[Vec<f64>], expected: &[Vec<f64>], epsilon: f64) {
assert_eq!(
actual.len(),
expected.len(),
"{name}: sample count mismatch: actual={}, expected={}",
actual.len(),
expected.len()
);
for (i, (a, e)) in actual.iter().zip(expected.iter()).enumerate() {
assert_slice_close(&format!("{name}[{i}]"), a, e, epsilon);
}
}
#[allow(dead_code)]
pub fn assert_image_batch_close(
name: &str,
actual: &[Vec<Vec<f64>>],
expected: &[Vec<Vec<f64>>],
epsilon: f64,
) {
assert_eq!(
actual.len(),
expected.len(),
"{name}: sample count mismatch: actual={}, expected={}",
actual.len(),
expected.len()
);
for (i, (a, e)) in actual.iter().zip(expected.iter()).enumerate() {
assert_batch_close(&format!("{name}[{i}]"), a, e, epsilon);
}
}