#![allow(clippy::type_complexity)]
use std::ops::RangeInclusive;
use nalgebra::{Complex, ComplexField};
use crate::{
assert_close, assert_is_derivative,
basis::{Basis, DifferentialBasis, IntegralBasis, OrthogonalBasis, Root, RootFindingBasis},
display::PolynomialDisplay,
statistics::DomainNormalizer,
value::{SteppedValues, Value},
Polynomial,
};
pub fn assert_basis_matrix_row<B: Basis<T>, T: Value>(basis: &B, x: T, expected: &[T]) {
let mut zeros = 0;
let x = basis.normalize_x(x);
while zeros <= expected.len() {
let mut matrix = nalgebra::DMatrix::<T>::zeros(1, expected.len());
basis.fill_matrix_row(zeros, x, matrix.row_mut(0));
for i in 0..zeros {
assert_eq!(matrix[(0, i)], T::zero(), "Matrix col {i} should be zero");
}
basis.fill_matrix_row(0, x, matrix.row_mut(0));
for i in 0..expected.len() {
assert_close!(matrix[(0, i)], expected[i], "Matrix col {i}");
}
zeros += 1;
}
}
pub fn assert_basis_functions_close<B: Basis<T>, T: Value>(
basis: &B,
x: T,
expected: &[T],
tol: T,
) {
let mut actual = vec![T::zero(); expected.len()];
for i in 0..expected.len() {
actual[i] = basis.solve_function(i, x);
}
for (i, (&a, &e)) in actual.iter().zip(expected.iter()).enumerate() {
if a.abs_sub(e) > tol {
eprintln!("Expected ∑{expected:?}");
eprintln!("Got ∑{actual:?}");
panic!("Basis function {i} differs: {a:?} != {e:?} (tol {tol:?})");
}
}
}
pub fn assert_basis_orthogonal<B, T>(basis: &B, functions: usize, nodes: usize, tol: T)
where
T: Value,
B: OrthogonalBasis<T>,
{
assert!(nodes >= functions, "need >= `functions` quadrature nodes");
let gram_matrix = basis.gauss_matrix(functions, nodes);
for i in 0..functions {
for j in i..functions {
let val = gram_matrix[(i, j)];
if i == j {
let expected = basis.gauss_normalization(i);
let err = Value::abs(val - expected);
assert!(
err <= tol,
"gram[{i},{j}] : {val:?} != {expected:?} ; {err:?} > {tol:?}\n{gram_matrix}"
);
} else {
let abs_val = Value::abs(val);
assert!(
abs_val <= tol,
"gram[{i},{j}] : {val:?} != 0 ; {abs_val:?} > {tol:?}\n{gram_matrix}"
);
}
}
}
}
pub fn assert_basis_normalizes<B: Basis<T>, T: Value>(
basis: &B,
src_range: (T, T),
dst_range: (T, T),
) {
let min = basis.normalize_x(src_range.0);
assert_close!(min, dst_range.0, "Min normalization failed");
let max = basis.normalize_x(src_range.1);
assert_close!(max, dst_range.1, "Max normalization failed");
}
pub fn test_derivation<'a, B, T>(
f: &crate::Polynomial<'a, B, T>,
norm: &DomainNormalizer<T>,
) -> Polynomial<'a, <B as DifferentialBasis<T>>::B2, T>
where
B: Basis<T> + PolynomialDisplay<T> + DifferentialBasis<T>,
<B as DifferentialBasis<T>>::B2: RootFindingBasis<T>,
T: Value,
{
let domain = norm.src_range();
let domain = domain.0..=domain.1;
let f_prime = f.derivative().expect("Failed to compute first derivative");
#[cfg(feature = "plotting")]
{
use crate::{basis::CriticalPoint, plot};
let critical_points = f
.critical_points(domain.clone())
.expect("Failed to compute critical points");
let crit_markers = CriticalPoint::as_plotting_element(&critical_points);
plot!([f, f_prime, crit_markers], {
x_range: Some(*domain.start()..*domain.end()),
});
}
assert_is_derivative!(f, f_prime, domain);
f_prime
}
pub fn test_reversible_derivation<B, T>(f: &crate::Polynomial<B, T>, norm: &DomainNormalizer<T>)
where
B: Basis<T> + PolynomialDisplay<T> + DifferentialBasis<T>,
<B as DifferentialBasis<T>>::B2: RootFindingBasis<T>,
<<B as DifferentialBasis<T>>::B2 as DifferentialBasis<T>>::B2: IntegralBasis<T>,
<B as DifferentialBasis<T>>::B2: IntegralBasis<T>,
T: Value,
{
let domain = norm.src_range();
let domain = domain.0..=domain.1;
let f_prime = test_derivation(f, norm);
let c0 = f.coefficients()[0];
let f2 = f_prime
.integral(Some(c0))
.expect("Failed to integrate f'(x)");
assert_is_derivative!(
f2,
f_prime,
&domain,
f_lbl = "d(f(x))/dx",
fprime_lbl = "∫(d(f(x))/dx)"
);
}
pub fn test_integration<'a, B, T>(
f: &Polynomial<'a, B, T>,
norm: &DomainNormalizer<T>,
) -> Polynomial<'a, <B as IntegralBasis<T>>::B2, T>
where
B: Basis<T> + PolynomialDisplay<T> + DifferentialBasis<T> + IntegralBasis<T>,
T: Value,
{
let domain = norm.src_range();
let domain = domain.0..=domain.1;
let c = f.coefficients()[1];
let g = f.integral(Some(c)).expect("Failed to compute integral");
assert_is_derivative!(g, f, domain, f_lbl = "∫f(x)", fprime_lbl = "f(x)");
g
}
pub fn test_reversible_integration<B, T>(f: &crate::Polynomial<B, T>, norm: &DomainNormalizer<T>)
where
B: Basis<T> + PolynomialDisplay<T> + DifferentialBasis<T> + IntegralBasis<T>,
<B as IntegralBasis<T>>::B2: DifferentialBasis<T>,
T: Value,
{
let domain = norm.src_range();
let domain = domain.0..=domain.1;
let g = test_integration(f, norm);
let f2 = g.derivative().expect("Failed to compute first derivative");
assert_is_derivative!(g, f2, domain, f_lbl = "∫f(x)", fprime_lbl = "f(x)");
}
pub fn test_root_finding<
B: Basis<T> + PolynomialDisplay<T> + DifferentialBasis<T> + RootFindingBasis<T>,
T: Value,
>(
f: &crate::Polynomial<B, T>,
x_range: RangeInclusive<T>,
) {
let roots = f.roots(x_range.clone()).expect("Failed to compute roots");
let real_roots = f.roots(x_range).expect("Failed to compute real roots");
let mut real_from_roots: Vec<_> = roots
.iter()
.filter_map(|r| {
if let Root::Real(root_r) = r {
Some(*root_r)
} else {
None
}
})
.collect();
let df = f.derivative().expect("Failed to compute derivative");
for root in &real_roots {
let Some(root) = root.as_real() else {
continue;
};
let dy_at_root = df.y(root);
let stability_factor = T::from_f64(1e12).unwrap_or(T::zero());
let slope_factor = T::from_f64(1e-5).unwrap_or(T::zero());
let tol = (T::epsilon() * stability_factor) / Value::max(dy_at_root, slope_factor);
let mut found_at = None;
for (i, r) in real_from_roots.iter().enumerate() {
if (*r).abs_sub(root) <= tol {
found_at = Some(i);
break;
}
}
if let Some(i) = found_at {
real_from_roots.remove(i);
} else {
eprintln!("Roots: {roots:?}");
eprintln!("Real roots: {real_roots:?}");
panic!("Real root {root:?} not found in roots within tolerance {tol:?}");
}
}
if !real_from_roots.is_empty() {
eprintln!("Roots: {roots:?}");
eprintln!("Real roots: {real_roots:?}");
panic!(
"Found extra real roots in roots() that are not in real_roots(): {real_from_roots:?}"
);
}
}
pub fn test_complex_y<B: Basis<T> + PolynomialDisplay<T> + RootFindingBasis<T>, T: Value>(
f: &crate::Polynomial<B, T>,
x_range: RangeInclusive<T>,
) {
for value in SteppedValues::new(x_range, T::from_positive_int(100)) {
let real_y = f.y(value);
let normal_f = f.basis().normalize_x(value);
let complex_y = f
.basis()
.complex_y(Complex::from_real(normal_f), f.coefficients());
assert_close!(
complex_y.re,
real_y,
epsilon = T::from_f64(1e-6).unwrap_or(T::zero()),
"{f}\nReal part of complex_y should match solve_function on the real axis (x = {value:?})"
);
assert!(
Value::abs(complex_y.im) <= T::epsilon(),
"{f}\nImaginary part of complex_y should be close to zero on the real axis (x = {value:?})"
);
}
}