1use itertools::repeat_n;
2use nalgebra::{DMatrix, RowDVector};
3use rand::Rng;
4use statrs::distribution::MultivariateNormal;
5use crate::stats::ContinuousBatchwise;
6use crate::utils::{col_normalize_log_weights, replacement_sampling_weighted};
7
8
9pub trait ThinParams: Clone + Send + Sync {
10 fn n_clusters(&self) -> usize;
12
13 fn cluster_dist(&self, cluster_id: usize) -> &MultivariateNormal;
15
16 fn cluster_weights(&self) -> &[f64];
18
19 fn cluster_aux_dist(&self, cluster_id: usize, aux_id: usize) -> &MultivariateNormal;
21
22 fn cluster_aux_weights(&self, cluster_id: usize) -> &[f64; 2];
24
25 fn n_params(&self) -> usize {
27 let dim = self.cluster_dist(0).mu().len();
28 self.n_clusters() * (dim * dim + dim + 1)
29 }
30}
31
32#[derive(Debug, Clone, PartialEq)]
33pub struct OwnedThinParams {
34 pub clusters: Vec<MultivariateNormal>,
35 pub cluster_weights: Vec<f64>,
36 pub clusters_aux: Vec<[MultivariateNormal; 2]>,
37 pub cluster_weights_aux: Vec<[f64; 2]>,
38}
39
40impl ThinParams for OwnedThinParams {
41 fn n_clusters(&self) -> usize {
42 self.clusters.len()
43 }
44
45 fn cluster_dist(&self, cluster_id: usize) -> &MultivariateNormal {
46 &self.clusters[cluster_id]
47 }
48
49 fn cluster_weights(&self) -> &[f64] {
50 &self.cluster_weights
51 }
52
53 fn cluster_aux_dist(&self, cluster_id: usize, aux_id: usize) -> &MultivariateNormal {
54 &self.clusters_aux[cluster_id][aux_id]
55 }
56
57 fn cluster_aux_weights(&self, cluster_id: usize) -> &[f64; 2] {
58 &self.cluster_weights_aux[cluster_id]
59 }
60}
61
62pub struct SuperMixtureParams<'a, D: ThinParams>(pub(crate) &'a D);
64
65impl<'a, D: ThinParams> MixtureParams for SuperMixtureParams<'a, D> {
66 fn n_clusters(&self) -> usize {
67 self.0.n_clusters()
68 }
69
70 fn dist(&self, cluster_id: usize) -> &MultivariateNormal {
71 self.0.cluster_dist(cluster_id)
72 }
73
74 fn weights(&self) -> &[f64] {
75 self.0.cluster_weights()
76 }
77}
78
79pub struct AuxMixtureParams<'a, D: ThinParams>(pub(crate) &'a D, pub(crate) usize);
81
82impl<'a, D: ThinParams> MixtureParams for AuxMixtureParams<'a, D> {
83 fn n_clusters(&self) -> usize {
84 2
85 }
86
87 fn dist(&self, cluster_id: usize) -> &MultivariateNormal {
88 self.0.cluster_aux_dist(self.1, cluster_id)
89 }
90
91 fn weights(&self) -> &[f64] {
92 self.0.cluster_aux_weights(self.1)
93 }
94}
95
96pub trait MixtureParams {
97 fn n_clusters(&self) -> usize;
99
100 fn dist(&self, cluster_id: usize) -> &MultivariateNormal;
102
103 fn weights(&self) -> &[f64];
105
106 fn log_likelihood(&self, data: DMatrix<f64>) -> DMatrix<f64> {
108 let mut ll = DMatrix::zeros(self.n_clusters(), data.ncols());
109 for (cluster_id, data) in repeat_n(data, self.n_clusters()).enumerate() {
111 ll.row_mut(cluster_id).copy_from_slice(
112 self.dist(cluster_id)
113 .batchwise_ln_pdf(data)
114 .as_slice()
115 );
116 }
117
118 let weights = self.weights();
120 for (prim, mut row) in ll.row_iter_mut().enumerate() {
121 let ln_weight = weights[prim].ln();
122 row.apply(|x| *x += ln_weight);
123 }
124
125 ll
126 }
127
128 fn predict(&self, data: DMatrix<f64>) -> (DMatrix<f64>, RowDVector<usize>) {
130 let mut labels = RowDVector::zeros(data.ncols());
131 let log_likelihood = self.log_likelihood(data);
132 hard_assignment(&log_likelihood, labels.as_mut_slice());
133 let probs = col_normalize_log_weights(log_likelihood);
134
135 (probs, labels)
136 }
137}
138
139
140pub fn hard_assignment(
159 log_likelihood: &DMatrix<f64>,
160 labels: &mut [usize],
161) {
162 for (i, row) in log_likelihood.column_iter().enumerate() {
163 labels[i] = row.argmax().0;
164 }
165}
166
167
168pub fn soft_assignment(
189 log_likelihood: DMatrix<f64>,
190 labels: &mut [usize],
191 rng: &mut impl Rng,
192) {
193 let probs = col_normalize_log_weights(log_likelihood);
194 for (i, col) in probs.column_iter().enumerate() {
195 replacement_sampling_weighted(rng, col.into_iter().cloned(), &mut labels[i..=i]);
196 }
197}