manifoldb_vector/quantization/
training.rs

1//! K-means clustering for codebook training.
2//!
3//! This module provides k-means clustering used to train Product Quantization codebooks.
4
5use crate::distance::DistanceMetric;
6use crate::error::VectorError;
7
8/// Configuration for k-means clustering.
9#[derive(Debug, Clone)]
10pub struct KMeansConfig {
11    /// Number of clusters (centroids).
12    pub k: usize,
13    /// Maximum number of iterations.
14    pub max_iterations: usize,
15    /// Convergence threshold (stop if centroid movement < threshold).
16    pub convergence_threshold: f32,
17    /// Random seed for reproducibility.
18    pub seed: Option<u64>,
19}
20
21impl Default for KMeansConfig {
22    fn default() -> Self {
23        Self { k: 256, max_iterations: 25, convergence_threshold: 1e-6, seed: None }
24    }
25}
26
27impl KMeansConfig {
28    /// Create a new k-means configuration.
29    #[must_use]
30    pub fn new(k: usize) -> Self {
31        Self { k, ..Default::default() }
32    }
33
34    /// Set the maximum number of iterations.
35    #[must_use]
36    pub const fn with_max_iterations(mut self, iterations: usize) -> Self {
37        self.max_iterations = iterations;
38        self
39    }
40
41    /// Set the convergence threshold.
42    #[must_use]
43    pub const fn with_convergence_threshold(mut self, threshold: f32) -> Self {
44        self.convergence_threshold = threshold;
45        self
46    }
47
48    /// Set the random seed.
49    #[must_use]
50    pub const fn with_seed(mut self, seed: u64) -> Self {
51        self.seed = Some(seed);
52        self
53    }
54}
55
56/// K-means clustering result.
57#[derive(Debug, Clone)]
58pub struct KMeans {
59    /// The cluster centroids.
60    pub centroids: Vec<Vec<f32>>,
61    /// The dimension of each centroid.
62    pub dimension: usize,
63    /// Number of iterations run.
64    pub iterations: usize,
65    /// Final inertia (sum of squared distances to nearest centroid).
66    pub inertia: f32,
67}
68
69impl KMeans {
70    /// Train k-means on the given data.
71    ///
72    /// # Arguments
73    ///
74    /// - `data`: Training vectors, each with the same dimension
75    /// - `config`: K-means configuration
76    /// - `metric`: Distance metric for clustering
77    ///
78    /// # Errors
79    ///
80    /// Returns an error if:
81    /// - `data` is empty
82    /// - Vectors have inconsistent dimensions
83    /// - `k` is greater than the number of data points
84    pub fn train(
85        data: &[&[f32]],
86        config: &KMeansConfig,
87        metric: DistanceMetric,
88    ) -> Result<Self, VectorError> {
89        if data.is_empty() {
90            return Err(VectorError::Encoding("cannot train k-means on empty data".to_string()));
91        }
92
93        let dimension = data[0].len();
94        if dimension == 0 {
95            return Err(VectorError::InvalidDimension { expected: 1, actual: 0 });
96        }
97
98        // Validate all vectors have same dimension
99        for (i, v) in data.iter().enumerate() {
100            if v.len() != dimension {
101                return Err(VectorError::DimensionMismatch {
102                    expected: dimension,
103                    actual: v.len(),
104                });
105            }
106            // Skip NaN check for training data - assume pre-validated
107            if i > 1000 {
108                break; // Only check first 1000 for performance
109            }
110        }
111
112        let k = config.k.min(data.len());
113        if k == 0 {
114            return Err(VectorError::Encoding("k must be > 0".to_string()));
115        }
116
117        // Initialize centroids using k-means++
118        let mut centroids = Self::kmeans_plus_plus_init(data, k, config.seed);
119
120        // Run iterations
121        let mut assignments = vec![0usize; data.len()];
122        let mut iterations = 0;
123        let mut inertia = f32::MAX;
124
125        for _ in 0..config.max_iterations {
126            iterations += 1;
127
128            // E-step: Assign each point to nearest centroid
129            let new_inertia = Self::assign_clusters(data, &centroids, &mut assignments, metric);
130
131            // M-step: Update centroids
132            let new_centroids = Self::update_centroids(data, &assignments, k, dimension);
133
134            // Check for convergence
135            let max_movement = Self::max_centroid_movement(&centroids, &new_centroids, metric);
136            centroids = new_centroids;
137            inertia = new_inertia;
138
139            if max_movement < config.convergence_threshold {
140                break;
141            }
142        }
143
144        Ok(Self { centroids, dimension, iterations, inertia })
145    }
146
147    /// K-means++ initialization: select initial centroids with probability
148    /// proportional to squared distance from existing centroids.
149    fn kmeans_plus_plus_init(data: &[&[f32]], k: usize, seed: Option<u64>) -> Vec<Vec<f32>> {
150        let mut rng_state = seed.unwrap_or_else(|| {
151            std::time::SystemTime::now()
152                .duration_since(std::time::UNIX_EPOCH)
153                .map(|d| d.as_nanos() as u64)
154                .unwrap_or(42)
155        });
156
157        let mut centroids: Vec<Vec<f32>> = Vec::with_capacity(k);
158
159        // First centroid: random point
160        let first_idx = Self::random_index(&mut rng_state, data.len());
161        centroids.push(data[first_idx].to_vec());
162
163        // Remaining centroids: probability proportional to D(x)^2
164        for _ in 1..k {
165            let mut distances: Vec<f32> = Vec::with_capacity(data.len());
166            let mut total_dist = 0.0f32;
167
168            for point in data {
169                // Find minimum distance to any existing centroid
170                let min_dist = centroids
171                    .iter()
172                    .map(|c| Self::squared_euclidean_distance(point, c))
173                    .fold(f32::MAX, f32::min);
174
175                distances.push(min_dist);
176                total_dist += min_dist;
177            }
178
179            // Sample proportional to distances
180            if total_dist <= 0.0 {
181                // All points are at existing centroids, just pick random
182                let idx = Self::random_index(&mut rng_state, data.len());
183                centroids.push(data[idx].to_vec());
184            } else {
185                let threshold = Self::random_f32(&mut rng_state) * total_dist;
186                let mut cumsum = 0.0f32;
187                let mut selected_idx = data.len() - 1;
188
189                for (i, &d) in distances.iter().enumerate() {
190                    cumsum += d;
191                    if cumsum >= threshold {
192                        selected_idx = i;
193                        break;
194                    }
195                }
196
197                centroids.push(data[selected_idx].to_vec());
198            }
199        }
200
201        centroids
202    }
203
204    /// Assign each data point to its nearest centroid.
205    /// Returns the total inertia (sum of squared distances).
206    fn assign_clusters(
207        data: &[&[f32]],
208        centroids: &[Vec<f32>],
209        assignments: &mut [usize],
210        metric: DistanceMetric,
211    ) -> f32 {
212        let mut total_inertia = 0.0f32;
213
214        for (i, point) in data.iter().enumerate() {
215            let mut min_dist = f32::MAX;
216            let mut min_idx = 0;
217
218            for (j, centroid) in centroids.iter().enumerate() {
219                let dist = Self::compute_distance(point, centroid, metric);
220                if dist < min_dist {
221                    min_dist = dist;
222                    min_idx = j;
223                }
224            }
225
226            assignments[i] = min_idx;
227            total_inertia += min_dist * min_dist;
228        }
229
230        total_inertia
231    }
232
233    /// Update centroids based on current assignments.
234    fn update_centroids(
235        data: &[&[f32]],
236        assignments: &[usize],
237        k: usize,
238        dimension: usize,
239    ) -> Vec<Vec<f32>> {
240        let mut new_centroids = vec![vec![0.0f32; dimension]; k];
241        let mut counts = vec![0usize; k];
242
243        // Sum points per cluster
244        for (point, &cluster) in data.iter().zip(assignments.iter()) {
245            counts[cluster] += 1;
246            for (j, &val) in point.iter().enumerate() {
247                new_centroids[cluster][j] += val;
248            }
249        }
250
251        // Divide by count to get mean
252        for (centroid, &count) in new_centroids.iter_mut().zip(counts.iter()) {
253            if count > 0 {
254                let count_f32 = count as f32;
255                for val in centroid.iter_mut() {
256                    *val /= count_f32;
257                }
258            }
259        }
260
261        // Handle empty clusters by reinitializing to random data point
262        for (i, centroid) in new_centroids.iter_mut().enumerate() {
263            if counts[i] == 0 && !data.is_empty() {
264                // Copy a random data point
265                let idx = i % data.len();
266                centroid.copy_from_slice(data[idx]);
267            }
268        }
269
270        new_centroids
271    }
272
273    /// Compute maximum centroid movement between iterations.
274    fn max_centroid_movement(old: &[Vec<f32>], new: &[Vec<f32>], metric: DistanceMetric) -> f32 {
275        old.iter()
276            .zip(new.iter())
277            .map(|(o, n)| Self::compute_distance(o, n, metric))
278            .fold(0.0f32, f32::max)
279    }
280
281    /// Compute distance between two vectors.
282    #[inline]
283    fn compute_distance(a: &[f32], b: &[f32], metric: DistanceMetric) -> f32 {
284        match metric {
285            DistanceMetric::Euclidean => Self::squared_euclidean_distance(a, b).sqrt(),
286            DistanceMetric::Cosine => {
287                let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
288                let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
289                let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
290                if norm_a == 0.0 || norm_b == 0.0 {
291                    1.0
292                } else {
293                    1.0 - (dot / (norm_a * norm_b))
294                }
295            }
296            DistanceMetric::DotProduct => {
297                let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
298                -dot
299            }
300            DistanceMetric::Manhattan => a.iter().zip(b.iter()).map(|(x, y)| (x - y).abs()).sum(),
301            DistanceMetric::Chebyshev => {
302                a.iter().zip(b.iter()).map(|(x, y)| (x - y).abs()).fold(0.0f32, f32::max)
303            }
304        }
305    }
306
307    /// Squared Euclidean distance (faster, no sqrt).
308    #[inline]
309    fn squared_euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
310        a.iter().zip(b.iter()).map(|(x, y)| (x - y) * (x - y)).sum()
311    }
312
313    /// Simple xorshift64 PRNG.
314    #[inline]
315    fn random_u64(state: &mut u64) -> u64 {
316        let mut x = *state;
317        x ^= x << 13;
318        x ^= x >> 7;
319        x ^= x << 17;
320        *state = x;
321        x
322    }
323
324    /// Random index in [0, max).
325    #[inline]
326    #[allow(clippy::cast_possible_truncation)]
327    fn random_index(state: &mut u64, max: usize) -> usize {
328        (Self::random_u64(state) as usize) % max
329    }
330
331    /// Random f32 in [0, 1).
332    #[inline]
333    #[allow(clippy::cast_precision_loss)]
334    fn random_f32(state: &mut u64) -> f32 {
335        (Self::random_u64(state) as f64 / u64::MAX as f64) as f32
336    }
337
338    /// Find the index of the nearest centroid to the given vector.
339    #[must_use]
340    pub fn find_nearest(&self, vector: &[f32], metric: DistanceMetric) -> usize {
341        let mut min_dist = f32::MAX;
342        let mut min_idx = 0;
343
344        for (i, centroid) in self.centroids.iter().enumerate() {
345            let dist = Self::compute_distance(vector, centroid, metric);
346            if dist < min_dist {
347                min_dist = dist;
348                min_idx = i;
349            }
350        }
351
352        min_idx
353    }
354}
355
356#[cfg(test)]
357mod tests {
358    use super::*;
359
360    #[test]
361    fn test_kmeans_simple() {
362        // Simple 2D data with two obvious clusters
363        let data: Vec<Vec<f32>> = vec![
364            vec![0.0, 0.0],
365            vec![0.1, 0.1],
366            vec![0.2, 0.0],
367            vec![10.0, 10.0],
368            vec![10.1, 10.1],
369            vec![10.2, 10.0],
370        ];
371
372        let data_refs: Vec<&[f32]> = data.iter().map(|v| v.as_slice()).collect();
373        let config = KMeansConfig::new(2).with_seed(42);
374        let result = KMeans::train(&data_refs, &config, DistanceMetric::Euclidean).unwrap();
375
376        assert_eq!(result.centroids.len(), 2);
377        assert_eq!(result.dimension, 2);
378
379        // Check that centroids are near the cluster centers
380        let c0_near_origin = result.centroids[0][0] < 5.0 || result.centroids[1][0] < 5.0;
381        let c1_near_ten = result.centroids[0][0] > 5.0 || result.centroids[1][0] > 5.0;
382        assert!(c0_near_origin && c1_near_ten);
383    }
384
385    #[test]
386    fn test_kmeans_single_cluster() {
387        let data: Vec<Vec<f32>> = vec![vec![1.0, 2.0], vec![1.1, 2.1], vec![0.9, 1.9]];
388
389        let data_refs: Vec<&[f32]> = data.iter().map(|v| v.as_slice()).collect();
390        let config = KMeansConfig::new(1).with_seed(42);
391        let result = KMeans::train(&data_refs, &config, DistanceMetric::Euclidean).unwrap();
392
393        assert_eq!(result.centroids.len(), 1);
394        // Centroid should be near the mean
395        assert!((result.centroids[0][0] - 1.0).abs() < 0.2);
396        assert!((result.centroids[0][1] - 2.0).abs() < 0.2);
397    }
398
399    #[test]
400    fn test_kmeans_empty_data() {
401        let data: Vec<&[f32]> = vec![];
402        let config = KMeansConfig::new(2);
403        let result = KMeans::train(&data, &config, DistanceMetric::Euclidean);
404        assert!(result.is_err());
405    }
406
407    #[test]
408    fn test_kmeans_k_larger_than_data() {
409        let data: Vec<Vec<f32>> = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
410
411        let data_refs: Vec<&[f32]> = data.iter().map(|v| v.as_slice()).collect();
412        let config = KMeansConfig::new(10).with_seed(42); // k=10, but only 2 points
413        let result = KMeans::train(&data_refs, &config, DistanceMetric::Euclidean).unwrap();
414
415        // Should cap k at data.len()
416        assert_eq!(result.centroids.len(), 2);
417    }
418
419    #[test]
420    fn test_find_nearest() {
421        let data: Vec<Vec<f32>> =
422            vec![vec![0.0, 0.0], vec![0.1, 0.0], vec![10.0, 10.0], vec![10.1, 10.0]];
423
424        let data_refs: Vec<&[f32]> = data.iter().map(|v| v.as_slice()).collect();
425        let config = KMeansConfig::new(2).with_seed(42);
426        let kmeans = KMeans::train(&data_refs, &config, DistanceMetric::Euclidean).unwrap();
427
428        // Point near origin should match the origin cluster
429        let query_origin = vec![0.05, 0.05];
430        let query_far = vec![10.05, 10.05];
431
432        let idx_origin = kmeans.find_nearest(&query_origin, DistanceMetric::Euclidean);
433        let idx_far = kmeans.find_nearest(&query_far, DistanceMetric::Euclidean);
434
435        // Should map to different clusters
436        assert_ne!(idx_origin, idx_far);
437    }
438
439    #[test]
440    fn test_cosine_distance_clustering() {
441        // Vectors in different directions (cosine should separate them)
442        let data: Vec<Vec<f32>> =
443            vec![vec![1.0, 0.0], vec![0.9, 0.1], vec![0.0, 1.0], vec![0.1, 0.9]];
444
445        let data_refs: Vec<&[f32]> = data.iter().map(|v| v.as_slice()).collect();
446        let config = KMeansConfig::new(2).with_seed(42);
447        let result = KMeans::train(&data_refs, &config, DistanceMetric::Cosine).unwrap();
448
449        assert_eq!(result.centroids.len(), 2);
450    }
451}