Skip to main content

entrenar/
numerical.rs

1//! Numerical utilities for numerically stable computation
2//!
3//! Provides compensated summation (Kahan) and other numerical primitives
4//! used in training loops and loss computation.
5
6/// Kahan compensated summation — reduces floating-point accumulation error
7/// from O(n * eps) to O(eps) for n addends.
8///
9/// Reference: Kahan (1965) "Pracniques: Further remarks on reducing truncation errors"
10///
11/// # Example
12/// ```
13/// use entrenar::numerical::kahan_sum;
14///
15/// // Many small values where naive summation drifts
16/// let values: Vec<f32> = vec![0.1; 100_000];
17/// let compensated = kahan_sum(&values);
18/// assert!((compensated - 10_000.0).abs() < 0.01);
19/// ```
20pub fn kahan_sum(values: &[f32]) -> f32 {
21    let mut sum = 0.0_f32;
22    let mut compensation = 0.0_f32;
23
24    for &val in values {
25        let y = val - compensation;
26        let t = sum + y;
27        compensation = (t - sum) - y;
28        sum = t;
29    }
30
31    sum
32}
33
34/// Kahan compensated summation for f64 values
35pub fn kahan_sum_f64(values: &[f64]) -> f64 {
36    let mut sum = 0.0_f64;
37    let mut compensation = 0.0_f64;
38
39    for &val in values {
40        let y = val - compensation;
41        let t = sum + y;
42        compensation = (t - sum) - y;
43        sum = t;
44    }
45
46    sum
47}
48
49/// Numerically stable mean computation using Kahan summation
50pub fn kahan_mean(values: &[f32]) -> f32 {
51    if values.is_empty() {
52        return 0.0;
53    }
54    kahan_sum(values) / values.len() as f32
55}
56
57/// Numerically stable variance computation (two-pass with Kahan)
58pub fn kahan_variance(values: &[f32]) -> f32 {
59    if values.len() < 2 {
60        return 0.0;
61    }
62    let mean = kahan_mean(values);
63    let sq_diffs: Vec<f32> = values.iter().map(|&x| (x - mean) * (x - mean)).collect();
64    kahan_sum(&sq_diffs) / values.len() as f32
65}
66
67#[cfg(test)]
68mod tests {
69    use super::*;
70
71    #[test]
72    fn test_kahan_sum_basic() {
73        let values = vec![1.0, 2.0, 3.0, 4.0];
74        assert!((kahan_sum(&values) - 10.0).abs() < 1e-7);
75    }
76
77    #[test]
78    fn test_kahan_sum_accumulated_error() {
79        // Many small values: naive accumulation drifts, Kahan stays accurate
80        // Sum of 100_000 values of 0.1 should be ~10000.0
81        let values: Vec<f32> = vec![0.1; 100_000];
82        let naive: f32 = values.iter().sum();
83        let compensated = kahan_sum(&values);
84        let expected = 10_000.0_f32;
85
86        let kahan_err = (compensated - expected).abs();
87        let naive_err = (naive - expected).abs();
88
89        // Kahan should be more accurate than naive
90        assert!(
91            kahan_err <= naive_err + 1e-3,
92            "Kahan error {kahan_err} should be <= naive error {naive_err}"
93        );
94        // Kahan should be close to expected
95        assert!(
96            kahan_err < 0.01,
97            "Kahan sum = {compensated}, expected ~{expected}, error = {kahan_err}"
98        );
99    }
100
101    #[test]
102    fn test_kahan_sum_many_small_values() {
103        // Sum of 1_000_000 values of 1e-7 should be ~0.1
104        let values: Vec<f32> = vec![1e-7; 1_000_000];
105        let compensated = kahan_sum(&values);
106        let expected = 0.1_f32;
107
108        assert!(
109            (compensated - expected).abs() < 1e-6,
110            "Kahan sum of 1M * 1e-7 = {compensated}, expected ~{expected}"
111        );
112    }
113
114    #[test]
115    fn test_kahan_sum_empty() {
116        assert_eq!(kahan_sum(&[]), 0.0);
117    }
118
119    #[test]
120    fn test_kahan_sum_single() {
121        assert_eq!(kahan_sum(&[42.0]), 42.0);
122    }
123
124    #[test]
125    fn test_kahan_mean() {
126        let values = vec![1.0, 2.0, 3.0, 4.0, 5.0];
127        assert!((kahan_mean(&values) - 3.0).abs() < 1e-7);
128    }
129
130    #[test]
131    fn test_kahan_variance() {
132        let values = vec![2.0, 4.0, 4.0, 4.0, 5.0, 5.0, 7.0, 9.0];
133        let var = kahan_variance(&values);
134        // Known variance for this dataset: 4.0
135        assert!((var - 4.0).abs() < 1e-5, "Kahan variance = {var}, expected 4.0");
136    }
137
138    #[test]
139    fn test_kahan_sum_f64_precision() {
140        // Many small f64 values where accumulation error matters
141        let values: Vec<f64> = vec![0.1; 1_000_000];
142        let compensated = kahan_sum_f64(&values);
143        let expected = 100_000.0_f64;
144        assert!(
145            (compensated - expected).abs() < 1e-8,
146            "Kahan f64 sum = {compensated}, expected {expected}"
147        );
148    }
149
150    /// Numerical validation: verify kahan sum against analytical_solution (EDD-03)
151    #[test]
152    fn verification_test_kahan_analytical_solution() {
153        // Closed-form exact_solution: sum(1..=N) = N*(N+1)/2
154        let n = 1000;
155        let values: Vec<f32> = (1..=n).map(|i| i as f32).collect();
156        let analytical_solution = (n * (n + 1) / 2) as f32;
157        let computed = kahan_sum(&values);
158        let tolerance = 1e-3;
159        assert!(
160            (computed - analytical_solution).abs() < tolerance,
161            "convergence_test: computed={computed}, exact={analytical_solution}"
162        );
163    }
164
165    #[test]
166    fn test_kahan_vs_naive_accuracy() {
167        // Alternating large and small values stress-test accumulation
168        let n = 10_000;
169        let values: Vec<f32> = (0..n).map(|i| if i % 2 == 0 { 1e6 } else { 1e-6 }).collect();
170
171        let kahan = kahan_sum(&values);
172        let naive: f32 = values.iter().sum();
173        let exact = (n / 2) as f32 * 1e6 + (n / 2) as f32 * 1e-6;
174
175        // Kahan should be closer to exact than naive
176        let kahan_err = (kahan - exact).abs();
177        let naive_err = (naive - exact).abs();
178        assert!(
179            kahan_err <= naive_err + 1e-3,
180            "Kahan error {kahan_err} should be <= naive error {naive_err}"
181        );
182    }
183}