1use ghostflow_core::Tensor;
7use rand::prelude::*;
8use std::f32::consts::PI;
9
10pub struct GaussianMixture {
12 pub n_components: usize,
13 pub covariance_type: CovarianceType,
14 pub max_iter: usize,
15 pub tol: f32,
16 pub reg_covar: f32,
17 pub n_init: usize,
18
19 weights: Vec<f32>, means: Vec<Vec<f32>>, covariances: Vec<Vec<f32>>, converged: bool,
24}
25
26#[derive(Clone, Copy)]
27pub enum CovarianceType {
28 Full, Tied, Diag, Spherical, }
33
34impl GaussianMixture {
35 pub fn new(n_components: usize) -> Self {
36 Self {
37 n_components,
38 covariance_type: CovarianceType::Full,
39 max_iter: 100,
40 tol: 1e-3,
41 reg_covar: 1e-6,
42 n_init: 1,
43 weights: Vec::new(),
44 means: Vec::new(),
45 covariances: Vec::new(),
46 converged: false,
47 }
48 }
49
50 pub fn covariance_type(mut self, cov_type: CovarianceType) -> Self {
51 self.covariance_type = cov_type;
52 self
53 }
54
55 pub fn max_iter(mut self, iter: usize) -> Self {
56 self.max_iter = iter;
57 self
58 }
59
60 pub fn tol(mut self, tolerance: f32) -> Self {
61 self.tol = tolerance;
62 self
63 }
64
65 pub fn fit(&mut self, x: &Tensor) {
67 let n_samples = x.dims()[0];
68 let n_features = x.dims()[1];
69 let x_data = x.data_f32();
70
71 let mut best_log_likelihood = f32::NEG_INFINITY;
72 let mut best_weights = Vec::new();
73 let mut best_means = Vec::new();
74 let mut best_covariances = Vec::new();
75
76 for _ in 0..self.n_init {
78 self.initialize_parameters(&x_data, n_samples, n_features);
80
81 let mut prev_log_likelihood = f32::NEG_INFINITY;
82
83 for _iteration in 0..self.max_iter {
85 let responsibilities = self.e_step(&x_data, n_samples, n_features);
87
88 self.m_step(&x_data, &responsibilities, n_samples, n_features);
90
91 let log_likelihood = self.compute_log_likelihood(&x_data, n_samples, n_features);
93
94 if (log_likelihood - prev_log_likelihood).abs() < self.tol {
96 self.converged = true;
97 break;
98 }
99
100 prev_log_likelihood = log_likelihood;
101 }
102
103 let final_log_likelihood = self.compute_log_likelihood(&x_data, n_samples, n_features);
105 if final_log_likelihood > best_log_likelihood {
106 best_log_likelihood = final_log_likelihood;
107 best_weights = self.weights.clone();
108 best_means = self.means.clone();
109 best_covariances = self.covariances.clone();
110 }
111 }
112
113 self.weights = best_weights;
115 self.means = best_means;
116 self.covariances = best_covariances;
117 }
118
119 fn initialize_parameters(&mut self, x_data: &[f32], n_samples: usize, n_features: usize) {
121 let mut rng = thread_rng();
122
123 self.weights = vec![1.0 / self.n_components as f32; self.n_components];
125
126 self.means = Vec::with_capacity(self.n_components);
128
129 let first_idx = rng.gen_range(0..n_samples);
131 self.means.push(x_data[first_idx * n_features..(first_idx + 1) * n_features].to_vec());
132
133 for _ in 1..self.n_components {
135 let mut distances = vec![f32::MAX; n_samples];
136
137 for i in 0..n_samples {
138 let sample = &x_data[i * n_features..(i + 1) * n_features];
139 let min_dist = self.means.iter()
140 .map(|mean| {
141 sample.iter().zip(mean.iter())
142 .map(|(x, m)| (x - m).powi(2))
143 .sum::<f32>()
144 })
145 .min_by(|a, b| a.partial_cmp(b).unwrap())
146 .unwrap();
147 distances[i] = min_dist;
148 }
149
150 let total_dist: f32 = distances.iter().sum();
152 let mut cumsum = 0.0;
153 let rand_val = rng.gen::<f32>() * total_dist;
154
155 let mut selected_idx = 0;
156 for (i, &dist) in distances.iter().enumerate() {
157 cumsum += dist;
158 if cumsum >= rand_val {
159 selected_idx = i;
160 break;
161 }
162 }
163
164 self.means.push(x_data[selected_idx * n_features..(selected_idx + 1) * n_features].to_vec());
165 }
166
167 self.covariances = match self.covariance_type {
169 CovarianceType::Full => {
170 (0..self.n_components)
171 .map(|_| {
172 let mut cov = vec![0.0; n_features * n_features];
173 for i in 0..n_features {
174 cov[i * n_features + i] = 1.0;
175 }
176 cov
177 })
178 .collect()
179 }
180 CovarianceType::Diag => {
181 (0..self.n_components)
182 .map(|_| vec![1.0; n_features])
183 .collect()
184 }
185 CovarianceType::Spherical => {
186 (0..self.n_components)
187 .map(|_| vec![1.0])
188 .collect()
189 }
190 CovarianceType::Tied => {
191 let mut cov = vec![0.0; n_features * n_features];
192 for i in 0..n_features {
193 cov[i * n_features + i] = 1.0;
194 }
195 vec![cov]
196 }
197 };
198 }
199
200 fn e_step(&self, x_data: &[f32], n_samples: usize, n_features: usize) -> Vec<Vec<f32>> {
202 let mut responsibilities = vec![vec![0.0; self.n_components]; n_samples];
203
204 for i in 0..n_samples {
205 let sample = &x_data[i * n_features..(i + 1) * n_features];
206 let mut total = 0.0;
207
208 for k in 0..self.n_components {
209 let prob = self.weights[k] * self.gaussian_pdf(sample, k, n_features);
210 responsibilities[i][k] = prob;
211 total += prob;
212 }
213
214 if total > 0.0 {
216 for k in 0..self.n_components {
217 responsibilities[i][k] /= total;
218 }
219 }
220 }
221
222 responsibilities
223 }
224
225 fn m_step(&mut self, x_data: &[f32], responsibilities: &[Vec<f32>], n_samples: usize, n_features: usize) {
227 for k in 0..self.n_components {
229 let n_k: f32 = responsibilities.iter().map(|r| r[k]).sum();
230 self.weights[k] = n_k / n_samples as f32;
231
232 let mut new_mean = vec![0.0; n_features];
234 for i in 0..n_samples {
235 let sample = &x_data[i * n_features..(i + 1) * n_features];
236 for j in 0..n_features {
237 new_mean[j] += responsibilities[i][k] * sample[j];
238 }
239 }
240 for j in 0..n_features {
241 new_mean[j] /= n_k;
242 }
243 self.means[k] = new_mean;
244
245 match self.covariance_type {
247 CovarianceType::Diag => {
248 let mut new_cov = vec![0.0; n_features];
249 for i in 0..n_samples {
250 let sample = &x_data[i * n_features..(i + 1) * n_features];
251 for j in 0..n_features {
252 let diff = sample[j] - self.means[k][j];
253 new_cov[j] += responsibilities[i][k] * diff * diff;
254 }
255 }
256 for j in 0..n_features {
257 new_cov[j] = (new_cov[j] / n_k) + self.reg_covar;
258 }
259 self.covariances[k] = new_cov;
260 }
261 CovarianceType::Spherical => {
262 let mut variance = 0.0;
263 for i in 0..n_samples {
264 let sample = &x_data[i * n_features..(i + 1) * n_features];
265 for j in 0..n_features {
266 let diff = sample[j] - self.means[k][j];
267 variance += responsibilities[i][k] * diff * diff;
268 }
269 }
270 variance = (variance / (n_k * n_features as f32)) + self.reg_covar;
271 self.covariances[k] = vec![variance];
272 }
273 _ => {
274 let mut new_cov = vec![0.0; n_features];
276 for i in 0..n_samples {
277 let sample = &x_data[i * n_features..(i + 1) * n_features];
278 for j in 0..n_features {
279 let diff = sample[j] - self.means[k][j];
280 new_cov[j] += responsibilities[i][k] * diff * diff;
281 }
282 }
283 for j in 0..n_features {
284 new_cov[j] = (new_cov[j] / n_k) + self.reg_covar;
285 }
286 self.covariances[k] = new_cov;
287 }
288 }
289 }
290 }
291
292 fn gaussian_pdf(&self, sample: &[f32], component: usize, n_features: usize) -> f32 {
294 let mean = &self.means[component];
295 let cov = &self.covariances[component];
296
297 match self.covariance_type {
298 CovarianceType::Diag | CovarianceType::Full => {
299 let mut exponent = 0.0;
300 let mut det = 1.0;
301
302 for i in 0..n_features {
303 let diff = sample[i] - mean[i];
304 exponent += diff * diff / cov[i];
305 det *= cov[i];
306 }
307
308 let norm = 1.0 / ((2.0 * PI).powf(n_features as f32 / 2.0) * det.sqrt());
309 norm * (-0.5 * exponent).exp()
310 }
311 CovarianceType::Spherical => {
312 let variance = cov[0];
313 let mut exponent = 0.0;
314
315 for i in 0..n_features {
316 let diff = sample[i] - mean[i];
317 exponent += diff * diff;
318 }
319
320 let norm = 1.0 / ((2.0 * PI * variance).powf(n_features as f32 / 2.0));
321 norm * (-exponent / (2.0 * variance)).exp()
322 }
323 CovarianceType::Tied => {
324 let mut exponent = 0.0;
326 let mut det = 1.0;
327
328 for i in 0..n_features {
329 let diff = sample[i] - mean[i];
330 let var = if component == 0 { cov[i] } else { self.covariances[0][i] };
331 exponent += diff * diff / var;
332 det *= var;
333 }
334
335 let norm = 1.0 / ((2.0 * PI).powf(n_features as f32 / 2.0) * det.sqrt());
336 norm * (-0.5 * exponent).exp()
337 }
338 }
339 }
340
341 fn compute_log_likelihood(&self, x_data: &[f32], n_samples: usize, n_features: usize) -> f32 {
343 let mut log_likelihood = 0.0;
344
345 for i in 0..n_samples {
346 let sample = &x_data[i * n_features..(i + 1) * n_features];
347 let mut prob = 0.0;
348
349 for k in 0..self.n_components {
350 prob += self.weights[k] * self.gaussian_pdf(sample, k, n_features);
351 }
352
353 log_likelihood += prob.max(1e-10).ln();
354 }
355
356 log_likelihood
357 }
358
359 pub fn predict(&self, x: &Tensor) -> Tensor {
361 let n_samples = x.dims()[0];
362 let n_features = x.dims()[1];
363 let x_data = x.data_f32();
364
365 let labels: Vec<f32> = (0..n_samples)
366 .map(|i| {
367 let sample = &x_data[i * n_features..(i + 1) * n_features];
368 let mut max_prob = 0.0;
369 let mut best_component = 0;
370
371 for k in 0..self.n_components {
372 let prob = self.weights[k] * self.gaussian_pdf(sample, k, n_features);
373 if prob > max_prob {
374 max_prob = prob;
375 best_component = k;
376 }
377 }
378
379 best_component as f32
380 })
381 .collect();
382
383 Tensor::from_slice(&labels, &[n_samples]).unwrap()
384 }
385
386 pub fn predict_proba(&self, x: &Tensor) -> Tensor {
388 let n_samples = x.dims()[0];
389 let n_features = x.dims()[1];
390 let x_data = x.data_f32();
391
392 let mut probabilities = Vec::with_capacity(n_samples * self.n_components);
393
394 for i in 0..n_samples {
395 let sample = &x_data[i * n_features..(i + 1) * n_features];
396 let mut total = 0.0;
397 let mut probs = vec![0.0; self.n_components];
398
399 for k in 0..self.n_components {
400 probs[k] = self.weights[k] * self.gaussian_pdf(sample, k, n_features);
401 total += probs[k];
402 }
403
404 for k in 0..self.n_components {
406 probabilities.push(probs[k] / total);
407 }
408 }
409
410 Tensor::from_slice(&probabilities, &[n_samples, self.n_components]).unwrap()
411 }
412
413 pub fn sample(&self, n_samples: usize) -> Tensor {
415 let mut rng = thread_rng();
416 let n_features = self.means[0].len();
417 let mut samples = Vec::with_capacity(n_samples * n_features);
418
419 for _ in 0..n_samples {
420 let rand_val = rng.gen::<f32>();
422 let mut cumsum = 0.0;
423 let mut component = 0;
424
425 for (k, &weight) in self.weights.iter().enumerate() {
426 cumsum += weight;
427 if cumsum >= rand_val {
428 component = k;
429 break;
430 }
431 }
432
433 let mean = &self.means[component];
435 let cov = &self.covariances[component];
436
437 for j in 0..n_features {
438 let std = match self.covariance_type {
439 CovarianceType::Spherical => cov[0].sqrt(),
440 _ => cov[j].sqrt(),
441 };
442 let sample = mean[j] + rng.gen::<f32>() * std;
443 samples.push(sample);
444 }
445 }
446
447 Tensor::from_slice(&samples, &[n_samples, n_features]).unwrap()
448 }
449}
450
451#[cfg(test)]
452mod tests {
453 use super::*;
454
455 #[test]
456 fn test_gaussian_mixture() {
457 let x = Tensor::from_slice(
459 &[
460 0.0f32, 0.0,
461 0.1, 0.1,
462 5.0, 5.0,
463 5.1, 5.1,
464 ],
465 &[4, 2],
466 ).unwrap();
467
468 let mut gmm = GaussianMixture::new(2)
469 .covariance_type(CovarianceType::Diag)
470 .max_iter(50);
471
472 gmm.fit(&x);
473 let labels = gmm.predict(&x);
474
475 assert_eq!(labels.dims()[0], 4); }
477
478 #[test]
479 fn test_gmm_predict_proba() {
480 let x = Tensor::from_slice(
481 &[0.0f32, 0.0, 1.0, 1.0],
482 &[2, 2],
483 ).unwrap();
484
485 let mut gmm = GaussianMixture::new(2);
486 gmm.fit(&x);
487 let proba = gmm.predict_proba(&x);
488
489 assert_eq!(proba.dims()[0], 2); }
491}
492
493