1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
//! Lambert W function implementation
//!
//! The Lambert W function is the inverse function of `w * exp(w)`.
//! It's a multivalued function with infinitely many branches,
//! indexed by the integer k.
//!
//! This implementation follows the approach from SciPy's lambertw function,
//! using Halley's iteration with initial guesses from asymptotic approximations.
use scirs2_core::numeric::Complex64;
use scirs2_core::numeric::Zero;
use std::f64::consts::{E, PI};
use crate::error::{SpecialError, SpecialResult};
const _EXPN1: f64 = E; // e (prefixed with _ since not currently used)
const EXPN1_INV: f64 = 1.0 / E; // 1/e
const TWO_PI: f64 = 2.0 * PI;
const MAX_ITERATIONS: usize = 100;
/// Lambert W function for real and complex arguments.
///
/// The Lambert W function W(z) is defined as the inverse function of w * exp(w).
/// In other words, the value of W(z) is such that z = W(z) * exp(W(z)) for any complex number z.
///
/// # Arguments
///
/// * `z` - Complex input argument
/// * `k` - Branch index (integer)
/// * `tol` - Tolerance for convergence (default = 1e-8)
///
/// # Returns
///
/// * `Complex64` - The value of W(z) on branch k
///
/// # Examples
///
/// ```
/// use scirs2_special::lambert_w;
/// use scirs2_core::numeric::Complex64;
///
/// let w = lambert_w(Complex64::new(1.0, 0.0), 0, 1e-8).unwrap();
/// assert!((w - Complex64::new(0.56714329040978384, 0.0)).norm() < 1e-10);
///
/// // Verify that w * exp(w) = z
/// let z = Complex64::new(1.0, 0.0);
/// let w_exp_w = w * w.exp();
/// assert!((w_exp_w - z).norm() < 1e-10);
/// ```
#[allow(dead_code)]
pub fn lambert_w(z: Complex64, k: i32, tol: f64) -> SpecialResult<Complex64> {
if z.is_nan() {
return Ok(Complex64::new(f64::NAN, f64::NAN));
}
// Special cases for infinite inputs
if z.is_infinite() {
if k == 0 {
return Ok(Complex64::new(f64::INFINITY, 0.0));
} else if k == 1 {
return Ok(Complex64::new(f64::INFINITY, TWO_PI));
} else if k == -1 {
return Ok(Complex64::new(f64::INFINITY, 3.0 * PI));
} else {
// For other branches with infinite input
// This may not match SciPy for all k values, but it's a reasonable default
let imag = (2.0 * k as f64 + 1.0) * PI;
return Ok(Complex64::new(f64::INFINITY, imag));
}
}
// For k=0 branch and very small inputs, the result is approximately the input
if k == 0 && z.norm() < 1e-300 {
return Ok(z);
}
// Special case for z = 0
if z.is_zero() {
if k == 0 {
return Ok(Complex64::new(0.0, 0.0));
} else {
// All other branches have a logarithmic singularity at z = 0
return Ok(Complex64::new(f64::NEG_INFINITY, 0.0));
}
}
// Compute the initial guess based on the branch
let mut w = initial_guess(z, k);
// Halley's iteration to refine the result
for _ in 0..MAX_ITERATIONS {
// Handle extreme values of w carefully to avoid overflow
if w.re > 700.0 {
// For very large w.re, exp(w) would overflow, so handle specially
return Ok(w); // At these extreme values, further refinement is unlikely to help
}
let ew = w.exp();
let wew = w * ew;
let wewz = wew - z;
// Check if we've converged
// Using both absolute and relative tolerance for better stability
let abs_tol = tol.max(1e-15);
let rel_tol = tol * w.norm().max(1.0);
if wewz.norm() < abs_tol || wewz.norm() < rel_tol {
break;
}
// Compute the next iteration using Halley's method
// The formula is: w_next = w - f(w)/f'(w) * (1 + f(w)*f''(w)/(2*f'(w)^2))
// where f(w) = w*e^w - z
let w1 = w + Complex64::new(1.0, 0.0);
let w1ew = w1 * ew;
let denominator =
w1ew - (w + Complex64::new(2.0, 0.0)) * wewz / (Complex64::new(2.0, 0.0) * w1);
// More robust handling of potential numerical issues
if denominator.norm() < 1e-15 {
// In case of near-zero denominator, use a dampened step
let safe_step = Complex64::new(0.1, 0.0)
* if w.norm() > 1.0 {
w / w.norm()
} else {
Complex64::new(1.0, 0.0)
};
w -= safe_step;
} else {
let delta = wewz / denominator;
// Limit step size to prevent overshooting
let delta_norm = delta.norm();
if delta_norm > 10.0 {
w -= delta * (10.0 / delta_norm);
} else {
w -= delta;
}
}
}
Ok(w)
}
/// Initial guess for the Lambert W function.
///
/// Uses different approximations depending on the branch and region:
/// 1. Near the branch point at -1/e
/// 2. Asymptotic series for large |z|
/// 3. Pade approximation around 0 for the principal branch
/// 4. General approximation for other cases
#[allow(dead_code)]
fn initial_guess(z: Complex64, k: i32) -> Complex64 {
// Near the branch point at -1/e for k=0 or k=-1
if (z + EXPN1_INV).norm() < 0.3 && (k == 0 || k == -1) {
let p = (2.0 * (E * z + 1.0)).sqrt();
if k == 0 {
return Complex64::new(-1.0, 0.0) + p - p.powi(2) / 3.0;
} else {
return Complex64::new(-1.0, 0.0) - p - p.powi(2) / 3.0;
}
}
// For large |z|, use the asymptotic series
if z.norm() > 3.0 {
let mut w = z.ln();
if w.is_zero() {
// Avoid division by zero
w = Complex64::new(1e-300, 0.0);
}
w -= w.ln().ln();
// Adjust for non-principal branches
if k != 0 {
w += Complex64::new(0.0, TWO_PI * k as f64);
}
return w;
}
// Use Pade approximation for principal branch near 0
if k == 0 && z.norm() < 1.0 {
// Coefficients from lambertw_pade in SciPy
let p = [1.0, 2.331_643_981_597_124, 1.812_187_885_639_363_4, 0.1];
let q = [1.0, 3.331_643_981_597_124, 1.812_187_885_639_363_4];
let numerator = p[0] + z * (p[1] + z * (p[2] + z * p[3]));
let denominator = q[0] + z * (q[1] + z * q[2]);
return numerator / denominator;
}
// For other cases, use a general approximation
let mut w = z.ln();
if w.is_zero() {
// Avoid division by zero
w = Complex64::new(1e-300, 0.0);
}
// For non-principal branches, add the branch offset
if k != 0 {
w += Complex64::new(0.0, TWO_PI * k as f64);
}
w
}
/// Lambert W function for real arguments on the principal branch (k=0).
///
/// # Arguments
///
/// * `x` - Real input value
/// * `tol` - Tolerance for convergence (default = 1e-8)
///
/// # Returns
///
/// * Result containing the real value of W(x) or a complex value when the result is not real
///
/// # Examples
///
/// ```
/// use scirs2_special::lambert_w_real;
///
/// let w = lambert_w_real(1.0, 1e-8).unwrap();
/// assert!((w - 0.56714329040978384).abs() < 1e-10);
/// ```
#[allow(dead_code)]
pub fn lambert_w_real(x: f64, tol: f64) -> SpecialResult<f64> {
let result = lambert_w(Complex64::new(x, 0.0), 0, tol)?;
// For the principal branch (k=0), the result is real when x > -1/e
if x > -EXPN1_INV && result.im.abs() < 1e-15 {
Ok(result.re)
} else {
Err(SpecialError::DomainError(format!(
"Lambert W function gives a complex result for x={x}"
)))
}
}