use crate::error::{check_cuvs, Result};
use crate::distance_type::DistanceType;
use std::fmt;
use std::io::{stderr, Write};
pub struct IndexParams(pub ffi::cuvsIvfFlatIndexParams_t);
impl IndexParams {
pub fn new() -> Result<IndexParams> {
unsafe {
let mut params = std::mem::MaybeUninit::<ffi::cuvsIvfFlatIndexParams_t>::uninit();
check_cuvs(ffi::cuvsIvfFlatIndexParamsCreate(params.as_mut_ptr()))?;
Ok(IndexParams(params.assume_init()))
}
}
pub fn set_n_lists(self, n_lists: u32) -> IndexParams {
unsafe {
(*self.0).n_lists = n_lists;
}
self
}
pub fn set_metric(self, metric: DistanceType) -> IndexParams {
unsafe {
(*self.0).metric = metric;
}
self
}
pub fn set_metric_arg(self, metric_arg: f32) -> IndexParams {
unsafe {
(*self.0).metric_arg = metric_arg;
}
self
}
pub fn set_kmeans_n_iters(self, kmeans_n_iters: u32) -> IndexParams {
unsafe {
(*self.0).kmeans_n_iters = kmeans_n_iters;
}
self
}
pub fn set_kmeans_trainset_fraction(self, kmeans_trainset_fraction: f64) -> IndexParams {
unsafe {
(*self.0).kmeans_trainset_fraction = kmeans_trainset_fraction;
}
self
}
pub fn set_add_data_on_build(self, add_data_on_build: bool) -> IndexParams {
unsafe {
(*self.0).add_data_on_build = add_data_on_build;
}
self
}
}
impl fmt::Debug for IndexParams {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "IndexParams({:?})", unsafe { *self.0 })
}
}
impl Drop for IndexParams {
fn drop(&mut self) {
if let Err(e) = check_cuvs(unsafe { ffi::cuvsIvfFlatIndexParamsDestroy(self.0) }) {
write!(
stderr(),
"failed to call cuvsIvfFlatIndexParamsDestroy {:?}",
e
)
.expect("failed to write to stderr");
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_index_params() {
let params = IndexParams::new()
.unwrap()
.set_n_lists(128)
.set_add_data_on_build(false);
unsafe {
assert_eq!((*params.0).n_lists, 128);
assert_eq!((*params.0).add_data_on_build, false);
}
}
}