Skip to main content

flow_utils/clustering/
gmm.rs

1//! Gaussian Mixture Model clustering implementation
2
3use crate::clustering::{ClusteringError, ClusteringResult};
4use linfa::prelude::*;
5use linfa_clustering::GaussianMixtureModel as LinfaGmm;
6use ndarray::Array2;
7
8/// Configuration for GMM clustering
9#[derive(Debug, Clone)]
10pub struct GmmConfig {
11    /// Number of components
12    pub n_components: usize,
13    /// Maximum number of iterations
14    pub max_iterations: usize,
15    /// Tolerance for convergence
16    pub tolerance: f64,
17    /// Random seed for reproducibility
18    pub seed: Option<u64>,
19}
20
21impl Default for GmmConfig {
22    fn default() -> Self {
23        Self {
24            n_components: 2,
25            max_iterations: 100,
26            tolerance: 1e-3,
27            seed: None,
28        }
29    }
30}
31
32/// GMM clustering result
33#[derive(Debug)]
34pub struct GmmResult {
35    /// Cluster assignments for each point
36    pub assignments: Vec<usize>,
37    /// Component means
38    pub means: Array2<f64>,
39    /// Number of iterations performed
40    pub iterations: usize,
41    /// Log likelihood
42    pub log_likelihood: f64,
43}
44
45/// Gaussian Mixture Model clustering
46pub struct Gmm;
47
48impl Gmm {
49    /// Fit GMM clustering model to data from raw vectors
50    /// 
51    /// Helper function to accept Vec<Vec<f64>> for version compatibility
52    ///
53    /// # Arguments
54    /// * `data_rows` - Input data as rows (n_samples × n_features)
55    /// * `config` - Configuration for GMM
56    ///
57    /// # Returns
58    /// GmmResult with component assignments and means
59    pub fn fit_from_rows(data_rows: Vec<Vec<f64>>, config: &GmmConfig) -> ClusteringResult<GmmResult> {
60        if data_rows.is_empty() {
61            return Err(ClusteringError::EmptyData);
62        }
63        let n_features = data_rows[0].len();
64        let n_samples = data_rows.len();
65        
66        // Flatten and create Array2
67        let flat: Vec<f64> = data_rows.into_iter().flatten().collect();
68        let data = Array2::from_shape_vec((n_samples, n_features), flat)
69            .map_err(|e| ClusteringError::ClusteringFailed(format!("Failed to create array: {:?}", e)))?;
70        
71        Self::fit(&data, config)
72    }
73    
74    /// Perform GMM clustering
75    ///
76    /// # Arguments
77    /// * `data` - Input data matrix (n_samples × n_features)
78    /// * `config` - Configuration for GMM
79    ///
80    /// # Returns
81    /// GmmResult with cluster assignments and means
82    pub fn fit(data: &Array2<f64>, config: &GmmConfig) -> ClusteringResult<GmmResult> {
83        if data.nrows() == 0 {
84            return Err(ClusteringError::EmptyData);
85        }
86
87        if data.nrows() < config.n_components {
88            return Err(ClusteringError::InsufficientData {
89                min: config.n_components,
90                actual: data.nrows(),
91            });
92        }
93
94        // Use linfa-clustering for GMM
95        // Use DatasetBase::new with empty targets () for unsupervised learning
96        let dataset = DatasetBase::new(data.clone(), ());
97        let model = LinfaGmm::params(config.n_components)
98            .max_n_iterations(config.max_iterations as u64)
99            .tolerance(config.tolerance)
100            .fit(&dataset)
101            .map_err(|e| ClusteringError::ClusteringFailed(format!("{}", e)))?;
102
103        // Extract assignments (hard assignment: most likely component)
104        // GMM predict returns probabilities, we need to find argmax for each point
105        let assignments: Vec<usize> = (0..data.nrows())
106            .map(|i| {
107                let point = data.row(i);
108                let mut max_prob = f64::NEG_INFINITY;
109                let mut best_component = 0;
110                // Calculate probability for each component (simplified - use means distance)
111                for (j, mean) in model.means().rows().into_iter().enumerate() {
112                    let dist: f64 = point
113                        .iter()
114                        .zip(mean.iter())
115                        .map(|(a, b)| (a - b).powi(2))
116                        .sum();
117                    let prob = (-dist).exp(); // Simplified probability
118                    if prob > max_prob {
119                        max_prob = prob;
120                        best_component = j;
121                    }
122                }
123                best_component
124            })
125            .collect();
126
127        // Extract means (convert to Array2<f64>)
128        let means = model.means().to_owned();
129
130        Ok(GmmResult {
131            assignments,
132            means,
133            iterations: config.max_iterations, // linfa doesn't expose n_iterations
134            log_likelihood: 0.0, // linfa doesn't expose log_likelihood directly
135        })
136    }
137}