faiss/
cluster.rs

1//! Vector clustering interface and implementation.
2
3use crate::error::Result;
4use crate::faiss_try;
5use crate::index::NativeIndex;
6use faiss_sys::*;
7use std::os::raw::c_int;
8use std::{mem, ptr};
9
10/// Parameters for the clustering algorithm.
11pub struct ClusteringParameters {
12    inner: FaissClusteringParameters,
13}
14
15impl Default for ClusteringParameters {
16    fn default() -> Self {
17        ClusteringParameters::new()
18    }
19}
20
21impl ClusteringParameters {
22    /// Create a new clustering parameters object.
23    pub fn new() -> Self {
24        unsafe {
25            let mut inner: FaissClusteringParameters = mem::zeroed();
26            faiss_ClusteringParameters_init(&mut inner);
27            ClusteringParameters { inner }
28        }
29    }
30
31    pub fn niter(&self) -> i32 {
32        self.inner.niter
33    }
34
35    pub fn nredo(&self) -> i32 {
36        self.inner.nredo
37    }
38
39    pub fn min_points_per_centroid(&self) -> i32 {
40        self.inner.min_points_per_centroid
41    }
42
43    pub fn max_points_per_centroid(&self) -> i32 {
44        self.inner.max_points_per_centroid
45    }
46
47    pub fn frozen_centroids(&self) -> bool {
48        self.inner.frozen_centroids != 0
49    }
50
51    pub fn spherical(&self) -> bool {
52        self.inner.spherical != 0
53    }
54
55    /// Getter for the `int_centroids` property.
56    /// Round centroids coordinates to integer
57    pub fn int_centroids(&self) -> bool {
58        self.inner.int_centroids != 0
59    }
60
61    pub fn update_index(&self) -> bool {
62        self.inner.update_index != 0
63    }
64
65    pub fn verbose(&self) -> bool {
66        self.inner.verbose != 0
67    }
68
69    pub fn seed(&self) -> u32 {
70        self.inner.seed as u32
71    }
72
73    /// Getter for the `decode_block_size` property.
74    /// How many vectors at a time to decode
75    pub fn decode_block_size(&self) -> usize {
76        self.inner.decode_block_size
77    }
78
79    pub fn set_niter(&mut self, niter: u32) {
80        self.inner.niter = (niter & 0x7FFF_FFFF) as i32;
81    }
82
83    pub fn set_nredo(&mut self, nredo: u32) {
84        self.inner.nredo = (nredo & 0x7FFF_FFFF) as i32;
85    }
86
87    pub fn set_min_points_per_centroid(&mut self, min_points_per_centroid: u32) {
88        self.inner.min_points_per_centroid = (min_points_per_centroid & 0x7FFF_FFFF) as i32;
89    }
90
91    pub fn set_max_points_per_centroid(&mut self, max_points_per_centroid: u32) {
92        self.inner.max_points_per_centroid = (max_points_per_centroid & 0x7FFF_FFFF) as i32;
93    }
94
95    pub fn set_frozen_centroids(&mut self, frozen_centroids: bool) {
96        self.inner.frozen_centroids = if frozen_centroids { 1 } else { 0 };
97    }
98
99    pub fn set_update_index(&mut self, update_index: bool) {
100        self.inner.update_index = if update_index { 1 } else { 0 };
101    }
102
103    pub fn set_spherical(&mut self, spherical: bool) {
104        self.inner.spherical = if spherical { 1 } else { 0 };
105    }
106
107    /// Setter for the `int_centroids` property.
108    /// Round centroids coordinates to integer
109    pub fn set_int_centroids(&mut self, int_centroids: bool) {
110        self.inner.int_centroids = if int_centroids { 1 } else { 0 }
111    }
112
113    pub fn set_verbose(&mut self, verbose: bool) {
114        self.inner.verbose = if verbose { 1 } else { 0 };
115    }
116
117    pub fn set_seed(&mut self, seed: u32) {
118        self.inner.seed = seed as i32;
119    }
120
121    /// Setter for the `decode_block_size` property.
122    /// How many vectors at a time to decode
123    pub fn set_decode_block_size(&mut self, decode_block_size: usize) {
124        self.inner.decode_block_size = decode_block_size;
125    }
126}
127
128#[repr(C)]
129#[derive(Debug, Copy, Clone)]
130pub struct ClusteringIterationStats {
131    _unused: [u8; 0],
132}
133
134impl ClusteringIterationStats {
135    /// objective values (sum of distances reported by index)
136    pub fn obj(&self) -> f32 {
137        unsafe {
138            faiss_ClusteringIterationStats_obj(
139                self as *const ClusteringIterationStats as *const FaissClusteringIterationStats,
140            )
141        }
142    }
143    /// seconds for iteration
144    pub fn time(&self) -> f64 {
145        unsafe {
146            faiss_ClusteringIterationStats_time(
147                self as *const ClusteringIterationStats as *const FaissClusteringIterationStats,
148            )
149        }
150    }
151    /// seconds for just search
152    pub fn time_search(&self) -> f64 {
153        unsafe {
154            faiss_ClusteringIterationStats_time_search(
155                self as *const ClusteringIterationStats as *const FaissClusteringIterationStats,
156            )
157        }
158    }
159    /// imbalance factor of iteration
160    pub fn imbalance_factor(&self) -> f64 {
161        unsafe {
162            faiss_ClusteringIterationStats_imbalance_factor(
163                self as *const ClusteringIterationStats as *const FaissClusteringIterationStats,
164            )
165        }
166    }
167    /// number of cluster splits
168    pub fn nsplit(&self) -> i32 {
169        unsafe {
170            faiss_ClusteringIterationStats_nsplit(
171                self as *const ClusteringIterationStats as *const FaissClusteringIterationStats,
172            )
173        }
174    }
175}
176
177/// The clustering algorithm.
178pub struct Clustering {
179    inner: *mut FaissClustering,
180}
181
182unsafe impl Send for Clustering {}
183unsafe impl Sync for Clustering {}
184
185impl Drop for Clustering {
186    fn drop(&mut self) {
187        unsafe {
188            faiss_Clustering_free(self.inner);
189        }
190    }
191}
192
193impl Clustering {
194    /**
195     * Obtain a new clustering object with the given dimensionality
196     * `d` and number of centroids `k`.
197     */
198    pub fn new(d: u32, k: u32) -> Result<Self> {
199        unsafe {
200            let d = d as c_int;
201            let k = k as c_int;
202            let mut inner: *mut FaissClustering = ptr::null_mut();
203            faiss_try(faiss_Clustering_new(&mut inner, d, k))?;
204            Ok(Clustering { inner })
205        }
206    }
207
208    /**
209     * Obtain a new clustering object, with the given clustering parameters.
210     */
211    pub fn new_with_params(d: u32, k: u32, params: &ClusteringParameters) -> Result<Self> {
212        unsafe {
213            let d = d as c_int;
214            let k = k as c_int;
215            let mut inner: *mut FaissClustering = ptr::null_mut();
216            faiss_try(faiss_Clustering_new_with_params(
217                &mut inner,
218                d,
219                k,
220                &params.inner,
221            ))?;
222            Ok(Clustering { inner })
223        }
224    }
225
226    /**
227     * Perform the clustering algorithm with the given data and index.
228     * The index is used during the assignment stage.
229     */
230    pub fn train<I>(&mut self, x: &[f32], index: &mut I) -> Result<()>
231    where
232        I: ?Sized + NativeIndex<Inner = FaissIndex>,
233    {
234        unsafe {
235            let n = x.len() / self.d() as usize;
236            faiss_try(faiss_Clustering_train(
237                self.inner,
238                n as idx_t,
239                x.as_ptr(),
240                index.inner_ptr(),
241            ))?;
242            Ok(())
243        }
244    }
245
246    /**
247     * Retrieve the centroids from the clustering process. Returns
248     * a vector of `k` slices of size `d`.
249     */
250    pub fn centroids(&self) -> Result<Vec<&[f32]>> {
251        unsafe {
252            let mut data = ptr::null_mut();
253            let mut size = 0;
254            faiss_Clustering_centroids(self.inner, &mut data, &mut size);
255            Ok(::std::slice::from_raw_parts(data, size)
256                .chunks(self.d() as usize)
257                .collect())
258        }
259    }
260
261    /**
262     * Retrieve the centroids from the clustering process. Returns
263     * a vector of `k` slices of size `d`.
264     */
265    pub fn centroids_mut(&mut self) -> Result<Vec<&mut [f32]>> {
266        unsafe {
267            let mut data = ptr::null_mut();
268            let mut size = 0;
269            faiss_Clustering_centroids(self.inner, &mut data, &mut size);
270            Ok(::std::slice::from_raw_parts_mut(data, size)
271                .chunks_mut(self.d() as usize)
272                .collect())
273        }
274    }
275
276    /**
277     * Retrieve the stats achieved from the clustering process.
278     * Returns as many values as the number of iterations made.
279     */
280    pub fn iteration_stats(&self) -> &[ClusteringIterationStats] {
281        unsafe {
282            let mut data = ptr::null_mut();
283            let mut size = 0;
284            faiss_Clustering_iteration_stats(self.inner, &mut data, &mut size);
285            ::std::slice::from_raw_parts(data as *mut ClusteringIterationStats, size)
286        }
287    }
288
289    /**
290     * Retrieve the stats.
291     * Returns as many values as the number of iterations made.
292     */
293    pub fn iteration_stats_mut(&mut self) -> &mut [ClusteringIterationStats] {
294        unsafe {
295            let mut data = ptr::null_mut();
296            let mut size = 0;
297            faiss_Clustering_iteration_stats(self.inner, &mut data, &mut size);
298            ::std::slice::from_raw_parts_mut(data as *mut ClusteringIterationStats, size)
299        }
300    }
301
302    /** Getter for the clustering object's vector dimensionality. */
303    pub fn d(&self) -> u32 {
304        unsafe { faiss_Clustering_d(self.inner) as u32 }
305    }
306
307    /** Getter for the number of centroids. */
308    pub fn k(&self) -> u32 {
309        unsafe { faiss_Clustering_k(self.inner) as u32 }
310    }
311
312    /** Getter for the number of k-means iterations. */
313    pub fn niter(&self) -> u32 {
314        unsafe { faiss_Clustering_niter(self.inner) as u32 }
315    }
316
317    /** Getter for the `nredo` property of `Clustering`. */
318    pub fn nredo(&self) -> u32 {
319        unsafe { faiss_Clustering_nredo(self.inner) as u32 }
320    }
321
322    /** Getter for the `verbose` property of `Clustering`. */
323    pub fn verbose(&self) -> bool {
324        unsafe { faiss_Clustering_niter(self.inner) != 0 }
325    }
326
327    /** Getter for whether spherical clustering is intended. */
328    pub fn spherical(&self) -> bool {
329        unsafe { faiss_Clustering_spherical(self.inner) != 0 }
330    }
331
332    /// Getter for the `int_centroids` property of `Clustering`.
333    /// Round centroids coordinates to integer
334    pub fn int_centroids(&self) -> bool {
335        unsafe { faiss_Clustering_int_centroids(self.inner) != 0 }
336    }
337
338    /** Getter for the `update_index` property of `Clustering`. */
339    pub fn update_index(&self) -> bool {
340        unsafe { faiss_Clustering_update_index(self.inner) != 0 }
341    }
342
343    /** Getter for the `frozen_centroids` property of `Clustering`. */
344    pub fn frozen_centroids(&self) -> bool {
345        unsafe { faiss_Clustering_frozen_centroids(self.inner) != 0 }
346    }
347
348    /** Getter for the `seed` property of `Clustering`. */
349    pub fn seed(&self) -> u32 {
350        unsafe { faiss_Clustering_seed(self.inner) as u32 }
351    }
352
353    /// Getter for the `decode_block_size` property of `Clustering`.
354    /// How many vectors at a time to decode
355    pub fn decode_block_size(&self) -> usize {
356        unsafe { faiss_Clustering_decode_block_size(self.inner) }
357    }
358
359    /** Getter for the minimum number of points per centroid. */
360    pub fn min_points_per_centroid(&self) -> u32 {
361        unsafe { faiss_Clustering_min_points_per_centroid(self.inner) as u32 }
362    }
363
364    /** Getter for the maximum number of points per centroid. */
365    pub fn max_points_per_centroid(&self) -> u32 {
366        unsafe { faiss_Clustering_max_points_per_centroid(self.inner) as u32 }
367    }
368}
369
370/// Plain data structure for the outcome of the simple k-means clustering
371/// function (see [`kmeans_clustering`]).
372///
373/// [`kmeans_clustering`]: fn.kmeans_clustering.html
374#[derive(Debug, Clone, PartialEq, PartialOrd)]
375pub struct KMeansResult {
376    /// The centroids of each cluster as a single contiguous vector (size `k * d`)
377    pub centroids: Vec<f32>,
378    /// The quantization error
379    pub q_error: f32,
380}
381
382/// Simplified interface for k-means clustering.
383///
384/// - `d`: dimension of the data
385/// - `k`: nb of output centroids
386/// - `x`: training set (size `n * d`)
387///
388/// The number of points is inferred from `x` and `k`.
389///
390/// Returns the final quantization error and centroids (size `k * d`).
391///
392pub fn kmeans_clustering(d: u32, k: u32, x: &[f32]) -> Result<KMeansResult> {
393    unsafe {
394        let n = x.len() / d as usize;
395        let mut centroids = vec![0_f32; (d * k) as usize];
396        let mut q_error: f32 = 0.;
397        faiss_try(faiss_kmeans_clustering(
398            d as usize,
399            n,
400            k as usize,
401            x.as_ptr(),
402            centroids.as_mut_ptr(),
403            &mut q_error,
404        ))?;
405        Ok(KMeansResult { centroids, q_error })
406    }
407}
408
409#[cfg(test)]
410mod tests {
411    use super::{kmeans_clustering, Clustering, ClusteringParameters};
412    use crate::index::index_factory;
413    use crate::MetricType;
414
415    #[test]
416    fn test_clustering() {
417        const D: u32 = 8;
418        const K: u32 = 3;
419        const NITER: u32 = 12;
420        let mut params = ClusteringParameters::default();
421        params.set_niter(NITER);
422        params.set_min_points_per_centroid(1);
423        params.set_max_points_per_centroid(10);
424
425        let some_data = [
426            7.5_f32, -7.5, 7.5, -7.5, 7.5, 7.5, 7.5, 7.5, -1., 1., 1., 1., 1., 1., 1., -1., 0., 0.,
427            0., 1., 1., 0., 0., -1., 100., 100., 100., 100., -100., 100., 100., 100., -7., 1., 4.,
428            1., 2., 1., 3., -1., 120., 100., 100., 120., -100., 100., 100., 120., 0., 0., -12., 1.,
429            1., 0., 6., -1., 0., 0., -0.25, 1., 16., 24., 0., -1., 100., 10., 100., 100., 10.,
430            100., 50., 10., 20., 22., 4.5, -2., -100., 0., 0., 100.,
431        ];
432
433        let mut clustering = Clustering::new_with_params(D, K, &params).unwrap();
434        let mut index = index_factory(D, "Flat", MetricType::L2).unwrap();
435        clustering.train(&some_data, &mut index).unwrap();
436
437        let centroids: Vec<_> = clustering.centroids().unwrap();
438        assert_eq!(centroids.len(), K as usize);
439
440        for c in centroids {
441            assert_eq!(c.len(), D as usize);
442        }
443
444        let stats = clustering.iteration_stats();
445        assert_eq!(stats.len(), NITER as usize);
446    }
447
448    #[test]
449    fn test_simple_clustering() {
450        const D: u32 = 8;
451        const K: u32 = 2;
452        let some_data = [
453            7.5_f32, -7.5, 7.5, -7.5, 7.5, 7.5, 7.5, 7.5, -1., 1., 1., 1., 1., 1., 1., -1., 0., 0.,
454            0., 1., 1., 0., 0., -1., 100., 100., 100., 100., -100., 100., 100., 100., -7., 1., 4.,
455            1., 2., 1., 3., -1., 120., 100., 100., 120., -100., 100., 100., 120., 0., 0., -12., 1.,
456            1., 0., 6., -1., 0., 0., -0.25, 1., 16., 24., 0., -1., 100., 10., 100., 100., 10.,
457            100., 50., 10., 20., 22., 4.5, -2., -100., 0., 0., 100.,
458        ];
459
460        let out = kmeans_clustering(D, K, &some_data).unwrap();
461        assert!(out.q_error > 0.);
462        assert_eq!(out.centroids.len(), (D * K) as usize);
463    }
464}
465
466#[cfg(feature = "gpu")]
467pub mod gpu {
468    #[cfg(test)]
469    mod tests {
470        use super::super::{Clustering, ClusteringParameters};
471        use crate::gpu::StandardGpuResources;
472        use crate::index::index_factory;
473        use crate::MetricType;
474
475        #[test]
476        fn test_clustering() {
477            const D: u32 = 8;
478            const K: u32 = 3;
479            const NITER: u32 = 12;
480            let mut params = ClusteringParameters::default();
481            params.set_niter(NITER);
482            params.set_min_points_per_centroid(1);
483            params.set_max_points_per_centroid(10);
484
485            let some_data = [
486                7.5_f32, -7.5, 7.5, -7.5, 7.5, 7.5, 7.5, 7.5, -1., 1., 1., 1., 1., 1., 1., -1., 0.,
487                0., 0., 1., 1., 0., 0., -1., 100., 100., 100., 100., -100., 100., 100., 100., -7.,
488                1., 4., 1., 2., 1., 3., -1., 120., 100., 100., 120., -100., 100., 100., 120., 0.,
489                0., -12., 1., 1., 0., 6., -1., 0., 0., -0.25, 1., 16., 24., 0., -1., 100., 10.,
490                100., 100., 10., 100., 50., 10., 20., 22., 4.5, -2., -100., 0., 0., 100.,
491            ];
492
493            let mut clustering = Clustering::new_with_params(D, K, &params).unwrap();
494            let res = StandardGpuResources::new().unwrap();
495            let mut index = index_factory(D, "Flat", MetricType::L2)
496                .unwrap()
497                .into_gpu(&res, 0)
498                .unwrap();
499            clustering.train(&some_data, &mut index).unwrap();
500
501            let centroids: Vec<_> = clustering.centroids().unwrap();
502            assert_eq!(centroids.len(), K as usize);
503
504            for c in centroids {
505                assert_eq!(c.len(), D as usize);
506            }
507
508            let stats = clustering.iteration_stats();
509            assert_eq!(stats.len(), NITER as usize);
510        }
511    }
512}