ndarray_numtest/
assert.rs1use ndarray::{Array, Dimension, IntoDimension};
4use float_cmp::ApproxEqRatio;
5use num_complex::Complex;
6
7pub trait AssertClose: Sized + Copy {
9 type Tol;
10 fn assert_close(self, truth: Self, rtol: Self::Tol);
11}
12
13macro_rules! impl_AssertClose {
14 ($scalar:ty) => {
15impl AssertClose for $scalar {
16 type Tol = $scalar;
17 fn assert_close(self, truth: Self, rtol: Self::Tol) {
18 if !self.approx_eq_ratio(&truth, rtol) {
19 panic!("Not close: val={}, truth={}, rtol={}", self, truth, rtol);
20 }
21 }
22}
23impl AssertClose for Complex<$scalar> {
24 type Tol = $scalar;
25 fn assert_close(self, truth: Self, rtol: Self::Tol) {
26 if !(self.re.approx_eq_ratio(&truth.re, rtol) && self.im.approx_eq_ratio(&truth.im, rtol)) {
27 panic!("Not close: val={}, truth={}, rtol={}", self, truth, rtol);
28 }
29 }
30}
31}} impl_AssertClose!(f64);
33impl_AssertClose!(f32);
34
35pub trait AssertAllClose {
37 type Tol;
38 fn assert_allclose_l2(&self, truth: &Self, rtol: Self::Tol);
40 fn assert_allclose_inf(&self, truth: &Self, atol: Self::Tol);
42}
43
44macro_rules! impl_AssertAllClose {
45 ($scalar:ty, $float:ty, $abs:ident) => {
46impl AssertAllClose for [$scalar]{
47 type Tol = $float;
48 fn assert_allclose_inf(&self, truth: &Self, atol: Self::Tol) {
49 for (x, y) in self.iter().zip(truth.iter()) {
50 let tol = (x - y).$abs();
51 if tol > atol {
52 panic!("Not close in inf-norm (atol={}): \ntest = \n{:?}\nTruth = \n{:?}",
53 atol, self, truth);
54 }
55 }
56 }
57 fn assert_allclose_l2(&self, truth: &Self, rtol: Self::Tol) {
58 let nrm: Self::Tol = truth.iter().map(|x| x.$abs().powi(2)).sum();
59 let dev: Self::Tol = self.iter().zip(truth.iter()).map(|(x, y)| (x-y).$abs().powi(2)).sum();
60 if dev / nrm > rtol.powi(2) {
61 panic!("Not close in L2-norm (rtol={}): \ntest = \n{:?}\nTruth = \n{:?}",
62 rtol, self, truth);
63 }
64 }
65}
66
67impl AssertAllClose for Vec<$scalar> {
68 type Tol = $float;
69 fn assert_allclose_inf(&self, truth: &Self, atol: Self::Tol) {
70 self.as_slice().assert_allclose_inf(&truth, atol);
71 }
72 fn assert_allclose_l2(&self, truth: &Self, rtol: Self::Tol) {
73 self.as_slice().assert_allclose_l2(&truth, rtol);
74 }
75}
76
77impl<D: Dimension> AssertAllClose for Array<$scalar, D> {
78 type Tol = $float;
79 fn assert_allclose_inf(&self, truth: &Self, atol: Self::Tol) {
80 if self.shape() != truth.shape() {
81 panic!("Shape missmatch: self={:?}, truth={:?}", self.shape(), truth.shape());
82 }
83 for (idx, val) in self.indexed_iter() {
84 let t = truth[idx.into_dimension()];
85 let tol = (*val - t).$abs();
86 if tol > atol {
87 panic!("Not close in inf-norm (atol={}): \ntest = \n{:?}\nTruth = \n{:?}",
88 atol, self, truth);
89 }
90 }
91 }
92 fn assert_allclose_l2(&self, truth: &Self, rtol: Self::Tol) {
93 if self.shape() != truth.shape() {
94 panic!("Shape missmatch: self={:?}, truth={:?}", self.shape(), truth.shape());
95 }
96 let nrm: Self::Tol = truth.iter().map(|x| x.$abs().powi(2)).sum();
97 let dev: Self::Tol = self.indexed_iter().map(|(idx, val)| (truth[idx.into_dimension()] - val).$abs().powi(2)).sum();
98 if dev / nrm > rtol.powi(2) {
99 panic!("Not close in L2-norm (rtol={}): \ntest = \n{:?}\nTruth = \n{:?}",
100 rtol, self, truth);
101 }
102 }
103}
104}} impl_AssertAllClose!(f64, f64, abs);
107impl_AssertAllClose!(f32, f32, abs);
108impl_AssertAllClose!(Complex<f64>, f64, norm);
109impl_AssertAllClose!(Complex<f32>, f32, norm);