gathers/
kmeans.rs

1//! K-means clustering implementation.
2
3use core::panic;
4use std::time::Instant;
5
6use log::debug;
7use rand::Rng;
8use rayon::prelude::*;
9
10use crate::distance::{argmin, neg_dot_product, squared_euclidean, Distance};
11use crate::rabitq::RaBitQ;
12use crate::sampling::subsample;
13use crate::utils::{as_continuous_vec, centroid_residual, normalize};
14
15const EPS: f32 = 1.0 / 1024.0;
16const MIN_POINTS_PER_CENTROID: usize = 39;
17const MAX_POINTS_PER_CENTROID: usize = 256;
18const LARGE_CLUSTER_THRESHOLD: usize = 1 << 20;
19const RAYON_BLOCK_SIZE: usize = 1024 * 32;
20
21/// Assign vectors to centroids.
22pub fn assign(vecs: &[f32], centroids: &[f32], dim: usize, distance: Distance, labels: &mut [u32]) {
23    let mut distances = vec![f32::MAX; centroids.len() / dim];
24
25    match distance {
26        Distance::NegativeDotProduct => {
27            for (i, vec) in vecs.chunks(dim).enumerate() {
28                for (j, centroid) in centroids.chunks(dim).enumerate() {
29                    distances[j] = neg_dot_product(vec, centroid);
30                    if j == 0 || distances[j] < distances[labels[i] as usize] {
31                        labels[i] = j as u32;
32                    }
33                }
34            }
35        }
36        Distance::SquaredEuclidean => {
37            // pre-compute the x**2 & y**2 for L2 distance
38            // let squared_x: Vec<f32> = vecs.chunks(dim).map(l2_norm).collect();
39            // let squared_y: Vec<f32> = centroids.chunks(dim).map(l2_norm).collect();
40
41            labels.copy_from_slice(
42                &vecs
43                    .par_chunks(dim * RAYON_BLOCK_SIZE)
44                    .flat_map(|vec| {
45                        let mut par_labels = vec![0; vec.len() / dim];
46                        let mut par_distances = vec![f32::MAX; centroids.len() / dim];
47                        for (i, v) in vec.chunks(dim).enumerate() {
48                            for (j, centroid) in centroids.chunks(dim).enumerate() {
49                                par_distances[j] = squared_euclidean(v, centroid);
50                            }
51                            par_labels[i] = argmin(&par_distances) as u32;
52                        }
53                        par_labels
54                    })
55                    .collect::<Vec<_>>(),
56            );
57        }
58    }
59}
60
61/// Assign vectors to centroids with RaBitQ.
62///
63/// TODO: support dot product distance
64pub fn rabitq_assign(vecs: &[f32], centroids: &[f32], dim: usize, labels: &mut [u32]) {
65    let rabitq = RaBitQ::new(centroids, dim);
66
67    labels.copy_from_slice(
68        &vecs
69            .par_chunks(dim * RAYON_BLOCK_SIZE)
70            .flat_map(|vec| {
71                vec.chunks(dim)
72                    .map(|v| rabitq.retrieve_top_one(v) as u32)
73                    .collect::<Vec<_>>()
74            })
75            .collect::<Vec<_>>(),
76    );
77
78    let (rough, precise) = rabitq.get_metrics();
79    debug!(
80        "RaBitQ: rough {}, precise {}, ratio: {}",
81        rough,
82        precise,
83        rough as f32 / precise as f32
84    )
85}
86
87/// Update centroids to the mean of assigned vectors.
88pub fn update_centroids(vecs: &[f32], centroids: &mut [f32], dim: usize, labels: &[u32]) -> f32 {
89    let mut means = vec![0.0; centroids.len()];
90    let mut elements = vec![0; centroids.len() / dim];
91    for (i, vec) in vecs.chunks(dim).enumerate() {
92        let label = labels[i] as usize;
93        elements[label] += 1;
94        means[label * dim..(label + 1) * dim]
95            .iter_mut()
96            .zip(vec.iter())
97            .for_each(|(m, &v)| *m += v);
98    }
99    let diff = squared_euclidean(centroids, &means);
100
101    let mut zero_count = 0;
102    for i in 0..elements.len() {
103        if elements[i] == 0 {
104            // need to split another cluster to fill this empty cluster
105            zero_count += 1;
106            let mut target = 0;
107            let mut rng = rand::thread_rng();
108            let base = 1.0 / (vecs.len() / dim - labels.len()) as f32;
109            loop {
110                let p = (elements[target] - 1) as f32 * base;
111                if rng.gen::<f32>() < p {
112                    break;
113                }
114                target = (target + 1) % labels.len();
115            }
116            debug!("split cluster {} to fill empty cluster {}", target, i);
117            if i < target {
118                let (left, right) = centroids.split_at_mut(target * dim);
119                left[i * dim..(i + 1) * dim].copy_from_slice(&right[..dim]);
120            } else {
121                let (left, right) = centroids.split_at_mut(i * dim);
122                right[..dim].copy_from_slice(&left[target * dim..(target + 1) * dim]);
123            }
124            // small symmetric perturbation
125            for j in 0..dim {
126                if j % 2 == 0 {
127                    centroids[i * dim + j] *= 1.0 + EPS;
128                    centroids[target * dim + j] *= 1.0 - EPS;
129                } else {
130                    centroids[i * dim + j] *= 1.0 - EPS;
131                    centroids[target * dim + j] *= 1.0 + EPS;
132                }
133            }
134            // update elements
135            elements[i] = elements[target] / 2;
136            elements[target] -= elements[i];
137        }
138        let divider = (elements[i] as f32).recip();
139        for j in i * dim..(i + 1) * dim {
140            centroids[j] = means[j] * divider;
141        }
142    }
143    if zero_count != 0 {
144        debug!("fixed {} empty clusters", zero_count);
145    }
146    diff
147}
148
149/// K-means clustering algorithm.
150#[derive(Debug)]
151pub struct KMeans {
152    n_cluster: u32,
153    max_iter: u32,
154    tolerance: f32,
155    distance: Distance,
156    use_residual: bool,
157    use_default_config: bool,
158}
159
160impl Default for KMeans {
161    fn default() -> Self {
162        Self {
163            n_cluster: 8,
164            max_iter: 25,
165            tolerance: 1e-4,
166            distance: Distance::default(),
167            use_residual: false,
168            use_default_config: true,
169        }
170    }
171}
172
173impl KMeans {
174    /// Create a new KMeans instance.
175    ///
176    /// # Arguments
177    ///
178    /// * `n_cluster` - number of clusters, recommend to be a number in [sqrt(n) * 4, sqrt(n) * 8]
179    /// * `max_iter` - max number of iterations
180    /// * `tolerance` - convergence tolerance, stop when the diff is less than this value
181    /// * `distance` - distance metric
182    /// * `use_residual` - use residual for more accurate L2 distance computations, only work for L2
183    pub fn new(
184        n_cluster: u32,
185        max_iter: u32,
186        tolerance: f32,
187        distance: Distance,
188        use_residual: bool,
189    ) -> Self {
190        if n_cluster <= 1 {
191            panic!("n_cluster must be greater than 1");
192        }
193        if max_iter <= 1 {
194            panic!("max_iter must be greater than 1");
195        }
196        if tolerance <= 0.0 {
197            panic!("tolerance must be greater than 0.0");
198        }
199        Self {
200            n_cluster,
201            max_iter,
202            tolerance,
203            distance,
204            use_residual,
205            use_default_config: false,
206        }
207    }
208
209    /// Fit the KMeans configurations to the given vectors and return the centroids.
210    pub fn fit(&self, mut vecs: Vec<f32>, dim: usize) -> Vec<f32> {
211        let num = vecs.len() / dim;
212
213        // auto-config the `n_cluster` if it's initialized with `default()`
214        let n_cluster = match self.use_default_config {
215            true => (((num as f32).sqrt() as u32) * 4).min((num / MIN_POINTS_PER_CENTROID) as u32),
216            false => self.n_cluster,
217        };
218        debug!("num of points: {}, num of clusters: {}", num, n_cluster);
219
220        if num < n_cluster as usize {
221            panic!("number of samples must be greater than n_cluster");
222        }
223        if num < n_cluster as usize * MIN_POINTS_PER_CENTROID {
224            panic!("too few samples for n_cluster");
225        }
226
227        // use residual for more accurate L2 distance computations
228        if self.distance == Distance::SquaredEuclidean && self.use_residual {
229            debug!("use residual");
230            centroid_residual(&mut vecs, dim);
231        }
232
233        // subsample
234        if num > MAX_POINTS_PER_CENTROID * n_cluster as usize {
235            let n_sample = MAX_POINTS_PER_CENTROID * n_cluster as usize;
236            debug!("subsample to {} points", n_sample);
237            vecs = as_continuous_vec(&subsample(n_sample, &vecs, dim));
238        }
239
240        let mut centroids = as_continuous_vec(&subsample(n_cluster as usize, &vecs, dim));
241        if self.distance == Distance::NegativeDotProduct {
242            centroids.chunks_mut(dim).for_each(normalize);
243        }
244
245        let mut labels: Vec<u32> = vec![0; num];
246        debug!("start training");
247        for i in 0..self.max_iter {
248            let start_time = Instant::now();
249            if self.distance == Distance::NegativeDotProduct || num * dim <= LARGE_CLUSTER_THRESHOLD
250            {
251                assign(&vecs, &centroids, dim, self.distance, &mut labels);
252            } else {
253                rabitq_assign(&vecs, &centroids, dim, &mut labels);
254            }
255            let diff = update_centroids(&vecs, &mut centroids, dim, &labels);
256            if self.distance == Distance::NegativeDotProduct {
257                centroids.chunks_mut(dim).for_each(normalize);
258            }
259            debug!("iter {} takes {} s", i, start_time.elapsed().as_secs_f32());
260            if diff < self.tolerance {
261                debug!("converged at iter {}", i);
262                break;
263            }
264        }
265
266        centroids
267    }
268}