flow_utils/clustering/
gmm.rs1use crate::clustering::{ClusteringError, ClusteringResult};
4use linfa::prelude::*;
5use linfa_clustering::GaussianMixtureModel as LinfaGmm;
6use ndarray::Array2;
7
8#[derive(Debug, Clone)]
10pub struct GmmConfig {
11 pub n_components: usize,
13 pub max_iterations: usize,
15 pub tolerance: f64,
17 pub seed: Option<u64>,
19}
20
21impl Default for GmmConfig {
22 fn default() -> Self {
23 Self {
24 n_components: 2,
25 max_iterations: 100,
26 tolerance: 1e-3,
27 seed: None,
28 }
29 }
30}
31
32#[derive(Debug)]
34pub struct GmmResult {
35 pub assignments: Vec<usize>,
37 pub means: Array2<f64>,
39 pub iterations: usize,
41 pub log_likelihood: f64,
43}
44
45pub struct Gmm;
47
48impl Gmm {
49 pub fn fit_from_rows(data_rows: Vec<Vec<f64>>, config: &GmmConfig) -> ClusteringResult<GmmResult> {
60 if data_rows.is_empty() {
61 return Err(ClusteringError::EmptyData);
62 }
63 let n_features = data_rows[0].len();
64 let n_samples = data_rows.len();
65
66 let flat: Vec<f64> = data_rows.into_iter().flatten().collect();
68 let data = Array2::from_shape_vec((n_samples, n_features), flat)
69 .map_err(|e| ClusteringError::ClusteringFailed(format!("Failed to create array: {:?}", e)))?;
70
71 Self::fit(&data, config)
72 }
73
74 pub fn fit(data: &Array2<f64>, config: &GmmConfig) -> ClusteringResult<GmmResult> {
83 if data.nrows() == 0 {
84 return Err(ClusteringError::EmptyData);
85 }
86
87 if data.nrows() < config.n_components {
88 return Err(ClusteringError::InsufficientData {
89 min: config.n_components,
90 actual: data.nrows(),
91 });
92 }
93
94 let dataset = DatasetBase::new(data.clone(), ());
97 let model = LinfaGmm::params(config.n_components)
98 .max_n_iterations(config.max_iterations as u64)
99 .tolerance(config.tolerance)
100 .fit(&dataset)
101 .map_err(|e| ClusteringError::ClusteringFailed(format!("{}", e)))?;
102
103 let assignments: Vec<usize> = (0..data.nrows())
106 .map(|i| {
107 let point = data.row(i);
108 let mut max_prob = f64::NEG_INFINITY;
109 let mut best_component = 0;
110 for (j, mean) in model.means().rows().into_iter().enumerate() {
112 let dist: f64 = point
113 .iter()
114 .zip(mean.iter())
115 .map(|(a, b)| (a - b).powi(2))
116 .sum();
117 let prob = (-dist).exp(); if prob > max_prob {
119 max_prob = prob;
120 best_component = j;
121 }
122 }
123 best_component
124 })
125 .collect();
126
127 let means = model.means().to_owned();
129
130 Ok(GmmResult {
131 assignments,
132 means,
133 iterations: config.max_iterations, log_likelihood: 0.0, })
136 }
137}