use super::*;
use crate::error::Result;
use crate::faiss_try;
use std::mem;
use std::os::raw::{c_char, c_int};
use std::ptr;
pub type IVFFlatIndex = IVFFlatIndexImpl;
#[derive(Debug)]
pub struct IVFFlatIndexImpl {
inner: *mut FaissIndexIVFFlat,
}
unsafe impl Send for IVFFlatIndexImpl {}
unsafe impl Sync for IVFFlatIndexImpl {}
impl CpuIndex for IVFFlatIndexImpl {}
impl Drop for IVFFlatIndexImpl {
fn drop(&mut self) {
unsafe {
faiss_IndexIVFFlat_free(self.inner);
}
}
}
impl IVFFlatIndexImpl {
fn new_helper(
quantizer: &flat::FlatIndex,
d: u32,
nlist: u32,
metric: MetricType,
own_fields: bool,
) -> Result<Self> {
unsafe {
let metric = metric as c_uint;
let mut inner = ptr::null_mut();
faiss_try(faiss_IndexIVFFlat_new_with_metric(
&mut inner,
quantizer.inner_ptr(),
d as usize,
nlist as usize,
metric,
))?;
faiss_IndexIVFFlat_set_own_fields(inner, c_int::from(own_fields));
Ok(IVFFlatIndexImpl { inner })
}
}
pub fn new(quantizer: flat::FlatIndex, d: u32, nlist: u32, metric: MetricType) -> Result<Self> {
let index = IVFFlatIndexImpl::new_helper(&quantizer, d, nlist, metric, true)?;
std::mem::forget(quantizer);
Ok(index)
}
pub fn new_l2(quantizer: flat::FlatIndex, d: u32, nlist: u32) -> Result<Self> {
IVFFlatIndexImpl::new(quantizer, d, nlist, MetricType::L2)
}
pub fn new_ip(quantizer: flat::FlatIndex, d: u32, nlist: u32) -> Result<Self> {
IVFFlatIndexImpl::new(quantizer, d, nlist, MetricType::InnerProduct)
}
pub fn nprobe(&self) -> u32 {
unsafe { faiss_IndexIVFFlat_nprobe(self.inner_ptr()) as u32 }
}
pub fn set_nprobe(&mut self, value: u32) {
unsafe {
faiss_IndexIVFFlat_set_nprobe(self.inner_ptr(), value as usize);
}
}
pub fn nlist(&self) -> u32 {
unsafe { faiss_IndexIVFFlat_nlist(self.inner_ptr()) as u32 }
}
pub fn train_type(&self) -> Option<TrainType> {
unsafe {
let code = faiss_IndexIVFFlat_quantizer_trains_alone(self.inner_ptr());
TrainType::from_code(code)
}
}
}
#[derive(Debug, Copy, Clone, Eq, Hash, PartialEq)]
pub enum TrainType {
QuantizerAsIndex,
QuantizerTrainsAlone,
FlatIndexAndQuantizer,
}
impl TrainType {
pub(crate) fn from_code(code: c_char) -> Option<Self> {
match code {
0 => Some(TrainType::QuantizerAsIndex),
1 => Some(TrainType::QuantizerTrainsAlone),
2 => Some(TrainType::FlatIndexAndQuantizer),
_ => None,
}
}
}
impl NativeIndex for IVFFlatIndexImpl {
fn inner_ptr(&self) -> *mut FaissIndex {
self.inner
}
}
impl FromInnerPtr for IVFFlatIndexImpl {
unsafe fn from_inner_ptr(inner_ptr: *mut FaissIndex) -> Self {
IVFFlatIndexImpl {
inner: inner_ptr as *mut FaissIndexIVFFlat,
}
}
}
impl_native_index!(IVFFlatIndex);
impl TryClone for IVFFlatIndexImpl {
fn try_clone(&self) -> Result<Self>
where
Self: Sized,
{
try_clone_from_inner_ptr(self)
}
}
impl_concurrent_index!(IVFFlatIndexImpl);
impl IndexImpl {
pub fn into_ivf_flat(self) -> Result<IVFFlatIndexImpl> {
unsafe {
let new_inner = faiss_IndexIVFFlat_cast(self.inner_ptr());
if new_inner.is_null() {
Err(Error::BadCast)
} else {
mem::forget(self);
Ok(IVFFlatIndexImpl { inner: new_inner })
}
}
}
}
#[cfg(test)]
mod tests {
use super::IVFFlatIndexImpl;
use crate::index::flat::FlatIndexImpl;
use crate::index::{index_factory, ConcurrentIndex, Idx, Index, UpcastIndex};
use crate::MetricType;
const D: u32 = 8;
#[test]
fn index_search() {
let q = FlatIndexImpl::new_l2(D).unwrap();
let mut index = IVFFlatIndexImpl::new_l2(q, D, 1).unwrap();
assert_eq!(index.d(), D);
assert_eq!(index.ntotal(), 0);
let some_data = &[
7.5_f32, -7.5, 7.5, -7.5, 7.5, 7.5, 7.5, 7.5, -1., 1., 1., 1., 1., 1., 1., -1., 4.,
-4., -8., 1., 1., 2., 4., -1., 8., 8., 10., -10., -10., 10., -10., 10., 16., 16., 32.,
25., 20., 20., 40., 15.,
];
index.train(some_data).unwrap();
index.add(some_data).unwrap();
assert_eq!(index.ntotal(), 5);
let my_query = [0.; D as usize];
let result = index.search(&my_query, 3).unwrap();
assert_eq!(result.labels.len(), 3);
assert!(result.labels.into_iter().all(Idx::is_some));
assert_eq!(result.distances.len(), 3);
assert!(result.distances.iter().all(|x| *x > 0.));
let my_query = [100.; D as usize];
let result = (&index).search(&my_query, 3).unwrap();
assert_eq!(result.labels.len(), 3);
assert!(result.labels.into_iter().all(Idx::is_some));
assert_eq!(result.distances.len(), 3);
assert!(result.distances.iter().all(|x| *x > 0.));
index.reset().unwrap();
assert_eq!(index.ntotal(), 0);
}
#[test]
fn index_search_own() {
let q = FlatIndexImpl::new_l2(D).unwrap();
let mut index = IVFFlatIndexImpl::new_l2(q, D, 1).unwrap();
assert_eq!(index.d(), D);
assert_eq!(index.ntotal(), 0);
let some_data = &[
7.5_f32, -7.5, 7.5, -7.5, 7.5, 7.5, 7.5, 7.5, -1., 1., 1., 1., 1., 1., 1., -1., 4.,
-4., -8., 1., 1., 2., 4., -1., 8., 8., 10., -10., -10., 10., -10., 10., 16., 16., 32.,
25., 20., 20., 40., 15.,
];
index.train(some_data).unwrap();
index.add(some_data).unwrap();
assert_eq!(index.ntotal(), 5);
let my_query = [0.; D as usize];
let result = index.search(&my_query, 3).unwrap();
assert_eq!(result.labels.len(), 3);
assert!(result.labels.into_iter().all(Idx::is_some));
assert_eq!(result.distances.len(), 3);
assert!(result.distances.iter().all(|x| *x > 0.));
let my_query = [100.; D as usize];
let result = (&index).search(&my_query, 3).unwrap();
assert_eq!(result.labels.len(), 3);
assert!(result.labels.into_iter().all(Idx::is_some));
assert_eq!(result.distances.len(), 3);
assert!(result.distances.iter().all(|x| *x > 0.));
index.reset().unwrap();
assert_eq!(index.ntotal(), 0);
}
#[test]
fn index_assign() {
let q = FlatIndexImpl::new_l2(D).unwrap();
let mut index = IVFFlatIndexImpl::new_l2(q, D, 1).unwrap();
assert_eq!(index.d(), D);
assert_eq!(index.ntotal(), 0);
let some_data = &[
7.5_f32, -7.5, 7.5, -7.5, 7.5, 7.5, 7.5, 7.5, -1., 1., 1., 1., 1., 1., 1., -1., 4.,
-4., -8., 1., 1., 2., 4., -1., 8., 8., 10., -10., -10., 10., -10., 10., 16., 16., 32.,
25., 20., 20., 40., 15.,
];
index.train(some_data).unwrap();
index.add(some_data).unwrap();
assert_eq!(index.ntotal(), 5);
let my_query = [0.; D as usize];
let result = index.assign(&my_query, 3).unwrap();
assert_eq!(result.labels.len(), 3);
assert!(result.labels.into_iter().all(Idx::is_some));
let my_query = [100.; D as usize];
let result = (&index).assign(&my_query, 3).unwrap();
assert_eq!(result.labels.len(), 3);
assert!(result.labels.into_iter().all(Idx::is_some));
index.reset().unwrap();
assert_eq!(index.ntotal(), 0);
}
#[test]
fn ivf_flat_index_from_cast() {
let mut index = index_factory(8, "IVF1,Flat", MetricType::L2).unwrap();
let some_data = &[
7.5_f32, -7.5, 7.5, -7.5, 7.5, 7.5, 7.5, 7.5, -1., 1., 1., 1., 1., 1., 1., -1., 0., 0.,
0., 1., 1., 0., 0., -1., 100., 100., 100., 100., -100., 100., 100., 100., 120., 100.,
100., 105., -100., 100., 100., 105.,
];
index.train(some_data).unwrap();
index.add(some_data).unwrap();
assert_eq!(index.ntotal(), 5);
let index: IVFFlatIndexImpl = index.into_ivf_flat().unwrap();
assert_eq!(index.is_trained(), true);
assert_eq!(index.ntotal(), 5);
}
#[test]
fn index_upcast() {
let q = FlatIndexImpl::new_l2(D).unwrap();
let index = IVFFlatIndexImpl::new_l2(q, D, 1).unwrap();
assert_eq!(index.d(), D);
assert_eq!(index.ntotal(), 0);
let index_impl = index.upcast();
assert_eq!(index_impl.d(), D);
}
}