use num_complex::{Complex, ComplexFloat};
use num_traits::{Float, Signed, Zero};
use crate::NEG_INV_E;
use core::{
f64::consts::{E, PI},
ops::{Add, Mul, Sub, SubAssign},
};
const MAX_ITERS: u8 = 255;
#[must_use = "this is a pure function that only returns a value and has no side effects"]
pub fn lambert_w(k: i32, z_re: f64, z_im: f64, error_tolerance: f64) -> (f64, f64) {
let w = lambert_w_generic(k, num_complex::Complex64::new(z_re, z_im), error_tolerance);
(w.re, w.im)
}
#[must_use = "this is a pure function that only returns a value and has no side effects"]
pub fn lambert_wf(k: i16, z_re: f32, z_im: f32, error_tolerance: f32) -> (f32, f32) {
let w = lambert_w_generic(k, num_complex::Complex32::new(z_re, z_im), error_tolerance);
(w.re, w.im)
}
fn lambert_w_generic<T, U>(k: U, z: Complex<T>, error_tolerance: T) -> Complex<T>
where
U: Signed + Copy,
T: Float
+ AsCastFrom<f64>
+ From<U>
+ From<<Complex<T> as ComplexFloat>::Real>
+ Mul<Complex<T>, Output = Complex<T>>
+ Add<Complex<T>, Output = Complex<T>>
+ Sub<Complex<T>, Output = Complex<T>>,
Complex<T>: ComplexFloat
+ SubAssign
+ Mul<T, Output = Complex<T>>
+ Add<T, Output = Complex<T>>
+ Sub<T, Output = Complex<T>>,
{
if !z.is_finite() {
return Complex::<T>::new(T::nan(), T::nan());
}
if error_tolerance.is_nan() {
return Complex::<T>::new(T::nan(), T::nan());
}
let i_zero = U::zero();
let i_one = U::one();
let d_zero = T::zero();
let d_one = T::one();
let d_two = d_one + d_one;
let d_e: T = T::as_cast_from(E);
let d_neg_inv_e: T = T::as_cast_from(NEG_INV_E);
let z_zero = Complex::<T>::from(d_zero);
let z_one = Complex::<T>::from(d_one);
if z == z_zero {
return if k == i_zero {
z_zero
} else {
T::neg_infinity().into()
};
}
if z == d_neg_inv_e.into() && (k == i_zero || k == -i_one) {
return -z_one;
}
if z == d_e.into() && k == i_zero {
return z_one;
}
let mut w = determine_start_point(k, z);
let mut w_prev_prev = None;
let mut iter = 0;
loop {
let w_prev = w;
let ew = w.exp();
w -= d_two * (w + d_one) * (w * ew - z)
/ (ew * (w * w + d_two * w + d_two) + (w + d_two) * z);
iter += 1;
if Some(w) == w_prev_prev {
return w_prev;
}
if are_nearly_equal(w, w_prev, error_tolerance) || iter == MAX_ITERS || !w.is_finite() {
return w;
}
w_prev_prev = Some(w);
}
}
pub(crate) fn are_nearly_equal<T>(a: Complex<T>, b: Complex<T>, epsilon: T) -> bool
where
T: Float + From<<Complex<T> as ComplexFloat>::Real>,
Complex<T>: ComplexFloat,
{
if a == b {
true
} else if a.is_nan() || b.is_nan() {
false
} else {
let indicator: T = a.abs().max(b.abs()).into();
let diff: T = (a - b).abs().into();
let zero = Complex::<T>::zero();
if a == zero || b == zero || indicator < T::min_positive_value() {
diff < epsilon * T::min_positive_value()
} else {
diff / indicator.min(T::max_value()) < epsilon
}
}
}
fn determine_start_point<T, U>(k: U, z: Complex<T>) -> Complex<T>
where
U: Signed + Copy,
T: Float
+ AsCastFrom<f64>
+ From<U>
+ Mul<Complex<T>, Output = Complex<T>>
+ Add<Complex<T>, Output = Complex<T>>
+ Sub<Complex<T>, Output = Complex<T>>,
Complex<T>: ComplexFloat
+ SubAssign
+ Mul<T, Output = Complex<T>>
+ Add<T, Output = Complex<T>>
+ Sub<T, Output = Complex<T>>,
{
let i_zero = U::zero();
let i_one = U::one();
let d_zero = T::zero();
let d_one = T::one();
let d_two = d_one + d_one;
let d_half = d_one / d_two;
let d_e: T = T::as_cast_from(E);
let d_pi: T = T::as_cast_from(PI);
let d_neg_inv_e: T = T::as_cast_from(NEG_INV_E);
let i = Complex::<T>::i();
let z_one = Complex::<T>::from(d_one);
let z_two = z_one + z_one;
let z_neg_inv_e = Complex::<T>::from(d_neg_inv_e);
let z_half = z_one / z_two;
let abs_one = z_one.abs();
let abs_half = z_half.abs();
let two_pi_k_i = d_two * d_pi * <T as From<U>>::from(k) * i;
let mut initial_point = z.ln() + two_pi_k_i - (z.ln() + two_pi_k_i).ln();
if (z - z_neg_inv_e).abs() <= abs_one {
let p = (d_two * (d_e * z + d_one)).sqrt();
let p2 = T::as_cast_from(1.0 / 3.0) * p * p;
let p3 = T::as_cast_from(11.0 / 72.0) * p * p * p;
if k == i_zero {
initial_point = -d_one + p - p2 + p3;
} else if (k == i_one && z.im < d_zero) || (k == -i_one && z.im > d_zero) {
initial_point = -d_one - p - p2 - p3;
}
}
if k == i_zero && (z - d_half).abs() <= abs_half {
initial_point = (T::as_cast_from(0.351_733_71)
* (T::as_cast_from(0.123_716_6) + T::as_cast_from(7.061_302_897) * z))
/ (d_two + T::as_cast_from(0.827_184) * (d_one + d_two * z));
}
if k == -i_one && (z - d_half).abs() <= abs_half {
initial_point = -(((T::as_cast_from(2.259_158_898_5) + T::as_cast_from(4.220_96) * i)
* ((T::as_cast_from(-14.073_271) - T::as_cast_from(33.767_687_754) * i) * z
- (T::as_cast_from(12.712_7) - T::as_cast_from(19.071_643) * i)
* (d_one + d_two * z)))
/ (d_two
- (T::as_cast_from(17.231_03) - T::as_cast_from(10.629_721) * i)
* (d_one + d_two * z)));
}
initial_point
}
trait AsCastFrom<U> {
fn as_cast_from(x: f64) -> Self;
}
impl AsCastFrom<f64> for f32 {
#[inline]
fn as_cast_from(x: f64) -> f32 {
x as f32
}
}
impl AsCastFrom<f64> for f64 {
#[inline]
fn as_cast_from(x: f64) -> f64 {
x
}
}