Skip to main content

oxiphysics_core/
bayesian_inference.rs

1#![allow(clippy::needless_range_loop)]
2// Copyright 2026 COOLJAPAN OU (Team KitaSan)
3// SPDX-License-Identifier: Apache-2.0
4
5//! Bayesian inference: priors, likelihoods, posterior updates, MCMC,
6//! Bayesian linear regression, Gaussian processes, and model selection.
7//!
8//! All distributions work with real-valued parameters; conjugate update
9//! formulas are provided where analytic posteriors exist.
10
11#![allow(dead_code)]
12
13use std::f64::consts::PI;
14
15// ─────────────────────────────────────────────────────────────────────────────
16// Local LCG RNG
17// ─────────────────────────────────────────────────────────────────────────────
18
19/// Lightweight LCG random number generator for sampling.
20struct BiRng {
21    state: u64,
22}
23
24impl BiRng {
25    fn new(seed: u64) -> Self {
26        Self { state: seed.max(1) }
27    }
28
29    fn next_u64(&mut self) -> u64 {
30        self.state = self
31            .state
32            .wrapping_mul(6_364_136_223_846_793_005)
33            .wrapping_add(1_442_695_040_888_963_407);
34        self.state
35    }
36
37    fn next_f64(&mut self) -> f64 {
38        (self.next_u64() >> 11) as f64 * (1.0 / (1u64 << 53) as f64)
39    }
40
41    /// Box-Muller standard normal sample.
42    fn next_normal(&mut self) -> f64 {
43        loop {
44            let u1 = self.next_f64();
45            let u2 = self.next_f64();
46            if u1 > 0.0 {
47                return (-2.0 * u1.ln()).sqrt() * (2.0 * PI * u2).cos();
48            }
49        }
50    }
51}
52
53// ─────────────────────────────────────────────────────────────────────────────
54// Prior
55// ─────────────────────────────────────────────────────────────────────────────
56
57/// Supported prior distribution families.
58#[derive(Debug, Clone)]
59pub enum Prior {
60    /// Uniform prior on \[low, high\].
61    Uniform {
62        /// Lower bound.
63        low: f64,
64        /// Upper bound.
65        high: f64,
66    },
67    /// Gaussian (normal) prior N(mean, std²).
68    Gaussian {
69        /// Prior mean.
70        mean: f64,
71        /// Prior standard deviation.
72        std: f64,
73    },
74    /// Laplace (double-exponential) prior with location μ and scale b.
75    Laplace {
76        /// Location parameter.
77        mu: f64,
78        /// Scale parameter (b > 0).
79        b: f64,
80    },
81    /// Jeffreys (improper) scale prior ∝ 1/θ for θ > 0.
82    Jeffreys,
83    /// Dirichlet prior over a simplex with concentration parameters α.
84    Dirichlet {
85        /// Concentration parameters (all positive).
86        alpha: Vec<f64>,
87    },
88    /// Beta prior Beta(alpha, beta) on \[0, 1\].
89    Beta {
90        /// Shape parameter α > 0.
91        alpha: f64,
92        /// Shape parameter β > 0.
93        beta: f64,
94    },
95    /// Gamma prior Gamma(shape, rate) on (0, ∞).
96    Gamma {
97        /// Shape k > 0.
98        shape: f64,
99        /// Rate λ > 0.
100        rate: f64,
101    },
102}
103
104impl Prior {
105    /// Returns the log prior density log p(θ) for a scalar parameter θ.
106    ///
107    /// Returns `f64::NEG_INFINITY` when θ is outside the support.
108    pub fn log_density(&self, theta: f64) -> f64 {
109        match self {
110            Prior::Uniform { low, high } => {
111                if theta >= *low && theta <= *high {
112                    -(*high - *low).ln()
113                } else {
114                    f64::NEG_INFINITY
115                }
116            }
117            Prior::Gaussian { mean, std } => {
118                if *std <= 0.0 {
119                    return f64::NEG_INFINITY;
120                }
121                let z = (theta - mean) / std;
122                -0.5 * z * z - std.ln() - 0.5 * (2.0 * PI).ln()
123            }
124            Prior::Laplace { mu, b } => {
125                if *b <= 0.0 {
126                    return f64::NEG_INFINITY;
127                }
128                -(theta - mu).abs() / b - (2.0 * b).ln()
129            }
130            Prior::Jeffreys => {
131                if theta > 0.0 {
132                    -theta.ln()
133                } else {
134                    f64::NEG_INFINITY
135                }
136            }
137            Prior::Beta { alpha, beta } => {
138                if theta <= 0.0 || theta >= 1.0 {
139                    return f64::NEG_INFINITY;
140                }
141                (*alpha - 1.0) * theta.ln() + (*beta - 1.0) * (1.0 - theta).ln()
142                    - log_beta(*alpha, *beta)
143            }
144            Prior::Gamma { shape, rate } => {
145                if theta <= 0.0 {
146                    return f64::NEG_INFINITY;
147                }
148                (*shape - 1.0) * theta.ln() - *rate * theta - log_gamma(*shape) + *shape * rate.ln()
149            }
150            Prior::Dirichlet { alpha: _ } => {
151                // Single-parameter overload is not meaningful for Dirichlet
152                f64::NEG_INFINITY
153            }
154        }
155    }
156
157    /// Returns the log density of a Dirichlet prior for a probability vector `x`.
158    ///
159    /// Returns `f64::NEG_INFINITY` when the dimensions do not match.
160    pub fn dirichlet_log_density(&self, x: &[f64]) -> f64 {
161        if let Prior::Dirichlet { alpha } = self {
162            if alpha.len() != x.len() {
163                return f64::NEG_INFINITY;
164            }
165            let sum: f64 = x.iter().sum();
166            if (sum - 1.0).abs() > 1e-8 {
167                return f64::NEG_INFINITY;
168            }
169            let log_num: f64 = alpha.iter().map(|&a| log_gamma(a)).sum();
170            let log_den = log_gamma(alpha.iter().sum::<f64>());
171            let log_z = log_den - log_num;
172            let sum_term: f64 = alpha
173                .iter()
174                .zip(x.iter())
175                .map(|(&a, &xi)| {
176                    if xi <= 0.0 {
177                        f64::NEG_INFINITY
178                    } else {
179                        (a - 1.0) * xi.ln()
180                    }
181                })
182                .sum();
183            log_z + sum_term
184        } else {
185            f64::NEG_INFINITY
186        }
187    }
188
189    /// Samples from the prior using the given RNG seed.
190    ///
191    /// Returns `None` for improper priors (Jeffreys, Dirichlet) where scalar sampling
192    /// is not defined without additional context.
193    pub fn sample(&self, seed: u64) -> Option<f64> {
194        let mut rng = BiRng::new(seed);
195        match self {
196            Prior::Uniform { low, high } => Some(low + rng.next_f64() * (high - low)),
197            Prior::Gaussian { mean, std } => Some(mean + std * rng.next_normal()),
198            Prior::Laplace { mu, b } => {
199                let u = rng.next_f64() - 0.5;
200                Some(mu - b * u.signum() * (1.0 - 2.0 * u.abs()).ln())
201            }
202            Prior::Beta { alpha, beta } => {
203                // Approximate via ratio of Gamma samples (Cheng's method simplified)
204                let x = sample_gamma(*alpha, &mut rng);
205                let y = sample_gamma(*beta, &mut rng);
206                if x + y <= 0.0 {
207                    Some(0.5)
208                } else {
209                    Some(x / (x + y))
210                }
211            }
212            Prior::Gamma { shape, rate } => Some(sample_gamma(*shape, &mut rng) / rate),
213            Prior::Jeffreys | Prior::Dirichlet { .. } => None,
214        }
215    }
216}
217
218// ─────────────────────────────────────────────────────────────────────────────
219// Likelihood
220// ─────────────────────────────────────────────────────────────────────────────
221
222/// Supported likelihood functions.
223#[derive(Debug, Clone)]
224pub enum Likelihood {
225    /// Gaussian likelihood: x | μ, σ ~ N(μ, σ²).
226    Gaussian {
227        /// Observed data.
228        data: Vec<f64>,
229        /// Known noise standard deviation σ.
230        sigma: f64,
231    },
232    /// Poisson likelihood: k | λ ~ Poisson(λ).
233    Poisson {
234        /// Observed counts.
235        counts: Vec<u64>,
236    },
237    /// Bernoulli likelihood: x ∈ {0,1} | p ~ Bernoulli(p).
238    Bernoulli {
239        /// Observed outcomes (0 or 1).
240        outcomes: Vec<u8>,
241    },
242    /// Multinomial likelihood over K categories.
243    Multinomial {
244        /// Observed counts per category.
245        counts: Vec<u64>,
246    },
247}
248
249impl Likelihood {
250    /// Returns the log likelihood log p(data | θ) for a scalar parameter θ.
251    ///
252    /// For Gaussian: θ = μ (mean).
253    /// For Poisson: θ = λ (rate).
254    /// For Bernoulli: θ = p (probability).
255    /// For Multinomial: not applicable (returns NEG_INFINITY; use `multinomial_log_likelihood`).
256    pub fn log_likelihood(&self, theta: f64) -> f64 {
257        match self {
258            Likelihood::Gaussian { data, sigma } => {
259                if *sigma <= 0.0 {
260                    return f64::NEG_INFINITY;
261                }
262                let n = data.len() as f64;
263                let ss: f64 = data.iter().map(|&x| (x - theta).powi(2)).sum();
264                -0.5 * n * (2.0 * PI * sigma * sigma).ln() - ss / (2.0 * sigma * sigma)
265            }
266            Likelihood::Poisson { counts } => {
267                if theta <= 0.0 {
268                    return f64::NEG_INFINITY;
269                }
270                counts
271                    .iter()
272                    .map(|&k| k as f64 * theta.ln() - theta - log_factorial(k))
273                    .sum()
274            }
275            Likelihood::Bernoulli { outcomes } => {
276                if theta <= 0.0 || theta >= 1.0 {
277                    return f64::NEG_INFINITY;
278                }
279                outcomes
280                    .iter()
281                    .map(|&x| {
282                        if x == 1 {
283                            theta.ln()
284                        } else {
285                            (1.0 - theta).ln()
286                        }
287                    })
288                    .sum()
289            }
290            Likelihood::Multinomial { counts: _ } => f64::NEG_INFINITY,
291        }
292    }
293
294    /// Returns the log multinomial likelihood for probability vector `probs`.
295    pub fn multinomial_log_likelihood(&self, probs: &[f64]) -> f64 {
296        if let Likelihood::Multinomial { counts } = self {
297            if counts.len() != probs.len() {
298                return f64::NEG_INFINITY;
299            }
300            counts
301                .iter()
302                .zip(probs.iter())
303                .map(|(&k, &p)| {
304                    if p <= 0.0 {
305                        if k == 0 { 0.0 } else { f64::NEG_INFINITY }
306                    } else {
307                        k as f64 * p.ln()
308                    }
309                })
310                .sum()
311        } else {
312            f64::NEG_INFINITY
313        }
314    }
315}
316
317// ─────────────────────────────────────────────────────────────────────────────
318// BayesianUpdate
319// ─────────────────────────────────────────────────────────────────────────────
320
321/// Conjugate Bayesian updates for standard distribution families.
322///
323/// Each method returns updated (posterior) hyperparameters given a prior and
324/// observed data.
325#[derive(Debug, Clone)]
326pub struct BayesianUpdate;
327
328impl BayesianUpdate {
329    /// Normal-Normal conjugate update: known variance σ², Gaussian prior on μ.
330    ///
331    /// Prior: μ ~ N(μ₀, τ₀²).  Data: x_i ~ N(μ, σ²).
332    /// Returns posterior (μ_n, τ_n).
333    pub fn normal_normal(
334        prior_mean: f64,
335        prior_std: f64,
336        likelihood_std: f64,
337        data: &[f64],
338    ) -> (f64, f64) {
339        let n = data.len() as f64;
340        if n == 0.0 {
341            return (prior_mean, prior_std);
342        }
343        let tau0_sq = prior_std * prior_std;
344        let sigma_sq = likelihood_std * likelihood_std;
345        let x_bar: f64 = data.iter().sum::<f64>() / n;
346        let tau_n_sq = 1.0 / (1.0 / tau0_sq + n / sigma_sq);
347        let mu_n = tau_n_sq * (prior_mean / tau0_sq + n * x_bar / sigma_sq);
348        (mu_n, tau_n_sq.sqrt())
349    }
350
351    /// Beta-Bernoulli conjugate update: Beta prior on p, Bernoulli observations.
352    ///
353    /// Prior: p ~ Beta(α, β).  Data: k successes out of n.
354    /// Returns posterior (α', β').
355    pub fn beta_bernoulli(
356        prior_alpha: f64,
357        prior_beta: f64,
358        successes: u64,
359        total: u64,
360    ) -> (f64, f64) {
361        let failures = total - successes.min(total);
362        (prior_alpha + successes as f64, prior_beta + failures as f64)
363    }
364
365    /// Gamma-Poisson conjugate update: Gamma prior on λ, Poisson observations.
366    ///
367    /// Prior: λ ~ Gamma(α, β).  Data: counts k_i.
368    /// Returns posterior (α', β').
369    pub fn gamma_poisson(prior_shape: f64, prior_rate: f64, counts: &[u64]) -> (f64, f64) {
370        let n = counts.len() as f64;
371        let sum_k: f64 = counts.iter().map(|&k| k as f64).sum();
372        (prior_shape + sum_k, prior_rate + n)
373    }
374
375    /// Dirichlet-Multinomial conjugate update.
376    ///
377    /// Prior: p ~ Dir(α).  Data: counts k_i.
378    /// Returns posterior α' = α + k.
379    pub fn dirichlet_multinomial(prior_alpha: &[f64], counts: &[u64]) -> Vec<f64> {
380        prior_alpha
381            .iter()
382            .zip(counts.iter())
383            .map(|(&a, &k)| a + k as f64)
384            .collect()
385    }
386
387    /// Normal-inverse-Gamma conjugate update for unknown mean and variance.
388    ///
389    /// Prior hyperparameters: (μ₀, κ₀, α₀, β₀).
390    /// Returns updated hyperparameters (μ_n, κ_n, α_n, β_n).
391    pub fn normal_inverse_gamma(
392        mu0: f64,
393        kappa0: f64,
394        alpha0: f64,
395        beta0: f64,
396        data: &[f64],
397    ) -> (f64, f64, f64, f64) {
398        let n = data.len() as f64;
399        if n == 0.0 {
400            return (mu0, kappa0, alpha0, beta0);
401        }
402        let x_bar = data.iter().sum::<f64>() / n;
403        let ss: f64 = data.iter().map(|&x| (x - x_bar).powi(2)).sum();
404        let kappa_n = kappa0 + n;
405        let mu_n = (kappa0 * mu0 + n * x_bar) / kappa_n;
406        let alpha_n = alpha0 + n / 2.0;
407        let beta_n = beta0 + 0.5 * ss + (kappa0 * n * (x_bar - mu0).powi(2)) / (2.0 * kappa_n);
408        (mu_n, kappa_n, alpha_n, beta_n)
409    }
410}
411
412// ─────────────────────────────────────────────────────────────────────────────
413// MarkovChainMonteCarlo
414// ─────────────────────────────────────────────────────────────────────────────
415
416/// Markov Chain Monte Carlo samplers.
417///
418/// Implements Metropolis-Hastings (random-walk), Gibbs sampling helpers,
419/// and a simplified dual-averaging NUTS-like step-size adaptation.
420#[derive(Debug, Clone)]
421pub struct MarkovChainMonteCarlo {
422    /// Step size (proposal std for MH).
423    pub step_size: f64,
424    /// Number of warm-up (burn-in) steps.
425    pub n_warmup: usize,
426}
427
428impl MarkovChainMonteCarlo {
429    /// Creates a new MCMC sampler with the given step size and warm-up count.
430    pub fn new(step_size: f64, n_warmup: usize) -> Self {
431        Self {
432            step_size,
433            n_warmup,
434        }
435    }
436
437    /// Runs a Metropolis-Hastings random-walk chain.
438    ///
439    /// # Arguments
440    /// * `log_target`  - Closure returning log π(θ) (up to normalisation).
441    /// * `init`        - Initial parameter value.
442    /// * `n_samples`   - Number of post-warmup samples to return.
443    /// * `seed`        - RNG seed.
444    pub fn metropolis_hastings<F>(
445        &self,
446        log_target: F,
447        init: f64,
448        n_samples: usize,
449        seed: u64,
450    ) -> Vec<f64>
451    where
452        F: Fn(f64) -> f64,
453    {
454        let mut rng = BiRng::new(seed);
455        let mut current = init;
456        let mut log_current = log_target(current);
457        // Warm-up
458        for _ in 0..self.n_warmup {
459            let proposal = current + self.step_size * rng.next_normal();
460            let log_proposal = log_target(proposal);
461            let log_alpha = log_proposal - log_current;
462            if rng.next_f64().ln() < log_alpha {
463                current = proposal;
464                log_current = log_proposal;
465            }
466        }
467        // Sampling
468        let mut samples = Vec::with_capacity(n_samples);
469        for _ in 0..n_samples {
470            let proposal = current + self.step_size * rng.next_normal();
471            let log_proposal = log_target(proposal);
472            let log_alpha = log_proposal - log_current;
473            if rng.next_f64().ln() < log_alpha {
474                current = proposal;
475                log_current = log_proposal;
476            }
477            samples.push(current);
478        }
479        samples
480    }
481
482    /// Metropolis-Hastings for a vector-valued parameter.
483    pub fn metropolis_hastings_vec<F>(
484        &self,
485        log_target: F,
486        init: Vec<f64>,
487        n_samples: usize,
488        seed: u64,
489    ) -> Vec<Vec<f64>>
490    where
491        F: Fn(&[f64]) -> f64,
492    {
493        let mut rng = BiRng::new(seed);
494        let dim = init.len();
495        let mut current = init.clone();
496        let mut log_current = log_target(&current);
497        // Warm-up
498        for _ in 0..self.n_warmup {
499            let proposal: Vec<f64> = current
500                .iter()
501                .map(|&x| x + self.step_size * rng.next_normal())
502                .collect();
503            let log_proposal = log_target(&proposal);
504            if rng.next_f64().ln() < log_proposal - log_current {
505                current = proposal;
506                log_current = log_proposal;
507            }
508        }
509        let mut samples = Vec::with_capacity(n_samples);
510        for _ in 0..n_samples {
511            let proposal: Vec<f64> = current
512                .iter()
513                .map(|&x| x + self.step_size * rng.next_normal())
514                .collect();
515            let log_proposal = log_target(&proposal);
516            if rng.next_f64().ln() < log_proposal - log_current {
517                current = proposal;
518                log_current = log_proposal;
519            }
520            samples.push(current.clone());
521        }
522        let _ = dim; // suppress unused warning
523        samples
524    }
525
526    /// Gibbs sampler for a bivariate Gaussian with known full conditionals.
527    ///
528    /// Samples from N(\[μ1, μ2\], \[\[σ1², ρσ1σ2\],\[ρσ1σ2, σ2²\]\]).
529    #[allow(clippy::too_many_arguments)]
530    pub fn gibbs_bivariate_gaussian(
531        mu1: f64,
532        mu2: f64,
533        sigma1: f64,
534        sigma2: f64,
535        rho: f64,
536        n_samples: usize,
537        seed: u64,
538    ) -> Vec<[f64; 2]> {
539        let mut rng = BiRng::new(seed);
540        #[allow(unused_assignments)]
541        let mut x1 = mu1;
542        let mut x2 = mu2;
543        let mut samples = Vec::with_capacity(n_samples);
544        for _ in 0..n_samples {
545            // x1 | x2 ~ N(μ1 + ρ(σ1/σ2)(x2-μ2), σ1²(1-ρ²))
546            let cond_mean1 = mu1 + rho * (sigma1 / sigma2) * (x2 - mu2);
547            let cond_std1 = sigma1 * (1.0 - rho * rho).sqrt();
548            x1 = cond_mean1 + cond_std1 * rng.next_normal();
549            // x2 | x1 ~ N(μ2 + ρ(σ2/σ1)(x1-μ1), σ2²(1-ρ²))
550            let cond_mean2 = mu2 + rho * (sigma2 / sigma1) * (x1 - mu1);
551            let cond_std2 = sigma2 * (1.0 - rho * rho).sqrt();
552            x2 = cond_mean2 + cond_std2 * rng.next_normal();
553            samples.push([x1, x2]);
554        }
555        samples
556    }
557
558    /// Simplified NUTS (No-U-Turn Sampler) step using leapfrog integration.
559    ///
560    /// Returns a single sample given gradient function `grad_log_target`.
561    pub fn nuts_step<F>(
562        &self,
563        log_target: F,
564        grad_log_target: impl Fn(f64) -> f64,
565        init: f64,
566        seed: u64,
567    ) -> f64
568    where
569        F: Fn(f64) -> f64,
570    {
571        let mut rng = BiRng::new(seed);
572        let mut q = init;
573        let mut p = rng.next_normal();
574        let h_init = -log_target(q) + 0.5 * p * p;
575        // Single leapfrog step
576        let grad = grad_log_target(q);
577        p += 0.5 * self.step_size * grad;
578        q += self.step_size * p;
579        p += 0.5 * self.step_size * grad_log_target(q);
580        let h_prop = -log_target(q) + 0.5 * p * p;
581        let log_alpha = -(h_prop - h_init);
582        if rng.next_f64().ln() < log_alpha {
583            q
584        } else {
585            init
586        }
587    }
588
589    /// Returns the effective sample size (ESS) from a chain using autocorrelation.
590    pub fn effective_sample_size(chain: &[f64]) -> f64 {
591        let n = chain.len();
592        if n < 4 {
593            return n as f64;
594        }
595        let mean = chain.iter().sum::<f64>() / n as f64;
596        let variance = chain.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / n as f64;
597        if variance < 1e-30 {
598            return n as f64;
599        }
600        let mut rho_sum = 0.0_f64;
601        for lag in 1..(n / 2) {
602            let acf: f64 = (0..n - lag)
603                .map(|i| (chain[i] - mean) * (chain[i + lag] - mean))
604                .sum::<f64>()
605                / (n as f64 * variance);
606            if acf < 0.0 {
607                break;
608            }
609            rho_sum += acf;
610        }
611        n as f64 / (1.0 + 2.0 * rho_sum)
612    }
613}
614
615// ─────────────────────────────────────────────────────────────────────────────
616// BayesianLinearRegression
617// ─────────────────────────────────────────────────────────────────────────────
618
619/// Bayesian linear regression with conjugate Gaussian prior on weights.
620///
621/// Model: y = X w + ε, ε ~ N(0, σ² I).
622/// Prior: w ~ N(w₀, α⁻¹ I).
623#[derive(Debug, Clone)]
624pub struct BayesianLinearRegression {
625    /// Prior precision α (= 1/σ_prior²).
626    pub alpha_prior: f64,
627    /// Noise precision β (= 1/σ_noise²).
628    pub beta_noise: f64,
629    /// Posterior mean weights (set after fitting).
630    pub posterior_mean: Vec<f64>,
631    /// Posterior covariance matrix (row-major, set after fitting).
632    pub posterior_cov: Vec<Vec<f64>>,
633}
634
635impl BayesianLinearRegression {
636    /// Creates a new Bayesian linear regression model.
637    ///
638    /// # Arguments
639    /// * `alpha_prior` - Prior precision (inverse variance) on weights.
640    /// * `beta_noise`  - Noise precision.
641    pub fn new(alpha_prior: f64, beta_noise: f64) -> Self {
642        Self {
643            alpha_prior,
644            beta_noise,
645            posterior_mean: vec![],
646            posterior_cov: vec![],
647        }
648    }
649
650    /// Fits the model to design matrix `x_mat` (n × d) and targets `y` (n).
651    ///
652    /// Computes the posterior mean and covariance analytically.
653    pub fn fit(&mut self, x_mat: &[Vec<f64>], y: &[f64]) {
654        let n = x_mat.len();
655        if n == 0 || y.is_empty() {
656            return;
657        }
658        let d = x_mat[0].len();
659        // S_N^{-1} = α I + β X^T X
660        // m_N = β S_N X^T y
661        // Compute X^T X (d × d)
662        let mut xtx = vec![vec![0.0_f64; d]; d];
663        for row in x_mat {
664            for i in 0..d {
665                for j in 0..d {
666                    xtx[i][j] += row[i] * row[j];
667                }
668            }
669        }
670        // S_N^{-1} = α I + β X^T X
671        let mut s_inv = vec![vec![0.0_f64; d]; d];
672        for i in 0..d {
673            for j in 0..d {
674                s_inv[i][j] = self.beta_noise * xtx[i][j];
675            }
676            s_inv[i][i] += self.alpha_prior;
677        }
678        // Invert S_inv using Gaussian elimination
679        let s_n = mat_inverse(&s_inv);
680        // X^T y (d vector)
681        let mut xty = vec![0.0_f64; d];
682        for (row, &yi) in x_mat.iter().zip(y.iter()) {
683            for i in 0..d {
684                xty[i] += row[i] * yi;
685            }
686        }
687        // m_N = β S_N X^T y
688        let mut m_n = vec![0.0_f64; d];
689        for i in 0..d {
690            for j in 0..d {
691                m_n[i] += s_n[i][j] * xty[j];
692            }
693        }
694        // Scale m_N by β
695        for v in m_n.iter_mut() {
696            *v *= self.beta_noise;
697        }
698        self.posterior_mean = m_n;
699        self.posterior_cov = s_n;
700    }
701
702    /// Returns the predictive mean for a new input vector `x_new`.
703    pub fn predict_mean(&self, x_new: &[f64]) -> f64 {
704        self.posterior_mean
705            .iter()
706            .zip(x_new.iter())
707            .map(|(&w, &x)| w * x)
708            .sum()
709    }
710
711    /// Returns the predictive variance for a new input `x_new`.
712    ///
713    /// σ²_pred = 1/β + x^T S_N x.
714    pub fn predict_variance(&self, x_new: &[f64]) -> f64 {
715        if self.posterior_cov.is_empty() {
716            return 1.0 / self.beta_noise;
717        }
718        let d = x_new.len();
719        let mut s_x = vec![0.0_f64; d];
720        for i in 0..d {
721            for j in 0..d {
722                s_x[i] += self.posterior_cov[i][j] * x_new[j];
723            }
724        }
725        let xtsx: f64 = x_new.iter().zip(s_x.iter()).map(|(&x, &sx)| x * sx).sum();
726        1.0 / self.beta_noise + xtsx
727    }
728
729    /// Returns the log marginal likelihood (model evidence) log p(y | X, α, β).
730    ///
731    /// log p(y) = (d/2) ln α + (n/2) ln β - (1/2)(β ||y - X m_N||² + α ||m_N||²)
732    ///            - (1/2) ln|S_N^{-1}| - (n/2) ln(2π)
733    pub fn log_evidence(&self, x_mat: &[Vec<f64>], y: &[f64]) -> f64 {
734        if self.posterior_mean.is_empty() || x_mat.is_empty() {
735            return f64::NEG_INFINITY;
736        }
737        let n = y.len() as f64;
738        let d = self.posterior_mean.len() as f64;
739        // Compute residuals y - X m_N
740        let mut ss_res = 0.0_f64;
741        for (row, &yi) in x_mat.iter().zip(y.iter()) {
742            let pred = self.predict_mean(row);
743            ss_res += (yi - pred).powi(2);
744        }
745        let m_norm_sq: f64 = self.posterior_mean.iter().map(|&w| w * w).sum();
746        let log_det_s: f64 = {
747            // Log-determinant of S_N via diagonal approx (identity prior case)
748            self.posterior_cov
749                .iter()
750                .enumerate()
751                .map(|(i, row)| row[i].abs().ln())
752                .sum()
753        };
754        let alpha = self.alpha_prior;
755        let beta = self.beta_noise;
756        (d / 2.0) * alpha.ln() + (n / 2.0) * beta.ln() - 0.5 * (beta * ss_res + alpha * m_norm_sq)
757            + 0.5 * log_det_s
758            - (n / 2.0) * (2.0 * PI).ln()
759    }
760}
761
762// ─────────────────────────────────────────────────────────────────────────────
763// GaussianProcess
764// ─────────────────────────────────────────────────────────────────────────────
765
766/// Kernel function types for Gaussian processes.
767#[derive(Debug, Clone, Copy)]
768pub enum Kernel {
769    /// Radial basis function (squared-exponential) kernel.
770    Rbf {
771        /// Length scale ℓ.
772        length_scale: f64,
773        /// Signal variance σ_f².
774        signal_variance: f64,
775    },
776    /// Matérn 3/2 kernel.
777    Matern32 {
778        /// Length scale ℓ.
779        length_scale: f64,
780        /// Signal variance σ_f².
781        signal_variance: f64,
782    },
783    /// Matérn 5/2 kernel.
784    Matern52 {
785        /// Length scale ℓ.
786        length_scale: f64,
787        /// Signal variance σ_f².
788        signal_variance: f64,
789    },
790    /// Linear (dot-product) kernel.
791    Linear {
792        /// Bias variance σ_b².
793        bias_variance: f64,
794        /// Slope variance σ_v².
795        slope_variance: f64,
796    },
797}
798
799impl Kernel {
800    /// Evaluates k(x, y) for scalar inputs.
801    pub fn eval(&self, x: f64, y: f64) -> f64 {
802        match self {
803            Kernel::Rbf {
804                length_scale,
805                signal_variance,
806            } => {
807                let r2 = (x - y).powi(2) / (length_scale * length_scale);
808                signal_variance * (-0.5 * r2).exp()
809            }
810            Kernel::Matern32 {
811                length_scale,
812                signal_variance,
813            } => {
814                let r = (x - y).abs() / length_scale;
815                let s3r = 3.0_f64.sqrt() * r;
816                signal_variance * (1.0 + s3r) * (-s3r).exp()
817            }
818            Kernel::Matern52 {
819                length_scale,
820                signal_variance,
821            } => {
822                let r = (x - y).abs() / length_scale;
823                let s5r = 5.0_f64.sqrt() * r;
824                signal_variance * (1.0 + s5r + 5.0 * r * r / 3.0) * (-s5r).exp()
825            }
826            Kernel::Linear {
827                bias_variance,
828                slope_variance,
829            } => bias_variance + slope_variance * x * y,
830        }
831    }
832}
833
834/// Gaussian Process regression with a scalar-input, scalar-output model.
835///
836/// Computes the posterior mean and variance given observed (x, y) pairs.
837#[derive(Debug, Clone)]
838pub struct GaussianProcess {
839    /// Kernel function.
840    pub kernel: Kernel,
841    /// Noise variance σ_n².
842    pub noise_variance: f64,
843    /// Training inputs.
844    pub x_train: Vec<f64>,
845    /// Training targets.
846    pub y_train: Vec<f64>,
847    /// Cholesky factor L of (K + σ_n² I) for efficient prediction.
848    chol: Vec<Vec<f64>>,
849    /// α = L^{-T} L^{-1} y.
850    alpha: Vec<f64>,
851}
852
853impl GaussianProcess {
854    /// Creates a new GP with the specified kernel and noise variance.
855    pub fn new(kernel: Kernel, noise_variance: f64) -> Self {
856        Self {
857            kernel,
858            noise_variance,
859            x_train: vec![],
860            y_train: vec![],
861            chol: vec![],
862            alpha: vec![],
863        }
864    }
865
866    /// Fits the GP to training data, computing the Cholesky factor.
867    pub fn fit(&mut self, x_train: Vec<f64>, y_train: Vec<f64>) {
868        let n = x_train.len();
869        self.x_train = x_train;
870        self.y_train = y_train.clone();
871        // Build kernel matrix K + σ_n² I
872        let mut k = vec![vec![0.0_f64; n]; n];
873        for i in 0..n {
874            for j in 0..n {
875                k[i][j] = self.kernel.eval(self.x_train[i], self.x_train[j]);
876            }
877            k[i][i] += self.noise_variance;
878        }
879        // Cholesky decomposition
880        self.chol = cholesky(&k);
881        // α = L^{-T} L^{-1} y via forward/backward substitution
882        let v = forward_sub(&self.chol, &y_train);
883        self.alpha = backward_sub_t(&self.chol, &v);
884    }
885
886    /// Returns the posterior mean at test point `x_star`.
887    pub fn predict_mean(&self, x_star: f64) -> f64 {
888        if self.x_train.is_empty() {
889            return 0.0;
890        }
891        let k_star: Vec<f64> = self
892            .x_train
893            .iter()
894            .map(|&xi| self.kernel.eval(xi, x_star))
895            .collect();
896        k_star
897            .iter()
898            .zip(self.alpha.iter())
899            .map(|(&k, &a)| k * a)
900            .sum()
901    }
902
903    /// Returns the posterior variance at test point `x_star`.
904    pub fn predict_variance(&self, x_star: f64) -> f64 {
905        let k_ss = self.kernel.eval(x_star, x_star) + self.noise_variance;
906        if self.chol.is_empty() {
907            return k_ss;
908        }
909        let k_star: Vec<f64> = self
910            .x_train
911            .iter()
912            .map(|&xi| self.kernel.eval(xi, x_star))
913            .collect();
914        let v = forward_sub(&self.chol, &k_star);
915        let reduction: f64 = v.iter().map(|&vi| vi * vi).sum();
916        (k_ss - reduction).max(0.0)
917    }
918
919    /// Returns the log marginal likelihood log p(y | X, θ).
920    pub fn log_marginal_likelihood(&self) -> f64 {
921        if self.chol.is_empty() || self.y_train.is_empty() {
922            return f64::NEG_INFINITY;
923        }
924        let n = self.y_train.len() as f64;
925        let y = &self.y_train;
926        // data fit term: -0.5 y^T α
927        let data_fit: f64 = y
928            .iter()
929            .zip(self.alpha.iter())
930            .map(|(&yi, &ai)| yi * ai)
931            .sum();
932        // log det term: log |K| = 2 Σ log L_ii
933        let log_det: f64 = self
934            .chol
935            .iter()
936            .enumerate()
937            .map(|(i, row)| row[i].abs().ln())
938            .sum::<f64>()
939            * 2.0;
940        -0.5 * data_fit - 0.5 * log_det - (n / 2.0) * (2.0 * PI).ln()
941    }
942}
943
944// ─────────────────────────────────────────────────────────────────────────────
945// ModelSelection
946// ─────────────────────────────────────────────────────────────────────────────
947
948/// Model selection criteria: AIC, BIC, Bayes factor, and cross-validation.
949#[derive(Debug, Clone)]
950pub struct ModelSelection;
951
952impl ModelSelection {
953    /// Akaike Information Criterion: AIC = 2k - 2 log L.
954    ///
955    /// # Arguments
956    /// * `log_likelihood` - Maximised log-likelihood.
957    /// * `n_params`       - Number of free parameters k.
958    pub fn aic(log_likelihood: f64, n_params: usize) -> f64 {
959        2.0 * n_params as f64 - 2.0 * log_likelihood
960    }
961
962    /// Corrected AIC for small samples: AICc = AIC + 2k(k+1)/(n-k-1).
963    pub fn aicc(log_likelihood: f64, n_params: usize, n_data: usize) -> f64 {
964        let k = n_params as f64;
965        let n = n_data as f64;
966        let aic = Self::aic(log_likelihood, n_params);
967        if n > k + 1.0 {
968            aic + 2.0 * k * (k + 1.0) / (n - k - 1.0)
969        } else {
970            aic
971        }
972    }
973
974    /// Bayesian Information Criterion: BIC = k ln(n) - 2 log L.
975    ///
976    /// # Arguments
977    /// * `log_likelihood` - Maximised log-likelihood.
978    /// * `n_params`       - Number of free parameters k.
979    /// * `n_data`         - Number of data points n.
980    pub fn bic(log_likelihood: f64, n_params: usize, n_data: usize) -> f64 {
981        n_params as f64 * (n_data as f64).ln() - 2.0 * log_likelihood
982    }
983
984    /// Bayes factor (in log scale): log BF₁₂ = log p(D|M₁) - log p(D|M₂).
985    pub fn log_bayes_factor(log_evidence_1: f64, log_evidence_2: f64) -> f64 {
986        log_evidence_1 - log_evidence_2
987    }
988
989    /// Interprets the log Bayes factor according to Jeffreys' scale.
990    ///
991    /// Returns a descriptive string.
992    pub fn jeffreys_scale(log_bf: f64) -> &'static str {
993        let bf = log_bf.exp();
994        if bf < 1.0 {
995            "Negative (favours M2)"
996        } else if bf < 3.0 {
997            "Barely worth mentioning"
998        } else if bf < 10.0 {
999            "Substantial"
1000        } else if bf < 30.0 {
1001            "Strong"
1002        } else if bf < 100.0 {
1003            "Very strong"
1004        } else {
1005            "Decisive"
1006        }
1007    }
1008
1009    /// K-fold cross-validation mean squared error.
1010    ///
1011    /// # Arguments
1012    /// * `x`    - Feature matrix (n × d, row-major as `Vec<Vec`f64`>`).
1013    /// * `y`    - Target vector (length n).
1014    /// * `k`    - Number of folds.
1015    /// * `alpha` - Prior precision for Bayesian linear regression.
1016    /// * `beta`  - Noise precision.
1017    pub fn k_fold_cv_mse(x: &[Vec<f64>], y: &[f64], k: usize, alpha: f64, beta_noise: f64) -> f64 {
1018        let n = x.len();
1019        if k == 0 || n < k {
1020            return f64::NAN;
1021        }
1022        let fold_size = n / k;
1023        let mut total_mse = 0.0_f64;
1024        let mut total_count = 0_usize;
1025        for fold in 0..k {
1026            let test_start = fold * fold_size;
1027            let test_end = if fold == k - 1 {
1028                n
1029            } else {
1030                test_start + fold_size
1031            };
1032            let x_train: Vec<Vec<f64>> = x[..test_start]
1033                .iter()
1034                .chain(x[test_end..].iter())
1035                .cloned()
1036                .collect();
1037            let y_train: Vec<f64> = y[..test_start]
1038                .iter()
1039                .chain(y[test_end..].iter())
1040                .cloned()
1041                .collect();
1042            let x_test = &x[test_start..test_end];
1043            let y_test = &y[test_start..test_end];
1044            let mut model = BayesianLinearRegression::new(alpha, beta_noise);
1045            model.fit(&x_train, &y_train);
1046            for (xi, &yi) in x_test.iter().zip(y_test.iter()) {
1047                let pred = model.predict_mean(xi);
1048                total_mse += (yi - pred).powi(2);
1049                total_count += 1;
1050            }
1051        }
1052        if total_count == 0 {
1053            f64::NAN
1054        } else {
1055            total_mse / total_count as f64
1056        }
1057    }
1058
1059    /// Pseudo Bayes factor approximation using LOO-CV log predictive density.
1060    pub fn loo_cv_log_predictive(x: &[Vec<f64>], y: &[f64], alpha: f64, beta_noise: f64) -> f64 {
1061        let n = x.len();
1062        let mut total = 0.0_f64;
1063        for i in 0..n {
1064            let x_train: Vec<Vec<f64>> = x
1065                .iter()
1066                .enumerate()
1067                .filter(|&(j, _)| j != i)
1068                .map(|(_, v)| v.clone())
1069                .collect();
1070            let y_train: Vec<f64> = y
1071                .iter()
1072                .enumerate()
1073                .filter(|&(j, _)| j != i)
1074                .map(|(_, &v)| v)
1075                .collect();
1076            let mut model = BayesianLinearRegression::new(alpha, beta_noise);
1077            model.fit(&x_train, &y_train);
1078            let mean = model.predict_mean(&x[i]);
1079            let var = model.predict_variance(&x[i]);
1080            // log N(y_i | mean, var)
1081            let log_p = -0.5 * (y[i] - mean).powi(2) / var - 0.5 * (2.0 * PI * var).ln();
1082            total += log_p;
1083        }
1084        total
1085    }
1086}
1087
1088// ─────────────────────────────────────────────────────────────────────────────
1089// Utility functions
1090// ─────────────────────────────────────────────────────────────────────────────
1091
1092/// Returns ln Γ(x) using the Lanczos approximation.
1093pub fn log_gamma(x: f64) -> f64 {
1094    if x <= 0.0 {
1095        return f64::INFINITY;
1096    }
1097    // Lanczos coefficients (g = 5, n = 7)
1098    let g = 5.0_f64;
1099    let c = [
1100        1.000000000190015,
1101        76.18009172947146,
1102        -86.50532032941677,
1103        24.01409824083091,
1104        -1.231739572450155,
1105        0.001208650973866179,
1106        -5.395239384953e-6,
1107    ];
1108    let mut sum = c[0];
1109    let mut xp = x;
1110    for ci in c.iter().skip(1) {
1111        xp += 1.0;
1112        sum += ci / xp;
1113    }
1114    let t = x + g + 0.5;
1115    0.5 * (2.0 * PI).ln() + (x + 0.5) * t.ln() - t + sum.ln() - x.ln()
1116}
1117
1118/// Returns ln B(a, b) = ln Γ(a) + ln Γ(b) - ln Γ(a+b).
1119pub fn log_beta(a: f64, b: f64) -> f64 {
1120    log_gamma(a) + log_gamma(b) - log_gamma(a + b)
1121}
1122
1123/// Returns ln k! via Stirling for large k.
1124fn log_factorial(k: u64) -> f64 {
1125    log_gamma(k as f64 + 1.0)
1126}
1127
1128/// Samples from a Gamma(shape, 1) distribution using Marsaglia-Tsang.
1129fn sample_gamma(shape: f64, rng: &mut BiRng) -> f64 {
1130    if shape < 1.0 {
1131        // Boost using shape+1
1132        return sample_gamma(1.0 + shape, rng) * rng.next_f64().powf(1.0 / shape);
1133    }
1134    let d = shape - 1.0 / 3.0;
1135    let c = 1.0 / (9.0 * d).sqrt();
1136    loop {
1137        let x = rng.next_normal();
1138        let v_raw = 1.0 + c * x;
1139        if v_raw <= 0.0 {
1140            continue;
1141        }
1142        let v = v_raw.powi(3);
1143        let u = rng.next_f64();
1144        if u < 1.0 - 0.0331 * (x * x) * (x * x) {
1145            return d * v;
1146        }
1147        if u.ln() < 0.5 * x * x + d * (1.0 - v + v.ln()) {
1148            return d * v;
1149        }
1150    }
1151}
1152
1153/// Performs Cholesky decomposition of a positive-definite matrix (returns lower L).
1154fn cholesky(a: &[Vec<f64>]) -> Vec<Vec<f64>> {
1155    let n = a.len();
1156    let mut l = vec![vec![0.0_f64; n]; n];
1157    for i in 0..n {
1158        for j in 0..=i {
1159            let sum: f64 = (0..j).map(|k| l[i][k] * l[j][k]).sum();
1160            if i == j {
1161                let val = a[i][i] - sum;
1162                l[i][j] = if val > 0.0 { val.sqrt() } else { 1e-10 };
1163            } else if l[j][j].abs() < 1e-30 {
1164                l[i][j] = 0.0;
1165            } else {
1166                l[i][j] = (a[i][j] - sum) / l[j][j];
1167            }
1168        }
1169    }
1170    l
1171}
1172
1173/// Forward substitution: solve L x = b.
1174fn forward_sub(l: &[Vec<f64>], b: &[f64]) -> Vec<f64> {
1175    let n = b.len();
1176    let mut x = vec![0.0_f64; n];
1177    for i in 0..n {
1178        let s: f64 = (0..i).map(|j| l[i][j] * x[j]).sum();
1179        if l[i][i].abs() > 1e-30 {
1180            x[i] = (b[i] - s) / l[i][i];
1181        }
1182    }
1183    x
1184}
1185
1186/// Backward substitution with transposed L: solve L^T x = b.
1187fn backward_sub_t(l: &[Vec<f64>], b: &[f64]) -> Vec<f64> {
1188    let n = b.len();
1189    let mut x = vec![0.0_f64; n];
1190    for i in (0..n).rev() {
1191        let s: f64 = (i + 1..n).map(|j| l[j][i] * x[j]).sum();
1192        if l[i][i].abs() > 1e-30 {
1193            x[i] = (b[i] - s) / l[i][i];
1194        }
1195    }
1196    x
1197}
1198
1199/// Inverts a matrix using Gauss-Jordan elimination (in-place augmented matrix).
1200fn mat_inverse(a: &[Vec<f64>]) -> Vec<Vec<f64>> {
1201    let n = a.len();
1202    if n == 0 {
1203        return vec![];
1204    }
1205    // Augmented matrix [a | I]
1206    let mut aug: Vec<Vec<f64>> = a
1207        .iter()
1208        .enumerate()
1209        .map(|(i, row)| {
1210            let mut r = row.clone();
1211            for j in 0..n {
1212                r.push(if i == j { 1.0 } else { 0.0 });
1213            }
1214            r
1215        })
1216        .collect();
1217    for col in 0..n {
1218        // Find pivot
1219        let mut max_row = col;
1220        for row in col + 1..n {
1221            if aug[row][col].abs() > aug[max_row][col].abs() {
1222                max_row = row;
1223            }
1224        }
1225        aug.swap(col, max_row);
1226        let pivot = aug[col][col];
1227        if pivot.abs() < 1e-30 {
1228            continue;
1229        }
1230        for j in 0..2 * n {
1231            aug[col][j] /= pivot;
1232        }
1233        for row in 0..n {
1234            if row != col {
1235                let factor = aug[row][col];
1236                for j in 0..2 * n {
1237                    let v = factor * aug[col][j];
1238                    aug[row][j] -= v;
1239                }
1240            }
1241        }
1242    }
1243    aug.into_iter().map(|row| row[n..].to_vec()).collect()
1244}
1245
1246// ─────────────────────────────────────────────────────────────────────────────
1247// Tests
1248// ─────────────────────────────────────────────────────────────────────────────
1249
1250#[cfg(test)]
1251mod tests {
1252    use super::*;
1253
1254    // ── Prior ────────────────────────────────────────────────────────────────
1255
1256    #[test]
1257    fn test_prior_uniform_inside() {
1258        let p = Prior::Uniform {
1259            low: 0.0,
1260            high: 1.0,
1261        };
1262        assert!(p.log_density(0.5).is_finite());
1263    }
1264
1265    #[test]
1266    fn test_prior_uniform_outside() {
1267        let p = Prior::Uniform {
1268            low: 0.0,
1269            high: 1.0,
1270        };
1271        assert_eq!(p.log_density(2.0), f64::NEG_INFINITY);
1272    }
1273
1274    #[test]
1275    fn test_prior_gaussian_mode() {
1276        let p = Prior::Gaussian {
1277            mean: 0.0,
1278            std: 1.0,
1279        };
1280        // Mode at 0 should be maximum
1281        let l0 = p.log_density(0.0);
1282        let l1 = p.log_density(1.0);
1283        assert!(l0 > l1);
1284    }
1285
1286    #[test]
1287    fn test_prior_laplace_symmetric() {
1288        let p = Prior::Laplace { mu: 0.0, b: 1.0 };
1289        assert!((p.log_density(1.0) - p.log_density(-1.0)).abs() < 1e-12);
1290    }
1291
1292    #[test]
1293    fn test_prior_jeffreys_positive() {
1294        let p = Prior::Jeffreys;
1295        assert!(p.log_density(2.0).is_finite());
1296    }
1297
1298    #[test]
1299    fn test_prior_jeffreys_non_positive() {
1300        let p = Prior::Jeffreys;
1301        assert_eq!(p.log_density(0.0), f64::NEG_INFINITY);
1302    }
1303
1304    #[test]
1305    fn test_prior_beta_at_half() {
1306        let p = Prior::Beta {
1307            alpha: 2.0,
1308            beta: 2.0,
1309        };
1310        assert!(p.log_density(0.5).is_finite());
1311    }
1312
1313    #[test]
1314    fn test_prior_beta_outside_support() {
1315        let p = Prior::Beta {
1316            alpha: 2.0,
1317            beta: 2.0,
1318        };
1319        assert_eq!(p.log_density(1.5), f64::NEG_INFINITY);
1320    }
1321
1322    #[test]
1323    fn test_prior_gamma_positive() {
1324        let p = Prior::Gamma {
1325            shape: 2.0,
1326            rate: 1.0,
1327        };
1328        assert!(p.log_density(1.0).is_finite());
1329    }
1330
1331    #[test]
1332    fn test_prior_sample_uniform() {
1333        let p = Prior::Uniform {
1334            low: 0.0,
1335            high: 1.0,
1336        };
1337        let s = p.sample(42).unwrap();
1338        assert!((0.0..=1.0).contains(&s));
1339    }
1340
1341    #[test]
1342    fn test_prior_sample_gaussian() {
1343        let p = Prior::Gaussian {
1344            mean: 5.0,
1345            std: 0.1,
1346        };
1347        let s = p.sample(7).unwrap();
1348        assert!((s - 5.0).abs() < 2.0);
1349    }
1350
1351    #[test]
1352    fn test_prior_dirichlet_log_density() {
1353        let p = Prior::Dirichlet {
1354            alpha: vec![1.0, 1.0, 1.0],
1355        };
1356        let ld = p.dirichlet_log_density(&[1.0 / 3.0, 1.0 / 3.0, 1.0 / 3.0]);
1357        assert!(ld.is_finite());
1358    }
1359
1360    // ── Likelihood ───────────────────────────────────────────────────────────
1361
1362    #[test]
1363    fn test_likelihood_gaussian_log_ll_finite() {
1364        let ll = Likelihood::Gaussian {
1365            data: vec![1.0, 2.0, 3.0],
1366            sigma: 1.0,
1367        };
1368        assert!(ll.log_likelihood(2.0).is_finite());
1369    }
1370
1371    #[test]
1372    fn test_likelihood_gaussian_mode_at_mean() {
1373        let data = vec![2.0, 2.0, 2.0];
1374        let ll = Likelihood::Gaussian {
1375            data: data.clone(),
1376            sigma: 1.0,
1377        };
1378        let l2 = ll.log_likelihood(2.0);
1379        let ll2 = Likelihood::Gaussian { data, sigma: 1.0 };
1380        let l3 = ll2.log_likelihood(3.0);
1381        assert!(l2 > l3);
1382    }
1383
1384    #[test]
1385    fn test_likelihood_poisson_log_ll() {
1386        let ll = Likelihood::Poisson {
1387            counts: vec![3, 4, 5],
1388        };
1389        assert!(ll.log_likelihood(4.0).is_finite());
1390    }
1391
1392    #[test]
1393    fn test_likelihood_poisson_non_positive_lambda() {
1394        let ll = Likelihood::Poisson { counts: vec![1] };
1395        assert_eq!(ll.log_likelihood(0.0), f64::NEG_INFINITY);
1396    }
1397
1398    #[test]
1399    fn test_likelihood_bernoulli_log_ll() {
1400        let ll = Likelihood::Bernoulli {
1401            outcomes: vec![1, 0, 1],
1402        };
1403        assert!(ll.log_likelihood(0.6).is_finite());
1404    }
1405
1406    #[test]
1407    fn test_likelihood_bernoulli_outside_range() {
1408        let ll = Likelihood::Bernoulli { outcomes: vec![1] };
1409        assert_eq!(ll.log_likelihood(1.0), f64::NEG_INFINITY);
1410    }
1411
1412    #[test]
1413    fn test_likelihood_multinomial_log_ll() {
1414        let ll = Likelihood::Multinomial {
1415            counts: vec![3, 4, 3],
1416        };
1417        let probs = [0.3, 0.4, 0.3];
1418        let lp = ll.multinomial_log_likelihood(&probs);
1419        assert!(lp.is_finite());
1420    }
1421
1422    // ── BayesianUpdate ───────────────────────────────────────────────────────
1423
1424    #[test]
1425    fn test_normal_normal_posterior_shrinks_toward_data() {
1426        let data = vec![5.0, 5.0, 5.0, 5.0, 5.0];
1427        let (mu_n, _) = BayesianUpdate::normal_normal(0.0, 10.0, 1.0, &data);
1428        assert!(mu_n > 3.0); // Should be pulled toward 5
1429    }
1430
1431    #[test]
1432    fn test_normal_normal_empty_data() {
1433        let (mu_n, sigma_n) = BayesianUpdate::normal_normal(1.0, 2.0, 1.0, &[]);
1434        assert_eq!((mu_n, sigma_n), (1.0, 2.0));
1435    }
1436
1437    #[test]
1438    fn test_beta_bernoulli_update() {
1439        let (alpha_n, beta_n) = BayesianUpdate::beta_bernoulli(1.0, 1.0, 7, 10);
1440        assert!((alpha_n - 8.0).abs() < 1e-12);
1441        assert!((beta_n - 4.0).abs() < 1e-12);
1442    }
1443
1444    #[test]
1445    fn test_gamma_poisson_update() {
1446        let (alpha_n, beta_n) = BayesianUpdate::gamma_poisson(2.0, 1.0, &[3, 4, 5]);
1447        assert!((alpha_n - 14.0).abs() < 1e-12);
1448        assert!((beta_n - 4.0).abs() < 1e-12);
1449    }
1450
1451    #[test]
1452    fn test_dirichlet_multinomial_update() {
1453        let alpha_n = BayesianUpdate::dirichlet_multinomial(&[1.0, 1.0, 1.0], &[3, 2, 5]);
1454        assert!((alpha_n[0] - 4.0).abs() < 1e-12);
1455        assert!((alpha_n[2] - 6.0).abs() < 1e-12);
1456    }
1457
1458    #[test]
1459    fn test_normal_inverse_gamma_update() {
1460        let (mu_n, kappa_n, alpha_n, beta_n) =
1461            BayesianUpdate::normal_inverse_gamma(0.0, 1.0, 2.0, 3.0, &[1.0, 2.0, 3.0]);
1462        assert!(kappa_n > 1.0);
1463        assert!(alpha_n > 2.0);
1464        assert!(beta_n > 3.0);
1465        let _ = mu_n;
1466    }
1467
1468    // ── MCMC ─────────────────────────────────────────────────────────────────
1469
1470    #[test]
1471    fn test_mh_samples_correct_count() {
1472        let mcmc = MarkovChainMonteCarlo::new(0.5, 100);
1473        let samples = mcmc.metropolis_hastings(|x| -0.5 * x * x, 0.0, 200, 42);
1474        assert_eq!(samples.len(), 200);
1475    }
1476
1477    #[test]
1478    fn test_mh_standard_normal_mean_close_to_zero() {
1479        let mcmc = MarkovChainMonteCarlo::new(1.0, 500);
1480        let samples = mcmc.metropolis_hastings(|x| -0.5 * x * x, 0.0, 1000, 99);
1481        let mean = samples.iter().sum::<f64>() / samples.len() as f64;
1482        assert!(mean.abs() < 0.3);
1483    }
1484
1485    #[test]
1486    fn test_gibbs_bivariate_gaussian_count() {
1487        let samples =
1488            MarkovChainMonteCarlo::gibbs_bivariate_gaussian(0.0, 0.0, 1.0, 1.0, 0.5, 100, 7);
1489        assert_eq!(samples.len(), 100);
1490    }
1491
1492    #[test]
1493    fn test_mh_vec_runs() {
1494        let mcmc = MarkovChainMonteCarlo::new(0.3, 50);
1495        let log_t = |x: &[f64]| -0.5 * x[0] * x[0] - 0.5 * x[1] * x[1];
1496        let samples = mcmc.metropolis_hastings_vec(log_t, vec![0.0, 0.0], 100, 13);
1497        assert_eq!(samples.len(), 100);
1498    }
1499
1500    #[test]
1501    fn test_nuts_step_returns_finite() {
1502        let mcmc = MarkovChainMonteCarlo::new(0.1, 0);
1503        let result = mcmc.nuts_step(|x| -0.5 * x * x, |x| -x, 0.0, 5);
1504        assert!(result.is_finite());
1505    }
1506
1507    #[test]
1508    fn test_ess_constant_chain() {
1509        let chain = vec![1.0; 100];
1510        let ess = MarkovChainMonteCarlo::effective_sample_size(&chain);
1511        assert!(ess > 0.0);
1512    }
1513
1514    // ── BayesianLinearRegression ─────────────────────────────────────────────
1515
1516    #[test]
1517    fn test_blr_fit_and_predict() {
1518        let x = vec![vec![1.0, 0.0], vec![1.0, 1.0], vec![1.0, 2.0]];
1519        let y = vec![1.0, 3.0, 5.0]; // y = 1 + 2x
1520        let mut blr = BayesianLinearRegression::new(1e-4, 10.0);
1521        blr.fit(&x, &y);
1522        let pred = blr.predict_mean(&[1.0, 1.5]);
1523        assert!((pred - 4.0).abs() < 1.0); // Should be close to 4
1524    }
1525
1526    #[test]
1527    fn test_blr_variance_positive() {
1528        let x = vec![vec![1.0, 0.0], vec![1.0, 1.0]];
1529        let y = vec![0.0, 1.0];
1530        let mut blr = BayesianLinearRegression::new(1.0, 1.0);
1531        blr.fit(&x, &y);
1532        let var = blr.predict_variance(&[1.0, 0.5]);
1533        assert!(var > 0.0);
1534    }
1535
1536    #[test]
1537    fn test_blr_log_evidence_finite_after_fit() {
1538        let x = vec![vec![1.0, 0.0], vec![1.0, 1.0], vec![1.0, 2.0]];
1539        let y = vec![1.0, 2.0, 3.0];
1540        let mut blr = BayesianLinearRegression::new(1.0, 1.0);
1541        blr.fit(&x, &y);
1542        let ev = blr.log_evidence(&x, &y);
1543        assert!(ev.is_finite());
1544    }
1545
1546    // ── GaussianProcess ──────────────────────────────────────────────────────
1547
1548    #[test]
1549    fn test_gp_rbf_kernel_at_same_point() {
1550        let k = Kernel::Rbf {
1551            length_scale: 1.0,
1552            signal_variance: 1.0,
1553        };
1554        assert!((k.eval(2.0, 2.0) - 1.0).abs() < 1e-12);
1555    }
1556
1557    #[test]
1558    fn test_gp_matern32_at_same_point() {
1559        let k = Kernel::Matern32 {
1560            length_scale: 1.0,
1561            signal_variance: 2.0,
1562        };
1563        assert!((k.eval(1.0, 1.0) - 2.0).abs() < 1e-12);
1564    }
1565
1566    #[test]
1567    fn test_gp_matern52_at_same_point() {
1568        let k = Kernel::Matern52 {
1569            length_scale: 1.0,
1570            signal_variance: 3.0,
1571        };
1572        assert!((k.eval(0.5, 0.5) - 3.0).abs() < 1e-12);
1573    }
1574
1575    #[test]
1576    fn test_gp_fit_and_predict() {
1577        let kernel = Kernel::Rbf {
1578            length_scale: 1.0,
1579            signal_variance: 1.0,
1580        };
1581        let mut gp = GaussianProcess::new(kernel, 0.01);
1582        let x_train: Vec<f64> = (0..5).map(|i| i as f64).collect();
1583        let y_train: Vec<f64> = x_train.iter().map(|&x| x * 2.0).collect();
1584        gp.fit(x_train, y_train);
1585        let pred = gp.predict_mean(2.0);
1586        assert!((pred - 4.0).abs() < 1.0);
1587    }
1588
1589    #[test]
1590    fn test_gp_variance_positive() {
1591        let kernel = Kernel::Rbf {
1592            length_scale: 1.0,
1593            signal_variance: 1.0,
1594        };
1595        let mut gp = GaussianProcess::new(kernel, 0.01);
1596        gp.fit(vec![0.0, 1.0], vec![0.0, 1.0]);
1597        let var = gp.predict_variance(5.0); // Far from training data
1598        assert!(var > 0.0);
1599    }
1600
1601    #[test]
1602    fn test_gp_log_marginal_likelihood_finite() {
1603        let kernel = Kernel::Rbf {
1604            length_scale: 1.0,
1605            signal_variance: 1.0,
1606        };
1607        let mut gp = GaussianProcess::new(kernel, 0.1);
1608        gp.fit(vec![0.0, 1.0, 2.0], vec![0.0, 1.0, 0.0]);
1609        assert!(gp.log_marginal_likelihood().is_finite());
1610    }
1611
1612    // ── ModelSelection ───────────────────────────────────────────────────────
1613
1614    #[test]
1615    fn test_aic_basic() {
1616        let aic = ModelSelection::aic(-100.0, 5);
1617        assert!((aic - (10.0 + 200.0)).abs() < 1e-12);
1618    }
1619
1620    #[test]
1621    fn test_bic_basic() {
1622        let bic = ModelSelection::bic(-100.0, 3, 50);
1623        let expected = 3.0 * 50.0_f64.ln() + 200.0;
1624        assert!((bic - expected).abs() < 1e-10);
1625    }
1626
1627    #[test]
1628    fn test_aicc_greater_than_aic() {
1629        let aic = ModelSelection::aic(-100.0, 4);
1630        let aicc = ModelSelection::aicc(-100.0, 4, 20);
1631        assert!(aicc >= aic);
1632    }
1633
1634    #[test]
1635    fn test_log_bayes_factor_symmetric() {
1636        let lbf = ModelSelection::log_bayes_factor(10.0, 10.0);
1637        assert_eq!(lbf, 0.0);
1638    }
1639
1640    #[test]
1641    fn test_jeffreys_scale_decisive() {
1642        let label = ModelSelection::jeffreys_scale(5.0); // e^5 >> 100
1643        assert_eq!(label, "Decisive");
1644    }
1645
1646    #[test]
1647    fn test_k_fold_cv_mse_finite() {
1648        let x: Vec<Vec<f64>> = (0..10).map(|i| vec![1.0, i as f64]).collect();
1649        let y: Vec<f64> = (0..10).map(|i| i as f64 * 2.0 + 1.0).collect();
1650        let mse = ModelSelection::k_fold_cv_mse(&x, &y, 5, 1.0, 1.0);
1651        assert!(mse.is_finite());
1652    }
1653
1654    #[test]
1655    fn test_loo_cv_log_predictive_finite() {
1656        let x: Vec<Vec<f64>> = (0..5).map(|i| vec![1.0, i as f64]).collect();
1657        let y: Vec<f64> = (0..5).map(|i| i as f64).collect();
1658        let lp = ModelSelection::loo_cv_log_predictive(&x, &y, 1.0, 1.0);
1659        assert!(lp.is_finite());
1660    }
1661
1662    // ── Utility ──────────────────────────────────────────────────────────────
1663
1664    #[test]
1665    fn test_log_gamma_half() {
1666        // Γ(1/2) = √π → ln Γ(1/2) = 0.5 ln π
1667        let lg = log_gamma(0.5);
1668        let expected = 0.5 * PI.ln();
1669        assert!((lg - expected).abs() < 1e-6);
1670    }
1671
1672    #[test]
1673    fn test_log_gamma_one() {
1674        // Γ(1) = 1 → ln Γ(1) = 0
1675        assert!(log_gamma(1.0).abs() < 1e-6);
1676    }
1677
1678    #[test]
1679    fn test_log_beta_symmetry() {
1680        assert!((log_beta(2.0, 3.0) - log_beta(3.0, 2.0)).abs() < 1e-12);
1681    }
1682
1683    #[test]
1684    fn test_cholesky_identity() {
1685        let eye = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
1686        let l = cholesky(&eye);
1687        assert!((l[0][0] - 1.0).abs() < 1e-12);
1688        assert!((l[1][1] - 1.0).abs() < 1e-12);
1689        assert!(l[1][0].abs() < 1e-12);
1690    }
1691
1692    #[test]
1693    fn test_mat_inverse_identity() {
1694        let eye = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
1695        let inv = mat_inverse(&eye);
1696        assert!((inv[0][0] - 1.0).abs() < 1e-10);
1697        assert!((inv[1][1] - 1.0).abs() < 1e-10);
1698    }
1699
1700    #[test]
1701    fn test_sample_gamma_positive() {
1702        let mut rng = BiRng::new(42);
1703        let s = sample_gamma(2.0, &mut rng);
1704        assert!(s > 0.0);
1705    }
1706}