vkml 0.0.2

High-level Vulkan-based machine learning library
use std::path::PathBuf;

pub fn assert_tensors_close(actual: &[f32], expected: &[f32], rtol: f32, atol: f32, name: &str) {
    assert_eq!(
        actual.len(),
        expected.len(),
        "Tensor length mismatch for {}: actual={}, expected={}",
        name,
        actual.len(),
        expected.len()
    );

    let mut max_abs_error: f32 = 0.0;
    let mut sum_abs_error: f64 = 0.0;
    let mut failure_count = 0;

    for (i, (&a, &b)) in actual.iter().zip(expected.iter()).enumerate() {
        let abs_error = (a - b).abs();
        if abs_error > max_abs_error {
            max_abs_error = abs_error;
        }
        sum_abs_error += abs_error as f64;

        let tolerance = atol + rtol * b.abs();
        if abs_error > tolerance {
            if failure_count < 10 {
                println!(
                    "FAIL [{}]: index {}, actual={}, expected={}, abs_error={}, tolerance={}",
                    name, i, a, b, abs_error, tolerance
                );
            }
            failure_count += 1;
        }
    }

    let mean_abs_error = (sum_abs_error / actual.len() as f64) as f32;

    println!(
        "STATS [{}]: Max Abs Error = {:.8}, Mean Abs Error = {:.8}",
        name, max_abs_error, mean_abs_error
    );

    if failure_count > 0 {
        panic!(
            "Tensor comparison failed for {} ({} elements exceeded tolerance). Max Abs Error = {:.8}",
            name, failure_count, max_abs_error
        );
    }
}

pub fn get_test_data_path(file: &str) -> PathBuf {
    PathBuf::from(env!("CARGO_MANIFEST_DIR"))
        .join("tests")
        .join("data")
        .join(file)
}