Skip to main content

oxiphysics_core/
probabilistic_models.rs

1#![allow(clippy::needless_range_loop)]
2// Copyright 2026 COOLJAPAN OU (Team KitaSan)
3// SPDX-License-Identifier: Apache-2.0
4
5//! Probabilistic models including Bayesian networks, HMMs, Gaussian processes,
6//! Dirichlet processes, variational inference, and expectation-maximization.
7//!
8//! These models provide foundational probabilistic machinery for physics-informed
9//! machine learning, uncertainty quantification, and data-driven modeling.
10
11#![allow(dead_code)]
12
13use std::f64::consts::{PI, TAU};
14
15// ---------------------------------------------------------------------------
16// Helper math utilities
17// ---------------------------------------------------------------------------
18
19/// Returns the log of the standard normal density at `x`.
20fn log_normal_pdf(x: f64, mean: f64, var: f64) -> f64 {
21    -0.5 * ((x - mean).powi(2) / var + var.ln() + (TAU).ln())
22}
23
24/// Returns the normal density at `x`.
25fn normal_pdf(x: f64, mean: f64, var: f64) -> f64 {
26    (-(x - mean).powi(2) / (2.0 * var)).exp() / (TAU * var).sqrt()
27}
28
29/// Log-sum-exp trick for numerical stability.
30fn log_sum_exp(values: &[f64]) -> f64 {
31    let max = values.iter().copied().fold(f64::NEG_INFINITY, f64::max);
32    if max.is_infinite() {
33        return f64::NEG_INFINITY;
34    }
35    let sum: f64 = values.iter().map(|&v| (v - max).exp()).sum();
36    max + sum.ln()
37}
38
39/// Softmax of a slice, returns normalized probabilities.
40fn softmax(logits: &[f64]) -> Vec<f64> {
41    let max = logits.iter().copied().fold(f64::NEG_INFINITY, f64::max);
42    let exp: Vec<f64> = logits.iter().map(|&x| (x - max).exp()).collect();
43    let sum: f64 = exp.iter().sum::<f64>().max(1e-300);
44    exp.iter().map(|&e| e / sum).collect()
45}
46
47/// Computes the multivariate Gaussian log-density (diagonal covariance).
48fn mvn_log_pdf_diag(x: &[f64], mean: &[f64], var: &[f64]) -> f64 {
49    let d = x.len() as f64;
50    let log_det: f64 = var.iter().map(|v| v.max(1e-300).ln()).sum();
51    let maha: f64 = x
52        .iter()
53        .zip(mean.iter())
54        .zip(var.iter())
55        .map(|((&xi, &mi), &vi)| (xi - mi).powi(2) / vi.max(1e-300))
56        .sum();
57    -0.5 * (d * TAU.ln() + log_det + maha)
58}
59
60// ---------------------------------------------------------------------------
61// BayesianNetwork — DAG with CPTs and belief propagation
62// ---------------------------------------------------------------------------
63
64/// A node in a Bayesian network.
65#[derive(Debug, Clone)]
66pub struct BnNode {
67    /// Name of the variable.
68    pub name: String,
69    /// Number of discrete states.
70    pub n_states: usize,
71    /// Parent node indices.
72    pub parents: Vec<usize>,
73    /// Conditional probability table.
74    ///
75    /// Shape: `[parent_config_count, n_states]` flattened row-major.
76    /// For a root node with no parents: length = `n_states`.
77    pub cpt: Vec<f64>,
78}
79
80impl BnNode {
81    /// Creates a new `BnNode`.
82    pub fn new(
83        name: impl Into<String>,
84        n_states: usize,
85        parents: Vec<usize>,
86        cpt: Vec<f64>,
87    ) -> Self {
88        Self {
89            name: name.into(),
90            n_states,
91            parents,
92            cpt,
93        }
94    }
95
96    /// Returns the conditional probability P(state | parent_config).
97    pub fn cpt_value(&self, state: usize, parent_config: usize) -> f64 {
98        let offset = parent_config * self.n_states;
99        self.cpt[offset + state]
100    }
101}
102
103/// A Bayesian Network: directed acyclic graph with conditional probability tables.
104///
105/// Supports exact inference via variable elimination on small networks and
106/// loopy belief propagation on larger ones.
107#[derive(Debug, Clone)]
108pub struct BayesianNetwork {
109    /// Nodes in topological order.
110    pub nodes: Vec<BnNode>,
111}
112
113impl BayesianNetwork {
114    /// Creates a new empty `BayesianNetwork`.
115    pub fn new() -> Self {
116        Self { nodes: Vec::new() }
117    }
118
119    /// Adds a node and returns its index.
120    pub fn add_node(&mut self, node: BnNode) -> usize {
121        let idx = self.nodes.len();
122        self.nodes.push(node);
123        idx
124    }
125
126    /// Computes the joint probability of a complete assignment.
127    ///
128    /// `assignment[i]` is the state of node `i`.
129    pub fn joint_probability(&self, assignment: &[usize]) -> f64 {
130        let mut prob = 1.0f64;
131        for (i, node) in self.nodes.iter().enumerate() {
132            let parent_config = self.parent_config_index(i, assignment);
133            prob *= node.cpt_value(assignment[i], parent_config);
134        }
135        prob
136    }
137
138    /// Computes the parent configuration index for node `i` given full assignment.
139    fn parent_config_index(&self, node_idx: usize, assignment: &[usize]) -> usize {
140        let node = &self.nodes[node_idx];
141        let mut config = 0usize;
142        for &p in &node.parents {
143            let p_states = self.nodes[p].n_states;
144            config = config * p_states + assignment[p];
145        }
146        config
147    }
148
149    /// Computes the marginal probability of node `target` being in `state`
150    /// by summing over all other assignments (exact, exponential complexity).
151    pub fn marginal(&self, target: usize, target_state: usize) -> f64 {
152        let n = self.nodes.len();
153        // Enumerate all assignments via mixed-radix counter
154        let n_states: Vec<usize> = self.nodes.iter().map(|nd| nd.n_states).collect();
155        let total: usize = n_states.iter().product();
156        let mut prob = 0.0f64;
157        let mut assignment = vec![0usize; n];
158        for _ in 0..total {
159            if assignment[target] == target_state {
160                prob += self.joint_probability(&assignment);
161            }
162            // Increment mixed-radix counter
163            let mut carry = 1;
164            for i in (0..n).rev() {
165                let next = assignment[i] + carry;
166                assignment[i] = next % n_states[i];
167                carry = next / n_states[i];
168                if carry == 0 {
169                    break;
170                }
171            }
172        }
173        prob
174    }
175
176    /// Returns marginal probabilities for all states of node `target`.
177    pub fn marginal_all(&self, target: usize) -> Vec<f64> {
178        let n_states = self.nodes[target].n_states;
179        (0..n_states).map(|s| self.marginal(target, s)).collect()
180    }
181
182    /// Computes conditional probability P(target=state | evidence).
183    ///
184    /// `evidence` is a list of `(node_idx, state)` observations.
185    pub fn conditional(
186        &self,
187        target: usize,
188        target_state: usize,
189        evidence: &[(usize, usize)],
190    ) -> f64 {
191        let n = self.nodes.len();
192        let n_states: Vec<usize> = self.nodes.iter().map(|nd| nd.n_states).collect();
193        let total: usize = n_states.iter().product();
194        let mut num = 0.0f64;
195        let mut denom = 0.0f64;
196        let mut assignment = vec![0usize; n];
197        for _ in 0..total {
198            // Check evidence consistency
199            let consistent = evidence.iter().all(|&(ni, s)| assignment[ni] == s);
200            if consistent {
201                let p = self.joint_probability(&assignment);
202                denom += p;
203                if assignment[target] == target_state {
204                    num += p;
205                }
206            }
207            let mut carry = 1;
208            for i in (0..n).rev() {
209                let next = assignment[i] + carry;
210                assignment[i] = next % n_states[i];
211                carry = next / n_states[i];
212                if carry == 0 {
213                    break;
214                }
215            }
216        }
217        if denom < 1e-300 { 0.0 } else { num / denom }
218    }
219
220    /// Checks that all CPTs are valid (non-negative, rows sum to 1).
221    pub fn validate(&self) -> bool {
222        for node in &self.nodes {
223            let n_configs = if node.parents.is_empty() {
224                1
225            } else {
226                node.cpt.len() / node.n_states
227            };
228            for cfg in 0..n_configs {
229                let sum: f64 = (0..node.n_states).map(|s| node.cpt_value(s, cfg)).sum();
230                if (sum - 1.0).abs() > 1e-6 {
231                    return false;
232                }
233            }
234        }
235        true
236    }
237}
238
239impl Default for BayesianNetwork {
240    fn default() -> Self {
241        Self::new()
242    }
243}
244
245// ---------------------------------------------------------------------------
246// HiddenMarkovModel
247// ---------------------------------------------------------------------------
248
249/// A Hidden Markov Model with discrete hidden states and Gaussian emissions.
250///
251/// Supports:
252/// - Forward algorithm (likelihood computation)
253/// - Viterbi algorithm (MAP state sequence)
254/// - Baum-Welch EM (parameter learning)
255#[derive(Debug, Clone)]
256pub struct HiddenMarkovModel {
257    /// Number of hidden states.
258    pub n_states: usize,
259    /// Initial state distribution π.
260    pub initial: Vec<f64>,
261    /// Transition matrix A\[i\]\[j\] = P(s_t=j | s_{t-1}=i).
262    pub transition: Vec<Vec<f64>>,
263    /// Emission mean for each state.
264    pub emission_mean: Vec<f64>,
265    /// Emission variance for each state.
266    pub emission_var: Vec<f64>,
267}
268
269impl HiddenMarkovModel {
270    /// Creates a new `HiddenMarkovModel` with given parameters.
271    pub fn new(
272        n_states: usize,
273        initial: Vec<f64>,
274        transition: Vec<Vec<f64>>,
275        emission_mean: Vec<f64>,
276        emission_var: Vec<f64>,
277    ) -> Self {
278        Self {
279            n_states,
280            initial,
281            transition,
282            emission_mean,
283            emission_var,
284        }
285    }
286
287    /// Creates a uniform HMM with `n_states` states.
288    pub fn uniform(n_states: usize) -> Self {
289        let p = 1.0 / n_states as f64;
290        let initial = vec![p; n_states];
291        let transition = vec![vec![p; n_states]; n_states];
292        let emission_mean: Vec<f64> = (0..n_states).map(|i| i as f64).collect();
293        let emission_var = vec![1.0; n_states];
294        Self::new(n_states, initial, transition, emission_mean, emission_var)
295    }
296
297    /// Computes log emission probability of `obs` in state `s`.
298    fn log_emit(&self, s: usize, obs: f64) -> f64 {
299        log_normal_pdf(obs, self.emission_mean[s], self.emission_var[s])
300    }
301
302    /// Forward algorithm: returns log-likelihood of observation sequence.
303    pub fn forward(&self, observations: &[f64]) -> f64 {
304        let t_len = observations.len();
305        if t_len == 0 {
306            return 0.0;
307        }
308        let k = self.n_states;
309        let mut alpha = vec![0.0f64; k];
310        // Initialization
311        for s in 0..k {
312            alpha[s] = self.initial[s].ln() + self.log_emit(s, observations[0]);
313        }
314        // Recursion
315        for t in 1..t_len {
316            let mut alpha_new = vec![f64::NEG_INFINITY; k];
317            for j in 0..k {
318                let log_emit_j = self.log_emit(j, observations[t]);
319                let terms: Vec<f64> = (0..k)
320                    .map(|i| alpha[i] + self.transition[i][j].max(1e-300).ln())
321                    .collect();
322                alpha_new[j] = log_sum_exp(&terms) + log_emit_j;
323            }
324            alpha = alpha_new;
325        }
326        log_sum_exp(&alpha)
327    }
328
329    /// Viterbi algorithm: returns the most likely state sequence.
330    pub fn viterbi(&self, observations: &[f64]) -> Vec<usize> {
331        let t_len = observations.len();
332        if t_len == 0 {
333            return Vec::new();
334        }
335        let k = self.n_states;
336        let mut delta = vec![vec![0.0f64; k]; t_len];
337        let mut psi = vec![vec![0usize; k]; t_len];
338
339        // Initialization
340        for s in 0..k {
341            delta[0][s] = self.initial[s].max(1e-300).ln() + self.log_emit(s, observations[0]);
342        }
343
344        // Recursion
345        for t in 1..t_len {
346            for j in 0..k {
347                let (best_s, best_val) = (0..k)
348                    .map(|i| {
349                        let v = delta[t - 1][i] + self.transition[i][j].max(1e-300).ln();
350                        (i, v)
351                    })
352                    .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
353                    .expect("states iterator is non-empty");
354                delta[t][j] = best_val + self.log_emit(j, observations[t]);
355                psi[t][j] = best_s;
356            }
357        }
358
359        // Backtrack
360        let mut path = vec![0usize; t_len];
361        path[t_len - 1] = (0..k)
362            .max_by(|&a, &b| {
363                delta[t_len - 1][a]
364                    .partial_cmp(&delta[t_len - 1][b])
365                    .unwrap_or(std::cmp::Ordering::Equal)
366            })
367            .expect("k states is non-empty");
368        for t in (0..t_len - 1).rev() {
369            path[t] = psi[t + 1][path[t + 1]];
370        }
371        path
372    }
373
374    /// Baum-Welch EM algorithm for parameter estimation.
375    ///
376    /// Returns the log-likelihood at each iteration.
377    pub fn baum_welch(&mut self, observations: &[f64], n_iter: usize) -> Vec<f64> {
378        let t_len = observations.len();
379        let k = self.n_states;
380        let mut ll_history = Vec::new();
381
382        for _iter in 0..n_iter {
383            // E-step: Forward-Backward
384            // Forward pass (log scale)
385            let mut log_alpha = vec![vec![0.0f64; k]; t_len];
386            for s in 0..k {
387                log_alpha[0][s] =
388                    self.initial[s].max(1e-300).ln() + self.log_emit(s, observations[0]);
389            }
390            for t in 1..t_len {
391                for j in 0..k {
392                    let terms: Vec<f64> = (0..k)
393                        .map(|i| log_alpha[t - 1][i] + self.transition[i][j].max(1e-300).ln())
394                        .collect();
395                    log_alpha[t][j] = log_sum_exp(&terms) + self.log_emit(j, observations[t]);
396                }
397            }
398            let log_ll = log_sum_exp(&log_alpha[t_len - 1]);
399            ll_history.push(log_ll);
400
401            // Backward pass
402            let mut log_beta = vec![vec![0.0f64; k]; t_len];
403            // log_beta[T-1][s] = log(1) = 0
404            for t in (0..t_len - 1).rev() {
405                for i in 0..k {
406                    let terms: Vec<f64> = (0..k)
407                        .map(|j| {
408                            self.transition[i][j].max(1e-300).ln()
409                                + self.log_emit(j, observations[t + 1])
410                                + log_beta[t + 1][j]
411                        })
412                        .collect();
413                    log_beta[t][i] = log_sum_exp(&terms);
414                }
415            }
416
417            // Compute gamma and xi
418            // gamma[t][s] = P(S_t=s | obs)
419            let mut gamma = vec![vec![0.0f64; k]; t_len];
420            for t in 0..t_len {
421                let log_probs: Vec<f64> =
422                    (0..k).map(|s| log_alpha[t][s] + log_beta[t][s]).collect();
423                let norm = log_sum_exp(&log_probs);
424                for s in 0..k {
425                    gamma[t][s] = (log_probs[s] - norm).exp();
426                }
427            }
428
429            // xi[t][i][j] = P(S_t=i, S_{t+1}=j | obs)
430            let mut xi = vec![vec![vec![0.0f64; k]; k]; t_len.saturating_sub(1)];
431            for t in 0..t_len.saturating_sub(1) {
432                let mut xi_t = vec![vec![0.0f64; k]; k];
433                let mut log_xi_t = vec![vec![0.0f64; k]; k];
434                for i in 0..k {
435                    for j in 0..k {
436                        log_xi_t[i][j] = log_alpha[t][i]
437                            + self.transition[i][j].max(1e-300).ln()
438                            + self.log_emit(j, observations[t + 1])
439                            + log_beta[t + 1][j];
440                    }
441                }
442                let flat: Vec<f64> = log_xi_t.iter().flat_map(|r| r.iter().copied()).collect();
443                let norm = log_sum_exp(&flat);
444                for i in 0..k {
445                    for j in 0..k {
446                        xi_t[i][j] = (log_xi_t[i][j] - norm).exp();
447                    }
448                }
449                xi[t] = xi_t;
450            }
451
452            // M-step: update parameters
453            // Update initial
454            for s in 0..k {
455                self.initial[s] = gamma[0][s].max(1e-300);
456            }
457            let init_sum: f64 = self.initial.iter().sum::<f64>().max(1e-300);
458            for s in 0..k {
459                self.initial[s] /= init_sum;
460            }
461
462            // Update transition
463            for i in 0..k {
464                let denom: f64 = (0..t_len.saturating_sub(1))
465                    .map(|t| gamma[t][i])
466                    .sum::<f64>()
467                    .max(1e-300);
468                for j in 0..k {
469                    let num: f64 = (0..t_len.saturating_sub(1)).map(|t| xi[t][i][j]).sum();
470                    self.transition[i][j] = (num / denom).max(1e-300);
471                }
472                // Renormalize row
473                let row_sum: f64 = self.transition[i].iter().sum::<f64>().max(1e-300);
474                for j in 0..k {
475                    self.transition[i][j] /= row_sum;
476                }
477            }
478
479            // Update emission parameters
480            for s in 0..k {
481                let denom: f64 = (0..t_len).map(|t| gamma[t][s]).sum::<f64>().max(1e-300);
482                let new_mean: f64 = (0..t_len)
483                    .map(|t| gamma[t][s] * observations[t])
484                    .sum::<f64>()
485                    / denom;
486                let new_var: f64 = ((0..t_len)
487                    .map(|t| gamma[t][s] * (observations[t] - new_mean).powi(2))
488                    .sum::<f64>()
489                    / denom)
490                    .max(1e-6);
491                self.emission_mean[s] = new_mean;
492                self.emission_var[s] = new_var;
493            }
494        }
495        ll_history
496    }
497}
498
499// ---------------------------------------------------------------------------
500// GaussianProcess
501// ---------------------------------------------------------------------------
502
503/// Available kernel functions for Gaussian Processes.
504#[derive(Debug, Clone, Copy, PartialEq)]
505pub enum KernelType {
506    /// Squared exponential (RBF) kernel: `k(x,x') = σ² exp(-||x-x'||²/(2ℓ²))`.
507    Rbf,
508    /// Matérn 3/2 kernel: `k(x,x') = σ²(1+√3r/ℓ)exp(-√3r/ℓ)`.
509    Matern32,
510    /// Matérn 5/2 kernel.
511    Matern52,
512    /// Periodic kernel: `k(x,x') = σ² exp(-2sin²(π|x-x'|/p)/ℓ²)`.
513    Periodic,
514}
515
516/// Gaussian Process for regression with various kernel functions.
517///
518/// Maintains training data and supports posterior mean/variance prediction.
519#[derive(Debug, Clone)]
520pub struct GaussianProcess {
521    /// Kernel type.
522    pub kernel: KernelType,
523    /// Signal variance σ².
524    pub signal_var: f64,
525    /// Length scale ℓ.
526    pub length_scale: f64,
527    /// Period (for periodic kernel).
528    pub period: f64,
529    /// Noise variance.
530    pub noise_var: f64,
531    /// Training inputs (1D for simplicity).
532    pub x_train: Vec<f64>,
533    /// Training targets.
534    pub y_train: Vec<f64>,
535    /// Cholesky factor of K + noise*I (column-major flattened).
536    chol: Vec<f64>,
537    /// alpha = L^{-T} L^{-1} y.
538    alpha: Vec<f64>,
539}
540
541impl GaussianProcess {
542    /// Creates a new `GaussianProcess` with given hyperparameters.
543    pub fn new(kernel: KernelType, signal_var: f64, length_scale: f64, noise_var: f64) -> Self {
544        Self {
545            kernel,
546            signal_var,
547            length_scale,
548            period: 1.0,
549            noise_var,
550            x_train: Vec::new(),
551            y_train: Vec::new(),
552            chol: Vec::new(),
553            alpha: Vec::new(),
554        }
555    }
556
557    /// Sets the period for periodic kernel.
558    pub fn with_period(mut self, period: f64) -> Self {
559        self.period = period;
560        self
561    }
562
563    /// Evaluates the kernel between two scalar inputs.
564    pub fn k(&self, x1: f64, x2: f64) -> f64 {
565        let r = (x1 - x2).abs();
566        match self.kernel {
567            KernelType::Rbf => {
568                self.signal_var * (-r * r / (2.0 * self.length_scale * self.length_scale)).exp()
569            }
570            KernelType::Matern32 => {
571                let sq3r = 3.0f64.sqrt() * r / self.length_scale;
572                self.signal_var * (1.0 + sq3r) * (-sq3r).exp()
573            }
574            KernelType::Matern52 => {
575                let sq5r = 5.0f64.sqrt() * r / self.length_scale;
576                self.signal_var * (1.0 + sq5r + sq5r * sq5r / 3.0) * (-sq5r).exp()
577            }
578            KernelType::Periodic => {
579                let arg = PI * r / self.period;
580                self.signal_var
581                    * (-2.0 * arg.sin().powi(2) / (self.length_scale * self.length_scale)).exp()
582            }
583        }
584    }
585
586    /// Fits the GP to training data by computing the Cholesky decomposition.
587    pub fn fit(&mut self, x_train: Vec<f64>, y_train: Vec<f64>) {
588        let n = x_train.len();
589        self.x_train = x_train;
590        self.y_train = y_train.clone();
591
592        // Build K + noise*I
593        let mut k_mat = vec![0.0f64; n * n];
594        for i in 0..n {
595            for j in 0..n {
596                k_mat[i * n + j] = self.k(self.x_train[i], self.x_train[j]);
597            }
598            k_mat[i * n + i] += self.noise_var;
599        }
600
601        // Cholesky (lower triangular, in-place)
602        let mut l = k_mat.clone();
603        for i in 0..n {
604            for j in 0..=i {
605                let mut s = l[i * n + j];
606                for k_idx in 0..j {
607                    s -= l[i * n + k_idx] * l[j * n + k_idx];
608                }
609                if i == j {
610                    l[i * n + j] = s.max(1e-12).sqrt();
611                } else {
612                    l[i * n + j] = s / l[j * n + j].max(1e-12);
613                }
614            }
615            // zero upper triangle
616            for j in i + 1..n {
617                l[i * n + j] = 0.0;
618            }
619        }
620        self.chol = l.clone();
621
622        // Solve L * w = y  →  alpha = L^T \ w
623        let mut w = y_train;
624        // Forward substitution: L w = y
625        for i in 0..n {
626            let mut s = w[i];
627            for j in 0..i {
628                s -= l[i * n + j] * w[j];
629            }
630            w[i] = s / l[i * n + i].max(1e-12);
631        }
632        // Back substitution: L^T alpha = w
633        let mut alpha = w;
634        for i in (0..n).rev() {
635            let mut s = alpha[i];
636            for j in i + 1..n {
637                s -= l[j * n + i] * alpha[j];
638            }
639            alpha[i] = s / l[i * n + i].max(1e-12);
640        }
641        self.alpha = alpha;
642    }
643
644    /// Predicts posterior mean and variance at test input `x_star`.
645    pub fn predict(&self, x_star: f64) -> (f64, f64) {
646        let n = self.x_train.len();
647        if n == 0 {
648            return (0.0, self.signal_var + self.noise_var);
649        }
650
651        // k_star = K(x_star, X_train)
652        let k_star: Vec<f64> = self.x_train.iter().map(|&xi| self.k(x_star, xi)).collect();
653
654        // mean = k_star^T alpha
655        let mean: f64 = k_star
656            .iter()
657            .zip(self.alpha.iter())
658            .map(|(a, b)| a * b)
659            .sum();
660
661        // variance = k(x*,x*) - k_star^T (K+sI)^{-1} k_star
662        // = k(x*,x*) - v^T v  where v = L^{-1} k_star
663        let mut v = k_star.clone();
664        for i in 0..n {
665            let mut s = v[i];
666            for j in 0..i {
667                s -= self.chol[i * n + j] * v[j];
668            }
669            v[i] = s / self.chol[i * n + i].max(1e-12);
670        }
671        let var = (self.k(x_star, x_star) - v.iter().map(|vi| vi * vi).sum::<f64>()).max(1e-12);
672
673        (mean, var)
674    }
675
676    /// Computes the log marginal likelihood.
677    pub fn log_marginal_likelihood(&self) -> f64 {
678        let n = self.x_train.len();
679        if n == 0 {
680            return 0.0;
681        }
682        // log p(y|X,θ) = -0.5 y^T α - Σ log L_ii - n/2 log(2π)
683        let data_fit: f64 = self
684            .y_train
685            .iter()
686            .zip(self.alpha.iter())
687            .map(|(y, a)| y * a)
688            .sum::<f64>();
689        let log_det: f64 = (0..n)
690            .map(|i| self.chol[i * n + i].max(1e-300).ln())
691            .sum::<f64>();
692        -0.5 * data_fit - log_det - 0.5 * n as f64 * TAU.ln()
693    }
694}
695
696// ---------------------------------------------------------------------------
697// DirichletProcess
698// ---------------------------------------------------------------------------
699
700/// Dirichlet Process mixture model using the Chinese Restaurant Process.
701///
702/// New data points are assigned to existing clusters with probability
703/// proportional to cluster size, or start a new cluster with probability α/(n+α).
704#[derive(Debug, Clone)]
705pub struct DirichletProcess {
706    /// Concentration parameter α.
707    pub alpha: f64,
708    /// Cluster assignments for each observed data point.
709    pub assignments: Vec<usize>,
710    /// Number of points in each cluster.
711    pub cluster_counts: Vec<usize>,
712    /// Cluster means (updated incrementally).
713    pub cluster_means: Vec<f64>,
714    /// Cluster sum of squares (for variance estimation).
715    pub cluster_ss: Vec<f64>,
716    /// Total number of data points assigned.
717    pub n_assigned: usize,
718}
719
720impl DirichletProcess {
721    /// Creates a new `DirichletProcess` with concentration `alpha`.
722    pub fn new(alpha: f64) -> Self {
723        Self {
724            alpha,
725            assignments: Vec::new(),
726            cluster_counts: Vec::new(),
727            cluster_means: Vec::new(),
728            cluster_ss: Vec::new(),
729            n_assigned: 0,
730        }
731    }
732
733    /// Returns the number of clusters.
734    pub fn n_clusters(&self) -> usize {
735        self.cluster_counts.len()
736    }
737
738    /// Assigns a new data point via Chinese Restaurant Process probabilities.
739    ///
740    /// Returns the cluster index assigned (deterministic: picks highest prob).
741    pub fn crp_assign(&mut self, x: f64) -> usize {
742        let n = self.n_assigned as f64;
743        let k = self.cluster_counts.len();
744
745        // Compute unnormalized probabilities
746        let mut probs: Vec<f64> = self
747            .cluster_counts
748            .iter()
749            .map(|&cnt| cnt as f64 / (n + self.alpha))
750            .collect();
751        probs.push(self.alpha / (n + self.alpha)); // new cluster
752
753        // Pick cluster with highest probability (deterministic for reproducibility)
754        let chosen = probs
755            .iter()
756            .enumerate()
757            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
758            .map(|(i, _)| i)
759            .unwrap_or(k);
760
761        if chosen == k {
762            // New cluster
763            self.cluster_counts.push(1);
764            self.cluster_means.push(x);
765            self.cluster_ss.push(0.0);
766        } else {
767            // Existing cluster — update incrementally
768            let cnt = self.cluster_counts[chosen] as f64;
769            let old_mean = self.cluster_means[chosen];
770            self.cluster_counts[chosen] += 1;
771            let new_mean = old_mean + (x - old_mean) / (cnt + 1.0);
772            self.cluster_ss[chosen] += (x - old_mean) * (x - new_mean);
773            self.cluster_means[chosen] = new_mean;
774        }
775        self.assignments.push(chosen);
776        self.n_assigned += 1;
777        chosen
778    }
779
780    /// Stick-breaking construction: samples first `k` mixture weights.
781    ///
782    /// Returns weights `(w_1, ..., w_k)` where `Σ w_i ≈ 1`.
783    /// Uses pseudo-random beta draws based on alpha and index.
784    pub fn stick_breaking_weights(&self, k: usize) -> Vec<f64> {
785        let mut weights = Vec::with_capacity(k);
786        let mut remaining = 1.0f64;
787        for i in 0..k {
788            // Deterministic approximation: mean of Beta(1, alpha)
789            let mean_beta = 1.0 / (1.0 + self.alpha);
790            // Slight variation per component
791            let v = mean_beta * (1.0 - 0.1 * i as f64 / (k as f64 + 1.0));
792            let v = v.clamp(1e-6, 1.0 - 1e-6);
793            let w = remaining * v;
794            weights.push(w);
795            remaining *= 1.0 - v;
796        }
797        // Add the remaining stick to the last component so weights sum to 1
798        if let Some(last) = weights.last_mut() {
799            *last += remaining;
800        }
801        // Normalize to sum exactly to 1
802        let total: f64 = weights.iter().sum::<f64>().max(1e-300);
803        weights.iter_mut().for_each(|w| *w /= total);
804        weights
805    }
806
807    /// Returns cluster variance estimates (unbiased).
808    pub fn cluster_variances(&self) -> Vec<f64> {
809        self.cluster_counts
810            .iter()
811            .zip(self.cluster_ss.iter())
812            .map(
813                |(&cnt, &ss)| {
814                    if cnt > 1 { ss / (cnt - 1) as f64 } else { 1.0 }
815                },
816            )
817            .collect()
818    }
819
820    /// Returns the expected number of clusters for n observations (approximation).
821    ///
822    /// `E[K_n] ≈ α ln(1 + n/α)`
823    pub fn expected_clusters(alpha: f64, n: usize) -> f64 {
824        alpha * (1.0 + n as f64 / alpha).ln()
825    }
826}
827
828// ---------------------------------------------------------------------------
829// VariationalInference
830// ---------------------------------------------------------------------------
831
832/// Mean-field variational inference for a Gaussian mixture model.
833///
834/// Maximizes the Evidence Lower BOund (ELBO) by coordinate ascent
835/// over the variational posterior factors.
836#[derive(Debug, Clone)]
837pub struct VariationalInference {
838    /// Number of mixture components.
839    pub n_components: usize,
840    /// Variational component weights (log-scale unnormalized).
841    pub log_weights: Vec<f64>,
842    /// Variational mean for each component.
843    pub var_mean: Vec<f64>,
844    /// Variational variance for each component.
845    pub var_var: Vec<f64>,
846    /// Prior mean.
847    pub prior_mean: f64,
848    /// Prior variance.
849    pub prior_var: f64,
850    /// Observation noise variance.
851    pub obs_var: f64,
852    /// ELBO history.
853    pub elbo_history: Vec<f64>,
854}
855
856impl VariationalInference {
857    /// Creates a new `VariationalInference` instance.
858    pub fn new(n_components: usize, prior_mean: f64, prior_var: f64, obs_var: f64) -> Self {
859        let log_weights = vec![-(n_components as f64).ln(); n_components];
860        let var_mean: Vec<f64> = (0..n_components).map(|i| i as f64).collect();
861        let var_var = vec![1.0f64; n_components];
862        Self {
863            n_components,
864            log_weights,
865            var_mean,
866            var_var,
867            prior_mean,
868            prior_var,
869            obs_var,
870            elbo_history: Vec::new(),
871        }
872    }
873
874    /// Computes the ELBO for current variational parameters given observations.
875    pub fn elbo(&self, observations: &[f64]) -> f64 {
876        let weights = softmax(&self.log_weights);
877        let mut elbo = 0.0f64;
878        // Expected log likelihood
879        for &x in observations {
880            let ll_terms: Vec<f64> = (0..self.n_components)
881                .map(|k| {
882                    weights[k].max(1e-300).ln()
883                        + log_normal_pdf(x, self.var_mean[k], self.obs_var + self.var_var[k])
884                })
885                .collect();
886            elbo += log_sum_exp(&ll_terms);
887        }
888        // KL divergence: Σ_k w_k KL(q(z_k) || p(z_k))
889        for k in 0..self.n_components {
890            // KL(N(μ_q, σ_q²) || N(μ_p, σ_p²))
891            let kl = 0.5
892                * (self.prior_var / self.var_var[k].max(1e-12)
893                    + (self.var_mean[k] - self.prior_mean).powi(2) / self.prior_var
894                    - 1.0
895                    + (self.var_var[k] / self.prior_var).ln());
896            elbo -= weights[k] * kl;
897        }
898        elbo
899    }
900
901    /// Performs one CAVI update step.
902    ///
903    /// Returns the new ELBO.
904    pub fn cavi_step(&mut self, observations: &[f64]) -> f64 {
905        let n = observations.len() as f64;
906        // Update variational parameters for each component
907        for k in 0..self.n_components {
908            let weights = softmax(&self.log_weights);
909            // Responsibility of component k for each observation
910            let r_k: Vec<f64> = observations
911                .iter()
912                .map(|&x| weights[k] * normal_pdf(x, self.var_mean[k], self.obs_var))
913                .collect();
914            let r_sum: f64 = r_k.iter().sum::<f64>().max(1e-300);
915
916            // Update mean: posterior precision = prior_prec + r_sum/obs_var
917            let prior_prec = 1.0 / self.prior_var.max(1e-12);
918            let lik_prec = r_sum / self.obs_var.max(1e-12);
919            let post_prec = prior_prec + lik_prec;
920            let post_var = 1.0 / post_prec.max(1e-12);
921            let data_sum: f64 = r_k
922                .iter()
923                .zip(observations.iter())
924                .map(|(r, x)| r * x)
925                .sum();
926            let post_mean =
927                post_var * (prior_prec * self.prior_mean + data_sum / self.obs_var.max(1e-12));
928
929            self.var_mean[k] = post_mean;
930            self.var_var[k] = post_var;
931
932            // Update weight (log)
933            self.log_weights[k] = r_sum.max(1e-300).ln();
934        }
935        // Renormalize log_weights
936        let lse = log_sum_exp(&self.log_weights.clone());
937        for k in 0..self.n_components {
938            self.log_weights[k] -= lse;
939        }
940        let _ = n;
941        let elbo_val = self.elbo(observations);
942        self.elbo_history.push(elbo_val);
943        elbo_val
944    }
945
946    /// Runs CAVI for `n_iter` iterations.
947    pub fn fit(&mut self, observations: &[f64], n_iter: usize) -> f64 {
948        for _ in 0..n_iter {
949            self.cavi_step(observations);
950        }
951        *self.elbo_history.last().unwrap_or(&f64::NEG_INFINITY)
952    }
953
954    /// Reparameterization trick: samples from `q(z) = N(mu, sigma^2)` using `eps ~ N(0,1)`.
955    pub fn reparameterize(&self, k: usize, eps: f64) -> f64 {
956        self.var_mean[k] + self.var_var[k].sqrt() * eps
957    }
958
959    /// Returns the predictive density at `x` under the variational posterior.
960    pub fn predictive_density(&self, x: f64) -> f64 {
961        let weights = softmax(&self.log_weights);
962        (0..self.n_components)
963            .map(|k| weights[k] * normal_pdf(x, self.var_mean[k], self.obs_var + self.var_var[k]))
964            .sum()
965    }
966}
967
968// ---------------------------------------------------------------------------
969// ExpectationMaximization — Gaussian Mixture Model
970// ---------------------------------------------------------------------------
971
972/// A Gaussian mixture model component.
973#[derive(Debug, Clone)]
974pub struct GmmComponent {
975    /// Mixture weight.
976    pub weight: f64,
977    /// Mean.
978    pub mean: f64,
979    /// Variance.
980    pub var: f64,
981}
982
983impl GmmComponent {
984    /// Creates a new `GmmComponent`.
985    pub fn new(weight: f64, mean: f64, var: f64) -> Self {
986        Self { weight, mean, var }
987    }
988}
989
990/// Expectation-Maximization for Gaussian Mixture Models.
991///
992/// Supports k-means initialization, BIC criterion for model selection,
993/// and full EM convergence.
994#[derive(Debug, Clone)]
995pub struct ExpectationMaximization {
996    /// Number of components.
997    pub n_components: usize,
998    /// Mixture components.
999    pub components: Vec<GmmComponent>,
1000    /// Log-likelihood history.
1001    pub ll_history: Vec<f64>,
1002    /// Convergence tolerance.
1003    pub tol: f64,
1004}
1005
1006impl ExpectationMaximization {
1007    /// Creates a new `ExpectationMaximization` with k-means seeding.
1008    pub fn new(n_components: usize) -> Self {
1009        let components = (0..n_components)
1010            .map(|i| GmmComponent::new(1.0 / n_components as f64, i as f64, 1.0))
1011            .collect();
1012        Self {
1013            n_components,
1014            components,
1015            ll_history: Vec::new(),
1016            tol: 1e-6,
1017        }
1018    }
1019
1020    /// Sets convergence tolerance.
1021    pub fn with_tol(mut self, tol: f64) -> Self {
1022        self.tol = tol;
1023        self
1024    }
1025
1026    /// Initializes component means via k-means (one pass, sorted data).
1027    pub fn kmeans_init(&mut self, data: &[f64]) {
1028        if data.is_empty() {
1029            return;
1030        }
1031        let mut sorted = data.to_vec();
1032        sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
1033        let k = self.n_components;
1034        for i in 0..k {
1035            let idx = (sorted.len() * (2 * i + 1)) / (2 * k);
1036            self.components[i].mean = sorted[idx.min(sorted.len() - 1)];
1037            self.components[i].var = 1.0;
1038            self.components[i].weight = 1.0 / k as f64;
1039        }
1040    }
1041
1042    /// Computes log-likelihood of data under current model.
1043    pub fn log_likelihood(&self, data: &[f64]) -> f64 {
1044        data.iter()
1045            .map(|&x| {
1046                let terms: Vec<f64> = self
1047                    .components
1048                    .iter()
1049                    .map(|c| c.weight.max(1e-300).ln() + log_normal_pdf(x, c.mean, c.var))
1050                    .collect();
1051                log_sum_exp(&terms)
1052            })
1053            .sum()
1054    }
1055
1056    /// Computes the Bayesian Information Criterion.
1057    ///
1058    /// `BIC = k ln(n) - 2 ln(L̂)`
1059    pub fn bic(&self, data: &[f64]) -> f64 {
1060        let n = data.len() as f64;
1061        let ll = self.log_likelihood(data);
1062        // Parameters: n_components means + variances + weights - 1
1063        let n_params = (3 * self.n_components - 1) as f64;
1064        n_params * n.ln() - 2.0 * ll
1065    }
1066
1067    /// E-step: computes responsibilities for each data point.
1068    fn e_step(&self, data: &[f64]) -> Vec<Vec<f64>> {
1069        data.iter()
1070            .map(|&x| {
1071                let log_probs: Vec<f64> = self
1072                    .components
1073                    .iter()
1074                    .map(|c| c.weight.max(1e-300).ln() + log_normal_pdf(x, c.mean, c.var))
1075                    .collect();
1076                softmax(&log_probs)
1077            })
1078            .collect()
1079    }
1080
1081    /// M-step: updates parameters given responsibilities.
1082    fn m_step(&mut self, data: &[f64], responsibilities: &[Vec<f64>]) {
1083        let n = data.len() as f64;
1084        for k in 0..self.n_components {
1085            let r_sum: f64 = responsibilities
1086                .iter()
1087                .map(|r| r[k])
1088                .sum::<f64>()
1089                .max(1e-300);
1090            let new_weight = r_sum / n;
1091            let new_mean: f64 = responsibilities
1092                .iter()
1093                .zip(data.iter())
1094                .map(|(r, &x)| r[k] * x)
1095                .sum::<f64>()
1096                / r_sum;
1097            let new_var: f64 = (responsibilities
1098                .iter()
1099                .zip(data.iter())
1100                .map(|(r, &x)| r[k] * (x - new_mean).powi(2))
1101                .sum::<f64>()
1102                / r_sum)
1103                .max(1e-6);
1104            self.components[k].weight = new_weight;
1105            self.components[k].mean = new_mean;
1106            self.components[k].var = new_var;
1107        }
1108    }
1109
1110    /// Runs the EM algorithm.
1111    ///
1112    /// Returns the final log-likelihood.
1113    pub fn fit(&mut self, data: &[f64], max_iter: usize) -> f64 {
1114        self.ll_history.clear();
1115        let mut prev_ll = f64::NEG_INFINITY;
1116        for _ in 0..max_iter {
1117            let resp = self.e_step(data);
1118            self.m_step(data, &resp);
1119            let ll = self.log_likelihood(data);
1120            self.ll_history.push(ll);
1121            if (ll - prev_ll).abs() < self.tol {
1122                break;
1123            }
1124            prev_ll = ll;
1125        }
1126        *self.ll_history.last().unwrap_or(&f64::NEG_INFINITY)
1127    }
1128
1129    /// Predicts cluster assignment (most likely component) for a data point.
1130    pub fn predict(&self, x: f64) -> usize {
1131        self.components
1132            .iter()
1133            .enumerate()
1134            .map(|(k, c)| {
1135                let ll = c.weight.max(1e-300).ln() + log_normal_pdf(x, c.mean, c.var);
1136                (k, ll)
1137            })
1138            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
1139            .map(|(k, _)| k)
1140            .unwrap_or(0)
1141    }
1142
1143    /// Returns component weights normalized to sum to 1.
1144    pub fn normalized_weights(&self) -> Vec<f64> {
1145        let sum: f64 = self
1146            .components
1147            .iter()
1148            .map(|c| c.weight)
1149            .sum::<f64>()
1150            .max(1e-300);
1151        self.components.iter().map(|c| c.weight / sum).collect()
1152    }
1153}
1154
1155// ---------------------------------------------------------------------------
1156// Tests
1157// ---------------------------------------------------------------------------
1158
1159#[cfg(test)]
1160mod tests {
1161    use super::*;
1162
1163    // --- BayesianNetwork ---
1164
1165    fn make_simple_bn() -> BayesianNetwork {
1166        let mut bn = BayesianNetwork::new();
1167        // Node 0: Rain (2 states) — prior
1168        bn.add_node(BnNode::new("Rain", 2, vec![], vec![0.3, 0.7]));
1169        // Node 1: Sprinkler (2 states) — parent: Rain
1170        bn.add_node(BnNode::new(
1171            "Sprinkler",
1172            2,
1173            vec![0],
1174            vec![0.1, 0.9, 0.5, 0.5], // [rain=0: s=0,s=1; rain=1: s=0,s=1]
1175        ));
1176        bn
1177    }
1178
1179    #[test]
1180    fn test_bn_validate() {
1181        let bn = make_simple_bn();
1182        assert!(bn.validate());
1183    }
1184
1185    #[test]
1186    fn test_bn_joint_probability_sums_to_one() {
1187        let mut bn = BayesianNetwork::new();
1188        bn.add_node(BnNode::new("A", 2, vec![], vec![0.4, 0.6]));
1189        bn.add_node(BnNode::new("B", 2, vec![0], vec![0.7, 0.3, 0.2, 0.8]));
1190        // Sum over all assignments
1191        let total: f64 = (0..4)
1192            .map(|i| {
1193                let a = i / 2;
1194                let b_val = i % 2;
1195                bn.joint_probability(&[a, b_val])
1196            })
1197            .sum();
1198        assert!((total - 1.0).abs() < 1e-10);
1199    }
1200
1201    #[test]
1202    fn test_bn_marginal_sums_to_one() {
1203        let bn = make_simple_bn();
1204        let m0 = bn.marginal(0, 0);
1205        let m1 = bn.marginal(0, 1);
1206        assert!((m0 + m1 - 1.0).abs() < 1e-6);
1207    }
1208
1209    #[test]
1210    fn test_bn_marginal_root_equals_prior() {
1211        let bn = make_simple_bn();
1212        let m0 = bn.marginal(0, 0);
1213        assert!((m0 - 0.3).abs() < 1e-8);
1214    }
1215
1216    #[test]
1217    fn test_bn_conditional_valid() {
1218        let bn = make_simple_bn();
1219        let p = bn.conditional(1, 0, &[(0, 0)]);
1220        assert!((0.0..=1.0).contains(&p));
1221    }
1222
1223    #[test]
1224    fn test_bn_marginal_all_sums_to_one() {
1225        let bn = make_simple_bn();
1226        let m = bn.marginal_all(0);
1227        let s: f64 = m.iter().sum();
1228        assert!((s - 1.0).abs() < 1e-8);
1229    }
1230
1231    #[test]
1232    fn test_bn_cpt_value() {
1233        let node = BnNode::new("X", 2, vec![], vec![0.4, 0.6]);
1234        assert!((node.cpt_value(0, 0) - 0.4).abs() < 1e-10);
1235        assert!((node.cpt_value(1, 0) - 0.6).abs() < 1e-10);
1236    }
1237
1238    #[test]
1239    fn test_bn_single_node() {
1240        let mut bn = BayesianNetwork::new();
1241        bn.add_node(BnNode::new("X", 3, vec![], vec![0.2, 0.5, 0.3]));
1242        let p = bn.joint_probability(&[1]);
1243        assert!((p - 0.5).abs() < 1e-10);
1244    }
1245
1246    // --- HiddenMarkovModel ---
1247
1248    fn make_hmm() -> HiddenMarkovModel {
1249        HiddenMarkovModel::new(
1250            2,
1251            vec![0.6, 0.4],
1252            vec![vec![0.7, 0.3], vec![0.4, 0.6]],
1253            vec![0.0, 3.0],
1254            vec![1.0, 1.0],
1255        )
1256    }
1257
1258    #[test]
1259    fn test_hmm_forward_returns_finite() {
1260        let hmm = make_hmm();
1261        let obs = vec![0.1, 0.2, 0.3, 2.8, 3.1];
1262        let ll = hmm.forward(&obs);
1263        assert!(ll.is_finite());
1264    }
1265
1266    #[test]
1267    fn test_hmm_forward_empty() {
1268        let hmm = make_hmm();
1269        assert_eq!(hmm.forward(&[]), 0.0);
1270    }
1271
1272    #[test]
1273    fn test_hmm_viterbi_length() {
1274        let hmm = make_hmm();
1275        let obs = vec![0.1, 0.2, 0.3, 2.8, 3.1];
1276        let path = hmm.viterbi(&obs);
1277        assert_eq!(path.len(), obs.len());
1278    }
1279
1280    #[test]
1281    fn test_hmm_viterbi_valid_states() {
1282        let hmm = make_hmm();
1283        let obs = vec![0.1, 2.9, 0.2, 3.0];
1284        let path = hmm.viterbi(&obs);
1285        assert!(path.iter().all(|&s| s < 2));
1286    }
1287
1288    #[test]
1289    fn test_hmm_viterbi_empty() {
1290        let hmm = make_hmm();
1291        assert_eq!(hmm.viterbi(&[]).len(), 0);
1292    }
1293
1294    #[test]
1295    fn test_hmm_baum_welch_ll_increases() {
1296        let mut hmm = make_hmm();
1297        let obs: Vec<f64> = (0..20)
1298            .map(|i| if i % 3 == 0 { 0.1 } else { 2.9 })
1299            .collect();
1300        let ll_hist = hmm.baum_welch(&obs, 5);
1301        // Log-likelihood should be non-decreasing
1302        for i in 1..ll_hist.len() {
1303            assert!(ll_hist[i] >= ll_hist[i - 1] - 1e-4);
1304        }
1305    }
1306
1307    #[test]
1308    fn test_hmm_uniform_creation() {
1309        let hmm = HiddenMarkovModel::uniform(3);
1310        assert_eq!(hmm.n_states, 3);
1311        let row_sum: f64 = hmm.transition[0].iter().sum();
1312        assert!((row_sum - 1.0).abs() < 1e-10);
1313    }
1314
1315    // --- GaussianProcess ---
1316
1317    #[test]
1318    fn test_gp_rbf_kernel_diagonal() {
1319        let gp = GaussianProcess::new(KernelType::Rbf, 1.0, 1.0, 1e-3);
1320        assert!((gp.k(0.0, 0.0) - 1.0).abs() < 1e-10);
1321    }
1322
1323    #[test]
1324    fn test_gp_rbf_kernel_decays() {
1325        let gp = GaussianProcess::new(KernelType::Rbf, 1.0, 1.0, 1e-3);
1326        assert!(gp.k(0.0, 10.0) < gp.k(0.0, 1.0));
1327    }
1328
1329    #[test]
1330    fn test_gp_matern32_diagonal() {
1331        let gp = GaussianProcess::new(KernelType::Matern32, 2.0, 1.0, 1e-3);
1332        assert!((gp.k(0.0, 0.0) - 2.0).abs() < 1e-10);
1333    }
1334
1335    #[test]
1336    fn test_gp_matern52_diagonal() {
1337        let gp = GaussianProcess::new(KernelType::Matern52, 1.5, 1.0, 1e-3);
1338        assert!((gp.k(0.0, 0.0) - 1.5).abs() < 1e-10);
1339    }
1340
1341    #[test]
1342    fn test_gp_periodic_diagonal() {
1343        let gp = GaussianProcess::new(KernelType::Periodic, 1.0, 1.0, 1e-3);
1344        assert!((gp.k(0.0, 0.0) - 1.0).abs() < 1e-10);
1345    }
1346
1347    #[test]
1348    fn test_gp_fit_predict_mean() {
1349        let mut gp = GaussianProcess::new(KernelType::Rbf, 1.0, 1.0, 1e-4);
1350        let x = vec![0.0, 1.0, 2.0, 3.0];
1351        let y = vec![0.0, 1.0, 4.0, 9.0];
1352        gp.fit(x, y);
1353        let (mean, _var) = gp.predict(1.0);
1354        assert!((mean - 1.0).abs() < 0.5); // near training point
1355    }
1356
1357    #[test]
1358    fn test_gp_predict_variance_positive() {
1359        let mut gp = GaussianProcess::new(KernelType::Rbf, 1.0, 1.0, 1e-4);
1360        gp.fit(vec![0.0, 1.0], vec![0.0, 1.0]);
1361        let (_mean, var) = gp.predict(5.0); // far from training data
1362        assert!(var > 0.0);
1363    }
1364
1365    #[test]
1366    fn test_gp_predict_empty() {
1367        let gp = GaussianProcess::new(KernelType::Rbf, 1.0, 1.0, 1e-3);
1368        let (mean, var) = gp.predict(0.5);
1369        assert_eq!(mean, 0.0);
1370        assert!(var > 0.0);
1371    }
1372
1373    #[test]
1374    fn test_gp_log_marginal_likelihood() {
1375        let mut gp = GaussianProcess::new(KernelType::Rbf, 1.0, 1.0, 0.1);
1376        gp.fit(vec![0.0, 1.0, 2.0], vec![0.0, 1.0, 0.0]);
1377        let lml = gp.log_marginal_likelihood();
1378        assert!(lml.is_finite());
1379    }
1380
1381    // --- DirichletProcess ---
1382
1383    #[test]
1384    fn test_dp_initial_state() {
1385        let dp = DirichletProcess::new(1.0);
1386        assert_eq!(dp.n_clusters(), 0);
1387        assert_eq!(dp.n_assigned, 0);
1388    }
1389
1390    #[test]
1391    fn test_dp_crp_first_point() {
1392        let mut dp = DirichletProcess::new(1.0);
1393        let c = dp.crp_assign(0.0);
1394        assert_eq!(c, 0);
1395        assert_eq!(dp.n_clusters(), 1);
1396    }
1397
1398    #[test]
1399    fn test_dp_crp_multiple_points() {
1400        let mut dp = DirichletProcess::new(0.1); // low alpha = prefer existing clusters
1401        for i in 0..10 {
1402            dp.crp_assign(i as f64 * 0.01);
1403        }
1404        // With low alpha, should form few clusters
1405        assert!(dp.n_clusters() <= 5);
1406    }
1407
1408    #[test]
1409    fn test_dp_stick_breaking_sums_to_one() {
1410        let dp = DirichletProcess::new(2.0);
1411        let w = dp.stick_breaking_weights(10);
1412        let sum: f64 = w.iter().sum();
1413        assert!((sum - 1.0).abs() < 1e-6);
1414    }
1415
1416    #[test]
1417    fn test_dp_stick_breaking_positive() {
1418        let dp = DirichletProcess::new(1.0);
1419        let w = dp.stick_breaking_weights(5);
1420        assert!(w.iter().all(|&wi| wi > 0.0));
1421    }
1422
1423    #[test]
1424    fn test_dp_expected_clusters() {
1425        let e = DirichletProcess::expected_clusters(1.0, 100);
1426        assert!(e > 3.0 && e < 10.0);
1427    }
1428
1429    #[test]
1430    fn test_dp_cluster_variances() {
1431        let mut dp = DirichletProcess::new(0.5);
1432        for i in 0..5 {
1433            dp.crp_assign(i as f64);
1434        }
1435        let vars = dp.cluster_variances();
1436        assert!(vars.iter().all(|&v| v >= 0.0));
1437    }
1438
1439    // --- VariationalInference ---
1440
1441    #[test]
1442    fn test_vi_elbo_finite() {
1443        let vi = VariationalInference::new(2, 0.0, 1.0, 1.0);
1444        let obs = vec![0.0, 1.0, -1.0, 2.0];
1445        let elbo = vi.elbo(&obs);
1446        assert!(elbo.is_finite());
1447    }
1448
1449    #[test]
1450    fn test_vi_cavi_step_updates_params() {
1451        let mut vi = VariationalInference::new(2, 0.0, 1.0, 1.0);
1452        let old_mean = vi.var_mean[0];
1453        let obs = vec![3.0, 3.1, 3.2, -3.0, -3.1, -3.2];
1454        vi.cavi_step(&obs);
1455        // Mean should change toward data
1456        assert!((vi.var_mean[0] - old_mean).abs() > 0.0);
1457    }
1458
1459    #[test]
1460    fn test_vi_fit_returns_finite() {
1461        let mut vi = VariationalInference::new(2, 0.0, 2.0, 1.0);
1462        let obs: Vec<f64> = (0..20)
1463            .map(|i| if i % 2 == 0 { 1.0 } else { -1.0 })
1464            .collect();
1465        let elbo = vi.fit(&obs, 10);
1466        assert!(elbo.is_finite());
1467    }
1468
1469    #[test]
1470    fn test_vi_reparameterize() {
1471        let vi = VariationalInference::new(2, 0.0, 1.0, 1.0);
1472        let sample = vi.reparameterize(0, 1.0);
1473        // sample = mean[0] + sqrt(var[0]) * 1.0
1474        let expected = vi.var_mean[0] + vi.var_var[0].sqrt();
1475        assert!((sample - expected).abs() < 1e-10);
1476    }
1477
1478    #[test]
1479    fn test_vi_predictive_density_positive() {
1480        let vi = VariationalInference::new(2, 0.0, 1.0, 1.0);
1481        let p = vi.predictive_density(0.0);
1482        assert!(p > 0.0);
1483    }
1484
1485    #[test]
1486    fn test_vi_elbo_history_grows() {
1487        let mut vi = VariationalInference::new(2, 0.0, 1.0, 1.0);
1488        let obs = vec![1.0, -1.0, 2.0];
1489        vi.fit(&obs, 5);
1490        assert_eq!(vi.elbo_history.len(), 5);
1491    }
1492
1493    // --- ExpectationMaximization ---
1494
1495    #[test]
1496    fn test_em_initial_weights_sum_to_one() {
1497        let em = ExpectationMaximization::new(3);
1498        let sum: f64 = em.normalized_weights().iter().sum();
1499        assert!((sum - 1.0).abs() < 1e-10);
1500    }
1501
1502    #[test]
1503    fn test_em_kmeans_init() {
1504        let mut em = ExpectationMaximization::new(2);
1505        let data = vec![0.0, 0.1, 0.2, 5.0, 5.1, 5.2];
1506        em.kmeans_init(&data);
1507        // Means should be near 0 and 5
1508        let means: Vec<f64> = em.components.iter().map(|c| c.mean).collect();
1509        assert!(means.iter().any(|&m| m < 1.0));
1510        assert!(means.iter().any(|&m| m > 4.0));
1511    }
1512
1513    #[test]
1514    fn test_em_log_likelihood_finite() {
1515        let em = ExpectationMaximization::new(2);
1516        let data = vec![0.0, 1.0, 2.0];
1517        assert!(em.log_likelihood(&data).is_finite());
1518    }
1519
1520    #[test]
1521    fn test_em_fit_ll_increases() {
1522        let mut em = ExpectationMaximization::new(2);
1523        let data: Vec<f64> = (0..30)
1524            .map(|i| {
1525                if i < 15 {
1526                    i as f64 * 0.1
1527                } else {
1528                    5.0 + i as f64 * 0.1
1529                }
1530            })
1531            .collect();
1532        em.kmeans_init(&data);
1533        em.fit(&data, 20);
1534        let ll = &em.ll_history;
1535        for i in 1..ll.len() {
1536            assert!(ll[i] >= ll[i - 1] - 1e-4);
1537        }
1538    }
1539
1540    #[test]
1541    fn test_em_predict_valid_component() {
1542        let em = ExpectationMaximization::new(3);
1543        let pred = em.predict(0.5);
1544        assert!(pred < 3);
1545    }
1546
1547    #[test]
1548    fn test_em_bic_finite() {
1549        let em = ExpectationMaximization::new(2);
1550        let data = vec![0.0, 1.0, 5.0, 6.0];
1551        let bic = em.bic(&data);
1552        assert!(bic.is_finite());
1553    }
1554
1555    #[test]
1556    fn test_em_fit_separates_clusters() {
1557        let mut em = ExpectationMaximization::new(2);
1558        let mut data: Vec<f64> = (0..20).map(|i| i as f64 * 0.05).collect(); // 0..1
1559        let data2: Vec<f64> = (0..20).map(|i| 10.0 + i as f64 * 0.05).collect(); // 10..11
1560        data.extend(data2);
1561        em.kmeans_init(&data);
1562        em.fit(&data, 50);
1563        // One component should be near 0.5, other near 10.5
1564        let means: Vec<f64> = em.components.iter().map(|c| c.mean).collect();
1565        assert!(means.iter().any(|&m| m < 3.0));
1566        assert!(means.iter().any(|&m| m > 7.0));
1567    }
1568
1569    #[test]
1570    fn test_em_n_components() {
1571        let em = ExpectationMaximization::new(4);
1572        assert_eq!(em.n_components, 4);
1573        assert_eq!(em.components.len(), 4);
1574    }
1575
1576    // --- Helper functions ---
1577
1578    #[test]
1579    fn test_log_sum_exp_empty() {
1580        assert_eq!(log_sum_exp(&[]), f64::NEG_INFINITY);
1581    }
1582
1583    #[test]
1584    fn test_log_sum_exp_single() {
1585        assert!((log_sum_exp(&[2.0]) - 2.0).abs() < 1e-10);
1586    }
1587
1588    #[test]
1589    fn test_softmax_sums_to_one() {
1590        let s = softmax(&[1.0, 2.0, 3.0]);
1591        assert!((s.iter().sum::<f64>() - 1.0).abs() < 1e-10);
1592    }
1593
1594    #[test]
1595    fn test_normal_pdf_peak() {
1596        let p = normal_pdf(0.0, 0.0, 1.0);
1597        assert!((p - 1.0 / (TAU).sqrt()).abs() < 1e-10);
1598    }
1599
1600    #[test]
1601    fn test_mvn_log_pdf_diag() {
1602        let x = vec![0.0, 0.0];
1603        let mean = vec![0.0, 0.0];
1604        let var = vec![1.0, 1.0];
1605        let lp = mvn_log_pdf_diag(&x, &mean, &var);
1606        assert!(lp.is_finite());
1607    }
1608}