funspace/utils.rs
1/// Test approx equality of two arrays element-wise
2///
3/// # Panics
4/// Panics when difference is larger than 1e-3.
5#[allow(dead_code)]
6pub fn approx_eq<A>(vec_a: &[A], vec_b: &[A])
7where
8 A: crate::types::FloatNum + std::fmt::Display,
9{
10 let tol = A::from_f64(1e-3).unwrap();
11 for (a, b) in vec_a.iter().zip(vec_b.iter()) {
12 assert!(
13 ((*a - *b).abs() < tol),
14 "Large difference of values, got {} expected {}.",
15 b,
16 a
17 );
18 }
19}
20
21/// Test approx equality of two arrays element-wise
22///
23/// # Panics
24/// Panics when difference is larger than 1e-3.
25#[allow(dead_code)]
26pub fn approx_eq_complex<A>(vec_a: &[num_complex::Complex<A>], vec_b: &[num_complex::Complex<A>])
27where
28 A: crate::types::FloatNum + std::fmt::Display,
29{
30 let tol = A::from_f64(1e-3).unwrap();
31 for (a, b) in vec_a.iter().zip(vec_b.iter()) {
32 assert!(
33 ((a.re - b.re).abs() < tol || (a.im - b.im).abs() < tol),
34 "Large difference of values, got {} expected {}.",
35 b,
36 a
37 );
38 }
39}
40
41/// Test approx equality of two arrays element-wise
42///
43/// # Panics
44/// Panics when difference is larger than 1e-3.
45pub fn approx_eq_ndarray<A, S, D>(
46 result: &ndarray::ArrayBase<S, D>,
47 expected: &ndarray::ArrayBase<S, D>,
48) where
49 A: crate::types::FloatNum + std::fmt::Display,
50 S: ndarray::Data<Elem = A>,
51 D: ndarray::Dimension,
52{
53 let tol = A::from_f64(1e-3).unwrap();
54 for (a, b) in expected.iter().zip(result.iter()) {
55 assert!(
56 ((*a - *b).abs() < tol),
57 "Large difference of values, got {} expected {}.",
58 b,
59 a
60 );
61 }
62}
63
64/// Test approx equality of two arrays element-wise
65///
66/// # Panics
67/// Panics when difference is larger than 1e-3.
68pub fn approx_eq_complex_ndarray<A, S, D>(
69 result: &ndarray::ArrayBase<S, D>,
70 expected: &ndarray::ArrayBase<S, D>,
71) where
72 A: crate::types::FloatNum + std::fmt::Display,
73 S: ndarray::Data<Elem = num_complex::Complex<A>>,
74 D: ndarray::Dimension,
75{
76 let tol = A::from_f64(1e-3).unwrap();
77 for (a, b) in expected.iter().zip(result.iter()) {
78 assert!(
79 ((a.re - b.re).abs() < tol || (a.im - b.im).abs() < tol),
80 "Large difference of values, got {} expected {}.",
81 b,
82 a
83 );
84 }
85}
86
87/// Returns a new array with same dimensionality
88/// but different size *n* along the specified *axis*.
89///
90/// # Example
91/// ```
92/// use funspace::utils::array_resized_axis;
93/// let array = ndarray::Array2::<f64>::zeros((5, 3));
94/// let resized: ndarray::Array2<f64> = array_resized_axis(&array, 2, 1);
95/// assert!(resized == ndarray::Array2::zeros((5, 2)));
96/// ```
97pub fn array_resized_axis<A, S, D, T>(
98 input: &ndarray::ArrayBase<S, D>,
99 size: usize,
100 axis: usize,
101) -> ndarray::Array<T, D>
102where
103 T: num_traits::Zero + std::clone::Clone,
104 S: ndarray::Data<Elem = A>,
105 D: ndarray::Dimension,
106{
107 // Get dim
108 let mut dim = input.raw_dim();
109
110 // Replace position in dim
111 dim[axis] = size;
112
113 // Return
114 ndarray::Array::<T, D>::zeros(dim)
115}
116
117/// Checks size of axis.
118///
119/// # Panics
120/// Panics when inputs shape does not match
121/// axis' size
122///
123/// # Example
124/// ```should_panic
125/// use funspace::utils::check_array_axis;
126/// let array = ndarray::Array2::<f64>::zeros((5, 3));
127/// check_array_axis(&array, 3, 0, "");
128/// ```
129pub fn check_array_axis<A, S, D>(
130 input: &ndarray::ArrayBase<S, D>,
131 size: usize,
132 axis: usize,
133 function_name: &str,
134) where
135 S: ndarray::Data<Elem = A>,
136 D: ndarray::Dimension,
137{
138 // Arrays size
139 let m = input.shape()[axis];
140
141 assert!(
142 input.shape()[axis] == size,
143 "Size mismatch in {}, got {} expected {} along axis {}",
144 function_name,
145 size,
146 m,
147 axis
148 );
149}
150
151// /// Checks size of axis.
152// ///
153// /// # Panics
154// /// Panics when inputs shape does not match
155// /// axis' size
156// ///
157// /// # Example
158// /// ```should_panic
159// /// use funspace::utils::check_array_axis;
160// /// let array = ndarray::Array2::<f64>::zeros((5, 3));
161// /// check_array_axis(&array, 3, 0, None);
162// /// ```
163// pub fn check_array_axis<A, S, D>(
164// input: &ndarray::ArrayBase<S, D>,
165// size: usize,
166// axis: usize,
167// function_name: Option<&str>,
168// ) where
169// A: ndarray::LinalgScalar,
170// S: ndarray::Data<Elem = A>,
171// D: ndarray::Dimension,
172// {
173// // Arrays size
174// let m = input.shape()[axis];
175//
176// // Panic
177// if size != m {
178// if let Some(name) = function_name {
179// panic!(
180// "Size mismatch in {}, got {} expected {} along axis {}",
181// name, size, m, axis
182// );
183// } else {
184// panic!(
185// "Size mismatch, got {} expected {} along axis {}",
186// size, m, axis
187// );
188// };
189// }
190// }