Skip to main content

cjc_runtime/
distributions.rs

1//! Probability distributions — CDF, PDF, PPF for Normal, t, Chi-squared, F,
2//! Binomial, Poisson.
3//!
4//! # Determinism Contract
5//! All functions are pure math — no randomness, no iteration order dependency.
6//! Approximations use deterministic, fixed-sequence algorithms.
7
8use std::f64::consts::PI;
9
10// ---------------------------------------------------------------------------
11// Helper: ln_gamma (Lanczos approximation)
12// ---------------------------------------------------------------------------
13
14/// Log-gamma function via Lanczos approximation (g=7, n=9 coefficients).
15/// Deterministic — fixed coefficient sequence.
16pub fn ln_gamma(x: f64) -> f64 {
17    if x <= 0.0 && x == x.floor() {
18        return f64::INFINITY; // poles at non-positive integers
19    }
20    let g = 7.0;
21    let coeff = [
22        0.99999999999980993,
23        676.5203681218851,
24        -1259.1392167224028,
25        771.32342877765313,
26        -176.61502916214059,
27        12.507343278686905,
28        -0.13857109526572012,
29        9.9843695780195716e-6,
30        1.5056327351493116e-7,
31    ];
32    let xx = if x < 0.5 {
33        // Reflection formula
34        let reflected = ln_gamma(1.0 - x);
35        return (PI / (PI * x).sin()).ln() - reflected;
36    } else {
37        x - 1.0
38    };
39    let mut sum = coeff[0];
40    for (i, &c) in coeff.iter().enumerate().skip(1) {
41        sum += c / (xx + i as f64);
42    }
43    let t = xx + g + 0.5;
44    0.5 * (2.0 * PI).ln() + (t.ln() * (xx + 0.5)) - t + sum.ln()
45}
46
47// ---------------------------------------------------------------------------
48// Helper: regularized incomplete beta function (Lentz continued fraction)
49// ---------------------------------------------------------------------------
50
51/// Regularized incomplete beta function I_x(a, b).
52/// Uses continued fraction (Lentz's method) for numerical stability.
53fn regularized_incomplete_beta(a: f64, b: f64, x: f64) -> f64 {
54    if x <= 0.0 { return 0.0; }
55    if x >= 1.0 { return 1.0; }
56
57    // Use symmetry if x > (a+1)/(a+b+2) for better convergence
58    if x > (a + 1.0) / (a + b + 2.0) {
59        return 1.0 - regularized_incomplete_beta(b, a, 1.0 - x);
60    }
61
62    let lbeta = ln_gamma(a) + ln_gamma(b) - ln_gamma(a + b);
63    let front = (x.ln() * a + (1.0 - x).ln() * b - lbeta).exp() / a;
64
65    // Lentz continued fraction
66    let eps = 1e-14;
67    let tiny = 1e-30;
68    let max_iter = 200;
69
70    let mut f = 1.0;
71    let mut c = 1.0;
72    let mut d = 1.0 - (a + b) * x / (a + 1.0);
73    if d.abs() < tiny { d = tiny; }
74    d = 1.0 / d;
75    f = d;
76
77    for m in 1..=max_iter {
78        let m_f = m as f64;
79        // Even step
80        let num_even = m_f * (b - m_f) * x / ((a + 2.0 * m_f - 1.0) * (a + 2.0 * m_f));
81        d = 1.0 + num_even * d;
82        if d.abs() < tiny { d = tiny; }
83        c = 1.0 + num_even / c;
84        if c.abs() < tiny { c = tiny; }
85        d = 1.0 / d;
86        f *= c * d;
87
88        // Odd step
89        let num_odd = -((a + m_f) * (a + b + m_f) * x)
90            / ((a + 2.0 * m_f) * (a + 2.0 * m_f + 1.0));
91        d = 1.0 + num_odd * d;
92        if d.abs() < tiny { d = tiny; }
93        c = 1.0 + num_odd / c;
94        if c.abs() < tiny { c = tiny; }
95        d = 1.0 / d;
96        let delta = c * d;
97        f *= delta;
98
99        if (delta - 1.0).abs() < eps {
100            break;
101        }
102    }
103
104    front * f
105}
106
107// ---------------------------------------------------------------------------
108// Helper: regularized lower incomplete gamma function
109// ---------------------------------------------------------------------------
110
111/// Regularized lower incomplete gamma function P(a, x) = gamma(a, x) / Gamma(a).
112/// Uses series expansion for x < a+1, continued fraction otherwise.
113fn regularized_gamma_p(a: f64, x: f64) -> f64 {
114    if x < 0.0 { return 0.0; }
115    if x == 0.0 { return 0.0; }
116
117    if x < a + 1.0 {
118        // Series expansion
119        gamma_series(a, x)
120    } else {
121        // Continued fraction
122        1.0 - gamma_cf(a, x)
123    }
124}
125
126fn gamma_series(a: f64, x: f64) -> f64 {
127    let max_iter = 200;
128    let eps = 1e-14;
129    let mut sum = 1.0 / a;
130    let mut term = 1.0 / a;
131    for n in 1..=max_iter {
132        term *= x / (a + n as f64);
133        sum += term;
134        if term.abs() < sum.abs() * eps {
135            break;
136        }
137    }
138    sum * (-x + a * x.ln() - ln_gamma(a)).exp()
139}
140
141fn gamma_cf(a: f64, x: f64) -> f64 {
142    // Lentz continued fraction for upper incomplete gamma Q(a,x)
143    // CF: 1/(x+1-a+ K_{n=1}^{inf} n*(n-a)/(x+2n+1-a))
144    let max_iter = 200;
145    let eps = 1e-14;
146    let tiny = 1e-30;
147
148    let mut b = x + 1.0 - a;
149    let mut c = 1.0 / tiny;
150    let mut d = 1.0 / b;
151    let mut f = d;
152
153    for i in 1..=max_iter {
154        let an = -(i as f64) * (i as f64 - a);
155        b += 2.0;
156        d = an * d + b;
157        if d.abs() < tiny { d = tiny; }
158        c = b + an / c;
159        if c.abs() < tiny { c = tiny; }
160        d = 1.0 / d;
161        let delta = d * c;
162        f *= delta;
163        if (delta - 1.0).abs() < eps {
164            break;
165        }
166    }
167
168    (-x + a * x.ln() - ln_gamma(a)).exp() * f
169}
170
171// ---------------------------------------------------------------------------
172// Error functions (Bastion ABI primitives)
173// ---------------------------------------------------------------------------
174
175/// Error function erf(x) using Horner-form rational approximation.
176///
177/// Uses the Abramowitz & Stegun 7.1.28 formula for |x| via the complementary
178/// error function. Maximum error: |eps| < 1.5e-7.
179///
180/// # Determinism Contract
181/// Pure math — same input => identical output. No iteration-order dependency.
182pub fn erf(x: f64) -> f64 {
183    // erf(x) = 1 - erfc(x)
184    1.0 - erfc(x)
185}
186
187/// Complementary error function erfc(x) = 1 - erf(x).
188///
189/// Uses Abramowitz & Stegun 7.1.26 polynomial approximation.
190/// Maximum error: |eps| < 1.5e-7.
191///
192/// # Determinism Contract
193/// Pure math — deterministic for all finite inputs. NaN in => NaN out.
194pub fn erfc(x: f64) -> f64 {
195    if x.is_nan() {
196        return f64::NAN;
197    }
198    if x == 0.0 {
199        return 1.0; // exact
200    }
201    if x == f64::INFINITY {
202        return 0.0;
203    }
204    if x == f64::NEG_INFINITY {
205        return 2.0;
206    }
207
208    let a1 =  0.254829592;
209    let a2 = -0.284496736;
210    let a3 =  1.421413741;
211    let a4 = -1.453152027;
212    let a5 =  1.061405429;
213    let p  =  0.3275911;
214
215    let z = x.abs();
216    let t = 1.0 / (1.0 + p * z);
217    let y = (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * (-z * z).exp();
218
219    if x < 0.0 {
220        2.0 - y  // erfc(-x) = 2 - erfc(x)
221    } else {
222        y
223    }
224}
225
226// ---------------------------------------------------------------------------
227// Normal distribution
228// ---------------------------------------------------------------------------
229
230/// Normal distribution CDF using Abramowitz & Stegun approximation (7.1.26).
231/// Maximum error: |eps| < 1.5e-7. Deterministic.
232///
233/// Equivalent to 0.5 * erfc(-x / sqrt(2)).
234pub fn normal_cdf(x: f64) -> f64 {
235    // Constants for the approximation
236    let a1 = 0.254829592;
237    let a2 = -0.284496736;
238    let a3 = 1.421413741;
239    let a4 = -1.453152027;
240    let a5 = 1.061405429;
241    let p = 0.3275911;
242
243    let sign = if x < 0.0 { -1.0 } else { 1.0 };
244    let x_abs = x.abs() / 2.0_f64.sqrt();
245    let t = 1.0 / (1.0 + p * x_abs);
246    let y = 1.0 - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * (-x_abs * x_abs).exp();
247
248    0.5 * (1.0 + sign * y)
249}
250
251/// Normal distribution PDF: (1/sqrt(2*pi)) * exp(-x^2/2).
252pub fn normal_pdf(x: f64) -> f64 {
253    (1.0 / (2.0 * PI).sqrt()) * (-x * x / 2.0).exp()
254}
255
256/// Normal distribution PPF (inverse CDF / quantile function).
257/// Uses rational approximation (Beasley-Springer-Moro).
258/// p must be in (0, 1).
259pub fn normal_ppf(p: f64) -> Result<f64, String> {
260    if p <= 0.0 || p >= 1.0 {
261        return Err(format!("normal_ppf: p must be in (0,1), got {p}"));
262    }
263    // Rational approximation
264    let a = [
265        -3.969683028665376e+01,
266         2.209460984245205e+02,
267        -2.759285104469687e+02,
268         1.383577518672690e+02,
269        -3.066479806614716e+01,
270         2.506628277459239e+00,
271    ];
272    let b = [
273        -5.447609879822406e+01,
274         1.615858368580409e+02,
275        -1.556989798598866e+02,
276         6.680131188771972e+01,
277        -1.328068155288572e+01,
278    ];
279    let c = [
280        -7.784894002430293e-03,
281        -3.223964580411365e-01,
282        -2.400758277161838e+00,
283        -2.549732539343734e+00,
284         4.374664141464968e+00,
285         2.938163982698783e+00,
286    ];
287    let d = [
288         7.784695709041462e-03,
289         3.224671290700398e-01,
290         2.445134137142996e+00,
291         3.754408661907416e+00,
292    ];
293
294    let p_low = 0.02425;
295    let p_high = 1.0 - p_low;
296
297    let result = if p < p_low {
298        // Lower tail
299        let q = (-2.0 * p.ln()).sqrt();
300        (((((c[0]*q+c[1])*q+c[2])*q+c[3])*q+c[4])*q+c[5])
301            / ((((d[0]*q+d[1])*q+d[2])*q+d[3])*q+1.0)
302    } else if p <= p_high {
303        // Central region
304        let q = p - 0.5;
305        let r = q * q;
306        (((((a[0]*r+a[1])*r+a[2])*r+a[3])*r+a[4])*r+a[5])*q
307            / (((((b[0]*r+b[1])*r+b[2])*r+b[3])*r+b[4])*r+1.0)
308    } else {
309        // Upper tail
310        let q = (-2.0 * (1.0 - p).ln()).sqrt();
311        -(((((c[0]*q+c[1])*q+c[2])*q+c[3])*q+c[4])*q+c[5])
312            / ((((d[0]*q+d[1])*q+d[2])*q+d[3])*q+1.0)
313    };
314    Ok(result)
315}
316
317// ---------------------------------------------------------------------------
318// Student's t-distribution
319// ---------------------------------------------------------------------------
320
321/// Student's t-distribution CDF.
322/// Uses regularized incomplete beta function.
323pub fn t_cdf(x: f64, df: f64) -> f64 {
324    let t2 = x * x;
325    let ix = df / (df + t2);
326    let beta = 0.5 * regularized_incomplete_beta(df / 2.0, 0.5, ix);
327    if x >= 0.0 {
328        1.0 - beta
329    } else {
330        beta
331    }
332}
333
334/// Student's t-distribution PPF via bisection.
335pub fn t_ppf(p: f64, df: f64) -> Result<f64, String> {
336    if p <= 0.0 || p >= 1.0 {
337        return Err(format!("t_ppf: p must be in (0,1), got {p}"));
338    }
339    // Bisection search
340    let mut lo = -1000.0;
341    let mut hi = 1000.0;
342    for _ in 0..100 {
343        let mid = (lo + hi) / 2.0;
344        if t_cdf(mid, df) < p {
345            lo = mid;
346        } else {
347            hi = mid;
348        }
349    }
350    Ok((lo + hi) / 2.0)
351}
352
353// ---------------------------------------------------------------------------
354// Chi-squared distribution
355// ---------------------------------------------------------------------------
356
357/// Chi-squared distribution CDF.
358/// Uses regularized lower incomplete gamma function.
359pub fn chi2_cdf(x: f64, df: f64) -> f64 {
360    if x <= 0.0 { return 0.0; }
361    regularized_gamma_p(df / 2.0, x / 2.0)
362}
363
364/// Chi-squared distribution PPF via bisection.
365pub fn chi2_ppf(p: f64, df: f64) -> Result<f64, String> {
366    if p <= 0.0 || p >= 1.0 {
367        return Err(format!("chi2_ppf: p must be in (0,1), got {p}"));
368    }
369    let mut lo = 0.0;
370    let mut hi = df + 10.0 * (2.0 * df).sqrt().max(10.0);
371    for _ in 0..100 {
372        let mid = (lo + hi) / 2.0;
373        if chi2_cdf(mid, df) < p {
374            lo = mid;
375        } else {
376            hi = mid;
377        }
378    }
379    Ok((lo + hi) / 2.0)
380}
381
382// ---------------------------------------------------------------------------
383// F-distribution
384// ---------------------------------------------------------------------------
385
386/// F-distribution CDF.
387/// Uses regularized incomplete beta function.
388pub fn f_cdf(x: f64, df1: f64, df2: f64) -> f64 {
389    if x <= 0.0 { return 0.0; }
390    let ix = df1 * x / (df1 * x + df2);
391    regularized_incomplete_beta(df1 / 2.0, df2 / 2.0, ix)
392}
393
394/// F-distribution PPF via bisection.
395pub fn f_ppf(p: f64, df1: f64, df2: f64) -> Result<f64, String> {
396    if p <= 0.0 || p >= 1.0 {
397        return Err(format!("f_ppf: p must be in (0,1), got {p}"));
398    }
399    let mut lo = 0.0;
400    let mut hi = 1000.0;
401    for _ in 0..100 {
402        let mid = (lo + hi) / 2.0;
403        if f_cdf(mid, df1, df2) < p {
404            lo = mid;
405        } else {
406            hi = mid;
407        }
408    }
409    Ok((lo + hi) / 2.0)
410}
411
412// ---------------------------------------------------------------------------
413// Binomial distribution
414// ---------------------------------------------------------------------------
415
416/// Binomial PMF: C(n,k) * p^k * (1-p)^(n-k).
417pub fn binomial_pmf(k: u64, n: u64, p: f64) -> f64 {
418    if k > n { return 0.0; }
419    let log_coeff = ln_gamma(n as f64 + 1.0) - ln_gamma(k as f64 + 1.0) - ln_gamma((n - k) as f64 + 1.0);
420    let log_prob = k as f64 * p.ln() + (n - k) as f64 * (1.0 - p).ln();
421    (log_coeff + log_prob).exp()
422}
423
424/// Binomial CDF: sum_{i=0}^{k} binomial_pmf(i, n, p).
425pub fn binomial_cdf(k: u64, n: u64, p: f64) -> f64 {
426    let mut sum = cjc_repro::KahanAccumulatorF64::new();
427    for i in 0..=k {
428        sum.add(binomial_pmf(i, n, p));
429    }
430    sum.finalize()
431}
432
433// ---------------------------------------------------------------------------
434// Poisson distribution
435// ---------------------------------------------------------------------------
436
437/// Poisson PMF: (lambda^k * e^-lambda) / k!
438pub fn poisson_pmf(k: u64, lambda: f64) -> f64 {
439    let log_prob = k as f64 * lambda.ln() - lambda - ln_gamma(k as f64 + 1.0);
440    log_prob.exp()
441}
442
443/// Poisson CDF: sum_{i=0}^{k} poisson_pmf(i, lambda).
444pub fn poisson_cdf(k: u64, lambda: f64) -> f64 {
445    let mut sum = cjc_repro::KahanAccumulatorF64::new();
446    for i in 0..=k {
447        sum.add(poisson_pmf(i, lambda));
448    }
449    sum.finalize()
450}
451
452// ---------------------------------------------------------------------------
453// Phase B6: Beta, Gamma, Exponential, Weibull distributions
454// ---------------------------------------------------------------------------
455
456/// Beta distribution PDF: x^(a-1) * (1-x)^(b-1) / B(a,b).
457/// x in [0,1], a > 0, b > 0.
458pub fn beta_pdf(x: f64, a: f64, b: f64) -> f64 {
459    if x < 0.0 || x > 1.0 { return 0.0; }
460    if x == 0.0 && a < 1.0 { return f64::INFINITY; }
461    if x == 1.0 && b < 1.0 { return f64::INFINITY; }
462    let log_beta = ln_gamma(a) + ln_gamma(b) - ln_gamma(a + b);
463    ((a - 1.0) * x.ln() + (b - 1.0) * (1.0 - x).ln() - log_beta).exp()
464}
465
466/// Beta distribution CDF via regularized incomplete beta function.
467pub fn beta_cdf(x: f64, a: f64, b: f64) -> f64 {
468    if x <= 0.0 { return 0.0; }
469    if x >= 1.0 { return 1.0; }
470    regularized_incomplete_beta(a, b, x)
471}
472
473/// Gamma distribution PDF: x^(k-1) * exp(-x/theta) / (theta^k * Gamma(k)).
474/// x >= 0, k > 0 (shape), theta > 0 (scale).
475pub fn gamma_pdf(x: f64, k: f64, theta: f64) -> f64 {
476    if x < 0.0 { return 0.0; }
477    if x == 0.0 {
478        if k < 1.0 { return f64::INFINITY; }
479        if k == 1.0 { return 1.0 / theta; }
480        return 0.0;
481    }
482    let log_pdf = (k - 1.0) * x.ln() - x / theta - k * theta.ln() - ln_gamma(k);
483    log_pdf.exp()
484}
485
486/// Gamma distribution CDF via regularized lower incomplete gamma function.
487pub fn gamma_cdf(x: f64, k: f64, theta: f64) -> f64 {
488    if x <= 0.0 { return 0.0; }
489    regularized_gamma_p(k, x / theta)
490}
491
492/// Exponential distribution PDF: lambda * exp(-lambda * x).
493/// x >= 0, lambda > 0 (rate).
494pub fn exp_pdf(x: f64, lambda: f64) -> f64 {
495    if x < 0.0 { return 0.0; }
496    lambda * (-lambda * x).exp()
497}
498
499/// Exponential distribution CDF: 1 - exp(-lambda * x).
500pub fn exp_cdf(x: f64, lambda: f64) -> f64 {
501    if x <= 0.0 { return 0.0; }
502    1.0 - (-lambda * x).exp()
503}
504
505/// Weibull distribution PDF: (k/lambda) * (x/lambda)^(k-1) * exp(-(x/lambda)^k).
506/// x >= 0, k > 0 (shape), lambda > 0 (scale).
507pub fn weibull_pdf(x: f64, k: f64, lambda: f64) -> f64 {
508    if x < 0.0 { return 0.0; }
509    if x == 0.0 {
510        if k < 1.0 { return f64::INFINITY; }
511        if k == 1.0 { return 1.0 / lambda; }
512        return 0.0;
513    }
514    (k / lambda) * (x / lambda).powf(k - 1.0) * (-(x / lambda).powf(k)).exp()
515}
516
517/// Weibull distribution CDF: 1 - exp(-(x/lambda)^k).
518pub fn weibull_cdf(x: f64, k: f64, lambda: f64) -> f64 {
519    if x <= 0.0 { return 0.0; }
520    1.0 - (-(x / lambda).powf(k)).exp()
521}
522
523// ===========================================================================
524// Phase 4: Distribution Sampling Functions
525// ===========================================================================
526//
527// All sampling functions are deterministic given the same RNG state.
528// They consume RNG draws in a fixed, predictable order.
529// Floating-point reductions use Kahan summation where applicable.
530
531// ---------------------------------------------------------------------------
532// Normal sampling
533// ---------------------------------------------------------------------------
534
535/// Sample `n` values from Normal(mu, sigma) using Box-Muller via `rng.next_normal_f64()`.
536///
537/// # Determinism Contract
538/// Same RNG state => identical output vector, bit-for-bit.
539pub fn normal_sample(mu: f64, sigma: f64, n: usize, rng: &mut cjc_repro::Rng) -> Vec<f64> {
540    let mut out = Vec::with_capacity(n);
541    for _ in 0..n {
542        out.push(mu + sigma * rng.next_normal_f64());
543    }
544    out
545}
546
547// ---------------------------------------------------------------------------
548// Uniform sampling
549// ---------------------------------------------------------------------------
550
551/// Sample `n` values from Uniform(a, b).
552///
553/// Each sample: a + (b - a) * U where U ~ Uniform[0, 1).
554pub fn uniform_sample(a: f64, b: f64, n: usize, rng: &mut cjc_repro::Rng) -> Vec<f64> {
555    let mut out = Vec::with_capacity(n);
556    for _ in 0..n {
557        out.push(a + (b - a) * rng.next_f64());
558    }
559    out
560}
561
562// ---------------------------------------------------------------------------
563// Exponential sampling
564// ---------------------------------------------------------------------------
565
566/// Sample `n` values from Exponential(lambda) using inverse CDF: -ln(1 - U) / lambda.
567///
568/// Uses `1.0 - rng.next_f64()` to avoid ln(0).
569pub fn exponential_sample(lambda: f64, n: usize, rng: &mut cjc_repro::Rng) -> Vec<f64> {
570    let mut out = Vec::with_capacity(n);
571    for _ in 0..n {
572        // 1.0 - next_f64() gives (0, 1] which avoids ln(0)
573        let u = 1.0 - rng.next_f64();
574        out.push(-u.ln() / lambda);
575    }
576    out
577}
578
579// ---------------------------------------------------------------------------
580// Gamma sampling — Marsaglia-Tsang method
581// ---------------------------------------------------------------------------
582
583/// Sample a single Gamma(shape, 1) value using Marsaglia-Tsang's method.
584///
585/// For shape >= 1: direct Marsaglia-Tsang.
586/// For shape < 1: sample Gamma(shape + 1, 1) then multiply by U^(1/shape)
587/// where U ~ Uniform(0,1).
588///
589/// Reference: Marsaglia & Tsang, "A Simple Method for Generating Gamma Variables" (2000).
590fn gamma_sample_single(shape: f64, rng: &mut cjc_repro::Rng) -> f64 {
591    if shape < 1.0 {
592        // Shape augmentation trick: Gamma(a, 1) = Gamma(a+1, 1) * U^(1/a)
593        let g = gamma_sample_single(shape + 1.0, rng);
594        let u = rng.next_f64();
595        // Use 1.0 - u to avoid u = 0 (which would give 0^(1/a) = 0 always)
596        return g * (1.0 - u).powf(1.0 / shape);
597    }
598
599    let d = shape - 1.0 / 3.0;
600    let c = 1.0 / (9.0 * d).sqrt();
601
602    loop {
603        let x = rng.next_normal_f64();
604        let v_base = 1.0 + c * x;
605        if v_base <= 0.0 {
606            continue;
607        }
608        let v = v_base * v_base * v_base;
609        let u = rng.next_f64();
610
611        // Squeeze test
612        if u < 1.0 - 0.0331 * (x * x) * (x * x) {
613            return d * v;
614        }
615        // Full test
616        if u.ln() < 0.5 * x * x + d * (1.0 - v + v.ln()) {
617            return d * v;
618        }
619    }
620}
621
622/// Sample `n` values from Gamma(shape_k, scale_theta).
623///
624/// Uses Marsaglia-Tsang method (shape >= 1) with shape augmentation for shape < 1.
625/// Result = Gamma(shape, 1) * scale.
626pub fn gamma_sample(
627    shape_k: f64,
628    scale_theta: f64,
629    n: usize,
630    rng: &mut cjc_repro::Rng,
631) -> Vec<f64> {
632    let mut out = Vec::with_capacity(n);
633    for _ in 0..n {
634        out.push(gamma_sample_single(shape_k, rng) * scale_theta);
635    }
636    out
637}
638
639// ---------------------------------------------------------------------------
640// Beta sampling — via Gamma
641// ---------------------------------------------------------------------------
642
643/// Sample `n` values from Beta(a, b) using the gamma ratio method.
644///
645/// X ~ Gamma(a, 1), Y ~ Gamma(b, 1), then X / (X + Y) ~ Beta(a, b).
646pub fn beta_sample(a: f64, b: f64, n: usize, rng: &mut cjc_repro::Rng) -> Vec<f64> {
647    let mut out = Vec::with_capacity(n);
648    for _ in 0..n {
649        let x = gamma_sample_single(a, rng);
650        let y = gamma_sample_single(b, rng);
651        out.push(x / (x + y));
652    }
653    out
654}
655
656// ---------------------------------------------------------------------------
657// Chi-squared sampling — via Gamma
658// ---------------------------------------------------------------------------
659
660/// Sample `n` values from Chi-squared(df).
661///
662/// Chi-squared(df) = Gamma(df/2, 2).
663pub fn chi2_sample(df: f64, n: usize, rng: &mut cjc_repro::Rng) -> Vec<f64> {
664    gamma_sample(df / 2.0, 2.0, n, rng)
665}
666
667// ---------------------------------------------------------------------------
668// Student's t sampling
669// ---------------------------------------------------------------------------
670
671/// Sample `n` values from Student's t(df).
672///
673/// t = Z / sqrt(V / df) where Z ~ Normal(0,1), V ~ Chi-squared(df).
674pub fn t_sample(df: f64, n: usize, rng: &mut cjc_repro::Rng) -> Vec<f64> {
675    let mut out = Vec::with_capacity(n);
676    for _ in 0..n {
677        let z = rng.next_normal_f64();
678        let v = gamma_sample_single(df / 2.0, rng) * 2.0; // Chi-squared(df)
679        out.push(z / (v / df).sqrt());
680    }
681    out
682}
683
684// ---------------------------------------------------------------------------
685// Poisson sampling
686// ---------------------------------------------------------------------------
687
688/// Sample a single Poisson(lambda) value using Knuth's algorithm (lambda < 30)
689/// or the transformed rejection method (lambda >= 30).
690fn poisson_sample_single(lambda: f64, rng: &mut cjc_repro::Rng) -> i64 {
691    if lambda < 30.0 {
692        // Knuth's algorithm
693        let l = (-lambda).exp();
694        let mut k: i64 = 0;
695        let mut p = 1.0;
696        loop {
697            k += 1;
698            p *= rng.next_f64();
699            if p <= l {
700                return k - 1;
701            }
702        }
703    } else {
704        // Transformed rejection method (Hoermann, "The Transformed Rejection Method")
705        // Approximation: Poisson ~ floor(Normal(lambda, sqrt(lambda)) + 0.5) with
706        // acceptance-rejection correction.
707        let sqrt_lam = lambda.sqrt();
708        let log_lam = lambda.ln();
709        let b = 0.931 + 2.53 * sqrt_lam;
710        let a = -0.059 + 0.02483 * b;
711        let inv_alpha = 1.1239 + 1.1328 / (b - 3.4);
712        let v_r = 0.9277 - 3.6224 / (b - 2.0);
713
714        loop {
715            let u = rng.next_f64() - 0.5;
716            let v = rng.next_f64();
717            let us = 0.5 - u.abs();
718            let k = ((2.0 * a / us + b) * u + lambda + 0.43).floor() as i64;
719
720            if k < 0 {
721                continue;
722            }
723
724            // Squeeze acceptance
725            if us >= 0.07 && v <= v_r {
726                return k;
727            }
728
729            // Full acceptance check
730            let kf = k as f64;
731            let log_fk = ln_gamma(kf + 1.0);
732            let log_prob = kf * log_lam - lambda - log_fk;
733
734            if (us >= 0.013 || v <= us)
735                && v.ln() + inv_alpha.ln() - (a / (us * us) + b).ln()
736                    <= log_prob
737            {
738                return k;
739            }
740        }
741    }
742}
743
744/// Sample `n` values from Poisson(lambda).
745///
746/// Uses Knuth's algorithm for lambda < 30 and transformed rejection for lambda >= 30.
747pub fn poisson_sample(lambda: f64, n: usize, rng: &mut cjc_repro::Rng) -> Vec<i64> {
748    let mut out = Vec::with_capacity(n);
749    for _ in 0..n {
750        out.push(poisson_sample_single(lambda, rng));
751    }
752    out
753}
754
755// ---------------------------------------------------------------------------
756// Binomial sampling
757// ---------------------------------------------------------------------------
758
759/// Sample a single Binomial(n_trials, p) value.
760///
761/// For small n_trials (< 25): direct simulation (sum of Bernoulli trials).
762/// For large n_trials: normal approximation with continuity correction,
763/// clamped to [0, n_trials].
764fn binomial_sample_single(n_trials: usize, p: f64, rng: &mut cjc_repro::Rng) -> i64 {
765    if n_trials < 25 {
766        // Direct simulation
767        let mut count: i64 = 0;
768        for _ in 0..n_trials {
769            if rng.next_f64() < p {
770                count += 1;
771            }
772        }
773        count
774    } else {
775        // Normal approximation: X ~ round(Normal(np, sqrt(np(1-p))))
776        let np = n_trials as f64 * p;
777        let sigma = (np * (1.0 - p)).sqrt();
778        let z = rng.next_normal_f64();
779        let x = (np + sigma * z).round() as i64;
780        // Clamp to valid range
781        x.max(0).min(n_trials as i64)
782    }
783}
784
785/// Sample `n` values from Binomial(n_trials, p).
786///
787/// Uses direct simulation for small n_trials (< 25) and normal approximation
788/// for larger values.
789pub fn binomial_sample(
790    n_trials: usize,
791    p: f64,
792    n: usize,
793    rng: &mut cjc_repro::Rng,
794) -> Vec<i64> {
795    let mut out = Vec::with_capacity(n);
796    for _ in 0..n {
797        out.push(binomial_sample_single(n_trials, p, rng));
798    }
799    out
800}
801
802// ---------------------------------------------------------------------------
803// Bernoulli sampling
804// ---------------------------------------------------------------------------
805
806/// Sample `n` Bernoulli(p) values.
807///
808/// Returns `true` with probability `p`, `false` with probability `1-p`.
809pub fn bernoulli_sample(p: f64, n: usize, rng: &mut cjc_repro::Rng) -> Vec<bool> {
810    let mut out = Vec::with_capacity(n);
811    for _ in 0..n {
812        out.push(rng.next_f64() < p);
813    }
814    out
815}
816
817// ---------------------------------------------------------------------------
818// Dirichlet sampling
819// ---------------------------------------------------------------------------
820
821/// Sample a single Dirichlet(alpha) vector.
822///
823/// Each component X_i ~ Gamma(alpha_i, 1), then normalize: X_i / sum(X).
824/// Normalization uses Kahan summation for deterministic stability.
825pub fn dirichlet_sample(alpha: &[f64], rng: &mut cjc_repro::Rng) -> Vec<f64> {
826    let k = alpha.len();
827    let mut raw = Vec::with_capacity(k);
828    let mut sum = cjc_repro::KahanAccumulatorF64::new();
829
830    for &a in alpha {
831        let g = gamma_sample_single(a, rng);
832        raw.push(g);
833        sum.add(g);
834    }
835
836    let total = sum.finalize();
837    for x in &mut raw {
838        *x /= total;
839    }
840    raw
841}
842
843// ---------------------------------------------------------------------------
844// Multinomial (categorical) sampling
845// ---------------------------------------------------------------------------
846
847/// Sample `n` categorical draws from the given probability vector.
848///
849/// Returns indices 0..probs.len()-1 sampled according to probabilities.
850/// Probabilities are normalized internally using Kahan summation.
851/// Uses CDF search for each draw.
852pub fn multinomial_sample(probs: &[f64], n: usize, rng: &mut cjc_repro::Rng) -> Vec<usize> {
853    if probs.is_empty() {
854        return Vec::new();
855    }
856
857    // Build normalized CDF using Kahan summation
858    let mut total_acc = cjc_repro::KahanAccumulatorF64::new();
859    for &p in probs {
860        total_acc.add(p);
861    }
862    let total = total_acc.finalize();
863
864    let k = probs.len();
865    let mut cdf = Vec::with_capacity(k);
866    let mut cum_acc = cjc_repro::KahanAccumulatorF64::new();
867    for &p in probs {
868        cum_acc.add(p / total);
869        cdf.push(cum_acc.finalize());
870    }
871    // Ensure last entry is exactly 1.0 to avoid floating-point edge cases
872    if let Some(last) = cdf.last_mut() {
873        *last = 1.0;
874    }
875
876    let mut out = Vec::with_capacity(n);
877    for _ in 0..n {
878        let u = rng.next_f64();
879        // Linear search through CDF (deterministic ordering)
880        let mut idx = 0;
881        while idx < k - 1 && u >= cdf[idx] {
882            idx += 1;
883        }
884        out.push(idx);
885    }
886    out
887}
888
889// ---------------------------------------------------------------------------
890// Latin Hypercube Sampling
891// ---------------------------------------------------------------------------
892
893/// Latin Hypercube Sampling — generates n samples in `dims` dimensions.
894///
895/// Each dimension is divided into n equal strata. Exactly one sample
896/// is drawn from each stratum per dimension, then columns are shuffled
897/// independently using the provided seed for deterministic output.
898///
899/// Returns a Tensor of shape [n, dims] with values in [0, 1).
900pub fn latin_hypercube_sample(n: usize, dims: usize, seed: u64) -> crate::tensor::Tensor {
901    if n == 0 || dims == 0 {
902        return crate::tensor::Tensor::from_vec_unchecked(Vec::new(), &[0, dims.max(1)]);
903    }
904
905    // We use separate RNG streams per dimension to ensure independence.
906    // Seed for dimension d is derived deterministically from the base seed.
907    let mut data = vec![0.0f64; n * dims];
908
909    for d in 0..dims {
910        // Derive a deterministic per-dimension seed using a simple mixing function.
911        let dim_seed = seed
912            .wrapping_add(d as u64)
913            .wrapping_mul(6364136223846793005)
914            .wrapping_add(1442695040888963407);
915        let mut rng = cjc_repro::Rng::seeded(dim_seed);
916
917        // Build strata indices [0, 1, ..., n-1].
918        let mut strata: Vec<usize> = (0..n).collect();
919
920        // Fisher-Yates shuffle of strata indices.
921        for i in (1..n).rev() {
922            // Generate a uniform integer in [0, i] using next_f64.
923            let j = (rng.next_f64() * (i + 1) as f64) as usize;
924            let j = j.min(i); // guard against rounding to i+1
925            strata.swap(i, j);
926        }
927
928        // Place one random point within each assigned stratum.
929        for i in 0..n {
930            let stratum = strata[i];
931            let offset = rng.next_f64(); // uniform in [0, 1)
932            let value = (stratum as f64 + offset) / n as f64;
933            data[i * dims + d] = value;
934        }
935    }
936
937    crate::tensor::Tensor::from_vec_unchecked(data, &[n, dims])
938}
939
940// ---------------------------------------------------------------------------
941// Sobol-like low-discrepancy sequence (Van der Corput)
942// ---------------------------------------------------------------------------
943
944/// Generate a Sobol-like low-discrepancy sequence.
945///
946/// Uses a simple bit-reversal approach (Van der Corput sequence)
947/// for each dimension with different bases. Not a true Sobol sequence
948/// but provides good space-filling properties for moderate dimensions.
949///
950/// Returns a Tensor of shape [n, dims] with values in [0, 1).
951pub fn sobol_sequence(n: usize, dims: usize) -> crate::tensor::Tensor {
952    if n == 0 || dims == 0 {
953        return crate::tensor::Tensor::from_vec_unchecked(Vec::new(), &[0, dims.max(1)]);
954    }
955
956    // First 30 primes used as bases for successive dimensions.
957    const PRIMES: [u64; 30] = [
958        2, 3, 5, 7, 11, 13, 17, 19, 23, 29,
959        31, 37, 41, 43, 47, 53, 59, 61, 67, 71,
960        73, 79, 83, 89, 97, 101, 103, 107, 109, 113,
961    ];
962
963    let mut data = vec![0.0f64; n * dims];
964
965    for d in 0..dims {
966        let base = PRIMES[d % PRIMES.len()];
967        for i in 0..n {
968            data[i * dims + d] = van_der_corput(i as u64, base);
969        }
970    }
971
972    crate::tensor::Tensor::from_vec_unchecked(data, &[n, dims])
973}
974
975/// Compute the Van der Corput radical-inverse of `index` in the given `base`.
976///
977/// Returns a value in [0, 1) by reflecting the base-`base` digits of `index`
978/// about the decimal point.
979fn van_der_corput(mut index: u64, base: u64) -> f64 {
980    let mut result = 0.0f64;
981    let mut denominator = 1.0f64;
982    while index > 0 {
983        denominator *= base as f64;
984        result += (index % base) as f64 / denominator;
985        index /= base;
986    }
987    result
988}
989
990// ---------------------------------------------------------------------------
991// Tests
992// ---------------------------------------------------------------------------
993
994#[cfg(test)]
995mod sampling_tests {
996    use super::*;
997
998    // -----------------------------------------------------------------------
999    // Helper: compute sample mean using Kahan summation
1000    // -----------------------------------------------------------------------
1001    fn kahan_mean(data: &[f64]) -> f64 {
1002        let mut acc = cjc_repro::KahanAccumulatorF64::new();
1003        for &x in data {
1004            acc.add(x);
1005        }
1006        acc.finalize() / data.len() as f64
1007    }
1008
1009    // -----------------------------------------------------------------------
1010    // Determinism tests: same seed => same output
1011    // -----------------------------------------------------------------------
1012
1013    #[test]
1014    fn test_normal_sample_determinism() {
1015        let mut r1 = cjc_repro::Rng::seeded(42);
1016        let mut r2 = cjc_repro::Rng::seeded(42);
1017        let a = normal_sample(0.0, 1.0, 100, &mut r1);
1018        let b = normal_sample(0.0, 1.0, 100, &mut r2);
1019        for (x, y) in a.iter().zip(b.iter()) {
1020            assert_eq!(x.to_bits(), y.to_bits());
1021        }
1022    }
1023
1024    #[test]
1025    fn test_uniform_sample_determinism() {
1026        let mut r1 = cjc_repro::Rng::seeded(7);
1027        let mut r2 = cjc_repro::Rng::seeded(7);
1028        let a = uniform_sample(0.0, 1.0, 100, &mut r1);
1029        let b = uniform_sample(0.0, 1.0, 100, &mut r2);
1030        for (x, y) in a.iter().zip(b.iter()) {
1031            assert_eq!(x.to_bits(), y.to_bits());
1032        }
1033    }
1034
1035    #[test]
1036    fn test_exponential_sample_determinism() {
1037        let mut r1 = cjc_repro::Rng::seeded(99);
1038        let mut r2 = cjc_repro::Rng::seeded(99);
1039        let a = exponential_sample(2.0, 100, &mut r1);
1040        let b = exponential_sample(2.0, 100, &mut r2);
1041        for (x, y) in a.iter().zip(b.iter()) {
1042            assert_eq!(x.to_bits(), y.to_bits());
1043        }
1044    }
1045
1046    #[test]
1047    fn test_gamma_sample_determinism() {
1048        let mut r1 = cjc_repro::Rng::seeded(13);
1049        let mut r2 = cjc_repro::Rng::seeded(13);
1050        let a = gamma_sample(2.5, 1.0, 100, &mut r1);
1051        let b = gamma_sample(2.5, 1.0, 100, &mut r2);
1052        for (x, y) in a.iter().zip(b.iter()) {
1053            assert_eq!(x.to_bits(), y.to_bits());
1054        }
1055    }
1056
1057    #[test]
1058    fn test_beta_sample_determinism() {
1059        let mut r1 = cjc_repro::Rng::seeded(55);
1060        let mut r2 = cjc_repro::Rng::seeded(55);
1061        let a = beta_sample(2.0, 5.0, 100, &mut r1);
1062        let b = beta_sample(2.0, 5.0, 100, &mut r2);
1063        for (x, y) in a.iter().zip(b.iter()) {
1064            assert_eq!(x.to_bits(), y.to_bits());
1065        }
1066    }
1067
1068    #[test]
1069    fn test_chi2_sample_determinism() {
1070        let mut r1 = cjc_repro::Rng::seeded(77);
1071        let mut r2 = cjc_repro::Rng::seeded(77);
1072        let a = chi2_sample(5.0, 100, &mut r1);
1073        let b = chi2_sample(5.0, 100, &mut r2);
1074        for (x, y) in a.iter().zip(b.iter()) {
1075            assert_eq!(x.to_bits(), y.to_bits());
1076        }
1077    }
1078
1079    #[test]
1080    fn test_t_sample_determinism() {
1081        let mut r1 = cjc_repro::Rng::seeded(111);
1082        let mut r2 = cjc_repro::Rng::seeded(111);
1083        let a = t_sample(10.0, 100, &mut r1);
1084        let b = t_sample(10.0, 100, &mut r2);
1085        for (x, y) in a.iter().zip(b.iter()) {
1086            assert_eq!(x.to_bits(), y.to_bits());
1087        }
1088    }
1089
1090    #[test]
1091    fn test_poisson_sample_determinism() {
1092        let mut r1 = cjc_repro::Rng::seeded(33);
1093        let mut r2 = cjc_repro::Rng::seeded(33);
1094        let a = poisson_sample(5.0, 100, &mut r1);
1095        let b = poisson_sample(5.0, 100, &mut r2);
1096        assert_eq!(a, b);
1097    }
1098
1099    #[test]
1100    fn test_poisson_large_lambda_determinism() {
1101        let mut r1 = cjc_repro::Rng::seeded(44);
1102        let mut r2 = cjc_repro::Rng::seeded(44);
1103        let a = poisson_sample(50.0, 100, &mut r1);
1104        let b = poisson_sample(50.0, 100, &mut r2);
1105        assert_eq!(a, b);
1106    }
1107
1108    #[test]
1109    fn test_binomial_sample_determinism() {
1110        let mut r1 = cjc_repro::Rng::seeded(88);
1111        let mut r2 = cjc_repro::Rng::seeded(88);
1112        let a = binomial_sample(20, 0.4, 100, &mut r1);
1113        let b = binomial_sample(20, 0.4, 100, &mut r2);
1114        assert_eq!(a, b);
1115    }
1116
1117    #[test]
1118    fn test_bernoulli_sample_determinism() {
1119        let mut r1 = cjc_repro::Rng::seeded(22);
1120        let mut r2 = cjc_repro::Rng::seeded(22);
1121        let a = bernoulli_sample(0.7, 100, &mut r1);
1122        let b = bernoulli_sample(0.7, 100, &mut r2);
1123        assert_eq!(a, b);
1124    }
1125
1126    #[test]
1127    fn test_dirichlet_sample_determinism() {
1128        let mut r1 = cjc_repro::Rng::seeded(66);
1129        let mut r2 = cjc_repro::Rng::seeded(66);
1130        let a = dirichlet_sample(&[1.0, 2.0, 3.0], &mut r1);
1131        let b = dirichlet_sample(&[1.0, 2.0, 3.0], &mut r2);
1132        for (x, y) in a.iter().zip(b.iter()) {
1133            assert_eq!(x.to_bits(), y.to_bits());
1134        }
1135    }
1136
1137    #[test]
1138    fn test_multinomial_sample_determinism() {
1139        let mut r1 = cjc_repro::Rng::seeded(101);
1140        let mut r2 = cjc_repro::Rng::seeded(101);
1141        let a = multinomial_sample(&[0.2, 0.3, 0.5], 100, &mut r1);
1142        let b = multinomial_sample(&[0.2, 0.3, 0.5], 100, &mut r2);
1143        assert_eq!(a, b);
1144    }
1145
1146    // -----------------------------------------------------------------------
1147    // Range correctness tests
1148    // -----------------------------------------------------------------------
1149
1150    #[test]
1151    fn test_uniform_range() {
1152        let mut rng = cjc_repro::Rng::seeded(1);
1153        let samples = uniform_sample(2.0, 5.0, 1000, &mut rng);
1154        for &x in &samples {
1155            assert!(x >= 2.0 && x < 5.0, "uniform out of range: {x}");
1156        }
1157    }
1158
1159    #[test]
1160    fn test_exponential_positive() {
1161        let mut rng = cjc_repro::Rng::seeded(2);
1162        let samples = exponential_sample(1.5, 1000, &mut rng);
1163        for &x in &samples {
1164            assert!(x > 0.0, "exponential not positive: {x}");
1165        }
1166    }
1167
1168    #[test]
1169    fn test_gamma_positive() {
1170        let mut rng = cjc_repro::Rng::seeded(3);
1171        // Test both shape < 1 and shape > 1
1172        let samples_small = gamma_sample(0.5, 2.0, 500, &mut rng);
1173        let samples_large = gamma_sample(5.0, 1.0, 500, &mut rng);
1174        for &x in samples_small.iter().chain(samples_large.iter()) {
1175            assert!(x > 0.0, "gamma not positive: {x}");
1176        }
1177    }
1178
1179    #[test]
1180    fn test_beta_unit_interval() {
1181        let mut rng = cjc_repro::Rng::seeded(4);
1182        let samples = beta_sample(2.0, 5.0, 1000, &mut rng);
1183        for &x in &samples {
1184            assert!(x >= 0.0 && x <= 1.0, "beta out of [0,1]: {x}");
1185        }
1186    }
1187
1188    #[test]
1189    fn test_chi2_positive() {
1190        let mut rng = cjc_repro::Rng::seeded(5);
1191        let samples = chi2_sample(3.0, 1000, &mut rng);
1192        for &x in &samples {
1193            assert!(x > 0.0, "chi2 not positive: {x}");
1194        }
1195    }
1196
1197    #[test]
1198    fn test_poisson_non_negative() {
1199        let mut rng = cjc_repro::Rng::seeded(6);
1200        let samples = poisson_sample(4.0, 1000, &mut rng);
1201        for &x in &samples {
1202            assert!(x >= 0, "poisson negative: {x}");
1203        }
1204    }
1205
1206    #[test]
1207    fn test_poisson_large_non_negative() {
1208        let mut rng = cjc_repro::Rng::seeded(60);
1209        let samples = poisson_sample(50.0, 1000, &mut rng);
1210        for &x in &samples {
1211            assert!(x >= 0, "poisson(50) negative: {x}");
1212        }
1213    }
1214
1215    #[test]
1216    fn test_binomial_range() {
1217        let mut rng = cjc_repro::Rng::seeded(7);
1218        let samples = binomial_sample(10, 0.5, 1000, &mut rng);
1219        for &x in &samples {
1220            assert!(x >= 0 && x <= 10, "binomial out of range: {x}");
1221        }
1222    }
1223
1224    #[test]
1225    fn test_bernoulli_values() {
1226        let mut rng = cjc_repro::Rng::seeded(8);
1227        let samples = bernoulli_sample(0.5, 1000, &mut rng);
1228        // Just check they are booleans (always true) and have both values
1229        let trues = samples.iter().filter(|&&x| x).count();
1230        let falses = samples.len() - trues;
1231        assert!(trues > 0, "no true values");
1232        assert!(falses > 0, "no false values");
1233    }
1234
1235    #[test]
1236    fn test_dirichlet_simplex() {
1237        let mut rng = cjc_repro::Rng::seeded(9);
1238        let sample = dirichlet_sample(&[1.0, 2.0, 3.0, 4.0], &mut rng);
1239        assert_eq!(sample.len(), 4);
1240        for &x in &sample {
1241            assert!(x >= 0.0 && x <= 1.0, "dirichlet component out of [0,1]: {x}");
1242        }
1243        let mut sum_acc = cjc_repro::KahanAccumulatorF64::new();
1244        for &x in &sample {
1245            sum_acc.add(x);
1246        }
1247        let sum = sum_acc.finalize();
1248        assert!((sum - 1.0).abs() < 1e-12, "dirichlet does not sum to 1: {sum}");
1249    }
1250
1251    #[test]
1252    fn test_multinomial_valid_indices() {
1253        let mut rng = cjc_repro::Rng::seeded(10);
1254        let probs = vec![0.1, 0.2, 0.3, 0.4];
1255        let samples = multinomial_sample(&probs, 1000, &mut rng);
1256        for &idx in &samples {
1257            assert!(idx < probs.len(), "multinomial index out of range: {idx}");
1258        }
1259    }
1260
1261    // -----------------------------------------------------------------------
1262    // Mean convergence tests (large n, loose tolerance)
1263    // -----------------------------------------------------------------------
1264
1265    #[test]
1266    fn test_normal_mean_convergence() {
1267        let mut rng = cjc_repro::Rng::seeded(1000);
1268        let mu = 3.0;
1269        let samples = normal_sample(mu, 1.0, 50_000, &mut rng);
1270        let mean = kahan_mean(&samples);
1271        assert!(
1272            (mean - mu).abs() < 0.05,
1273            "normal mean {mean} not close to {mu}"
1274        );
1275    }
1276
1277    #[test]
1278    fn test_uniform_mean_convergence() {
1279        let mut rng = cjc_repro::Rng::seeded(1001);
1280        let (a, b) = (2.0, 8.0);
1281        let expected = (a + b) / 2.0;
1282        let samples = uniform_sample(a, b, 50_000, &mut rng);
1283        let mean = kahan_mean(&samples);
1284        assert!(
1285            (mean - expected).abs() < 0.05,
1286            "uniform mean {mean} not close to {expected}"
1287        );
1288    }
1289
1290    #[test]
1291    fn test_exponential_mean_convergence() {
1292        let mut rng = cjc_repro::Rng::seeded(1002);
1293        let lambda = 2.0;
1294        let expected = 1.0 / lambda;
1295        let samples = exponential_sample(lambda, 50_000, &mut rng);
1296        let mean = kahan_mean(&samples);
1297        assert!(
1298            (mean - expected).abs() < 0.02,
1299            "exponential mean {mean} not close to {expected}"
1300        );
1301    }
1302
1303    #[test]
1304    fn test_gamma_mean_convergence() {
1305        let mut rng = cjc_repro::Rng::seeded(1003);
1306        let (shape, scale) = (3.0, 2.0);
1307        let expected = shape * scale;
1308        let samples = gamma_sample(shape, scale, 50_000, &mut rng);
1309        let mean = kahan_mean(&samples);
1310        assert!(
1311            (mean - expected).abs() < 0.1,
1312            "gamma mean {mean} not close to {expected}"
1313        );
1314    }
1315
1316    #[test]
1317    fn test_gamma_small_shape_mean() {
1318        let mut rng = cjc_repro::Rng::seeded(1004);
1319        let (shape, scale) = (0.5, 2.0);
1320        let expected = shape * scale;
1321        let samples = gamma_sample(shape, scale, 50_000, &mut rng);
1322        let mean = kahan_mean(&samples);
1323        assert!(
1324            (mean - expected).abs() < 0.1,
1325            "gamma(0.5) mean {mean} not close to {expected}"
1326        );
1327    }
1328
1329    #[test]
1330    fn test_beta_mean_convergence() {
1331        let mut rng = cjc_repro::Rng::seeded(1005);
1332        let (a, b) = (2.0, 5.0);
1333        let expected = a / (a + b);
1334        let samples = beta_sample(a, b, 50_000, &mut rng);
1335        let mean = kahan_mean(&samples);
1336        assert!(
1337            (mean - expected).abs() < 0.02,
1338            "beta mean {mean} not close to {expected}"
1339        );
1340    }
1341
1342    #[test]
1343    fn test_chi2_mean_convergence() {
1344        let mut rng = cjc_repro::Rng::seeded(1006);
1345        let df = 5.0;
1346        let samples = chi2_sample(df, 50_000, &mut rng);
1347        let mean = kahan_mean(&samples);
1348        assert!(
1349            (mean - df).abs() < 0.1,
1350            "chi2 mean {mean} not close to df={df}"
1351        );
1352    }
1353
1354    #[test]
1355    fn test_t_mean_convergence() {
1356        let mut rng = cjc_repro::Rng::seeded(1007);
1357        let df = 10.0; // mean is 0 for df > 1
1358        let samples = t_sample(df, 50_000, &mut rng);
1359        let mean = kahan_mean(&samples);
1360        assert!(
1361            mean.abs() < 0.05,
1362            "t mean {mean} not close to 0"
1363        );
1364    }
1365
1366    #[test]
1367    fn test_poisson_mean_convergence() {
1368        let mut rng = cjc_repro::Rng::seeded(1008);
1369        let lambda = 7.0;
1370        let samples = poisson_sample(lambda, 50_000, &mut rng);
1371        let fsamples: Vec<f64> = samples.iter().map(|&x| x as f64).collect();
1372        let mean = kahan_mean(&fsamples);
1373        assert!(
1374            (mean - lambda).abs() < 0.1,
1375            "poisson mean {mean} not close to {lambda}"
1376        );
1377    }
1378
1379    #[test]
1380    fn test_poisson_large_lambda_mean() {
1381        let mut rng = cjc_repro::Rng::seeded(1009);
1382        let lambda = 50.0;
1383        let samples = poisson_sample(lambda, 50_000, &mut rng);
1384        let fsamples: Vec<f64> = samples.iter().map(|&x| x as f64).collect();
1385        let mean = kahan_mean(&fsamples);
1386        assert!(
1387            (mean - lambda).abs() < 0.5,
1388            "poisson(50) mean {mean} not close to {lambda}"
1389        );
1390    }
1391
1392    #[test]
1393    fn test_binomial_mean_convergence() {
1394        let mut rng = cjc_repro::Rng::seeded(1010);
1395        let (n_trials, p) = (20, 0.3);
1396        let expected = n_trials as f64 * p;
1397        let samples = binomial_sample(n_trials, p, 50_000, &mut rng);
1398        let fsamples: Vec<f64> = samples.iter().map(|&x| x as f64).collect();
1399        let mean = kahan_mean(&fsamples);
1400        assert!(
1401            (mean - expected).abs() < 0.1,
1402            "binomial mean {mean} not close to {expected}"
1403        );
1404    }
1405
1406    #[test]
1407    fn test_bernoulli_mean_convergence() {
1408        let mut rng = cjc_repro::Rng::seeded(1011);
1409        let p = 0.7;
1410        let samples = bernoulli_sample(p, 50_000, &mut rng);
1411        let fsamples: Vec<f64> = samples.iter().map(|&x| if x { 1.0 } else { 0.0 }).collect();
1412        let mean = kahan_mean(&fsamples);
1413        assert!(
1414            (mean - p).abs() < 0.02,
1415            "bernoulli mean {mean} not close to {p}"
1416        );
1417    }
1418
1419    #[test]
1420    fn test_multinomial_frequency_convergence() {
1421        let mut rng = cjc_repro::Rng::seeded(1012);
1422        let probs = vec![0.1, 0.2, 0.3, 0.4];
1423        let n = 50_000;
1424        let samples = multinomial_sample(&probs, n, &mut rng);
1425        let mut counts = vec![0usize; probs.len()];
1426        for &idx in &samples {
1427            counts[idx] += 1;
1428        }
1429        for (i, (&expected_p, &count)) in probs.iter().zip(counts.iter()).enumerate() {
1430            let empirical_p = count as f64 / n as f64;
1431            assert!(
1432                (empirical_p - expected_p).abs() < 0.02,
1433                "multinomial category {i}: empirical={empirical_p}, expected={expected_p}"
1434            );
1435        }
1436    }
1437
1438    // -----------------------------------------------------------------------
1439    // Latin Hypercube Sampling tests
1440    // -----------------------------------------------------------------------
1441
1442    #[test]
1443    fn test_lhs_shape() {
1444        let samples = latin_hypercube_sample(100, 3, 42);
1445        assert_eq!(samples.shape(), &[100, 3]);
1446    }
1447
1448    #[test]
1449    fn test_lhs_bounds() {
1450        let samples = latin_hypercube_sample(50, 2, 123);
1451        for &v in samples.to_vec().iter() {
1452            assert!(v >= 0.0 && v < 1.0);
1453        }
1454    }
1455
1456    #[test]
1457    fn test_lhs_stratification() {
1458        // Each dimension should have exactly one sample per stratum
1459        let n = 20;
1460        let samples = latin_hypercube_sample(n, 2, 42);
1461        let data = samples.to_vec();
1462        for dim in 0..2 {
1463            let mut strata = vec![false; n];
1464            for i in 0..n {
1465                let val = data[i * 2 + dim];
1466                let stratum = (val * n as f64) as usize;
1467                assert!(!strata[stratum], "Stratum {} used twice in dim {}", stratum, dim);
1468                strata[stratum] = true;
1469            }
1470        }
1471    }
1472
1473    #[test]
1474    fn test_lhs_determinism() {
1475        let s1 = latin_hypercube_sample(50, 3, 42);
1476        let s2 = latin_hypercube_sample(50, 3, 42);
1477        assert_eq!(s1.to_vec(), s2.to_vec(), "LHS must be deterministic");
1478    }
1479
1480    // -----------------------------------------------------------------------
1481    // Sobol sequence tests
1482    // -----------------------------------------------------------------------
1483
1484    #[test]
1485    fn test_sobol_shape() {
1486        let seq = sobol_sequence(100, 4);
1487        assert_eq!(seq.shape(), &[100, 4]);
1488    }
1489
1490    #[test]
1491    fn test_sobol_bounds() {
1492        let seq = sobol_sequence(50, 3);
1493        for &v in seq.to_vec().iter() {
1494            assert!(v >= 0.0 && v < 1.0);
1495        }
1496    }
1497
1498    #[test]
1499    fn test_sobol_determinism() {
1500        let s1 = sobol_sequence(50, 3);
1501        let s2 = sobol_sequence(50, 3);
1502        assert_eq!(s1.to_vec(), s2.to_vec());
1503    }
1504}
1505
1506// ---------------------------------------------------------------------------
1507// Existing tests (PDF/CDF/PPF)
1508// ---------------------------------------------------------------------------
1509
1510#[cfg(test)]
1511mod tests {
1512    use super::*;
1513
1514    #[test]
1515    fn test_normal_cdf_zero() {
1516        let r = normal_cdf(0.0);
1517        assert!((r - 0.5).abs() < 1e-6, "CDF(0) = {r}");
1518    }
1519
1520    #[test]
1521    fn test_normal_cdf_196() {
1522        let r = normal_cdf(1.96);
1523        assert!((r - 0.975).abs() < 1e-3, "CDF(1.96) = {r}");
1524    }
1525
1526    #[test]
1527    fn test_normal_pdf_zero() {
1528        let r = normal_pdf(0.0);
1529        let expected = 1.0 / (2.0 * PI).sqrt();
1530        assert!((r - expected).abs() < 1e-12);
1531    }
1532
1533    #[test]
1534    fn test_normal_ppf_half() {
1535        let r = normal_ppf(0.5).unwrap();
1536        assert!(r.abs() < 1e-6, "PPF(0.5) = {r}");
1537    }
1538
1539    #[test]
1540    fn test_normal_ppf_975() {
1541        let r = normal_ppf(0.975).unwrap();
1542        assert!((r - 1.96).abs() < 0.01, "PPF(0.975) = {r}");
1543    }
1544
1545    #[test]
1546    fn test_t_cdf_symmetry() {
1547        let cdf_pos = t_cdf(0.0, 10.0);
1548        assert!((cdf_pos - 0.5).abs() < 1e-6);
1549    }
1550
1551    #[test]
1552    fn test_chi2_cdf_basic() {
1553        // chi2(df=1) at x=3.841 should be ~0.95
1554        let r = chi2_cdf(3.841, 1.0);
1555        assert!((r - 0.95).abs() < 0.01, "chi2_cdf = {r}");
1556    }
1557
1558    #[test]
1559    fn test_f_cdf_basic() {
1560        let r = f_cdf(0.0, 5.0, 10.0);
1561        assert_eq!(r, 0.0);
1562    }
1563
1564    #[test]
1565    fn test_binomial_pmf() {
1566        // P(X=0) for n=10, p=0.5 = 1/1024
1567        let r = binomial_pmf(0, 10, 0.5);
1568        assert!((r - 1.0 / 1024.0).abs() < 1e-12);
1569    }
1570
1571    #[test]
1572    fn test_poisson_pmf() {
1573        // P(X=0) for lambda=1 = e^-1
1574        let r = poisson_pmf(0, 1.0);
1575        assert!((r - (-1.0_f64).exp()).abs() < 1e-12);
1576    }
1577
1578    #[test]
1579    fn test_ln_gamma_basic() {
1580        // Gamma(1) = 0! = 1, ln(1) = 0
1581        assert!(ln_gamma(1.0).abs() < 1e-12);
1582        // Gamma(5) = 4! = 24, ln(24)
1583        assert!((ln_gamma(5.0) - 24.0_f64.ln()).abs() < 1e-10);
1584    }
1585
1586    #[test]
1587    fn test_determinism() {
1588        let r1 = normal_cdf(1.5);
1589        let r2 = normal_cdf(1.5);
1590        assert_eq!(r1.to_bits(), r2.to_bits());
1591    }
1592
1593    // --- B6: New distribution tests ---
1594
1595    #[test]
1596    fn test_beta_pdf_symmetric() {
1597        // Beta(2, 2) at x=0.5 should be 1.5
1598        let r = beta_pdf(0.5, 2.0, 2.0);
1599        assert!((r - 1.5).abs() < 1e-10, "beta_pdf(0.5, 2, 2) = {r}");
1600    }
1601
1602    #[test]
1603    fn test_beta_cdf_uniform() {
1604        // Beta(1, 1) = Uniform[0,1], CDF(x) = x
1605        for &x in &[0.1, 0.3, 0.5, 0.7, 0.9] {
1606            let r = beta_cdf(x, 1.0, 1.0);
1607            assert!((r - x).abs() < 1e-6, "beta_cdf({x}, 1, 1) = {r}");
1608        }
1609    }
1610
1611    #[test]
1612    fn test_beta_cdf_endpoints() {
1613        assert!((beta_cdf(0.0, 2.0, 3.0) - 0.0).abs() < 1e-12);
1614        assert!((beta_cdf(1.0, 2.0, 3.0) - 1.0).abs() < 1e-12);
1615    }
1616
1617    #[test]
1618    fn test_gamma_cdf_exponential() {
1619        // Gamma(1, 1/lambda) ≈ Exp(lambda)
1620        let lambda = 2.0;
1621        for &x in &[0.5, 1.0, 2.0] {
1622            let gc = gamma_cdf(x, 1.0, 1.0 / lambda);
1623            let ec = exp_cdf(x, lambda);
1624            assert!((gc - ec).abs() < 1e-6, "gamma_cdf({x}) = {gc}, exp_cdf = {ec}");
1625        }
1626    }
1627
1628    #[test]
1629    fn test_exp_cdf_memoryless() {
1630        // exp_cdf(1/lambda, lambda) ≈ 1 - 1/e
1631        let lambda = 3.0;
1632        let r = exp_cdf(1.0 / lambda, lambda);
1633        let expected = 1.0 - (-1.0_f64).exp();
1634        assert!((r - expected).abs() < 1e-10, "exp_cdf = {r}, expected {expected}");
1635    }
1636
1637    #[test]
1638    fn test_exp_pdf_integral() {
1639        // Numerical integration of PDF should be ~1.0
1640        let lambda = 1.5;
1641        let dx = 0.001;
1642        let mut sum = 0.0;
1643        let mut x = 0.0;
1644        while x < 20.0 {
1645            sum += exp_pdf(x, lambda) * dx;
1646            x += dx;
1647        }
1648        assert!((sum - 1.0).abs() < 0.01, "integral = {sum}");
1649    }
1650
1651    #[test]
1652    fn test_weibull_cdf_exponential() {
1653        // Weibull(k=1, lambda) = Exp(1/lambda)
1654        let lambda = 2.0;
1655        for &x in &[0.5, 1.0, 3.0] {
1656            let wc = weibull_cdf(x, 1.0, lambda);
1657            let ec = exp_cdf(x, 1.0 / lambda);
1658            assert!((wc - ec).abs() < 1e-10, "weibull_cdf({x}) = {wc}, exp_cdf = {ec}");
1659        }
1660    }
1661
1662    #[test]
1663    fn test_weibull_pdf_mode() {
1664        // For k > 1, mode = lambda * ((k-1)/k)^(1/k)
1665        let k: f64 = 3.0;
1666        let lambda: f64 = 2.0;
1667        let mode = lambda * ((k - 1.0) / k).powf(1.0_f64 / k);
1668        let pdf_at_mode = weibull_pdf(mode, k, lambda);
1669        // PDF at mode should be a maximum
1670        let pdf_left = weibull_pdf(mode - 0.01, k, lambda);
1671        let pdf_right = weibull_pdf(mode + 0.01, k, lambda);
1672        assert!(pdf_at_mode >= pdf_left, "mode not a max left");
1673        assert!(pdf_at_mode >= pdf_right, "mode not a max right");
1674    }
1675
1676    #[test]
1677    fn test_b6_dist_determinism() {
1678        let r1 = beta_pdf(0.3, 2.0, 5.0);
1679        let r2 = beta_pdf(0.3, 2.0, 5.0);
1680        assert_eq!(r1.to_bits(), r2.to_bits());
1681        let r1 = gamma_cdf(1.5, 3.0, 2.0);
1682        let r2 = gamma_cdf(1.5, 3.0, 2.0);
1683        assert_eq!(r1.to_bits(), r2.to_bits());
1684    }
1685
1686    // -------------------------------------------------------------------
1687    // erf / erfc tests (Bastion ABI)
1688    // -------------------------------------------------------------------
1689
1690    #[test]
1691    fn test_erf_known_values() {
1692        // erf(0) = 0
1693        assert!((erf(0.0)).abs() < 1e-10);
1694        // erf(+inf) = 1
1695        assert!((erf(f64::INFINITY) - 1.0).abs() < 1e-10);
1696        // erf(-inf) = -1
1697        assert!((erf(f64::NEG_INFINITY) + 1.0).abs() < 1e-10);
1698        // erf is odd: erf(-x) = -erf(x)
1699        assert!((erf(1.0) + erf(-1.0)).abs() < 1e-10);
1700        assert!((erf(0.5) + erf(-0.5)).abs() < 1e-10);
1701    }
1702
1703    #[test]
1704    fn test_erf_reference_values() {
1705        // Reference values from mathematical tables
1706        assert!((erf(0.5) - 0.5204998778).abs() < 2e-7);
1707        assert!((erf(1.0) - 0.8427007929).abs() < 2e-7);
1708        assert!((erf(1.5) - 0.9661051465).abs() < 2e-7);
1709        assert!((erf(2.0) - 0.9953222650).abs() < 2e-7);
1710        assert!((erf(3.0) - 0.9999779095).abs() < 2e-7);
1711    }
1712
1713    #[test]
1714    fn test_erfc_known_values() {
1715        // erfc(0) = 1
1716        assert!((erfc(0.0) - 1.0).abs() < 1e-10);
1717        // erfc(+inf) = 0
1718        assert!((erfc(f64::INFINITY)).abs() < 1e-10);
1719        // erfc(-inf) = 2
1720        assert!((erfc(f64::NEG_INFINITY) - 2.0).abs() < 1e-10);
1721        // erfc(x) + erfc(-x) = 2
1722        assert!((erfc(1.0) + erfc(-1.0) - 2.0).abs() < 1e-10);
1723    }
1724
1725    #[test]
1726    fn test_erfc_reference_values() {
1727        assert!((erfc(0.5) - 0.4795001222).abs() < 2e-7);
1728        assert!((erfc(1.0) - 0.1572992071).abs() < 2e-7);
1729        assert!((erfc(2.0) - 0.0046777350).abs() < 2e-7);
1730    }
1731
1732    #[test]
1733    fn test_erf_erfc_consistency() {
1734        // erf(x) + erfc(x) = 1 for all x
1735        for &x in &[0.0, 0.1, 0.5, 1.0, 1.5, 2.0, 3.0, -0.5, -1.0, -2.0] {
1736            assert!(
1737                (erf(x) + erfc(x) - 1.0).abs() < 1e-12,
1738                "erf({x}) + erfc({x}) != 1: got {}",
1739                erf(x) + erfc(x)
1740            );
1741        }
1742    }
1743
1744    #[test]
1745    fn test_erf_nan() {
1746        assert!(erf(f64::NAN).is_nan());
1747        assert!(erfc(f64::NAN).is_nan());
1748    }
1749
1750    #[test]
1751    fn test_erf_normal_cdf_consistency() {
1752        // normal_cdf(x) should agree with 0.5 * erfc(-x / sqrt(2))
1753        // Both use the same A&S 7.1.26 approximation but may differ slightly
1754        // due to independent evaluation paths. Tolerance: 2e-7 (within A&S error bound).
1755        for &x in &[-3.0, -1.5, -0.5, 0.0, 0.5, 1.5, 3.0] {
1756            let via_erfc = 0.5 * erfc(-x / 2.0_f64.sqrt());
1757            let via_cdf = normal_cdf(x);
1758            assert!(
1759                (via_erfc - via_cdf).abs() < 2e-7,
1760                "normal_cdf({x}) vs erfc route: cdf={via_cdf}, erfc={via_erfc}, diff={}",
1761                (via_erfc - via_cdf).abs()
1762            );
1763        }
1764    }
1765
1766    #[test]
1767    fn test_erf_determinism() {
1768        let a = erf(1.23456789);
1769        let b = erf(1.23456789);
1770        assert_eq!(a.to_bits(), b.to_bits());
1771        let a = erfc(1.23456789);
1772        let b = erfc(1.23456789);
1773        assert_eq!(a.to_bits(), b.to_bits());
1774    }
1775}