use crate::error::{Error, Result};
use crate::metric::MetricType;
use crate::selector::IdSelector;
use std::ffi::CString;
use std::fmt::{self, Display, Formatter, Write};
use std::os::raw::c_uint;
use std::ptr;
use faiss_sys::*;
pub mod flat;
pub mod id_map;
pub mod io;
pub mod lsh;
#[cfg(feature = "gpu")]
pub mod gpu;
#[repr(transparent)]
#[derive(Debug, Copy, Clone)]
pub struct Idx(idx_t);
impl Display for Idx {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
match self.get() {
None => f.write_char('x'),
Some(i) => i.fmt(f),
}
}
}
impl From<idx_t> for Idx {
fn from(x: idx_t) -> Self {
Idx(x)
}
}
impl Idx {
#[inline]
pub fn new(idx: u64) -> Self {
assert!(
idx < 0x8000_0000_0000_0000,
"too large index value provided to Idx::new"
);
let idx = idx as idx_t;
Idx(idx)
}
#[inline]
pub fn none() -> Self {
Idx(-1)
}
#[inline]
pub fn is_none(self) -> bool {
self.0 == -1
}
#[inline]
pub fn is_some(self) -> bool {
self.0 != -1
}
pub fn get(self) -> Option<u64> {
match self.0 {
-1 => None,
x => Some(x as u64),
}
}
pub fn to_native(self) -> idx_t {
self.0
}
}
impl PartialEq<Idx> for Idx {
fn eq(&self, idx: &Idx) -> bool {
self.0 != -1 && idx.0 != -1 && self.0 == idx.0
}
}
impl PartialOrd<Idx> for Idx {
fn partial_cmp(&self, idx: &Idx) -> Option<std::cmp::Ordering> {
match (self.get(), idx.get()) {
(None, _) => None,
(_, None) => None,
(Some(a), Some(b)) => Some(a.cmp(&b)),
}
}
}
pub trait Index {
fn is_trained(&self) -> bool;
fn ntotal(&self) -> u64;
fn d(&self) -> u32;
fn metric_type(&self) -> MetricType;
fn add(&mut self, x: &[f32]) -> Result<()>;
fn add_with_ids(&mut self, x: &[f32], xids: &[Idx]) -> Result<()>;
fn train(&mut self, x: &[f32]) -> Result<()>;
fn assign(&mut self, q: &[f32], k: usize) -> Result<AssignSearchResult>;
fn search(&mut self, q: &[f32], k: usize) -> Result<SearchResult>;
fn range_search(&mut self, q: &[f32], radius: f32) -> Result<RangeSearchResult>;
fn reset(&mut self) -> Result<()>;
fn remove_ids(&mut self, sel: &IdSelector) -> Result<i64>;
}
pub trait NativeIndex: Index {
fn inner_ptr(&self) -> *mut FaissIndex;
}
pub trait ConcurrentIndex: Index {
fn assign(&self, q: &[f32], k: usize) -> Result<AssignSearchResult>;
fn search(&self, q: &[f32], k: usize) -> Result<SearchResult>;
fn range_search(&self, q: &[f32], radius: f32) -> Result<RangeSearchResult>;
}
pub trait CpuIndex: Index {}
pub trait FromInnerPtr: NativeIndex {
unsafe fn from_inner_ptr(inner_ptr: *mut FaissIndex) -> Self;
}
#[derive(Debug, Clone, PartialEq)]
pub struct AssignSearchResult {
pub labels: Vec<Idx>,
}
#[derive(Debug, Clone, PartialEq)]
pub struct SearchResult {
pub distances: Vec<f32>,
pub labels: Vec<Idx>,
}
#[derive(Debug, Clone, PartialEq)]
pub struct RangeSearchResult {
inner: *mut FaissRangeSearchResult,
}
impl RangeSearchResult {
pub fn nq(&self) -> usize {
unsafe { faiss_RangeSearchResult_nq(self.inner) }
}
pub fn lims(&self) -> &[usize] {
unsafe {
let mut lims_ptr = ptr::null_mut();
faiss_RangeSearchResult_lims(self.inner, &mut lims_ptr);
::std::slice::from_raw_parts(lims_ptr, self.nq() + 1)
}
}
pub fn distance_and_labels(&self) -> (&[f32], &[Idx]) {
let lims = self.lims();
let full_len = lims.last().cloned().unwrap_or(0);
unsafe {
let mut distances_ptr = ptr::null_mut();
let mut labels_ptr = ptr::null_mut();
faiss_RangeSearchResult_labels(self.inner, &mut labels_ptr, &mut distances_ptr);
let distances = ::std::slice::from_raw_parts(distances_ptr, full_len);
let labels = ::std::slice::from_raw_parts(labels_ptr as *const Idx, full_len);
(distances, labels)
}
}
pub fn distance_and_labels_mut(&self) -> (&mut [f32], &mut [Idx]) {
unsafe {
let buf_size = faiss_RangeSearchResult_buffer_size(self.inner);
let mut distances_ptr = ptr::null_mut();
let mut labels_ptr = ptr::null_mut();
faiss_RangeSearchResult_labels(self.inner, &mut labels_ptr, &mut distances_ptr);
let distances = ::std::slice::from_raw_parts_mut(distances_ptr, buf_size);
let labels = ::std::slice::from_raw_parts_mut(labels_ptr as *mut Idx, buf_size);
(distances, labels)
}
}
pub fn distances(&self) -> &[f32] {
self.distance_and_labels().0
}
pub fn distances_mut(&mut self) -> &mut [f32] {
self.distance_and_labels_mut().0
}
pub fn labels(&self) -> &[Idx] {
self.distance_and_labels().1
}
pub fn labels_mut(&mut self) -> &mut [Idx] {
self.distance_and_labels_mut().1
}
}
impl Drop for RangeSearchResult {
fn drop(&mut self) {
unsafe {
faiss_RangeSearchResult_free(self.inner);
}
}
}
#[derive(Debug)]
pub struct IndexImpl {
inner: *mut FaissIndex,
}
unsafe impl Send for IndexImpl {}
unsafe impl Sync for IndexImpl {}
impl CpuIndex for IndexImpl {}
impl Drop for IndexImpl {
fn drop(&mut self) {
unsafe {
faiss_Index_free(self.inner);
}
}
}
impl IndexImpl {
pub fn inner_ptr(&self) -> *mut FaissIndex {
self.inner
}
}
impl NativeIndex for IndexImpl {
fn inner_ptr(&self) -> *mut FaissIndex {
self.inner
}
}
impl FromInnerPtr for IndexImpl {
unsafe fn from_inner_ptr(inner_ptr: *mut FaissIndex) -> Self {
IndexImpl { inner: inner_ptr }
}
}
impl_native_index!(IndexImpl);
impl_native_index_clone!(IndexImpl);
pub fn index_factory<D>(d: u32, description: D, metric: MetricType) -> Result<IndexImpl>
where
D: AsRef<str>,
{
unsafe {
let metric = metric as c_uint;
let description =
CString::new(description.as_ref()).map_err(|_| Error::IndexDescription)?;
let mut index_ptr = ::std::ptr::null_mut();
faiss_try!(faiss_index_factory(
&mut index_ptr,
(d & 0x7FFF_FFFF) as i32,
description.as_ptr(),
metric
));
Ok(IndexImpl { inner: index_ptr })
}
}
#[cfg(test)]
mod tests {
use super::{index_factory, Idx, Index};
use crate::metric::MetricType;
#[test]
fn index_factory_flat() {
let index = index_factory(64, "Flat", MetricType::L2).unwrap();
assert_eq!(index.is_trained(), true);
assert_eq!(index.ntotal(), 0);
}
#[test]
fn index_factory_ivf_flat() {
let index = index_factory(64, "IVF8,Flat", MetricType::L2).unwrap();
assert_eq!(index.is_trained(), false);
assert_eq!(index.ntotal(), 0);
}
#[test]
fn index_factory_sq() {
let index = index_factory(64, "SQ8", MetricType::L2).unwrap();
assert_eq!(index.is_trained(), false);
assert_eq!(index.ntotal(), 0);
}
#[test]
fn index_factory_pq() {
let index = index_factory(64, "PQ8", MetricType::L2).unwrap();
assert_eq!(index.is_trained(), false);
assert_eq!(index.ntotal(), 0);
}
#[test]
fn index_factory_ivf_sq() {
let index = index_factory(64, "IVF8,SQ4", MetricType::L2).unwrap();
assert_eq!(index.is_trained(), false);
assert_eq!(index.ntotal(), 0);
let index = index_factory(64, "IVF8,SQ8", MetricType::L2).unwrap();
assert_eq!(index.is_trained(), false);
assert_eq!(index.ntotal(), 0);
}
#[test]
fn index_factory_hnsw() {
let index = index_factory(64, "HNSW8", MetricType::L2).unwrap();
assert_eq!(index.is_trained(), true);
assert_eq!(index.ntotal(), 0);
}
#[test]
fn bad_index_factory_description() {
let r = index_factory(64, "fdnoyq", MetricType::L2);
assert!(r.is_err());
let r = index_factory(64, "Flat\0Flat", MetricType::L2);
assert!(r.is_err());
}
#[test]
fn index_clone() {
let mut index = index_factory(4, "Flat", MetricType::L2).unwrap();
let some_data = &[
7.5_f32, -7.5, 7.5, -7.5, 7.5, 7.5, 7.5, 7.5, -1., 1., 1., 1., 1., 1., 1., -1., 0., 0.,
0., 1., 1., 0., 0., -1.,
];
index.add(some_data).unwrap();
assert_eq!(index.ntotal(), 6);
let mut index2 = index.try_clone().unwrap();
assert_eq!(index2.ntotal(), 6);
let some_more_data = &[
100., 100., 100., 100., -100., 100., 100., 100., 120., 100., 100., 105., -100., 100.,
100., 105.,
];
index2.add(some_more_data).unwrap();
assert_eq!(index.ntotal(), 6);
assert_eq!(index2.ntotal(), 10);
}
#[test]
fn flat_index_search() {
let mut index = index_factory(8, "Flat", MetricType::L2).unwrap();
let some_data = &[
7.5_f32, -7.5, 7.5, -7.5, 7.5, 7.5, 7.5, 7.5, -1., 1., 1., 1., 1., 1., 1., -1., 0., 0.,
0., 1., 1., 0., 0., -1., 100., 100., 100., 100., -100., 100., 100., 100., 120., 100.,
100., 105., -100., 100., 100., 105.,
];
index.add(some_data).unwrap();
assert_eq!(index.ntotal(), 5);
let my_query = [0.; 8];
let result = index.search(&my_query, 5).unwrap();
assert_eq!(result.labels, vec![Idx(2), Idx(1), Idx(0), Idx(3), Idx(4)]);
assert!(result.distances.iter().all(|x| *x > 0.));
let my_query = [100.; 8];
let result = index.search(&my_query, 5).unwrap();
assert_eq!(result.labels, vec![Idx(3), Idx(4), Idx(0), Idx(1), Idx(2)]);
assert!(result.distances.iter().all(|x| *x > 0.));
let my_query = vec![
0., 0., 0., 0., 0., 0., 0., 0., 100., 100., 100., 100., 100., 100., 100., 100.,
];
let result = index.search(&my_query, 5).unwrap();
assert_eq!(
result.labels,
vec![
Idx(2),
Idx(1),
Idx(0),
Idx(3),
Idx(4),
Idx(3),
Idx(4),
Idx(0),
Idx(1),
Idx(2)
]
);
assert!(result.distances.iter().all(|x| *x > 0.));
}
#[test]
fn flat_index_assign() {
let mut index = index_factory(8, "Flat", MetricType::L2).unwrap();
assert_eq!(index.d(), 8);
assert_eq!(index.ntotal(), 0);
let some_data = &[
7.5_f32, -7.5, 7.5, -7.5, 7.5, 7.5, 7.5, 7.5, -1., 1., 1., 1., 1., 1., 1., -1., 0., 0.,
0., 1., 1., 0., 0., -1., 100., 100., 100., 100., -100., 100., 100., 100., 120., 100.,
100., 105., -100., 100., 100., 105.,
];
index.add(some_data).unwrap();
assert_eq!(index.ntotal(), 5);
let my_query = [0.; 8];
let result = index.assign(&my_query, 5).unwrap();
assert_eq!(result.labels, vec![Idx(2), Idx(1), Idx(0), Idx(3), Idx(4)]);
let my_query = [0.; 32];
let result = index.assign(&my_query, 5).unwrap();
assert_eq!(
result.labels,
vec![2, 1, 0, 3, 4, 2, 1, 0, 3, 4, 2, 1, 0, 3, 4, 2, 1, 0, 3, 4]
.into_iter()
.map(Idx)
.collect::<Vec<_>>()
);
let my_query = [100.; 8];
let result = index.assign(&my_query, 5).unwrap();
assert_eq!(
result.labels,
vec![3, 4, 0, 1, 2].into_iter().map(Idx).collect::<Vec<_>>()
);
let my_query = vec![
0., 0., 0., 0., 0., 0., 0., 0., 100., 100., 100., 100., 100., 100., 100., 100.,
];
let result = index.assign(&my_query, 5).unwrap();
assert_eq!(
result.labels,
vec![2, 1, 0, 3, 4, 3, 4, 0, 1, 2]
.into_iter()
.map(Idx)
.collect::<Vec<_>>()
);
index.reset().unwrap();
assert_eq!(index.ntotal(), 0);
}
#[test]
fn flat_index_range_search() {
let mut index = index_factory(8, "Flat", MetricType::L2).unwrap();
let some_data = &[
7.5_f32, -7.5, 7.5, -7.5, 7.5, 7.5, 7.5, 7.5, -1., 1., 1., 1., 1., 1., 1., -1., 0., 0.,
0., 1., 1., 0., 0., -1., 100., 100., 100., 100., -100., 100., 100., 100., 120., 100.,
100., 105., -100., 100., 100., 105.,
];
index.add(some_data).unwrap();
assert_eq!(index.ntotal(), 5);
let my_query = [0.; 8];
let result = index.range_search(&my_query, 8.125).unwrap();
let (distances, labels) = result.distance_and_labels();
assert!(labels == &[Idx(1), Idx(2)] || labels == &[Idx(2), Idx(1)]);
assert!(distances.iter().all(|x| *x > 0.));
}
}