cubek_test_utils/correctness/
base.rs

1use crate::correctness::color_printer::ColorPrinter;
2use crate::test_mode::{TestMode, current_test_mode};
3use crate::{HostData, ValidationResult};
4
5pub fn assert_equals_approx(
6    actual: &HostData,
7    expected: &HostData,
8    epsilon: f32,
9) -> ValidationResult {
10    if actual.shape != expected.shape {
11        return ValidationResult::Fail(format!(
12            "Shape mismatch: got {:?}, expected {:?}",
13            actual.shape, expected.shape,
14        ));
15    }
16
17    let shape = &actual.shape;
18    let test_mode = current_test_mode();
19
20    let mut visitor: Box<dyn CompareVisitor> = match test_mode.clone() {
21        TestMode::Print {
22            filter,
23            fail_only: _,
24        } => {
25            if !filter.is_empty() && filter.len() != shape.len() {
26                return ValidationResult::Skipped(format!(
27                    "Print mode activated with invalid filter rank. Got {:?}, expected {:?}",
28                    filter.len(),
29                    shape.len()
30                ));
31            }
32            Box::new(ColorPrinter::new(filter))
33        }
34        _ => Box::new(FailFast),
35    };
36
37    let test_failed = compare_tensors(
38        actual,
39        expected,
40        shape,
41        epsilon,
42        &mut *visitor,
43        &mut Vec::new(),
44    );
45
46    match test_failed {
47        true => ValidationResult::Fail("Got incorrect results".to_string()),
48        false => ValidationResult::Pass,
49    }
50}
51
52#[derive(Debug)]
53pub(crate) enum ElemStatus {
54    Correct { got: f32 },
55    Wrong(WrongStatus),
56}
57
58#[derive(Debug)]
59pub(crate) enum WrongStatus {
60    GotWrongValue {
61        got: f32,
62        expected: f32,
63        diff: f32,
64        epsilon: f32,
65    },
66    ExpectedNan {
67        got: f32,
68    },
69    GotNan {
70        expected: f32,
71    },
72}
73
74pub(crate) trait CompareVisitor {
75    fn visit(&mut self, index: &[usize], status: ElemStatus);
76}
77
78pub(crate) struct FailFast;
79
80impl CompareVisitor for FailFast {
81    fn visit(&mut self, index: &[usize], status: ElemStatus) {
82        if let ElemStatus::Wrong(w) = status {
83            panic!("Mismatch at {:?}: {:?}", index, w);
84        }
85    }
86}
87
88#[inline]
89fn compare_elem(got: f32, expected: f32, epsilon: f32) -> ElemStatus {
90    let eps = (epsilon * expected).abs().max(epsilon).min(0.99);
91
92    // NaN check: pass if both are NaN
93    if got.is_nan() && expected.is_nan() {
94        return ElemStatus::Correct { got };
95    }
96
97    // NaN mismatch
98    if got.is_nan() || expected.is_nan() {
99        return if expected.is_nan() {
100            ElemStatus::Wrong(WrongStatus::ExpectedNan { got })
101        } else {
102            ElemStatus::Wrong(WrongStatus::GotNan { expected })
103        };
104    }
105
106    // Infinite check: pass if both inf with same sign
107    if got.is_infinite() && expected.is_infinite() {
108        if got.signum() == expected.signum() {
109            return ElemStatus::Correct { got };
110        } else {
111            return ElemStatus::Wrong(WrongStatus::GotWrongValue {
112                got,
113                expected,
114                diff: f32::INFINITY,
115                epsilon: eps,
116            });
117        }
118    }
119
120    // Regular numeric comparison
121    let diff = (got - expected).abs();
122    if diff < eps {
123        ElemStatus::Correct { got }
124    } else {
125        ElemStatus::Wrong(WrongStatus::GotWrongValue {
126            got,
127            expected,
128            diff,
129            epsilon: eps,
130        })
131    }
132}
133
134fn compare_tensors(
135    actual: &HostData,
136    expected: &HostData,
137    shape: &[usize],
138    epsilon: f32,
139    visitor: &mut dyn CompareVisitor,
140    index: &mut Vec<usize>,
141) -> bool {
142    let mut failed = false;
143
144    let dim = index.len();
145    if dim == shape.len() {
146        let got = actual.get_f32(index);
147        let exp = expected.get_f32(index);
148
149        let status = compare_elem(got, exp, epsilon);
150        if matches!(status, ElemStatus::Wrong(_)) {
151            failed = true;
152        }
153        visitor.visit(index, status);
154        return failed;
155    }
156
157    for i in 0..shape[dim] {
158        index.push(i);
159        if compare_tensors(actual, expected, shape, epsilon, visitor, index) {
160            failed = true;
161        }
162        index.pop();
163    }
164
165    failed
166}