use crate::distance_type::DistanceType;
use crate::error::{check_cuvs, Result};
use std::fmt;
use std::io::{stderr, Write};
pub use ffi::codebook_gen;
pub struct IndexParams(pub ffi::cuvsIvfPqIndexParams_t);
impl IndexParams {
pub fn new() -> Result<IndexParams> {
unsafe {
let mut params = std::mem::MaybeUninit::<ffi::cuvsIvfPqIndexParams_t>::uninit();
check_cuvs(ffi::cuvsIvfPqIndexParamsCreate(params.as_mut_ptr()))?;
Ok(IndexParams(params.assume_init()))
}
}
pub fn set_n_lists(self, n_lists: u32) -> IndexParams {
unsafe {
(*self.0).n_lists = n_lists;
}
self
}
pub fn set_metric(self, metric: DistanceType) -> IndexParams {
unsafe {
(*self.0).metric = metric;
}
self
}
pub fn set_metric_arg(self, metric_arg: f32) -> IndexParams {
unsafe {
(*self.0).metric_arg = metric_arg;
}
self
}
pub fn set_kmeans_n_iters(self, kmeans_n_iters: u32) -> IndexParams {
unsafe {
(*self.0).kmeans_n_iters = kmeans_n_iters;
}
self
}
pub fn set_kmeans_trainset_fraction(self, kmeans_trainset_fraction: f64) -> IndexParams {
unsafe {
(*self.0).kmeans_trainset_fraction = kmeans_trainset_fraction;
}
self
}
pub fn set_pq_bits(self, pq_bits: u32) -> IndexParams {
unsafe {
(*self.0).pq_bits = pq_bits;
}
self
}
pub fn set_pq_dim(self, pq_dim: u32) -> IndexParams {
unsafe {
(*self.0).pq_dim = pq_dim;
}
self
}
pub fn set_codebook_kind(self, codebook_kind: codebook_gen) -> IndexParams {
unsafe {
(*self.0).codebook_kind = codebook_kind;
}
self
}
pub fn set_force_random_rotation(self, force_random_rotation: bool) -> IndexParams {
unsafe {
(*self.0).force_random_rotation = force_random_rotation;
}
self
}
pub fn set_add_data_on_build(self, add_data_on_build: bool) -> IndexParams {
unsafe {
(*self.0).add_data_on_build = add_data_on_build;
}
self
}
}
impl fmt::Debug for IndexParams {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "IndexParams({:?})", unsafe { *self.0 })
}
}
impl Drop for IndexParams {
fn drop(&mut self) {
if let Err(e) = check_cuvs(unsafe { ffi::cuvsIvfPqIndexParamsDestroy(self.0) }) {
write!(
stderr(),
"failed to call cuvsIvfPqIndexParamsDestroy {:?}",
e
)
.expect("failed to write to stderr");
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_index_params() {
let params = IndexParams::new()
.unwrap()
.set_n_lists(128)
.set_add_data_on_build(false);
unsafe {
assert_eq!((*params.0).n_lists, 128);
assert_eq!((*params.0).add_data_on_build, false);
}
}
}