use crate::distance_type::DistanceType;
use crate::error::{check_cuvs, Result};
use std::fmt;
use std::io::{stderr, Write};
pub struct IndexParams(pub ffi::cuvsVamanaIndexParams_t);
impl IndexParams {
pub fn new() -> Result<IndexParams> {
unsafe {
let mut params = std::mem::MaybeUninit::<ffi::cuvsVamanaIndexParams_t>::uninit();
check_cuvs(ffi::cuvsVamanaIndexParamsCreate(params.as_mut_ptr()))?;
Ok(IndexParams(params.assume_init()))
}
}
pub fn set_metric(self, metric: DistanceType) -> IndexParams {
unsafe {
(*self.0).metric = metric;
}
self
}
pub fn set_graph_degree(self, graph_degree: u32) -> IndexParams {
unsafe {
(*self.0).graph_degree = graph_degree;
}
self
}
pub fn set_visited_size(self, visited_size: u32) -> IndexParams {
unsafe {
(*self.0).visited_size = visited_size;
}
self
}
pub fn set_vamana_iters(self, vamana_iters: f32) -> IndexParams {
unsafe {
(*self.0).vamana_iters = vamana_iters;
}
self
}
pub fn set_alpha(self, alpha: f32) -> IndexParams {
unsafe {
(*self.0).alpha = alpha;
}
self
}
pub fn set_max_fraction(self, max_fraction: f32) -> IndexParams {
unsafe {
(*self.0).max_fraction = max_fraction;
}
self
}
pub fn set_batch_base(self, batch_base: f32) -> IndexParams {
unsafe {
(*self.0).batch_base = batch_base;
}
self
}
pub fn set_queue_size(self, queue_size: u32) -> IndexParams {
unsafe {
(*self.0).queue_size = queue_size;
}
self
}
pub fn set_reverse_batchsize(self, reverse_batchsize: u32) -> IndexParams {
unsafe {
(*self.0).reverse_batchsize = reverse_batchsize;
}
self
}
}
impl fmt::Debug for IndexParams {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "IndexParams({:?})", unsafe { *self.0 })
}
}
impl Drop for IndexParams {
fn drop(&mut self) {
if let Err(e) = check_cuvs(unsafe { ffi::cuvsVamanaIndexParamsDestroy(self.0) }) {
write!(
stderr(),
"failed to call cuvsVamanaIndexParamsDestroy {:?}",
e
)
.expect("failed to write to stderr");
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_index_params() {
let params = IndexParams::new()
.unwrap()
.set_alpha(1.0)
.set_visited_size(128);
unsafe {
assert_eq!((*params.0).alpha, 1.0);
assert_eq!((*params.0).visited_size, 128);
}
}
}