faiss-next 0.6.0

Rust bindings for Faiss (Facebook AI Similarity Search)
Documentation
use std::ptr;

use faiss_next_sys::{self, FaissClustering, FaissClusteringParameters};

use crate::error::{check_return_code, Result};
use crate::index::Index;

pub struct Clustering {
    inner: *mut FaissClustering,
}

impl Clustering {
    pub fn new(d: u32, k: u32) -> Result<Self> {
        unsafe {
            let mut inner = ptr::null_mut();
            check_return_code(faiss_next_sys::faiss_Clustering_new(
                &mut inner, d as i32, k as i32,
            ))?;
            Ok(Self { inner })
        }
    }

    pub fn new_with_params(d: u32, k: u32, params: &ClusteringParameters) -> Result<Self> {
        unsafe {
            let mut inner = ptr::null_mut();
            check_return_code(faiss_next_sys::faiss_Clustering_new_with_params(
                &mut inner,
                d as i32,
                k as i32,
                params.inner,
            ))?;
            Ok(Self { inner })
        }
    }

    pub fn train(&mut self, n: u64, x: &[f32], index: &mut impl Index) -> Result<()> {
        check_return_code(unsafe {
            faiss_next_sys::faiss_Clustering_train(
                self.inner,
                n as i64,
                x.as_ptr(),
                index.inner_ptr(),
            )
        })
    }

    pub fn niter(&self) -> i32 {
        unsafe { faiss_next_sys::faiss_Clustering_niter(self.inner) }
    }

    pub fn k(&self) -> usize {
        unsafe { faiss_next_sys::faiss_Clustering_k(self.inner) }
    }

    pub fn d(&self) -> usize {
        unsafe { faiss_next_sys::faiss_Clustering_d(self.inner) }
    }

    pub fn centroids(&self) -> Vec<f32> {
        unsafe {
            let mut ptr = ptr::null_mut();
            let mut size = 0usize;
            faiss_next_sys::faiss_Clustering_centroids(self.inner, &mut ptr, &mut size);
            if ptr.is_null() || size == 0 {
                Vec::new()
            } else {
                std::slice::from_raw_parts(ptr, size).to_vec()
            }
        }
    }

    pub fn verbose(&self) -> bool {
        unsafe { faiss_next_sys::faiss_Clustering_verbose(self.inner) != 0 }
    }

    pub fn seed(&self) -> i32 {
        unsafe { faiss_next_sys::faiss_Clustering_seed(self.inner) }
    }
}

impl Drop for Clustering {
    fn drop(&mut self) {
        if !self.inner.is_null() {
            unsafe {
                faiss_next_sys::faiss_Clustering_free(self.inner);
            }
        }
    }
}

pub struct ClusteringParameters {
    inner: *mut FaissClusteringParameters,
}

impl ClusteringParameters {
    pub fn new() -> Result<Self> {
        unsafe {
            let mut inner = Box::new(std::mem::zeroed::<FaissClusteringParameters>());
            faiss_next_sys::faiss_ClusteringParameters_init(inner.as_mut() as *mut _);
            Ok(Self {
                inner: Box::into_raw(inner),
            })
        }
    }

    pub fn niter(&mut self, niter: i32) -> &mut Self {
        unsafe { (*self.inner).niter = niter }
        self
    }

    pub fn verbose(&mut self, verbose: bool) -> &mut Self {
        unsafe { (*self.inner).verbose = verbose as i32 }
        self
    }

    pub fn spherical(&mut self, spherical: bool) -> &mut Self {
        unsafe { (*self.inner).spherical = spherical as i32 }
        self
    }

    pub fn min_points_per_centroid(&mut self, n: i32) -> &mut Self {
        unsafe { (*self.inner).min_points_per_centroid = n }
        self
    }

    pub fn max_points_per_centroid(&mut self, n: i32) -> &mut Self {
        unsafe { (*self.inner).max_points_per_centroid = n }
        self
    }

    pub fn seed(&mut self, seed: i32) -> &mut Self {
        unsafe { (*self.inner).seed = seed }
        self
    }

    pub fn nredo(&mut self, nredo: i32) -> &mut Self {
        unsafe { (*self.inner).nredo = nredo }
        self
    }
}

impl Default for ClusteringParameters {
    fn default() -> Self {
        Self::new().expect("failed to create ClusteringParameters")
    }
}

impl Drop for ClusteringParameters {
    fn drop(&mut self) {
        if !self.inner.is_null() {
            unsafe {
                let _ = Box::from_raw(self.inner);
            }
        }
    }
}