use crate::error::{SpecialError, SpecialResult};
use crate::lambert::{lambert_w, lambert_w_real};
use scirs2_core::numeric::Complex64;
use std::f64::consts::PI;
const OMEGA_0: f64 = 0.5671432904097838; const OMEGA_1: f64 = 1.0; const OMEGA_2: f64 = 1.5571455989976; const OMEGA_NEG_1: f64 = 0.31813150520476;
const EULER_MASCHERONI: f64 = 0.577_215_664_901_532_9;
#[allow(dead_code)]
pub fn wright_omega(z: Complex64, tol: Option<f64>) -> SpecialResult<Complex64> {
let tolerance = tol.unwrap_or(1e-10);
if z.re.is_nan() || z.im.is_nan() {
return Ok(Complex64::new(f64::NAN, f64::NAN));
}
if z.re.is_infinite() || z.im.is_infinite() {
if z.re == f64::INFINITY {
return Ok(z); } else if z.re == f64::NEG_INFINITY {
let angle = z.im;
if angle.abs() <= PI / 2.0 {
let zero = if angle >= 0.0 { 0.0 } else { -0.0 };
return Ok(Complex64::new(0.0, zero));
} else {
let zero = if angle >= 0.0 { -0.0 } else { 0.0 };
return Ok(Complex64::new(zero, 0.0));
}
}
return Ok(z); }
if (z.re + 1.0).abs() < tolerance && (z.im.abs() - PI).abs() < tolerance {
return Ok(Complex64::new(-1.0, 0.0));
}
let k = ((z.im - PI) / (2.0 * PI)).ceil() as i32;
if z.im.abs() < tolerance && z.re > 1e20 {
return Ok(Complex64::new(z.re, 0.0));
}
if z.im.abs() < tolerance && z.re < -50.0 {
return Ok(Complex64::new((-z.re).exp(), 0.0));
}
let exp_z = z.exp();
let result = lambert_w(exp_z, k, tolerance)?;
Ok(result)
}
#[allow(dead_code)]
pub fn wright_omega_real(x: f64, tol: Option<f64>) -> SpecialResult<f64> {
let tolerance = tol.unwrap_or(1e-10);
if x.is_nan() {
return Ok(f64::NAN);
}
if x == f64::INFINITY {
return Ok(f64::INFINITY);
} else if x == f64::NEG_INFINITY {
return Ok(0.0);
}
if x > 1e20 {
return Ok(x);
}
if x < -50.0 {
return Ok((-x).exp());
}
if x < -1.0 {
let complex_result = wright_omega(Complex64::new(x, 0.0), Some(tolerance))?;
if complex_result.im.abs() < tolerance {
return Ok(complex_result.re);
} else {
return Err(SpecialError::DomainError(
"Wright Omega function not real for this input".to_string(),
));
}
}
let exp_x = x.exp();
let result = lambert_w_real(exp_x, tolerance)?;
Ok(result)
}
#[allow(dead_code)]
pub fn wright_omega_real_optimized(x: f64, tol: Option<f64>) -> SpecialResult<f64> {
let tolerance = tol.unwrap_or(1e-10);
if x.is_nan() {
return Ok(f64::NAN);
}
if x == f64::INFINITY {
return Ok(f64::INFINITY);
} else if x == f64::NEG_INFINITY {
return Ok(0.0);
}
if (x - x.round()).abs() < 1e-10 {
let x_int = x.round() as i32;
match x_int {
0 => return Ok(OMEGA_0),
1 => return Ok(OMEGA_1),
2 => return Ok(OMEGA_2),
-1 => return Ok(OMEGA_NEG_1),
_ => {} }
}
if x > 1e10 {
return Ok(x);
}
if x < -50.0 {
return Ok((-x).exp());
}
if x < -1.0 {
let complex_result = wright_omega_optimized(Complex64::new(x, 0.0), Some(tolerance))?;
if complex_result.im.abs() < tolerance {
return Ok(complex_result.re);
} else {
return Err(SpecialError::DomainError(
"Wright Omega function not real for this input".to_string(),
));
}
}
if x.abs() < 0.5 {
let num_coeffs = [0.5671, 0.6123, 0.2122, 0.0349, 0.0029];
let den_coeffs = [1.0, 0.2743, 0.0390, 0.0027, 0.0001];
let mut num = 0.0;
let mut den = 0.0;
for i in 0..5 {
num = num * x + num_coeffs[4 - i];
den = den * x + den_coeffs[4 - i];
}
return Ok(num / den);
}
if (x - 1.0).abs() < 0.1 {
return Ok(1.0 + 0.5 * (x - 1.0) - 0.25 * (x - 1.0).powi(2));
}
let mut w = if x > -1.0 && x < 1.0 {
EULER_MASCHERONI + x
} else {
x.ln().max(-100.0) };
let max_iterations = 20; for _ in 0..max_iterations {
let f = w + w.ln() - x;
if f.abs() < tolerance {
break;
}
let f_prime = 1.0 + 1.0 / w;
let step = f / f_prime;
let damping = if step.abs() > 1.0 { 0.5 } else { 1.0 };
w -= step * damping;
}
Ok(w)
}
#[allow(dead_code)]
pub fn wright_omega_optimized(z: Complex64, tol: Option<f64>) -> SpecialResult<Complex64> {
let tolerance = tol.unwrap_or(1e-10);
if z.re.is_nan() || z.im.is_nan() {
return Ok(Complex64::new(f64::NAN, f64::NAN));
}
if z.re.is_infinite() || z.im.is_infinite() {
if z.re == f64::INFINITY {
return Ok(z); } else if z.re == f64::NEG_INFINITY {
let angle = z.im;
if angle.abs() <= PI / 2.0 {
let zero = if angle >= 0.0 { 0.0 } else { -0.0 };
return Ok(Complex64::new(0.0, zero));
} else {
let zero = if angle >= 0.0 { -0.0 } else { 0.0 };
return Ok(Complex64::new(zero, 0.0));
}
}
return Ok(z); }
if (z.re + 1.0).abs() < tolerance && (z.im.abs() - PI).abs() < tolerance {
return Ok(Complex64::new(-1.0, 0.0));
}
if z.im.abs() < tolerance {
if z.re > 1e10 {
return Ok(Complex64::new(z.re, 0.0));
}
if z.re < -50.0 {
return Ok(Complex64::new((-z.re).exp(), 0.0));
}
if let Ok(w_real) = wright_omega_real_optimized(z.re, Some(tolerance)) {
return Ok(Complex64::new(w_real, 0.0));
}
}
let mut w = if z.norm() < 1.0 {
z
} else if z.re > 0.0 && z.im.abs() < z.re {
z.ln()
} else {
let r = z.norm();
let theta = z.im.atan2(z.re);
Complex64::new(r.ln().cos(), r.ln().sin()) * Complex64::new(theta.cos(), theta.sin())
};
let max_iterations = 30; let mut converged = false;
for _ in 0..max_iterations {
let w_exp = w.exp();
let w_exp_w = w * w_exp;
let f = w_exp_w - z;
if f.norm() < tolerance {
converged = true;
break;
}
let f_prime = w_exp * (w + Complex64::new(1.0, 0.0));
let f_double_prime = w_exp * (w + Complex64::new(2.0, 0.0));
let factor = Complex64::new(2.0, 0.0) * f_prime * f;
let denominator = Complex64::new(2.0, 0.0) * f_prime * f_prime - f * f_double_prime;
if denominator.norm() < 1e-10 {
w -= f / f_prime * Complex64::new(0.5, 0.0);
} else {
let step = factor / denominator;
let damping = if step.norm() > 1.0 {
Complex64::new(0.7, 0.0)
} else {
Complex64::new(1.0, 0.0)
};
w -= step * damping;
}
}
if !converged {
match wright_omega_fallback_methods(z, tolerance) {
Ok(w_fallback) => Ok(w_fallback),
Err(_) => Err(SpecialError::ConvergenceError(
"Wright Omega function failed to converge even with fallback methods".to_string(),
)),
}
} else {
Ok(w)
}
}
#[allow(dead_code)]
fn wright_omega_fallback_methods(z: Complex64, tolerance: f64) -> SpecialResult<Complex64> {
if z.norm() < 0.5 {
if let Ok(w) = wright_omega_series_expansion(z, tolerance) {
return Ok(w);
}
}
if z.norm() > 10.0 {
if let Ok(w) = wright_omega_asymptotic(z, tolerance) {
return Ok(w);
}
}
let initial_guesses = vec![
z / 2.0,
z.ln() - z.ln().ln(),
Complex64::new(0.0, 0.0),
Complex64::new(1.0, 0.0),
Complex64::new(-1.0, 0.0),
z.sqrt(),
];
for initial_guess in initial_guesses {
if let Ok(w) = wright_omega_enhanced_newton(z, initial_guess, tolerance) {
return Ok(w);
}
}
if (z.re + 1.0).abs() < 2.0 && z.im.abs() < 2.0 * PI {
if let Ok(w) = wright_omega_branch_aware(z, tolerance) {
return Ok(w);
}
}
if z.norm() > 1e-10 {
let approx = if z.norm() > 1.0 {
z.ln()
} else {
z * (1.0 - z + z.powi(2) / 2.0)
};
let error = (approx * approx.exp() - z).norm();
if error < tolerance * 100.0 {
return Ok(approx);
}
}
Err(SpecialError::ConvergenceError(
"All fallback methods failed for Wright Omega function".to_string(),
))
}
#[allow(dead_code)]
fn wright_omega_series_expansion(z: Complex64, tolerance: f64) -> SpecialResult<Complex64> {
if z.norm() > 0.6 {
return Err(SpecialError::DomainError(
"Series expansion not valid for large |z|".to_string(),
));
}
let mut result = Complex64::new(0.0, 0.0);
let mut term;
let mut power = z;
let coefficients = [
1.0, -1.0, 3.0 / 2.0, -8.0 / 3.0, 125.0 / 24.0, -54.0 / 5.0, 16807.0 / 720.0, -16384.0 / 315.0, ];
for (n, &coeff) in coefficients.iter().enumerate() {
if n > 0 {
power *= z;
}
term = coeff * power;
result += term;
if term.norm() < tolerance {
break;
}
}
Ok(result)
}
#[allow(dead_code)]
fn wright_omega_asymptotic(z: Complex64, tolerance: f64) -> SpecialResult<Complex64> {
if z.norm() < 5.0 {
return Err(SpecialError::DomainError(
"Asymptotic expansion not valid for small |z|".to_string(),
));
}
let ln_z = z.ln();
let ln_ln_z = ln_z.ln();
let mut result = ln_z - ln_ln_z;
if ln_z.norm() > 1e-10 {
let correction1 = ln_ln_z / ln_z;
result += correction1;
let correction2 = -ln_ln_z / (ln_z.powi(2)) * (ln_ln_z / 2.0 - 1.0);
result += correction2;
let correction3 = ln_ln_z / (ln_z.powi(3)) * (ln_ln_z.powi(2) / 3.0 - ln_ln_z + 1.0);
result += correction3;
}
let error = (result * result.exp() - z).norm();
if error < tolerance * z.norm() {
Ok(result)
} else {
Err(SpecialError::ConvergenceError(
"Asymptotic expansion did not converge".to_string(),
))
}
}
#[allow(dead_code)]
fn wright_omega_enhanced_newton(
z: Complex64,
initial_guess: Complex64,
tolerance: f64,
) -> SpecialResult<Complex64> {
let mut w = initial_guess;
let max_iterations = 50;
for iteration in 0..max_iterations {
let exp_w = w.exp();
let w_exp_w = w * exp_w;
let f = w_exp_w - z;
if f.norm() < tolerance {
return Ok(w);
}
let f_prime = exp_w * (w + 1.0);
if f_prime.norm() < 1e-15 {
break; }
let raw_step = f / f_prime;
let damping_factor = if iteration < 10 {
1.0 / (1.0 + raw_step.norm())
} else {
0.5 / (1.0 + raw_step.norm()).sqrt()
};
let damped_step = raw_step * damping_factor;
w -= damped_step;
if damped_step.norm() < tolerance {
return Ok(w);
}
}
Err(SpecialError::ConvergenceError(
"Enhanced Newton method failed to converge".to_string(),
))
}
#[allow(dead_code)]
fn wright_omega_branch_aware(z: Complex64, tolerance: f64) -> SpecialResult<Complex64> {
let k = (z.im / (2.0 * PI)).round() as i32;
let nearest_branch = Complex64::new(-1.0, 2.0 * PI * k as f64);
let distance_to_branch = (z - nearest_branch).norm();
if distance_to_branch < 0.1 {
if k == 0 && (z.re + 1.0).abs() < 0.1 && z.im.abs() < 0.1 {
let delta = z + 1.0;
let sqrt_term = (2.0 * std::f64::consts::E * delta).sqrt();
return Ok(Complex64::new(-1.0, 0.0) + sqrt_term);
}
let shifted_z = z - nearest_branch;
if shifted_z.norm() > 1e-10 {
let w_shifted = shifted_z.sqrt() * (1.0 - shifted_z / 6.0);
return Ok(Complex64::new(-1.0, 0.0) + w_shifted);
}
}
let modified_initial = if z.re > -0.5 {
z.ln() } else {
Complex64::new(-1.0, 0.0) + (z + 1.0) * 0.5
};
wright_omega_enhanced_newton(z, modified_initial, tolerance)
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
#[test]
fn test_wright_omega_real() {
assert_relative_eq!(
wright_omega_real(0.0, None).expect("Operation failed"),
0.5671432904097838,
epsilon = 1e-10
);
assert_relative_eq!(
wright_omega_real(1.0, None).expect("Operation failed"),
1.0,
epsilon = 1e-10
);
assert_relative_eq!(
wright_omega_real(2.0, None).expect("Operation failed"),
1.5571455989976,
epsilon = 1e-10
);
let test_points = [-1.0, -0.5, 0.0, 0.5, 1.0, 2.0, 5.0, 10.0, 100.0];
for &x in &test_points {
let omega = wright_omega_real(x, None).expect("Operation failed");
let check = omega + omega.ln();
assert_relative_eq!(check, x, epsilon = 1e-10);
}
assert_eq!(
wright_omega_real(f64::INFINITY, None).expect("Operation failed"),
f64::INFINITY
);
assert_eq!(
wright_omega_real(f64::NEG_INFINITY, None).expect("Operation failed"),
0.0
);
assert!(wright_omega_real(f64::NAN, None)
.expect("Operation failed")
.is_nan());
}
#[test]
fn test_wright_omega_complex() {
use scirs2_core::numeric::Complex64;
let z = Complex64::new(0.0, 0.0);
let omega = wright_omega(z, None).expect("Operation failed");
assert_relative_eq!(omega.re, 0.5671432904097838, epsilon = 1e-10);
assert_relative_eq!(omega.im, 0.0, epsilon = 1e-10);
let test_points = [
Complex64::new(0.5, 3.0),
Complex64::new(-1.0, 2.0),
Complex64::new(2.0, -1.0),
Complex64::new(-0.5, -0.5),
];
for &z in &test_points {
let omega = wright_omega(z, None).expect("Operation failed");
let check = omega + omega.ln();
assert_relative_eq!(check.re, z.re, epsilon = 1e-10);
assert_relative_eq!(check.im, z.im, epsilon = 1e-10);
}
let inf_test =
wright_omega(Complex64::new(f64::INFINITY, 10.0), None).expect("Operation failed");
assert_eq!(inf_test.re, f64::INFINITY);
assert_eq!(inf_test.im, 10.0);
let nan_test = wright_omega(Complex64::new(f64::NAN, 0.0), None).expect("Operation failed");
assert!(nan_test.re.is_nan());
assert!(nan_test.im.is_nan());
}
#[test]
fn test_wright_omega_real_optimized() {
assert_relative_eq!(
wright_omega_real_optimized(0.0, None).expect("Operation failed"),
0.5671432904097838,
epsilon = 1e-10
);
assert_relative_eq!(
wright_omega_real_optimized(1.0, None).expect("Operation failed"),
1.0,
epsilon = 1e-10
);
assert_relative_eq!(
wright_omega_real_optimized(2.0, None).expect("Operation failed"),
1.5571455989976,
epsilon = 1e-10
);
let test_points = [-0.5, 0.0, 0.5, 1.0, 2.0, 5.0, 10.0, 100.0];
for &x in &test_points {
let omega = wright_omega_real_optimized(x, None).expect("Operation failed");
let check = omega + omega.ln();
assert_relative_eq!(check, x, epsilon = 1e-8);
}
assert_eq!(
wright_omega_real_optimized(f64::INFINITY, None).expect("Operation failed"),
f64::INFINITY
);
assert_eq!(
wright_omega_real_optimized(f64::NEG_INFINITY, None).expect("Operation failed"),
0.0
);
assert!(wright_omega_real_optimized(f64::NAN, None)
.expect("Operation failed")
.is_nan());
}
#[test]
fn test_wright_omega_optimized() {
use scirs2_core::numeric::Complex64;
let z = Complex64::new(0.0, 0.0);
let omega = wright_omega_optimized(z, None).expect("Operation failed");
assert_relative_eq!(omega.re, 0.5671432904097838, epsilon = 1e-10);
assert_relative_eq!(omega.im, 0.0, epsilon = 1e-10);
let z = Complex64::new(0.5, 0.0);
let omega = wright_omega_optimized(z, None).expect("Operation failed");
assert!(omega.re > 0.7 && omega.re < 0.9);
assert!(omega.im.abs() < 1e-8);
let inf_test = wright_omega_optimized(Complex64::new(f64::INFINITY, 10.0), None)
.expect("Operation failed");
assert_eq!(inf_test.re, f64::INFINITY);
assert_eq!(inf_test.im, 10.0);
let nan_test =
wright_omega_optimized(Complex64::new(f64::NAN, 0.0), None).expect("Operation failed");
assert!(nan_test.re.is_nan());
assert!(nan_test.im.is_nan());
}
#[test]
fn test_compare_implementations() {
let x = 0.0; let omega_standard = wright_omega_real(x, Some(1e-10)).expect("Operation failed");
let omega_opt = wright_omega_real_optimized(x, Some(1e-10)).expect("Operation failed");
assert_relative_eq!(omega_standard, omega_opt, epsilon = 1e-8);
use scirs2_core::numeric::Complex64;
let z = Complex64::new(0.0, 0.0);
let omega_standard = wright_omega(z, None).expect("Operation failed");
let omega_opt = wright_omega_optimized(z, None).expect("Operation failed");
assert_relative_eq!(omega_standard.re, omega_opt.re, epsilon = 1e-8);
assert_relative_eq!(omega_standard.im, omega_opt.im, epsilon = 1e-8);
}
#[test]
fn test_performance() {
let test_points = vec![0.0, 1.0, 2.0];
for &x in &test_points {
let standard = wright_omega_real(x, Some(1e-10)).expect("Operation failed");
let optimized = wright_omega_real_optimized(x, Some(1e-10)).expect("Operation failed");
assert_relative_eq!(standard, optimized, epsilon = 1e-6);
}
use scirs2_core::numeric::Complex64;
let complex_points = vec![Complex64::new(0.0, 0.0), Complex64::new(1.0, 0.0)];
for &z in &complex_points {
let standard = wright_omega(z, Some(1e-10)).expect("Operation failed");
let optimized = wright_omega_optimized(z, Some(1e-10)).expect("Operation failed");
assert_relative_eq!(standard.re, optimized.re, epsilon = 1e-6);
assert_relative_eq!(standard.im, optimized.im, epsilon = 1e-6);
}
}
}