Skip to main content

gam_math/
special.rs

1//! Scalar special-function primitives shared across the workspace.
2//!
3//! These are pure (`std`/`libm`-only) numeric kernels with no upward crate
4//! dependencies, so they live in the lowest crate (`gam-math`) and can be
5//! consumed by any term/basis/inference code without inducing an SCC edge.
6
7/// Numerically stable `C(n,k) = n! / (k!·(n−k)!)` as `f64`.  Uses the
8/// symmetry `C(n,k) = C(n, n−k)` to keep the loop count `min(k, n−k)`
9/// and the multiplicative recurrence `C(n,j+1) = C(n,j)·(n−j)/(j+1)`,
10/// avoiding the overflow of separate factorial evaluations.  Returns
11/// `0.0` for `k > n` and exact integer results within `2^53`.
12#[inline]
13pub fn binomial_coefficient_f64(n: usize, k: usize) -> f64 {
14    if k > n {
15        return 0.0;
16    }
17    if k == 0 || k == n {
18        return 1.0;
19    }
20    let k_eff = k.min(n - k);
21    // Carry the recurrence in u128, not f64. At step `j` the running product
22    // equals the integer `C(n, j)`, which is always divisible by the next
23    // denominator `(j + 1)` (the partial product of `(j+1)` consecutive
24    // integers `(n−j)…(n)` is divisible by `(j+1)!`), so each integer division
25    // is exact and no rounding accumulates. The earlier all-`f64` recurrence
26    // divided in floating point, where `(n−j)/(j+1)` is generally inexact, and
27    // the drift pushed results off the true integer well below `2^53`
28    // (e.g. `C(54,24)` came back one short). Converting the exact `u128` at the
29    // end is bit-exact for every value at or below `2^53`.
30    let mut num: u128 = 1;
31    for j in 0..k_eff {
32        match num.checked_mul((n - j) as u128) {
33            Some(scaled) => num = scaled / (j as u128 + 1),
34            None => {
35                // The true coefficient overflows u128 — astronomically above
36                // `2^53`, where the exactness contract no longer applies.
37                // Finish the (now necessarily inexact) recurrence in f64.
38                let mut out = num as f64;
39                for jj in j..k_eff {
40                    out = out * (n - jj) as f64 / (jj + 1) as f64;
41                }
42                return out;
43            }
44        }
45    }
46    num as f64
47}
48
49#[inline]
50fn horner_polynomial(x: f64, coeffs: &[f64]) -> f64 {
51    coeffs.iter().rev().fold(0.0, |acc, &c| acc * x + c)
52}
53
54/// Evaluate `(Σ_k coeffs[k]·x^k) · exp(−x)` without overflow.  For moderate
55/// `x ≤ 600` uses Horner + `exp(−x)` directly; for very large `x` rewrites
56/// `xᵈ · exp(−x) = exp(d·ln x − x)` and runs Horner in `1/x`, which keeps
57/// both the polynomial sum and its multiplier inside double range.  Returns
58/// `0.0` for non-finite `x` or empty `coeffs`.
59#[inline]
60pub fn stable_polynomial_times_exp_neg(x: f64, coeffs: &[f64]) -> f64 {
61    if coeffs.is_empty() || !x.is_finite() {
62        return 0.0;
63    }
64    // Below this argument `(-x).exp()` is still well-resolved, so the direct
65    // Horner-times-exp form is both accurate and cheapest. Above it the factor
66    // underflows toward zero and we switch to the convergent asymptotic tail
67    // series to retain the leading significant digits.
68    const DIRECT_EXP_SWITCH: f64 = 600.0;
69    if x <= DIRECT_EXP_SWITCH {
70        return horner_polynomial(x, coeffs) * (-x).exp();
71    }
72
73    let inv_x = x.recip();
74    let mut tail = 0.0;
75    for &c in coeffs {
76        tail = tail * inv_x + c;
77    }
78    let degree = (coeffs.len() - 1) as f64;
79    let scale = (degree * x.ln() - x).exp();
80    scale * tail
81}
82
83#[cfg(test)]
84mod tests {
85    use super::*;
86
87    #[test]
88    fn binom_k_exceeds_n_returns_zero() {
89        assert_eq!(binomial_coefficient_f64(3, 5), 0.0);
90        assert_eq!(binomial_coefficient_f64(0, 1), 0.0);
91        assert_eq!(binomial_coefficient_f64(10, 11), 0.0);
92    }
93
94    #[test]
95    fn binom_k_zero_returns_one() {
96        assert_eq!(binomial_coefficient_f64(0, 0), 1.0);
97        assert_eq!(binomial_coefficient_f64(5, 0), 1.0);
98        assert_eq!(binomial_coefficient_f64(100, 0), 1.0);
99    }
100
101    #[test]
102    fn binom_k_equals_n_returns_one() {
103        assert_eq!(binomial_coefficient_f64(1, 1), 1.0);
104        assert_eq!(binomial_coefficient_f64(5, 5), 1.0);
105        assert_eq!(binomial_coefficient_f64(20, 20), 1.0);
106    }
107
108    #[test]
109    fn binom_small_exact_values() {
110        assert_eq!(binomial_coefficient_f64(5, 2), 10.0);
111        assert_eq!(binomial_coefficient_f64(10, 3), 120.0);
112        assert_eq!(binomial_coefficient_f64(20, 10), 184_756.0);
113        assert_eq!(binomial_coefficient_f64(6, 3), 20.0);
114    }
115
116    #[test]
117    fn binom_symmetry() {
118        assert_eq!(
119            binomial_coefficient_f64(10, 3),
120            binomial_coefficient_f64(10, 7)
121        );
122        assert_eq!(
123            binomial_coefficient_f64(20, 5),
124            binomial_coefficient_f64(20, 15)
125        );
126        assert_eq!(
127            binomial_coefficient_f64(54, 24),
128            binomial_coefficient_f64(54, 30)
129        );
130    }
131
132    #[test]
133    fn binom_c54_24_is_exact() {
134        // The u128-recurrence fix restored this value (old f64 recurrence
135        // returned 1_402_659_561_581_459, one short of the true integer).
136        assert_eq!(binomial_coefficient_f64(54, 24), 1_402_659_561_581_460.0);
137    }
138
139    #[test]
140    fn poly_exp_empty_coeffs_returns_zero() {
141        assert_eq!(stable_polynomial_times_exp_neg(1.0, &[]), 0.0);
142        assert_eq!(stable_polynomial_times_exp_neg(0.0, &[]), 0.0);
143        assert_eq!(stable_polynomial_times_exp_neg(700.0, &[]), 0.0);
144    }
145
146    #[test]
147    fn poly_exp_nonfinite_x_returns_zero() {
148        assert_eq!(stable_polynomial_times_exp_neg(f64::INFINITY, &[1.0, 2.0]), 0.0);
149        assert_eq!(
150            stable_polynomial_times_exp_neg(f64::NEG_INFINITY, &[1.0, 2.0]),
151            0.0
152        );
153        assert_eq!(stable_polynomial_times_exp_neg(f64::NAN, &[1.0]), 0.0);
154    }
155
156    #[test]
157    fn poly_exp_constant_at_zero() {
158        // At x=0: poly(0) = coeffs[0], exp(0)=1 → result = coeffs[0].
159        assert_eq!(stable_polynomial_times_exp_neg(0.0, &[5.0]), 5.0);
160        assert_eq!(stable_polynomial_times_exp_neg(0.0, &[3.0, 1.0, 2.0]), 3.0);
161    }
162
163    #[test]
164    fn poly_exp_constant_poly_direct_path() {
165        // x=2.0 < 600: direct Horner * exp(-x).
166        let x = 2.0;
167        let got = stable_polynomial_times_exp_neg(x, &[3.0]);
168        let expected = 3.0 * (-x).exp();
169        assert!((got - expected).abs() < 1e-14, "got={got} expected={expected}");
170    }
171
172    #[test]
173    fn poly_exp_linear_poly_direct_path() {
174        // coeffs = [a, b] → poly = a + b*x.
175        let x = 1.5;
176        let (a, b) = (2.0, 3.0);
177        let got = stable_polynomial_times_exp_neg(x, &[a, b]);
178        let expected = (a + b * x) * (-x).exp();
179        assert!((got - expected).abs() < 1e-14, "got={got} expected={expected}");
180    }
181
182    #[test]
183    fn poly_exp_constant_poly_asymptotic_path() {
184        // x=700 > 600: asymptotic path. For poly = [1.0], result = exp(-700).
185        let x = 700.0_f64;
186        let got = stable_polynomial_times_exp_neg(x, &[1.0]);
187        let expected = (-x).exp();
188        let rel = (got - expected).abs() / expected;
189        assert!(rel < 1e-12, "got={got} expected={expected} rel={rel}");
190    }
191
192    #[test]
193    fn poly_exp_quadratic_asymptotic_path() {
194        // x=620 > 600: poly = x^2 (coeffs=[0,0,1]). Result = x^2 * exp(-x).
195        // x=800 would underflow to 0.0 in both the asymptotic path and the
196        // reference, making the relative-error check degenerate; x=620 keeps
197        // the result in the normal f64 range (~10^-264) while still exercising
198        // the asymptotic branch (threshold is x=600).
199        let x = 620.0_f64;
200        let got = stable_polynomial_times_exp_neg(x, &[0.0, 0.0, 1.0]);
201        let expected = (2.0 * x.ln() - x).exp();
202        let rel = (got - expected).abs() / expected.abs();
203        assert!(rel < 1e-12, "got={got} expected={expected} rel={rel}");
204    }
205}