use crate::error::Result;
use crate::faiss_try;
use crate::index::NativeIndex;
use faiss_sys::*;
use std::os::raw::c_int;
use std::{mem, ptr};
pub struct ClusteringParameters {
inner: FaissClusteringParameters,
}
impl Default for ClusteringParameters {
fn default() -> Self {
ClusteringParameters::new()
}
}
impl ClusteringParameters {
pub fn new() -> Self {
unsafe {
let mut inner: FaissClusteringParameters = mem::zeroed();
faiss_ClusteringParameters_init(&mut inner);
ClusteringParameters { inner }
}
}
pub fn niter(&self) -> i32 {
self.inner.niter
}
pub fn nredo(&self) -> i32 {
self.inner.nredo
}
pub fn min_points_per_centroid(&self) -> i32 {
self.inner.min_points_per_centroid
}
pub fn max_points_per_centroid(&self) -> i32 {
self.inner.max_points_per_centroid
}
pub fn frozen_centroids(&self) -> bool {
self.inner.frozen_centroids != 0
}
pub fn spherical(&self) -> bool {
self.inner.spherical != 0
}
pub fn int_centroids(&self) -> bool {
self.inner.int_centroids != 0
}
pub fn update_index(&self) -> bool {
self.inner.update_index != 0
}
pub fn verbose(&self) -> bool {
self.inner.verbose != 0
}
pub fn seed(&self) -> u32 {
self.inner.seed as u32
}
pub fn decode_block_size(&self) -> usize {
self.inner.decode_block_size
}
pub fn set_niter(&mut self, niter: u32) {
self.inner.niter = (niter & 0x7FFF_FFFF) as i32;
}
pub fn set_nredo(&mut self, nredo: u32) {
self.inner.nredo = (nredo & 0x7FFF_FFFF) as i32;
}
pub fn set_min_points_per_centroid(&mut self, min_points_per_centroid: u32) {
self.inner.min_points_per_centroid = (min_points_per_centroid & 0x7FFF_FFFF) as i32;
}
pub fn set_max_points_per_centroid(&mut self, max_points_per_centroid: u32) {
self.inner.max_points_per_centroid = (max_points_per_centroid & 0x7FFF_FFFF) as i32;
}
pub fn set_frozen_centroids(&mut self, frozen_centroids: bool) {
self.inner.frozen_centroids = if frozen_centroids { 1 } else { 0 };
}
pub fn set_update_index(&mut self, update_index: bool) {
self.inner.update_index = if update_index { 1 } else { 0 };
}
pub fn set_spherical(&mut self, spherical: bool) {
self.inner.spherical = if spherical { 1 } else { 0 };
}
pub fn set_int_centroids(&mut self, int_centroids: bool) {
self.inner.int_centroids = if int_centroids { 1 } else { 0 }
}
pub fn set_verbose(&mut self, verbose: bool) {
self.inner.verbose = if verbose { 1 } else { 0 };
}
pub fn set_seed(&mut self, seed: u32) {
self.inner.seed = seed as i32;
}
pub fn set_decode_block_size(&mut self, decode_block_size: usize) {
self.inner.decode_block_size = decode_block_size;
}
}
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct ClusteringIterationStats {
_unused: [u8; 0],
}
impl ClusteringIterationStats {
pub fn obj(&self) -> f32 {
unsafe {
faiss_ClusteringIterationStats_obj(
self as *const ClusteringIterationStats as *const FaissClusteringIterationStats,
)
}
}
pub fn time(&self) -> f64 {
unsafe {
faiss_ClusteringIterationStats_time(
self as *const ClusteringIterationStats as *const FaissClusteringIterationStats,
)
}
}
pub fn time_search(&self) -> f64 {
unsafe {
faiss_ClusteringIterationStats_time_search(
self as *const ClusteringIterationStats as *const FaissClusteringIterationStats,
)
}
}
pub fn imbalance_factor(&self) -> f64 {
unsafe {
faiss_ClusteringIterationStats_imbalance_factor(
self as *const ClusteringIterationStats as *const FaissClusteringIterationStats,
)
}
}
pub fn nsplit(&self) -> i32 {
unsafe {
faiss_ClusteringIterationStats_nsplit(
self as *const ClusteringIterationStats as *const FaissClusteringIterationStats,
)
}
}
}
pub struct Clustering {
inner: *mut FaissClustering,
}
unsafe impl Send for Clustering {}
unsafe impl Sync for Clustering {}
impl Drop for Clustering {
fn drop(&mut self) {
unsafe {
faiss_Clustering_free(self.inner);
}
}
}
impl Clustering {
pub fn new(d: u32, k: u32) -> Result<Self> {
unsafe {
let d = d as c_int;
let k = k as c_int;
let mut inner: *mut FaissClustering = ptr::null_mut();
faiss_try(faiss_Clustering_new(&mut inner, d, k))?;
Ok(Clustering { inner })
}
}
pub fn new_with_params(d: u32, k: u32, params: &ClusteringParameters) -> Result<Self> {
unsafe {
let d = d as c_int;
let k = k as c_int;
let mut inner: *mut FaissClustering = ptr::null_mut();
faiss_try(faiss_Clustering_new_with_params(
&mut inner,
d,
k,
¶ms.inner,
))?;
Ok(Clustering { inner })
}
}
pub fn train<I: ?Sized>(&mut self, x: &[f32], index: &mut I) -> Result<()>
where
I: NativeIndex,
{
unsafe {
let n = x.len() / self.d() as usize;
faiss_try(faiss_Clustering_train(
self.inner,
n as idx_t,
x.as_ptr(),
index.inner_ptr(),
))?;
Ok(())
}
}
pub fn centroids(&self) -> Result<Vec<&[f32]>> {
unsafe {
let mut data = ptr::null_mut();
let mut size = 0;
faiss_Clustering_centroids(self.inner, &mut data, &mut size);
Ok(::std::slice::from_raw_parts(data, size)
.chunks(self.d() as usize)
.collect())
}
}
pub fn centroids_mut(&mut self) -> Result<Vec<&mut [f32]>> {
unsafe {
let mut data = ptr::null_mut();
let mut size = 0;
faiss_Clustering_centroids(self.inner, &mut data, &mut size);
Ok(::std::slice::from_raw_parts_mut(data, size)
.chunks_mut(self.d() as usize)
.collect())
}
}
pub fn iteration_stats(&self) -> &[ClusteringIterationStats] {
unsafe {
let mut data = ptr::null_mut();
let mut size = 0;
faiss_Clustering_iteration_stats(self.inner, &mut data, &mut size);
::std::slice::from_raw_parts(data as *mut ClusteringIterationStats, size)
}
}
pub fn iteration_stats_mut(&mut self) -> &mut [ClusteringIterationStats] {
unsafe {
let mut data = ptr::null_mut();
let mut size = 0;
faiss_Clustering_iteration_stats(self.inner, &mut data, &mut size);
::std::slice::from_raw_parts_mut(data as *mut ClusteringIterationStats, size)
}
}
pub fn d(&self) -> u32 {
unsafe { faiss_Clustering_d(self.inner) as u32 }
}
pub fn k(&self) -> u32 {
unsafe { faiss_Clustering_k(self.inner) as u32 }
}
pub fn niter(&self) -> u32 {
unsafe { faiss_Clustering_niter(self.inner) as u32 }
}
pub fn nredo(&self) -> u32 {
unsafe { faiss_Clustering_nredo(self.inner) as u32 }
}
pub fn verbose(&self) -> bool {
unsafe { faiss_Clustering_niter(self.inner) != 0 }
}
pub fn spherical(&self) -> bool {
unsafe { faiss_Clustering_spherical(self.inner) != 0 }
}
pub fn int_centroids(&self) -> bool {
unsafe { faiss_Clustering_int_centroids(self.inner) != 0 }
}
pub fn update_index(&self) -> bool {
unsafe { faiss_Clustering_update_index(self.inner) != 0 }
}
pub fn frozen_centroids(&self) -> bool {
unsafe { faiss_Clustering_frozen_centroids(self.inner) != 0 }
}
pub fn seed(&self) -> u32 {
unsafe { faiss_Clustering_seed(self.inner) as u32 }
}
pub fn decode_block_size(&self) -> usize {
unsafe { faiss_Clustering_decode_block_size(self.inner) }
}
pub fn min_points_per_centroid(&self) -> u32 {
unsafe { faiss_Clustering_min_points_per_centroid(self.inner) as u32 }
}
pub fn max_points_per_centroid(&self) -> u32 {
unsafe { faiss_Clustering_max_points_per_centroid(self.inner) as u32 }
}
}
#[derive(Debug, Clone, PartialEq, PartialOrd)]
pub struct KMeansResult {
pub centroids: Vec<f32>,
pub q_error: f32,
}
pub fn kmeans_clustering(d: u32, k: u32, x: &[f32]) -> Result<KMeansResult> {
unsafe {
let n = x.len() / d as usize;
let mut centroids = vec![0_f32; (d * k) as usize];
let mut q_error: f32 = 0.;
faiss_try(faiss_kmeans_clustering(
d as usize,
n,
k as usize,
x.as_ptr(),
centroids.as_mut_ptr(),
&mut q_error,
))?;
Ok(KMeansResult { centroids, q_error })
}
}
#[cfg(test)]
mod tests {
use super::{kmeans_clustering, Clustering, ClusteringParameters};
use crate::index::index_factory;
use crate::MetricType;
#[test]
fn test_clustering() {
const D: u32 = 8;
const K: u32 = 3;
const NITER: u32 = 12;
let mut params = ClusteringParameters::default();
params.set_niter(NITER);
params.set_min_points_per_centroid(1);
params.set_max_points_per_centroid(10);
let some_data = [
7.5_f32, -7.5, 7.5, -7.5, 7.5, 7.5, 7.5, 7.5, -1., 1., 1., 1., 1., 1., 1., -1., 0., 0.,
0., 1., 1., 0., 0., -1., 100., 100., 100., 100., -100., 100., 100., 100., -7., 1., 4.,
1., 2., 1., 3., -1., 120., 100., 100., 120., -100., 100., 100., 120., 0., 0., -12., 1.,
1., 0., 6., -1., 0., 0., -0.25, 1., 16., 24., 0., -1., 100., 10., 100., 100., 10.,
100., 50., 10., 20., 22., 4.5, -2., -100., 0., 0., 100.,
];
let mut clustering = Clustering::new_with_params(D, K, ¶ms).unwrap();
let mut index = index_factory(D, "Flat", MetricType::L2).unwrap();
clustering.train(&some_data, &mut index).unwrap();
let centroids: Vec<_> = clustering.centroids().unwrap();
assert_eq!(centroids.len(), K as usize);
for c in centroids {
assert_eq!(c.len(), D as usize);
}
let stats = clustering.iteration_stats();
assert_eq!(stats.len(), NITER as usize);
}
#[test]
fn test_simple_clustering() {
const D: u32 = 8;
const K: u32 = 2;
let some_data = [
7.5_f32, -7.5, 7.5, -7.5, 7.5, 7.5, 7.5, 7.5, -1., 1., 1., 1., 1., 1., 1., -1., 0., 0.,
0., 1., 1., 0., 0., -1., 100., 100., 100., 100., -100., 100., 100., 100., -7., 1., 4.,
1., 2., 1., 3., -1., 120., 100., 100., 120., -100., 100., 100., 120., 0., 0., -12., 1.,
1., 0., 6., -1., 0., 0., -0.25, 1., 16., 24., 0., -1., 100., 10., 100., 100., 10.,
100., 50., 10., 20., 22., 4.5, -2., -100., 0., 0., 100.,
];
let out = kmeans_clustering(D, K, &some_data).unwrap();
assert!(out.q_error > 0.);
assert_eq!(out.centroids.len(), (D * K) as usize);
}
}
#[cfg(feature = "gpu")]
pub mod gpu {
#[cfg(test)]
mod tests {
use super::super::{Clustering, ClusteringParameters};
use crate::gpu::StandardGpuResources;
use crate::index::index_factory;
use crate::MetricType;
#[test]
fn test_clustering() {
const D: u32 = 8;
const K: u32 = 3;
const NITER: u32 = 12;
let mut params = ClusteringParameters::default();
params.set_niter(NITER);
params.set_min_points_per_centroid(1);
params.set_max_points_per_centroid(10);
let some_data = [
7.5_f32, -7.5, 7.5, -7.5, 7.5, 7.5, 7.5, 7.5, -1., 1., 1., 1., 1., 1., 1., -1., 0.,
0., 0., 1., 1., 0., 0., -1., 100., 100., 100., 100., -100., 100., 100., 100., -7.,
1., 4., 1., 2., 1., 3., -1., 120., 100., 100., 120., -100., 100., 100., 120., 0.,
0., -12., 1., 1., 0., 6., -1., 0., 0., -0.25, 1., 16., 24., 0., -1., 100., 10.,
100., 100., 10., 100., 50., 10., 20., 22., 4.5, -2., -100., 0., 0., 100.,
];
let mut clustering = Clustering::new_with_params(D, K, ¶ms).unwrap();
let res = StandardGpuResources::new().unwrap();
let mut index = index_factory(D, "Flat", MetricType::L2)
.unwrap()
.into_gpu(&res, 0)
.unwrap();
clustering.train(&some_data, &mut index).unwrap();
let centroids: Vec<_> = clustering.centroids().unwrap();
assert_eq!(centroids.len(), K as usize);
for c in centroids {
assert_eq!(c.len(), D as usize);
}
let stats = clustering.iteration_stats();
assert_eq!(stats.len(), NITER as usize);
}
}
}