bunsen 0.21.0-pre.2

bunsen is a community companion library for burn
Documentation
//! Private module for internal use in the `bimm` crate.
use std::fmt::Debug;

use num_traits::float::Float;

cfg_select! {
    feature = "cuda" => {
        /// Testing backend selected for performance.
        pub type PerfTestBackend = ::burn::backend::Cuda;
    }
    feature = "metal" => {
        /// Testing backend selected for performance.
        pub type PerfTestBackend = ::burn::backend::Metal;
    }
    feature = "wgpu" => {
        /// Testing backend selected for performance.
        pub type PerfTestBackend = ::burn::backend::Wgpu;
    }
    _ => {
        /// Testing backend selected for performance.
        pub type PerfTestBackend = ::burn::backend::Flex;
    }
}
/// Testing backend selected for fast setup.
pub type SetupTestBackend = ::burn::backend::Flex;

/// Asserts that two vectors of floating-point numbers are close to each other
/// within a given tolerance.
pub fn assert_close_to_vec<T>(
    actual: &[T],
    expected: &[T],
    tolerance: T,
) where
    T: Float + std::ops::Sub<Output = T> + std::ops::Add<Output = T> + Copy + Debug,
{
    let mut pass = actual.len() == expected.len();
    for (&a, &e) in actual.iter().zip(expected.iter()) {
        if !pass {
            break;
        }
        if (a - e).abs() > tolerance {
            pass = false;
            break;
        }
    }
    if !pass {
        panic!("Expected (+/- {tolerance:?}):\n{expected:?}\nActual:\n{actual:?}");
    }
}

#[cfg(test)]
mod tests {
    use crate::support::testing::assert_close_to_vec;

    #[test]
    fn test_assert_close_to_vec() {
        let actual = vec![1.0, 2.0, 3.0];
        let expected = vec![1.0, 2.0, 3.0];
        assert_close_to_vec(&actual, &expected, 0.01);

        let actual = vec![1.0, 2.0, 3.1];
        let expected = vec![1.0, 2.0, 3.0];
        assert_close_to_vec(&actual, &expected, 0.2);
    }

    #[test]
    #[should_panic]
    fn test_assert_close_to_vec_bad_values() {
        let actual = vec![1.0, 2.0, 3.0];
        let expected = vec![1.0, 2.0, 3.5];
        assert_close_to_vec(&actual, &expected, 0.01);
    }

    #[test]
    #[should_panic]
    fn test_assert_close_to_vec_different_lengths() {
        let actual = vec![1.0, 2.0];
        let expected = vec![1.0, 2.0, 3.0];
        assert_close_to_vec(&actual, &expected, 0.01);
    }
}