#[allow(dead_code)]
pub fn approx_eq<A>(vec_a: &[A], vec_b: &[A])
where
A: crate::types::FloatNum + std::fmt::Display,
{
let tol = A::from_f64(1e-3).unwrap();
for (a, b) in vec_a.iter().zip(vec_b.iter()) {
assert!(
((*a - *b).abs() < tol),
"Large difference of values, got {} expected {}.",
b,
a
);
}
}
#[allow(dead_code)]
pub fn approx_eq_complex<A>(vec_a: &[num_complex::Complex<A>], vec_b: &[num_complex::Complex<A>])
where
A: crate::types::FloatNum + std::fmt::Display,
{
let tol = A::from_f64(1e-3).unwrap();
for (a, b) in vec_a.iter().zip(vec_b.iter()) {
assert!(
((a.re - b.re).abs() < tol || (a.im - b.im).abs() < tol),
"Large difference of values, got {} expected {}.",
b,
a
);
}
}
pub fn approx_eq_ndarray<A, S, D>(
result: &ndarray::ArrayBase<S, D>,
expected: &ndarray::ArrayBase<S, D>,
) where
A: crate::types::FloatNum + std::fmt::Display,
S: ndarray::Data<Elem = A>,
D: ndarray::Dimension,
{
let tol = A::from_f64(1e-3).unwrap();
for (a, b) in expected.iter().zip(result.iter()) {
assert!(
((*a - *b).abs() < tol),
"Large difference of values, got {} expected {}.",
b,
a
);
}
}
pub fn approx_eq_complex_ndarray<A, S, D>(
result: &ndarray::ArrayBase<S, D>,
expected: &ndarray::ArrayBase<S, D>,
) where
A: crate::types::FloatNum + std::fmt::Display,
S: ndarray::Data<Elem = num_complex::Complex<A>>,
D: ndarray::Dimension,
{
let tol = A::from_f64(1e-3).unwrap();
for (a, b) in expected.iter().zip(result.iter()) {
assert!(
((a.re - b.re).abs() < tol || (a.im - b.im).abs() < tol),
"Large difference of values, got {} expected {}.",
b,
a
);
}
}
pub fn array_resized_axis<A, S, D, T>(
input: &ndarray::ArrayBase<S, D>,
size: usize,
axis: usize,
) -> ndarray::Array<T, D>
where
T: num_traits::Zero + std::clone::Clone,
S: ndarray::Data<Elem = A>,
D: ndarray::Dimension,
{
let mut dim = input.raw_dim();
dim[axis] = size;
ndarray::Array::<T, D>::zeros(dim)
}
pub fn check_array_axis<A, S, D>(
input: &ndarray::ArrayBase<S, D>,
size: usize,
axis: usize,
function_name: &str,
) where
S: ndarray::Data<Elem = A>,
D: ndarray::Dimension,
{
let m = input.shape()[axis];
assert!(
input.shape()[axis] == size,
"Size mismatch in {}, got {} expected {} along axis {}",
function_name,
size,
m,
axis
);
}