use std::ops::RangeInclusive;
use nalgebra::Complex;
use nalgebra::ComplexField;
use nalgebra::MatrixViewMut;
use nalgebra::Normed;
use crate::value::bisect;
use crate::value::FloatClampedCast;
use crate::value::SteppedValues;
use crate::{error::Result, value::Value};
pub(crate) mod monomial;
pub use monomial::MonomialBasis;
pub(crate) mod chebyshev;
pub use chebyshev::{ChebyshevBasis, SecondFormChebyshevBasis, ThirdFormChebyshevBasis};
pub(crate) mod augmented_fourier;
pub use augmented_fourier::{AugmentedFourierBasis, FourierBasis, LinearAugmentedFourierBasis};
pub(crate) mod legendre;
pub use legendre::LegendreBasis;
pub(crate) mod hermite;
pub use hermite::PhysicistsHermiteBasis;
pub use hermite::ProbabilistsHermiteBasis;
pub(crate) mod laguerre;
pub use laguerre::LaguerreBasis;
pub(crate) mod logarithmic;
pub use logarithmic::LogarithmicBasis;
pub trait Basis<T: Value>: Sized + Clone + std::fmt::Debug + Send + Sync {
fn from_range(x_range: std::ops::RangeInclusive<T>) -> Self;
#[inline(always)]
fn k(&self, degree: usize) -> usize {
degree + 1
}
#[inline(always)]
fn degree(&self, k: usize) -> Option<usize> {
if k > 0 {
Some(k - 1)
} else {
None
}
}
fn fill_matrix_row<R: nalgebra::Dim, C: nalgebra::Dim, RS: nalgebra::Dim, CS: nalgebra::Dim>(
&self,
start_index: usize,
x: T,
row: MatrixViewMut<T, R, C, RS, CS>,
);
fn normalize_x(&self, x: T) -> T;
fn denormalize_x(&self, x: T) -> T;
fn solve_function(&self, j: usize, x: T) -> T;
fn solve(&self, x: T, coefficients: &[T]) -> T {
let mut y = T::zero();
for (i, &coef) in coefficients.iter().enumerate() {
y += coef * self.solve_function(i, x);
}
y
}
}
pub trait IntoMonomialBasis<T: Value>: Basis<T> {
fn as_monomial(&self, coefficients: &mut [T]) -> Result<()>;
}
pub trait DifferentialBasis<T: Value>: Basis<T> {
type B2: Basis<T> + crate::display::PolynomialDisplay<T>;
fn derivative(&self, coefficients: &[T]) -> Result<(Self::B2, Vec<T>)>;
fn second_derivative(&self, coefficients: &[T]) -> Result<(SecondDerivative<Self, T>, Vec<T>)>
where
Self::B2: DifferentialBasis<T>,
{
let (basis1, first) = self.derivative(coefficients)?;
let (basis2, second) = basis1.derivative(&first)?;
Ok((basis2, second))
}
}
pub type SecondDerivative<B, T> = <<B as DifferentialBasis<T>>::B2 as DifferentialBasis<T>>::B2;
#[derive(Debug, Clone, PartialEq, Copy)]
pub enum RootFindingMethod {
Analytical,
Iterative,
}
impl std::fmt::Display for RootFindingMethod {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
RootFindingMethod::Analytical => write!(f, "Analytical"),
RootFindingMethod::Iterative => write!(f, "Iterative"),
}
}
}
pub trait RootFindingBasis<T: Value>: Basis<T> + DifferentialBasis<T> {
const DEFAULT_ROOT_FINDING_SAMPLES: usize = 5000;
const DEFAULT_ROOT_FINDING_MAX_ITERATIONS: usize = 20;
fn root_finding_method(&self) -> RootFindingMethod {
RootFindingMethod::Iterative
}
fn roots(&self, coefs: &[T], x_range: RangeInclusive<T>) -> Result<Vec<Root<T>>> {
self.roots_iterative(
coefs,
x_range,
Self::DEFAULT_ROOT_FINDING_SAMPLES,
Self::DEFAULT_ROOT_FINDING_MAX_ITERATIONS,
)
}
fn roots_iterative(
&self,
coefs: &[T],
x_range: RangeInclusive<T>,
samples: usize,
max_iterations: usize,
) -> Result<Vec<Root<T>>> {
let mut roots = vec![];
let mut prev_x = *x_range.start();
let mut prev_y = self.solve(prev_x, coefs);
let domain_width = *x_range.end() - *x_range.start();
let domain_scalar = match (Value::abs(domain_width) + T::one()).log10() - T::one() {
x if x > T::zero() => x,
x if x < T::zero() => T::one() / -x,
_ => T::one(),
};
let precision_scalar = T::epsilon() / f64::EPSILON.clamped_cast::<T>();
let num_samples = T::from_positive_int(samples) * domain_scalar * precision_scalar;
let (dx_basis, dx_coefs) = self.derivative(coefs)?;
let sqrt_eps = T::epsilon().sqrt();
for x in SteppedValues::new(x_range, domain_width / num_samples) {
let y = self.solve(x, coefs);
if (prev_y * y).is_sign_negative() || prev_y.is_near_zero() || y.is_near_zero() {
let (x, _) = bisect(
&|x| self.solve(x, coefs),
prev_x,
x,
prev_y,
y,
4, );
let mut a = prev_x;
let mut b = x;
let mut fa = prev_y;
for _ in 0..max_iterations {
let m = (a + b) / T::two();
let fm = self.solve(m, coefs);
if (fa * fm).is_sign_negative() {
b = m;
} else {
a = m;
fa = fm;
}
}
let x = (a + b) / T::two();
let mut newton_prev_x = x;
let mut newton_x;
for _ in 0..max_iterations {
let y = self.solve(newton_prev_x, coefs);
let dy = dx_basis.solve(newton_prev_x, &dx_coefs);
newton_x = newton_prev_x - y / dy;
newton_x = Value::clamp(newton_x, prev_x, x);
let rel_tol = sqrt_eps + (T::one() * Value::abs(newton_x));
if Value::abs(y) <= rel_tol || Value::abs(newton_x - newton_prev_x) <= rel_tol {
break;
}
newton_prev_x = newton_x;
}
roots.push(newton_prev_x);
}
prev_x = x;
prev_y = y;
}
let roots = roots.into_iter().map(Root::Real).collect();
Ok(roots)
}
fn complex_y(&self, z: Complex<T>, coefs: &[T]) -> Complex<T>;
}
pub trait IntegralBasis<T: Value>: Basis<T> {
type B2: Basis<T> + crate::display::PolynomialDisplay<T>;
fn integral(&self, coefficients: &[T], constant: T) -> Result<(Self::B2, Vec<T>)>;
}
#[derive(Debug, Clone, PartialEq, Copy)]
pub enum CriticalPoint<T: Value> {
Minima(T, T),
Maxima(T, T),
Inflection(T, T),
}
impl<T: Value> std::fmt::Display for CriticalPoint<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
CriticalPoint::Minima(x, y) => write!(f, "Minima({x:.2}, {y:.2})"),
CriticalPoint::Maxima(x, y) => write!(f, "Maxima({x:.2}, {y:.2})"),
CriticalPoint::Inflection(x, y) => write!(f, "Inflection({x:.2}, {y:.2})"),
}
}
}
impl<T: Value> CriticalPoint<T> {
#[cfg(feature = "plotting")]
pub fn as_plotting_element(points: &[Self]) -> crate::plotting::PlottingElement<T> {
crate::plotting::PlottingElement::from_markers(points.iter().map(|p| {
let (x, y) = p.coords();
(x, y, Some(p.to_string()))
}))
}
pub fn x(&self) -> T {
match self {
CriticalPoint::Minima(x, _)
| CriticalPoint::Maxima(x, _)
| CriticalPoint::Inflection(x, _) => *x,
}
}
pub fn y(&self) -> T {
match self {
CriticalPoint::Minima(_, y)
| CriticalPoint::Maxima(_, y)
| CriticalPoint::Inflection(_, y) => *y,
}
}
pub fn coords(&self) -> (T, T) {
match self {
CriticalPoint::Minima(x, y)
| CriticalPoint::Maxima(x, y)
| CriticalPoint::Inflection(x, y) => (*x, *y),
}
}
}
#[derive(Debug, Clone, PartialEq, Copy)]
pub enum Root<T: Value> {
Real(T),
Complex(Complex<T>),
ComplexPair(Complex<T>, Complex<T>),
}
impl<T: Value> Root<T> {
pub fn is_real(&self) -> bool {
matches!(self, Root::Real(_))
}
pub fn is_complex(&self) -> bool {
matches!(self, Root::Complex(_) | Root::ComplexPair(_, _))
}
pub fn as_real(&self) -> Option<T> {
match self {
Root::Real(x) => Some(*x),
_ => None,
}
}
pub fn as_complex(&self) -> Option<Vec<Complex<T>>> {
match self {
Root::Complex(z) => Some(vec![*z]),
Root::ComplexPair(z1, z2) => Some(vec![*z1, *z2]),
Root::Real(_) => None,
}
}
#[allow(
clippy::match_same_arms,
reason = "This is more readable as a match statement, even if some arms are the same"
)]
pub fn roots_from_complex<F: Fn(&Complex<T>) -> Complex<T>>(
eigenvalues: &[Complex<T>],
solver: F,
) -> Vec<Root<T>> {
let mut roots = Vec::new();
let mut skip = vec![false; eigenvalues.len()];
for i in 0..eigenvalues.len() {
if skip[i] {
continue;
}
if !eigenvalues[i].imaginary().is_finite() || !eigenvalues[i].real().is_finite() {
continue;
}
let zero_tol = (T::one() + eigenvalues[i].norm()) * T::epsilon().sqrt();
if solver(&eigenvalues[i]).norm() > zero_tol {
continue;
}
let conj_tol = T::epsilon().sqrt() * (T::one() + eigenvalues[i].norm());
for j in (i + 1)..eigenvalues.len() {
if (eigenvalues[i] - eigenvalues[j]).norm() < conj_tol {
skip[j] = true;
}
}
if Value::abs(eigenvalues[i].imaginary()) < zero_tol {
roots.push(Root::Real(eigenvalues[i].real()));
continue;
}
for j in (i + 1)..eigenvalues.len() {
if Value::abs(eigenvalues[i].real() - eigenvalues[j].real()) < conj_tol
&& Value::abs(eigenvalues[i].imaginary() + eigenvalues[j].imaginary())
< conj_tol
{
let root_i = eigenvalues[i];
let root_j = eigenvalues[j];
skip[j] = true;
roots.push(Root::ComplexPair(root_i, root_j));
break;
}
}
roots.push(Root::Complex(eigenvalues[i]));
}
roots.sort_by(|a, b| match (a, b) {
(Root::Real(x), Root::Real(y)) => x.partial_cmp(y).unwrap_or(std::cmp::Ordering::Equal),
(Root::Real(_), _) => std::cmp::Ordering::Less,
(_, Root::Real(_)) => std::cmp::Ordering::Greater,
(Root::ComplexPair(a1, a2), Root::ComplexPair(b1, b2)) => {
let mag_a = a1.norm() + a2.norm();
let mag_b = b1.norm() + b2.norm();
mag_a
.partial_cmp(&mag_b)
.unwrap_or(std::cmp::Ordering::Equal)
}
(Root::ComplexPair(_, _), Root::Complex(_)) => std::cmp::Ordering::Less,
(Root::Complex(_), Root::ComplexPair(_, _)) => std::cmp::Ordering::Greater,
(Root::Complex(a), Root::Complex(b)) => a
.norm()
.partial_cmp(&b.norm())
.unwrap_or(std::cmp::Ordering::Equal),
});
roots
}
}
pub trait OrthogonalBasis<T: Value>: Basis<T> {
fn gauss_weight(&self, x: T) -> T;
fn gauss_nodes(&self, n: usize) -> Vec<(T, T)>;
fn gauss_normalization(&self, n: usize) -> T;
fn inner_product(&self, i: usize, j: usize, nodes: &[(T, T)]) -> T {
let mut sum = T::zero();
for (x, w) in nodes {
sum += self.solve_function(i, *x) * self.solve_function(j, *x) * *w;
}
sum
}
fn gauss_matrix(&self, functions: usize, nodes: usize) -> nalgebra::DMatrix<T> {
let nodes = self.gauss_nodes(nodes);
let mut mat = nalgebra::DMatrix::<T>::zeros(functions, functions);
for i in 0..functions {
for j in i..functions {
let val = self.inner_product(i, j, &nodes);
mat[(i, j)] = val;
if i != j {
mat[(j, i)] = val; }
}
}
mat
}
}