Skip to main content

batuta_common/
math.rs

1//! Shared mathematical functions for the Batuta stack.
2//!
3//! Provides common math operations (statistics, special functions) used across
4//! pmat, trueno, aprender, and trueno-viz.
5
6// =============================================================================
7// ERROR FUNCTION (Abramowitz & Stegun approximation)
8// =============================================================================
9
10/// Compute the error function erf(x) using the Abramowitz & Stegun approximation.
11///
12/// Maximum error: |ε| < 1.5 × 10⁻⁷
13///
14/// # Examples
15/// ```
16/// use batuta_common::math::erf;
17/// assert!((erf(0.0) - 0.0).abs() < 1e-6);
18/// assert!((erf(1.0) - 0.842_700_8).abs() < 1e-5);
19/// assert!((erf(-1.0) + 0.842_700_8).abs() < 1e-5);
20/// ```
21#[must_use]
22pub fn erf(x: f64) -> f64 {
23    // Abramowitz and Stegun formula 7.1.26
24    const A1: f64 = 0.254_829_592;
25    const A2: f64 = -0.284_496_736;
26    const A3: f64 = 1.421_413_741;
27    const A4: f64 = -1.453_152_027;
28    const A5: f64 = 1.061_405_429;
29    const P: f64 = 0.327_591_1;
30
31    let sign = if x < 0.0 { -1.0 } else { 1.0 };
32    let x = x.abs();
33    let t = 1.0 / (1.0 + P * x);
34    let y = 1.0 - (((((A5 * t + A4) * t) + A3) * t + A2) * t + A1) * t * (-x * x).exp();
35
36    sign * y
37}
38
39/// Compute erf(x) with f32 precision.
40///
41/// Convenience wrapper for f32 callers; internally delegates to the f64 version.
42#[must_use]
43pub fn erf_f32(x: f32) -> f32 {
44    erf(f64::from(x)) as f32
45}
46
47// =============================================================================
48// STANDARD DEVIATION
49// =============================================================================
50
51/// Compute sample standard deviation of a slice (Bessel's correction, n-1).
52///
53/// Returns 0.0 if fewer than 2 elements.
54///
55/// # Examples
56/// ```
57/// use batuta_common::math::std_dev;
58/// let data = [2.0, 4.0, 4.0, 4.0, 5.0, 5.0, 7.0, 9.0];
59/// assert!((std_dev(&data) - 2.138).abs() < 0.01);
60/// assert_eq!(std_dev(&[1.0]), 0.0);
61/// assert_eq!(std_dev(&[]), 0.0);
62/// ```
63#[must_use]
64pub fn std_dev(samples: &[f64]) -> f64 {
65    if samples.len() < 2 {
66        return 0.0;
67    }
68    let n = samples.len() as f64;
69    let mean = samples.iter().sum::<f64>() / n;
70    let variance = samples.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / (n - 1.0);
71    variance.sqrt()
72}
73
74/// Compute sample standard deviation for f32 data.
75///
76/// Returns 0.0 if fewer than 2 elements.
77#[must_use]
78pub fn std_dev_f32(samples: &[f32]) -> f32 {
79    if samples.len() < 2 {
80        return 0.0;
81    }
82    let n = samples.len() as f32;
83    let mean = samples.iter().sum::<f32>() / n;
84    let variance = samples.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / (n - 1.0);
85    variance.sqrt()
86}
87
88/// Compute sample standard deviation given a pre-computed mean.
89///
90/// Useful when the mean has already been calculated separately.
91#[must_use]
92pub fn std_dev_with_mean(samples: &[f64], mean: f64) -> f64 {
93    if samples.len() < 2 {
94        return 0.0;
95    }
96    let variance =
97        samples.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / (samples.len() - 1) as f64;
98    variance.sqrt()
99}
100
101/// Compute sample standard deviation for f32 data given a pre-computed mean.
102#[must_use]
103pub fn std_dev_f32_with_mean(samples: &[f32], mean: f32) -> f32 {
104    if samples.len() < 2 {
105        return 0.0;
106    }
107    let variance =
108        samples.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / (samples.len() - 1) as f32;
109    variance.sqrt()
110}
111
112// =============================================================================
113// COSINE SIMILARITY
114// =============================================================================
115
116/// Compute cosine similarity between two f32 vectors.
117///
118/// Returns 0.0 if either vector has zero norm.
119///
120/// # Examples
121/// ```
122/// use batuta_common::math::cosine_similarity_f32;
123/// let a = [1.0f32, 0.0, 0.0];
124/// let b = [0.0f32, 1.0, 0.0];
125/// assert!((cosine_similarity_f32(&a, &b) - 0.0).abs() < 1e-6);
126///
127/// let c = [1.0f32, 2.0, 3.0];
128/// assert!((cosine_similarity_f32(&c, &c) - 1.0).abs() < 1e-6);
129/// ```
130#[must_use]
131pub fn cosine_similarity_f32(a: &[f32], b: &[f32]) -> f32 {
132    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
133    let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
134    let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
135
136    if norm_a == 0.0 || norm_b == 0.0 {
137        return 0.0;
138    }
139    dot / (norm_a * norm_b)
140}
141
142/// Compute cosine similarity between two f64 vectors.
143///
144/// Returns 0.0 if either vector has zero norm.
145#[must_use]
146pub fn cosine_similarity(a: &[f64], b: &[f64]) -> f64 {
147    let dot: f64 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
148    let norm_a: f64 = a.iter().map(|x| x * x).sum::<f64>().sqrt();
149    let norm_b: f64 = b.iter().map(|x| x * x).sum::<f64>().sqrt();
150
151    if norm_a == 0.0 || norm_b == 0.0 {
152        return 0.0;
153    }
154    dot / (norm_a * norm_b)
155}
156
157// =============================================================================
158// USAGE PERCENT
159// =============================================================================
160
161/// Compute usage percentage from used/total byte counts.
162///
163/// Returns 0.0 if `total` is 0 (avoids divide-by-zero).
164///
165/// # Examples
166/// ```
167/// use batuta_common::math::usage_percent;
168/// assert!((usage_percent(750, 1000) - 75.0).abs() < 1e-10);
169/// assert_eq!(usage_percent(0, 0), 0.0);
170/// assert!((usage_percent(1024, 4096) - 25.0).abs() < 1e-10);
171/// ```
172#[must_use]
173pub fn usage_percent(used: u64, total: u64) -> f64 {
174    if total == 0 {
175        return 0.0;
176    }
177    (used as f64 / total as f64) * 100.0
178}
179
180// =============================================================================
181// TESTS
182// =============================================================================
183
184#[cfg(test)]
185mod tests {
186    use super::*;
187
188    // --- erf ---
189
190    #[test]
191    fn test_erf_zero() {
192        assert!((erf(0.0) - 0.0).abs() < 1e-6);
193    }
194
195    #[test]
196    fn test_erf_positive() {
197        assert!((erf(1.0) - 0.842_700_793).abs() < 1e-6);
198    }
199
200    #[test]
201    fn test_erf_negative_symmetry() {
202        assert!((erf(-1.0) + erf(1.0)).abs() < 1e-10);
203    }
204
205    #[test]
206    fn test_erf_large() {
207        assert!((erf(5.0) - 1.0).abs() < 1e-6);
208    }
209
210    #[test]
211    fn test_erf_f32_matches() {
212        let f32_val = erf_f32(1.0_f32);
213        let f64_val = erf(1.0) as f32;
214        assert!((f32_val - f64_val).abs() < 1e-6);
215    }
216
217    // --- std_dev ---
218
219    #[test]
220    fn test_std_dev_known_value() {
221        // Sample std_dev with Bessel's correction (n-1):
222        // Mean = 5.0, sum_sq_diff = 32, variance = 32/7 ≈ 4.571, sd ≈ 2.138
223        let data = [2.0, 4.0, 4.0, 4.0, 5.0, 5.0, 7.0, 9.0];
224        assert!((std_dev(&data) - 2.138).abs() < 0.01);
225    }
226
227    #[test]
228    fn test_std_dev_single_element() {
229        assert_eq!(std_dev(&[42.0]), 0.0);
230    }
231
232    #[test]
233    fn test_std_dev_empty() {
234        assert_eq!(std_dev(&[]), 0.0);
235    }
236
237    #[test]
238    fn test_std_dev_identical_values() {
239        assert_eq!(std_dev(&[5.0, 5.0, 5.0, 5.0]), 0.0);
240    }
241
242    #[test]
243    fn test_std_dev_f32() {
244        let data: Vec<f32> = vec![2.0, 4.0, 4.0, 4.0, 5.0, 5.0, 7.0, 9.0];
245        assert!((std_dev_f32(&data) - 2.138).abs() < 0.02);
246    }
247
248    #[test]
249    fn test_std_dev_with_mean_matches() {
250        let data = [2.0, 4.0, 4.0, 4.0, 5.0, 5.0, 7.0, 9.0];
251        let mean = data.iter().sum::<f64>() / data.len() as f64;
252        let sd1 = std_dev(&data);
253        let sd2 = std_dev_with_mean(&data, mean);
254        assert!((sd1 - sd2).abs() < 1e-10);
255    }
256
257    // --- cosine_similarity ---
258
259    #[test]
260    fn test_cosine_identical() {
261        let a = [1.0, 2.0, 3.0];
262        assert!((cosine_similarity(&a, &a) - 1.0).abs() < 1e-10);
263    }
264
265    #[test]
266    fn test_cosine_orthogonal() {
267        let a = [1.0, 0.0, 0.0];
268        let b = [0.0, 1.0, 0.0];
269        assert!(cosine_similarity(&a, &b).abs() < 1e-10);
270    }
271
272    #[test]
273    fn test_cosine_opposite() {
274        let a = [1.0, 0.0];
275        let b = [-1.0, 0.0];
276        assert!((cosine_similarity(&a, &b) + 1.0).abs() < 1e-10);
277    }
278
279    #[test]
280    fn test_cosine_zero_vector() {
281        let a = [0.0, 0.0, 0.0];
282        let b = [1.0, 2.0, 3.0];
283        assert_eq!(cosine_similarity(&a, &b), 0.0);
284    }
285
286    #[test]
287    fn test_cosine_f32() {
288        let a = [1.0f32, 0.0, 0.0];
289        let b = [0.0f32, 1.0, 0.0];
290        assert!(cosine_similarity_f32(&a, &b).abs() < 1e-6);
291    }
292
293    // --- usage_percent ---
294
295    #[test]
296    fn test_usage_percent_normal() {
297        assert!((usage_percent(750, 1000) - 75.0).abs() < 1e-10);
298    }
299
300    #[test]
301    fn test_usage_percent_zero_total() {
302        assert_eq!(usage_percent(0, 0), 0.0);
303    }
304
305    #[test]
306    fn test_usage_percent_full() {
307        assert!((usage_percent(1000, 1000) - 100.0).abs() < 1e-10);
308    }
309
310    #[test]
311    fn test_usage_percent_empty() {
312        assert!((usage_percent(0, 1000) - 0.0).abs() < 1e-10);
313    }
314}