use crate::error::{SpecialError, SpecialResult};
const N_LATTICE: i32 = 40;
const POLE_TOL: f64 = 1e-10;
fn half_periods_from_invariants(g2: f64, g3: f64) -> SpecialResult<(f64, f64)> {
let roots = cubic_roots_weierstrass(g2, g3)?;
let (e1, e2, e3) = roots;
let diff13 = e1 - e3;
if diff13 < POLE_TOL {
return Err(SpecialError::ComputationError(
"degenerate lattice: e₁ = e₃".to_string(),
));
}
let k2 = (e2 - e3) / diff13;
let k = k2.sqrt().clamp(0.0, 1.0 - 1e-15);
let big_k = complete_elliptic_k(k);
let big_k_prime = complete_elliptic_k((1.0 - k2).sqrt());
let sqrt_diff = diff13.sqrt();
let omega1 = big_k / sqrt_diff;
let omega2_imag = big_k_prime / sqrt_diff;
Ok((omega1, omega2_imag))
}
fn cubic_roots_weierstrass(g2: f64, g3: f64) -> SpecialResult<(f64, f64, f64)> {
let p = g2 / 4.0;
let q = g3 / 4.0;
let delta_quarter = p * p * p / 27.0 - q * q / 4.0;
if delta_quarter > 0.0 {
let m = 2.0 * (p / 3.0).sqrt();
let theta = (3.0 * q / (p * m)).acos() / 3.0;
let two_pi_3 = 2.0 * std::f64::consts::PI / 3.0;
let t1 = m * theta.cos();
let t2 = m * (theta - two_pi_3).cos();
let t3 = m * (theta - 2.0 * two_pi_3).cos();
let mut roots = [t1, t2, t3];
roots.sort_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
Ok((roots[0], roots[1], roots[2]))
} else if delta_quarter == 0.0 {
let t1 = 3.0 * q / p;
let t2 = -3.0 * q / (2.0 * p);
if t1 >= t2 {
Ok((t1, t2, t2))
} else {
Ok((t2, t2, t1))
}
} else {
let sqrt_neg = (-delta_quarter).sqrt();
let u = (-q / 2.0 + sqrt_neg).cbrt();
let v = (-q / 2.0 - sqrt_neg).cbrt();
let t1 = u + v;
Ok((t1, t1, t1))
}
}
fn complete_elliptic_k(k: f64) -> f64 {
if k <= 0.0 {
return std::f64::consts::FRAC_PI_2;
}
if k >= 1.0 {
return f64::INFINITY;
}
let b = (1.0 - k * k).sqrt();
std::f64::consts::FRAC_PI_2 / agm(1.0, b)
}
fn agm(mut a: f64, mut b: f64) -> f64 {
for _ in 0..100 {
let a_new = (a + b) * 0.5;
let b_new = (a * b).sqrt();
if (a_new - b_new).abs() < 1e-15 * a_new.abs() {
return a_new;
}
a = a_new;
b = b_new;
}
(a + b) * 0.5
}
pub fn weierstrass_p(z: f64, g2: f64, g3: f64) -> f64 {
match weierstrass_p_impl(z, g2, g3) {
Ok(v) => v,
Err(_) => f64::NAN,
}
}
fn weierstrass_p_impl(z: f64, g2: f64, g3: f64) -> SpecialResult<f64> {
let (omega1, omega2_imag) = half_periods_from_invariants(g2, g3)?;
if z.abs() < POLE_TOL {
return Err(SpecialError::DomainError(
"℘: z is too close to the origin (pole)".to_string(),
));
}
let mut result = 1.0 / (z * z);
for m in -N_LATTICE..=N_LATTICE {
for n in -N_LATTICE..=N_LATTICE {
if m == 0 && n == 0 {
continue;
}
let omega_r = 2.0 * (m as f64) * omega1;
let omega_i = 2.0 * (n as f64) * omega2_imag;
let dr = z - omega_r;
let denom_sq = dr * dr + omega_i * omega_i;
if denom_sq < POLE_TOL * POLE_TOL {
return Err(SpecialError::DomainError(format!(
"℘: z is too close to lattice point ({m},{n})"
)));
}
let denom4 = denom_sq * denom_sq;
let one_over_zsq_re = (dr * dr - omega_i * omega_i) / denom4;
let omega_sq = omega_r * omega_r - omega_i * omega_i;
let omega_mod4 = (omega_r * omega_r + omega_i * omega_i).powi(2);
if omega_mod4 < 1e-300 {
continue;
}
let one_over_omega_sq_re = omega_sq / omega_mod4;
result += one_over_zsq_re - one_over_omega_sq_re;
}
}
Ok(result)
}
pub fn weierstrass_p_derivative(z: f64, g2: f64, g3: f64) -> f64 {
match weierstrass_p_derivative_impl(z, g2, g3) {
Ok(v) => v,
Err(_) => f64::NAN,
}
}
fn weierstrass_p_derivative_impl(z: f64, g2: f64, g3: f64) -> SpecialResult<f64> {
let (omega1, omega2_imag) = half_periods_from_invariants(g2, g3)?;
if z.abs() < POLE_TOL {
return Err(SpecialError::DomainError(
"℘': z is too close to the origin (pole)".to_string(),
));
}
let mut result = -2.0 / (z * z * z);
for m in -N_LATTICE..=N_LATTICE {
for n in -N_LATTICE..=N_LATTICE {
if m == 0 && n == 0 {
continue;
}
let omega_r = 2.0 * (m as f64) * omega1;
let omega_i = 2.0 * (n as f64) * omega2_imag;
let dr = z - omega_r;
let denom_sq = dr * dr + omega_i * omega_i;
if denom_sq < POLE_TOL * POLE_TOL {
return Err(SpecialError::DomainError(format!(
"℘': z too close to lattice point ({m},{n})"
)));
}
let oi = omega_i;
let denom6 = denom_sq.powi(3);
let re_cube = dr * dr * dr - 3.0 * dr * oi * oi;
result += -2.0 * re_cube / denom6;
}
}
Ok(result)
}
pub fn weierstrass_zeta(z: f64, g2: f64, g3: f64) -> f64 {
match weierstrass_zeta_impl(z, g2, g3) {
Ok(v) => v,
Err(_) => f64::NAN,
}
}
fn weierstrass_zeta_impl(z: f64, g2: f64, g3: f64) -> SpecialResult<f64> {
let (omega1, omega2_imag) = half_periods_from_invariants(g2, g3)?;
if z.abs() < POLE_TOL {
return Err(SpecialError::DomainError(
"ζ: z too close to origin (pole)".to_string(),
));
}
let mut result = 1.0 / z;
for m in -N_LATTICE..=N_LATTICE {
for n in -N_LATTICE..=N_LATTICE {
if m == 0 && n == 0 {
continue;
}
let omega_r = 2.0 * (m as f64) * omega1;
let omega_i = 2.0 * (n as f64) * omega2_imag;
let dr = z - omega_r;
let denom_sq = dr * dr + omega_i * omega_i;
if denom_sq < POLE_TOL * POLE_TOL {
return Err(SpecialError::DomainError(format!(
"ζ: z too close to lattice point ({m},{n})"
)));
}
let re_1_over_z_minus_omega = dr / denom_sq;
let omega_mod2 = omega_r * omega_r + omega_i * omega_i;
if omega_mod2 < 1e-300 {
continue;
}
let re_1_over_omega = omega_r / omega_mod2;
let omega_mod4 = omega_mod2 * omega_mod2;
let re_z_over_omega_sq = z * (omega_r * omega_r - omega_i * omega_i) / omega_mod4;
result += re_1_over_z_minus_omega + re_1_over_omega + re_z_over_omega_sq;
}
}
Ok(result)
}
pub fn weierstrass_sigma(z: f64, g2: f64, g3: f64) -> f64 {
match weierstrass_sigma_impl(z, g2, g3) {
Ok(v) => v,
Err(_) => f64::NAN,
}
}
fn weierstrass_sigma_impl(z: f64, g2: f64, g3: f64) -> SpecialResult<f64> {
let (omega1, omega2_imag) = half_periods_from_invariants(g2, g3)?;
let mut log_sigma_no_z = 0.0_f64;
for m in -N_LATTICE..=N_LATTICE {
for n in -N_LATTICE..=N_LATTICE {
if m == 0 && n == 0 {
continue;
}
let omega_r = 2.0 * (m as f64) * omega1;
let omega_i = 2.0 * (n as f64) * omega2_imag;
let omega_mod2 = omega_r * omega_r + omega_i * omega_i;
if omega_mod2 < 1e-300 {
continue;
}
let re_z_over_omega = z * omega_r / omega_mod2;
let im_z_over_omega = -z * omega_i / omega_mod2;
let re_1mz = 1.0 - re_z_over_omega;
let im_1mz = -im_z_over_omega;
let mod1mz = (re_1mz * re_1mz + im_1mz * im_1mz).sqrt();
if mod1mz < 1e-300 {
return Ok(0.0);
}
let arg1mz = im_1mz.atan2(re_1mz);
let omega_mod4 = omega_mod2 * omega_mod2;
let re_inv_omega2 = (omega_r * omega_r - omega_i * omega_i) / omega_mod4;
let im_inv_omega2 = -2.0 * omega_r * omega_i / omega_mod4;
let re_exp_arg = re_z_over_omega + 0.5 * z * z * re_inv_omega2;
let im_exp_arg = im_z_over_omega + 0.5 * z * z * im_inv_omega2;
log_sigma_no_z += mod1mz.ln() + re_exp_arg;
let _ = (arg1mz + im_exp_arg); }
}
let sigma = z * log_sigma_no_z.exp();
Ok(sigma)
}
pub fn lattice_invariants(omega1: f64, omega2_imag: f64) -> (f64, f64) {
let mut sum4 = 0.0_f64;
let mut sum6 = 0.0_f64;
for m in -N_LATTICE..=N_LATTICE {
for n in -N_LATTICE..=N_LATTICE {
if m == 0 && n == 0 {
continue;
}
let omega_r = 2.0 * (m as f64) * omega1;
let omega_i = 2.0 * (n as f64) * omega2_imag;
let omega_mod2 = omega_r * omega_r + omega_i * omega_i;
if omega_mod2 < 1e-300 {
continue;
}
let omega_mod4 = omega_mod2 * omega_mod2;
let omega_mod6 = omega_mod4 * omega_mod2;
let re_omega4 = omega_r * omega_r * omega_r * omega_r
- 6.0 * omega_r * omega_r * omega_i * omega_i
+ omega_i * omega_i * omega_i * omega_i;
sum4 += re_omega4 / (omega_mod4 * omega_mod4);
let re_omega6 = omega_r.powi(6)
- 15.0 * omega_r.powi(4) * omega_i * omega_i
+ 15.0 * omega_r * omega_r * omega_i.powi(4)
- omega_i.powi(6);
sum6 += re_omega6 / (omega_mod6 * omega_mod6);
}
}
let g2 = 60.0 * sum4;
let g3 = 140.0 * sum6;
(g2, g3)
}
pub fn discriminant(g2: f64, g3: f64) -> f64 {
g2 * g2 * g2 - 27.0 * g3 * g3
}
pub fn j_invariant(g2: f64, g3: f64) -> f64 {
let delta = discriminant(g2, g3);
if delta.abs() < 1e-300 {
return f64::INFINITY;
}
1728.0 * g2 * g2 * g2 / delta
}
pub fn check_differential_equation(z: f64, g2: f64, g3: f64) -> f64 {
let p = weierstrass_p(z, g2, g3);
let dp = weierstrass_p_derivative(z, g2, g3);
if !p.is_finite() || !dp.is_finite() {
return f64::NAN;
}
let lhs = dp * dp;
let rhs = 4.0 * p * p * p - g2 * p - g3;
(lhs - rhs).abs()
}
#[cfg(test)]
mod tests {
use super::*;
const EPS_COARSE: f64 = 1e-5; const EPS_MED: f64 = 1e-4;
#[test]
fn test_discriminant_lemniscate() {
let d = discriminant(4.0, 0.0);
assert!((d - 64.0).abs() < 1e-10, "Δ = {d}");
}
#[test]
fn test_discriminant_equianharmonic() {
let g3 = 4.0_f64;
let d = discriminant(0.0, g3);
assert!((d - (-27.0 * g3 * g3)).abs() < 1e-10, "Δ = {d}");
}
#[test]
fn test_j_invariant_lemniscate() {
let j = j_invariant(4.0, 0.0);
assert!((j - 1728.0).abs() < 1e-6, "j = {j}");
}
#[test]
fn test_lattice_invariants_finite() {
let (g2, g3) = lattice_invariants(1.0, 1.0);
assert!(g2.is_finite(), "g2 not finite: {g2}");
assert!(g3.is_finite(), "g3 not finite: {g3}");
}
#[test]
fn test_weierstrass_p_finite() {
let v = weierstrass_p(0.5, 1.0, 0.0);
assert!(v.is_finite(), "℘ not finite: {v}");
}
#[test]
fn test_weierstrass_p_near_pole() {
let v = weierstrass_p(1e-12, 1.0, 0.0);
assert!(v.is_nan() || v.abs() > 1e10, "Expected large or NaN near pole");
}
#[test]
fn test_weierstrass_p_derivative_odd() {
let g2 = 1.0_f64;
let g3 = 0.0_f64;
let z = 0.4_f64;
let dp_pos = weierstrass_p_derivative(z, g2, g3);
let dp_neg = weierstrass_p_derivative(-z, g2, g3);
if dp_pos.is_finite() && dp_neg.is_finite() {
assert!(
(dp_pos + dp_neg).abs() < EPS_MED,
"℘' not odd: dp({z})={dp_pos}, dp(-{z})={dp_neg}"
);
}
}
#[test]
fn test_weierstrass_zeta_finite() {
let v = weierstrass_zeta(0.5, 1.0, 0.0);
assert!(v.is_finite(), "ζ not finite: {v}");
}
#[test]
fn test_weierstrass_sigma_odd() {
let z = 0.3_f64;
let sp = weierstrass_sigma(z, 1.0, 0.0);
let sm = weierstrass_sigma(-z, 1.0, 0.0);
if sp.is_finite() && sm.is_finite() {
assert!(
(sp + sm).abs() < EPS_MED,
"σ not odd: σ({z})={sp}, σ(-{z})={sm}"
);
}
}
#[test]
fn test_weierstrass_sigma_zero_at_origin() {
let v = weierstrass_sigma(1e-8, 1.0, 0.0);
assert!(v.abs() < 1e-6, "σ near 0: {v}");
}
#[test]
fn test_weierstrass_ode_lemniscate() {
let residual = check_differential_equation(0.5, 4.0, 0.0);
assert!(
!residual.is_nan() && residual < 1.0,
"ODE residual too large: {residual}"
);
}
#[test]
fn test_cubic_roots_sum_zero() {
let g2 = 3.0_f64;
let g3 = 1.0_f64;
match cubic_roots_weierstrass(g2, g3) {
Ok((e1, e2, e3)) => {
assert!(
(e1 + e2 + e3).abs() < 1e-12,
"roots sum = {}",
e1 + e2 + e3
);
}
Err(_) => {}
}
}
}