Skip to main content

flow_utils/clustering/
kmeans.rs

1//! K-means clustering implementation
2
3use crate::clustering::{ClusteringError, ClusteringResult};
4use linfa::prelude::*;
5use linfa_clustering::KMeans as LinfaKMeans;
6use ndarray::Array2;
7
8/// Configuration for K-means clustering
9#[derive(Debug, Clone)]
10pub struct KMeansConfig {
11    /// Number of clusters
12    pub n_clusters: usize,
13    /// Maximum number of iterations
14    pub max_iterations: usize,
15    /// Tolerance for convergence
16    pub tolerance: f64,
17    /// Random seed for reproducibility
18    pub seed: Option<u64>,
19}
20
21impl Default for KMeansConfig {
22    fn default() -> Self {
23        Self {
24            n_clusters: 2,
25            max_iterations: 300,
26            tolerance: 1e-4,
27            seed: None,
28        }
29    }
30}
31
32/// K-means clustering result
33#[derive(Debug, Clone)]
34pub struct KMeansResult {
35    /// Cluster assignments for each point
36    pub assignments: Vec<usize>,
37    /// Cluster centroids
38    pub centroids: Array2<f64>,
39    /// Number of iterations performed
40    pub iterations: usize,
41    /// Inertia (sum of squared distances to centroids)
42    pub inertia: f64,
43}
44
45/// K-means clustering
46pub struct KMeans;
47
48impl KMeans {
49    /// Fit K-means clustering model to data from raw vectors
50    /// 
51    /// Helper function to accept Vec<Vec<f64>> for version compatibility
52    ///
53    /// # Arguments
54    /// * `data_rows` - Input data as rows (n_samples × n_features)
55    /// * `config` - Configuration for K-means
56    ///
57    /// # Returns
58    /// KMeansResult with cluster assignments and centroids
59    pub fn fit_from_rows(data_rows: Vec<Vec<f64>>, config: &KMeansConfig) -> ClusteringResult<KMeansResult> {
60        if data_rows.is_empty() {
61            return Err(ClusteringError::EmptyData);
62        }
63        let n_features = data_rows[0].len();
64        let n_samples = data_rows.len();
65        
66        // Flatten and create Array2
67        let flat: Vec<f64> = data_rows.into_iter().flatten().collect();
68        let data = Array2::from_shape_vec((n_samples, n_features), flat)
69            .map_err(|e| ClusteringError::ClusteringFailed(format!("Failed to create array: {:?}", e)))?;
70        
71        Self::fit(&data, config)
72    }
73    
74    /// Perform K-means clustering
75    ///
76    /// # Arguments
77    /// * `data` - Input data matrix (n_samples × n_features)
78    /// * `config` - Configuration for K-means
79    ///
80    /// # Returns
81    /// KMeansResult with cluster assignments and centroids
82    pub fn fit(data: &Array2<f64>, config: &KMeansConfig) -> ClusteringResult<KMeansResult> {
83        if data.nrows() == 0 {
84            return Err(ClusteringError::EmptyData);
85        }
86
87        if data.nrows() < config.n_clusters {
88            return Err(ClusteringError::InsufficientData {
89                min: config.n_clusters,
90                actual: data.nrows(),
91            });
92        }
93
94        // Use linfa-clustering for K-means
95        // Array2 implements Records, so we can create DatasetBase directly
96        // Note: linfa expects data as records (samples × features)
97        // Use DatasetBase::new with empty targets () for unsupervised learning
98        let dataset = DatasetBase::new(data.clone(), ());
99        let model = LinfaKMeans::params(config.n_clusters)
100            .max_n_iterations(config.max_iterations as u64)
101            .tolerance(config.tolerance)
102            .fit(&dataset)
103            .map_err(|e| ClusteringError::ClusteringFailed(format!("{}", e)))?;
104
105        // Extract assignments - KMeans doesn't have predict, use centroids to assign
106        let assignments: Vec<usize> = (0..data.nrows())
107            .map(|i| {
108                let point = data.row(i);
109                let mut min_dist = f64::INFINITY;
110                let mut best_cluster = 0;
111                for (j, centroid) in model.centroids().rows().into_iter().enumerate() {
112                    let dist: f64 = point
113                        .iter()
114                        .zip(centroid.iter())
115                        .map(|(a, b)| (a - b).powi(2))
116                        .sum();
117                    if dist < min_dist {
118                        min_dist = dist;
119                        best_cluster = j;
120                    }
121                }
122                best_cluster
123            })
124            .collect();
125
126        // Extract centroids (convert to Array2<f64>)
127        let centroids = model.centroids().to_owned();
128
129        // Calculate inertia
130        let inertia = Self::calculate_inertia(data, &centroids, &assignments);
131
132        Ok(KMeansResult {
133            assignments,
134            centroids,
135            iterations: config.max_iterations, // linfa doesn't expose n_iterations, use config
136            inertia,
137        })
138    }
139
140    /// Calculate inertia (sum of squared distances to centroids)
141    fn calculate_inertia(
142        data: &Array2<f64>,
143        centroids: &Array2<f64>,
144        assignments: &[usize],
145    ) -> f64 {
146        let mut inertia = 0.0;
147        for (i, assignment) in assignments.iter().enumerate() {
148            let point = data.row(i);
149            let centroid = centroids.row(*assignment);
150            let dist_sq: f64 = point
151                .iter()
152                .zip(centroid.iter())
153                .map(|(a, b)| (a - b).powi(2))
154                .sum();
155            inertia += dist_sq;
156        }
157        inertia
158    }
159}