1use core::panic;
4use std::time::Instant;
5
6use log::debug;
7use rand::Rng;
8use rayon::prelude::*;
9
10use crate::distance::{argmin, neg_dot_product, squared_euclidean, Distance};
11use crate::rabitq::RaBitQ;
12use crate::sampling::subsample;
13use crate::utils::{as_continuous_vec, centroid_residual, normalize};
14
15const EPS: f32 = 1.0 / 1024.0;
16const MIN_POINTS_PER_CENTROID: usize = 39;
17const MAX_POINTS_PER_CENTROID: usize = 256;
18const LARGE_CLUSTER_THRESHOLD: usize = 1 << 20;
19const RAYON_BLOCK_SIZE: usize = 1024 * 32;
20
21pub fn assign(vecs: &[f32], centroids: &[f32], dim: usize, distance: Distance, labels: &mut [u32]) {
23 let mut distances = vec![f32::MAX; centroids.len() / dim];
24
25 match distance {
26 Distance::NegativeDotProduct => {
27 for (i, vec) in vecs.chunks(dim).enumerate() {
28 for (j, centroid) in centroids.chunks(dim).enumerate() {
29 distances[j] = neg_dot_product(vec, centroid);
30 if j == 0 || distances[j] < distances[labels[i] as usize] {
31 labels[i] = j as u32;
32 }
33 }
34 }
35 }
36 Distance::SquaredEuclidean => {
37 labels.copy_from_slice(
42 &vecs
43 .par_chunks(dim * RAYON_BLOCK_SIZE)
44 .flat_map(|vec| {
45 let mut par_labels = vec![0; vec.len() / dim];
46 let mut par_distances = vec![f32::MAX; centroids.len() / dim];
47 for (i, v) in vec.chunks(dim).enumerate() {
48 for (j, centroid) in centroids.chunks(dim).enumerate() {
49 par_distances[j] = squared_euclidean(v, centroid);
50 }
51 par_labels[i] = argmin(&par_distances) as u32;
52 }
53 par_labels
54 })
55 .collect::<Vec<_>>(),
56 );
57 }
58 }
59}
60
61pub fn rabitq_assign(vecs: &[f32], centroids: &[f32], dim: usize, labels: &mut [u32]) {
65 let rabitq = RaBitQ::new(centroids, dim);
66
67 labels.copy_from_slice(
68 &vecs
69 .par_chunks(dim * RAYON_BLOCK_SIZE)
70 .flat_map(|vec| {
71 vec.chunks(dim)
72 .map(|v| rabitq.retrieve_top_one(v) as u32)
73 .collect::<Vec<_>>()
74 })
75 .collect::<Vec<_>>(),
76 );
77
78 let (rough, precise) = rabitq.get_metrics();
79 debug!(
80 "RaBitQ: rough {}, precise {}, ratio: {}",
81 rough,
82 precise,
83 rough as f32 / precise as f32
84 )
85}
86
87pub fn update_centroids(vecs: &[f32], centroids: &mut [f32], dim: usize, labels: &[u32]) -> f32 {
89 let mut means = vec![0.0; centroids.len()];
90 let mut elements = vec![0; centroids.len() / dim];
91 for (i, vec) in vecs.chunks(dim).enumerate() {
92 let label = labels[i] as usize;
93 elements[label] += 1;
94 means[label * dim..(label + 1) * dim]
95 .iter_mut()
96 .zip(vec.iter())
97 .for_each(|(m, &v)| *m += v);
98 }
99 let diff = squared_euclidean(centroids, &means);
100
101 let mut zero_count = 0;
102 for i in 0..elements.len() {
103 if elements[i] == 0 {
104 zero_count += 1;
106 let mut target = 0;
107 let mut rng = rand::thread_rng();
108 let base = 1.0 / (vecs.len() / dim - labels.len()) as f32;
109 loop {
110 let p = (elements[target] - 1) as f32 * base;
111 if rng.gen::<f32>() < p {
112 break;
113 }
114 target = (target + 1) % labels.len();
115 }
116 debug!("split cluster {} to fill empty cluster {}", target, i);
117 if i < target {
118 let (left, right) = centroids.split_at_mut(target * dim);
119 left[i * dim..(i + 1) * dim].copy_from_slice(&right[..dim]);
120 } else {
121 let (left, right) = centroids.split_at_mut(i * dim);
122 right[..dim].copy_from_slice(&left[target * dim..(target + 1) * dim]);
123 }
124 for j in 0..dim {
126 if j % 2 == 0 {
127 centroids[i * dim + j] *= 1.0 + EPS;
128 centroids[target * dim + j] *= 1.0 - EPS;
129 } else {
130 centroids[i * dim + j] *= 1.0 - EPS;
131 centroids[target * dim + j] *= 1.0 + EPS;
132 }
133 }
134 elements[i] = elements[target] / 2;
136 elements[target] -= elements[i];
137 }
138 let divider = (elements[i] as f32).recip();
139 for j in i * dim..(i + 1) * dim {
140 centroids[j] = means[j] * divider;
141 }
142 }
143 if zero_count != 0 {
144 debug!("fixed {} empty clusters", zero_count);
145 }
146 diff
147}
148
149#[derive(Debug)]
151pub struct KMeans {
152 n_cluster: u32,
153 max_iter: u32,
154 tolerance: f32,
155 distance: Distance,
156 use_residual: bool,
157 use_default_config: bool,
158}
159
160impl Default for KMeans {
161 fn default() -> Self {
162 Self {
163 n_cluster: 8,
164 max_iter: 25,
165 tolerance: 1e-4,
166 distance: Distance::default(),
167 use_residual: false,
168 use_default_config: true,
169 }
170 }
171}
172
173impl KMeans {
174 pub fn new(
184 n_cluster: u32,
185 max_iter: u32,
186 tolerance: f32,
187 distance: Distance,
188 use_residual: bool,
189 ) -> Self {
190 if n_cluster <= 1 {
191 panic!("n_cluster must be greater than 1");
192 }
193 if max_iter <= 1 {
194 panic!("max_iter must be greater than 1");
195 }
196 if tolerance <= 0.0 {
197 panic!("tolerance must be greater than 0.0");
198 }
199 Self {
200 n_cluster,
201 max_iter,
202 tolerance,
203 distance,
204 use_residual,
205 use_default_config: false,
206 }
207 }
208
209 pub fn fit(&self, mut vecs: Vec<f32>, dim: usize) -> Vec<f32> {
211 let num = vecs.len() / dim;
212
213 let n_cluster = match self.use_default_config {
215 true => (((num as f32).sqrt() as u32) * 4).min((num / MIN_POINTS_PER_CENTROID) as u32),
216 false => self.n_cluster,
217 };
218 debug!("num of points: {}, num of clusters: {}", num, n_cluster);
219
220 if num < n_cluster as usize {
221 panic!("number of samples must be greater than n_cluster");
222 }
223 if num < n_cluster as usize * MIN_POINTS_PER_CENTROID {
224 panic!("too few samples for n_cluster");
225 }
226
227 if self.distance == Distance::SquaredEuclidean && self.use_residual {
229 debug!("use residual");
230 centroid_residual(&mut vecs, dim);
231 }
232
233 if num > MAX_POINTS_PER_CENTROID * n_cluster as usize {
235 let n_sample = MAX_POINTS_PER_CENTROID * n_cluster as usize;
236 debug!("subsample to {} points", n_sample);
237 vecs = as_continuous_vec(&subsample(n_sample, &vecs, dim));
238 }
239
240 let mut centroids = as_continuous_vec(&subsample(n_cluster as usize, &vecs, dim));
241 if self.distance == Distance::NegativeDotProduct {
242 centroids.chunks_mut(dim).for_each(normalize);
243 }
244
245 let mut labels: Vec<u32> = vec![0; num];
246 debug!("start training");
247 for i in 0..self.max_iter {
248 let start_time = Instant::now();
249 if self.distance == Distance::NegativeDotProduct || num * dim <= LARGE_CLUSTER_THRESHOLD
250 {
251 assign(&vecs, ¢roids, dim, self.distance, &mut labels);
252 } else {
253 rabitq_assign(&vecs, ¢roids, dim, &mut labels);
254 }
255 let diff = update_centroids(&vecs, &mut centroids, dim, &labels);
256 if self.distance == Distance::NegativeDotProduct {
257 centroids.chunks_mut(dim).for_each(normalize);
258 }
259 debug!("iter {} takes {} s", i, start_time.elapsed().as_secs_f32());
260 if diff < self.tolerance {
261 debug!("converged at iter {}", i);
262 break;
263 }
264 }
265
266 centroids
267 }
268}