mixturs/params/
thin.rs

1use itertools::repeat_n;
2use nalgebra::{DMatrix, RowDVector};
3use rand::Rng;
4use statrs::distribution::MultivariateNormal;
5use crate::stats::ContinuousBatchwise;
6use crate::utils::{col_normalize_log_weights, replacement_sampling_weighted};
7
8
9pub trait ThinParams: Clone + Send + Sync {
10    /// Number of clusters.
11    fn n_clusters(&self) -> usize;
12
13    /// Distribution of the primary cluster.
14    fn cluster_dist(&self, cluster_id: usize) -> &MultivariateNormal;
15
16    /// Weights of the primary clusters.
17    fn cluster_weights(&self) -> &[f64];
18
19    /// Distribution of the auxiliary clusters given the primary cluster.
20    fn cluster_aux_dist(&self, cluster_id: usize, aux_id: usize) -> &MultivariateNormal;
21
22    /// Weights of the auxiliary clusters given the primary cluster.
23    fn cluster_aux_weights(&self, cluster_id: usize) -> &[f64; 2];
24
25    /// Number of parameters in the model.
26    fn n_params(&self) -> usize {
27        let dim = self.cluster_dist(0).mu().len();
28        self.n_clusters() * (dim * dim + dim + 1)
29    }
30}
31
32#[derive(Debug, Clone, PartialEq)]
33pub struct OwnedThinParams {
34    pub clusters: Vec<MultivariateNormal>,
35    pub cluster_weights: Vec<f64>,
36    pub clusters_aux: Vec<[MultivariateNormal; 2]>,
37    pub cluster_weights_aux: Vec<[f64; 2]>,
38}
39
40impl ThinParams for OwnedThinParams {
41    fn n_clusters(&self) -> usize {
42        self.clusters.len()
43    }
44
45    fn cluster_dist(&self, cluster_id: usize) -> &MultivariateNormal {
46        &self.clusters[cluster_id]
47    }
48
49    fn cluster_weights(&self) -> &[f64] {
50        &self.cluster_weights
51    }
52
53    fn cluster_aux_dist(&self, cluster_id: usize, aux_id: usize) -> &MultivariateNormal {
54        &self.clusters_aux[cluster_id][aux_id]
55    }
56
57    fn cluster_aux_weights(&self, cluster_id: usize) -> &[f64; 2] {
58        &self.cluster_weights_aux[cluster_id]
59    }
60}
61
62/// Selects super cluster params from thin params
63pub struct SuperMixtureParams<'a, D: ThinParams>(pub(crate) &'a D);
64
65impl<'a, D: ThinParams> MixtureParams for SuperMixtureParams<'a, D> {
66    fn n_clusters(&self) -> usize {
67        self.0.n_clusters()
68    }
69
70    fn dist(&self, cluster_id: usize) -> &MultivariateNormal {
71        self.0.cluster_dist(cluster_id)
72    }
73
74    fn weights(&self) -> &[f64] {
75        self.0.cluster_weights()
76    }
77}
78
79/// Selects auxiliary cluster params from thin params for a given super cluster
80pub struct AuxMixtureParams<'a, D: ThinParams>(pub(crate) &'a D, pub(crate) usize);
81
82impl<'a, D: ThinParams> MixtureParams for AuxMixtureParams<'a, D> {
83    fn n_clusters(&self) -> usize {
84        2
85    }
86
87    fn dist(&self, cluster_id: usize) -> &MultivariateNormal {
88        self.0.cluster_aux_dist(self.1, cluster_id)
89    }
90
91    fn weights(&self) -> &[f64] {
92        self.0.cluster_aux_weights(self.1)
93    }
94}
95
96pub trait MixtureParams {
97    /// Number of clusters.
98    fn n_clusters(&self) -> usize;
99
100    /// Distribution of the primary cluster.
101    fn dist(&self, cluster_id: usize) -> &MultivariateNormal;
102
103    /// Weights of the primary clusters.
104    fn weights(&self) -> &[f64];
105
106    /// Log-likelihood of the data points (columns) given the model.
107    fn log_likelihood(&self, data: DMatrix<f64>) -> DMatrix<f64> {
108        let mut ll = DMatrix::zeros(self.n_clusters(), data.ncols());
109        // Add cluster log probabilities
110        for (cluster_id, data) in repeat_n(data, self.n_clusters()).enumerate() {
111            ll.row_mut(cluster_id).copy_from_slice(
112                self.dist(cluster_id)
113                    .batchwise_ln_pdf(data)
114                    .as_slice()
115            );
116        }
117
118        // Add mixture weights
119        let weights = self.weights();
120        for (prim, mut row) in ll.row_iter_mut().enumerate() {
121            let ln_weight = weights[prim].ln();
122            row.apply(|x| *x += ln_weight);
123        }
124
125        ll
126    }
127
128    /// Predict the cluster labels for the data points (columns).
129    fn predict(&self, data: DMatrix<f64>) -> (DMatrix<f64>, RowDVector<usize>) {
130        let mut labels = RowDVector::zeros(data.ncols());
131        let log_likelihood = self.log_likelihood(data);
132        hard_assignment(&log_likelihood, labels.as_mut_slice());
133        let probs = col_normalize_log_weights(log_likelihood);
134
135        (probs, labels)
136    }
137}
138
139
140/// Assigns each column of `log_likelihood` to the cluster with the highest log probability.
141///
142/// # Arguments
143///
144/// * `log_likelihood`: A matrix of log probabilities of shape (n_clusters, n_samples)
145/// * `labels`: A mutable vector of length `n_samples` the cluster assignments will be written to.
146///
147/// # Examples
148///
149/// ```
150/// use mixturs::params::thin::hard_assignment;
151/// use nalgebra::{DMatrix, RowDVector};
152///
153/// let log_likelihood = DMatrix::from_row_slice(2, 3, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
154/// let mut labels = RowDVector::zeros(3);
155/// hard_assignment(&log_likelihood, labels.as_mut_slice());
156/// assert_eq!(labels, RowDVector::from_row_slice(&[1, 1, 1]));
157/// ```
158pub fn hard_assignment(
159    log_likelihood: &DMatrix<f64>,
160    labels: &mut [usize],
161) {
162    for (i, row) in log_likelihood.column_iter().enumerate() {
163        labels[i] = row.argmax().0;
164    }
165}
166
167
168/// Assigns each column of `log_likelihood` to a cluster according to the probability distribution
169/// defined by the log probabilities.
170///
171/// # Arguments
172///
173/// * `log_likelihood`: A matrix of log probabilities of shape (n_clusters, n_samples)
174/// * `labels`: A mutable vector of length `n_samples` the cluster assignments will be written to.
175/// * `rng`: A random number generator.
176///
177/// # Examples
178///
179/// ```
180/// use mixturs::params::thin::soft_assignment;
181/// use nalgebra::{DMatrix, RowDVector};
182///
183/// let log_likelihood = DMatrix::from_row_slice(2, 3, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
184/// let mut labels = RowDVector::zeros(3);
185/// let mut rng = rand::thread_rng();
186/// soft_assignment(log_likelihood, labels.as_mut_slice(), &mut rng);
187/// ```
188pub fn soft_assignment(
189    log_likelihood: DMatrix<f64>,
190    labels: &mut [usize],
191    rng: &mut impl Rng,
192) {
193    let probs = col_normalize_log_weights(log_likelihood);
194    for (i, col) in probs.column_iter().enumerate() {
195        replacement_sampling_weighted(rng, col.into_iter().cloned(), &mut labels[i..=i]);
196    }
197}