1use alloc::format;
2use alloc::string::String;
3use burn_std::{BoolStore, DType, bf16, f16};
4use num_traits::{Float, ToPrimitive};
5
6use super::TensorData;
7use crate::{Element, ElementOrdered};
8
9#[derive(Debug, Clone, Copy)]
28pub struct Tolerance<F> {
29 relative: F,
30 absolute: F,
31}
32
33impl<F: Float> Default for Tolerance<F> {
34 fn default() -> Self {
35 Self::balanced()
36 }
37}
38
39impl<F: Float> Tolerance<F> {
40 pub fn strict() -> Self {
42 Self {
43 relative: F::from(0.00).unwrap(),
44 absolute: F::from(64).unwrap() * F::min_positive_value(),
45 }
46 }
47 pub fn balanced() -> Self {
49 Self {
50 relative: F::from(0.005).unwrap(), absolute: F::from(1e-5).unwrap(),
52 }
53 }
54
55 pub fn permissive() -> Self {
57 Self {
58 relative: F::from(0.01).unwrap(), absolute: F::from(0.01).unwrap(),
60 }
61 }
62 pub fn rel_abs<FF: ToPrimitive>(relative: FF, absolute: FF) -> Self {
72 let relative = Self::check_relative(relative);
73 let absolute = Self::check_absolute(absolute);
74
75 Self { relative, absolute }
76 }
77
78 pub fn relative<FF: ToPrimitive>(tolerance: FF) -> Self {
88 let relative = Self::check_relative(tolerance);
89
90 Self {
91 relative,
92 absolute: F::from(0.0).unwrap(),
93 }
94 }
95
96 pub fn absolute<FF: ToPrimitive>(tolerance: FF) -> Self {
106 let absolute = Self::check_absolute(tolerance);
107
108 Self {
109 relative: F::from(0.0).unwrap(),
110 absolute,
111 }
112 }
113
114 pub fn set_relative<FF: ToPrimitive>(mut self, tolerance: FF) -> Self {
116 self.relative = Self::check_relative(tolerance);
117 self
118 }
119
120 pub fn set_half_precision_relative<FF: ToPrimitive>(mut self, tolerance: FF) -> Self {
122 if core::mem::size_of::<F>() == 2 {
123 self.relative = Self::check_relative(tolerance);
124 }
125 self
126 }
127
128 pub fn set_single_precision_relative<FF: ToPrimitive>(mut self, tolerance: FF) -> Self {
130 if core::mem::size_of::<F>() == 4 {
131 self.relative = Self::check_relative(tolerance);
132 }
133 self
134 }
135
136 pub fn set_double_precision_relative<FF: ToPrimitive>(mut self, tolerance: FF) -> Self {
138 if core::mem::size_of::<F>() == 8 {
139 self.relative = Self::check_relative(tolerance);
140 }
141 self
142 }
143
144 pub fn set_absolute<FF: ToPrimitive>(mut self, tolerance: FF) -> Self {
146 self.absolute = Self::check_absolute(tolerance);
147 self
148 }
149
150 pub fn set_half_precision_absolute<FF: ToPrimitive>(mut self, tolerance: FF) -> Self {
152 if core::mem::size_of::<F>() == 2 {
153 self.absolute = Self::check_absolute(tolerance);
154 }
155 self
156 }
157
158 pub fn set_single_precision_absolute<FF: ToPrimitive>(mut self, tolerance: FF) -> Self {
160 if core::mem::size_of::<F>() == 4 {
161 self.absolute = Self::check_absolute(tolerance);
162 }
163 self
164 }
165
166 pub fn set_double_precision_absolute<FF: ToPrimitive>(mut self, tolerance: FF) -> Self {
168 if core::mem::size_of::<F>() == 8 {
169 self.absolute = Self::check_absolute(tolerance);
170 }
171 self
172 }
173
174 pub fn approx_eq(&self, x: F, y: F) -> bool {
176 if x == y {
182 return true;
183 }
184
185 let diff = (x - y).abs();
186 let max = F::max(x.abs(), y.abs());
187
188 diff < self.absolute.max(self.relative * max)
189 }
190
191 fn check_relative<FF: ToPrimitive>(tolerance: FF) -> F {
192 let tolerance = F::from(tolerance).unwrap();
193 assert!(tolerance <= F::one());
194 tolerance
195 }
196
197 fn check_absolute<FF: ToPrimitive>(tolerance: FF) -> F {
198 let tolerance = F::from(tolerance).unwrap();
199 assert!(tolerance >= F::zero());
200 tolerance
201 }
202}
203
204impl TensorData {
205 #[track_caller]
217 pub fn assert_eq(&self, other: &Self, strict: bool) {
218 if strict {
219 assert_eq!(
220 self.dtype, other.dtype,
221 "Data types differ ({:?} != {:?})",
222 self.dtype, other.dtype
223 );
224 }
225
226 match self.dtype {
227 DType::F64 => self.assert_eq_elem::<f64>(other),
228 DType::F32 | DType::Flex32 => self.assert_eq_elem::<f32>(other),
229 DType::F16 => self.assert_eq_elem::<f16>(other),
230 DType::BF16 => self.assert_eq_elem::<bf16>(other),
231 DType::I64 => self.assert_eq_elem::<i64>(other),
232 DType::I32 => self.assert_eq_elem::<i32>(other),
233 DType::I16 => self.assert_eq_elem::<i16>(other),
234 DType::I8 => self.assert_eq_elem::<i8>(other),
235 DType::U64 => self.assert_eq_elem::<u64>(other),
236 DType::U32 => self.assert_eq_elem::<u32>(other),
237 DType::U16 => self.assert_eq_elem::<u16>(other),
238 DType::U8 => self.assert_eq_elem::<u8>(other),
239 DType::Bool(BoolStore::Native) => self.assert_eq_elem::<bool>(other),
240 DType::Bool(BoolStore::U8) => self.assert_eq_elem::<u8>(other),
241 DType::Bool(BoolStore::U32) => self.assert_eq_elem::<u32>(other),
242 DType::QFloat(q) => {
243 let q_other = if let DType::QFloat(q_other) = other.dtype {
245 q_other
246 } else {
247 panic!("Quantized data differs from other not quantized data")
248 };
249
250 if q.value == q_other.value && q.level == q_other.level {
252 self.assert_eq_elem::<i8>(other)
253 } else {
254 panic!("Quantization schemes differ ({q:?} != {q_other:?})")
255 }
256 }
257 }
258 }
259
260 #[track_caller]
261 fn assert_eq_elem<E: Element>(&self, other: &Self) {
262 let mut message = String::new();
263 if self.shape != other.shape {
264 message += format!(
265 "\n => Shape is different: {:?} != {:?}",
266 self.shape, other.shape
267 )
268 .as_str();
269 }
270
271 let mut num_diff = 0;
272 let max_num_diff = 5;
273 for (i, (a, b)) in self.iter::<E>().zip(other.iter::<E>()).enumerate() {
274 if !a.eq(&b) {
275 if num_diff < max_num_diff {
277 message += format!("\n => Position {i}: {a} != {b}").as_str();
278 }
279 num_diff += 1;
280 }
281 }
282
283 if num_diff >= max_num_diff {
284 message += format!("\n{} more errors...", num_diff - max_num_diff).as_str();
285 }
286
287 if !message.is_empty() {
288 panic!("Tensors are not eq:{message}");
289 }
290 }
291
292 #[track_caller]
303 pub fn assert_approx_eq<F: Float + Element>(&self, other: &Self, tolerance: Tolerance<F>) {
304 let mut message = String::new();
305 if self.shape != other.shape {
306 message += format!(
307 "\n => Shape is different: {:?} != {:?}",
308 self.shape, other.shape
309 )
310 .as_str();
311 }
312
313 let iter = self.iter::<F>().zip(other.iter::<F>());
314
315 let mut num_diff = 0;
316 let max_num_diff = 5;
317
318 for (i, (a, b)) in iter.enumerate() {
319 let both_nan = a.is_nan() && b.is_nan();
321 let both_inf =
323 a.is_infinite() && b.is_infinite() && ((a > F::zero()) == (b > F::zero()));
324
325 if both_nan || both_inf {
326 continue;
327 }
328
329 if !tolerance.approx_eq(F::from(a).unwrap(), F::from(b).unwrap()) {
330 if num_diff < max_num_diff {
332 let diff_abs = ToPrimitive::to_f64(&(a - b).abs()).unwrap();
333 let max = F::max(a.abs(), b.abs());
334 let diff_rel = diff_abs / ToPrimitive::to_f64(&max).unwrap();
335
336 let tol_rel = ToPrimitive::to_f64(&tolerance.relative).unwrap();
337 let tol_abs = ToPrimitive::to_f64(&tolerance.absolute).unwrap();
338
339 message += format!(
340 "\n => Position {i}: {a} != {b}\n diff (rel = {diff_rel:+.2e}, abs = {diff_abs:+.2e}), tol (rel = {tol_rel:+.2e}, abs = {tol_abs:+.2e})"
341 )
342 .as_str();
343 }
344 num_diff += 1;
345 }
346 }
347
348 if num_diff >= max_num_diff {
349 message += format!("\n{} more errors...", num_diff - 5).as_str();
350 }
351
352 if !message.is_empty() {
353 panic!("Tensors are not approx eq:{message}");
354 }
355 }
356
357 pub fn assert_within_range<E: ElementOrdered>(&self, range: core::ops::Range<E>) {
368 for elem in self.iter::<E>() {
369 if elem.cmp(&range.start).is_lt() || elem.cmp(&range.end).is_ge() {
370 panic!("Element ({elem:?}) is not within range {range:?}");
371 }
372 }
373 }
374
375 pub fn assert_within_range_inclusive<E: ElementOrdered>(
385 &self,
386 range: core::ops::RangeInclusive<E>,
387 ) {
388 let start = range.start();
389 let end = range.end();
390
391 for elem in self.iter::<E>() {
392 if elem.cmp(start).is_lt() || elem.cmp(end).is_gt() {
393 panic!("Element ({elem:?}) is not within range {range:?}");
394 }
395 }
396 }
397}
398
399#[cfg(test)]
400mod tests {
401 use super::*;
402
403 #[test]
404 fn should_assert_appox_eq_limit() {
405 let data1 = TensorData::from([[3.0, 5.0, 6.0]]);
406 let data2 = TensorData::from([[3.03, 5.0, 6.0]]);
407
408 data1.assert_approx_eq::<f32>(&data2, Tolerance::absolute(3e-2));
409 data1.assert_approx_eq::<f16>(&data2, Tolerance::absolute(3e-2));
410 }
411
412 #[test]
413 #[should_panic]
414 fn should_assert_approx_eq_above_limit() {
415 let data1 = TensorData::from([[3.0, 5.0, 6.0]]);
416 let data2 = TensorData::from([[3.031, 5.0, 6.0]]);
417
418 data1.assert_approx_eq::<f32>(&data2, Tolerance::absolute(1e-2));
419 }
420
421 #[test]
422 #[should_panic]
423 fn should_assert_approx_eq_check_shape() {
424 let data1 = TensorData::from([[3.0, 5.0, 6.0, 7.0]]);
425 let data2 = TensorData::from([[3.0, 5.0, 6.0]]);
426
427 data1.assert_approx_eq::<f32>(&data2, Tolerance::absolute(1e-2));
428 }
429}