Skip to main content

flow_clustering/clustering/
kmeans.rs

1//! K-means clustering implementation
2
3use crate::clustering::{ClusteringError, ClusteringResult};
4use linfa::prelude::*;
5use linfa_clustering::KMeans as LinfaKMeans;
6use ndarray::Array2;
7
8/// Configuration for K-means clustering
9#[derive(Debug, Clone)]
10pub struct KMeansConfig {
11    /// Number of clusters
12    pub n_clusters: usize,
13    /// Maximum number of iterations
14    pub max_iterations: usize,
15    /// Tolerance for convergence
16    pub tolerance: f64,
17    /// Random seed for reproducibility
18    pub seed: Option<u64>,
19}
20
21impl Default for KMeansConfig {
22    fn default() -> Self {
23        Self {
24            n_clusters: 2,
25            max_iterations: 300,
26            tolerance: 1e-4,
27            seed: None,
28        }
29    }
30}
31
32/// K-means clustering result
33#[derive(Debug, Clone)]
34pub struct KMeansResult {
35    /// Cluster assignments for each point
36    pub assignments: Vec<usize>,
37    /// Cluster centroids
38    pub centroids: Array2<f64>,
39    /// Number of iterations performed
40    pub iterations: usize,
41    /// Inertia (sum of squared distances to centroids)
42    pub inertia: f64,
43}
44
45/// K-means clustering
46pub struct KMeans;
47
48impl KMeans {
49    /// Fit K-means clustering model to data from raw vectors
50    ///
51    /// Helper function to accept Vec<Vec<f64>> for version compatibility
52    ///
53    /// # Arguments
54    /// * `data_rows` - Input data as rows (n_samples × n_features)
55    /// * `config` - Configuration for K-means
56    ///
57    /// # Returns
58    /// KMeansResult with cluster assignments and centroids
59    pub fn fit_from_rows(
60        data_rows: Vec<Vec<f64>>,
61        config: &KMeansConfig,
62    ) -> ClusteringResult<KMeansResult> {
63        if data_rows.is_empty() {
64            return Err(ClusteringError::EmptyData);
65        }
66        let n_features = data_rows[0].len();
67        let n_samples = data_rows.len();
68
69        // Flatten and create Array2
70        let flat: Vec<f64> = data_rows.into_iter().flatten().collect();
71        let data = Array2::from_shape_vec((n_samples, n_features), flat).map_err(|e| {
72            ClusteringError::ClusteringFailed(format!("Failed to create array: {:?}", e))
73        })?;
74
75        Self::fit(&data, config)
76    }
77
78    /// Perform K-means clustering
79    ///
80    /// # Arguments
81    /// * `data` - Input data matrix (n_samples × n_features)
82    /// * `config` - Configuration for K-means
83    ///
84    /// # Returns
85    /// KMeansResult with cluster assignments and centroids
86    pub fn fit(data: &Array2<f64>, config: &KMeansConfig) -> ClusteringResult<KMeansResult> {
87        if data.nrows() == 0 {
88            return Err(ClusteringError::EmptyData);
89        }
90
91        if data.nrows() < config.n_clusters {
92            return Err(ClusteringError::InsufficientData {
93                min: config.n_clusters,
94                actual: data.nrows(),
95            });
96        }
97
98        // Use linfa-clustering for K-means
99        // Array2 implements Records, so we can create DatasetBase directly
100        // Note: linfa expects data as records (samples × features)
101        // Use DatasetBase::new with empty targets () for unsupervised learning
102        let dataset = DatasetBase::new(data.clone(), ());
103        let model = LinfaKMeans::params(config.n_clusters)
104            .max_n_iterations(config.max_iterations as u64)
105            .tolerance(config.tolerance)
106            .fit(&dataset)
107            .map_err(|e| ClusteringError::ClusteringFailed(format!("{}", e)))?;
108
109        // Extract assignments - KMeans doesn't have predict, use centroids to assign
110        let assignments: Vec<usize> = (0..data.nrows())
111            .map(|i| {
112                let point = data.row(i);
113                let mut min_dist = f64::INFINITY;
114                let mut best_cluster = 0;
115                for (j, centroid) in model.centroids().rows().into_iter().enumerate() {
116                    let dist: f64 = point
117                        .iter()
118                        .zip(centroid.iter())
119                        .map(|(a, b)| (a - b).powi(2))
120                        .sum();
121                    if dist < min_dist {
122                        min_dist = dist;
123                        best_cluster = j;
124                    }
125                }
126                best_cluster
127            })
128            .collect();
129
130        // Extract centroids (convert to Array2<f64>)
131        let centroids = model.centroids().to_owned();
132
133        // Calculate inertia
134        let inertia = Self::calculate_inertia(data, &centroids, &assignments);
135
136        Ok(KMeansResult {
137            assignments,
138            centroids,
139            iterations: config.max_iterations, // linfa doesn't expose n_iterations, use config
140            inertia,
141        })
142    }
143
144    /// Calculate inertia (sum of squared distances to centroids)
145    fn calculate_inertia(
146        data: &Array2<f64>,
147        centroids: &Array2<f64>,
148        assignments: &[usize],
149    ) -> f64 {
150        let mut inertia = 0.0;
151        for (i, assignment) in assignments.iter().enumerate() {
152            let point = data.row(i);
153            let centroid = centroids.row(*assignment);
154            let dist_sq: f64 = point
155                .iter()
156                .zip(centroid.iter())
157                .map(|(a, b)| (a - b).powi(2))
158                .sum();
159            inertia += dist_sq;
160        }
161        inertia
162    }
163}