use scirs2_core::numeric::Complex64;
use scirs2_core::numeric::Zero;
use std::f64::consts::{E, PI};
use crate::error::{SpecialError, SpecialResult};
const _EXPN1: f64 = E; const EXPN1_INV: f64 = 1.0 / E; const TWO_PI: f64 = 2.0 * PI;
const MAX_ITERATIONS: usize = 100;
#[allow(dead_code)]
pub fn lambert_w(z: Complex64, k: i32, tol: f64) -> SpecialResult<Complex64> {
if z.is_nan() {
return Ok(Complex64::new(f64::NAN, f64::NAN));
}
if z.is_infinite() {
if k == 0 {
return Ok(Complex64::new(f64::INFINITY, 0.0));
} else if k == 1 {
return Ok(Complex64::new(f64::INFINITY, TWO_PI));
} else if k == -1 {
return Ok(Complex64::new(f64::INFINITY, 3.0 * PI));
} else {
let imag = (2.0 * k as f64 + 1.0) * PI;
return Ok(Complex64::new(f64::INFINITY, imag));
}
}
if k == 0 && z.norm() < 1e-300 {
return Ok(z);
}
if z.is_zero() {
if k == 0 {
return Ok(Complex64::new(0.0, 0.0));
} else {
return Ok(Complex64::new(f64::NEG_INFINITY, 0.0));
}
}
let mut w = initial_guess(z, k);
for _ in 0..MAX_ITERATIONS {
if w.re > 700.0 {
return Ok(w); }
let ew = w.exp();
let wew = w * ew;
let wewz = wew - z;
let abs_tol = tol.max(1e-15);
let rel_tol = tol * w.norm().max(1.0);
if wewz.norm() < abs_tol || wewz.norm() < rel_tol {
break;
}
let w1 = w + Complex64::new(1.0, 0.0);
let w1ew = w1 * ew;
let denominator =
w1ew - (w + Complex64::new(2.0, 0.0)) * wewz / (Complex64::new(2.0, 0.0) * w1);
if denominator.norm() < 1e-15 {
let safe_step = Complex64::new(0.1, 0.0)
* if w.norm() > 1.0 {
w / w.norm()
} else {
Complex64::new(1.0, 0.0)
};
w -= safe_step;
} else {
let delta = wewz / denominator;
let delta_norm = delta.norm();
if delta_norm > 10.0 {
w -= delta * (10.0 / delta_norm);
} else {
w -= delta;
}
}
}
Ok(w)
}
#[allow(dead_code)]
fn initial_guess(z: Complex64, k: i32) -> Complex64 {
if (z + EXPN1_INV).norm() < 0.3 && (k == 0 || k == -1) {
let p = (2.0 * (E * z + 1.0)).sqrt();
if k == 0 {
return Complex64::new(-1.0, 0.0) + p - p.powi(2) / 3.0;
} else {
return Complex64::new(-1.0, 0.0) - p - p.powi(2) / 3.0;
}
}
if z.norm() > 3.0 {
let mut w = z.ln();
if w.is_zero() {
w = Complex64::new(1e-300, 0.0);
}
w -= w.ln().ln();
if k != 0 {
w += Complex64::new(0.0, TWO_PI * k as f64);
}
return w;
}
if k == 0 && z.norm() < 1.0 {
let p = [1.0, 2.331_643_981_597_124, 1.812_187_885_639_363_4, 0.1];
let q = [1.0, 3.331_643_981_597_124, 1.812_187_885_639_363_4];
let numerator = p[0] + z * (p[1] + z * (p[2] + z * p[3]));
let denominator = q[0] + z * (q[1] + z * q[2]);
return numerator / denominator;
}
let mut w = z.ln();
if w.is_zero() {
w = Complex64::new(1e-300, 0.0);
}
if k != 0 {
w += Complex64::new(0.0, TWO_PI * k as f64);
}
w
}
#[allow(dead_code)]
pub fn lambert_w_real(x: f64, tol: f64) -> SpecialResult<f64> {
let result = lambert_w(Complex64::new(x, 0.0), 0, tol)?;
if x > -EXPN1_INV && result.im.abs() < 1e-15 {
Ok(result.re)
} else {
Err(SpecialError::DomainError(format!(
"Lambert W function gives a complex result for x={x}"
)))
}
}