use std::error::Error;
use std::fmt::{Debug, Display, Formatter};
use crate::{AsView, Layout, TensorView};
pub trait ApproxEq: Sized {
fn default_abs_tolerance() -> Self;
fn default_rel_tolerance() -> Self;
fn approx_eq_with_atol_rtol(&self, other: &Self, atol: Self, rtol: Self) -> bool;
fn approx_eq_with_tolerance(&self, other: &Self, epsilon: Self) -> bool {
self.approx_eq_with_atol_rtol(other, epsilon, Self::default_rel_tolerance())
}
fn approx_eq(&self, other: &Self) -> bool {
self.approx_eq_with_atol_rtol(
other,
Self::default_abs_tolerance(),
Self::default_rel_tolerance(),
)
}
}
impl ApproxEq for f32 {
#[inline]
fn default_abs_tolerance() -> f32 {
1e-8
}
#[inline]
fn default_rel_tolerance() -> f32 {
1e-5
}
#[inline]
fn approx_eq_with_atol_rtol(&self, other: &f32, atol: f32, rtol: f32) -> bool {
if self == other {
true
} else {
(self - other).abs() <= atol + rtol * other.abs()
}
}
}
macro_rules! impl_approx_eq_for_ints {
($($type:ty),*) => {
$(impl ApproxEq for $type {
#[inline]
fn default_abs_tolerance() -> $type {
0
}
#[inline]
fn default_rel_tolerance() -> $type {
0
}
#[inline]
fn approx_eq_with_atol_rtol(&self, other: &$type, atol: $type, _rtol: $type) -> bool {
(self.max(other) - self.min(other)) <= atol
}
})+
};
}
impl_approx_eq_for_ints!(i8, i16, i32, i64, u8, u16, u32, u64);
fn index_from_linear_index(shape: &[usize], lin_index: usize) -> Vec<usize> {
assert!(
lin_index < shape.iter().product(),
"Linear index {} is out of bounds for shape {:?}",
lin_index,
shape,
);
(0..shape.len())
.map(|dim| {
let elts_per_index: usize = shape[dim + 1..].iter().product();
let lin_index_for_dim = lin_index % (shape[dim] * elts_per_index);
lin_index_for_dim / elts_per_index
})
.collect()
}
#[derive(Debug)]
pub enum ExpectEqualError {
ShapeMismatch(String),
ValueMismatch(String),
}
impl Display for ExpectEqualError {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
ExpectEqualError::ShapeMismatch(details) => write!(f, "{}", details),
ExpectEqualError::ValueMismatch(details) => write!(f, "{}", details),
}
}
}
impl Error for ExpectEqualError {}
pub fn expect_equal<V: AsView>(x: &V, y: &V) -> Result<(), ExpectEqualError>
where
V::Elem: Clone + Debug + ApproxEq,
{
expect_equal_with_tolerance(
x,
y,
V::Elem::default_abs_tolerance(),
V::Elem::default_rel_tolerance(),
)
}
pub fn expect_equal_with_tolerance<V: AsView>(
x: &V,
y: &V,
atol: V::Elem,
rtol: V::Elem,
) -> Result<(), ExpectEqualError>
where
V::Elem: Clone + Debug + ApproxEq,
{
if x.shape() != y.shape() {
return Err(ExpectEqualError::ShapeMismatch(format!(
"Tensors have different shapes. {:?} vs. {:?}",
x.shape(),
y.shape()
)));
}
let mismatches: Vec<_> = x
.iter()
.zip(y.iter())
.enumerate()
.filter_map(|(i, (xi, yi))| {
if !xi.approx_eq_with_atol_rtol(yi, atol.clone(), rtol.clone()) {
Some((index_from_linear_index(x.shape().as_ref(), i), xi, yi))
} else {
None
}
})
.collect();
if !mismatches.is_empty() {
let max_examples = 16;
Err(ExpectEqualError::ValueMismatch(format!(
"Tensor values differ at {} of {} indexes: {:?}{}",
mismatches.len(),
x.len(),
&mismatches[..mismatches.len().min(max_examples)],
if mismatches.len() > max_examples {
"..."
} else {
""
}
)))
} else {
Ok(())
}
}
pub fn eq_with_nans(a: TensorView, b: TensorView) -> bool {
if a.shape() != b.shape() {
false
} else {
a.iter()
.zip(b.iter())
.all(|(a, b)| (a.is_nan() && b.is_nan()) || a == b)
}
}
#[cfg(test)]
mod tests {
use super::ApproxEq;
#[test]
fn test_approx_eq_i32() {
let vals = [
-5,
-1,
0,
1,
5,
i32::MIN,
i32::MIN + 1,
i32::MAX,
i32::MAX - 1,
];
for val in vals {
assert!(val.approx_eq(&val));
if val > i32::MIN {
assert!(!val.approx_eq(&(val - 1)));
}
if val < i32::MAX {
assert!(!val.approx_eq(&(val + 1)));
}
}
}
#[test]
fn test_approx_eq_f32() {
let vals = [-1000., -5., -0.5, 0., 0.5, 5., 1000.];
for val in vals {
assert!(val.approx_eq(&val));
}
for val in vals {
let close = val + 9e-9 + val * 9e-6;
assert_ne!(val, close);
assert!(val.approx_eq(&close));
}
for val in vals {
let not_close = val + 2e-8 + val * 2e-5;
assert_ne!(val, not_close);
assert!(!val.approx_eq(¬_close));
}
let vals = [f32::NEG_INFINITY, f32::INFINITY];
for val in vals {
assert!(val.approx_eq(&val));
}
}
}