fastkmeans_rs/
kmeans.rs

1use crate::algorithm::{kmeans_double_chunked, predict_labels};
2use crate::config::KMeansConfig;
3use crate::error::KMeansError;
4use ndarray::{Array1, Array2, ArrayView2};
5
6/// Fast k-means clustering implementation compatible with ndarray.
7///
8/// This implementation uses double-chunking to process large datasets efficiently
9/// without running out of memory. It provides an API similar to FAISS and scikit-learn.
10///
11/// # Example
12///
13/// ```
14/// use fastkmeans_rs::FastKMeans;
15/// use ndarray::Array2;
16/// use ndarray_rand::RandomExt;
17/// use ndarray_rand::rand_distr::Uniform;
18///
19/// // Generate random data
20/// let data = Array2::random((1000, 128), Uniform::new(-1.0f32, 1.0));
21///
22/// // Create and train the model
23/// let mut kmeans = FastKMeans::new(128, 10);
24/// kmeans.train(&data.view()).unwrap();
25///
26/// // Get cluster assignments
27/// let labels = kmeans.predict(&data.view()).unwrap();
28/// ```
29pub struct FastKMeans {
30    /// Model configuration
31    config: KMeansConfig,
32
33    /// Number of features (dimensions)
34    d: usize,
35
36    /// Trained centroids (None if not yet fitted)
37    centroids: Option<Array2<f32>>,
38}
39
40impl FastKMeans {
41    /// Create a new FastKMeans instance with default configuration.
42    ///
43    /// # Arguments
44    ///
45    /// * `d` - Number of features (dimensions) in the data
46    /// * `k` - Number of clusters
47    ///
48    /// # Panics
49    ///
50    /// Panics if `k` is 0.
51    pub fn new(d: usize, k: usize) -> Self {
52        assert!(k > 0, "k must be greater than 0");
53
54        Self {
55            config: KMeansConfig::new(k),
56            d,
57            centroids: None,
58        }
59    }
60
61    /// Create a new FastKMeans instance with custom configuration.
62    ///
63    /// # Arguments
64    ///
65    /// * `config` - Custom configuration for the k-means algorithm
66    ///
67    /// # Panics
68    ///
69    /// Panics if `config.k` is 0.
70    pub fn with_config(config: KMeansConfig) -> Self {
71        assert!(config.k > 0, "k must be greater than 0");
72
73        Self {
74            d: 0, // Will be set on first train call
75            config,
76            centroids: None,
77        }
78    }
79
80    /// Train the k-means model on the given data.
81    ///
82    /// This method mimics the FAISS `train()` API.
83    ///
84    /// # Arguments
85    ///
86    /// * `data` - Training data of shape (n_samples, n_features)
87    ///
88    /// # Errors
89    ///
90    /// Returns an error if:
91    /// - Number of samples is less than k
92    /// - Data dimensions don't match (for subsequent calls)
93    pub fn train(&mut self, data: &ArrayView2<f32>) -> Result<(), KMeansError> {
94        let n_features = data.ncols();
95
96        // Set dimensions on first call, validate on subsequent calls
97        if self.d == 0 {
98            self.d = n_features;
99        } else if n_features != self.d {
100            return Err(KMeansError::InvalidDimensions(format!(
101                "Expected {} features, got {}",
102                self.d, n_features
103            )));
104        }
105
106        // Run the k-means algorithm
107        let result = kmeans_double_chunked(data, &self.config)?;
108
109        self.centroids = Some(result.centroids);
110        Ok(())
111    }
112
113    /// Fit the model to the data.
114    ///
115    /// This method mimics the scikit-learn `fit()` API.
116    /// It is equivalent to `train()`.
117    ///
118    /// # Arguments
119    ///
120    /// * `data` - Training data of shape (n_samples, n_features)
121    ///
122    /// # Returns
123    ///
124    /// Returns `&mut Self` for method chaining.
125    pub fn fit(&mut self, data: &ArrayView2<f32>) -> Result<&mut Self, KMeansError> {
126        self.train(data)?;
127        Ok(self)
128    }
129
130    /// Predict cluster assignments for new data.
131    ///
132    /// This method mimics the scikit-learn `predict()` API.
133    ///
134    /// # Arguments
135    ///
136    /// * `data` - Data to predict, of shape (n_samples, n_features)
137    ///
138    /// # Returns
139    ///
140    /// Returns an array of cluster labels of shape (n_samples,).
141    ///
142    /// # Errors
143    ///
144    /// Returns an error if:
145    /// - The model has not been fitted yet
146    /// - Data dimensions don't match the training data
147    pub fn predict(&self, data: &ArrayView2<f32>) -> Result<Array1<i64>, KMeansError> {
148        let centroids = self.centroids.as_ref().ok_or(KMeansError::NotFitted)?;
149
150        let n_features = data.ncols();
151        if n_features != self.d {
152            return Err(KMeansError::InvalidDimensions(format!(
153                "Expected {} features, got {}",
154                self.d, n_features
155            )));
156        }
157
158        let labels = predict_labels(
159            data,
160            &centroids.view(),
161            self.config.chunk_size_data,
162            self.config.chunk_size_centroids,
163        );
164
165        Ok(labels)
166    }
167
168    /// Fit the model and predict cluster assignments in one call.
169    ///
170    /// This method mimics the scikit-learn `fit_predict()` API.
171    ///
172    /// # Arguments
173    ///
174    /// * `data` - Training data of shape (n_samples, n_features)
175    ///
176    /// # Returns
177    ///
178    /// Returns an array of cluster labels of shape (n_samples,).
179    pub fn fit_predict(&mut self, data: &ArrayView2<f32>) -> Result<Array1<i64>, KMeansError> {
180        self.train(data)?;
181        self.predict(data)
182    }
183
184    /// Get the centroids of the fitted model.
185    ///
186    /// # Returns
187    ///
188    /// Returns `Some(&Array2<f32>)` if the model has been fitted, `None` otherwise.
189    pub fn centroids(&self) -> Option<&Array2<f32>> {
190        self.centroids.as_ref()
191    }
192
193    /// Get the number of clusters.
194    pub fn k(&self) -> usize {
195        self.config.k
196    }
197
198    /// Get the number of features (dimensions).
199    pub fn d(&self) -> usize {
200        self.d
201    }
202
203    /// Get the configuration.
204    pub fn config(&self) -> &KMeansConfig {
205        &self.config
206    }
207}
208
209#[cfg(test)]
210mod tests {
211    use super::*;
212    use ndarray::Array2;
213    use ndarray_rand::rand_distr::Uniform;
214    use ndarray_rand::RandomExt;
215
216    #[test]
217    fn test_fastkmeans_new() {
218        let kmeans = FastKMeans::new(128, 10);
219        assert_eq!(kmeans.k(), 10);
220        assert_eq!(kmeans.d(), 128);
221        assert!(kmeans.centroids().is_none());
222    }
223
224    #[test]
225    fn test_fastkmeans_train() {
226        let data = Array2::random((500, 32), Uniform::new(-1.0f32, 1.0));
227        let mut kmeans = FastKMeans::new(32, 5);
228
229        kmeans.train(&data.view()).unwrap();
230
231        assert!(kmeans.centroids().is_some());
232        let centroids = kmeans.centroids().unwrap();
233        assert_eq!(centroids.nrows(), 5);
234        assert_eq!(centroids.ncols(), 32);
235    }
236
237    #[test]
238    fn test_fastkmeans_fit() {
239        let data = Array2::random((500, 32), Uniform::new(-1.0f32, 1.0));
240        let mut kmeans = FastKMeans::new(32, 5);
241
242        let result = kmeans.fit(&data.view());
243        assert!(result.is_ok());
244        assert!(kmeans.centroids().is_some());
245    }
246
247    #[test]
248    fn test_fastkmeans_predict() {
249        let train_data = Array2::random((500, 16), Uniform::new(-1.0f32, 1.0));
250        let test_data = Array2::random((100, 16), Uniform::new(-1.0f32, 1.0));
251
252        let mut kmeans = FastKMeans::new(16, 8);
253        kmeans.train(&train_data.view()).unwrap();
254
255        let labels = kmeans.predict(&test_data.view()).unwrap();
256        assert_eq!(labels.len(), 100);
257
258        for &label in labels.iter() {
259            assert!((0..8).contains(&label));
260        }
261    }
262
263    #[test]
264    fn test_fastkmeans_fit_predict() {
265        let data = Array2::random((300, 8), Uniform::new(-1.0f32, 1.0));
266        let mut kmeans = FastKMeans::new(8, 4);
267
268        let labels = kmeans.fit_predict(&data.view()).unwrap();
269        assert_eq!(labels.len(), 300);
270        assert!(kmeans.centroids().is_some());
271    }
272
273    #[test]
274    fn test_fastkmeans_predict_before_fit() {
275        let data = Array2::random((100, 8), Uniform::new(-1.0f32, 1.0));
276        let kmeans = FastKMeans::new(8, 5);
277
278        let result = kmeans.predict(&data.view());
279        assert!(matches!(result, Err(KMeansError::NotFitted)));
280    }
281
282    #[test]
283    fn test_fastkmeans_dimension_mismatch() {
284        let train_data = Array2::random((100, 8), Uniform::new(-1.0f32, 1.0));
285        let test_data = Array2::random((50, 16), Uniform::new(-1.0f32, 1.0));
286
287        let mut kmeans = FastKMeans::new(8, 5);
288        kmeans.train(&train_data.view()).unwrap();
289
290        let result = kmeans.predict(&test_data.view());
291        assert!(matches!(result, Err(KMeansError::InvalidDimensions(_))));
292    }
293
294    #[test]
295    #[should_panic(expected = "k must be greater than 0")]
296    fn test_fastkmeans_k_zero() {
297        let _ = FastKMeans::new(8, 0);
298    }
299}