manifoldb_vector/quantization/
training.rs1use crate::distance::DistanceMetric;
6use crate::error::VectorError;
7
8#[derive(Debug, Clone)]
10pub struct KMeansConfig {
11 pub k: usize,
13 pub max_iterations: usize,
15 pub convergence_threshold: f32,
17 pub seed: Option<u64>,
19}
20
21impl Default for KMeansConfig {
22 fn default() -> Self {
23 Self { k: 256, max_iterations: 25, convergence_threshold: 1e-6, seed: None }
24 }
25}
26
27impl KMeansConfig {
28 #[must_use]
30 pub fn new(k: usize) -> Self {
31 Self { k, ..Default::default() }
32 }
33
34 #[must_use]
36 pub const fn with_max_iterations(mut self, iterations: usize) -> Self {
37 self.max_iterations = iterations;
38 self
39 }
40
41 #[must_use]
43 pub const fn with_convergence_threshold(mut self, threshold: f32) -> Self {
44 self.convergence_threshold = threshold;
45 self
46 }
47
48 #[must_use]
50 pub const fn with_seed(mut self, seed: u64) -> Self {
51 self.seed = Some(seed);
52 self
53 }
54}
55
56#[derive(Debug, Clone)]
58pub struct KMeans {
59 pub centroids: Vec<Vec<f32>>,
61 pub dimension: usize,
63 pub iterations: usize,
65 pub inertia: f32,
67}
68
69impl KMeans {
70 pub fn train(
85 data: &[&[f32]],
86 config: &KMeansConfig,
87 metric: DistanceMetric,
88 ) -> Result<Self, VectorError> {
89 if data.is_empty() {
90 return Err(VectorError::Encoding("cannot train k-means on empty data".to_string()));
91 }
92
93 let dimension = data[0].len();
94 if dimension == 0 {
95 return Err(VectorError::InvalidDimension { expected: 1, actual: 0 });
96 }
97
98 for (i, v) in data.iter().enumerate() {
100 if v.len() != dimension {
101 return Err(VectorError::DimensionMismatch {
102 expected: dimension,
103 actual: v.len(),
104 });
105 }
106 if i > 1000 {
108 break; }
110 }
111
112 let k = config.k.min(data.len());
113 if k == 0 {
114 return Err(VectorError::Encoding("k must be > 0".to_string()));
115 }
116
117 let mut centroids = Self::kmeans_plus_plus_init(data, k, config.seed);
119
120 let mut assignments = vec![0usize; data.len()];
122 let mut iterations = 0;
123 let mut inertia = f32::MAX;
124
125 for _ in 0..config.max_iterations {
126 iterations += 1;
127
128 let new_inertia = Self::assign_clusters(data, ¢roids, &mut assignments, metric);
130
131 let new_centroids = Self::update_centroids(data, &assignments, k, dimension);
133
134 let max_movement = Self::max_centroid_movement(¢roids, &new_centroids, metric);
136 centroids = new_centroids;
137 inertia = new_inertia;
138
139 if max_movement < config.convergence_threshold {
140 break;
141 }
142 }
143
144 Ok(Self { centroids, dimension, iterations, inertia })
145 }
146
147 fn kmeans_plus_plus_init(data: &[&[f32]], k: usize, seed: Option<u64>) -> Vec<Vec<f32>> {
150 let mut rng_state = seed.unwrap_or_else(|| {
151 std::time::SystemTime::now()
152 .duration_since(std::time::UNIX_EPOCH)
153 .map(|d| d.as_nanos() as u64)
154 .unwrap_or(42)
155 });
156
157 let mut centroids: Vec<Vec<f32>> = Vec::with_capacity(k);
158
159 let first_idx = Self::random_index(&mut rng_state, data.len());
161 centroids.push(data[first_idx].to_vec());
162
163 for _ in 1..k {
165 let mut distances: Vec<f32> = Vec::with_capacity(data.len());
166 let mut total_dist = 0.0f32;
167
168 for point in data {
169 let min_dist = centroids
171 .iter()
172 .map(|c| Self::squared_euclidean_distance(point, c))
173 .fold(f32::MAX, f32::min);
174
175 distances.push(min_dist);
176 total_dist += min_dist;
177 }
178
179 if total_dist <= 0.0 {
181 let idx = Self::random_index(&mut rng_state, data.len());
183 centroids.push(data[idx].to_vec());
184 } else {
185 let threshold = Self::random_f32(&mut rng_state) * total_dist;
186 let mut cumsum = 0.0f32;
187 let mut selected_idx = data.len() - 1;
188
189 for (i, &d) in distances.iter().enumerate() {
190 cumsum += d;
191 if cumsum >= threshold {
192 selected_idx = i;
193 break;
194 }
195 }
196
197 centroids.push(data[selected_idx].to_vec());
198 }
199 }
200
201 centroids
202 }
203
204 fn assign_clusters(
207 data: &[&[f32]],
208 centroids: &[Vec<f32>],
209 assignments: &mut [usize],
210 metric: DistanceMetric,
211 ) -> f32 {
212 let mut total_inertia = 0.0f32;
213
214 for (i, point) in data.iter().enumerate() {
215 let mut min_dist = f32::MAX;
216 let mut min_idx = 0;
217
218 for (j, centroid) in centroids.iter().enumerate() {
219 let dist = Self::compute_distance(point, centroid, metric);
220 if dist < min_dist {
221 min_dist = dist;
222 min_idx = j;
223 }
224 }
225
226 assignments[i] = min_idx;
227 total_inertia += min_dist * min_dist;
228 }
229
230 total_inertia
231 }
232
233 fn update_centroids(
235 data: &[&[f32]],
236 assignments: &[usize],
237 k: usize,
238 dimension: usize,
239 ) -> Vec<Vec<f32>> {
240 let mut new_centroids = vec![vec![0.0f32; dimension]; k];
241 let mut counts = vec![0usize; k];
242
243 for (point, &cluster) in data.iter().zip(assignments.iter()) {
245 counts[cluster] += 1;
246 for (j, &val) in point.iter().enumerate() {
247 new_centroids[cluster][j] += val;
248 }
249 }
250
251 for (centroid, &count) in new_centroids.iter_mut().zip(counts.iter()) {
253 if count > 0 {
254 let count_f32 = count as f32;
255 for val in centroid.iter_mut() {
256 *val /= count_f32;
257 }
258 }
259 }
260
261 for (i, centroid) in new_centroids.iter_mut().enumerate() {
263 if counts[i] == 0 && !data.is_empty() {
264 let idx = i % data.len();
266 centroid.copy_from_slice(data[idx]);
267 }
268 }
269
270 new_centroids
271 }
272
273 fn max_centroid_movement(old: &[Vec<f32>], new: &[Vec<f32>], metric: DistanceMetric) -> f32 {
275 old.iter()
276 .zip(new.iter())
277 .map(|(o, n)| Self::compute_distance(o, n, metric))
278 .fold(0.0f32, f32::max)
279 }
280
281 #[inline]
283 fn compute_distance(a: &[f32], b: &[f32], metric: DistanceMetric) -> f32 {
284 match metric {
285 DistanceMetric::Euclidean => Self::squared_euclidean_distance(a, b).sqrt(),
286 DistanceMetric::Cosine => {
287 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
288 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
289 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
290 if norm_a == 0.0 || norm_b == 0.0 {
291 1.0
292 } else {
293 1.0 - (dot / (norm_a * norm_b))
294 }
295 }
296 DistanceMetric::DotProduct => {
297 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
298 -dot
299 }
300 DistanceMetric::Manhattan => a.iter().zip(b.iter()).map(|(x, y)| (x - y).abs()).sum(),
301 DistanceMetric::Chebyshev => {
302 a.iter().zip(b.iter()).map(|(x, y)| (x - y).abs()).fold(0.0f32, f32::max)
303 }
304 }
305 }
306
307 #[inline]
309 fn squared_euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
310 a.iter().zip(b.iter()).map(|(x, y)| (x - y) * (x - y)).sum()
311 }
312
313 #[inline]
315 fn random_u64(state: &mut u64) -> u64 {
316 let mut x = *state;
317 x ^= x << 13;
318 x ^= x >> 7;
319 x ^= x << 17;
320 *state = x;
321 x
322 }
323
324 #[inline]
326 #[allow(clippy::cast_possible_truncation)]
327 fn random_index(state: &mut u64, max: usize) -> usize {
328 (Self::random_u64(state) as usize) % max
329 }
330
331 #[inline]
333 #[allow(clippy::cast_precision_loss)]
334 fn random_f32(state: &mut u64) -> f32 {
335 (Self::random_u64(state) as f64 / u64::MAX as f64) as f32
336 }
337
338 #[must_use]
340 pub fn find_nearest(&self, vector: &[f32], metric: DistanceMetric) -> usize {
341 let mut min_dist = f32::MAX;
342 let mut min_idx = 0;
343
344 for (i, centroid) in self.centroids.iter().enumerate() {
345 let dist = Self::compute_distance(vector, centroid, metric);
346 if dist < min_dist {
347 min_dist = dist;
348 min_idx = i;
349 }
350 }
351
352 min_idx
353 }
354}
355
356#[cfg(test)]
357mod tests {
358 use super::*;
359
360 #[test]
361 fn test_kmeans_simple() {
362 let data: Vec<Vec<f32>> = vec![
364 vec![0.0, 0.0],
365 vec![0.1, 0.1],
366 vec![0.2, 0.0],
367 vec![10.0, 10.0],
368 vec![10.1, 10.1],
369 vec![10.2, 10.0],
370 ];
371
372 let data_refs: Vec<&[f32]> = data.iter().map(|v| v.as_slice()).collect();
373 let config = KMeansConfig::new(2).with_seed(42);
374 let result = KMeans::train(&data_refs, &config, DistanceMetric::Euclidean).unwrap();
375
376 assert_eq!(result.centroids.len(), 2);
377 assert_eq!(result.dimension, 2);
378
379 let c0_near_origin = result.centroids[0][0] < 5.0 || result.centroids[1][0] < 5.0;
381 let c1_near_ten = result.centroids[0][0] > 5.0 || result.centroids[1][0] > 5.0;
382 assert!(c0_near_origin && c1_near_ten);
383 }
384
385 #[test]
386 fn test_kmeans_single_cluster() {
387 let data: Vec<Vec<f32>> = vec![vec![1.0, 2.0], vec![1.1, 2.1], vec![0.9, 1.9]];
388
389 let data_refs: Vec<&[f32]> = data.iter().map(|v| v.as_slice()).collect();
390 let config = KMeansConfig::new(1).with_seed(42);
391 let result = KMeans::train(&data_refs, &config, DistanceMetric::Euclidean).unwrap();
392
393 assert_eq!(result.centroids.len(), 1);
394 assert!((result.centroids[0][0] - 1.0).abs() < 0.2);
396 assert!((result.centroids[0][1] - 2.0).abs() < 0.2);
397 }
398
399 #[test]
400 fn test_kmeans_empty_data() {
401 let data: Vec<&[f32]> = vec![];
402 let config = KMeansConfig::new(2);
403 let result = KMeans::train(&data, &config, DistanceMetric::Euclidean);
404 assert!(result.is_err());
405 }
406
407 #[test]
408 fn test_kmeans_k_larger_than_data() {
409 let data: Vec<Vec<f32>> = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
410
411 let data_refs: Vec<&[f32]> = data.iter().map(|v| v.as_slice()).collect();
412 let config = KMeansConfig::new(10).with_seed(42); let result = KMeans::train(&data_refs, &config, DistanceMetric::Euclidean).unwrap();
414
415 assert_eq!(result.centroids.len(), 2);
417 }
418
419 #[test]
420 fn test_find_nearest() {
421 let data: Vec<Vec<f32>> =
422 vec![vec![0.0, 0.0], vec![0.1, 0.0], vec![10.0, 10.0], vec![10.1, 10.0]];
423
424 let data_refs: Vec<&[f32]> = data.iter().map(|v| v.as_slice()).collect();
425 let config = KMeansConfig::new(2).with_seed(42);
426 let kmeans = KMeans::train(&data_refs, &config, DistanceMetric::Euclidean).unwrap();
427
428 let query_origin = vec![0.05, 0.05];
430 let query_far = vec![10.05, 10.05];
431
432 let idx_origin = kmeans.find_nearest(&query_origin, DistanceMetric::Euclidean);
433 let idx_far = kmeans.find_nearest(&query_far, DistanceMetric::Euclidean);
434
435 assert_ne!(idx_origin, idx_far);
437 }
438
439 #[test]
440 fn test_cosine_distance_clustering() {
441 let data: Vec<Vec<f32>> =
443 vec![vec![1.0, 0.0], vec![0.9, 0.1], vec![0.0, 1.0], vec![0.1, 0.9]];
444
445 let data_refs: Vec<&[f32]> = data.iter().map(|v| v.as_slice()).collect();
446 let config = KMeansConfig::new(2).with_seed(42);
447 let result = KMeans::train(&data_refs, &config, DistanceMetric::Cosine).unwrap();
448
449 assert_eq!(result.centroids.len(), 2);
450 }
451}