burn_backend/data/
compare.rs

1use alloc::format;
2use alloc::string::String;
3use burn_std::{DType, bf16, f16};
4use num_traits::{Float, ToPrimitive};
5
6use super::TensorData;
7use crate::element::Element;
8
9/// The tolerance used to compare to floating point numbers.
10///
11/// Generally, two numbers `x` and `y` are approximately equal if
12///
13/// ```text
14/// |x - y| < max(R * (|x + y|), A)
15/// ```
16///
17/// where `R` is the relative tolerance and `A` is the absolute tolerance.
18///
19///
20/// The most common way to initialize this struct is to use `Tolerance::<F>::default()`.
21/// In that case, the relative and absolute tolerances are computed using an heuristic based
22/// on the EPSILON and MIN_POSITIVE values of the given floating point type `F`.
23///
24/// Another common initialization is `Tolerance::<F>::rel_abs(1e-4, 1e-5).set_half_precision_relative(1e-2)`.
25/// This will use a sane default to manage values too close to 0.0 and
26/// use different relative tolerances depending on the floating point precision.
27#[derive(Debug, Clone, Copy)]
28pub struct Tolerance<F> {
29    relative: F,
30    absolute: F,
31}
32
33impl<F: Float> Default for Tolerance<F> {
34    fn default() -> Self {
35        Self::balanced()
36    }
37}
38
39impl<F: Float> Tolerance<F> {
40    /// Create a tolerance with strict precision setting.
41    pub fn strict() -> Self {
42        Self {
43            relative: F::from(0.00).unwrap(),
44            absolute: F::from(64).unwrap() * F::min_positive_value(),
45        }
46    }
47    /// Create a tolerance with balanced precision setting.
48    pub fn balanced() -> Self {
49        Self {
50            relative: F::from(0.005).unwrap(), // 0.5%
51            absolute: F::from(1e-5).unwrap(),
52        }
53    }
54
55    /// Create a tolerance with permissive precision setting.
56    pub fn permissive() -> Self {
57        Self {
58            relative: F::from(0.01).unwrap(), // 1.0%
59            absolute: F::from(0.01).unwrap(),
60        }
61    }
62    /// When comparing two numbers, this uses both the relative and absolute differences.
63    ///
64    /// That is, `x` and `y` are approximately equal if
65    ///
66    /// ```text
67    /// |x - y| < max(R * (|x + y|), A)
68    /// ```
69    ///
70    /// where `R` is the `relative` tolerance and `A` is the `absolute` tolerance.
71    pub fn rel_abs<FF: ToPrimitive>(relative: FF, absolute: FF) -> Self {
72        let relative = Self::check_relative(relative);
73        let absolute = Self::check_absolute(absolute);
74
75        Self { relative, absolute }
76    }
77
78    /// When comparing two numbers, this uses only the relative difference.
79    ///
80    /// That is, `x` and `y` are approximately equal if
81    ///
82    /// ```text
83    /// |x - y| < R * max(|x|, |y|)
84    /// ```
85    ///
86    /// where `R` is the relative `tolerance`.
87    pub fn relative<FF: ToPrimitive>(tolerance: FF) -> Self {
88        let relative = Self::check_relative(tolerance);
89
90        Self {
91            relative,
92            absolute: F::from(0.0).unwrap(),
93        }
94    }
95
96    /// When comparing two numbers, this uses only the absolute difference.
97    ///
98    /// That is, `x` and `y` are approximately equal if
99    ///
100    /// ```text
101    /// |x - y| < A
102    /// ```
103    ///
104    /// where `A` is the absolute `tolerance`.
105    pub fn absolute<FF: ToPrimitive>(tolerance: FF) -> Self {
106        let absolute = Self::check_absolute(tolerance);
107
108        Self {
109            relative: F::from(0.0).unwrap(),
110            absolute,
111        }
112    }
113
114    /// Change the relative tolerance to the given one.
115    pub fn set_relative<FF: ToPrimitive>(mut self, tolerance: FF) -> Self {
116        self.relative = Self::check_relative(tolerance);
117        self
118    }
119
120    /// Change the relative tolerance to the given one only if `F` is half precision.
121    pub fn set_half_precision_relative<FF: ToPrimitive>(mut self, tolerance: FF) -> Self {
122        if core::mem::size_of::<F>() == 2 {
123            self.relative = Self::check_relative(tolerance);
124        }
125        self
126    }
127
128    /// Change the relative tolerance to the given one only if `F` is single precision.
129    pub fn set_single_precision_relative<FF: ToPrimitive>(mut self, tolerance: FF) -> Self {
130        if core::mem::size_of::<F>() == 4 {
131            self.relative = Self::check_relative(tolerance);
132        }
133        self
134    }
135
136    /// Change the relative tolerance to the given one only if `F` is double precision.
137    pub fn set_double_precision_relative<FF: ToPrimitive>(mut self, tolerance: FF) -> Self {
138        if core::mem::size_of::<F>() == 8 {
139            self.relative = Self::check_relative(tolerance);
140        }
141        self
142    }
143
144    /// Change the absolute tolerance to the given one.
145    pub fn set_absolute<FF: ToPrimitive>(mut self, tolerance: FF) -> Self {
146        self.absolute = Self::check_absolute(tolerance);
147        self
148    }
149
150    /// Change the absolute tolerance to the given one only if `F` is half precision.
151    pub fn set_half_precision_absolute<FF: ToPrimitive>(mut self, tolerance: FF) -> Self {
152        if core::mem::size_of::<F>() == 2 {
153            self.absolute = Self::check_absolute(tolerance);
154        }
155        self
156    }
157
158    /// Change the absolute tolerance to the given one only if `F` is single precision.
159    pub fn set_single_precision_absolute<FF: ToPrimitive>(mut self, tolerance: FF) -> Self {
160        if core::mem::size_of::<F>() == 4 {
161            self.absolute = Self::check_absolute(tolerance);
162        }
163        self
164    }
165
166    /// Change the absolute tolerance to the given one only if `F` is double precision.
167    pub fn set_double_precision_absolute<FF: ToPrimitive>(mut self, tolerance: FF) -> Self {
168        if core::mem::size_of::<F>() == 8 {
169            self.absolute = Self::check_absolute(tolerance);
170        }
171        self
172    }
173
174    /// Checks if `x` and `y` are approximately equal given the tolerance.
175    pub fn approx_eq(&self, x: F, y: F) -> bool {
176        // See the accepted answer here
177        // https://stackoverflow.com/questions/4915462/how-should-i-do-floating-point-comparison
178
179        // This also handles the case where both a and b are infinity so that we don't need
180        // to manage it in the rest of the function.
181        if x == y {
182            return true;
183        }
184
185        let diff = (x - y).abs();
186        let max = F::max(x.abs(), y.abs());
187
188        diff < self.absolute.max(self.relative * max)
189    }
190
191    fn check_relative<FF: ToPrimitive>(tolerance: FF) -> F {
192        let tolerance = F::from(tolerance).unwrap();
193        assert!(tolerance <= F::one());
194        tolerance
195    }
196
197    fn check_absolute<FF: ToPrimitive>(tolerance: FF) -> F {
198        let tolerance = F::from(tolerance).unwrap();
199        assert!(tolerance >= F::zero());
200        tolerance
201    }
202}
203
204impl TensorData {
205    /// Asserts the data is equal to another data.
206    ///
207    /// # Arguments
208    ///
209    /// * `other` - The other data.
210    /// * `strict` - If true, the data types must the be same.
211    ///   Otherwise, the comparison is done in the current data type.
212    ///
213    /// # Panics
214    ///
215    /// Panics if the data is not equal.
216    #[track_caller]
217    pub fn assert_eq(&self, other: &Self, strict: bool) {
218        if strict {
219            assert_eq!(
220                self.dtype, other.dtype,
221                "Data types differ ({:?} != {:?})",
222                self.dtype, other.dtype
223            );
224        }
225
226        match self.dtype {
227            DType::F64 => self.assert_eq_elem::<f64>(other),
228            DType::F32 | DType::Flex32 => self.assert_eq_elem::<f32>(other),
229            DType::F16 => self.assert_eq_elem::<f16>(other),
230            DType::BF16 => self.assert_eq_elem::<bf16>(other),
231            DType::I64 => self.assert_eq_elem::<i64>(other),
232            DType::I32 => self.assert_eq_elem::<i32>(other),
233            DType::I16 => self.assert_eq_elem::<i16>(other),
234            DType::I8 => self.assert_eq_elem::<i8>(other),
235            DType::U64 => self.assert_eq_elem::<u64>(other),
236            DType::U32 => self.assert_eq_elem::<u32>(other),
237            DType::U16 => self.assert_eq_elem::<u16>(other),
238            DType::U8 => self.assert_eq_elem::<u8>(other),
239            DType::Bool => self.assert_eq_elem::<bool>(other),
240            DType::QFloat(q) => {
241                // Strict or not, it doesn't make sense to compare quantized data to not quantized data for equality
242                let q_other = if let DType::QFloat(q_other) = other.dtype {
243                    q_other
244                } else {
245                    panic!("Quantized data differs from other not quantized data")
246                };
247
248                // Data equality mostly depends on input quantization type, but we also check level
249                if q.value == q_other.value && q.level == q_other.level {
250                    self.assert_eq_elem::<i8>(other)
251                } else {
252                    panic!("Quantization schemes differ ({q:?} != {q_other:?})")
253                }
254            }
255        }
256    }
257
258    #[track_caller]
259    fn assert_eq_elem<E: Element>(&self, other: &Self) {
260        let mut message = String::new();
261        if self.shape != other.shape {
262            message += format!(
263                "\n  => Shape is different: {:?} != {:?}",
264                self.shape, other.shape
265            )
266            .as_str();
267        }
268
269        let mut num_diff = 0;
270        let max_num_diff = 5;
271        for (i, (a, b)) in self.iter::<E>().zip(other.iter::<E>()).enumerate() {
272            if a.cmp(&b).is_ne() {
273                // Only print the first 5 different values.
274                if num_diff < max_num_diff {
275                    message += format!("\n  => Position {i}: {a} != {b}").as_str();
276                }
277                num_diff += 1;
278            }
279        }
280
281        if num_diff >= max_num_diff {
282            message += format!("\n{} more errors...", num_diff - max_num_diff).as_str();
283        }
284
285        if !message.is_empty() {
286            panic!("Tensors are not eq:{message}");
287        }
288    }
289
290    /// Asserts the data is approximately equal to another data.
291    ///
292    /// # Arguments
293    ///
294    /// * `other` - The other data.
295    /// * `tolerance` - The tolerance of the comparison.
296    ///
297    /// # Panics
298    ///
299    /// Panics if the data is not approximately equal.
300    #[track_caller]
301    pub fn assert_approx_eq<F: Float + Element>(&self, other: &Self, tolerance: Tolerance<F>) {
302        let mut message = String::new();
303        if self.shape != other.shape {
304            message += format!(
305                "\n  => Shape is different: {:?} != {:?}",
306                self.shape, other.shape
307            )
308            .as_str();
309        }
310
311        let iter = self.iter::<F>().zip(other.iter::<F>());
312
313        let mut num_diff = 0;
314        let max_num_diff = 5;
315
316        for (i, (a, b)) in iter.enumerate() {
317            //if they are both nan, then they are equally nan
318            let both_nan = a.is_nan() && b.is_nan();
319            //this works for both infinities
320            let both_inf =
321                a.is_infinite() && b.is_infinite() && ((a > F::zero()) == (b > F::zero()));
322
323            if both_nan || both_inf {
324                continue;
325            }
326
327            if !tolerance.approx_eq(F::from(a).unwrap(), F::from(b).unwrap()) {
328                // Only print the first 5 different values.
329                if num_diff < max_num_diff {
330                    let diff_abs = ToPrimitive::to_f64(&(a - b).abs()).unwrap();
331                    let max = F::max(a.abs(), b.abs());
332                    let diff_rel = diff_abs / ToPrimitive::to_f64(&max).unwrap();
333
334                    let tol_rel = ToPrimitive::to_f64(&tolerance.relative).unwrap();
335                    let tol_abs = ToPrimitive::to_f64(&tolerance.absolute).unwrap();
336
337                    message += format!(
338                        "\n  => Position {i}: {a} != {b}\n     diff (rel = {diff_rel:+.2e}, abs = {diff_abs:+.2e}), tol (rel = {tol_rel:+.2e}, abs = {tol_abs:+.2e})"
339                    )
340                    .as_str();
341                }
342                num_diff += 1;
343            }
344        }
345
346        if num_diff >= max_num_diff {
347            message += format!("\n{} more errors...", num_diff - 5).as_str();
348        }
349
350        if !message.is_empty() {
351            panic!("Tensors are not approx eq:{message}");
352        }
353    }
354
355    /// Asserts each value is within a given range.
356    ///
357    /// # Arguments
358    ///
359    /// * `range` - The range.
360    ///
361    /// # Panics
362    ///
363    /// If any value is not within the half-open range bounded inclusively below
364    /// and exclusively above (`start..end`).
365    pub fn assert_within_range<E: Element>(&self, range: core::ops::Range<E>) {
366        for elem in self.iter::<E>() {
367            if elem.cmp(&range.start).is_lt() || elem.cmp(&range.end).is_ge() {
368                panic!("Element ({elem:?}) is not within range {range:?}");
369            }
370        }
371    }
372
373    /// Asserts each value is within a given inclusive range.
374    ///
375    /// # Arguments
376    ///
377    /// * `range` - The range.
378    ///
379    /// # Panics
380    ///
381    /// If any value is not within the half-open range bounded inclusively (`start..=end`).
382    pub fn assert_within_range_inclusive<E: Element>(&self, range: core::ops::RangeInclusive<E>) {
383        let start = range.start();
384        let end = range.end();
385
386        for elem in self.iter::<E>() {
387            if elem.cmp(start).is_lt() || elem.cmp(end).is_gt() {
388                panic!("Element ({elem:?}) is not within range {range:?}");
389            }
390        }
391    }
392}
393
394#[cfg(test)]
395mod tests {
396    use super::*;
397
398    #[test]
399    fn should_assert_appox_eq_limit() {
400        let data1 = TensorData::from([[3.0, 5.0, 6.0]]);
401        let data2 = TensorData::from([[3.03, 5.0, 6.0]]);
402
403        data1.assert_approx_eq::<f32>(&data2, Tolerance::absolute(3e-2));
404        data1.assert_approx_eq::<f16>(&data2, Tolerance::absolute(3e-2));
405    }
406
407    #[test]
408    #[should_panic]
409    fn should_assert_approx_eq_above_limit() {
410        let data1 = TensorData::from([[3.0, 5.0, 6.0]]);
411        let data2 = TensorData::from([[3.031, 5.0, 6.0]]);
412
413        data1.assert_approx_eq::<f32>(&data2, Tolerance::absolute(1e-2));
414    }
415
416    #[test]
417    #[should_panic]
418    fn should_assert_approx_eq_check_shape() {
419        let data1 = TensorData::from([[3.0, 5.0, 6.0, 7.0]]);
420        let data2 = TensorData::from([[3.0, 5.0, 6.0]]);
421
422        data1.assert_approx_eq::<f32>(&data2, Tolerance::absolute(1e-2));
423    }
424}