use core::f64::consts::PI;
#[allow(unused_imports)]
use num_traits::Float as _;
use crate::constants::{
BATTIN_THRESHOLD, HYPERGEOMETRIC_2F1_MAX_TERMS, HYPERGEOMETRIC_2F1_TOL, LAGRANGE_THRESHOLD,
};
pub(crate) fn x_to_tof_with_y(x: f64, y: f64, lambda: f64, m: u32) -> f64 {
let dist = (x - 1.0).abs();
if m == 0 && dist <= BATTIN_THRESHOLD {
x_to_tof_battin(x, y, lambda)
} else if dist <= LAGRANGE_THRESHOLD {
x_to_tof_lancaster(x, y, lambda, m)
} else {
x_to_tof_lagrange(x, lambda, m)
}
}
#[inline]
pub(crate) fn compute_y(x: f64, lambda: f64) -> f64 {
(1.0 - lambda * lambda * (1.0 - x * x)).max(0.0).sqrt()
}
fn x_to_tof_lagrange(x: f64, lambda: f64, m: u32) -> f64 {
let a = 1.0 / (1.0 - x * x);
if a > 0.0 {
let alpha = 2.0 * x.acos();
let mut beta = 2.0 * (lambda * lambda / a).abs().sqrt().asin();
if lambda < 0.0 {
beta = -beta;
}
0.5 * a * a.sqrt() * ((alpha - alpha.sin()) - (beta - beta.sin()) + 2.0 * PI * f64::from(m))
} else {
let alpha = 2.0 * x.acosh();
let mut beta = 2.0 * (-lambda * lambda / a).sqrt().asinh();
if lambda < 0.0 {
beta = -beta;
}
-0.5 * a * (-a).sqrt() * ((beta - beta.sinh()) - (alpha - alpha.sinh()))
}
}
fn x_to_tof_lancaster(x: f64, y: f64, lambda: f64, m: u32) -> f64 {
let one_m_x2 = 1.0 - x * x;
let psi = compute_psi(x, y, lambda, one_m_x2);
(((psi + f64::from(m) * PI) / one_m_x2.abs().sqrt()) - x + lambda * y) / one_m_x2
}
fn compute_psi(x: f64, y: f64, lambda: f64, one_m_x2: f64) -> f64 {
if x.abs() < 1.0 {
((y - x * lambda) * one_m_x2.sqrt()).atan2(x * y + lambda * one_m_x2)
} else {
((y - x * lambda) * (-one_m_x2).sqrt()).asinh()
}
}
fn x_to_tof_battin(x: f64, y: f64, lambda: f64) -> f64 {
let eta = y - lambda * x;
let s1 = 0.5 * (1.0 - lambda - x * eta);
let q = (4.0 / 3.0) * hypergeometric_2f1_special(s1);
0.5 * (eta * eta * eta * q + 4.0 * lambda * eta)
}
fn hypergeometric_2f1_special(z: f64) -> f64 {
let mut s = 1.0;
let mut c = 1.0;
for j in 0..HYPERGEOMETRIC_2F1_MAX_TERMS {
let jf = f64::from(j);
c *= (3.0 + jf) * (1.0 + jf) / (2.5 + jf) * z / (jf + 1.0);
let s_new = s + c;
if (s_new - s).abs() < HYPERGEOMETRIC_2F1_TOL {
return s_new;
}
s = s_new;
}
s
}
#[allow(clippy::similar_names)] pub(crate) fn tof_derivatives_with_y(
x: f64,
y: f64,
lambda: f64,
tof: f64,
) -> (f64, f64, f64) {
let one_m_x2 = 1.0 - x * x;
let l2 = lambda * lambda;
let l3 = l2 * lambda;
let l5 = l3 * l2;
let dt = (3.0 * tof * x - 2.0 + 2.0 * l3 * x / y) / one_m_x2;
let ddt = (3.0 * tof + 5.0 * x * dt + 2.0 * (1.0 - l2) * l3 / (y * y * y)) / one_m_x2;
let dddt = (7.0 * x * ddt + 8.0 * dt - 6.0 * (1.0 - l2) * l5 * x / (y.powi(5))) / one_m_x2;
(dt, ddt, dddt)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn multi_rev_near_parabolic_keeps_revolution_term() {
let x = 0.995;
let lambda = 0.4;
let m = 1;
let y = compute_y(x, lambda);
let got = x_to_tof_with_y(x, y, lambda, m);
let expected = x_to_tof_lancaster(x, y, lambda, m);
assert!(
(got - expected).abs() < 1e-12,
"multi-rev TOF dispatch dropped M term: got {got}, expected {expected}"
);
}
#[test]
#[allow(clippy::similar_names)] fn analytic_derivatives_match_finite_differences() {
let cases = [
(0.5_f64, 0.3_f64, 0_u32),
(-0.3, 0.5, 0),
(0.85, 0.2, 0),
(0.5, 0.3, 2),
(-0.5, -0.4, 1),
(0.3, 0.7, 3),
];
let h = 1e-3;
let tof_at = |x: f64, lambda: f64, m: u32| {
x_to_tof_with_y(x, compute_y(x, lambda), lambda, m)
};
for (x, lambda, m) in cases {
let t = tof_at(x, lambda, m);
let y = compute_y(x, lambda);
let (dt_a, ddt_a, dddt_a) = tof_derivatives_with_y(x, y, lambda, t);
let t_p1 = tof_at(x + h, lambda, m);
let t_m1 = tof_at(x - h, lambda, m);
let t_p2 = tof_at(x + 2.0 * h, lambda, m);
let t_m2 = tof_at(x - 2.0 * h, lambda, m);
let dt_fd = (t_p1 - t_m1) / (2.0 * h);
let ddt_fd = (t_p1 - 2.0 * t + t_m1) / (h * h);
let dddt_fd = (t_p2 - 2.0 * t_p1 + 2.0 * t_m1 - t_m2) / (2.0 * h * h * h);
let rel_err = |a: f64, b: f64| (a - b).abs() / a.abs().max(1.0);
assert!(
rel_err(dt_a, dt_fd) < 1e-5,
"dT/dx mismatch at (x={x}, λ={lambda}, M={m}): \
analytic = {dt_a}, FD = {dt_fd}",
);
assert!(
rel_err(ddt_a, ddt_fd) < 1e-3,
"d²T/dx² mismatch at (x={x}, λ={lambda}, M={m}): \
analytic = {ddt_a}, FD = {ddt_fd}",
);
assert!(
rel_err(dddt_a, dddt_fd) < 1e-1,
"d³T/dx³ mismatch at (x={x}, λ={lambda}, M={m}): \
analytic = {dddt_a}, FD = {dddt_fd}",
);
}
}
}