Skip to main content

proof_engine/math/
statistics.rs

1//! Statistics and probability: descriptive stats, distributions, hypothesis testing,
2//! regression, Bayesian inference, random number generation, information theory.
3
4use std::f64::consts::PI;
5
6// ============================================================
7// RANDOM NUMBER GENERATORS
8// ============================================================
9
10/// Trait for random number generators.
11pub trait Rng {
12    fn next_u64(&mut self) -> u64;
13    fn next_f64(&mut self) -> f64 {
14        (self.next_u64() >> 11) as f64 / (1u64 << 53) as f64
15    }
16    fn next_u32(&mut self) -> u32 {
17        (self.next_u64() >> 32) as u32
18    }
19}
20
21/// Xorshift64 — fast, simple 64-bit RNG.
22#[derive(Clone, Debug)]
23pub struct Xorshift64 {
24    pub state: u64,
25}
26
27impl Xorshift64 {
28    pub fn new(seed: u64) -> Self { Self { state: seed.max(1) } }
29}
30
31impl Rng for Xorshift64 {
32    fn next_u64(&mut self) -> u64 {
33        let mut x = self.state;
34        x ^= x << 13;
35        x ^= x >> 7;
36        x ^= x << 17;
37        self.state = x;
38        x
39    }
40}
41
42/// PCG32 — Permuted Congruential Generator.
43#[derive(Clone, Debug)]
44pub struct Pcg32 {
45    pub state: u64,
46    pub inc: u64,
47}
48
49impl Pcg32 {
50    pub fn new(seed: u64, seq: u64) -> Self {
51        let mut rng = Self { state: 0, inc: (seq << 1) | 1 };
52        rng.state = rng.state.wrapping_add(seed);
53        rng.next_u64();
54        rng
55    }
56}
57
58impl Rng for Pcg32 {
59    fn next_u64(&mut self) -> u64 {
60        let old_state = self.state;
61        self.state = old_state
62            .wrapping_mul(6_364_136_223_846_793_005)
63            .wrapping_add(self.inc);
64        let xorshifted = ((old_state >> 18) ^ old_state) >> 27;
65        let rot = (old_state >> 59) as u32;
66        let result32 = xorshifted.rotate_right(rot) as u32;
67        (result32 as u64) | ((result32 as u64) << 32)
68    }
69}
70
71/// SplitMix64 — fast 64-bit generator suitable as seed scrambler.
72#[derive(Clone, Debug)]
73pub struct SplitMix64 {
74    pub state: u64,
75}
76
77impl SplitMix64 {
78    pub fn new(seed: u64) -> Self { Self { state: seed } }
79}
80
81impl Rng for SplitMix64 {
82    fn next_u64(&mut self) -> u64 {
83        self.state = self.state.wrapping_add(0x9e3779b97f4a7c15);
84        let mut z = self.state;
85        z = (z ^ (z >> 30)).wrapping_mul(0xbf58476d1ce4e5b9);
86        z = (z ^ (z >> 27)).wrapping_mul(0x94d049bb133111eb);
87        z ^ (z >> 31)
88    }
89}
90
91/// Linear Congruential Generator.
92#[derive(Clone, Debug)]
93pub struct Lcg {
94    pub state: u64,
95    pub a: u64,
96    pub c: u64,
97    pub m: u64,
98}
99
100impl Lcg {
101    pub fn new(seed: u64) -> Self {
102        Self {
103            state: seed,
104            a: 6_364_136_223_846_793_005,
105            c: 1_442_695_040_888_963_407,
106            m: u64::MAX,
107        }
108    }
109}
110
111impl Rng for Lcg {
112    fn next_u64(&mut self) -> u64 {
113        self.state = self.state.wrapping_mul(self.a).wrapping_add(self.c);
114        self.state
115    }
116}
117
118/// Fisher-Yates shuffle.
119pub fn shuffle<T>(data: &mut [T], rng: &mut impl Rng) {
120    let n = data.len();
121    for i in (1..n).rev() {
122        let j = (rng.next_u64() as usize) % (i + 1);
123        data.swap(i, j);
124    }
125}
126
127/// Sample k distinct indices from 0..n without replacement (Knuth's algorithm S).
128pub fn sample_without_replacement(n: usize, k: usize, rng: &mut impl Rng) -> Vec<usize> {
129    let k = k.min(n);
130    let mut result = Vec::with_capacity(k);
131    let mut needed = k;
132    let mut available = n;
133    for i in 0..n {
134        let u = rng.next_f64();
135        if u < needed as f64 / available as f64 {
136            result.push(i);
137            needed -= 1;
138            if needed == 0 { break; }
139        }
140        available -= 1;
141    }
142    result
143}
144
145/// Weighted sampling — draw one index proportional to weights.
146pub fn weighted_sample(weights: &[f64], rng: &mut impl Rng) -> usize {
147    let total: f64 = weights.iter().sum();
148    let mut r = rng.next_f64() * total;
149    for (i, &w) in weights.iter().enumerate() {
150        r -= w;
151        if r <= 0.0 { return i; }
152    }
153    weights.len() - 1
154}
155
156// ============================================================
157// DESCRIPTIVE STATISTICS
158// ============================================================
159
160/// Arithmetic mean.
161pub fn mean(data: &[f64]) -> f64 {
162    if data.is_empty() { return 0.0; }
163    data.iter().sum::<f64>() / data.len() as f64
164}
165
166/// Sample variance (Bessel's correction, n-1 denominator).
167pub fn variance(data: &[f64]) -> f64 {
168    let n = data.len();
169    if n < 2 { return 0.0; }
170    let m = mean(data);
171    data.iter().map(|x| (x - m).powi(2)).sum::<f64>() / (n - 1) as f64
172}
173
174/// Sample standard deviation.
175pub fn std_dev(data: &[f64]) -> f64 { variance(data).sqrt() }
176
177/// Median (sorts the slice in place).
178pub fn median(data: &mut [f64]) -> f64 {
179    data.sort_by(|a, b| a.partial_cmp(b).unwrap());
180    let n = data.len();
181    if n == 0 { return 0.0; }
182    if n % 2 == 0 { (data[n / 2 - 1] + data[n / 2]) / 2.0 } else { data[n / 2] }
183}
184
185/// Mode(s) — returns all values that appear most frequently.
186pub fn mode(data: &[f64]) -> Vec<f64> {
187    if data.is_empty() { return vec![]; }
188    let mut sorted = data.to_vec();
189    sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
190    let mut modes = Vec::new();
191    let mut max_count = 0usize;
192    let mut count = 1usize;
193    for i in 1..sorted.len() {
194        if (sorted[i] - sorted[i - 1]).abs() < 1e-12 {
195            count += 1;
196        } else {
197            if count > max_count { max_count = count; modes.clear(); modes.push(sorted[i - 1]); }
198            else if count == max_count { modes.push(sorted[i - 1]); }
199            count = 1;
200        }
201    }
202    let last = *sorted.last().unwrap();
203    if count > max_count { modes = vec![last]; }
204    else if count == max_count { modes.push(last); }
205    modes
206}
207
208/// p-th percentile (p in [0,100]).
209pub fn percentile(data: &mut [f64], p: f64) -> f64 {
210    data.sort_by(|a, b| a.partial_cmp(b).unwrap());
211    let n = data.len();
212    if n == 0 { return 0.0; }
213    let idx = (p / 100.0 * (n - 1) as f64).clamp(0.0, (n - 1) as f64);
214    let lo = idx.floor() as usize;
215    let hi = idx.ceil() as usize;
216    let frac = idx - lo as f64;
217    data[lo] + frac * (data[hi] - data[lo])
218}
219
220/// Interquartile range.
221pub fn iqr(data: &mut [f64]) -> f64 {
222    let q3 = percentile(data, 75.0);
223    let q1 = percentile(data, 25.0);
224    q3 - q1
225}
226
227/// Sample skewness.
228pub fn skewness(data: &[f64]) -> f64 {
229    let n = data.len() as f64;
230    if n < 3.0 { return 0.0; }
231    let m = mean(data);
232    let s = std_dev(data);
233    if s == 0.0 { return 0.0; }
234    let sum: f64 = data.iter().map(|x| ((x - m) / s).powi(3)).sum();
235    sum * n / ((n - 1.0) * (n - 2.0))
236}
237
238/// Sample excess kurtosis.
239pub fn kurtosis(data: &[f64]) -> f64 {
240    let n = data.len() as f64;
241    if n < 4.0 { return 0.0; }
242    let m = mean(data);
243    let s = std_dev(data);
244    if s == 0.0 { return 0.0; }
245    let sum: f64 = data.iter().map(|x| ((x - m) / s).powi(4)).sum();
246    let g2 = sum * n * (n + 1.0) / ((n - 1.0) * (n - 2.0) * (n - 3.0))
247        - 3.0 * (n - 1.0).powi(2) / ((n - 2.0) * (n - 3.0));
248    g2
249}
250
251/// Sample covariance.
252pub fn covariance(x: &[f64], y: &[f64]) -> f64 {
253    let n = x.len().min(y.len());
254    if n < 2 { return 0.0; }
255    let mx = mean(x);
256    let my = mean(y);
257    x.iter().zip(y.iter()).map(|(xi, yi)| (xi - mx) * (yi - my)).sum::<f64>() / (n - 1) as f64
258}
259
260/// Pearson correlation coefficient.
261pub fn pearson_r(x: &[f64], y: &[f64]) -> f64 {
262    let cov = covariance(x, y);
263    let sx = std_dev(x);
264    let sy = std_dev(y);
265    if sx == 0.0 || sy == 0.0 { return 0.0; }
266    cov / (sx * sy)
267}
268
269/// Spearman rank correlation.
270pub fn spearman_rho(x: &[f64], y: &[f64]) -> f64 {
271    let n = x.len().min(y.len());
272    if n < 2 { return 0.0; }
273    let rank = |data: &[f64]| -> Vec<f64> {
274        let mut indexed: Vec<(usize, f64)> = data.iter().copied().enumerate().collect();
275        indexed.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
276        let mut ranks = vec![0.0f64; indexed.len()];
277        let mut i = 0;
278        while i < indexed.len() {
279            let mut j = i;
280            while j < indexed.len() && (indexed[j].1 - indexed[i].1).abs() < 1e-12 { j += 1; }
281            let avg_rank = (i + j - 1) as f64 / 2.0 + 1.0;
282            for k in i..j { ranks[indexed[k].0] = avg_rank; }
283            i = j;
284        }
285        ranks
286    };
287    let rx: Vec<f64> = rank(&x[..n]);
288    let ry: Vec<f64> = rank(&y[..n]);
289    pearson_r(&rx, &ry)
290}
291
292// ============================================================
293// SPECIAL FUNCTIONS
294// ============================================================
295
296/// Error function erf(x).
297pub fn erf(x: f64) -> f64 {
298    // Abramowitz & Stegun approximation 7.1.26
299    let t = 1.0 / (1.0 + 0.3275911 * x.abs());
300    let poly = t * (0.254829592
301        + t * (-0.284496736
302        + t * (1.421413741
303        + t * (-1.453152027
304        + t * 1.061405429))));
305    let result = 1.0 - poly * (-x * x).exp();
306    if x >= 0.0 { result } else { -result }
307}
308
309/// Complementary error function erfc(x).
310pub fn erfc(x: f64) -> f64 { 1.0 - erf(x) }
311
312/// Natural log of gamma function (Lanczos approximation).
313pub fn lgamma(x: f64) -> f64 {
314    const G: f64 = 7.0;
315    const C: [f64; 9] = [
316        0.99999999999980993,
317        676.5203681218851,
318        -1259.1392167224028,
319        771.32342877765313,
320        -176.61502916214059,
321        12.507343278686905,
322        -0.13857109526572012,
323        9.9843695780195716e-6,
324        1.5056327351493116e-7,
325    ];
326    if x < 0.5 {
327        return (PI / ((PI * x).sin())).ln() - lgamma(1.0 - x);
328    }
329    let z = x - 1.0;
330    let mut t = z + G + 0.5;
331    let mut s = C[0];
332    for i in 1..9 { s += C[i] / (z + i as f64); }
333    0.5 * (2.0 * PI).ln() + s.ln() + (z + 0.5) * t.ln() - t
334}
335
336/// Gamma function.
337pub fn gamma(x: f64) -> f64 { lgamma(x).exp() }
338
339/// Regularized incomplete gamma function P(a, x) — lower.
340pub fn gammainc_lower(a: f64, x: f64) -> f64 {
341    if x <= 0.0 { return 0.0; }
342    if x < a + 1.0 {
343        // Series expansion
344        let mut term = 1.0 / a;
345        let mut sum = term;
346        for n in 1..200usize {
347            term *= x / (a + n as f64);
348            sum += term;
349            if term.abs() < sum.abs() * 1e-12 { break; }
350        }
351        sum * (-x + a * x.ln() - lgamma(a)).exp()
352    } else {
353        // Continued fraction (Lentz's method)
354        let eps = 1e-12;
355        let mut b = x + 1.0 - a;
356        let mut c = 1.0 / 1e-300;
357        let mut d = 1.0 / b;
358        let mut h = d;
359        for i in 1..200i64 {
360            let an = -i as f64 * (i as f64 - a);
361            b += 2.0;
362            d = an * d + b;
363            if d.abs() < 1e-300 { d = 1e-300; }
364            c = b + an / c;
365            if c.abs() < 1e-300 { c = 1e-300; }
366            d = 1.0 / d;
367            let del = d * c;
368            h *= del;
369            if (del - 1.0).abs() < eps { break; }
370        }
371        1.0 - (-x + a * x.ln() - lgamma(a)).exp() * h
372    }
373}
374
375/// Regularized incomplete beta function I_x(a,b).
376pub fn betainc(x: f64, a: f64, b: f64) -> f64 {
377    if x <= 0.0 { return 0.0; }
378    if x >= 1.0 { return 1.0; }
379    let lbeta = lgamma(a) + lgamma(b) - lgamma(a + b);
380    let factor = (a * x.ln() + b * (1.0 - x).ln() - lbeta).exp();
381    // Use symmetry relation for convergence
382    if x < (a + 1.0) / (a + b + 2.0) {
383        factor * betacf(x, a, b) / a
384    } else {
385        1.0 - factor * betacf(1.0 - x, b, a) / b
386    }
387}
388
389fn betacf(x: f64, a: f64, b: f64) -> f64 {
390    let max_iter = 200;
391    let eps = 1e-12;
392    let qab = a + b;
393    let qap = a + 1.0;
394    let qam = a - 1.0;
395    let mut c = 1.0;
396    let mut d = 1.0 - qab * x / qap;
397    if d.abs() < 1e-300 { d = 1e-300; }
398    d = 1.0 / d;
399    let mut h = d;
400    for m in 1..=max_iter {
401        let m = m as f64;
402        let m2 = 2.0 * m;
403        let mut aa = m * (b - m) * x / ((qam + m2) * (a + m2));
404        d = 1.0 + aa * d;
405        if d.abs() < 1e-300 { d = 1e-300; }
406        c = 1.0 + aa / c;
407        if c.abs() < 1e-300 { c = 1e-300; }
408        d = 1.0 / d;
409        h *= d * c;
410        aa = -(a + m) * (qab + m) * x / ((a + m2) * (qap + m2));
411        d = 1.0 + aa * d;
412        if d.abs() < 1e-300 { d = 1e-300; }
413        c = 1.0 + aa / c;
414        if c.abs() < 1e-300 { c = 1e-300; }
415        d = 1.0 / d;
416        let del = d * c;
417        h *= del;
418        if (del - 1.0).abs() < eps { break; }
419    }
420    h
421}
422
423/// Inverse normal CDF (probit function) via rational approximation.
424pub fn probit(p: f64) -> f64 {
425    let p = p.clamp(1e-12, 1.0 - 1e-12);
426    let sign = if p < 0.5 { -1.0 } else { 1.0 };
427    let q = if p < 0.5 { p } else { 1.0 - p };
428    let t = (-2.0 * q.ln()).sqrt();
429    const C: [f64; 3] = [2.515517, 0.802853, 0.010328];
430    const D: [f64; 3] = [1.432788, 0.189269, 0.001308];
431    let num = C[0] + C[1] * t + C[2] * t * t;
432    let den = 1.0 + D[0] * t + D[1] * t * t + D[2] * t * t * t;
433    sign * (t - num / den)
434}
435
436/// Two-tailed p-value from t statistic with df degrees of freedom.
437pub fn p_value_from_t(t: f64, df: f64) -> f64 {
438    // CDF of t-distribution via regularized incomplete beta
439    let x = df / (df + t * t);
440    let p_one_tail = 0.5 * betainc(x, df / 2.0, 0.5);
441    (2.0 * p_one_tail).min(1.0)
442}
443
444/// p-value from chi-squared statistic with k degrees of freedom.
445pub fn p_value_from_chi2(chi2: f64, k: usize) -> f64 {
446    if chi2 <= 0.0 { return 1.0; }
447    1.0 - gammainc_lower(k as f64 / 2.0, chi2 / 2.0)
448}
449
450// ============================================================
451// PROBABILITY DISTRIBUTIONS
452// ============================================================
453
454/// Normal (Gaussian) distribution.
455#[derive(Clone, Debug)]
456pub struct NormalDist {
457    pub mean: f64,
458    pub std_dev: f64,
459}
460
461impl NormalDist {
462    pub fn pdf(&self, x: f64) -> f64 {
463        let z = (x - self.mean) / self.std_dev;
464        (-0.5 * z * z).exp() / (self.std_dev * (2.0 * PI).sqrt())
465    }
466    pub fn cdf(&self, x: f64) -> f64 {
467        0.5 * (1.0 + erf((x - self.mean) / (self.std_dev * 2.0f64.sqrt())))
468    }
469    pub fn inv_cdf(&self, p: f64) -> f64 {
470        self.mean + self.std_dev * probit(p)
471    }
472    /// Box-Muller sampling. Returns two independent samples.
473    pub fn sample_pair(&self, rng: &mut impl Rng) -> (f64, f64) {
474        let u1 = rng.next_f64().max(1e-300);
475        let u2 = rng.next_f64();
476        let r = (-2.0 * u1.ln()).sqrt();
477        let theta = 2.0 * PI * u2;
478        let z0 = r * theta.cos();
479        let z1 = r * theta.sin();
480        (self.mean + self.std_dev * z0, self.mean + self.std_dev * z1)
481    }
482    pub fn sample(&self, rng: &mut impl Rng) -> f64 { self.sample_pair(rng).0 }
483}
484
485/// Continuous uniform distribution.
486#[derive(Clone, Debug)]
487pub struct UniformDist {
488    pub min: f64,
489    pub max: f64,
490}
491
492impl UniformDist {
493    pub fn pdf(&self, x: f64) -> f64 {
494        if x >= self.min && x <= self.max { 1.0 / (self.max - self.min) } else { 0.0 }
495    }
496    pub fn cdf(&self, x: f64) -> f64 {
497        ((x - self.min) / (self.max - self.min)).clamp(0.0, 1.0)
498    }
499    pub fn inv_cdf(&self, p: f64) -> f64 { self.min + p * (self.max - self.min) }
500    pub fn sample(&self, rng: &mut impl Rng) -> f64 { self.inv_cdf(rng.next_f64()) }
501}
502
503/// Exponential distribution.
504#[derive(Clone, Debug)]
505pub struct ExponentialDist {
506    pub lambda: f64,
507}
508
509impl ExponentialDist {
510    pub fn pdf(&self, x: f64) -> f64 {
511        if x < 0.0 { 0.0 } else { self.lambda * (-self.lambda * x).exp() }
512    }
513    pub fn cdf(&self, x: f64) -> f64 {
514        if x < 0.0 { 0.0 } else { 1.0 - (-self.lambda * x).exp() }
515    }
516    pub fn inv_cdf(&self, p: f64) -> f64 { -((1.0 - p).max(1e-300)).ln() / self.lambda }
517    pub fn sample(&self, rng: &mut impl Rng) -> f64 { self.inv_cdf(rng.next_f64()) }
518}
519
520/// Poisson distribution.
521#[derive(Clone, Debug)]
522pub struct PoissonDist {
523    pub lambda: f64,
524}
525
526impl PoissonDist {
527    pub fn pmf(&self, k: u64) -> f64 {
528        (-self.lambda).exp() * self.lambda.powi(k as i32) / gamma(k as f64 + 1.0)
529    }
530    pub fn cdf(&self, k: u64) -> f64 {
531        (0..=k).map(|i| self.pmf(i)).sum()
532    }
533    /// Knuth algorithm for Poisson sampling.
534    pub fn sample(&self, rng: &mut impl Rng) -> u64 {
535        let l = (-self.lambda).exp();
536        let mut k = 0u64;
537        let mut p = 1.0;
538        loop {
539            k += 1;
540            p *= rng.next_f64();
541            if p <= l { break; }
542        }
543        k - 1
544    }
545}
546
547/// Binomial distribution.
548#[derive(Clone, Debug)]
549pub struct BinomialDist {
550    pub n: u64,
551    pub p: f64,
552}
553
554impl BinomialDist {
555    pub fn pmf(&self, k: u64) -> f64 {
556        if k > self.n { return 0.0; }
557        let log_coeff = lgamma(self.n as f64 + 1.0)
558            - lgamma(k as f64 + 1.0)
559            - lgamma((self.n - k) as f64 + 1.0);
560        (log_coeff + k as f64 * self.p.ln() + (self.n - k) as f64 * (1.0 - self.p).ln()).exp()
561    }
562    pub fn cdf(&self, k: u64) -> f64 {
563        (0..=k).map(|i| self.pmf(i)).sum()
564    }
565    pub fn sample(&self, rng: &mut impl Rng) -> u64 {
566        (0..self.n).filter(|_| rng.next_f64() < self.p).count() as u64
567    }
568}
569
570/// Beta distribution (Johnk's method for sampling).
571#[derive(Clone, Debug)]
572pub struct BetaDist {
573    pub alpha: f64,
574    pub beta: f64,
575}
576
577impl BetaDist {
578    pub fn pdf(&self, x: f64) -> f64 {
579        if x <= 0.0 || x >= 1.0 { return 0.0; }
580        let lbeta = lgamma(self.alpha) + lgamma(self.beta) - lgamma(self.alpha + self.beta);
581        ((self.alpha - 1.0) * x.ln() + (self.beta - 1.0) * (1.0 - x).ln() - lbeta).exp()
582    }
583    pub fn cdf(&self, x: f64) -> f64 { betainc(x, self.alpha, self.beta) }
584    pub fn sample(&self, rng: &mut impl Rng) -> f64 {
585        // Johnk's method
586        loop {
587            let u = rng.next_f64();
588            let v = rng.next_f64();
589            let x = u.powf(1.0 / self.alpha);
590            let y = v.powf(1.0 / self.beta);
591            if x + y <= 1.0 { return x / (x + y); }
592        }
593    }
594}
595
596/// Gamma distribution (Marsaglia-Tsang method for alpha >= 1).
597#[derive(Clone, Debug)]
598pub struct GammaDist {
599    pub shape: f64,  // alpha / k
600    pub scale: f64,  // theta
601}
602
603impl GammaDist {
604    pub fn pdf(&self, x: f64) -> f64 {
605        if x <= 0.0 { return 0.0; }
606        let log_scale = self.scale.ln();
607        ((self.shape - 1.0) * x.ln() - x / self.scale - self.shape * log_scale - lgamma(self.shape)).exp()
608    }
609    pub fn cdf(&self, x: f64) -> f64 {
610        if x <= 0.0 { return 0.0; }
611        gammainc_lower(self.shape, x / self.scale)
612    }
613    pub fn sample(&self, rng: &mut impl Rng) -> f64 {
614        let alpha = self.shape;
615        let s = if alpha >= 1.0 {
616            // Marsaglia-Tsang
617            let d = alpha - 1.0 / 3.0;
618            let c = 1.0 / (9.0 * d).sqrt();
619            let norm = NormalDist { mean: 0.0, std_dev: 1.0 };
620            loop {
621                let x = norm.sample(rng);
622                let v = (1.0 + c * x).powi(3);
623                if v <= 0.0 { continue; }
624                let u = rng.next_f64();
625                if u < 1.0 - 0.0331 * (x * x).powi(2) { break d * v; }
626                if u.ln() < 0.5 * x * x + d * (1.0 - v + v.ln()) { break d * v; }
627            }
628        } else {
629            // alpha < 1: use alpha+1 and scale
630            let d = alpha + 1.0 - 1.0 / 3.0;
631            let c = 1.0 / (9.0 * d).sqrt();
632            let norm = NormalDist { mean: 0.0, std_dev: 1.0 };
633            let s_plus1 = loop {
634                let x = norm.sample(rng);
635                let v = (1.0 + c * x).powi(3);
636                if v <= 0.0 { continue; }
637                let u = rng.next_f64();
638                if u < 1.0 - 0.0331 * (x * x).powi(2) { break d * v; }
639                if u.ln() < 0.5 * x * x + d * (1.0 - v + v.ln()) { break d * v; }
640            };
641            s_plus1 * rng.next_f64().powf(1.0 / alpha)
642        };
643        s * self.scale
644    }
645}
646
647/// Log-Normal distribution.
648#[derive(Clone, Debug)]
649pub struct LogNormalDist {
650    pub mu: f64,
651    pub sigma: f64,
652}
653
654impl LogNormalDist {
655    pub fn pdf(&self, x: f64) -> f64 {
656        if x <= 0.0 { return 0.0; }
657        let z = (x.ln() - self.mu) / self.sigma;
658        (-0.5 * z * z).exp() / (x * self.sigma * (2.0 * PI).sqrt())
659    }
660    pub fn cdf(&self, x: f64) -> f64 {
661        if x <= 0.0 { return 0.0; }
662        0.5 * (1.0 + erf((x.ln() - self.mu) / (self.sigma * 2.0f64.sqrt())))
663    }
664    pub fn sample(&self, rng: &mut impl Rng) -> f64 {
665        let norm = NormalDist { mean: self.mu, std_dev: self.sigma };
666        norm.sample(rng).exp()
667    }
668}
669
670/// Weibull distribution.
671#[derive(Clone, Debug)]
672pub struct WeibullDist {
673    pub shape: f64,   // k
674    pub scale: f64,   // lambda
675}
676
677impl WeibullDist {
678    pub fn pdf(&self, x: f64) -> f64 {
679        if x < 0.0 { return 0.0; }
680        let k = self.shape; let l = self.scale;
681        (k / l) * (x / l).powf(k - 1.0) * (-(x / l).powf(k)).exp()
682    }
683    pub fn cdf(&self, x: f64) -> f64 {
684        if x < 0.0 { return 0.0; }
685        1.0 - (-(x / self.scale).powf(self.shape)).exp()
686    }
687    pub fn inv_cdf(&self, p: f64) -> f64 {
688        self.scale * (-(1.0 - p).ln()).powf(1.0 / self.shape)
689    }
690    pub fn sample(&self, rng: &mut impl Rng) -> f64 { self.inv_cdf(rng.next_f64()) }
691}
692
693/// Cauchy distribution.
694#[derive(Clone, Debug)]
695pub struct CauchyDist {
696    pub location: f64,
697    pub scale: f64,
698}
699
700impl CauchyDist {
701    pub fn pdf(&self, x: f64) -> f64 {
702        let z = (x - self.location) / self.scale;
703        1.0 / (PI * self.scale * (1.0 + z * z))
704    }
705    pub fn cdf(&self, x: f64) -> f64 {
706        0.5 + ((x - self.location) / self.scale).atan() / PI
707    }
708    pub fn inv_cdf(&self, p: f64) -> f64 {
709        self.location + self.scale * (PI * (p - 0.5)).tan()
710    }
711    pub fn sample(&self, rng: &mut impl Rng) -> f64 { self.inv_cdf(rng.next_f64()) }
712}
713
714/// Student's t-distribution.
715#[derive(Clone, Debug)]
716pub struct StudentTDist {
717    pub degrees_of_freedom: f64,
718}
719
720impl StudentTDist {
721    pub fn pdf(&self, t: f64) -> f64 {
722        let nu = self.degrees_of_freedom;
723        let coeff = gamma((nu + 1.0) / 2.0) / (gamma(nu / 2.0) * (nu * PI).sqrt());
724        coeff * (1.0 + t * t / nu).powf(-(nu + 1.0) / 2.0)
725    }
726    pub fn cdf(&self, t: f64) -> f64 {
727        let nu = self.degrees_of_freedom;
728        let x = nu / (nu + t * t);
729        let ib = betainc(x, nu / 2.0, 0.5) / 2.0;
730        if t > 0.0 { 1.0 - ib } else { ib }
731    }
732    pub fn sample(&self, rng: &mut impl Rng) -> f64 {
733        let z = NormalDist { mean: 0.0, std_dev: 1.0 }.sample(rng);
734        let chi2 = GammaDist { shape: self.degrees_of_freedom / 2.0, scale: 2.0 }.sample(rng);
735        z / (chi2 / self.degrees_of_freedom).sqrt()
736    }
737}
738
739/// Chi-squared distribution.
740#[derive(Clone, Debug)]
741pub struct ChiSquaredDist {
742    pub k: f64,
743}
744
745impl ChiSquaredDist {
746    pub fn pdf(&self, x: f64) -> f64 {
747        GammaDist { shape: self.k / 2.0, scale: 2.0 }.pdf(x)
748    }
749    pub fn cdf(&self, x: f64) -> f64 {
750        if x <= 0.0 { return 0.0; }
751        gammainc_lower(self.k / 2.0, x / 2.0)
752    }
753    pub fn sample(&self, rng: &mut impl Rng) -> f64 {
754        GammaDist { shape: self.k / 2.0, scale: 2.0 }.sample(rng)
755    }
756}
757
758// ============================================================
759// HYPOTHESIS TESTING
760// ============================================================
761
762/// One-sample t-test against mu0.
763/// Returns (t-statistic, two-tailed p-value).
764pub fn t_test_one_sample(data: &[f64], mu0: f64) -> (f64, f64) {
765    let n = data.len() as f64;
766    if n < 2.0 { return (0.0, 1.0); }
767    let xbar = mean(data);
768    let s = std_dev(data);
769    if s == 0.0 { return (0.0, 1.0); }
770    let t = (xbar - mu0) / (s / n.sqrt());
771    let p = p_value_from_t(t, n - 1.0);
772    (t, p)
773}
774
775/// Welch's two-sample t-test.
776/// Returns (t-statistic, two-tailed p-value).
777pub fn t_test_two_sample(a: &[f64], b: &[f64]) -> (f64, f64) {
778    let na = a.len() as f64;
779    let nb = b.len() as f64;
780    if na < 2.0 || nb < 2.0 { return (0.0, 1.0); }
781    let ma = mean(a);
782    let mb = mean(b);
783    let sa2 = variance(a);
784    let sb2 = variance(b);
785    let se = (sa2 / na + sb2 / nb).sqrt();
786    if se == 0.0 { return (0.0, 1.0); }
787    let t = (ma - mb) / se;
788    // Welch-Satterthwaite degrees of freedom
789    let df = (sa2 / na + sb2 / nb).powi(2)
790        / ((sa2 / na).powi(2) / (na - 1.0) + (sb2 / nb).powi(2) / (nb - 1.0));
791    let p = p_value_from_t(t, df);
792    (t, p)
793}
794
795/// Chi-squared goodness-of-fit test.
796/// Returns (chi2-statistic, p-value).
797pub fn chi_squared_test(observed: &[f64], expected: &[f64]) -> (f64, f64) {
798    let chi2: f64 = observed
799        .iter()
800        .zip(expected.iter())
801        .map(|(o, e)| if *e > 0.0 { (o - e).powi(2) / e } else { 0.0 })
802        .sum();
803    let df = (observed.len() - 1).max(1);
804    let p = p_value_from_chi2(chi2, df);
805    (chi2, p)
806}
807
808/// Kolmogorov-Smirnov test against a theoretical CDF.
809/// Returns (D-statistic, approximate p-value).
810pub fn ks_test(data: &[f64], cdf: impl Fn(f64) -> f64) -> (f64, f64) {
811    let n = data.len();
812    if n == 0 { return (0.0, 1.0); }
813    let mut sorted = data.to_vec();
814    sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
815    let mut d = 0.0f64;
816    for (i, &x) in sorted.iter().enumerate() {
817        let empirical_upper = (i + 1) as f64 / n as f64;
818        let empirical_lower = i as f64 / n as f64;
819        let theoretical = cdf(x);
820        d = d.max((empirical_upper - theoretical).abs());
821        d = d.max((empirical_lower - theoretical).abs());
822    }
823    // Approximate p-value using Kolmogorov distribution
824    let sqrt_n = (n as f64).sqrt();
825    let z = (sqrt_n + 0.12 + 0.11 / sqrt_n) * d;
826    // Two-tailed KS p-value approximation
827    let p = if z <= 0.0 { 1.0 } else {
828        let mut sum = 0.0;
829        for k in 1..50i64 {
830            let sign = if k % 2 == 0 { 1.0 } else { -1.0 };
831            sum += sign * (-2.0 * (k as f64).powi(2) * z * z).exp();
832        }
833        (2.0 * sum).clamp(0.0, 1.0)
834    };
835    (d, p)
836}
837
838/// Mann-Whitney U test (non-parametric, two-sample).
839/// Returns (U-statistic, approximate two-tailed p-value).
840pub fn mann_whitney_u(a: &[f64], b: &[f64]) -> (f64, f64) {
841    let na = a.len();
842    let nb = b.len();
843    let mut u = 0.0f64;
844    for &ai in a {
845        for &bi in b {
846            if ai > bi { u += 1.0; }
847            else if ai == bi { u += 0.5; }
848        }
849    }
850    let mean_u = na as f64 * nb as f64 / 2.0;
851    let std_u = ((na as f64 * nb as f64 * (na + nb + 1) as f64) / 12.0).sqrt();
852    if std_u == 0.0 { return (u, 1.0); }
853    let z = (u - mean_u) / std_u;
854    let norm = NormalDist { mean: 0.0, std_dev: 1.0 };
855    let p = 2.0 * (1.0 - norm.cdf(z.abs()));
856    (u, p)
857}
858
859/// Shapiro-Wilk test statistic W for normality.
860/// Uses first 20 a-coefficients approximation.
861pub fn shapiro_wilk_stat(data: &[f64]) -> f64 {
862    let n = data.len();
863    if n < 3 { return 1.0; }
864    let mut sorted = data.to_vec();
865    sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
866    let m = mean(&sorted);
867    let ss: f64 = sorted.iter().map(|x| (x - m).powi(2)).sum();
868    if ss == 0.0 { return 1.0; }
869    // Approximate a coefficients using expected normal order statistics
870    let norm = NormalDist { mean: 0.0, std_dev: 1.0 };
871    let half = n / 2;
872    let mut b = 0.0f64;
873    for i in 0..half {
874        let expected_i = norm.inv_cdf((i as f64 + 0.625) / (n as f64 + 0.25));
875        let expected_n_i = norm.inv_cdf((n as f64 - 1.0 - i as f64 + 0.625) / (n as f64 + 0.25));
876        let a_i = expected_n_i - expected_i;
877        b += a_i * (sorted[n - 1 - i] - sorted[i]);
878    }
879    b * b / ss
880}
881
882// ============================================================
883// REGRESSION
884// ============================================================
885
886/// Simple linear regression: y = slope * x + intercept.
887pub fn linear_regression(x: &[f64], y: &[f64]) -> (f64, f64) {
888    let n = x.len().min(y.len()) as f64;
889    if n < 2.0 { return (0.0, 0.0); }
890    let mx = mean(x);
891    let my = mean(y);
892    let ss_xx: f64 = x.iter().map(|xi| (xi - mx).powi(2)).sum();
893    let ss_xy: f64 = x.iter().zip(y.iter()).map(|(xi, yi)| (xi - mx) * (yi - my)).sum();
894    if ss_xx == 0.0 { return (0.0, my); }
895    let slope = ss_xy / ss_xx;
896    let intercept = my - slope * mx;
897    (slope, intercept)
898}
899
900/// Polynomial regression of given degree. Returns coefficients [a0, a1, ..., a_deg].
901pub fn polynomial_regression(x: &[f64], y: &[f64], degree: usize) -> Vec<f64> {
902    let n = x.len().min(y.len());
903    let d = degree + 1;
904    // Build Vandermonde matrix X
905    let mut xmat = vec![vec![0.0f64; d]; n];
906    for i in 0..n {
907        for j in 0..d {
908            xmat[i][j] = x[i].powi(j as i32);
909        }
910    }
911    // X^T X
912    let mut xtx = vec![vec![0.0f64; d]; d];
913    for r in 0..d {
914        for c in 0..d {
915            for i in 0..n { xtx[r][c] += xmat[i][r] * xmat[i][c]; }
916        }
917    }
918    // X^T y
919    let mut xty = vec![0.0f64; d];
920    for r in 0..d {
921        for i in 0..n { xty[r] += xmat[i][r] * y[i]; }
922    }
923    // Solve xtx * coeffs = xty via Gaussian elimination
924    solve_system(&mut xtx, &mut xty).unwrap_or_else(|| vec![0.0; d])
925}
926
927fn solve_system(a: &mut Vec<Vec<f64>>, b: &mut Vec<f64>) -> Option<Vec<f64>> {
928    let n = b.len();
929    for k in 0..n {
930        let mut max_val = a[k][k].abs();
931        let mut max_row = k;
932        for i in k + 1..n {
933            if a[i][k].abs() > max_val { max_val = a[i][k].abs(); max_row = i; }
934        }
935        if max_val < 1e-12 { return None; }
936        a.swap(k, max_row);
937        b.swap(k, max_row);
938        let pivot = a[k][k];
939        for j in k..n { a[k][j] /= pivot; }
940        b[k] /= pivot;
941        for i in 0..n {
942            if i != k {
943                let factor = a[i][k];
944                for j in k..n { a[i][j] -= factor * a[k][j]; }
945                b[i] -= factor * b[k];
946            }
947        }
948    }
949    Some(b.clone())
950}
951
952/// Multiple linear regression (OLS). X is n_samples × n_features.
953/// Returns coefficient vector (including intercept as first element).
954pub fn multiple_linear_regression(x: &[Vec<f64>], y: &[f64]) -> Vec<f64> {
955    let n = x.len().min(y.len());
956    if n == 0 { return vec![]; }
957    let p = x[0].len() + 1; // +1 for intercept
958    // Build design matrix with intercept column
959    let mut xmat = vec![vec![0.0f64; p]; n];
960    for i in 0..n {
961        xmat[i][0] = 1.0;
962        for j in 1..p { xmat[i][j] = x[i][j - 1]; }
963    }
964    // X^T X
965    let mut xtx = vec![vec![0.0f64; p]; p];
966    for r in 0..p {
967        for c in 0..p {
968            for i in 0..n { xtx[r][c] += xmat[i][r] * xmat[i][c]; }
969        }
970    }
971    // X^T y
972    let mut xty = vec![0.0f64; p];
973    for r in 0..p {
974        for i in 0..n { xty[r] += xmat[i][r] * y[i]; }
975    }
976    solve_system(&mut xtx, &mut xty).unwrap_or_else(|| vec![0.0; p])
977}
978
979/// R-squared coefficient of determination.
980pub fn r_squared(y_true: &[f64], y_pred: &[f64]) -> f64 {
981    let n = y_true.len().min(y_pred.len());
982    if n == 0 { return 0.0; }
983    let mean_true = mean(y_true);
984    let ss_res: f64 = y_true.iter().zip(y_pred.iter()).map(|(y, yh)| (y - yh).powi(2)).sum();
985    let ss_tot: f64 = y_true.iter().map(|y| (y - mean_true).powi(2)).sum();
986    if ss_tot == 0.0 { return 1.0; }
987    1.0 - ss_res / ss_tot
988}
989
990/// Ridge regression (L2 regularized OLS). Returns coefficients.
991pub fn ridge_regression(x: &[Vec<f64>], y: &[f64], lambda: f64) -> Vec<f64> {
992    let n = x.len().min(y.len());
993    if n == 0 { return vec![]; }
994    let p = x[0].len() + 1;
995    let mut xmat = vec![vec![0.0f64; p]; n];
996    for i in 0..n {
997        xmat[i][0] = 1.0;
998        for j in 1..p { xmat[i][j] = x[i][j - 1]; }
999    }
1000    let mut xtx = vec![vec![0.0f64; p]; p];
1001    for r in 0..p {
1002        for c in 0..p {
1003            for i in 0..n { xtx[r][c] += xmat[i][r] * xmat[i][c]; }
1004        }
1005    }
1006    // Add lambda * I (skip intercept at index 0)
1007    for j in 1..p { xtx[j][j] += lambda; }
1008    let mut xty = vec![0.0f64; p];
1009    for r in 0..p {
1010        for i in 0..n { xty[r] += xmat[i][r] * y[i]; }
1011    }
1012    solve_system(&mut xtx, &mut xty).unwrap_or_else(|| vec![0.0; p])
1013}
1014
1015/// Logistic regression via gradient descent.
1016/// `x` is n_samples × n_features, `y` is bool labels.
1017/// Returns weight vector (n_features + 1, including intercept).
1018pub fn logistic_regression(x: &[Vec<f64>], y: &[bool], lr: f64, epochs: usize) -> Vec<f64> {
1019    let n = x.len().min(y.len());
1020    if n == 0 { return vec![]; }
1021    let p = x[0].len() + 1;
1022    let mut w = vec![0.0f64; p];
1023    let sigmoid = |z: f64| 1.0 / (1.0 + (-z).exp());
1024    for _ in 0..epochs {
1025        let mut grad = vec![0.0f64; p];
1026        for i in 0..n {
1027            let mut z = w[0];
1028            for j in 1..p { z += w[j] * x[i][j - 1]; }
1029            let pred = sigmoid(z);
1030            let target = if y[i] { 1.0 } else { 0.0 };
1031            let err = pred - target;
1032            grad[0] += err;
1033            for j in 1..p { grad[j] += err * x[i][j - 1]; }
1034        }
1035        for j in 0..p { w[j] -= lr * grad[j] / n as f64; }
1036    }
1037    w
1038}
1039
1040// ============================================================
1041// BAYESIAN INFERENCE
1042// ============================================================
1043
1044/// Beta-Bernoulli conjugate model.
1045#[derive(Clone, Debug)]
1046pub struct BetaBernoulli {
1047    pub alpha: f64,
1048    pub beta: f64,
1049}
1050
1051/// Update Beta prior with new Bernoulli observations.
1052pub fn update_beta_bernoulli(prior: BetaBernoulli, successes: u32, failures: u32) -> BetaBernoulli {
1053    BetaBernoulli {
1054        alpha: prior.alpha + successes as f64,
1055        beta: prior.beta + failures as f64,
1056    }
1057}
1058
1059/// Posterior mean of Beta-Bernoulli model.
1060pub fn posterior_mean(dist: &BetaBernoulli) -> f64 {
1061    dist.alpha / (dist.alpha + dist.beta)
1062}
1063
1064/// Equal-tailed credible interval for Beta distribution.
1065pub fn credible_interval(dist: &BetaBernoulli, level: f64) -> (f64, f64) {
1066    let tail = (1.0 - level) / 2.0;
1067    let beta = BetaDist { alpha: dist.alpha, beta: dist.beta };
1068    // Numerical inversion of beta CDF
1069    let inv_beta_cdf = |p: f64| -> f64 {
1070        let mut lo = 0.0f64;
1071        let mut hi = 1.0f64;
1072        for _ in 0..100 {
1073            let mid = (lo + hi) * 0.5;
1074            if beta.cdf(mid) < p { lo = mid; } else { hi = mid; }
1075        }
1076        (lo + hi) * 0.5
1077    };
1078    (inv_beta_cdf(tail), inv_beta_cdf(1.0 - tail))
1079}
1080
1081/// Gaussian-Gaussian conjugate model (known variance).
1082#[derive(Clone, Debug)]
1083pub struct GaussianGaussian {
1084    pub prior_mean: f64,
1085    pub prior_variance: f64,
1086    pub likelihood_variance: f64,
1087}
1088
1089impl GaussianGaussian {
1090    /// Update posterior given n observations with sample mean.
1091    pub fn update(&self, sample_mean: f64, n: usize) -> (f64, f64) {
1092        let n = n as f64;
1093        let lv = self.likelihood_variance;
1094        let pv = self.prior_variance;
1095        let post_var = 1.0 / (1.0 / pv + n / lv);
1096        let post_mean = post_var * (self.prior_mean / pv + n * sample_mean / lv);
1097        (post_mean, post_var)
1098    }
1099}
1100
1101/// Bayesian Information Criterion.
1102pub fn bayesian_information_criterion(log_likelihood: f64, n_params: usize, n_samples: usize) -> f64 {
1103    -2.0 * log_likelihood + n_params as f64 * (n_samples as f64).ln()
1104}
1105
1106/// Akaike Information Criterion.
1107pub fn akaike_information_criterion(log_likelihood: f64, n_params: usize) -> f64 {
1108    -2.0 * log_likelihood + 2.0 * n_params as f64
1109}
1110
1111// ============================================================
1112// INFORMATION THEORY
1113// ============================================================
1114
1115/// Shannon entropy in nats (natural log base).
1116pub fn entropy(probs: &[f64]) -> f64 {
1117    probs.iter()
1118        .filter(|&&p| p > 0.0)
1119        .map(|&p| -p * p.ln())
1120        .sum()
1121}
1122
1123/// Cross-entropy H(P, Q) = -sum_x P(x) log Q(x).
1124pub fn cross_entropy(p: &[f64], q: &[f64]) -> f64 {
1125    p.iter()
1126        .zip(q.iter())
1127        .filter(|(&pi, &qi)| pi > 0.0 && qi > 0.0)
1128        .map(|(&pi, &qi)| -pi * qi.ln())
1129        .sum()
1130}
1131
1132/// KL divergence D_KL(P || Q) = sum_x P(x) log(P(x)/Q(x)).
1133pub fn kl_divergence(p: &[f64], q: &[f64]) -> f64 {
1134    p.iter()
1135        .zip(q.iter())
1136        .filter(|(&pi, &qi)| pi > 0.0 && qi > 0.0)
1137        .map(|(&pi, &qi)| pi * (pi / qi).ln())
1138        .sum()
1139}
1140
1141/// Mutual information I(X;Y) from joint probability matrix.
1142pub fn mutual_information(joint: &[Vec<f64>]) -> f64 {
1143    let rows = joint.len();
1144    if rows == 0 { return 0.0; }
1145    let cols = joint[0].len();
1146    let px: Vec<f64> = (0..rows).map(|i| joint[i].iter().sum()).collect();
1147    let py: Vec<f64> = (0..cols).map(|j| joint.iter().map(|row| row[j]).sum()).collect();
1148    let mut mi = 0.0;
1149    for i in 0..rows {
1150        for j in 0..cols {
1151            let pij = joint[i][j];
1152            if pij > 0.0 && px[i] > 0.0 && py[j] > 0.0 {
1153                mi += pij * (pij / (px[i] * py[j])).ln();
1154            }
1155        }
1156    }
1157    mi
1158}
1159
1160/// Jensen-Shannon divergence — symmetric, bounded [0, ln(2)].
1161pub fn jensen_shannon_divergence(p: &[f64], q: &[f64]) -> f64 {
1162    let m: Vec<f64> = p.iter().zip(q.iter()).map(|(pi, qi)| (pi + qi) * 0.5).collect();
1163    0.5 * kl_divergence(p, &m) + 0.5 * kl_divergence(q, &m)
1164}
1165
1166// ============================================================
1167// TESTS
1168// ============================================================
1169
1170#[cfg(test)]
1171mod tests {
1172    use super::*;
1173
1174    #[test]
1175    fn test_mean() {
1176        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
1177        assert!((mean(&data) - 3.0).abs() < 1e-10);
1178    }
1179
1180    #[test]
1181    fn test_variance() {
1182        let data = vec![2.0, 4.0, 4.0, 4.0, 5.0, 5.0, 7.0, 9.0];
1183        assert!((variance(&data) - 4.571428571428571).abs() < 1e-8);
1184    }
1185
1186    #[test]
1187    fn test_std_dev() {
1188        let data = vec![2.0, 4.0, 4.0, 4.0, 5.0, 5.0, 7.0, 9.0];
1189        assert!((std_dev(&data) - 2.138).abs() < 0.001);
1190    }
1191
1192    #[test]
1193    fn test_median_odd() {
1194        let mut data = vec![3.0, 1.0, 4.0, 1.0, 5.0];
1195        assert_eq!(median(&mut data), 3.0);
1196    }
1197
1198    #[test]
1199    fn test_median_even() {
1200        let mut data = vec![1.0, 2.0, 3.0, 4.0];
1201        assert!((median(&mut data) - 2.5).abs() < 1e-10);
1202    }
1203
1204    #[test]
1205    fn test_percentile() {
1206        let mut data: Vec<f64> = (1..=100).map(|x| x as f64).collect();
1207        assert!((percentile(&mut data, 50.0) - 50.5).abs() < 0.5);
1208    }
1209
1210    #[test]
1211    fn test_pearson_r_perfect() {
1212        let x = vec![1.0, 2.0, 3.0, 4.0, 5.0];
1213        let y: Vec<f64> = x.iter().map(|xi| 2.0 * xi + 1.0).collect();
1214        assert!((pearson_r(&x, &y) - 1.0).abs() < 1e-10);
1215    }
1216
1217    #[test]
1218    fn test_spearman_rho() {
1219        let x = vec![1.0, 2.0, 3.0, 4.0, 5.0];
1220        let y = vec![5.0, 4.0, 3.0, 2.0, 1.0];
1221        assert!((spearman_rho(&x, &y) + 1.0).abs() < 1e-10);
1222    }
1223
1224    #[test]
1225    fn test_normal_dist_cdf() {
1226        let n = NormalDist { mean: 0.0, std_dev: 1.0 };
1227        assert!((n.cdf(0.0) - 0.5).abs() < 1e-6);
1228        assert!((n.cdf(1.96) - 0.975).abs() < 0.001);
1229    }
1230
1231    #[test]
1232    fn test_normal_dist_sample() {
1233        let n = NormalDist { mean: 5.0, std_dev: 2.0 };
1234        let mut rng = Xorshift64::new(42);
1235        let samples: Vec<f64> = (0..10000).map(|_| n.sample(&mut rng)).collect();
1236        let m = mean(&samples);
1237        assert!((m - 5.0).abs() < 0.1, "mean {} far from 5.0", m);
1238    }
1239
1240    #[test]
1241    fn test_exponential_sample() {
1242        let e = ExponentialDist { lambda: 2.0 };
1243        let mut rng = Xorshift64::new(42);
1244        let samples: Vec<f64> = (0..10000).map(|_| e.sample(&mut rng)).collect();
1245        let m = mean(&samples);
1246        assert!((m - 0.5).abs() < 0.05, "mean {} far from 0.5", m);
1247    }
1248
1249    #[test]
1250    fn test_poisson_sample() {
1251        let p = PoissonDist { lambda: 3.0 };
1252        let mut rng = Xorshift64::new(42);
1253        let samples: Vec<f64> = (0..10000).map(|_| p.sample(&mut rng) as f64).collect();
1254        let m = mean(&samples);
1255        assert!((m - 3.0).abs() < 0.1, "mean {} far from 3.0", m);
1256    }
1257
1258    #[test]
1259    fn test_gamma_sample() {
1260        let g = GammaDist { shape: 2.0, scale: 3.0 };
1261        let mut rng = Xorshift64::new(99);
1262        let samples: Vec<f64> = (0..10000).map(|_| g.sample(&mut rng)).collect();
1263        let m = mean(&samples);
1264        // Expected mean = shape * scale = 6
1265        assert!((m - 6.0).abs() < 0.2, "mean {} far from 6.0", m);
1266    }
1267
1268    #[test]
1269    fn test_t_test_one_sample() {
1270        let data = vec![10.0, 11.0, 9.5, 10.5, 10.2, 9.8, 10.1, 9.9, 10.3, 10.4];
1271        let (t, p) = t_test_one_sample(&data, 10.0);
1272        assert!(p > 0.05, "should not reject null at 10.0; p={}", p);
1273        let _ = t;
1274    }
1275
1276    #[test]
1277    fn test_t_test_two_sample() {
1278        let a = vec![1.0, 2.0, 3.0, 4.0, 5.0];
1279        let b = vec![6.0, 7.0, 8.0, 9.0, 10.0];
1280        let (_t, p) = t_test_two_sample(&a, &b);
1281        assert!(p < 0.05, "should reject null; p={}", p);
1282    }
1283
1284    #[test]
1285    fn test_chi_squared_test() {
1286        let obs = vec![10.0, 20.0, 30.0];
1287        let exp = vec![10.0, 20.0, 30.0];
1288        let (chi2, p) = chi_squared_test(&obs, &exp);
1289        assert!(chi2.abs() < 1e-10);
1290        assert!(p > 0.9);
1291    }
1292
1293    #[test]
1294    fn test_linear_regression() {
1295        let x = vec![1.0, 2.0, 3.0, 4.0, 5.0];
1296        let y: Vec<f64> = x.iter().map(|xi| 3.0 * xi + 1.0).collect();
1297        let (slope, intercept) = linear_regression(&x, &y);
1298        assert!((slope - 3.0).abs() < 1e-10);
1299        assert!((intercept - 1.0).abs() < 1e-10);
1300    }
1301
1302    #[test]
1303    fn test_polynomial_regression() {
1304        let x = vec![0.0, 1.0, 2.0, 3.0];
1305        let y: Vec<f64> = x.iter().map(|xi| xi * xi + 2.0 * xi + 1.0).collect();
1306        let coeffs = polynomial_regression(&x, &y, 2);
1307        assert_eq!(coeffs.len(), 3);
1308        assert!((coeffs[0] - 1.0).abs() < 1e-6);
1309        assert!((coeffs[1] - 2.0).abs() < 1e-6);
1310        assert!((coeffs[2] - 1.0).abs() < 1e-6);
1311    }
1312
1313    #[test]
1314    fn test_r_squared() {
1315        let y_true = vec![1.0, 2.0, 3.0, 4.0, 5.0];
1316        let y_pred = vec![1.0, 2.0, 3.0, 4.0, 5.0];
1317        assert!((r_squared(&y_true, &y_pred) - 1.0).abs() < 1e-10);
1318    }
1319
1320    #[test]
1321    fn test_logistic_regression_separable() {
1322        let x = vec![vec![-2.0], vec![-1.0], vec![1.0], vec![2.0]];
1323        let y = vec![false, false, true, true];
1324        let w = logistic_regression(&x, &y, 0.5, 500);
1325        let sigmoid = |z: f64| 1.0 / (1.0 + (-z).exp());
1326        let pred_neg = sigmoid(w[0] + w[1] * (-2.0));
1327        let pred_pos = sigmoid(w[0] + w[1] * 2.0);
1328        assert!(pred_neg < 0.5, "negative class should have prob < 0.5");
1329        assert!(pred_pos > 0.5, "positive class should have prob > 0.5");
1330    }
1331
1332    #[test]
1333    fn test_bayesian_update() {
1334        let prior = BetaBernoulli { alpha: 1.0, beta: 1.0 };
1335        let posterior = update_beta_bernoulli(prior, 6, 4);
1336        assert!((posterior.alpha - 7.0).abs() < 1e-10);
1337        assert!((posterior.beta - 5.0).abs() < 1e-10);
1338        assert!((posterior_mean(&posterior) - 7.0 / 12.0).abs() < 1e-10);
1339    }
1340
1341    #[test]
1342    fn test_entropy() {
1343        let uniform = vec![0.25, 0.25, 0.25, 0.25];
1344        assert!((entropy(&uniform) - (4.0f64).ln()).abs() < 1e-10);
1345    }
1346
1347    #[test]
1348    fn test_kl_divergence() {
1349        let p = vec![0.5, 0.5];
1350        let q = vec![0.5, 0.5];
1351        assert!(kl_divergence(&p, &q).abs() < 1e-10);
1352    }
1353
1354    #[test]
1355    fn test_jsd() {
1356        let p = vec![1.0, 0.0];
1357        let q = vec![0.0, 1.0];
1358        let jsd = jensen_shannon_divergence(&p, &q);
1359        assert!((jsd - 2.0f64.ln()).abs() < 1e-10);
1360    }
1361
1362    #[test]
1363    fn test_pcg32() {
1364        let mut rng = Pcg32::new(42, 1);
1365        let v: Vec<f64> = (0..1000).map(|_| rng.next_f64()).collect();
1366        let m = mean(&v);
1367        assert!((m - 0.5).abs() < 0.05);
1368    }
1369
1370    #[test]
1371    fn test_splitmix64() {
1372        let mut rng = SplitMix64::new(12345);
1373        let v: Vec<f64> = (0..1000).map(|_| rng.next_f64()).collect();
1374        let m = mean(&v);
1375        assert!((m - 0.5).abs() < 0.05);
1376    }
1377
1378    #[test]
1379    fn test_shuffle() {
1380        let mut data = vec![1, 2, 3, 4, 5, 6, 7, 8];
1381        let original = data.clone();
1382        let mut rng = Xorshift64::new(7);
1383        shuffle(&mut data, &mut rng);
1384        // Not necessarily different but should contain same elements
1385        let mut sorted = data.clone();
1386        sorted.sort();
1387        assert_eq!(sorted, vec![1, 2, 3, 4, 5, 6, 7, 8]);
1388        let _ = original;
1389    }
1390
1391    #[test]
1392    fn test_weighted_sample() {
1393        let weights = vec![0.0, 1.0, 0.0]; // must pick index 1
1394        let mut rng = Xorshift64::new(42);
1395        let idx = weighted_sample(&weights, &mut rng);
1396        assert_eq!(idx, 1);
1397    }
1398
1399    #[test]
1400    fn test_sample_without_replacement() {
1401        let mut rng = Xorshift64::new(42);
1402        let sample = sample_without_replacement(100, 10, &mut rng);
1403        assert_eq!(sample.len(), 10);
1404        // All unique
1405        let mut s = sample.clone();
1406        s.sort();
1407        s.dedup();
1408        assert_eq!(s.len(), 10);
1409    }
1410}