#[inline(always)]
pub fn sin_cos(x: f64) -> (f64, f64) {
x.sin_cos()
}
#[inline(always)]
pub fn sqrt_fast(x: f64) -> f64 {
x.sqrt()
}
#[inline(always)]
pub fn fma(a: f64, b: f64, c: f64) -> f64 {
a.mul_add(b, c)
}
#[inline]
pub fn stumpff_c(z: f64) -> f64 {
const TOL: f64 = 1e-6;
if z > TOL {
let sqrt_z = sqrt_fast(z);
(1.0 - sqrt_z.cos()) / z
} else if z < -TOL {
let sqrt_neg_z = sqrt_fast(-z);
(1.0 - sqrt_neg_z.cosh()) / z
} else {
0.5 - z / 24.0 + z * z / 720.0
}
}
#[inline]
pub fn stumpff_s(z: f64) -> f64 {
const TOL: f64 = 1e-6;
if z > TOL {
let sqrt_z = sqrt_fast(z);
(sqrt_z - sqrt_z.sin()) / (z * sqrt_z)
} else if z < -TOL {
let sqrt_neg_z = sqrt_fast(-z);
(sqrt_neg_z.sinh() - sqrt_neg_z) / (z * sqrt_neg_z)
} else {
1.0 / 6.0 - z / 120.0 + z * z / 5040.0
}
}
#[inline]
pub fn stumpff_cs(z: f64) -> (f64, f64) {
const TOL: f64 = 1e-6;
if z > TOL {
let sqrt_z = sqrt_fast(z);
let (sin_sqrt_z, cos_sqrt_z) = sin_cos(sqrt_z);
let c2 = (1.0 - cos_sqrt_z) / z;
let c3 = (sqrt_z - sin_sqrt_z) / (z * sqrt_z);
(c2, c3)
} else if z < -TOL {
let sqrt_neg_z = sqrt_fast(-z);
let sinh_val = sqrt_neg_z.sinh();
let cosh_val = sqrt_neg_z.cosh();
let c2 = (1.0 - cosh_val) / z;
let c3 = (sinh_val - sqrt_neg_z) / (z * sqrt_neg_z);
(c2, c3)
} else {
let c2 = 0.5 - z / 24.0 + z * z / 720.0;
let c3 = 1.0 / 6.0 - z / 120.0 + z * z / 5040.0;
(c2, c3)
}
}
#[inline]
pub fn stumpff_derivatives(z: f64, c2: f64, c3: f64) -> (f64, f64) {
const TOL: f64 = 1e-6;
if z.abs() < TOL {
let c2_prime = -1.0 / 24.0 + z / 360.0;
let c3_prime = -1.0 / 120.0 + z / 840.0;
(c2_prime, c3_prime)
} else {
let c2_prime = (c3 - 3.0 * c2) / (2.0 * z);
let c3_prime = (c2 - 3.0 * c3) / (2.0 * z);
(c2_prime, c3_prime)
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
use std::f64::consts::PI;
#[test]
fn test_sin_cos_accuracy() {
let test_values = vec![0.0, PI / 6.0, PI / 4.0, PI / 3.0, PI / 2.0, PI, 2.0 * PI];
for x in test_values {
let (s, c) = sin_cos(x);
assert_relative_eq!(s, x.sin(), epsilon = 1e-15);
assert_relative_eq!(c, x.cos(), epsilon = 1e-15);
}
}
#[test]
fn test_sqrt_fast_accuracy() {
let test_values = vec![0.0, 1.0, 2.0, 4.0, 9.0, 16.0, 100.0, 1e6];
for x in test_values {
assert_relative_eq!(sqrt_fast(x), x.sqrt(), epsilon = 1e-15);
}
}
#[test]
fn test_fma_correctness() {
assert_eq!(fma(2.0, 3.0, 4.0), 10.0);
assert_eq!(fma(0.0, 5.0, 7.0), 7.0);
assert_eq!(fma(1.0, 1.0, 1.0), 2.0);
}
#[test]
fn test_fma_accuracy() {
let a = 1e16;
let b = 1.0 + 1e-16;
let c = -1e16;
let fma_result = fma(a, b, c);
let separate_result = a * b + c;
assert!(fma_result >= separate_result);
}
#[test]
fn test_stumpff_cs_elliptic() {
let z = 1.0;
let (c2, c3) = stumpff_cs(z);
let sqrt_z = z.sqrt();
let expected_c2 = (1.0 - sqrt_z.cos()) / z;
let expected_c3 = (sqrt_z - sqrt_z.sin()) / (z * sqrt_z);
assert_relative_eq!(c2, expected_c2, epsilon = 1e-15);
assert_relative_eq!(c3, expected_c3, epsilon = 1e-15);
}
#[test]
fn test_stumpff_cs_hyperbolic() {
let z = -1.0;
let (c2, c3) = stumpff_cs(z);
let sqrt_neg_z = (-z).sqrt();
let expected_c2 = (1.0 - sqrt_neg_z.cosh()) / z;
let expected_c3 = (sqrt_neg_z.sinh() - sqrt_neg_z) / (z * sqrt_neg_z);
assert_relative_eq!(c2, expected_c2, epsilon = 1e-15);
assert_relative_eq!(c3, expected_c3, epsilon = 1e-15);
}
#[test]
fn test_stumpff_cs_parabolic() {
let z = 1e-8;
let (c2, c3) = stumpff_cs(z);
assert_relative_eq!(c2, 0.5, epsilon = 1e-6);
assert_relative_eq!(c3, 1.0 / 6.0, epsilon = 1e-6);
}
#[test]
fn test_stumpff_cs_consistency() {
let test_values = vec![-10.0, -1.0, -0.1, 0.0, 0.1, 1.0, 10.0];
for z in test_values {
let (c2_combined, c3_combined) = stumpff_cs(z);
let c2_separate = stumpff_c(z);
let c3_separate = stumpff_s(z);
assert_relative_eq!(c2_combined, c2_separate, epsilon = 1e-14);
assert_relative_eq!(c3_combined, c3_separate, epsilon = 1e-14);
}
}
#[test]
fn test_stumpff_derivatives() {
let z = 1.0;
let (c2, c3) = stumpff_cs(z);
let (c2_prime, c3_prime) = stumpff_derivatives(z, c2, c3);
let expected_c2_prime = (c3 - 3.0 * c2) / (2.0 * z);
let expected_c3_prime = (c2 - 3.0 * c3) / (2.0 * z);
assert_relative_eq!(c2_prime, expected_c2_prime, epsilon = 1e-15);
assert_relative_eq!(c3_prime, expected_c3_prime, epsilon = 1e-15);
}
}