cubek_test_utils/correctness/
base.rs1use crate::correctness::color_printer::ColorPrinter;
2use crate::test_mode::{TestMode, current_test_mode};
3use crate::{HostData, ValidationResult};
4
5pub fn assert_equals_approx(
6 actual: &HostData,
7 expected: &HostData,
8 epsilon: f32,
9) -> ValidationResult {
10 if actual.shape != expected.shape {
11 return ValidationResult::Fail(format!(
12 "Shape mismatch: got {:?}, expected {:?}",
13 actual.shape, expected.shape,
14 ));
15 }
16
17 let shape = &actual.shape;
18 let test_mode = current_test_mode();
19
20 let mut visitor: Box<dyn CompareVisitor> = match test_mode.clone() {
21 TestMode::Print {
22 filter,
23 fail_only: _,
24 } => {
25 if !filter.is_empty() && filter.len() != shape.len() {
26 return ValidationResult::Skipped(format!(
27 "Print mode activated with invalid filter rank. Got {:?}, expected {:?}",
28 filter.len(),
29 shape.len()
30 ));
31 }
32 Box::new(ColorPrinter::new(filter))
33 }
34 _ => Box::new(FailFast),
35 };
36
37 let test_failed = compare_tensors(
38 actual,
39 expected,
40 shape,
41 epsilon,
42 &mut *visitor,
43 &mut Vec::new(),
44 );
45
46 match test_failed {
47 true => ValidationResult::Fail("Got incorrect results".to_string()),
48 false => ValidationResult::Pass,
49 }
50}
51
52#[derive(Debug)]
53pub(crate) enum ElemStatus {
54 Correct { got: f32 },
55 Wrong(WrongStatus),
56}
57
58#[derive(Debug)]
59pub(crate) enum WrongStatus {
60 GotWrongValue {
61 got: f32,
62 expected: f32,
63 diff: f32,
64 epsilon: f32,
65 },
66 ExpectedNan {
67 got: f32,
68 },
69 GotNan {
70 expected: f32,
71 },
72}
73
74pub(crate) trait CompareVisitor {
75 fn visit(&mut self, index: &[usize], status: ElemStatus);
76}
77
78pub(crate) struct FailFast;
79
80impl CompareVisitor for FailFast {
81 fn visit(&mut self, index: &[usize], status: ElemStatus) {
82 if let ElemStatus::Wrong(w) = status {
83 panic!("Mismatch at {:?}: {:?}", index, w);
84 }
85 }
86}
87
88#[inline]
89fn compare_elem(got: f32, expected: f32, epsilon: f32) -> ElemStatus {
90 let eps = (epsilon * expected).abs().max(epsilon).min(0.99);
91
92 if got.is_nan() && expected.is_nan() {
94 return ElemStatus::Correct { got };
95 }
96
97 if got.is_nan() || expected.is_nan() {
99 return if expected.is_nan() {
100 ElemStatus::Wrong(WrongStatus::ExpectedNan { got })
101 } else {
102 ElemStatus::Wrong(WrongStatus::GotNan { expected })
103 };
104 }
105
106 if got.is_infinite() && expected.is_infinite() {
108 if got.signum() == expected.signum() {
109 return ElemStatus::Correct { got };
110 } else {
111 return ElemStatus::Wrong(WrongStatus::GotWrongValue {
112 got,
113 expected,
114 diff: f32::INFINITY,
115 epsilon: eps,
116 });
117 }
118 }
119
120 let diff = (got - expected).abs();
122 if diff < eps {
123 ElemStatus::Correct { got }
124 } else {
125 ElemStatus::Wrong(WrongStatus::GotWrongValue {
126 got,
127 expected,
128 diff,
129 epsilon: eps,
130 })
131 }
132}
133
134fn compare_tensors(
135 actual: &HostData,
136 expected: &HostData,
137 shape: &[usize],
138 epsilon: f32,
139 visitor: &mut dyn CompareVisitor,
140 index: &mut Vec<usize>,
141) -> bool {
142 let mut failed = false;
143
144 let dim = index.len();
145 if dim == shape.len() {
146 let got = actual.get_f32(index);
147 let exp = expected.get_f32(index);
148
149 let status = compare_elem(got, exp, epsilon);
150 if matches!(status, ElemStatus::Wrong(_)) {
151 failed = true;
152 }
153 visitor.visit(index, status);
154 return failed;
155 }
156
157 for i in 0..shape[dim] {
158 index.push(i);
159 if compare_tensors(actual, expected, shape, epsilon, visitor, index) {
160 failed = true;
161 }
162 index.pop();
163 }
164
165 failed
166}