1use crate::error::Result;
4use crate::faiss_try;
5use crate::index::NativeIndex;
6use faiss_sys::*;
7use std::os::raw::c_int;
8use std::{mem, ptr};
9
10pub struct ClusteringParameters {
12 inner: FaissClusteringParameters,
13}
14
15impl Default for ClusteringParameters {
16 fn default() -> Self {
17 ClusteringParameters::new()
18 }
19}
20
21impl ClusteringParameters {
22 pub fn new() -> Self {
24 unsafe {
25 let mut inner: FaissClusteringParameters = mem::zeroed();
26 faiss_ClusteringParameters_init(&mut inner);
27 ClusteringParameters { inner }
28 }
29 }
30
31 pub fn niter(&self) -> i32 {
32 self.inner.niter
33 }
34
35 pub fn nredo(&self) -> i32 {
36 self.inner.nredo
37 }
38
39 pub fn min_points_per_centroid(&self) -> i32 {
40 self.inner.min_points_per_centroid
41 }
42
43 pub fn max_points_per_centroid(&self) -> i32 {
44 self.inner.max_points_per_centroid
45 }
46
47 pub fn frozen_centroids(&self) -> bool {
48 self.inner.frozen_centroids != 0
49 }
50
51 pub fn spherical(&self) -> bool {
52 self.inner.spherical != 0
53 }
54
55 pub fn int_centroids(&self) -> bool {
58 self.inner.int_centroids != 0
59 }
60
61 pub fn update_index(&self) -> bool {
62 self.inner.update_index != 0
63 }
64
65 pub fn verbose(&self) -> bool {
66 self.inner.verbose != 0
67 }
68
69 pub fn seed(&self) -> u32 {
70 self.inner.seed as u32
71 }
72
73 pub fn decode_block_size(&self) -> usize {
76 self.inner.decode_block_size
77 }
78
79 pub fn set_niter(&mut self, niter: u32) {
80 self.inner.niter = (niter & 0x7FFF_FFFF) as i32;
81 }
82
83 pub fn set_nredo(&mut self, nredo: u32) {
84 self.inner.nredo = (nredo & 0x7FFF_FFFF) as i32;
85 }
86
87 pub fn set_min_points_per_centroid(&mut self, min_points_per_centroid: u32) {
88 self.inner.min_points_per_centroid = (min_points_per_centroid & 0x7FFF_FFFF) as i32;
89 }
90
91 pub fn set_max_points_per_centroid(&mut self, max_points_per_centroid: u32) {
92 self.inner.max_points_per_centroid = (max_points_per_centroid & 0x7FFF_FFFF) as i32;
93 }
94
95 pub fn set_frozen_centroids(&mut self, frozen_centroids: bool) {
96 self.inner.frozen_centroids = if frozen_centroids { 1 } else { 0 };
97 }
98
99 pub fn set_update_index(&mut self, update_index: bool) {
100 self.inner.update_index = if update_index { 1 } else { 0 };
101 }
102
103 pub fn set_spherical(&mut self, spherical: bool) {
104 self.inner.spherical = if spherical { 1 } else { 0 };
105 }
106
107 pub fn set_int_centroids(&mut self, int_centroids: bool) {
110 self.inner.int_centroids = if int_centroids { 1 } else { 0 }
111 }
112
113 pub fn set_verbose(&mut self, verbose: bool) {
114 self.inner.verbose = if verbose { 1 } else { 0 };
115 }
116
117 pub fn set_seed(&mut self, seed: u32) {
118 self.inner.seed = seed as i32;
119 }
120
121 pub fn set_decode_block_size(&mut self, decode_block_size: usize) {
124 self.inner.decode_block_size = decode_block_size;
125 }
126}
127
128#[repr(C)]
129#[derive(Debug, Copy, Clone)]
130pub struct ClusteringIterationStats {
131 _unused: [u8; 0],
132}
133
134impl ClusteringIterationStats {
135 pub fn obj(&self) -> f32 {
137 unsafe {
138 faiss_ClusteringIterationStats_obj(
139 self as *const ClusteringIterationStats as *const FaissClusteringIterationStats,
140 )
141 }
142 }
143 pub fn time(&self) -> f64 {
145 unsafe {
146 faiss_ClusteringIterationStats_time(
147 self as *const ClusteringIterationStats as *const FaissClusteringIterationStats,
148 )
149 }
150 }
151 pub fn time_search(&self) -> f64 {
153 unsafe {
154 faiss_ClusteringIterationStats_time_search(
155 self as *const ClusteringIterationStats as *const FaissClusteringIterationStats,
156 )
157 }
158 }
159 pub fn imbalance_factor(&self) -> f64 {
161 unsafe {
162 faiss_ClusteringIterationStats_imbalance_factor(
163 self as *const ClusteringIterationStats as *const FaissClusteringIterationStats,
164 )
165 }
166 }
167 pub fn nsplit(&self) -> i32 {
169 unsafe {
170 faiss_ClusteringIterationStats_nsplit(
171 self as *const ClusteringIterationStats as *const FaissClusteringIterationStats,
172 )
173 }
174 }
175}
176
177pub struct Clustering {
179 inner: *mut FaissClustering,
180}
181
182unsafe impl Send for Clustering {}
183unsafe impl Sync for Clustering {}
184
185impl Drop for Clustering {
186 fn drop(&mut self) {
187 unsafe {
188 faiss_Clustering_free(self.inner);
189 }
190 }
191}
192
193impl Clustering {
194 pub fn new(d: u32, k: u32) -> Result<Self> {
199 unsafe {
200 let d = d as c_int;
201 let k = k as c_int;
202 let mut inner: *mut FaissClustering = ptr::null_mut();
203 faiss_try(faiss_Clustering_new(&mut inner, d, k))?;
204 Ok(Clustering { inner })
205 }
206 }
207
208 pub fn new_with_params(d: u32, k: u32, params: &ClusteringParameters) -> Result<Self> {
212 unsafe {
213 let d = d as c_int;
214 let k = k as c_int;
215 let mut inner: *mut FaissClustering = ptr::null_mut();
216 faiss_try(faiss_Clustering_new_with_params(
217 &mut inner,
218 d,
219 k,
220 ¶ms.inner,
221 ))?;
222 Ok(Clustering { inner })
223 }
224 }
225
226 pub fn train<I>(&mut self, x: &[f32], index: &mut I) -> Result<()>
231 where
232 I: ?Sized + NativeIndex<Inner = FaissIndex>,
233 {
234 unsafe {
235 let n = x.len() / self.d() as usize;
236 faiss_try(faiss_Clustering_train(
237 self.inner,
238 n as idx_t,
239 x.as_ptr(),
240 index.inner_ptr(),
241 ))?;
242 Ok(())
243 }
244 }
245
246 pub fn centroids(&self) -> Result<Vec<&[f32]>> {
251 unsafe {
252 let mut data = ptr::null_mut();
253 let mut size = 0;
254 faiss_Clustering_centroids(self.inner, &mut data, &mut size);
255 Ok(::std::slice::from_raw_parts(data, size)
256 .chunks(self.d() as usize)
257 .collect())
258 }
259 }
260
261 pub fn centroids_mut(&mut self) -> Result<Vec<&mut [f32]>> {
266 unsafe {
267 let mut data = ptr::null_mut();
268 let mut size = 0;
269 faiss_Clustering_centroids(self.inner, &mut data, &mut size);
270 Ok(::std::slice::from_raw_parts_mut(data, size)
271 .chunks_mut(self.d() as usize)
272 .collect())
273 }
274 }
275
276 pub fn iteration_stats(&self) -> &[ClusteringIterationStats] {
281 unsafe {
282 let mut data = ptr::null_mut();
283 let mut size = 0;
284 faiss_Clustering_iteration_stats(self.inner, &mut data, &mut size);
285 ::std::slice::from_raw_parts(data as *mut ClusteringIterationStats, size)
286 }
287 }
288
289 pub fn iteration_stats_mut(&mut self) -> &mut [ClusteringIterationStats] {
294 unsafe {
295 let mut data = ptr::null_mut();
296 let mut size = 0;
297 faiss_Clustering_iteration_stats(self.inner, &mut data, &mut size);
298 ::std::slice::from_raw_parts_mut(data as *mut ClusteringIterationStats, size)
299 }
300 }
301
302 pub fn d(&self) -> u32 {
304 unsafe { faiss_Clustering_d(self.inner) as u32 }
305 }
306
307 pub fn k(&self) -> u32 {
309 unsafe { faiss_Clustering_k(self.inner) as u32 }
310 }
311
312 pub fn niter(&self) -> u32 {
314 unsafe { faiss_Clustering_niter(self.inner) as u32 }
315 }
316
317 pub fn nredo(&self) -> u32 {
319 unsafe { faiss_Clustering_nredo(self.inner) as u32 }
320 }
321
322 pub fn verbose(&self) -> bool {
324 unsafe { faiss_Clustering_niter(self.inner) != 0 }
325 }
326
327 pub fn spherical(&self) -> bool {
329 unsafe { faiss_Clustering_spherical(self.inner) != 0 }
330 }
331
332 pub fn int_centroids(&self) -> bool {
335 unsafe { faiss_Clustering_int_centroids(self.inner) != 0 }
336 }
337
338 pub fn update_index(&self) -> bool {
340 unsafe { faiss_Clustering_update_index(self.inner) != 0 }
341 }
342
343 pub fn frozen_centroids(&self) -> bool {
345 unsafe { faiss_Clustering_frozen_centroids(self.inner) != 0 }
346 }
347
348 pub fn seed(&self) -> u32 {
350 unsafe { faiss_Clustering_seed(self.inner) as u32 }
351 }
352
353 pub fn decode_block_size(&self) -> usize {
356 unsafe { faiss_Clustering_decode_block_size(self.inner) }
357 }
358
359 pub fn min_points_per_centroid(&self) -> u32 {
361 unsafe { faiss_Clustering_min_points_per_centroid(self.inner) as u32 }
362 }
363
364 pub fn max_points_per_centroid(&self) -> u32 {
366 unsafe { faiss_Clustering_max_points_per_centroid(self.inner) as u32 }
367 }
368}
369
370#[derive(Debug, Clone, PartialEq, PartialOrd)]
375pub struct KMeansResult {
376 pub centroids: Vec<f32>,
378 pub q_error: f32,
380}
381
382pub fn kmeans_clustering(d: u32, k: u32, x: &[f32]) -> Result<KMeansResult> {
393 unsafe {
394 let n = x.len() / d as usize;
395 let mut centroids = vec![0_f32; (d * k) as usize];
396 let mut q_error: f32 = 0.;
397 faiss_try(faiss_kmeans_clustering(
398 d as usize,
399 n,
400 k as usize,
401 x.as_ptr(),
402 centroids.as_mut_ptr(),
403 &mut q_error,
404 ))?;
405 Ok(KMeansResult { centroids, q_error })
406 }
407}
408
409#[cfg(test)]
410mod tests {
411 use super::{kmeans_clustering, Clustering, ClusteringParameters};
412 use crate::index::index_factory;
413 use crate::MetricType;
414
415 #[test]
416 fn test_clustering() {
417 const D: u32 = 8;
418 const K: u32 = 3;
419 const NITER: u32 = 12;
420 let mut params = ClusteringParameters::default();
421 params.set_niter(NITER);
422 params.set_min_points_per_centroid(1);
423 params.set_max_points_per_centroid(10);
424
425 let some_data = [
426 7.5_f32, -7.5, 7.5, -7.5, 7.5, 7.5, 7.5, 7.5, -1., 1., 1., 1., 1., 1., 1., -1., 0., 0.,
427 0., 1., 1., 0., 0., -1., 100., 100., 100., 100., -100., 100., 100., 100., -7., 1., 4.,
428 1., 2., 1., 3., -1., 120., 100., 100., 120., -100., 100., 100., 120., 0., 0., -12., 1.,
429 1., 0., 6., -1., 0., 0., -0.25, 1., 16., 24., 0., -1., 100., 10., 100., 100., 10.,
430 100., 50., 10., 20., 22., 4.5, -2., -100., 0., 0., 100.,
431 ];
432
433 let mut clustering = Clustering::new_with_params(D, K, ¶ms).unwrap();
434 let mut index = index_factory(D, "Flat", MetricType::L2).unwrap();
435 clustering.train(&some_data, &mut index).unwrap();
436
437 let centroids: Vec<_> = clustering.centroids().unwrap();
438 assert_eq!(centroids.len(), K as usize);
439
440 for c in centroids {
441 assert_eq!(c.len(), D as usize);
442 }
443
444 let stats = clustering.iteration_stats();
445 assert_eq!(stats.len(), NITER as usize);
446 }
447
448 #[test]
449 fn test_simple_clustering() {
450 const D: u32 = 8;
451 const K: u32 = 2;
452 let some_data = [
453 7.5_f32, -7.5, 7.5, -7.5, 7.5, 7.5, 7.5, 7.5, -1., 1., 1., 1., 1., 1., 1., -1., 0., 0.,
454 0., 1., 1., 0., 0., -1., 100., 100., 100., 100., -100., 100., 100., 100., -7., 1., 4.,
455 1., 2., 1., 3., -1., 120., 100., 100., 120., -100., 100., 100., 120., 0., 0., -12., 1.,
456 1., 0., 6., -1., 0., 0., -0.25, 1., 16., 24., 0., -1., 100., 10., 100., 100., 10.,
457 100., 50., 10., 20., 22., 4.5, -2., -100., 0., 0., 100.,
458 ];
459
460 let out = kmeans_clustering(D, K, &some_data).unwrap();
461 assert!(out.q_error > 0.);
462 assert_eq!(out.centroids.len(), (D * K) as usize);
463 }
464}
465
466#[cfg(feature = "gpu")]
467pub mod gpu {
468 #[cfg(test)]
469 mod tests {
470 use super::super::{Clustering, ClusteringParameters};
471 use crate::gpu::StandardGpuResources;
472 use crate::index::index_factory;
473 use crate::MetricType;
474
475 #[test]
476 fn test_clustering() {
477 const D: u32 = 8;
478 const K: u32 = 3;
479 const NITER: u32 = 12;
480 let mut params = ClusteringParameters::default();
481 params.set_niter(NITER);
482 params.set_min_points_per_centroid(1);
483 params.set_max_points_per_centroid(10);
484
485 let some_data = [
486 7.5_f32, -7.5, 7.5, -7.5, 7.5, 7.5, 7.5, 7.5, -1., 1., 1., 1., 1., 1., 1., -1., 0.,
487 0., 0., 1., 1., 0., 0., -1., 100., 100., 100., 100., -100., 100., 100., 100., -7.,
488 1., 4., 1., 2., 1., 3., -1., 120., 100., 100., 120., -100., 100., 100., 120., 0.,
489 0., -12., 1., 1., 0., 6., -1., 0., 0., -0.25, 1., 16., 24., 0., -1., 100., 10.,
490 100., 100., 10., 100., 50., 10., 20., 22., 4.5, -2., -100., 0., 0., 100.,
491 ];
492
493 let mut clustering = Clustering::new_with_params(D, K, ¶ms).unwrap();
494 let res = StandardGpuResources::new().unwrap();
495 let mut index = index_factory(D, "Flat", MetricType::L2)
496 .unwrap()
497 .into_gpu(&res, 0)
498 .unwrap();
499 clustering.train(&some_data, &mut index).unwrap();
500
501 let centroids: Vec<_> = clustering.centroids().unwrap();
502 assert_eq!(centroids.len(), K as usize);
503
504 for c in centroids {
505 assert_eq!(c.len(), D as usize);
506 }
507
508 let stats = clustering.iteration_stats();
509 assert_eq!(stats.len(), NITER as usize);
510 }
511 }
512}