1#![allow(clippy::needless_range_loop)]
2#![allow(dead_code)]
12
13use std::f64::consts::{PI, TAU};
14
15fn log_normal_pdf(x: f64, mean: f64, var: f64) -> f64 {
21 -0.5 * ((x - mean).powi(2) / var + var.ln() + (TAU).ln())
22}
23
24fn normal_pdf(x: f64, mean: f64, var: f64) -> f64 {
26 (-(x - mean).powi(2) / (2.0 * var)).exp() / (TAU * var).sqrt()
27}
28
29fn log_sum_exp(values: &[f64]) -> f64 {
31 let max = values.iter().copied().fold(f64::NEG_INFINITY, f64::max);
32 if max.is_infinite() {
33 return f64::NEG_INFINITY;
34 }
35 let sum: f64 = values.iter().map(|&v| (v - max).exp()).sum();
36 max + sum.ln()
37}
38
39fn softmax(logits: &[f64]) -> Vec<f64> {
41 let max = logits.iter().copied().fold(f64::NEG_INFINITY, f64::max);
42 let exp: Vec<f64> = logits.iter().map(|&x| (x - max).exp()).collect();
43 let sum: f64 = exp.iter().sum::<f64>().max(1e-300);
44 exp.iter().map(|&e| e / sum).collect()
45}
46
47fn mvn_log_pdf_diag(x: &[f64], mean: &[f64], var: &[f64]) -> f64 {
49 let d = x.len() as f64;
50 let log_det: f64 = var.iter().map(|v| v.max(1e-300).ln()).sum();
51 let maha: f64 = x
52 .iter()
53 .zip(mean.iter())
54 .zip(var.iter())
55 .map(|((&xi, &mi), &vi)| (xi - mi).powi(2) / vi.max(1e-300))
56 .sum();
57 -0.5 * (d * TAU.ln() + log_det + maha)
58}
59
60#[derive(Debug, Clone)]
66pub struct BnNode {
67 pub name: String,
69 pub n_states: usize,
71 pub parents: Vec<usize>,
73 pub cpt: Vec<f64>,
78}
79
80impl BnNode {
81 pub fn new(
83 name: impl Into<String>,
84 n_states: usize,
85 parents: Vec<usize>,
86 cpt: Vec<f64>,
87 ) -> Self {
88 Self {
89 name: name.into(),
90 n_states,
91 parents,
92 cpt,
93 }
94 }
95
96 pub fn cpt_value(&self, state: usize, parent_config: usize) -> f64 {
98 let offset = parent_config * self.n_states;
99 self.cpt[offset + state]
100 }
101}
102
103#[derive(Debug, Clone)]
108pub struct BayesianNetwork {
109 pub nodes: Vec<BnNode>,
111}
112
113impl BayesianNetwork {
114 pub fn new() -> Self {
116 Self { nodes: Vec::new() }
117 }
118
119 pub fn add_node(&mut self, node: BnNode) -> usize {
121 let idx = self.nodes.len();
122 self.nodes.push(node);
123 idx
124 }
125
126 pub fn joint_probability(&self, assignment: &[usize]) -> f64 {
130 let mut prob = 1.0f64;
131 for (i, node) in self.nodes.iter().enumerate() {
132 let parent_config = self.parent_config_index(i, assignment);
133 prob *= node.cpt_value(assignment[i], parent_config);
134 }
135 prob
136 }
137
138 fn parent_config_index(&self, node_idx: usize, assignment: &[usize]) -> usize {
140 let node = &self.nodes[node_idx];
141 let mut config = 0usize;
142 for &p in &node.parents {
143 let p_states = self.nodes[p].n_states;
144 config = config * p_states + assignment[p];
145 }
146 config
147 }
148
149 pub fn marginal(&self, target: usize, target_state: usize) -> f64 {
152 let n = self.nodes.len();
153 let n_states: Vec<usize> = self.nodes.iter().map(|nd| nd.n_states).collect();
155 let total: usize = n_states.iter().product();
156 let mut prob = 0.0f64;
157 let mut assignment = vec![0usize; n];
158 for _ in 0..total {
159 if assignment[target] == target_state {
160 prob += self.joint_probability(&assignment);
161 }
162 let mut carry = 1;
164 for i in (0..n).rev() {
165 let next = assignment[i] + carry;
166 assignment[i] = next % n_states[i];
167 carry = next / n_states[i];
168 if carry == 0 {
169 break;
170 }
171 }
172 }
173 prob
174 }
175
176 pub fn marginal_all(&self, target: usize) -> Vec<f64> {
178 let n_states = self.nodes[target].n_states;
179 (0..n_states).map(|s| self.marginal(target, s)).collect()
180 }
181
182 pub fn conditional(
186 &self,
187 target: usize,
188 target_state: usize,
189 evidence: &[(usize, usize)],
190 ) -> f64 {
191 let n = self.nodes.len();
192 let n_states: Vec<usize> = self.nodes.iter().map(|nd| nd.n_states).collect();
193 let total: usize = n_states.iter().product();
194 let mut num = 0.0f64;
195 let mut denom = 0.0f64;
196 let mut assignment = vec![0usize; n];
197 for _ in 0..total {
198 let consistent = evidence.iter().all(|&(ni, s)| assignment[ni] == s);
200 if consistent {
201 let p = self.joint_probability(&assignment);
202 denom += p;
203 if assignment[target] == target_state {
204 num += p;
205 }
206 }
207 let mut carry = 1;
208 for i in (0..n).rev() {
209 let next = assignment[i] + carry;
210 assignment[i] = next % n_states[i];
211 carry = next / n_states[i];
212 if carry == 0 {
213 break;
214 }
215 }
216 }
217 if denom < 1e-300 { 0.0 } else { num / denom }
218 }
219
220 pub fn validate(&self) -> bool {
222 for node in &self.nodes {
223 let n_configs = if node.parents.is_empty() {
224 1
225 } else {
226 node.cpt.len() / node.n_states
227 };
228 for cfg in 0..n_configs {
229 let sum: f64 = (0..node.n_states).map(|s| node.cpt_value(s, cfg)).sum();
230 if (sum - 1.0).abs() > 1e-6 {
231 return false;
232 }
233 }
234 }
235 true
236 }
237}
238
239impl Default for BayesianNetwork {
240 fn default() -> Self {
241 Self::new()
242 }
243}
244
245#[derive(Debug, Clone)]
256pub struct HiddenMarkovModel {
257 pub n_states: usize,
259 pub initial: Vec<f64>,
261 pub transition: Vec<Vec<f64>>,
263 pub emission_mean: Vec<f64>,
265 pub emission_var: Vec<f64>,
267}
268
269impl HiddenMarkovModel {
270 pub fn new(
272 n_states: usize,
273 initial: Vec<f64>,
274 transition: Vec<Vec<f64>>,
275 emission_mean: Vec<f64>,
276 emission_var: Vec<f64>,
277 ) -> Self {
278 Self {
279 n_states,
280 initial,
281 transition,
282 emission_mean,
283 emission_var,
284 }
285 }
286
287 pub fn uniform(n_states: usize) -> Self {
289 let p = 1.0 / n_states as f64;
290 let initial = vec![p; n_states];
291 let transition = vec![vec![p; n_states]; n_states];
292 let emission_mean: Vec<f64> = (0..n_states).map(|i| i as f64).collect();
293 let emission_var = vec![1.0; n_states];
294 Self::new(n_states, initial, transition, emission_mean, emission_var)
295 }
296
297 fn log_emit(&self, s: usize, obs: f64) -> f64 {
299 log_normal_pdf(obs, self.emission_mean[s], self.emission_var[s])
300 }
301
302 pub fn forward(&self, observations: &[f64]) -> f64 {
304 let t_len = observations.len();
305 if t_len == 0 {
306 return 0.0;
307 }
308 let k = self.n_states;
309 let mut alpha = vec![0.0f64; k];
310 for s in 0..k {
312 alpha[s] = self.initial[s].ln() + self.log_emit(s, observations[0]);
313 }
314 for t in 1..t_len {
316 let mut alpha_new = vec![f64::NEG_INFINITY; k];
317 for j in 0..k {
318 let log_emit_j = self.log_emit(j, observations[t]);
319 let terms: Vec<f64> = (0..k)
320 .map(|i| alpha[i] + self.transition[i][j].max(1e-300).ln())
321 .collect();
322 alpha_new[j] = log_sum_exp(&terms) + log_emit_j;
323 }
324 alpha = alpha_new;
325 }
326 log_sum_exp(&alpha)
327 }
328
329 pub fn viterbi(&self, observations: &[f64]) -> Vec<usize> {
331 let t_len = observations.len();
332 if t_len == 0 {
333 return Vec::new();
334 }
335 let k = self.n_states;
336 let mut delta = vec![vec![0.0f64; k]; t_len];
337 let mut psi = vec![vec![0usize; k]; t_len];
338
339 for s in 0..k {
341 delta[0][s] = self.initial[s].max(1e-300).ln() + self.log_emit(s, observations[0]);
342 }
343
344 for t in 1..t_len {
346 for j in 0..k {
347 let (best_s, best_val) = (0..k)
348 .map(|i| {
349 let v = delta[t - 1][i] + self.transition[i][j].max(1e-300).ln();
350 (i, v)
351 })
352 .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
353 .expect("states iterator is non-empty");
354 delta[t][j] = best_val + self.log_emit(j, observations[t]);
355 psi[t][j] = best_s;
356 }
357 }
358
359 let mut path = vec![0usize; t_len];
361 path[t_len - 1] = (0..k)
362 .max_by(|&a, &b| {
363 delta[t_len - 1][a]
364 .partial_cmp(&delta[t_len - 1][b])
365 .unwrap_or(std::cmp::Ordering::Equal)
366 })
367 .expect("k states is non-empty");
368 for t in (0..t_len - 1).rev() {
369 path[t] = psi[t + 1][path[t + 1]];
370 }
371 path
372 }
373
374 pub fn baum_welch(&mut self, observations: &[f64], n_iter: usize) -> Vec<f64> {
378 let t_len = observations.len();
379 let k = self.n_states;
380 let mut ll_history = Vec::new();
381
382 for _iter in 0..n_iter {
383 let mut log_alpha = vec![vec![0.0f64; k]; t_len];
386 for s in 0..k {
387 log_alpha[0][s] =
388 self.initial[s].max(1e-300).ln() + self.log_emit(s, observations[0]);
389 }
390 for t in 1..t_len {
391 for j in 0..k {
392 let terms: Vec<f64> = (0..k)
393 .map(|i| log_alpha[t - 1][i] + self.transition[i][j].max(1e-300).ln())
394 .collect();
395 log_alpha[t][j] = log_sum_exp(&terms) + self.log_emit(j, observations[t]);
396 }
397 }
398 let log_ll = log_sum_exp(&log_alpha[t_len - 1]);
399 ll_history.push(log_ll);
400
401 let mut log_beta = vec![vec![0.0f64; k]; t_len];
403 for t in (0..t_len - 1).rev() {
405 for i in 0..k {
406 let terms: Vec<f64> = (0..k)
407 .map(|j| {
408 self.transition[i][j].max(1e-300).ln()
409 + self.log_emit(j, observations[t + 1])
410 + log_beta[t + 1][j]
411 })
412 .collect();
413 log_beta[t][i] = log_sum_exp(&terms);
414 }
415 }
416
417 let mut gamma = vec![vec![0.0f64; k]; t_len];
420 for t in 0..t_len {
421 let log_probs: Vec<f64> =
422 (0..k).map(|s| log_alpha[t][s] + log_beta[t][s]).collect();
423 let norm = log_sum_exp(&log_probs);
424 for s in 0..k {
425 gamma[t][s] = (log_probs[s] - norm).exp();
426 }
427 }
428
429 let mut xi = vec![vec![vec![0.0f64; k]; k]; t_len.saturating_sub(1)];
431 for t in 0..t_len.saturating_sub(1) {
432 let mut xi_t = vec![vec![0.0f64; k]; k];
433 let mut log_xi_t = vec![vec![0.0f64; k]; k];
434 for i in 0..k {
435 for j in 0..k {
436 log_xi_t[i][j] = log_alpha[t][i]
437 + self.transition[i][j].max(1e-300).ln()
438 + self.log_emit(j, observations[t + 1])
439 + log_beta[t + 1][j];
440 }
441 }
442 let flat: Vec<f64> = log_xi_t.iter().flat_map(|r| r.iter().copied()).collect();
443 let norm = log_sum_exp(&flat);
444 for i in 0..k {
445 for j in 0..k {
446 xi_t[i][j] = (log_xi_t[i][j] - norm).exp();
447 }
448 }
449 xi[t] = xi_t;
450 }
451
452 for s in 0..k {
455 self.initial[s] = gamma[0][s].max(1e-300);
456 }
457 let init_sum: f64 = self.initial.iter().sum::<f64>().max(1e-300);
458 for s in 0..k {
459 self.initial[s] /= init_sum;
460 }
461
462 for i in 0..k {
464 let denom: f64 = (0..t_len.saturating_sub(1))
465 .map(|t| gamma[t][i])
466 .sum::<f64>()
467 .max(1e-300);
468 for j in 0..k {
469 let num: f64 = (0..t_len.saturating_sub(1)).map(|t| xi[t][i][j]).sum();
470 self.transition[i][j] = (num / denom).max(1e-300);
471 }
472 let row_sum: f64 = self.transition[i].iter().sum::<f64>().max(1e-300);
474 for j in 0..k {
475 self.transition[i][j] /= row_sum;
476 }
477 }
478
479 for s in 0..k {
481 let denom: f64 = (0..t_len).map(|t| gamma[t][s]).sum::<f64>().max(1e-300);
482 let new_mean: f64 = (0..t_len)
483 .map(|t| gamma[t][s] * observations[t])
484 .sum::<f64>()
485 / denom;
486 let new_var: f64 = ((0..t_len)
487 .map(|t| gamma[t][s] * (observations[t] - new_mean).powi(2))
488 .sum::<f64>()
489 / denom)
490 .max(1e-6);
491 self.emission_mean[s] = new_mean;
492 self.emission_var[s] = new_var;
493 }
494 }
495 ll_history
496 }
497}
498
499#[derive(Debug, Clone, Copy, PartialEq)]
505pub enum KernelType {
506 Rbf,
508 Matern32,
510 Matern52,
512 Periodic,
514}
515
516#[derive(Debug, Clone)]
520pub struct GaussianProcess {
521 pub kernel: KernelType,
523 pub signal_var: f64,
525 pub length_scale: f64,
527 pub period: f64,
529 pub noise_var: f64,
531 pub x_train: Vec<f64>,
533 pub y_train: Vec<f64>,
535 chol: Vec<f64>,
537 alpha: Vec<f64>,
539}
540
541impl GaussianProcess {
542 pub fn new(kernel: KernelType, signal_var: f64, length_scale: f64, noise_var: f64) -> Self {
544 Self {
545 kernel,
546 signal_var,
547 length_scale,
548 period: 1.0,
549 noise_var,
550 x_train: Vec::new(),
551 y_train: Vec::new(),
552 chol: Vec::new(),
553 alpha: Vec::new(),
554 }
555 }
556
557 pub fn with_period(mut self, period: f64) -> Self {
559 self.period = period;
560 self
561 }
562
563 pub fn k(&self, x1: f64, x2: f64) -> f64 {
565 let r = (x1 - x2).abs();
566 match self.kernel {
567 KernelType::Rbf => {
568 self.signal_var * (-r * r / (2.0 * self.length_scale * self.length_scale)).exp()
569 }
570 KernelType::Matern32 => {
571 let sq3r = 3.0f64.sqrt() * r / self.length_scale;
572 self.signal_var * (1.0 + sq3r) * (-sq3r).exp()
573 }
574 KernelType::Matern52 => {
575 let sq5r = 5.0f64.sqrt() * r / self.length_scale;
576 self.signal_var * (1.0 + sq5r + sq5r * sq5r / 3.0) * (-sq5r).exp()
577 }
578 KernelType::Periodic => {
579 let arg = PI * r / self.period;
580 self.signal_var
581 * (-2.0 * arg.sin().powi(2) / (self.length_scale * self.length_scale)).exp()
582 }
583 }
584 }
585
586 pub fn fit(&mut self, x_train: Vec<f64>, y_train: Vec<f64>) {
588 let n = x_train.len();
589 self.x_train = x_train;
590 self.y_train = y_train.clone();
591
592 let mut k_mat = vec![0.0f64; n * n];
594 for i in 0..n {
595 for j in 0..n {
596 k_mat[i * n + j] = self.k(self.x_train[i], self.x_train[j]);
597 }
598 k_mat[i * n + i] += self.noise_var;
599 }
600
601 let mut l = k_mat.clone();
603 for i in 0..n {
604 for j in 0..=i {
605 let mut s = l[i * n + j];
606 for k_idx in 0..j {
607 s -= l[i * n + k_idx] * l[j * n + k_idx];
608 }
609 if i == j {
610 l[i * n + j] = s.max(1e-12).sqrt();
611 } else {
612 l[i * n + j] = s / l[j * n + j].max(1e-12);
613 }
614 }
615 for j in i + 1..n {
617 l[i * n + j] = 0.0;
618 }
619 }
620 self.chol = l.clone();
621
622 let mut w = y_train;
624 for i in 0..n {
626 let mut s = w[i];
627 for j in 0..i {
628 s -= l[i * n + j] * w[j];
629 }
630 w[i] = s / l[i * n + i].max(1e-12);
631 }
632 let mut alpha = w;
634 for i in (0..n).rev() {
635 let mut s = alpha[i];
636 for j in i + 1..n {
637 s -= l[j * n + i] * alpha[j];
638 }
639 alpha[i] = s / l[i * n + i].max(1e-12);
640 }
641 self.alpha = alpha;
642 }
643
644 pub fn predict(&self, x_star: f64) -> (f64, f64) {
646 let n = self.x_train.len();
647 if n == 0 {
648 return (0.0, self.signal_var + self.noise_var);
649 }
650
651 let k_star: Vec<f64> = self.x_train.iter().map(|&xi| self.k(x_star, xi)).collect();
653
654 let mean: f64 = k_star
656 .iter()
657 .zip(self.alpha.iter())
658 .map(|(a, b)| a * b)
659 .sum();
660
661 let mut v = k_star.clone();
664 for i in 0..n {
665 let mut s = v[i];
666 for j in 0..i {
667 s -= self.chol[i * n + j] * v[j];
668 }
669 v[i] = s / self.chol[i * n + i].max(1e-12);
670 }
671 let var = (self.k(x_star, x_star) - v.iter().map(|vi| vi * vi).sum::<f64>()).max(1e-12);
672
673 (mean, var)
674 }
675
676 pub fn log_marginal_likelihood(&self) -> f64 {
678 let n = self.x_train.len();
679 if n == 0 {
680 return 0.0;
681 }
682 let data_fit: f64 = self
684 .y_train
685 .iter()
686 .zip(self.alpha.iter())
687 .map(|(y, a)| y * a)
688 .sum::<f64>();
689 let log_det: f64 = (0..n)
690 .map(|i| self.chol[i * n + i].max(1e-300).ln())
691 .sum::<f64>();
692 -0.5 * data_fit - log_det - 0.5 * n as f64 * TAU.ln()
693 }
694}
695
696#[derive(Debug, Clone)]
705pub struct DirichletProcess {
706 pub alpha: f64,
708 pub assignments: Vec<usize>,
710 pub cluster_counts: Vec<usize>,
712 pub cluster_means: Vec<f64>,
714 pub cluster_ss: Vec<f64>,
716 pub n_assigned: usize,
718}
719
720impl DirichletProcess {
721 pub fn new(alpha: f64) -> Self {
723 Self {
724 alpha,
725 assignments: Vec::new(),
726 cluster_counts: Vec::new(),
727 cluster_means: Vec::new(),
728 cluster_ss: Vec::new(),
729 n_assigned: 0,
730 }
731 }
732
733 pub fn n_clusters(&self) -> usize {
735 self.cluster_counts.len()
736 }
737
738 pub fn crp_assign(&mut self, x: f64) -> usize {
742 let n = self.n_assigned as f64;
743 let k = self.cluster_counts.len();
744
745 let mut probs: Vec<f64> = self
747 .cluster_counts
748 .iter()
749 .map(|&cnt| cnt as f64 / (n + self.alpha))
750 .collect();
751 probs.push(self.alpha / (n + self.alpha)); let chosen = probs
755 .iter()
756 .enumerate()
757 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
758 .map(|(i, _)| i)
759 .unwrap_or(k);
760
761 if chosen == k {
762 self.cluster_counts.push(1);
764 self.cluster_means.push(x);
765 self.cluster_ss.push(0.0);
766 } else {
767 let cnt = self.cluster_counts[chosen] as f64;
769 let old_mean = self.cluster_means[chosen];
770 self.cluster_counts[chosen] += 1;
771 let new_mean = old_mean + (x - old_mean) / (cnt + 1.0);
772 self.cluster_ss[chosen] += (x - old_mean) * (x - new_mean);
773 self.cluster_means[chosen] = new_mean;
774 }
775 self.assignments.push(chosen);
776 self.n_assigned += 1;
777 chosen
778 }
779
780 pub fn stick_breaking_weights(&self, k: usize) -> Vec<f64> {
785 let mut weights = Vec::with_capacity(k);
786 let mut remaining = 1.0f64;
787 for i in 0..k {
788 let mean_beta = 1.0 / (1.0 + self.alpha);
790 let v = mean_beta * (1.0 - 0.1 * i as f64 / (k as f64 + 1.0));
792 let v = v.clamp(1e-6, 1.0 - 1e-6);
793 let w = remaining * v;
794 weights.push(w);
795 remaining *= 1.0 - v;
796 }
797 if let Some(last) = weights.last_mut() {
799 *last += remaining;
800 }
801 let total: f64 = weights.iter().sum::<f64>().max(1e-300);
803 weights.iter_mut().for_each(|w| *w /= total);
804 weights
805 }
806
807 pub fn cluster_variances(&self) -> Vec<f64> {
809 self.cluster_counts
810 .iter()
811 .zip(self.cluster_ss.iter())
812 .map(
813 |(&cnt, &ss)| {
814 if cnt > 1 { ss / (cnt - 1) as f64 } else { 1.0 }
815 },
816 )
817 .collect()
818 }
819
820 pub fn expected_clusters(alpha: f64, n: usize) -> f64 {
824 alpha * (1.0 + n as f64 / alpha).ln()
825 }
826}
827
828#[derive(Debug, Clone)]
837pub struct VariationalInference {
838 pub n_components: usize,
840 pub log_weights: Vec<f64>,
842 pub var_mean: Vec<f64>,
844 pub var_var: Vec<f64>,
846 pub prior_mean: f64,
848 pub prior_var: f64,
850 pub obs_var: f64,
852 pub elbo_history: Vec<f64>,
854}
855
856impl VariationalInference {
857 pub fn new(n_components: usize, prior_mean: f64, prior_var: f64, obs_var: f64) -> Self {
859 let log_weights = vec![-(n_components as f64).ln(); n_components];
860 let var_mean: Vec<f64> = (0..n_components).map(|i| i as f64).collect();
861 let var_var = vec![1.0f64; n_components];
862 Self {
863 n_components,
864 log_weights,
865 var_mean,
866 var_var,
867 prior_mean,
868 prior_var,
869 obs_var,
870 elbo_history: Vec::new(),
871 }
872 }
873
874 pub fn elbo(&self, observations: &[f64]) -> f64 {
876 let weights = softmax(&self.log_weights);
877 let mut elbo = 0.0f64;
878 for &x in observations {
880 let ll_terms: Vec<f64> = (0..self.n_components)
881 .map(|k| {
882 weights[k].max(1e-300).ln()
883 + log_normal_pdf(x, self.var_mean[k], self.obs_var + self.var_var[k])
884 })
885 .collect();
886 elbo += log_sum_exp(&ll_terms);
887 }
888 for k in 0..self.n_components {
890 let kl = 0.5
892 * (self.prior_var / self.var_var[k].max(1e-12)
893 + (self.var_mean[k] - self.prior_mean).powi(2) / self.prior_var
894 - 1.0
895 + (self.var_var[k] / self.prior_var).ln());
896 elbo -= weights[k] * kl;
897 }
898 elbo
899 }
900
901 pub fn cavi_step(&mut self, observations: &[f64]) -> f64 {
905 let n = observations.len() as f64;
906 for k in 0..self.n_components {
908 let weights = softmax(&self.log_weights);
909 let r_k: Vec<f64> = observations
911 .iter()
912 .map(|&x| weights[k] * normal_pdf(x, self.var_mean[k], self.obs_var))
913 .collect();
914 let r_sum: f64 = r_k.iter().sum::<f64>().max(1e-300);
915
916 let prior_prec = 1.0 / self.prior_var.max(1e-12);
918 let lik_prec = r_sum / self.obs_var.max(1e-12);
919 let post_prec = prior_prec + lik_prec;
920 let post_var = 1.0 / post_prec.max(1e-12);
921 let data_sum: f64 = r_k
922 .iter()
923 .zip(observations.iter())
924 .map(|(r, x)| r * x)
925 .sum();
926 let post_mean =
927 post_var * (prior_prec * self.prior_mean + data_sum / self.obs_var.max(1e-12));
928
929 self.var_mean[k] = post_mean;
930 self.var_var[k] = post_var;
931
932 self.log_weights[k] = r_sum.max(1e-300).ln();
934 }
935 let lse = log_sum_exp(&self.log_weights.clone());
937 for k in 0..self.n_components {
938 self.log_weights[k] -= lse;
939 }
940 let _ = n;
941 let elbo_val = self.elbo(observations);
942 self.elbo_history.push(elbo_val);
943 elbo_val
944 }
945
946 pub fn fit(&mut self, observations: &[f64], n_iter: usize) -> f64 {
948 for _ in 0..n_iter {
949 self.cavi_step(observations);
950 }
951 *self.elbo_history.last().unwrap_or(&f64::NEG_INFINITY)
952 }
953
954 pub fn reparameterize(&self, k: usize, eps: f64) -> f64 {
956 self.var_mean[k] + self.var_var[k].sqrt() * eps
957 }
958
959 pub fn predictive_density(&self, x: f64) -> f64 {
961 let weights = softmax(&self.log_weights);
962 (0..self.n_components)
963 .map(|k| weights[k] * normal_pdf(x, self.var_mean[k], self.obs_var + self.var_var[k]))
964 .sum()
965 }
966}
967
968#[derive(Debug, Clone)]
974pub struct GmmComponent {
975 pub weight: f64,
977 pub mean: f64,
979 pub var: f64,
981}
982
983impl GmmComponent {
984 pub fn new(weight: f64, mean: f64, var: f64) -> Self {
986 Self { weight, mean, var }
987 }
988}
989
990#[derive(Debug, Clone)]
995pub struct ExpectationMaximization {
996 pub n_components: usize,
998 pub components: Vec<GmmComponent>,
1000 pub ll_history: Vec<f64>,
1002 pub tol: f64,
1004}
1005
1006impl ExpectationMaximization {
1007 pub fn new(n_components: usize) -> Self {
1009 let components = (0..n_components)
1010 .map(|i| GmmComponent::new(1.0 / n_components as f64, i as f64, 1.0))
1011 .collect();
1012 Self {
1013 n_components,
1014 components,
1015 ll_history: Vec::new(),
1016 tol: 1e-6,
1017 }
1018 }
1019
1020 pub fn with_tol(mut self, tol: f64) -> Self {
1022 self.tol = tol;
1023 self
1024 }
1025
1026 pub fn kmeans_init(&mut self, data: &[f64]) {
1028 if data.is_empty() {
1029 return;
1030 }
1031 let mut sorted = data.to_vec();
1032 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
1033 let k = self.n_components;
1034 for i in 0..k {
1035 let idx = (sorted.len() * (2 * i + 1)) / (2 * k);
1036 self.components[i].mean = sorted[idx.min(sorted.len() - 1)];
1037 self.components[i].var = 1.0;
1038 self.components[i].weight = 1.0 / k as f64;
1039 }
1040 }
1041
1042 pub fn log_likelihood(&self, data: &[f64]) -> f64 {
1044 data.iter()
1045 .map(|&x| {
1046 let terms: Vec<f64> = self
1047 .components
1048 .iter()
1049 .map(|c| c.weight.max(1e-300).ln() + log_normal_pdf(x, c.mean, c.var))
1050 .collect();
1051 log_sum_exp(&terms)
1052 })
1053 .sum()
1054 }
1055
1056 pub fn bic(&self, data: &[f64]) -> f64 {
1060 let n = data.len() as f64;
1061 let ll = self.log_likelihood(data);
1062 let n_params = (3 * self.n_components - 1) as f64;
1064 n_params * n.ln() - 2.0 * ll
1065 }
1066
1067 fn e_step(&self, data: &[f64]) -> Vec<Vec<f64>> {
1069 data.iter()
1070 .map(|&x| {
1071 let log_probs: Vec<f64> = self
1072 .components
1073 .iter()
1074 .map(|c| c.weight.max(1e-300).ln() + log_normal_pdf(x, c.mean, c.var))
1075 .collect();
1076 softmax(&log_probs)
1077 })
1078 .collect()
1079 }
1080
1081 fn m_step(&mut self, data: &[f64], responsibilities: &[Vec<f64>]) {
1083 let n = data.len() as f64;
1084 for k in 0..self.n_components {
1085 let r_sum: f64 = responsibilities
1086 .iter()
1087 .map(|r| r[k])
1088 .sum::<f64>()
1089 .max(1e-300);
1090 let new_weight = r_sum / n;
1091 let new_mean: f64 = responsibilities
1092 .iter()
1093 .zip(data.iter())
1094 .map(|(r, &x)| r[k] * x)
1095 .sum::<f64>()
1096 / r_sum;
1097 let new_var: f64 = (responsibilities
1098 .iter()
1099 .zip(data.iter())
1100 .map(|(r, &x)| r[k] * (x - new_mean).powi(2))
1101 .sum::<f64>()
1102 / r_sum)
1103 .max(1e-6);
1104 self.components[k].weight = new_weight;
1105 self.components[k].mean = new_mean;
1106 self.components[k].var = new_var;
1107 }
1108 }
1109
1110 pub fn fit(&mut self, data: &[f64], max_iter: usize) -> f64 {
1114 self.ll_history.clear();
1115 let mut prev_ll = f64::NEG_INFINITY;
1116 for _ in 0..max_iter {
1117 let resp = self.e_step(data);
1118 self.m_step(data, &resp);
1119 let ll = self.log_likelihood(data);
1120 self.ll_history.push(ll);
1121 if (ll - prev_ll).abs() < self.tol {
1122 break;
1123 }
1124 prev_ll = ll;
1125 }
1126 *self.ll_history.last().unwrap_or(&f64::NEG_INFINITY)
1127 }
1128
1129 pub fn predict(&self, x: f64) -> usize {
1131 self.components
1132 .iter()
1133 .enumerate()
1134 .map(|(k, c)| {
1135 let ll = c.weight.max(1e-300).ln() + log_normal_pdf(x, c.mean, c.var);
1136 (k, ll)
1137 })
1138 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
1139 .map(|(k, _)| k)
1140 .unwrap_or(0)
1141 }
1142
1143 pub fn normalized_weights(&self) -> Vec<f64> {
1145 let sum: f64 = self
1146 .components
1147 .iter()
1148 .map(|c| c.weight)
1149 .sum::<f64>()
1150 .max(1e-300);
1151 self.components.iter().map(|c| c.weight / sum).collect()
1152 }
1153}
1154
1155#[cfg(test)]
1160mod tests {
1161 use super::*;
1162
1163 fn make_simple_bn() -> BayesianNetwork {
1166 let mut bn = BayesianNetwork::new();
1167 bn.add_node(BnNode::new("Rain", 2, vec![], vec![0.3, 0.7]));
1169 bn.add_node(BnNode::new(
1171 "Sprinkler",
1172 2,
1173 vec![0],
1174 vec![0.1, 0.9, 0.5, 0.5], ));
1176 bn
1177 }
1178
1179 #[test]
1180 fn test_bn_validate() {
1181 let bn = make_simple_bn();
1182 assert!(bn.validate());
1183 }
1184
1185 #[test]
1186 fn test_bn_joint_probability_sums_to_one() {
1187 let mut bn = BayesianNetwork::new();
1188 bn.add_node(BnNode::new("A", 2, vec![], vec![0.4, 0.6]));
1189 bn.add_node(BnNode::new("B", 2, vec![0], vec![0.7, 0.3, 0.2, 0.8]));
1190 let total: f64 = (0..4)
1192 .map(|i| {
1193 let a = i / 2;
1194 let b_val = i % 2;
1195 bn.joint_probability(&[a, b_val])
1196 })
1197 .sum();
1198 assert!((total - 1.0).abs() < 1e-10);
1199 }
1200
1201 #[test]
1202 fn test_bn_marginal_sums_to_one() {
1203 let bn = make_simple_bn();
1204 let m0 = bn.marginal(0, 0);
1205 let m1 = bn.marginal(0, 1);
1206 assert!((m0 + m1 - 1.0).abs() < 1e-6);
1207 }
1208
1209 #[test]
1210 fn test_bn_marginal_root_equals_prior() {
1211 let bn = make_simple_bn();
1212 let m0 = bn.marginal(0, 0);
1213 assert!((m0 - 0.3).abs() < 1e-8);
1214 }
1215
1216 #[test]
1217 fn test_bn_conditional_valid() {
1218 let bn = make_simple_bn();
1219 let p = bn.conditional(1, 0, &[(0, 0)]);
1220 assert!((0.0..=1.0).contains(&p));
1221 }
1222
1223 #[test]
1224 fn test_bn_marginal_all_sums_to_one() {
1225 let bn = make_simple_bn();
1226 let m = bn.marginal_all(0);
1227 let s: f64 = m.iter().sum();
1228 assert!((s - 1.0).abs() < 1e-8);
1229 }
1230
1231 #[test]
1232 fn test_bn_cpt_value() {
1233 let node = BnNode::new("X", 2, vec![], vec![0.4, 0.6]);
1234 assert!((node.cpt_value(0, 0) - 0.4).abs() < 1e-10);
1235 assert!((node.cpt_value(1, 0) - 0.6).abs() < 1e-10);
1236 }
1237
1238 #[test]
1239 fn test_bn_single_node() {
1240 let mut bn = BayesianNetwork::new();
1241 bn.add_node(BnNode::new("X", 3, vec![], vec![0.2, 0.5, 0.3]));
1242 let p = bn.joint_probability(&[1]);
1243 assert!((p - 0.5).abs() < 1e-10);
1244 }
1245
1246 fn make_hmm() -> HiddenMarkovModel {
1249 HiddenMarkovModel::new(
1250 2,
1251 vec![0.6, 0.4],
1252 vec![vec![0.7, 0.3], vec![0.4, 0.6]],
1253 vec![0.0, 3.0],
1254 vec![1.0, 1.0],
1255 )
1256 }
1257
1258 #[test]
1259 fn test_hmm_forward_returns_finite() {
1260 let hmm = make_hmm();
1261 let obs = vec![0.1, 0.2, 0.3, 2.8, 3.1];
1262 let ll = hmm.forward(&obs);
1263 assert!(ll.is_finite());
1264 }
1265
1266 #[test]
1267 fn test_hmm_forward_empty() {
1268 let hmm = make_hmm();
1269 assert_eq!(hmm.forward(&[]), 0.0);
1270 }
1271
1272 #[test]
1273 fn test_hmm_viterbi_length() {
1274 let hmm = make_hmm();
1275 let obs = vec![0.1, 0.2, 0.3, 2.8, 3.1];
1276 let path = hmm.viterbi(&obs);
1277 assert_eq!(path.len(), obs.len());
1278 }
1279
1280 #[test]
1281 fn test_hmm_viterbi_valid_states() {
1282 let hmm = make_hmm();
1283 let obs = vec![0.1, 2.9, 0.2, 3.0];
1284 let path = hmm.viterbi(&obs);
1285 assert!(path.iter().all(|&s| s < 2));
1286 }
1287
1288 #[test]
1289 fn test_hmm_viterbi_empty() {
1290 let hmm = make_hmm();
1291 assert_eq!(hmm.viterbi(&[]).len(), 0);
1292 }
1293
1294 #[test]
1295 fn test_hmm_baum_welch_ll_increases() {
1296 let mut hmm = make_hmm();
1297 let obs: Vec<f64> = (0..20)
1298 .map(|i| if i % 3 == 0 { 0.1 } else { 2.9 })
1299 .collect();
1300 let ll_hist = hmm.baum_welch(&obs, 5);
1301 for i in 1..ll_hist.len() {
1303 assert!(ll_hist[i] >= ll_hist[i - 1] - 1e-4);
1304 }
1305 }
1306
1307 #[test]
1308 fn test_hmm_uniform_creation() {
1309 let hmm = HiddenMarkovModel::uniform(3);
1310 assert_eq!(hmm.n_states, 3);
1311 let row_sum: f64 = hmm.transition[0].iter().sum();
1312 assert!((row_sum - 1.0).abs() < 1e-10);
1313 }
1314
1315 #[test]
1318 fn test_gp_rbf_kernel_diagonal() {
1319 let gp = GaussianProcess::new(KernelType::Rbf, 1.0, 1.0, 1e-3);
1320 assert!((gp.k(0.0, 0.0) - 1.0).abs() < 1e-10);
1321 }
1322
1323 #[test]
1324 fn test_gp_rbf_kernel_decays() {
1325 let gp = GaussianProcess::new(KernelType::Rbf, 1.0, 1.0, 1e-3);
1326 assert!(gp.k(0.0, 10.0) < gp.k(0.0, 1.0));
1327 }
1328
1329 #[test]
1330 fn test_gp_matern32_diagonal() {
1331 let gp = GaussianProcess::new(KernelType::Matern32, 2.0, 1.0, 1e-3);
1332 assert!((gp.k(0.0, 0.0) - 2.0).abs() < 1e-10);
1333 }
1334
1335 #[test]
1336 fn test_gp_matern52_diagonal() {
1337 let gp = GaussianProcess::new(KernelType::Matern52, 1.5, 1.0, 1e-3);
1338 assert!((gp.k(0.0, 0.0) - 1.5).abs() < 1e-10);
1339 }
1340
1341 #[test]
1342 fn test_gp_periodic_diagonal() {
1343 let gp = GaussianProcess::new(KernelType::Periodic, 1.0, 1.0, 1e-3);
1344 assert!((gp.k(0.0, 0.0) - 1.0).abs() < 1e-10);
1345 }
1346
1347 #[test]
1348 fn test_gp_fit_predict_mean() {
1349 let mut gp = GaussianProcess::new(KernelType::Rbf, 1.0, 1.0, 1e-4);
1350 let x = vec![0.0, 1.0, 2.0, 3.0];
1351 let y = vec![0.0, 1.0, 4.0, 9.0];
1352 gp.fit(x, y);
1353 let (mean, _var) = gp.predict(1.0);
1354 assert!((mean - 1.0).abs() < 0.5); }
1356
1357 #[test]
1358 fn test_gp_predict_variance_positive() {
1359 let mut gp = GaussianProcess::new(KernelType::Rbf, 1.0, 1.0, 1e-4);
1360 gp.fit(vec![0.0, 1.0], vec![0.0, 1.0]);
1361 let (_mean, var) = gp.predict(5.0); assert!(var > 0.0);
1363 }
1364
1365 #[test]
1366 fn test_gp_predict_empty() {
1367 let gp = GaussianProcess::new(KernelType::Rbf, 1.0, 1.0, 1e-3);
1368 let (mean, var) = gp.predict(0.5);
1369 assert_eq!(mean, 0.0);
1370 assert!(var > 0.0);
1371 }
1372
1373 #[test]
1374 fn test_gp_log_marginal_likelihood() {
1375 let mut gp = GaussianProcess::new(KernelType::Rbf, 1.0, 1.0, 0.1);
1376 gp.fit(vec![0.0, 1.0, 2.0], vec![0.0, 1.0, 0.0]);
1377 let lml = gp.log_marginal_likelihood();
1378 assert!(lml.is_finite());
1379 }
1380
1381 #[test]
1384 fn test_dp_initial_state() {
1385 let dp = DirichletProcess::new(1.0);
1386 assert_eq!(dp.n_clusters(), 0);
1387 assert_eq!(dp.n_assigned, 0);
1388 }
1389
1390 #[test]
1391 fn test_dp_crp_first_point() {
1392 let mut dp = DirichletProcess::new(1.0);
1393 let c = dp.crp_assign(0.0);
1394 assert_eq!(c, 0);
1395 assert_eq!(dp.n_clusters(), 1);
1396 }
1397
1398 #[test]
1399 fn test_dp_crp_multiple_points() {
1400 let mut dp = DirichletProcess::new(0.1); for i in 0..10 {
1402 dp.crp_assign(i as f64 * 0.01);
1403 }
1404 assert!(dp.n_clusters() <= 5);
1406 }
1407
1408 #[test]
1409 fn test_dp_stick_breaking_sums_to_one() {
1410 let dp = DirichletProcess::new(2.0);
1411 let w = dp.stick_breaking_weights(10);
1412 let sum: f64 = w.iter().sum();
1413 assert!((sum - 1.0).abs() < 1e-6);
1414 }
1415
1416 #[test]
1417 fn test_dp_stick_breaking_positive() {
1418 let dp = DirichletProcess::new(1.0);
1419 let w = dp.stick_breaking_weights(5);
1420 assert!(w.iter().all(|&wi| wi > 0.0));
1421 }
1422
1423 #[test]
1424 fn test_dp_expected_clusters() {
1425 let e = DirichletProcess::expected_clusters(1.0, 100);
1426 assert!(e > 3.0 && e < 10.0);
1427 }
1428
1429 #[test]
1430 fn test_dp_cluster_variances() {
1431 let mut dp = DirichletProcess::new(0.5);
1432 for i in 0..5 {
1433 dp.crp_assign(i as f64);
1434 }
1435 let vars = dp.cluster_variances();
1436 assert!(vars.iter().all(|&v| v >= 0.0));
1437 }
1438
1439 #[test]
1442 fn test_vi_elbo_finite() {
1443 let vi = VariationalInference::new(2, 0.0, 1.0, 1.0);
1444 let obs = vec![0.0, 1.0, -1.0, 2.0];
1445 let elbo = vi.elbo(&obs);
1446 assert!(elbo.is_finite());
1447 }
1448
1449 #[test]
1450 fn test_vi_cavi_step_updates_params() {
1451 let mut vi = VariationalInference::new(2, 0.0, 1.0, 1.0);
1452 let old_mean = vi.var_mean[0];
1453 let obs = vec![3.0, 3.1, 3.2, -3.0, -3.1, -3.2];
1454 vi.cavi_step(&obs);
1455 assert!((vi.var_mean[0] - old_mean).abs() > 0.0);
1457 }
1458
1459 #[test]
1460 fn test_vi_fit_returns_finite() {
1461 let mut vi = VariationalInference::new(2, 0.0, 2.0, 1.0);
1462 let obs: Vec<f64> = (0..20)
1463 .map(|i| if i % 2 == 0 { 1.0 } else { -1.0 })
1464 .collect();
1465 let elbo = vi.fit(&obs, 10);
1466 assert!(elbo.is_finite());
1467 }
1468
1469 #[test]
1470 fn test_vi_reparameterize() {
1471 let vi = VariationalInference::new(2, 0.0, 1.0, 1.0);
1472 let sample = vi.reparameterize(0, 1.0);
1473 let expected = vi.var_mean[0] + vi.var_var[0].sqrt();
1475 assert!((sample - expected).abs() < 1e-10);
1476 }
1477
1478 #[test]
1479 fn test_vi_predictive_density_positive() {
1480 let vi = VariationalInference::new(2, 0.0, 1.0, 1.0);
1481 let p = vi.predictive_density(0.0);
1482 assert!(p > 0.0);
1483 }
1484
1485 #[test]
1486 fn test_vi_elbo_history_grows() {
1487 let mut vi = VariationalInference::new(2, 0.0, 1.0, 1.0);
1488 let obs = vec![1.0, -1.0, 2.0];
1489 vi.fit(&obs, 5);
1490 assert_eq!(vi.elbo_history.len(), 5);
1491 }
1492
1493 #[test]
1496 fn test_em_initial_weights_sum_to_one() {
1497 let em = ExpectationMaximization::new(3);
1498 let sum: f64 = em.normalized_weights().iter().sum();
1499 assert!((sum - 1.0).abs() < 1e-10);
1500 }
1501
1502 #[test]
1503 fn test_em_kmeans_init() {
1504 let mut em = ExpectationMaximization::new(2);
1505 let data = vec![0.0, 0.1, 0.2, 5.0, 5.1, 5.2];
1506 em.kmeans_init(&data);
1507 let means: Vec<f64> = em.components.iter().map(|c| c.mean).collect();
1509 assert!(means.iter().any(|&m| m < 1.0));
1510 assert!(means.iter().any(|&m| m > 4.0));
1511 }
1512
1513 #[test]
1514 fn test_em_log_likelihood_finite() {
1515 let em = ExpectationMaximization::new(2);
1516 let data = vec![0.0, 1.0, 2.0];
1517 assert!(em.log_likelihood(&data).is_finite());
1518 }
1519
1520 #[test]
1521 fn test_em_fit_ll_increases() {
1522 let mut em = ExpectationMaximization::new(2);
1523 let data: Vec<f64> = (0..30)
1524 .map(|i| {
1525 if i < 15 {
1526 i as f64 * 0.1
1527 } else {
1528 5.0 + i as f64 * 0.1
1529 }
1530 })
1531 .collect();
1532 em.kmeans_init(&data);
1533 em.fit(&data, 20);
1534 let ll = &em.ll_history;
1535 for i in 1..ll.len() {
1536 assert!(ll[i] >= ll[i - 1] - 1e-4);
1537 }
1538 }
1539
1540 #[test]
1541 fn test_em_predict_valid_component() {
1542 let em = ExpectationMaximization::new(3);
1543 let pred = em.predict(0.5);
1544 assert!(pred < 3);
1545 }
1546
1547 #[test]
1548 fn test_em_bic_finite() {
1549 let em = ExpectationMaximization::new(2);
1550 let data = vec![0.0, 1.0, 5.0, 6.0];
1551 let bic = em.bic(&data);
1552 assert!(bic.is_finite());
1553 }
1554
1555 #[test]
1556 fn test_em_fit_separates_clusters() {
1557 let mut em = ExpectationMaximization::new(2);
1558 let mut data: Vec<f64> = (0..20).map(|i| i as f64 * 0.05).collect(); let data2: Vec<f64> = (0..20).map(|i| 10.0 + i as f64 * 0.05).collect(); data.extend(data2);
1561 em.kmeans_init(&data);
1562 em.fit(&data, 50);
1563 let means: Vec<f64> = em.components.iter().map(|c| c.mean).collect();
1565 assert!(means.iter().any(|&m| m < 3.0));
1566 assert!(means.iter().any(|&m| m > 7.0));
1567 }
1568
1569 #[test]
1570 fn test_em_n_components() {
1571 let em = ExpectationMaximization::new(4);
1572 assert_eq!(em.n_components, 4);
1573 assert_eq!(em.components.len(), 4);
1574 }
1575
1576 #[test]
1579 fn test_log_sum_exp_empty() {
1580 assert_eq!(log_sum_exp(&[]), f64::NEG_INFINITY);
1581 }
1582
1583 #[test]
1584 fn test_log_sum_exp_single() {
1585 assert!((log_sum_exp(&[2.0]) - 2.0).abs() < 1e-10);
1586 }
1587
1588 #[test]
1589 fn test_softmax_sums_to_one() {
1590 let s = softmax(&[1.0, 2.0, 3.0]);
1591 assert!((s.iter().sum::<f64>() - 1.0).abs() < 1e-10);
1592 }
1593
1594 #[test]
1595 fn test_normal_pdf_peak() {
1596 let p = normal_pdf(0.0, 0.0, 1.0);
1597 assert!((p - 1.0 / (TAU).sqrt()).abs() < 1e-10);
1598 }
1599
1600 #[test]
1601 fn test_mvn_log_pdf_diag() {
1602 let x = vec![0.0, 0.0];
1603 let mean = vec![0.0, 0.0];
1604 let var = vec![1.0, 1.0];
1605 let lp = mvn_log_pdf_diag(&x, &mean, &var);
1606 assert!(lp.is_finite());
1607 }
1608}