use alloc::format;
use alloc::string::String;
use burn_std::{DType, bf16, f16};
use num_traits::{Float, ToPrimitive};
use super::TensorData;
use crate::element::Element;
#[derive(Debug, Clone, Copy)]
pub struct Tolerance<F> {
relative: F,
absolute: F,
}
impl<F: Float> Default for Tolerance<F> {
fn default() -> Self {
Self::balanced()
}
}
impl<F: Float> Tolerance<F> {
pub fn strict() -> Self {
Self {
relative: F::from(0.00).unwrap(),
absolute: F::from(64).unwrap() * F::min_positive_value(),
}
}
pub fn balanced() -> Self {
Self {
relative: F::from(0.005).unwrap(), absolute: F::from(1e-5).unwrap(),
}
}
pub fn permissive() -> Self {
Self {
relative: F::from(0.01).unwrap(), absolute: F::from(0.01).unwrap(),
}
}
pub fn rel_abs<FF: ToPrimitive>(relative: FF, absolute: FF) -> Self {
let relative = Self::check_relative(relative);
let absolute = Self::check_absolute(absolute);
Self { relative, absolute }
}
pub fn relative<FF: ToPrimitive>(tolerance: FF) -> Self {
let relative = Self::check_relative(tolerance);
Self {
relative,
absolute: F::from(0.0).unwrap(),
}
}
pub fn absolute<FF: ToPrimitive>(tolerance: FF) -> Self {
let absolute = Self::check_absolute(tolerance);
Self {
relative: F::from(0.0).unwrap(),
absolute,
}
}
pub fn set_relative<FF: ToPrimitive>(mut self, tolerance: FF) -> Self {
self.relative = Self::check_relative(tolerance);
self
}
pub fn set_half_precision_relative<FF: ToPrimitive>(mut self, tolerance: FF) -> Self {
if core::mem::size_of::<F>() == 2 {
self.relative = Self::check_relative(tolerance);
}
self
}
pub fn set_single_precision_relative<FF: ToPrimitive>(mut self, tolerance: FF) -> Self {
if core::mem::size_of::<F>() == 4 {
self.relative = Self::check_relative(tolerance);
}
self
}
pub fn set_double_precision_relative<FF: ToPrimitive>(mut self, tolerance: FF) -> Self {
if core::mem::size_of::<F>() == 8 {
self.relative = Self::check_relative(tolerance);
}
self
}
pub fn set_absolute<FF: ToPrimitive>(mut self, tolerance: FF) -> Self {
self.absolute = Self::check_absolute(tolerance);
self
}
pub fn set_half_precision_absolute<FF: ToPrimitive>(mut self, tolerance: FF) -> Self {
if core::mem::size_of::<F>() == 2 {
self.absolute = Self::check_absolute(tolerance);
}
self
}
pub fn set_single_precision_absolute<FF: ToPrimitive>(mut self, tolerance: FF) -> Self {
if core::mem::size_of::<F>() == 4 {
self.absolute = Self::check_absolute(tolerance);
}
self
}
pub fn set_double_precision_absolute<FF: ToPrimitive>(mut self, tolerance: FF) -> Self {
if core::mem::size_of::<F>() == 8 {
self.absolute = Self::check_absolute(tolerance);
}
self
}
pub fn approx_eq(&self, x: F, y: F) -> bool {
if x == y {
return true;
}
let diff = (x - y).abs();
let max = F::max(x.abs(), y.abs());
diff < self.absolute.max(self.relative * max)
}
fn check_relative<FF: ToPrimitive>(tolerance: FF) -> F {
let tolerance = F::from(tolerance).unwrap();
assert!(tolerance <= F::one());
tolerance
}
fn check_absolute<FF: ToPrimitive>(tolerance: FF) -> F {
let tolerance = F::from(tolerance).unwrap();
assert!(tolerance >= F::zero());
tolerance
}
}
impl TensorData {
#[track_caller]
pub fn assert_eq(&self, other: &Self, strict: bool) {
if strict {
assert_eq!(
self.dtype, other.dtype,
"Data types differ ({:?} != {:?})",
self.dtype, other.dtype
);
}
match self.dtype {
DType::F64 => self.assert_eq_elem::<f64>(other),
DType::F32 | DType::Flex32 => self.assert_eq_elem::<f32>(other),
DType::F16 => self.assert_eq_elem::<f16>(other),
DType::BF16 => self.assert_eq_elem::<bf16>(other),
DType::I64 => self.assert_eq_elem::<i64>(other),
DType::I32 => self.assert_eq_elem::<i32>(other),
DType::I16 => self.assert_eq_elem::<i16>(other),
DType::I8 => self.assert_eq_elem::<i8>(other),
DType::U64 => self.assert_eq_elem::<u64>(other),
DType::U32 => self.assert_eq_elem::<u32>(other),
DType::U16 => self.assert_eq_elem::<u16>(other),
DType::U8 => self.assert_eq_elem::<u8>(other),
DType::Bool => self.assert_eq_elem::<bool>(other),
DType::QFloat(q) => {
let q_other = if let DType::QFloat(q_other) = other.dtype {
q_other
} else {
panic!("Quantized data differs from other not quantized data")
};
if q.value == q_other.value && q.level == q_other.level {
self.assert_eq_elem::<i8>(other)
} else {
panic!("Quantization schemes differ ({q:?} != {q_other:?})")
}
}
}
}
#[track_caller]
fn assert_eq_elem<E: Element>(&self, other: &Self) {
let mut message = String::new();
if self.shape != other.shape {
message += format!(
"\n => Shape is different: {:?} != {:?}",
self.shape, other.shape
)
.as_str();
}
let mut num_diff = 0;
let max_num_diff = 5;
for (i, (a, b)) in self.iter::<E>().zip(other.iter::<E>()).enumerate() {
if a.cmp(&b).is_ne() {
if num_diff < max_num_diff {
message += format!("\n => Position {i}: {a} != {b}").as_str();
}
num_diff += 1;
}
}
if num_diff >= max_num_diff {
message += format!("\n{} more errors...", num_diff - max_num_diff).as_str();
}
if !message.is_empty() {
panic!("Tensors are not eq:{message}");
}
}
#[track_caller]
pub fn assert_approx_eq<F: Float + Element>(&self, other: &Self, tolerance: Tolerance<F>) {
let mut message = String::new();
if self.shape != other.shape {
message += format!(
"\n => Shape is different: {:?} != {:?}",
self.shape, other.shape
)
.as_str();
}
let iter = self.iter::<F>().zip(other.iter::<F>());
let mut num_diff = 0;
let max_num_diff = 5;
for (i, (a, b)) in iter.enumerate() {
let both_nan = a.is_nan() && b.is_nan();
let both_inf =
a.is_infinite() && b.is_infinite() && ((a > F::zero()) == (b > F::zero()));
if both_nan || both_inf {
continue;
}
if !tolerance.approx_eq(F::from(a).unwrap(), F::from(b).unwrap()) {
if num_diff < max_num_diff {
let diff_abs = ToPrimitive::to_f64(&(a - b).abs()).unwrap();
let max = F::max(a.abs(), b.abs());
let diff_rel = diff_abs / ToPrimitive::to_f64(&max).unwrap();
let tol_rel = ToPrimitive::to_f64(&tolerance.relative).unwrap();
let tol_abs = ToPrimitive::to_f64(&tolerance.absolute).unwrap();
message += format!(
"\n => Position {i}: {a} != {b}\n diff (rel = {diff_rel:+.2e}, abs = {diff_abs:+.2e}), tol (rel = {tol_rel:+.2e}, abs = {tol_abs:+.2e})"
)
.as_str();
}
num_diff += 1;
}
}
if num_diff >= max_num_diff {
message += format!("\n{} more errors...", num_diff - 5).as_str();
}
if !message.is_empty() {
panic!("Tensors are not approx eq:{message}");
}
}
pub fn assert_within_range<E: Element>(&self, range: core::ops::Range<E>) {
for elem in self.iter::<E>() {
if elem.cmp(&range.start).is_lt() || elem.cmp(&range.end).is_ge() {
panic!("Element ({elem:?}) is not within range {range:?}");
}
}
}
pub fn assert_within_range_inclusive<E: Element>(&self, range: core::ops::RangeInclusive<E>) {
let start = range.start();
let end = range.end();
for elem in self.iter::<E>() {
if elem.cmp(start).is_lt() || elem.cmp(end).is_gt() {
panic!("Element ({elem:?}) is not within range {range:?}");
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn should_assert_appox_eq_limit() {
let data1 = TensorData::from([[3.0, 5.0, 6.0]]);
let data2 = TensorData::from([[3.03, 5.0, 6.0]]);
data1.assert_approx_eq::<f32>(&data2, Tolerance::absolute(3e-2));
data1.assert_approx_eq::<f16>(&data2, Tolerance::absolute(3e-2));
}
#[test]
#[should_panic]
fn should_assert_approx_eq_above_limit() {
let data1 = TensorData::from([[3.0, 5.0, 6.0]]);
let data2 = TensorData::from([[3.031, 5.0, 6.0]]);
data1.assert_approx_eq::<f32>(&data2, Tolerance::absolute(1e-2));
}
#[test]
#[should_panic]
fn should_assert_approx_eq_check_shape() {
let data1 = TensorData::from([[3.0, 5.0, 6.0, 7.0]]);
let data2 = TensorData::from([[3.0, 5.0, 6.0]]);
data1.assert_approx_eq::<f32>(&data2, Tolerance::absolute(1e-2));
}
}