funspace/
utils.rs

1/// Test approx equality of two arrays element-wise
2///
3/// # Panics
4/// Panics when difference is larger than 1e-3.
5#[allow(dead_code)]
6pub fn approx_eq<A>(vec_a: &[A], vec_b: &[A])
7where
8    A: crate::types::FloatNum + std::fmt::Display,
9{
10    let tol = A::from_f64(1e-3).unwrap();
11    for (a, b) in vec_a.iter().zip(vec_b.iter()) {
12        assert!(
13            ((*a - *b).abs() < tol),
14            "Large difference of values, got {} expected {}.",
15            b,
16            a
17        );
18    }
19}
20
21/// Test approx equality of two arrays element-wise
22///
23/// # Panics
24/// Panics when difference is larger than 1e-3.
25#[allow(dead_code)]
26pub fn approx_eq_complex<A>(vec_a: &[num_complex::Complex<A>], vec_b: &[num_complex::Complex<A>])
27where
28    A: crate::types::FloatNum + std::fmt::Display,
29{
30    let tol = A::from_f64(1e-3).unwrap();
31    for (a, b) in vec_a.iter().zip(vec_b.iter()) {
32        assert!(
33            ((a.re - b.re).abs() < tol || (a.im - b.im).abs() < tol),
34            "Large difference of values, got {} expected {}.",
35            b,
36            a
37        );
38    }
39}
40
41/// Test approx equality of two arrays element-wise
42///
43/// # Panics
44/// Panics when difference is larger than 1e-3.
45pub fn approx_eq_ndarray<A, S, D>(
46    result: &ndarray::ArrayBase<S, D>,
47    expected: &ndarray::ArrayBase<S, D>,
48) where
49    A: crate::types::FloatNum + std::fmt::Display,
50    S: ndarray::Data<Elem = A>,
51    D: ndarray::Dimension,
52{
53    let tol = A::from_f64(1e-3).unwrap();
54    for (a, b) in expected.iter().zip(result.iter()) {
55        assert!(
56            ((*a - *b).abs() < tol),
57            "Large difference of values, got {} expected {}.",
58            b,
59            a
60        );
61    }
62}
63
64/// Test approx equality of two arrays element-wise
65///
66/// # Panics
67/// Panics when difference is larger than 1e-3.
68pub fn approx_eq_complex_ndarray<A, S, D>(
69    result: &ndarray::ArrayBase<S, D>,
70    expected: &ndarray::ArrayBase<S, D>,
71) where
72    A: crate::types::FloatNum + std::fmt::Display,
73    S: ndarray::Data<Elem = num_complex::Complex<A>>,
74    D: ndarray::Dimension,
75{
76    let tol = A::from_f64(1e-3).unwrap();
77    for (a, b) in expected.iter().zip(result.iter()) {
78        assert!(
79            ((a.re - b.re).abs() < tol || (a.im - b.im).abs() < tol),
80            "Large difference of values, got {} expected {}.",
81            b,
82            a
83        );
84    }
85}
86
87/// Returns a new array with same dimensionality
88/// but different size *n* along the specified *axis*.
89///
90/// # Example
91/// ```
92/// use funspace::utils::array_resized_axis;
93/// let array = ndarray::Array2::<f64>::zeros((5, 3));
94/// let resized: ndarray::Array2<f64> = array_resized_axis(&array, 2, 1);
95/// assert!(resized == ndarray::Array2::zeros((5, 2)));
96/// ```
97pub fn array_resized_axis<A, S, D, T>(
98    input: &ndarray::ArrayBase<S, D>,
99    size: usize,
100    axis: usize,
101) -> ndarray::Array<T, D>
102where
103    T: num_traits::Zero + std::clone::Clone,
104    S: ndarray::Data<Elem = A>,
105    D: ndarray::Dimension,
106{
107    // Get dim
108    let mut dim = input.raw_dim();
109
110    // Replace position in dim
111    dim[axis] = size;
112
113    // Return
114    ndarray::Array::<T, D>::zeros(dim)
115}
116
117/// Checks size of axis.
118///
119/// # Panics
120/// Panics when inputs shape does not match
121/// axis' size
122///
123/// # Example
124/// ```should_panic
125/// use funspace::utils::check_array_axis;
126/// let array = ndarray::Array2::<f64>::zeros((5, 3));
127/// check_array_axis(&array, 3, 0, "");
128/// ```
129pub fn check_array_axis<A, S, D>(
130    input: &ndarray::ArrayBase<S, D>,
131    size: usize,
132    axis: usize,
133    function_name: &str,
134) where
135    S: ndarray::Data<Elem = A>,
136    D: ndarray::Dimension,
137{
138    // Arrays size
139    let m = input.shape()[axis];
140
141    assert!(
142        input.shape()[axis] == size,
143        "Size mismatch in {}, got {} expected {} along axis {}",
144        function_name,
145        size,
146        m,
147        axis
148    );
149}
150
151// /// Checks size of axis.
152// ///
153// /// # Panics
154// /// Panics when inputs shape does not match
155// /// axis' size
156// ///
157// /// # Example
158// /// ```should_panic
159// /// use funspace::utils::check_array_axis;
160// /// let array = ndarray::Array2::<f64>::zeros((5, 3));
161// /// check_array_axis(&array, 3, 0, None);
162// /// ```
163// pub fn check_array_axis<A, S, D>(
164//     input: &ndarray::ArrayBase<S, D>,
165//     size: usize,
166//     axis: usize,
167//     function_name: Option<&str>,
168// ) where
169//     A: ndarray::LinalgScalar,
170//     S: ndarray::Data<Elem = A>,
171//     D: ndarray::Dimension,
172// {
173//     // Arrays size
174//     let m = input.shape()[axis];
175//
176//     // Panic
177//     if size != m {
178//         if let Some(name) = function_name {
179//             panic!(
180//                 "Size mismatch in {}, got {} expected {} along axis {}",
181//                 name, size, m, axis
182//             );
183//         } else {
184//             panic!(
185//                 "Size mismatch, got {} expected {} along axis {}",
186//                 size, m, axis
187//             );
188//         };
189//     }
190// }