Skip to main content

flow_clustering/clustering/
gmm.rs

1//! Gaussian Mixture Model clustering implementation
2
3use crate::clustering::{ClusteringError, ClusteringResult};
4use linfa::prelude::*;
5use linfa_clustering::GaussianMixtureModel as LinfaGmm;
6use ndarray::Array2;
7
8/// Configuration for GMM clustering
9#[derive(Debug, Clone)]
10pub struct GmmConfig {
11    /// Number of components
12    pub n_components: usize,
13    /// Maximum number of iterations
14    pub max_iterations: usize,
15    /// Tolerance for convergence
16    pub tolerance: f64,
17    /// Random seed for reproducibility
18    pub seed: Option<u64>,
19}
20
21impl Default for GmmConfig {
22    fn default() -> Self {
23        Self {
24            n_components: 2,
25            max_iterations: 100,
26            tolerance: 1e-3,
27            seed: None,
28        }
29    }
30}
31
32/// GMM clustering result
33#[derive(Debug, Clone)]
34pub struct GmmResult {
35    /// Cluster assignments for each point
36    pub assignments: Vec<usize>,
37    /// Component means
38    pub means: Array2<f64>,
39    /// Number of iterations performed
40    pub iterations: usize,
41    /// Log likelihood
42    pub log_likelihood: f64,
43}
44
45/// Gaussian Mixture Model clustering
46pub struct Gmm;
47
48impl Gmm {
49    /// Fit GMM clustering model to data from raw vectors
50    ///
51    /// Helper function to accept Vec<Vec<f64>> for version compatibility
52    ///
53    /// # Arguments
54    /// * `data_rows` - Input data as rows (n_samples × n_features)
55    /// * `config` - Configuration for GMM
56    ///
57    /// # Returns
58    /// GmmResult with component assignments and means
59    pub fn fit_from_rows(
60        data_rows: Vec<Vec<f64>>,
61        config: &GmmConfig,
62    ) -> ClusteringResult<GmmResult> {
63        if data_rows.is_empty() {
64            return Err(ClusteringError::EmptyData);
65        }
66        let n_features = data_rows[0].len();
67        let n_samples = data_rows.len();
68
69        // Flatten and create Array2
70        let flat: Vec<f64> = data_rows.into_iter().flatten().collect();
71        let data = Array2::from_shape_vec((n_samples, n_features), flat).map_err(|e| {
72            ClusteringError::ClusteringFailed(format!("Failed to create array: {:?}", e))
73        })?;
74
75        Self::fit(&data, config)
76    }
77
78    /// Perform GMM clustering
79    ///
80    /// # Arguments
81    /// * `data` - Input data matrix (n_samples × n_features)
82    /// * `config` - Configuration for GMM
83    ///
84    /// # Returns
85    /// GmmResult with cluster assignments and means
86    pub fn fit(data: &Array2<f64>, config: &GmmConfig) -> ClusteringResult<GmmResult> {
87        if data.nrows() == 0 {
88            return Err(ClusteringError::EmptyData);
89        }
90
91        if data.nrows() < config.n_components {
92            return Err(ClusteringError::InsufficientData {
93                min: config.n_components,
94                actual: data.nrows(),
95            });
96        }
97
98        // Use linfa-clustering for GMM
99        // Use DatasetBase::new with empty targets () for unsupervised learning
100        let dataset = DatasetBase::new(data.clone(), ());
101        let model = LinfaGmm::params(config.n_components)
102            .max_n_iterations(config.max_iterations as u64)
103            .tolerance(config.tolerance)
104            .fit(&dataset)
105            .map_err(|e| ClusteringError::ClusteringFailed(format!("{}", e)))?;
106
107        // Extract assignments (hard assignment: most likely component)
108        // GMM predict returns probabilities, we need to find argmax for each point
109        let assignments: Vec<usize> = (0..data.nrows())
110            .map(|i| {
111                let point = data.row(i);
112                let mut max_prob = f64::NEG_INFINITY;
113                let mut best_component = 0;
114                // Calculate probability for each component (simplified - use means distance)
115                for (j, mean) in model.means().rows().into_iter().enumerate() {
116                    let dist: f64 = point
117                        .iter()
118                        .zip(mean.iter())
119                        .map(|(a, b)| (a - b).powi(2))
120                        .sum();
121                    let prob = (-dist).exp(); // Simplified probability
122                    if prob > max_prob {
123                        max_prob = prob;
124                        best_component = j;
125                    }
126                }
127                best_component
128            })
129            .collect();
130
131        // Extract means (convert to Array2<f64>)
132        let means = model.means().to_owned();
133
134        Ok(GmmResult {
135            assignments,
136            means,
137            iterations: config.max_iterations, // linfa doesn't expose n_iterations
138            log_likelihood: 0.0,               // linfa doesn't expose log_likelihood directly
139        })
140    }
141}