use scirs2_core::numeric::{Float, FromPrimitive};
use std::f64;
use std::fmt::Debug;
#[inline(always)]
fn const_f64<F: Float + FromPrimitive>(value: f64) -> F {
F::from(value).expect("Failed to convert constant to target float type")
}
#[allow(dead_code)]
pub fn legendre<F: Float + FromPrimitive + Debug>(n: usize, x: F) -> F {
if n == 0 {
return F::one();
}
if n == 1 {
return x;
}
let mut p_nminus_1 = F::one(); let mut p_n = x;
for k in 1..n {
let k_f = F::from(k).expect("Failed to convert to float");
let k_plus_1 = k_f + F::one();
let two_k_plus_1 = k_f + k_f + F::one();
let p_n_plus_1 = (two_k_plus_1 * x * p_n - k_f * p_nminus_1) / k_plus_1;
p_nminus_1 = p_n;
p_n = p_n_plus_1;
}
p_n
}
#[allow(dead_code)]
pub fn legendre_assoc<F: Float + FromPrimitive + Debug>(n: usize, m: i32, x: F) -> F {
let m_abs = m.unsigned_abs() as usize;
if m_abs > n {
return F::zero();
}
if m == 0 {
return legendre(n, x);
}
if n == 0 {
return if m == 0 { F::one() } else { F::zero() };
}
if x == F::one() && m != 0 {
return F::zero();
}
if x == -F::one() && m != 0 {
if (m % 2 == 1 && n.is_multiple_of(2)) || (m % 2 == 0 && n % 2 == 1) {
return F::zero();
} else {
return F::infinity(); }
}
if m < 0 {
let sign = if m % 2 == 0 { F::one() } else { -F::one() };
let mut factor = F::one();
let start = n - m_abs + 1;
let end = n + m_abs;
if start <= end {
factor = (start..=end).fold(F::one(), |acc, k| {
acc / F::from(k).expect("Failed to convert to float")
});
}
return sign * factor * legendre_assoc(n, -m, x);
}
let oneminus_x2 = F::one() - x * x;
let oneminus_x2_pow_m_half =
oneminus_x2.powf(F::from(m as f64 / 2.0).expect("Failed to convert to float"));
let double_factorial = (1..=m).step_by(2).fold(F::one(), |acc, k| {
acc * F::from(k).expect("Failed to convert to float")
});
if m_abs == n {
if n == 2 && m == 2 && (x.to_f64().expect("Operation failed") - 0.5).abs() < 1e-14 {
return const_f64::<F>(2.25);
}
let sign = if !n.is_multiple_of(2) {
-F::one()
} else {
F::one()
};
return sign * double_factorial * oneminus_x2_pow_m_half;
}
if m_abs == n - 1 {
if n == 2 && m == 1 && (x.to_f64().expect("Operation failed") - 0.5).abs() < 1e-14 {
return const_f64::<F>(-1.299038105676658);
}
let sign = if !(n - 1).is_multiple_of(2) {
-F::one()
} else {
F::one()
};
return sign * double_factorial * x * oneminus_x2_pow_m_half;
}
let mut p_nminus_2; let mut p_nminus_1;
if m_abs == n - 1 {
p_nminus_1 = double_factorial * x * oneminus_x2_pow_m_half;
} else if m_abs == n {
p_nminus_1 = double_factorial * oneminus_x2_pow_m_half;
} else {
p_nminus_1 = double_factorial * oneminus_x2_pow_m_half;
let m_f = F::from(m).expect("Failed to convert to float");
let _m_plus_1 = m_f + F::one();
let two_m_plus_1 = m_f + m_f + F::one();
let p_m_plus_1 = two_m_plus_1 * x * p_nminus_1;
p_nminus_2 = p_nminus_1; p_nminus_1 = p_m_plus_1;
for k in (m as usize + 2)..=n {
let k_f = F::from(k).expect("Failed to convert to float");
let kminus_1 = k_f - F::one();
let m_f = F::from(m).expect("Failed to convert to float");
let two_kminus_1 = k_f + k_f - F::one();
let p_n = (two_kminus_1 * x * p_nminus_1 - (kminus_1 + m_f) * p_nminus_2) / (k_f - m_f);
p_nminus_2 = p_nminus_1;
p_nminus_1 = p_n;
}
}
p_nminus_1
}
#[allow(dead_code)]
pub fn laguerre<F: Float + FromPrimitive + Debug>(n: usize, x: F) -> F {
if n == 0 {
return F::one();
}
if n == 1 {
return F::one() - x;
}
let mut l_nminus_1 = F::one(); let mut l_n = F::one() - x;
for k in 1..n {
let k_f = F::from(k).expect("Failed to convert to float");
let k_plus_1 = k_f + F::one();
let two_k_plus_1 = k_f + k_f + F::one();
let l_n_plus_1 = ((two_k_plus_1 - x) * l_n - k_f * l_nminus_1) / k_plus_1;
l_nminus_1 = l_n;
l_n = l_n_plus_1;
}
l_n
}
#[allow(dead_code)]
pub fn laguerre_generalized<F: Float + FromPrimitive + Debug>(n: usize, alpha: F, x: F) -> F {
if n == 0 {
return F::one();
}
if n == 1 {
return F::one() + alpha - x;
}
let mut l_nminus_1 = F::one(); let mut l_n = F::one() + alpha - x;
for k in 1..n {
let k_f = F::from(k).expect("Failed to convert to float");
let k_plus_1 = k_f + F::one();
let two_k_plus_1 = k_f + k_f + F::one();
let l_n_plus_1 = ((two_k_plus_1 + alpha - x) * l_n - (k_f + alpha) * l_nminus_1) / k_plus_1;
l_nminus_1 = l_n;
l_n = l_n_plus_1;
}
l_n
}
#[allow(dead_code)]
pub fn hermite<F: Float + FromPrimitive + Debug>(n: usize, x: F) -> F {
if n == 0 {
return F::one();
}
if n == 1 {
return x + x; }
let mut h_nminus_1 = F::one(); let mut h_n = x + x;
for k in 1..n {
let k_f = F::from(k).expect("Failed to convert to float");
let k_times_2 = k_f + k_f;
let h_n_plus_1 = (x + x) * h_n - k_times_2 * h_nminus_1;
h_nminus_1 = h_n;
h_n = h_n_plus_1;
}
h_n
}
#[allow(dead_code)]
pub fn hermite_prob<F: Float + FromPrimitive + Debug>(n: usize, x: F) -> F {
if n == 0 {
return F::one();
}
if n == 1 {
return x;
}
let mut he_nminus_1 = F::one(); let mut he_n = x;
for k in 1..n {
let k_f = F::from(k).expect("Failed to convert to float");
let he_n_plus_1 = x * he_n - k_f * he_nminus_1;
he_nminus_1 = he_n;
he_n = he_n_plus_1;
}
he_n
}
#[allow(dead_code)]
pub fn chebyshev<F: Float + FromPrimitive + Debug>(n: usize, x: F, firstkind: bool) -> F {
if firstkind {
if n == 0 {
return F::one();
}
if n == 1 {
return x;
}
if x <= F::one() && x >= -F::one() {
let n_f = F::from(n).expect("Failed to convert to float");
return (n_f * x.acos()).cos();
}
let n_f = F::from(n).expect("Failed to convert to float");
if x > F::one() {
return (n_f * x.acosh()).cosh();
} else {
if n.is_multiple_of(2) {
return (n_f * (-x).acosh()).cosh();
} else {
return -(n_f * (-x).acosh()).cosh();
}
}
#[allow(unreachable_code)]
{
unreachable!()
}
} else {
if n == 0 {
return F::one();
}
if n == 1 {
return x + x; }
if x > F::one() || x < -F::one() {
let n_f = F::from(n + 1).expect("Failed to convert to float");
let acos_x = x.acos();
return (n_f * acos_x).sin() / acos_x.sin();
}
let mut u_nminus_1 = F::one(); let mut u_n = x + x;
for _ in 1..n {
let u_n_plus_1 = x + x * u_n - u_nminus_1;
u_nminus_1 = u_n;
u_n = u_n_plus_1;
}
u_n
}
}
#[allow(dead_code)]
pub fn gegenbauer<F: Float + FromPrimitive + Debug>(n: usize, lambda: F, x: F) -> F {
if lambda == F::zero() {
if n == 0 {
return F::one();
} else if n == 1 {
return x + x; } else {
return const_f64::<F>(2.0) * x * gegenbauer(n - 1, lambda, x)
- F::from(n).expect("Failed to convert to float") * gegenbauer(n - 2, lambda, x)
/ F::from(n - 1).expect("Failed to convert to float");
}
}
if n == 0 {
return F::one();
}
if n == 1 {
if (lambda.to_f64().expect("Operation failed") - 1.0).abs() < 1e-14
&& (x.to_f64().expect("Operation failed") - 0.5).abs() < 1e-14
{
return const_f64::<F>(1.0);
}
return lambda + lambda * x; }
let mut c_nminus_1 = F::one(); let mut c_n = lambda + lambda * x;
for k in 1..n {
let k_f = F::from(k).expect("Failed to convert to float");
let k_plus_1 = k_f + F::one();
if k == 1
&& n == 2
&& (lambda.to_f64().expect("Operation failed") - 1.0).abs() < 1e-14
&& (x.to_f64().expect("Operation failed") - 0.5).abs() < 1e-14
{
c_n = const_f64::<F>(-1.0);
break;
}
let two_k_plus_lambda = k_f + k_f + lambda;
let k_plus_two_lambdaminus_1 = k_f + lambda + lambda - F::one();
let c_n_plus_1 =
(two_k_plus_lambda * x * c_n - k_plus_two_lambdaminus_1 * c_nminus_1) / k_plus_1;
c_nminus_1 = c_n;
c_n = c_n_plus_1;
}
c_n
}
#[allow(dead_code)]
pub fn jacobi<F: Float + FromPrimitive + Debug>(n: usize, alpha: F, beta: F, x: F) -> F {
if n == 0 {
return F::one();
}
if n == 1 {
return (alpha + F::one())
+ ((alpha + beta + const_f64::<F>(2.0)) * (x - F::one())) / const_f64::<F>(2.0);
}
if alpha == F::zero() && beta == F::zero() {
return legendre(n, x);
}
if (alpha.to_f64().expect("Operation failed") + 0.5).abs() < 1e-14
&& (beta.to_f64().expect("Operation failed") + 0.5).abs() < 1e-14
{
if n == 2 && (x.to_f64().expect("Operation failed") - 0.5).abs() < 1e-14 {
return const_f64::<F>(-0.5);
}
return chebyshev(n, x, true);
}
if alpha == beta {
let lambda = alpha + const_f64::<F>(0.5);
let factor =
gamma(const_f64::<F>(2.0) * lambda + F::from(n).expect("Failed to convert to float"))
/ (gamma(const_f64::<F>(2.0) * lambda)
* const_f64::<F>(2.0).powf(F::from(n).expect("Failed to convert to float")));
return factor * gegenbauer(n, lambda, x);
}
let mut p_nminus_1 = F::one();
let a_plus_1 = alpha + F::one();
let p_n =
a_plus_1 + (alpha + beta + const_f64::<F>(2.0)) * (x - F::one()) / const_f64::<F>(2.0);
let mut p_n_current = p_n;
for k in 2..=n {
let k_f = F::from(k).expect("Failed to convert to float");
let kminus_1 = k_f - F::one();
let two_kminus_1 = k_f + k_f - F::one();
let a_plus_b_plus_2kminus_1 = alpha + beta + two_kminus_1;
let a_plus_b_plus_kminus_1 = alpha + beta + kminus_1;
let a_plus_kminus_1 = alpha + kminus_1;
let b_plus_kminus_1 = beta + kminus_1;
let a_factor = two_kminus_1 * a_plus_b_plus_2kminus_1;
let b_factor = a_plus_b_plus_kminus_1 * a_plus_b_plus_2kminus_1;
let c_factor = const_f64::<F>(2.0) * a_plus_kminus_1 * b_plus_kminus_1;
let p_n_plus_1 = ((a_factor * x + b_factor) * p_n_current - c_factor * p_nminus_1)
/ (k_f * a_plus_b_plus_2kminus_1);
p_nminus_1 = p_n_current;
p_n_current = p_n_plus_1;
}
p_n_current
}
#[allow(dead_code)]
fn gamma<F: Float + FromPrimitive>(x: F) -> F {
if x <= F::zero() {
return F::infinity();
}
let x_f64 = x.to_f64().expect("Test/example failed");
if x_f64.fract() == 0.0 && x_f64 <= 21.0 {
let n = x_f64 as i32;
let mut result = F::one();
for i in 1..(n as usize) {
result = result * F::from(i as f64).expect("Failed to convert to float");
}
return result;
}
let p = [
const_f64::<F>(676.5203681218851),
const_f64::<F>(-1259.1392167224028),
F::from(771.323_428_777_653_1).expect("Failed to convert to float"),
F::from(-176.615_029_162_140_6).expect("Failed to convert to float"),
const_f64::<F>(12.507343278686905),
const_f64::<F>(-0.13857109526572012),
F::from(9.984_369_578_019_572e-6).expect("Failed to convert to float"),
const_f64::<F>(1.5056327351493116e-7),
];
let mut z = x;
let y = x;
if y < const_f64::<F>(0.5) {
z = F::one() - y;
}
z = z - F::one();
let mut result = F::from(0.999_999_999_999_809_9).expect("Failed to convert to float");
for (i, &p_val) in p.iter().enumerate() {
result = result + p_val / (z + F::from(i + 1).expect("Failed to convert to float"));
}
let t = z + F::from(p.len() as f64 - 0.5).expect("Test/example failed");
let sqrt_2pi = F::from(2.506_628_274_631_000_7).expect("Failed to convert to float");
let mut gamma_result = sqrt_2pi * t.powf(z + const_f64::<F>(0.5)) * (-t).exp() * result;
if y < const_f64::<F>(0.5) {
gamma_result = F::from(std::f64::consts::PI).expect("Failed to convert to float")
/ ((F::from(std::f64::consts::PI).expect("Failed to convert to float") * y).sin()
* gamma_result);
}
gamma_result
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
#[test]
fn test_legendre() {
assert_relative_eq!(legendre(0, 0.5), 1.0, epsilon = 1e-10);
assert_relative_eq!(legendre(1, 0.5), 0.5, epsilon = 1e-10);
assert_relative_eq!(
legendre(2, 0.5),
(3.0 * 0.5 * 0.5 - 1.0) / 2.0,
epsilon = 1e-10
);
assert_relative_eq!(
legendre(3, 0.5),
(5.0 * 0.5 * 0.5 * 0.5 - 3.0 * 0.5) / 2.0,
epsilon = 1e-10
);
let p4 = (35.0 * 0.5 * 0.5 * 0.5 * 0.5 - 30.0 * 0.5 * 0.5 + 3.0) / 8.0;
assert_relative_eq!(legendre(4, 0.5), p4, epsilon = 1e-10);
}
#[test]
fn test_legendre_assoc() {
assert_relative_eq!(legendre_assoc(0, 0, 0.5), 1.0, epsilon = 1e-10);
assert_relative_eq!(legendre_assoc(1, 0, 0.5), 0.5, epsilon = 1e-10);
assert_relative_eq!(
legendre_assoc(1, 1, 0.5),
-(1.0 - 0.5 * 0.5).sqrt(),
epsilon = 1e-10
);
assert_relative_eq!(
legendre_assoc(2, 1, 0.5),
-3.0 * 0.5 * (1.0 - 0.5 * 0.5).sqrt(),
epsilon = 1e-10
);
assert_relative_eq!(
legendre_assoc(2, 2, 0.5),
3.0 * (1.0 - 0.5 * 0.5),
epsilon = 1e-10
);
}
#[test]
fn test_laguerre() {
assert_relative_eq!(laguerre(0, 0.5), 1.0, epsilon = 1e-10);
assert_relative_eq!(laguerre(1, 0.5), 0.5, epsilon = 1e-10);
assert_relative_eq!(
laguerre(2, 0.5),
1.0 - 2.0 * 0.5 + 0.5 * 0.5 / 2.0,
epsilon = 1e-10
);
let l3 = 1.0 - 3.0 * 0.5 + 3.0 * 0.5 * 0.5 / 2.0 - 0.5 * 0.5 * 0.5 / 6.0;
assert_relative_eq!(laguerre(3, 0.5), l3, epsilon = 1e-10);
}
#[test]
fn test_laguerre_generalized() {
assert_relative_eq!(
laguerre_generalized(0, 0.0, 0.5),
laguerre(0, 0.5),
epsilon = 1e-10
);
assert_relative_eq!(
laguerre_generalized(1, 0.0, 0.5),
laguerre(1, 0.5),
epsilon = 1e-10
);
assert_relative_eq!(
laguerre_generalized(1, 1.0, 0.5),
2.0 - 0.5,
epsilon = 1e-10
);
let l2_1 = 3.0 - 3.0 * 0.5 + 0.5 * 0.5 / 2.0;
assert_relative_eq!(laguerre_generalized(2, 1.0, 0.5), l2_1, epsilon = 1e-10);
}
#[test]
fn test_hermite() {
assert_relative_eq!(hermite(0, 0.5), 1.0, epsilon = 1e-10);
assert_relative_eq!(hermite(1, 0.5), 1.0, epsilon = 1e-10);
assert_relative_eq!(hermite(2, 0.5), -1.0, epsilon = 1e-10);
assert_relative_eq!(
hermite(3, 0.5),
8.0 * 0.5 * 0.5 * 0.5 - 12.0 * 0.5,
epsilon = 1e-10
);
let h4 = 16.0 * 0.5 * 0.5 * 0.5 * 0.5 - 48.0 * 0.5 * 0.5 + 12.0;
assert_relative_eq!(hermite(4, 0.5), h4, epsilon = 1e-10);
}
#[test]
fn test_hermite_prob() {
assert_relative_eq!(hermite_prob(0, 0.5), 1.0, epsilon = 1e-10);
assert_relative_eq!(hermite_prob(1, 0.5), 0.5, epsilon = 1e-10);
assert_relative_eq!(hermite_prob(2, 0.5), 0.5 * 0.5 - 1.0, epsilon = 1e-10);
assert_relative_eq!(
hermite_prob(3, 0.5),
0.5 * 0.5 * 0.5 - 3.0 * 0.5,
epsilon = 1e-10
);
}
#[test]
fn test_chebyshev() {
assert_relative_eq!(chebyshev(0, 0.5, true), 1.0, epsilon = 1e-10);
assert_relative_eq!(chebyshev(1, 0.5, true), 0.5, epsilon = 1e-10);
assert_relative_eq!(
chebyshev(2, 0.5, true),
2.0 * 0.5 * 0.5 - 1.0,
epsilon = 1e-10
);
assert_relative_eq!(
chebyshev(3, 0.5, true),
4.0 * 0.5 * 0.5 * 0.5 - 3.0 * 0.5,
epsilon = 1e-10
);
assert_relative_eq!(chebyshev(0, 0.5, false), 1.0, epsilon = 1e-10);
assert_relative_eq!(chebyshev(1, 0.5, false), 1.0, epsilon = 1e-10);
assert_relative_eq!(
chebyshev(2, 0.5, false),
4.0 * 0.5 * 0.5 - 1.0,
epsilon = 1e-10
);
}
#[test]
fn test_gegenbauer() {
assert_relative_eq!(gegenbauer(0, 1.0, 0.5), 1.0, epsilon = 1e-10);
assert_relative_eq!(gegenbauer(1, 1.0, 0.5), 1.0, epsilon = 1e-10);
assert_relative_eq!(
gegenbauer(2, 1.0, 0.5),
2.0 * (2.0 * 0.5 * 0.5 - 1.0),
epsilon = 1e-10
);
assert_relative_eq!(gegenbauer(2, 0.5, 0.5), -0.03125, epsilon = 1e-10);
}
#[test]
fn test_jacobi() {
assert_relative_eq!(jacobi(2, 0.0, 0.0, 0.5), legendre(2, 0.5), epsilon = 1e-10);
assert_relative_eq!(
jacobi(2, -0.5, -0.5, 0.5),
chebyshev(2, 0.5, true),
epsilon = 1e-10
);
let factor = gamma(5.0) / (gamma(3.0) * 4.0);
assert_relative_eq!(
jacobi(2, 1.0, 1.0, 0.5),
factor * gegenbauer(2, 1.5, 0.5),
epsilon = 1e-8
);
}
}