use crate::error::{Error, Result};
use crate::faiss_try;
use crate::metric::MetricType;
use crate::search_params::SearchParameters;
use crate::selector::IdSelector;
use std::ffi::CString;
use std::fmt::{self, Display, Formatter, Write};
use std::os::raw::c_uint;
use std::{mem, ptr};
use faiss_sys::*;
pub mod autotune;
pub mod flat;
pub mod id_map;
pub mod io;
pub mod io_flags;
pub mod ivf_flat;
pub mod lsh;
pub mod pretransform;
pub mod refine_flat;
pub mod scalar_quantizer;
#[cfg(feature = "gpu")]
pub mod gpu;
#[repr(transparent)]
#[derive(Debug, Copy, Clone)]
pub struct Idx(idx_t);
impl Display for Idx {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
match self.get() {
None => f.write_char('x'),
Some(i) => i.fmt(f),
}
}
}
impl From<idx_t> for Idx {
fn from(x: idx_t) -> Self {
Idx(x)
}
}
impl Idx {
#[inline]
pub fn new(idx: u64) -> Self {
assert!(
idx < 0x8000_0000_0000_0000,
"too large index value provided to Idx::new"
);
let idx = idx as idx_t;
Idx(idx)
}
#[inline]
pub fn none() -> Self {
Idx(-1)
}
#[inline]
pub fn is_none(self) -> bool {
self.0 == -1
}
#[inline]
pub fn is_some(self) -> bool {
self.0 != -1
}
pub fn get(self) -> Option<u64> {
match self.0 {
-1 => None,
x => Some(x as u64),
}
}
pub fn to_native(self) -> idx_t {
self.0
}
}
impl PartialEq<Idx> for Idx {
fn eq(&self, idx: &Idx) -> bool {
self.0 != -1 && idx.0 != -1 && self.0 == idx.0
}
}
impl PartialOrd<Idx> for Idx {
fn partial_cmp(&self, idx: &Idx) -> Option<std::cmp::Ordering> {
match (self.get(), idx.get()) {
(None, _) => None,
(_, None) => None,
(Some(a), Some(b)) => Some(a.cmp(&b)),
}
}
}
pub trait Index<Data = f32, Radius = f32> {
fn is_trained(&self) -> bool;
fn ntotal(&self) -> u64;
fn d(&self) -> u32;
fn metric_type(&self) -> MetricType;
fn add(&mut self, x: &[Data]) -> Result<()>;
fn add_with_ids(&mut self, x: &[Data], xids: &[Idx]) -> Result<()>;
fn train(&mut self, x: &[Data]) -> Result<()>;
fn assign(&mut self, q: &[Data], k: usize) -> Result<AssignSearchResult>;
fn search(&mut self, q: &[Data], k: usize) -> Result<SearchResult<Radius>>;
fn range_search(&mut self, q: &[Data], radius: Radius) -> Result<RangeSearchResult>;
fn reconstruct(&self, key: Idx, output: &mut [Data]) -> Result<()>;
fn reconstruct_n(&self, first_key: Idx, count: usize, output: &mut [Data]) -> Result<()>;
fn reset(&mut self) -> Result<()>;
fn remove_ids(&mut self, sel: &IdSelector) -> Result<usize>;
fn verbose(&self) -> bool;
fn set_verbose(&mut self, value: bool);
}
impl<I, Data, Radius> Index<Data, Radius> for Box<I>
where
I: Index<Data, Radius>,
{
fn is_trained(&self) -> bool {
(**self).is_trained()
}
fn ntotal(&self) -> u64 {
(**self).ntotal()
}
fn d(&self) -> u32 {
(**self).d()
}
fn metric_type(&self) -> MetricType {
(**self).metric_type()
}
fn add(&mut self, x: &[Data]) -> Result<()> {
(**self).add(x)
}
fn add_with_ids(&mut self, x: &[Data], xids: &[Idx]) -> Result<()> {
(**self).add_with_ids(x, xids)
}
fn train(&mut self, x: &[Data]) -> Result<()> {
(**self).train(x)
}
fn assign(&mut self, q: &[Data], k: usize) -> Result<AssignSearchResult> {
(**self).assign(q, k)
}
fn search(&mut self, q: &[Data], k: usize) -> Result<SearchResult<Radius>> {
(**self).search(q, k)
}
fn range_search(&mut self, q: &[Data], radius: Radius) -> Result<RangeSearchResult> {
(**self).range_search(q, radius)
}
fn reconstruct(&self, key: Idx, output: &mut [Data]) -> Result<()> {
(**self).reconstruct(key, output)
}
fn reconstruct_n(&self, first_key: Idx, count: usize, output: &mut [Data]) -> Result<()> {
(**self).reconstruct_n(first_key, count, output)
}
fn reset(&mut self) -> Result<()> {
(**self).reset()
}
fn remove_ids(&mut self, sel: &IdSelector) -> Result<usize> {
(**self).remove_ids(sel)
}
fn verbose(&self) -> bool {
(**self).verbose()
}
fn set_verbose(&mut self, value: bool) {
(**self).set_verbose(value)
}
}
pub trait NativeIndex<Data = f32, Radius = f32>: Index<Data, Radius> {
type Inner;
fn inner_ptr(&self) -> *mut Self::Inner;
}
impl<Data, Radius, NI: NativeIndex<Data, Radius>> NativeIndex<Data, Radius> for Box<NI> {
type Inner = NI::Inner;
fn inner_ptr(&self) -> *mut Self::Inner {
(**self).inner_ptr()
}
}
pub trait ConcurrentIndex<Data = f32, Radius = f32>: Index<Data, Radius> {
fn assign(&self, q: &[Data], k: usize) -> Result<AssignSearchResult>;
fn search(&self, q: &[Data], k: usize) -> Result<SearchResult<Radius>>;
fn range_search(&self, q: &[Data], radius: Radius) -> Result<RangeSearchResult>;
}
impl<Data, Radius, CI: ConcurrentIndex<Data, Radius>> ConcurrentIndex<Data, Radius> for Box<CI> {
fn assign(&self, q: &[Data], k: usize) -> Result<AssignSearchResult> {
(**self).assign(q, k)
}
fn search(&self, q: &[Data], k: usize) -> Result<SearchResult<Radius>> {
(**self).search(q, k)
}
fn range_search(&self, q: &[Data], radius: Radius) -> Result<RangeSearchResult> {
(**self).range_search(q, radius)
}
}
pub trait CpuIndex<Data = f32, Radius = f32>: Index<Data, Radius> {}
impl<Data, Radius, CI: CpuIndex<Data, Radius>> CpuIndex<Data, Radius> for Box<CI> {}
pub trait FromInnerPtr<Data = f32, Radius = f32>: NativeIndex<Data, Radius> {
unsafe fn from_inner_ptr(inner_ptr: *mut Self::Inner) -> Self;
}
pub trait TryFromInnerPtr<Data = f32, Radius = f32>: NativeIndex<Data, Radius> {
unsafe fn try_from_inner_ptr(inner_ptr: *mut Self::Inner) -> Result<Self>
where
Self: Sized;
}
pub trait TryClone {
fn try_clone(&self) -> Result<Self>
where
Self: Sized;
}
pub fn try_clone_from_inner_ptr<T>(val: &T) -> Result<T>
where
T: FromInnerPtr<f32, f32, Inner = FaissIndex>,
{
unsafe {
let mut new_index_ptr = ::std::ptr::null_mut();
faiss_try(faiss_clone_index(val.inner_ptr(), &mut new_index_ptr))?;
Ok(crate::index::FromInnerPtr::from_inner_ptr(new_index_ptr))
}
}
pub fn try_clone_binary_from_inner_ptr<T>(val: &T) -> Result<T>
where
T: FromInnerPtr<u8, i32, Inner = FaissIndexBinary>,
{
unsafe {
let mut new_index_ptr = ::std::ptr::null_mut();
faiss_try(faiss_clone_index_binary(val.inner_ptr(), &mut new_index_ptr))?;
Ok(crate::index::FromInnerPtr::from_inner_ptr(new_index_ptr))
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct AssignSearchResult {
pub labels: Vec<Idx>,
}
#[derive(Debug, Clone, PartialEq)]
pub struct SearchResult<Data = f32> {
pub distances: Vec<Data>,
pub labels: Vec<Idx>,
}
#[derive(Debug, Clone, PartialEq)]
pub struct RangeSearchResult {
inner: *mut FaissRangeSearchResult,
}
impl RangeSearchResult {
pub fn nq(&self) -> usize {
unsafe { faiss_RangeSearchResult_nq(self.inner) }
}
pub fn lims(&self) -> &[usize] {
unsafe {
let mut lims_ptr = ptr::null_mut();
faiss_RangeSearchResult_lims(self.inner, &mut lims_ptr);
::std::slice::from_raw_parts(lims_ptr, self.nq() + 1)
}
}
pub fn distance_and_labels(&self) -> (&[f32], &[Idx]) {
let lims = self.lims();
let full_len = lims.last().cloned().unwrap_or(0);
unsafe {
let mut distances_ptr = ptr::null_mut();
let mut labels_ptr = ptr::null_mut();
faiss_RangeSearchResult_labels(self.inner, &mut labels_ptr, &mut distances_ptr);
let distances = ::std::slice::from_raw_parts(distances_ptr, full_len);
let labels = ::std::slice::from_raw_parts(labels_ptr as *const Idx, full_len);
(distances, labels)
}
}
pub fn distance_and_labels_mut(&mut self) -> (&mut [f32], &mut [Idx]) {
unsafe {
let buf_size = faiss_RangeSearchResult_buffer_size(self.inner);
let mut distances_ptr = ptr::null_mut();
let mut labels_ptr = ptr::null_mut();
faiss_RangeSearchResult_labels(self.inner, &mut labels_ptr, &mut distances_ptr);
let distances = ::std::slice::from_raw_parts_mut(distances_ptr, buf_size);
let labels = ::std::slice::from_raw_parts_mut(labels_ptr as *mut Idx, buf_size);
(distances, labels)
}
}
pub fn distances(&self) -> &[f32] {
self.distance_and_labels().0
}
pub fn distances_mut(&mut self) -> &mut [f32] {
self.distance_and_labels_mut().0
}
pub fn labels(&self) -> &[Idx] {
self.distance_and_labels().1
}
pub fn labels_mut(&mut self) -> &mut [Idx] {
self.distance_and_labels_mut().1
}
}
impl Drop for RangeSearchResult {
fn drop(&mut self) {
unsafe {
faiss_RangeSearchResult_free(self.inner);
}
}
}
#[derive(Debug)]
pub struct BinaryIndexImpl {
pub(crate) inner: *mut FaissIndexBinary,
}
unsafe impl Send for BinaryIndexImpl {}
unsafe impl Sync for BinaryIndexImpl {}
impl CpuIndex<u8, i32> for BinaryIndexImpl {}
impl Drop for BinaryIndexImpl {
fn drop(&mut self) {
unsafe {
faiss_IndexBinary_free(self.inner);
}
}
}
impl BinaryIndexImpl {
pub fn inner_ptr(&self) -> *mut FaissIndexBinary {
self.inner
}
}
impl NativeIndex<u8, i32> for BinaryIndexImpl {
type Inner = FaissIndexBinary;
fn inner_ptr(&self) -> *mut FaissIndexBinary {
self.inner
}
}
impl FromInnerPtr<u8, i32> for BinaryIndexImpl {
unsafe fn from_inner_ptr(inner_ptr: *mut FaissIndexBinary) -> Self {
BinaryIndexImpl { inner: inner_ptr }
}
}
impl TryFromInnerPtr<u8, i32> for BinaryIndexImpl {
unsafe fn try_from_inner_ptr(inner_ptr: *mut FaissIndexBinary) -> Result<Self>
where
Self: Sized {
if inner_ptr.is_null() {
Err(Error::BadCast)
} else {
Ok(BinaryIndexImpl { inner: inner_ptr })
}
}
}
#[derive(Debug)]
pub struct IndexImpl {
inner: *mut FaissIndex,
}
unsafe impl Send for IndexImpl {}
unsafe impl Sync for IndexImpl {}
impl CpuIndex for IndexImpl {}
impl Drop for IndexImpl {
fn drop(&mut self) {
unsafe {
faiss_Index_free(self.inner);
}
}
}
impl IndexImpl {
pub fn inner_ptr(&self) -> *mut FaissIndex {
self.inner
}
}
impl NativeIndex for IndexImpl {
type Inner = FaissIndex;
fn inner_ptr(&self) -> *mut Self::Inner {
self.inner
}
}
impl FromInnerPtr for IndexImpl {
unsafe fn from_inner_ptr(inner_ptr: *mut FaissIndex) -> Self {
IndexImpl { inner: inner_ptr }
}
}
impl TryFromInnerPtr for IndexImpl {
unsafe fn try_from_inner_ptr(inner_ptr: *mut FaissIndex) -> Result<Self>
where
Self: Sized,
{
if inner_ptr.is_null() {
Err(Error::BadCast)
} else {
Ok(IndexImpl { inner: inner_ptr })
}
}
}
pub trait UpcastIndex<Data = f32, Radius = f32>: NativeIndex<Data, Radius>
{
type Output: FromInnerPtr<Data, Radius, Inner = Self::Inner>;
fn upcast(self) -> Self::Output;
}
impl<NI: NativeIndex<f32, f32, Inner = FaissIndex>> UpcastIndex<f32, f32> for NI {
type Output = IndexImpl;
fn upcast(self) -> Self::Output {
let inner_ptr = self.inner_ptr();
mem::forget(self);
unsafe { IndexImpl::from_inner_ptr(inner_ptr) }
}
}
impl<NI: NativeIndex<u8, i32, Inner = FaissIndexBinary>> UpcastIndex<u8, i32> for NI {
type Output = BinaryIndexImpl;
fn upcast(self) -> Self::Output {
let inner_ptr = self.inner_ptr();
mem::forget(self);
unsafe { BinaryIndexImpl::from_inner_ptr(inner_ptr) }
}
}
impl_native_index!(IndexImpl);
impl_native_index_binary!(BinaryIndexImpl);
impl TryClone for IndexImpl {
fn try_clone(&self) -> Result<Self>
where
Self: Sized,
{
try_clone_from_inner_ptr(self)
}
}
impl TryClone for BinaryIndexImpl {
fn try_clone(&self) -> Result<Self>
where
Self: Sized,
{
try_clone_binary_from_inner_ptr(self)
}
}
pub trait SearchWithParamsMut<Data, Radius> {
fn search_with_params(&mut self, query: &[Data], k: usize, params: &SearchParameters) -> Result<SearchResult<Radius>>;
}
impl<I> SearchWithParamsMut<f32, f32> for I
where
I: NativeIndex<f32, f32, Inner = FaissIndex>
{
fn search_with_params(&mut self, query: &[f32], k: usize, params: &SearchParameters) -> Result<SearchResult<f32>> {
unsafe {
let nq = query.len() / (self.d() as usize);
let mut distances = vec![0.; k * nq];
let mut labels = vec![Idx::none(); k * nq];
faiss_try(faiss_Index_search_with_params(
self.inner_ptr(),
nq as idx_t,
query.as_ptr(),
k as idx_t,
params.inner_ptr() as *const _,
distances.as_mut_ptr(),
labels.as_mut_ptr() as *mut _,
))?;
Ok(SearchResult { distances, labels })
}
}
}
pub trait SearchWithParams<Data, Radius> {
fn search_with_params(&self, query: &[Data], k: usize, params: &SearchParameters) -> Result<SearchResult<Radius>>;
}
impl<I> SearchWithParams<f32, f32> for I
where
I: NativeIndex<f32, f32, Inner = FaissIndex>,
I: ConcurrentIndex<f32, f32>,
{
fn search_with_params(&self, query: &[f32], k: usize, params: &SearchParameters) -> Result<SearchResult<f32>> {
unsafe {
let nq = query.len() / (self.d() as usize);
let mut distances = vec![0.; k * nq];
let mut labels = vec![Idx::none(); k * nq];
faiss_try(faiss_Index_search_with_params(
self.inner_ptr(),
nq as idx_t,
query.as_ptr(),
k as idx_t,
params.inner_ptr() as *const _,
distances.as_mut_ptr(),
labels.as_mut_ptr() as *mut _,
))?;
Ok(SearchResult { distances, labels })
}
}
}
pub fn index_factory<D>(d: u32, description: D, metric: MetricType) -> Result<IndexImpl>
where
D: AsRef<str>,
{
unsafe {
let metric = metric as c_uint;
let description =
CString::new(description.as_ref()).map_err(|_| Error::IndexDescription)?;
let mut index_ptr = ::std::ptr::null_mut();
faiss_try(faiss_index_factory(
&mut index_ptr,
(d & 0x7FFF_FFFF) as i32,
description.as_ptr(),
metric,
))?;
Ok(IndexImpl { inner: index_ptr })
}
}
pub fn index_binary_factory<D>(d: u32, description: D) -> Result<BinaryIndexImpl>
where
D: AsRef<str>,
{
unsafe {
let description =
CString::new(description.as_ref()).map_err(|_| Error::IndexDescription)?;
let mut index_ptr = ::std::ptr::null_mut();
faiss_try(faiss_index_binary_factory(
&mut index_ptr,
(d & 0x7FFF_FFFF) as i32,
description.as_ptr(),
))?;
Ok(BinaryIndexImpl { inner: index_ptr })
}
}
#[cfg(test)]
mod tests {
use super::{index_binary_factory, index_factory, Idx, Index, TryClone};
use crate::metric::MetricType;
#[test]
fn index_factory_flat() {
let index = index_factory(64, "Flat", MetricType::L2).unwrap();
assert_eq!(index.is_trained(), true); assert_eq!(index.ntotal(), 0);
}
#[test]
fn index_binary_factory_flat() {
let index = index_binary_factory(256, "BFlat").unwrap();
assert_eq!(index.is_trained(), true);
assert_eq!(index.ntotal(), 0);
}
#[test]
fn bad_dimension_index_binary_factory_not_divisible_by_8_flat() {
let r = index_binary_factory(9, "BFlat");
assert!(r.is_err());
}
#[test]
fn index_factory_flat_boxed() {
let index = index_factory(64, "Flat", MetricType::L2).unwrap();
let boxed = Box::new(index);
assert_eq!(boxed.is_trained(), true); assert_eq!(boxed.ntotal(), 0);
}
#[test]
fn index_binary_factory_flat_boxed() {
let index = index_binary_factory(256, "BFlat").unwrap();
let boxed = Box::new(index);
assert_eq!(boxed.is_trained(), true);
assert_eq!(boxed.ntotal(), 0);
}
#[test]
fn index_factory_ivf_flat() {
let index = index_factory(64, "IVF8,Flat", MetricType::L2).unwrap();
assert_eq!(index.is_trained(), false);
assert_eq!(index.ntotal(), 0);
}
#[test]
fn index_binary_factory_ivf_flat() {
let index = index_binary_factory(256, "BIVF1024").unwrap();
assert_eq!(index.is_trained(), false);
assert_eq!(index.ntotal(), 0);
}
#[test]
fn index_factory_sq() {
let index = index_factory(64, "SQ8", MetricType::L2).unwrap();
assert_eq!(index.is_trained(), false);
assert_eq!(index.ntotal(), 0);
}
#[test]
fn index_factory_pq() {
let index = index_factory(64, "PQ8", MetricType::L2).unwrap();
assert_eq!(index.is_trained(), false);
assert_eq!(index.ntotal(), 0);
}
#[test]
fn index_factory_ivf_sq() {
let index = index_factory(64, "IVF8,SQ4", MetricType::L2).unwrap();
assert_eq!(index.is_trained(), false);
assert_eq!(index.ntotal(), 0);
let index = index_factory(64, "IVF8,SQ8", MetricType::L2).unwrap();
assert_eq!(index.is_trained(), false);
assert_eq!(index.ntotal(), 0);
}
#[test]
fn index_factory_hnsw() {
let index = index_factory(64, "HNSW8", MetricType::L2).unwrap();
assert_eq!(index.is_trained(), true); assert_eq!(index.ntotal(), 0);
}
#[test]
fn bad_index_factory_description() {
let r = index_factory(64, "fdnoyq", MetricType::L2);
assert!(r.is_err());
let r = index_factory(64, "Flat\0Flat", MetricType::L2);
assert!(r.is_err());
}
#[test]
fn bad_index_binary_factory_description() {
let r = index_binary_factory(64, "Bjkads");
assert!(r.is_err());
let r = index_binary_factory(64, "Flat");
assert!(r.is_err());
}
#[test]
fn index_clone() {
let mut index = index_factory(4, "Flat", MetricType::L2).unwrap();
let some_data = &[
7.5_f32, -7.5, 7.5, -7.5, 7.5, 7.5, 7.5, 7.5, -1., 1., 1., 1., 1., 1., 1., -1., 0., 0.,
0., 1., 1., 0., 0., -1.,
];
index.add(some_data).unwrap();
assert_eq!(index.ntotal(), 6);
let mut index2 = index.try_clone().unwrap();
assert_eq!(index2.ntotal(), 6);
let some_more_data = &[
100., 100., 100., 100., -100., 100., 100., 100., 120., 100., 100., 105., -100., 100.,
100., 105.,
];
index2.add(some_more_data).unwrap();
assert_eq!(index.ntotal(), 6);
assert_eq!(index2.ntotal(), 10);
}
#[test]
fn index_binary_clone() {
let mut index = index_binary_factory(16, "BFlat").unwrap();
let some_data = &[255u8, 0, 1, 16];
index.add(some_data).unwrap();
assert_eq!(index.ntotal(), 2);
let mut index2 = index.try_clone().unwrap();
assert_eq!(index2.ntotal(), 2);
let some_more_data = &[2u8, 3, 4, 5];
index2.add(some_more_data).unwrap();
assert_eq!(index.ntotal(), 2);
assert_eq!(index2.ntotal(), 4);
}
#[test]
fn flat_index_search() {
let mut index = index_factory(8, "Flat", MetricType::L2).unwrap();
let some_data = &[
7.5_f32, -7.5, 7.5, -7.5, 7.5, 7.5, 7.5, 7.5, -1., 1., 1., 1., 1., 1., 1., -1., 0., 0.,
0., 1., 1., 0., 0., -1., 100., 100., 100., 100., -100., 100., 100., 100., 120., 100.,
100., 105., -100., 100., 100., 105.,
];
index.add(some_data).unwrap();
assert_eq!(index.ntotal(), 5);
let my_query = [0.; 8];
let result = index.search(&my_query, 5).unwrap();
assert_eq!(result.labels, vec![Idx(2), Idx(1), Idx(0), Idx(3), Idx(4)]);
assert!(result.distances.iter().all(|x| *x > 0.));
let my_query = [100.; 8];
let result = index.search(&my_query, 5).unwrap();
assert_eq!(result.labels, vec![Idx(3), Idx(4), Idx(0), Idx(1), Idx(2)]);
assert!(result.distances.iter().all(|x| *x > 0.));
let my_query = vec![
0., 0., 0., 0., 0., 0., 0., 0., 100., 100., 100., 100., 100., 100., 100., 100.,
];
let result = index.search(&my_query, 5).unwrap();
assert_eq!(
result.labels,
vec![
Idx(2),
Idx(1),
Idx(0),
Idx(3),
Idx(4),
Idx(3),
Idx(4),
Idx(0),
Idx(1),
Idx(2)
]
);
assert!(result.distances.iter().all(|x| *x > 0.));
}
#[test]
fn flat_index_assign() {
let mut index = index_factory(8, "Flat", MetricType::L2).unwrap();
assert_eq!(index.d(), 8);
assert_eq!(index.ntotal(), 0);
let some_data = &[
7.5_f32, -7.5, 7.5, -7.5, 7.5, 7.5, 7.5, 7.5, -1., 1., 1., 1., 1., 1., 1., -1., 0., 0.,
0., 1., 1., 0., 0., -1., 100., 100., 100., 100., -100., 100., 100., 100., 120., 100.,
100., 105., -100., 100., 100., 105.,
];
index.add(some_data).unwrap();
assert_eq!(index.ntotal(), 5);
let my_query = [0.; 8];
let result = index.assign(&my_query, 5).unwrap();
assert_eq!(result.labels, vec![Idx(2), Idx(1), Idx(0), Idx(3), Idx(4)]);
let my_query = [0.; 32];
let result = index.assign(&my_query, 5).unwrap();
assert_eq!(
result.labels,
vec![2, 1, 0, 3, 4, 2, 1, 0, 3, 4, 2, 1, 0, 3, 4, 2, 1, 0, 3, 4]
.into_iter()
.map(Idx)
.collect::<Vec<_>>()
);
let my_query = [100.; 8];
let result = index.assign(&my_query, 5).unwrap();
assert_eq!(
result.labels,
vec![3, 4, 0, 1, 2].into_iter().map(Idx).collect::<Vec<_>>()
);
let my_query = vec![
0., 0., 0., 0., 0., 0., 0., 0., 100., 100., 100., 100., 100., 100., 100., 100.,
];
let result = index.assign(&my_query, 5).unwrap();
assert_eq!(
result.labels,
vec![2, 1, 0, 3, 4, 3, 4, 0, 1, 2]
.into_iter()
.map(Idx)
.collect::<Vec<_>>()
);
index.reset().unwrap();
assert_eq!(index.ntotal(), 0);
}
#[test]
fn flat_index_range_search() {
let mut index = index_factory(8, "Flat", MetricType::L2).unwrap();
let some_data = &[
7.5_f32, -7.5, 7.5, -7.5, 7.5, 7.5, 7.5, 7.5, -1., 1., 1., 1., 1., 1., 1., -1., 0., 0.,
0., 1., 1., 0., 0., -1., 100., 100., 100., 100., -100., 100., 100., 100., 120., 100.,
100., 105., -100., 100., 100., 105.,
];
index.add(some_data).unwrap();
assert_eq!(index.ntotal(), 5);
let my_query = [0.; 8];
let result = index.range_search(&my_query, 8.125).unwrap();
let (distances, labels) = result.distance_and_labels();
assert!(labels == &[Idx(1), Idx(2)] || labels == &[Idx(2), Idx(1)]);
assert!(distances.iter().all(|x| *x > 0.));
}
#[test]
fn flat_index_reconstruct() {
let mut index = index_factory(8, "Flat", MetricType::L2).unwrap();
let some_data = &[
7.5_f32, -7.5, 7.5, -7.5, 7.5, 7.5, 7.5, 7.5, -1., 1., 1., 1., 1., 1., 1., -1., 0., 0.,
0., 1., 1., 0., 0., -1., 100., 100., 100., 100., -100., 100., 100., 100., 120., 100.,
100., 105., -100., 100., 100., 105.,
];
index.add(some_data).unwrap();
assert_eq!(index.ntotal(), 5);
let mut output = vec![0.; 8];
index.reconstruct(Idx(0), &mut output).unwrap();
assert_eq!(output, vec![7.5_f32, -7.5, 7.5, -7.5, 7.5, 7.5, 7.5, 7.5]);
let mut output = vec![0.; 16];
index.reconstruct_n(Idx(0), 2, &mut output).unwrap();
assert_eq!(output, vec![7.5_f32, -7.5, 7.5, -7.5, 7.5, 7.5, 7.5, 7.5, -1., 1., 1., 1., 1., 1., 1., -1.]);
}
#[test]
fn flat_index_binary_reconstruct() {
let mut index = index_binary_factory(16, "BFlat").unwrap();
let some_data = &[255u8, 0, 1, 16];
index.add(some_data).unwrap();
assert_eq!(index.ntotal(), 2);
let mut output = vec![0; 2];
index.reconstruct(Idx(0), &mut output).unwrap();
assert_eq!(output, vec![255u8, 0]);
let mut output = vec![0; 4];
index.reconstruct_n(Idx(0), 2, &mut output).unwrap();
assert_eq!(output, vec![255u8, 0, 1, 16]);
}
}