use std::io::{stderr, Write};
use crate::ivf_flat::{IndexParams, SearchParams};
use crate::dlpack::ManagedTensor;
use crate::error::{check_cuvs, Result};
use crate::resources::Resources;
#[derive(Debug)]
pub struct Index(ffi::cuvsIvfFlatIndex_t);
impl Index {
pub fn build<T: Into<ManagedTensor>>(
res: &Resources,
params: &IndexParams,
dataset: T,
) -> Result<Index> {
let dataset: ManagedTensor = dataset.into();
let index = Index::new()?;
unsafe {
check_cuvs(ffi::cuvsIvfFlatBuild(
res.0,
params.0,
dataset.as_ptr(),
index.0,
))?;
}
Ok(index)
}
pub fn new() -> Result<Index> {
unsafe {
let mut index = std::mem::MaybeUninit::<ffi::cuvsIvfFlatIndex_t>::uninit();
check_cuvs(ffi::cuvsIvfFlatIndexCreate(index.as_mut_ptr()))?;
Ok(Index(index.assume_init()))
}
}
pub fn search(
self,
res: &Resources,
params: &SearchParams,
queries: &ManagedTensor,
neighbors: &ManagedTensor,
distances: &ManagedTensor,
) -> Result<()> {
unsafe {
check_cuvs(ffi::cuvsIvfFlatSearch(
res.0,
params.0,
self.0,
queries.as_ptr(),
neighbors.as_ptr(),
distances.as_ptr(),
))
}
}
}
impl Drop for Index {
fn drop(&mut self) {
if let Err(e) = check_cuvs(unsafe { ffi::cuvsIvfFlatIndexDestroy(self.0) }) {
write!(stderr(), "failed to call cuvsIvfFlatIndexDestroy {:?}", e)
.expect("failed to write to stderr");
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::s;
use ndarray_rand::rand_distr::Uniform;
use ndarray_rand::RandomExt;
#[test]
fn test_ivf_flat() {
let build_params = IndexParams::new().unwrap().set_n_lists(64);
let res = Resources::new().unwrap();
let n_datapoints = 1024;
let n_features = 16;
let dataset =
ndarray::Array::<f32, _>::random((n_datapoints, n_features), Uniform::new(0., 1.0));
let dataset_device = ManagedTensor::from(&dataset).to_device(&res).unwrap();
let index =
Index::build(&res, &build_params, dataset_device).expect("failed to create ivf-flat index");
let n_queries = 4;
let queries = dataset.slice(s![0..n_queries, ..]);
let k = 10;
let queries = ManagedTensor::from(&queries).to_device(&res).unwrap();
let mut neighbors_host = ndarray::Array::<i64, _>::zeros((n_queries, k));
let neighbors = ManagedTensor::from(&neighbors_host)
.to_device(&res)
.unwrap();
let mut distances_host = ndarray::Array::<f32, _>::zeros((n_queries, k));
let distances = ManagedTensor::from(&distances_host)
.to_device(&res)
.unwrap();
let search_params = SearchParams::new().unwrap();
index
.search(&res, &search_params, &queries, &neighbors, &distances)
.unwrap();
distances.to_host(&res, &mut distances_host).unwrap();
neighbors.to_host(&res, &mut neighbors_host).unwrap();
assert_eq!(neighbors_host[[0, 0]], 0);
assert_eq!(neighbors_host[[1, 0]], 1);
assert_eq!(neighbors_host[[2, 0]], 2);
assert_eq!(neighbors_host[[3, 0]], 3);
}
}