1use alloc::format;
2use alloc::string::String;
3use burn_std::{DType, bf16, f16};
4use num_traits::{Float, ToPrimitive};
5
6use super::TensorData;
7use crate::element::Element;
8
9#[derive(Debug, Clone, Copy)]
28pub struct Tolerance<F> {
29 relative: F,
30 absolute: F,
31}
32
33impl<F: Float> Default for Tolerance<F> {
34 fn default() -> Self {
35 Self::balanced()
36 }
37}
38
39impl<F: Float> Tolerance<F> {
40 pub fn strict() -> Self {
42 Self {
43 relative: F::from(0.00).unwrap(),
44 absolute: F::from(64).unwrap() * F::min_positive_value(),
45 }
46 }
47 pub fn balanced() -> Self {
49 Self {
50 relative: F::from(0.005).unwrap(), absolute: F::from(1e-5).unwrap(),
52 }
53 }
54
55 pub fn permissive() -> Self {
57 Self {
58 relative: F::from(0.01).unwrap(), absolute: F::from(0.01).unwrap(),
60 }
61 }
62 pub fn rel_abs<FF: ToPrimitive>(relative: FF, absolute: FF) -> Self {
72 let relative = Self::check_relative(relative);
73 let absolute = Self::check_absolute(absolute);
74
75 Self { relative, absolute }
76 }
77
78 pub fn relative<FF: ToPrimitive>(tolerance: FF) -> Self {
88 let relative = Self::check_relative(tolerance);
89
90 Self {
91 relative,
92 absolute: F::from(0.0).unwrap(),
93 }
94 }
95
96 pub fn absolute<FF: ToPrimitive>(tolerance: FF) -> Self {
106 let absolute = Self::check_absolute(tolerance);
107
108 Self {
109 relative: F::from(0.0).unwrap(),
110 absolute,
111 }
112 }
113
114 pub fn set_relative<FF: ToPrimitive>(mut self, tolerance: FF) -> Self {
116 self.relative = Self::check_relative(tolerance);
117 self
118 }
119
120 pub fn set_half_precision_relative<FF: ToPrimitive>(mut self, tolerance: FF) -> Self {
122 if core::mem::size_of::<F>() == 2 {
123 self.relative = Self::check_relative(tolerance);
124 }
125 self
126 }
127
128 pub fn set_single_precision_relative<FF: ToPrimitive>(mut self, tolerance: FF) -> Self {
130 if core::mem::size_of::<F>() == 4 {
131 self.relative = Self::check_relative(tolerance);
132 }
133 self
134 }
135
136 pub fn set_double_precision_relative<FF: ToPrimitive>(mut self, tolerance: FF) -> Self {
138 if core::mem::size_of::<F>() == 8 {
139 self.relative = Self::check_relative(tolerance);
140 }
141 self
142 }
143
144 pub fn set_absolute<FF: ToPrimitive>(mut self, tolerance: FF) -> Self {
146 self.absolute = Self::check_absolute(tolerance);
147 self
148 }
149
150 pub fn set_half_precision_absolute<FF: ToPrimitive>(mut self, tolerance: FF) -> Self {
152 if core::mem::size_of::<F>() == 2 {
153 self.absolute = Self::check_absolute(tolerance);
154 }
155 self
156 }
157
158 pub fn set_single_precision_absolute<FF: ToPrimitive>(mut self, tolerance: FF) -> Self {
160 if core::mem::size_of::<F>() == 4 {
161 self.absolute = Self::check_absolute(tolerance);
162 }
163 self
164 }
165
166 pub fn set_double_precision_absolute<FF: ToPrimitive>(mut self, tolerance: FF) -> Self {
168 if core::mem::size_of::<F>() == 8 {
169 self.absolute = Self::check_absolute(tolerance);
170 }
171 self
172 }
173
174 pub fn approx_eq(&self, x: F, y: F) -> bool {
176 if x == y {
182 return true;
183 }
184
185 let diff = (x - y).abs();
186 let max = F::max(x.abs(), y.abs());
187
188 diff < self.absolute.max(self.relative * max)
189 }
190
191 fn check_relative<FF: ToPrimitive>(tolerance: FF) -> F {
192 let tolerance = F::from(tolerance).unwrap();
193 assert!(tolerance <= F::one());
194 tolerance
195 }
196
197 fn check_absolute<FF: ToPrimitive>(tolerance: FF) -> F {
198 let tolerance = F::from(tolerance).unwrap();
199 assert!(tolerance >= F::zero());
200 tolerance
201 }
202}
203
204impl TensorData {
205 #[track_caller]
217 pub fn assert_eq(&self, other: &Self, strict: bool) {
218 if strict {
219 assert_eq!(
220 self.dtype, other.dtype,
221 "Data types differ ({:?} != {:?})",
222 self.dtype, other.dtype
223 );
224 }
225
226 match self.dtype {
227 DType::F64 => self.assert_eq_elem::<f64>(other),
228 DType::F32 | DType::Flex32 => self.assert_eq_elem::<f32>(other),
229 DType::F16 => self.assert_eq_elem::<f16>(other),
230 DType::BF16 => self.assert_eq_elem::<bf16>(other),
231 DType::I64 => self.assert_eq_elem::<i64>(other),
232 DType::I32 => self.assert_eq_elem::<i32>(other),
233 DType::I16 => self.assert_eq_elem::<i16>(other),
234 DType::I8 => self.assert_eq_elem::<i8>(other),
235 DType::U64 => self.assert_eq_elem::<u64>(other),
236 DType::U32 => self.assert_eq_elem::<u32>(other),
237 DType::U16 => self.assert_eq_elem::<u16>(other),
238 DType::U8 => self.assert_eq_elem::<u8>(other),
239 DType::Bool => self.assert_eq_elem::<bool>(other),
240 DType::QFloat(q) => {
241 let q_other = if let DType::QFloat(q_other) = other.dtype {
243 q_other
244 } else {
245 panic!("Quantized data differs from other not quantized data")
246 };
247
248 if q.value == q_other.value && q.level == q_other.level {
250 self.assert_eq_elem::<i8>(other)
251 } else {
252 panic!("Quantization schemes differ ({q:?} != {q_other:?})")
253 }
254 }
255 }
256 }
257
258 #[track_caller]
259 fn assert_eq_elem<E: Element>(&self, other: &Self) {
260 let mut message = String::new();
261 if self.shape != other.shape {
262 message += format!(
263 "\n => Shape is different: {:?} != {:?}",
264 self.shape, other.shape
265 )
266 .as_str();
267 }
268
269 let mut num_diff = 0;
270 let max_num_diff = 5;
271 for (i, (a, b)) in self.iter::<E>().zip(other.iter::<E>()).enumerate() {
272 if a.cmp(&b).is_ne() {
273 if num_diff < max_num_diff {
275 message += format!("\n => Position {i}: {a} != {b}").as_str();
276 }
277 num_diff += 1;
278 }
279 }
280
281 if num_diff >= max_num_diff {
282 message += format!("\n{} more errors...", num_diff - max_num_diff).as_str();
283 }
284
285 if !message.is_empty() {
286 panic!("Tensors are not eq:{message}");
287 }
288 }
289
290 #[track_caller]
301 pub fn assert_approx_eq<F: Float + Element>(&self, other: &Self, tolerance: Tolerance<F>) {
302 let mut message = String::new();
303 if self.shape != other.shape {
304 message += format!(
305 "\n => Shape is different: {:?} != {:?}",
306 self.shape, other.shape
307 )
308 .as_str();
309 }
310
311 let iter = self.iter::<F>().zip(other.iter::<F>());
312
313 let mut num_diff = 0;
314 let max_num_diff = 5;
315
316 for (i, (a, b)) in iter.enumerate() {
317 let both_nan = a.is_nan() && b.is_nan();
319 let both_inf =
321 a.is_infinite() && b.is_infinite() && ((a > F::zero()) == (b > F::zero()));
322
323 if both_nan || both_inf {
324 continue;
325 }
326
327 if !tolerance.approx_eq(F::from(a).unwrap(), F::from(b).unwrap()) {
328 if num_diff < max_num_diff {
330 let diff_abs = ToPrimitive::to_f64(&(a - b).abs()).unwrap();
331 let max = F::max(a.abs(), b.abs());
332 let diff_rel = diff_abs / ToPrimitive::to_f64(&max).unwrap();
333
334 let tol_rel = ToPrimitive::to_f64(&tolerance.relative).unwrap();
335 let tol_abs = ToPrimitive::to_f64(&tolerance.absolute).unwrap();
336
337 message += format!(
338 "\n => Position {i}: {a} != {b}\n diff (rel = {diff_rel:+.2e}, abs = {diff_abs:+.2e}), tol (rel = {tol_rel:+.2e}, abs = {tol_abs:+.2e})"
339 )
340 .as_str();
341 }
342 num_diff += 1;
343 }
344 }
345
346 if num_diff >= max_num_diff {
347 message += format!("\n{} more errors...", num_diff - 5).as_str();
348 }
349
350 if !message.is_empty() {
351 panic!("Tensors are not approx eq:{message}");
352 }
353 }
354
355 pub fn assert_within_range<E: Element>(&self, range: core::ops::Range<E>) {
366 for elem in self.iter::<E>() {
367 if elem.cmp(&range.start).is_lt() || elem.cmp(&range.end).is_ge() {
368 panic!("Element ({elem:?}) is not within range {range:?}");
369 }
370 }
371 }
372
373 pub fn assert_within_range_inclusive<E: Element>(&self, range: core::ops::RangeInclusive<E>) {
383 let start = range.start();
384 let end = range.end();
385
386 for elem in self.iter::<E>() {
387 if elem.cmp(start).is_lt() || elem.cmp(end).is_gt() {
388 panic!("Element ({elem:?}) is not within range {range:?}");
389 }
390 }
391 }
392}
393
394#[cfg(test)]
395mod tests {
396 use super::*;
397
398 #[test]
399 fn should_assert_appox_eq_limit() {
400 let data1 = TensorData::from([[3.0, 5.0, 6.0]]);
401 let data2 = TensorData::from([[3.03, 5.0, 6.0]]);
402
403 data1.assert_approx_eq::<f32>(&data2, Tolerance::absolute(3e-2));
404 data1.assert_approx_eq::<f16>(&data2, Tolerance::absolute(3e-2));
405 }
406
407 #[test]
408 #[should_panic]
409 fn should_assert_approx_eq_above_limit() {
410 let data1 = TensorData::from([[3.0, 5.0, 6.0]]);
411 let data2 = TensorData::from([[3.031, 5.0, 6.0]]);
412
413 data1.assert_approx_eq::<f32>(&data2, Tolerance::absolute(1e-2));
414 }
415
416 #[test]
417 #[should_panic]
418 fn should_assert_approx_eq_check_shape() {
419 let data1 = TensorData::from([[3.0, 5.0, 6.0, 7.0]]);
420 let data2 = TensorData::from([[3.0, 5.0, 6.0]]);
421
422 data1.assert_approx_eq::<f32>(&data2, Tolerance::absolute(1e-2));
423 }
424}