use std::ffi::CString;
use std::io::{stderr, Write};
use crate::dlpack::ManagedTensor;
use crate::error::{check_cuvs, Result};
use crate::resources::Resources;
use crate::vamana::IndexParams;
#[derive(Debug)]
pub struct Index(ffi::cuvsVamanaIndex_t);
impl Index {
pub fn build<T: Into<ManagedTensor>>(
res: &Resources,
params: &IndexParams,
dataset: T,
) -> Result<Index> {
let dataset: ManagedTensor = dataset.into();
let index = Index::new()?;
unsafe {
check_cuvs(ffi::cuvsVamanaBuild(
res.0,
params.0,
dataset.as_ptr(),
index.0,
))?;
}
Ok(index)
}
pub fn new() -> Result<Index> {
unsafe {
let mut index = std::mem::MaybeUninit::<ffi::cuvsVamanaIndex_t>::uninit();
check_cuvs(ffi::cuvsVamanaIndexCreate(index.as_mut_ptr()))?;
Ok(Index(index.assume_init()))
}
}
pub fn serialize(self, res: &Resources, filename: &str, include_dataset: bool) -> Result<()> {
let c_filename = CString::new(filename).unwrap();
unsafe {
check_cuvs(ffi::cuvsVamanaSerialize(
res.0,
c_filename.as_ptr(),
self.0,
include_dataset,
))
}
}
}
impl Drop for Index {
fn drop(&mut self) {
if let Err(e) = check_cuvs(unsafe { ffi::cuvsVamanaIndexDestroy(self.0) }) {
write!(stderr(), "failed to call cuvsVamanaIndexDestroy {:?}", e)
.expect("failed to write to stderr");
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray_rand::rand_distr::Uniform;
use ndarray_rand::RandomExt;
#[test]
fn test_vamana() {
let build_params = IndexParams::new().unwrap();
let res = Resources::new().unwrap();
let n_datapoints = 1024;
let n_features = 16;
let dataset =
ndarray::Array::<f32, _>::random((n_datapoints, n_features), Uniform::new(0., 1.0));
let dataset_device = ManagedTensor::from(&dataset).to_device(&res).unwrap();
let _index = Index::build(&res, &build_params, dataset_device)
.expect("failed to create vamana index");
}
}