1use crate::algorithm::{kmeans_double_chunked, predict_labels};
2use crate::config::KMeansConfig;
3use crate::error::KMeansError;
4use ndarray::{Array1, Array2, ArrayView2};
5
6pub struct FastKMeans {
30 config: KMeansConfig,
32
33 d: usize,
35
36 centroids: Option<Array2<f32>>,
38}
39
40impl FastKMeans {
41 pub fn new(d: usize, k: usize) -> Self {
52 assert!(k > 0, "k must be greater than 0");
53
54 Self {
55 config: KMeansConfig::new(k),
56 d,
57 centroids: None,
58 }
59 }
60
61 pub fn with_config(config: KMeansConfig) -> Self {
71 assert!(config.k > 0, "k must be greater than 0");
72
73 Self {
74 d: 0, config,
76 centroids: None,
77 }
78 }
79
80 pub fn train(&mut self, data: &ArrayView2<f32>) -> Result<(), KMeansError> {
94 let n_features = data.ncols();
95
96 if self.d == 0 {
98 self.d = n_features;
99 } else if n_features != self.d {
100 return Err(KMeansError::InvalidDimensions(format!(
101 "Expected {} features, got {}",
102 self.d, n_features
103 )));
104 }
105
106 let result = kmeans_double_chunked(data, &self.config)?;
108
109 self.centroids = Some(result.centroids);
110 Ok(())
111 }
112
113 pub fn fit(&mut self, data: &ArrayView2<f32>) -> Result<&mut Self, KMeansError> {
126 self.train(data)?;
127 Ok(self)
128 }
129
130 pub fn predict(&self, data: &ArrayView2<f32>) -> Result<Array1<i64>, KMeansError> {
148 let centroids = self.centroids.as_ref().ok_or(KMeansError::NotFitted)?;
149
150 let n_features = data.ncols();
151 if n_features != self.d {
152 return Err(KMeansError::InvalidDimensions(format!(
153 "Expected {} features, got {}",
154 self.d, n_features
155 )));
156 }
157
158 let labels = predict_labels(
159 data,
160 ¢roids.view(),
161 self.config.chunk_size_data,
162 self.config.chunk_size_centroids,
163 );
164
165 Ok(labels)
166 }
167
168 pub fn fit_predict(&mut self, data: &ArrayView2<f32>) -> Result<Array1<i64>, KMeansError> {
180 self.train(data)?;
181 self.predict(data)
182 }
183
184 pub fn centroids(&self) -> Option<&Array2<f32>> {
190 self.centroids.as_ref()
191 }
192
193 pub fn k(&self) -> usize {
195 self.config.k
196 }
197
198 pub fn d(&self) -> usize {
200 self.d
201 }
202
203 pub fn config(&self) -> &KMeansConfig {
205 &self.config
206 }
207}
208
209#[cfg(test)]
210mod tests {
211 use super::*;
212 use ndarray::Array2;
213 use ndarray_rand::rand_distr::Uniform;
214 use ndarray_rand::RandomExt;
215
216 #[test]
217 fn test_fastkmeans_new() {
218 let kmeans = FastKMeans::new(128, 10);
219 assert_eq!(kmeans.k(), 10);
220 assert_eq!(kmeans.d(), 128);
221 assert!(kmeans.centroids().is_none());
222 }
223
224 #[test]
225 fn test_fastkmeans_train() {
226 let data = Array2::random((500, 32), Uniform::new(-1.0f32, 1.0));
227 let mut kmeans = FastKMeans::new(32, 5);
228
229 kmeans.train(&data.view()).unwrap();
230
231 assert!(kmeans.centroids().is_some());
232 let centroids = kmeans.centroids().unwrap();
233 assert_eq!(centroids.nrows(), 5);
234 assert_eq!(centroids.ncols(), 32);
235 }
236
237 #[test]
238 fn test_fastkmeans_fit() {
239 let data = Array2::random((500, 32), Uniform::new(-1.0f32, 1.0));
240 let mut kmeans = FastKMeans::new(32, 5);
241
242 let result = kmeans.fit(&data.view());
243 assert!(result.is_ok());
244 assert!(kmeans.centroids().is_some());
245 }
246
247 #[test]
248 fn test_fastkmeans_predict() {
249 let train_data = Array2::random((500, 16), Uniform::new(-1.0f32, 1.0));
250 let test_data = Array2::random((100, 16), Uniform::new(-1.0f32, 1.0));
251
252 let mut kmeans = FastKMeans::new(16, 8);
253 kmeans.train(&train_data.view()).unwrap();
254
255 let labels = kmeans.predict(&test_data.view()).unwrap();
256 assert_eq!(labels.len(), 100);
257
258 for &label in labels.iter() {
259 assert!((0..8).contains(&label));
260 }
261 }
262
263 #[test]
264 fn test_fastkmeans_fit_predict() {
265 let data = Array2::random((300, 8), Uniform::new(-1.0f32, 1.0));
266 let mut kmeans = FastKMeans::new(8, 4);
267
268 let labels = kmeans.fit_predict(&data.view()).unwrap();
269 assert_eq!(labels.len(), 300);
270 assert!(kmeans.centroids().is_some());
271 }
272
273 #[test]
274 fn test_fastkmeans_predict_before_fit() {
275 let data = Array2::random((100, 8), Uniform::new(-1.0f32, 1.0));
276 let kmeans = FastKMeans::new(8, 5);
277
278 let result = kmeans.predict(&data.view());
279 assert!(matches!(result, Err(KMeansError::NotFitted)));
280 }
281
282 #[test]
283 fn test_fastkmeans_dimension_mismatch() {
284 let train_data = Array2::random((100, 8), Uniform::new(-1.0f32, 1.0));
285 let test_data = Array2::random((50, 16), Uniform::new(-1.0f32, 1.0));
286
287 let mut kmeans = FastKMeans::new(8, 5);
288 kmeans.train(&train_data.view()).unwrap();
289
290 let result = kmeans.predict(&test_data.view());
291 assert!(matches!(result, Err(KMeansError::InvalidDimensions(_))));
292 }
293
294 #[test]
295 #[should_panic(expected = "k must be greater than 0")]
296 fn test_fastkmeans_k_zero() {
297 let _ = FastKMeans::new(8, 0);
298 }
299}