use crate::error::{SpecialError, SpecialResult};
const THETA_TOL: f64 = 1e-15;
const THETA_MAX_TERMS: usize = 500;
#[inline]
fn validate_nome(q: f64) -> SpecialResult<()> {
if !(0.0..1.0).contains(&q) {
return Err(SpecialError::DomainError(format!(
"theta: nome q must satisfy 0 ≤ q < 1, got q = {q}"
)));
}
Ok(())
}
pub fn theta1(z: f64, q: f64) -> f64 {
theta1_impl(z, q).unwrap_or(f64::NAN)
}
fn theta1_impl(z: f64, q: f64) -> SpecialResult<f64> {
validate_nome(q)?;
if q < 1e-300 {
return Ok(0.0);
}
let mut result = 0.0_f64;
for n in 0..THETA_MAX_TERMS {
let nf = n as f64;
let exp = (nf + 0.5) * (nf + 0.5);
let q_pow = q.powf(exp);
if q_pow < THETA_TOL {
break;
}
let sign = if n % 2 == 0 { 1.0_f64 } else { -1.0_f64 };
result += sign * q_pow * ((2.0 * nf + 1.0) * z).sin();
}
Ok(2.0 * result)
}
pub fn theta2(z: f64, q: f64) -> f64 {
theta2_impl(z, q).unwrap_or(f64::NAN)
}
fn theta2_impl(z: f64, q: f64) -> SpecialResult<f64> {
validate_nome(q)?;
if q < 1e-300 {
return Ok(0.0);
}
let mut result = 0.0_f64;
for n in 0..THETA_MAX_TERMS {
let nf = n as f64;
let exp = (nf + 0.5) * (nf + 0.5);
let q_pow = q.powf(exp);
if q_pow < THETA_TOL {
break;
}
result += q_pow * ((2.0 * nf + 1.0) * z).cos();
}
Ok(2.0 * result)
}
pub fn theta3(z: f64, q: f64) -> f64 {
theta3_impl(z, q).unwrap_or(f64::NAN)
}
fn theta3_impl(z: f64, q: f64) -> SpecialResult<f64> {
validate_nome(q)?;
if q < 1e-300 {
return Ok(1.0);
}
let mut result = 1.0_f64;
for n in 1..=THETA_MAX_TERMS {
let nf = n as f64;
let q_pow = q.powf(nf * nf);
if q_pow < THETA_TOL {
break;
}
result += 2.0 * q_pow * (2.0 * nf * z).cos();
}
Ok(result)
}
pub fn theta4(z: f64, q: f64) -> f64 {
theta4_impl(z, q).unwrap_or(f64::NAN)
}
fn theta4_impl(z: f64, q: f64) -> SpecialResult<f64> {
validate_nome(q)?;
if q < 1e-300 {
return Ok(1.0);
}
let mut result = 1.0_f64;
for n in 1..=THETA_MAX_TERMS {
let nf = n as f64;
let sign = if n % 2 == 0 { 1.0_f64 } else { -1.0_f64 };
let q_pow = q.powf(nf * nf);
if q_pow < THETA_TOL {
break;
}
result += 2.0 * sign * q_pow * (2.0 * nf * z).cos();
}
Ok(result)
}
pub fn theta1_derivative(z: f64, q: f64) -> f64 {
theta1_derivative_impl(z, q).unwrap_or(f64::NAN)
}
fn theta1_derivative_impl(z: f64, q: f64) -> SpecialResult<f64> {
validate_nome(q)?;
if q < 1e-300 {
return Ok(0.0);
}
let mut result = 0.0_f64;
for n in 0..THETA_MAX_TERMS {
let nf = n as f64;
let exp = (nf + 0.5) * (nf + 0.5);
let q_pow = q.powf(exp);
if q_pow.abs() < THETA_TOL {
break;
}
let sign = if n % 2 == 0 { 1.0_f64 } else { -1.0_f64 };
let freq = 2.0 * nf + 1.0;
result += sign * freq * q_pow * (freq * z).cos();
}
Ok(2.0 * result)
}
pub fn theta2_derivative(z: f64, q: f64) -> f64 {
theta2_derivative_impl(z, q).unwrap_or(f64::NAN)
}
fn theta2_derivative_impl(z: f64, q: f64) -> SpecialResult<f64> {
validate_nome(q)?;
if q < 1e-300 {
return Ok(0.0);
}
let mut result = 0.0_f64;
for n in 0..THETA_MAX_TERMS {
let nf = n as f64;
let exp = (nf + 0.5) * (nf + 0.5);
let q_pow = q.powf(exp);
if q_pow < THETA_TOL {
break;
}
let freq = 2.0 * nf + 1.0;
result += freq * q_pow * (freq * z).sin();
}
Ok(-2.0 * result)
}
pub fn theta3_derivative(z: f64, q: f64) -> f64 {
theta3_derivative_impl(z, q).unwrap_or(f64::NAN)
}
fn theta3_derivative_impl(z: f64, q: f64) -> SpecialResult<f64> {
validate_nome(q)?;
if q < 1e-300 {
return Ok(0.0);
}
let mut result = 0.0_f64;
for n in 1..=THETA_MAX_TERMS {
let nf = n as f64;
let q_pow = q.powf(nf * nf);
if q_pow < THETA_TOL {
break;
}
result += nf * q_pow * (2.0 * nf * z).sin();
}
Ok(-4.0 * result)
}
pub fn theta4_derivative(z: f64, q: f64) -> f64 {
theta4_derivative_impl(z, q).unwrap_or(f64::NAN)
}
fn theta4_derivative_impl(z: f64, q: f64) -> SpecialResult<f64> {
validate_nome(q)?;
if q < 1e-300 {
return Ok(0.0);
}
let mut result = 0.0_f64;
for n in 1..=THETA_MAX_TERMS {
let nf = n as f64;
let sign = if n % 2 == 0 { 1.0_f64 } else { -1.0_f64 };
let q_pow = q.powf(nf * nf);
if q_pow < THETA_TOL {
break;
}
result += sign * nf * q_pow * (2.0 * nf * z).sin();
}
Ok(-4.0 * result)
}
pub fn q_from_k(k: f64) -> f64 {
if !(0.0..1.0).contains(&k) {
return f64::NAN;
}
if k == 0.0 {
return 0.0;
}
let k_prime = (1.0 - k * k).sqrt();
let big_k = agm_elliptic_k(k);
let big_k_prime = agm_elliptic_k(k_prime);
(-std::f64::consts::PI * big_k_prime / big_k).exp()
}
pub fn k_from_q(q: f64) -> f64 {
if !(0.0..1.0).contains(&q) {
return f64::NAN;
}
if q == 0.0 {
return 0.0;
}
let t2 = theta2(0.0, q);
let t3 = theta3(0.0, q);
if t3.abs() < 1e-300 {
return f64::NAN;
}
(t2 / t3) * (t2 / t3)
}
fn agm_elliptic_k(k: f64) -> f64 {
if k <= 0.0 {
return std::f64::consts::FRAC_PI_2;
}
if k >= 1.0 {
return f64::INFINITY;
}
let k_prime_sq = 1.0 - k * k;
let b = k_prime_sq.sqrt();
let agm_val = agm(1.0, b);
std::f64::consts::FRAC_PI_2 / agm_val
}
fn agm(mut a: f64, mut b: f64) -> f64 {
for _ in 0..100 {
let a_new = (a + b) * 0.5;
let b_new = (a * b).sqrt();
if (a_new - b_new).abs() < 1e-15 * a_new.abs() {
return a_new;
}
a = a_new;
b = b_new;
}
(a + b) * 0.5
}
pub fn theta1_log_derivative(z: f64, q: f64) -> f64 {
let t1 = theta1(z, q);
let dt1 = theta1_derivative(z, q);
if t1.abs() < 1e-300 {
return f64::NAN;
}
dt1 / t1
}
pub fn complementary_nome(q: f64) -> f64 {
if !(0.0..1.0).contains(&q) {
return f64::NAN;
}
if q == 0.0 {
return 1.0;
}
let k = k_from_q(q);
if !k.is_finite() {
return f64::NAN;
}
let k_prime = (1.0 - k * k).max(0.0).sqrt();
q_from_k(k_prime)
}
pub fn jacobi_sn(u: f64, k: f64) -> f64 {
if !(0.0..1.0).contains(&k) {
return f64::NAN;
}
if k == 0.0 {
return u.sin();
}
let q = q_from_k(k);
let t3_0 = theta3(0.0, q);
if t3_0.abs() < 1e-300 {
return f64::NAN;
}
let big_k = std::f64::consts::FRAC_PI_2 * t3_0 * t3_0;
let z = std::f64::consts::FRAC_PI_2 * u / big_k;
let t1_z = theta1(z, q);
let t4_z = theta4(z, q);
let t2_0 = theta2(0.0, q);
if t2_0.abs() < 1e-300 || t4_z.abs() < 1e-300 {
return f64::NAN;
}
(t3_0 / t2_0) * (t1_z / t4_z)
}
pub fn jacobi_cn(u: f64, k: f64) -> f64 {
if !(0.0..1.0).contains(&k) {
return f64::NAN;
}
if k == 0.0 {
return u.cos();
}
let q = q_from_k(k);
let t3_0 = theta3(0.0, q);
if t3_0.abs() < 1e-300 {
return f64::NAN;
}
let big_k = std::f64::consts::FRAC_PI_2 * t3_0 * t3_0;
let z = std::f64::consts::FRAC_PI_2 * u / big_k;
let t2_z = theta2(z, q);
let t4_z = theta4(z, q);
let t2_0 = theta2(0.0, q);
let t4_0 = theta4(0.0, q);
if t2_0.abs() < 1e-300 || t4_z.abs() < 1e-300 {
return f64::NAN;
}
(t4_0 / t2_0) * (t2_z / t4_z)
}
pub fn jacobi_dn(u: f64, k: f64) -> f64 {
if !(0.0..1.0).contains(&k) {
return f64::NAN;
}
if k == 0.0 {
return 1.0;
}
let q = q_from_k(k);
let t3_0 = theta3(0.0, q);
if t3_0.abs() < 1e-300 {
return f64::NAN;
}
let big_k = std::f64::consts::FRAC_PI_2 * t3_0 * t3_0;
let z = std::f64::consts::FRAC_PI_2 * u / big_k;
let t3_z = theta3(z, q);
let t4_z = theta4(z, q);
let t4_0 = theta4(0.0, q);
if t3_0.abs() < 1e-300 || t4_z.abs() < 1e-300 {
return f64::NAN;
}
(t4_0 / t3_0) * (t3_z / t4_z)
}
#[cfg(test)]
mod tests {
use super::*;
const EPS: f64 = 1e-12;
#[test]
fn test_theta1_zero_z() {
for &q in &[0.0, 0.1, 0.3, 0.5, 0.7] {
assert!(
theta1(0.0, q).abs() < EPS,
"theta1(0, {q}) should be 0, got {}",
theta1(0.0, q)
);
}
}
#[test]
fn test_theta1_odd_symmetry() {
let q = 0.3;
for &z in &[0.2, 0.5, 1.0, 1.5] {
let pos = theta1(z, q);
let neg = theta1(-z, q);
assert!(
(pos + neg).abs() < EPS,
"theta1 odd symmetry failed at z={z}: {pos} + {neg} = {}",
pos + neg
);
}
}
#[test]
fn test_theta2_even_symmetry() {
let q = 0.2;
for &z in &[0.1, 0.4, 0.8, 1.2] {
let pos = theta2(z, q);
let neg = theta2(-z, q);
assert!(
(pos - neg).abs() < EPS,
"theta2 even symmetry failed at z={z}"
);
}
}
#[test]
fn test_theta2_positive_at_zero() {
for &q in &[0.05, 0.2, 0.5] {
assert!(theta2(0.0, q) > 0.0, "theta2(0, {q}) should be positive");
}
}
#[test]
fn test_theta3_q0() {
assert!((theta3(0.0, 0.0) - 1.0).abs() < EPS);
}
#[test]
fn test_theta3_even_symmetry() {
let q = 0.15;
for &z in &[0.3, 0.7, 1.1] {
assert!(
(theta3(z, q) - theta3(-z, q)).abs() < EPS,
"theta3 even symmetry failed at z={z}"
);
}
}
#[test]
fn test_theta4_q0() {
assert!((theta4(0.0, 0.0) - 1.0).abs() < EPS);
}
#[test]
fn test_theta1_derivative_z0() {
let d = theta1_derivative(0.0, 0.1);
assert!(d > 0.0);
}
#[test]
fn test_theta2_derivative_z0() {
assert!(theta2_derivative(0.0, 0.3).abs() < EPS);
}
#[test]
fn test_theta3_derivative_z0() {
assert!(theta3_derivative(0.0, 0.2).abs() < EPS);
}
#[test]
fn test_q_from_k_zero() {
assert!((q_from_k(0.0) - 0.0).abs() < EPS);
}
#[test]
fn test_k_from_q_zero() {
assert!((k_from_q(0.0) - 0.0).abs() < EPS);
}
#[test]
fn test_k_q_roundtrip() {
for &k in &[0.1, 0.3, 0.5, 0.7, 0.9] {
let q = q_from_k(k);
let k2 = k_from_q(q);
assert!(
(k2 - k).abs() < 1e-10,
"round-trip failed for k={k}: got {k2}"
);
}
}
#[test]
fn test_jacobi_sn_zero() {
assert!(jacobi_sn(0.0, 0.5).abs() < EPS);
}
#[test]
fn test_jacobi_cn_zero() {
assert!((jacobi_cn(0.0, 0.5) - 1.0).abs() < 1e-10);
}
#[test]
fn test_jacobi_dn_zero() {
assert!((jacobi_dn(0.0, 0.7) - 1.0).abs() < 1e-10);
}
#[test]
fn test_jacobi_identity_sn2_cn2() {
let k = 0.6_f64;
for &u in &[0.3, 0.8, 1.2] {
let sn = jacobi_sn(u, k);
let cn = jacobi_cn(u, k);
assert!(
(sn * sn + cn * cn - 1.0).abs() < 1e-10,
"sn² + cn² ≠ 1 at u={u}, k={k}: sn={sn}, cn={cn}"
);
}
}
#[test]
fn test_invalid_q() {
assert!(theta1(0.0, -0.1).is_nan());
assert!(theta1(0.0, 1.0).is_nan());
assert!(theta2(0.0, 1.5).is_nan());
}
}