pub type Key = u64;
pub type Distance = f32;
pub type StatefulMetric = unsafe extern "C" fn(
*const std::ffi::c_void,
*const std::ffi::c_void,
*mut std::ffi::c_void,
) -> Distance;
pub type StatefulPredicate = unsafe extern "C" fn(Key, *mut std::ffi::c_void) -> bool;
#[derive(Debug)]
pub enum BitAddressableError {
IndexOutOfRange,
}
impl std::fmt::Display for BitAddressableError {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match *self {
BitAddressableError::IndexOutOfRange => write!(f, "Index out of range"),
}
}
}
impl std::error::Error for BitAddressableError {}
pub trait BitAddressable {
fn set_bit(&mut self, index: usize, value: bool) -> Result<(), BitAddressableError>;
fn get_bit(&self, index: usize) -> Result<bool, BitAddressableError>;
}
#[repr(transparent)]
#[allow(non_camel_case_types)]
#[derive(Clone, Copy, Eq, PartialEq)]
pub struct b1x8(pub u8);
impl b1x8 {
pub fn from_u8s(slice: &[u8]) -> &[Self] {
unsafe { std::slice::from_raw_parts(slice.as_ptr() as *const Self, slice.len()) }
}
pub fn from_mut_u8s(slice: &mut [u8]) -> &mut [Self] {
unsafe { std::slice::from_raw_parts_mut(slice.as_mut_ptr() as *mut Self, slice.len()) }
}
pub fn to_u8s(slice: &[Self]) -> &[u8] {
unsafe { std::slice::from_raw_parts(slice.as_ptr() as *const u8, slice.len()) }
}
pub fn to_mut_u8s(slice: &mut [Self]) -> &mut [u8] {
unsafe { std::slice::from_raw_parts_mut(slice.as_mut_ptr() as *mut u8, slice.len()) }
}
}
#[repr(transparent)]
#[allow(non_camel_case_types)]
#[derive(Clone, Copy)]
pub struct f16(i16);
impl f16 {
pub fn from_i16s(slice: &[i16]) -> &[Self] {
unsafe { std::slice::from_raw_parts(slice.as_ptr() as *const Self, slice.len()) }
}
pub fn from_mut_i16s(slice: &mut [i16]) -> &mut [Self] {
unsafe { std::slice::from_raw_parts_mut(slice.as_mut_ptr() as *mut Self, slice.len()) }
}
pub fn to_i16s(slice: &[Self]) -> &[i16] {
unsafe { std::slice::from_raw_parts(slice.as_ptr() as *const i16, slice.len()) }
}
pub fn to_mut_i16s(slice: &mut [Self]) -> &mut [i16] {
unsafe { std::slice::from_raw_parts_mut(slice.as_mut_ptr() as *mut i16, slice.len()) }
}
}
impl BitAddressable for b1x8 {
fn set_bit(&mut self, index: usize, value: bool) -> Result<(), BitAddressableError> {
if index >= 8 {
Err(BitAddressableError::IndexOutOfRange)
} else {
if value {
self.0 |= 1 << index;
} else {
self.0 &= !(1 << index);
}
Ok(())
}
}
fn get_bit(&self, index: usize) -> Result<bool, BitAddressableError> {
if index >= 8 {
Err(BitAddressableError::IndexOutOfRange)
} else {
Ok(((self.0 >> index) & 1) == 1)
}
}
}
impl BitAddressable for [b1x8] {
fn set_bit(&mut self, index: usize, value: bool) -> Result<(), BitAddressableError> {
let byte_index = index / 8;
let bit_index = index % 8;
if byte_index >= self.len() {
Err(BitAddressableError::IndexOutOfRange)
} else {
self[byte_index].set_bit(bit_index, value)
}
}
fn get_bit(&self, index: usize) -> Result<bool, BitAddressableError> {
let byte_index = index / 8;
let bit_index = index % 8;
if byte_index >= self.len() {
Err(BitAddressableError::IndexOutOfRange)
} else {
self[byte_index].get_bit(bit_index)
}
}
}
impl PartialEq for f16 {
fn eq(&self, other: &Self) -> bool {
let nan_self = (self.0 & 0x7C00) == 0x7C00 && (self.0 & 0x03FF) != 0;
let nan_other = (other.0 & 0x7C00) == 0x7C00 && (other.0 & 0x03FF) != 0;
if nan_self || nan_other {
return false;
}
self.0 == other.0
}
}
impl std::fmt::Debug for b1x8 {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{:08b}", self.0)
}
}
impl std::fmt::Debug for f16 {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let bits = self.0;
let sign = (bits >> 15) & 1;
let exponent = (bits >> 10) & 0x1F;
let mantissa = bits & 0x3FF;
write!(f, "{}|{:05b}|{:010b}", sign, exponent, mantissa)
}
}
#[cxx::bridge]
pub mod ffi {
#[derive(Debug)]
#[repr(i32)]
enum MetricKind {
Unknown,
IP,
L2sq,
Cos,
Pearson,
Haversine,
Divergence,
Hamming,
Tanimoto,
Sorensen,
}
#[derive(Debug)]
#[repr(i32)]
enum ScalarKind {
Unknown,
F64,
F32,
F16,
BF16,
I8,
B1,
}
#[derive(Debug)]
struct Matches {
keys: Vec<u64>,
distances: Vec<f32>,
}
#[derive(Debug, PartialEq)]
struct IndexOptions {
dimensions: usize,
metric: MetricKind,
quantization: ScalarKind,
connectivity: usize,
expansion_add: usize,
expansion_search: usize,
multi: bool,
}
unsafe extern "C++" {
include!("lib.hpp");
type NativeIndex;
pub fn expansion_add(self: &NativeIndex) -> usize;
pub fn expansion_search(self: &NativeIndex) -> usize;
pub fn change_expansion_add(self: &NativeIndex, n: usize);
pub fn change_expansion_search(self: &NativeIndex, n: usize);
pub fn change_metric_kind(self: &NativeIndex, metric: MetricKind);
pub fn change_metric(self: &NativeIndex, metric: usize, metric_state: usize);
pub fn new_native_index(options: &IndexOptions) -> Result<UniquePtr<NativeIndex>>;
pub fn reserve(self: &NativeIndex, capacity: usize) -> Result<()>;
pub fn dimensions(self: &NativeIndex) -> usize;
pub fn connectivity(self: &NativeIndex) -> usize;
pub fn size(self: &NativeIndex) -> usize;
pub fn capacity(self: &NativeIndex) -> usize;
pub fn serialized_length(self: &NativeIndex) -> usize;
pub fn add_b1x8(self: &NativeIndex, key: u64, vector: &[u8]) -> Result<()>;
pub fn add_i8(self: &NativeIndex, key: u64, vector: &[i8]) -> Result<()>;
pub fn add_f16(self: &NativeIndex, key: u64, vector: &[i16]) -> Result<()>;
pub fn add_f32(self: &NativeIndex, key: u64, vector: &[f32]) -> Result<()>;
pub fn add_f64(self: &NativeIndex, key: u64, vector: &[f64]) -> Result<()>;
pub fn search_b1x8(self: &NativeIndex, query: &[u8], count: usize) -> Result<Matches>;
pub fn search_i8(self: &NativeIndex, query: &[i8], count: usize) -> Result<Matches>;
pub fn search_f16(self: &NativeIndex, query: &[i16], count: usize) -> Result<Matches>;
pub fn search_f32(self: &NativeIndex, query: &[f32], count: usize) -> Result<Matches>;
pub fn search_f64(self: &NativeIndex, query: &[f64], count: usize) -> Result<Matches>;
pub fn filtered_search_b1x8(
self: &NativeIndex,
query: &[u8],
count: usize,
filter: usize,
filter_state: usize,
) -> Result<Matches>;
pub fn filtered_search_i8(
self: &NativeIndex,
query: &[i8],
count: usize,
filter: usize,
filter_state: usize,
) -> Result<Matches>;
pub fn filtered_search_f16(
self: &NativeIndex,
query: &[i16],
count: usize,
filter: usize,
filter_state: usize,
) -> Result<Matches>;
pub fn filtered_search_f32(
self: &NativeIndex,
query: &[f32],
count: usize,
filter: usize,
filter_state: usize,
) -> Result<Matches>;
pub fn filtered_search_f64(
self: &NativeIndex,
query: &[f64],
count: usize,
filter: usize,
filter_state: usize,
) -> Result<Matches>;
pub fn get_b1x8(self: &NativeIndex, key: u64, buffer: &mut [u8]) -> Result<usize>;
pub fn get_i8(self: &NativeIndex, key: u64, buffer: &mut [i8]) -> Result<usize>;
pub fn get_f16(self: &NativeIndex, key: u64, buffer: &mut [i16]) -> Result<usize>;
pub fn get_f32(self: &NativeIndex, key: u64, buffer: &mut [f32]) -> Result<usize>;
pub fn get_f64(self: &NativeIndex, key: u64, buffer: &mut [f64]) -> Result<usize>;
pub fn remove(self: &NativeIndex, key: u64) -> Result<usize>;
pub fn rename(self: &NativeIndex, from: u64, to: u64) -> Result<usize>;
pub fn contains(self: &NativeIndex, key: u64) -> bool;
pub fn count(self: &NativeIndex, key: u64) -> usize;
pub fn save(self: &NativeIndex, path: &str) -> Result<()>;
pub fn load(self: &NativeIndex, path: &str) -> Result<()>;
pub fn view(self: &NativeIndex, path: &str) -> Result<()>;
pub fn reset(self: &NativeIndex) -> Result<()>;
pub fn memory_usage(self: &NativeIndex) -> usize;
pub fn hardware_acceleration(self: &NativeIndex) -> *const c_char;
pub fn save_to_buffer(self: &NativeIndex, buffer: &mut [u8]) -> Result<()>;
pub fn load_from_buffer(self: &NativeIndex, buffer: &[u8]) -> Result<()>;
pub fn view_from_buffer(self: &NativeIndex, buffer: &[u8]) -> Result<()>;
}
}
pub use ffi::{IndexOptions, MetricKind, ScalarKind};
pub enum MetricFunction {
B1X8Metric(std::boxed::Box<dyn Fn(*const b1x8, *const b1x8) -> Distance + Send + Sync>),
I8Metric(std::boxed::Box<dyn Fn(*const i8, *const i8) -> Distance + Send + Sync>),
F16Metric(std::boxed::Box<dyn Fn(*const f16, *const f16) -> Distance + Send + Sync>),
F32Metric(std::boxed::Box<dyn Fn(*const f32, *const f32) -> Distance + Send + Sync>),
F64Metric(std::boxed::Box<dyn Fn(*const f64, *const f64) -> Distance + Send + Sync>),
}
pub struct Index {
inner: cxx::UniquePtr<ffi::NativeIndex>,
metric_fn: Option<MetricFunction>,
}
unsafe impl Send for Index {}
unsafe impl Sync for Index {}
impl Default for ffi::IndexOptions {
fn default() -> Self {
Self {
dimensions: 256,
metric: MetricKind::Cos,
quantization: ScalarKind::BF16,
connectivity: 0,
expansion_add: 0,
expansion_search: 0,
multi: false,
}
}
}
impl Clone for ffi::IndexOptions {
fn clone(&self) -> Self {
ffi::IndexOptions {
dimensions: (self.dimensions),
metric: (self.metric),
quantization: (self.quantization),
connectivity: (self.connectivity),
expansion_add: (self.expansion_add),
expansion_search: (self.expansion_search),
multi: (self.multi),
}
}
}
pub trait VectorType {
fn add(index: &Index, key: Key, vector: &[Self]) -> Result<(), cxx::Exception>
where
Self: Sized;
fn get(index: &Index, key: Key, buffer: &mut [Self]) -> Result<usize, cxx::Exception>
where
Self: Sized;
fn search(index: &Index, query: &[Self], count: usize) -> Result<ffi::Matches, cxx::Exception>
where
Self: Sized;
fn filtered_search<F>(
index: &Index,
query: &[Self],
count: usize,
filter: F,
) -> Result<ffi::Matches, cxx::Exception>
where
Self: Sized,
F: Fn(Key) -> bool;
fn change_metric(
index: &mut Index,
metric: std::boxed::Box<dyn Fn(*const Self, *const Self) -> Distance + Send + Sync>,
) -> Result<(), cxx::Exception>
where
Self: Sized;
}
impl VectorType for f32 {
fn search(index: &Index, query: &[Self], count: usize) -> Result<ffi::Matches, cxx::Exception> {
index.inner.search_f32(query, count)
}
fn get(index: &Index, key: Key, vector: &mut [Self]) -> Result<usize, cxx::Exception> {
index.inner.get_f32(key, vector)
}
fn add(index: &Index, key: Key, vector: &[Self]) -> Result<(), cxx::Exception> {
index.inner.add_f32(key, vector)
}
fn filtered_search<F>(
index: &Index,
query: &[Self],
count: usize,
filter: F,
) -> Result<ffi::Matches, cxx::Exception>
where
Self: Sized,
F: Fn(Key) -> bool,
{
extern "C" fn trampoline<F: Fn(u64) -> bool>(key: u64, closure_address: usize) -> bool {
let closure = closure_address as *const F;
unsafe { (*closure)(key) }
}
unsafe {
let trampoline_fn: usize = std::mem::transmute(trampoline::<F> as *const ());
let closure_address: usize = &filter as *const F as usize;
index
.inner
.filtered_search_f32(query, count, trampoline_fn, closure_address)
}
}
fn change_metric(
index: &mut Index,
metric: std::boxed::Box<dyn Fn(*const Self, *const Self) -> Distance + Send + Sync>,
) -> Result<(), cxx::Exception> {
type MetricFn = fn(*const f32, *const f32) -> Distance;
index.metric_fn = Some(MetricFunction::F32Metric(metric));
extern "C" fn trampoline(first: usize, second: usize, closure_address: usize) -> Distance {
let first_ptr = first as *const f32;
let second_ptr = second as *const f32;
let closure: MetricFn = unsafe { std::mem::transmute(closure_address) };
closure(first_ptr, second_ptr)
}
unsafe {
let trampoline_fn: usize = std::mem::transmute(trampoline as *const ());
let closure_address = match index.metric_fn {
Some(MetricFunction::F32Metric(ref metric)) => metric as *const _ as usize,
_ => panic!("Expected F32Metric"),
};
index.inner.change_metric(trampoline_fn, closure_address)
}
Ok(())
}
}
impl VectorType for i8 {
fn search(index: &Index, query: &[Self], count: usize) -> Result<ffi::Matches, cxx::Exception> {
index.inner.search_i8(query, count)
}
fn get(index: &Index, key: Key, vector: &mut [Self]) -> Result<usize, cxx::Exception> {
index.inner.get_i8(key, vector)
}
fn add(index: &Index, key: Key, vector: &[Self]) -> Result<(), cxx::Exception> {
index.inner.add_i8(key, vector)
}
fn filtered_search<F>(
index: &Index,
query: &[Self],
count: usize,
filter: F,
) -> Result<ffi::Matches, cxx::Exception>
where
Self: Sized,
F: Fn(Key) -> bool,
{
extern "C" fn trampoline<F: Fn(u64) -> bool>(key: u64, closure_address: usize) -> bool {
let closure = closure_address as *const F;
unsafe { (*closure)(key) }
}
unsafe {
let trampoline_fn: usize = std::mem::transmute(trampoline::<F> as *const ());
let closure_address: usize = &filter as *const F as usize;
index
.inner
.filtered_search_i8(query, count, trampoline_fn, closure_address)
}
}
fn change_metric(
index: &mut Index,
metric: std::boxed::Box<dyn Fn(*const Self, *const Self) -> Distance + Send + Sync>,
) -> Result<(), cxx::Exception> {
type MetricFn = fn(*const i8, *const i8) -> Distance;
index.metric_fn = Some(MetricFunction::I8Metric(metric));
extern "C" fn trampoline(first: usize, second: usize, closure_address: usize) -> Distance {
let first_ptr = first as *const i8;
let second_ptr = second as *const i8;
let closure: MetricFn = unsafe { std::mem::transmute(closure_address) };
closure(first_ptr, second_ptr)
}
unsafe {
let trampoline_fn: usize = std::mem::transmute(trampoline as *const ());
let closure_address = match index.metric_fn {
Some(MetricFunction::I8Metric(ref metric)) => metric as *const _ as usize,
_ => panic!("Expected I8Metric"),
};
index.inner.change_metric(trampoline_fn, closure_address)
}
Ok(())
}
}
impl VectorType for f64 {
fn search(index: &Index, query: &[Self], count: usize) -> Result<ffi::Matches, cxx::Exception> {
index.inner.search_f64(query, count)
}
fn get(index: &Index, key: Key, vector: &mut [Self]) -> Result<usize, cxx::Exception> {
index.inner.get_f64(key, vector)
}
fn add(index: &Index, key: Key, vector: &[Self]) -> Result<(), cxx::Exception> {
index.inner.add_f64(key, vector)
}
fn filtered_search<F>(
index: &Index,
query: &[Self],
count: usize,
filter: F,
) -> Result<ffi::Matches, cxx::Exception>
where
Self: Sized,
F: Fn(Key) -> bool,
{
extern "C" fn trampoline<F: Fn(u64) -> bool>(key: u64, closure_address: usize) -> bool {
let closure = closure_address as *const F;
unsafe { (*closure)(key) }
}
unsafe {
let trampoline_fn: usize = std::mem::transmute(trampoline::<F> as *const ());
let closure_address: usize = &filter as *const F as usize;
index
.inner
.filtered_search_f64(query, count, trampoline_fn, closure_address)
}
}
fn change_metric(
index: &mut Index,
metric: std::boxed::Box<dyn Fn(*const Self, *const Self) -> Distance + Send + Sync>,
) -> Result<(), cxx::Exception> {
type MetricFn = fn(*const f64, *const f64) -> Distance;
index.metric_fn = Some(MetricFunction::F64Metric(metric));
extern "C" fn trampoline(first: usize, second: usize, closure_address: usize) -> Distance {
let first_ptr = first as *const f64;
let second_ptr = second as *const f64;
let closure: MetricFn = unsafe { std::mem::transmute(closure_address) };
closure(first_ptr, second_ptr)
}
unsafe {
let trampoline_fn: usize = std::mem::transmute(trampoline as *const ());
let closure_address = match index.metric_fn {
Some(MetricFunction::F64Metric(ref metric)) => metric as *const _ as usize,
_ => panic!("Expected F64Metric"),
};
index.inner.change_metric(trampoline_fn, closure_address)
}
Ok(())
}
}
impl VectorType for f16 {
fn search(index: &Index, query: &[Self], count: usize) -> Result<ffi::Matches, cxx::Exception> {
index.inner.search_f16(f16::to_i16s(query), count)
}
fn get(index: &Index, key: Key, vector: &mut [Self]) -> Result<usize, cxx::Exception> {
index.inner.get_f16(key, f16::to_mut_i16s(vector))
}
fn add(index: &Index, key: Key, vector: &[Self]) -> Result<(), cxx::Exception> {
index.inner.add_f16(key, f16::to_i16s(vector))
}
fn filtered_search<F>(
index: &Index,
query: &[Self],
count: usize,
filter: F,
) -> Result<ffi::Matches, cxx::Exception>
where
Self: Sized,
F: Fn(Key) -> bool,
{
extern "C" fn trampoline<F: Fn(u64) -> bool>(key: u64, closure_address: usize) -> bool {
let closure = closure_address as *const F;
unsafe { (*closure)(key) }
}
unsafe {
let trampoline_fn: usize = std::mem::transmute(trampoline::<F> as *const ());
let closure_address: usize = &filter as *const F as usize;
index.inner.filtered_search_f16(
f16::to_i16s(query),
count,
trampoline_fn,
closure_address,
)
}
}
fn change_metric(
index: &mut Index,
metric: std::boxed::Box<dyn Fn(*const Self, *const Self) -> Distance + Send + Sync>,
) -> Result<(), cxx::Exception> {
type MetricFn = fn(*const f16, *const f16) -> Distance;
index.metric_fn = Some(MetricFunction::F16Metric(metric));
extern "C" fn trampoline(first: usize, second: usize, closure_address: usize) -> Distance {
let first_ptr = first as *const f16;
let second_ptr = second as *const f16;
let closure: MetricFn = unsafe { std::mem::transmute(closure_address) };
closure(first_ptr, second_ptr)
}
unsafe {
let trampoline_fn: usize = std::mem::transmute(trampoline as *const ());
let closure_address = match index.metric_fn {
Some(MetricFunction::F16Metric(ref metric)) => metric as *const _ as usize,
_ => panic!("Expected F16Metric"),
};
index.inner.change_metric(trampoline_fn, closure_address)
}
Ok(())
}
}
impl VectorType for b1x8 {
fn search(index: &Index, query: &[Self], count: usize) -> Result<ffi::Matches, cxx::Exception> {
index.inner.search_b1x8(b1x8::to_u8s(query), count)
}
fn get(index: &Index, key: Key, vector: &mut [Self]) -> Result<usize, cxx::Exception> {
index.inner.get_b1x8(key, b1x8::to_mut_u8s(vector))
}
fn add(index: &Index, key: Key, vector: &[Self]) -> Result<(), cxx::Exception> {
index.inner.add_b1x8(key, b1x8::to_u8s(vector))
}
fn filtered_search<F>(
index: &Index,
query: &[Self],
count: usize,
filter: F,
) -> Result<ffi::Matches, cxx::Exception>
where
Self: Sized,
F: Fn(Key) -> bool,
{
extern "C" fn trampoline<F: Fn(u64) -> bool>(key: u64, closure_address: usize) -> bool {
let closure = closure_address as *const F;
unsafe { (*closure)(key) }
}
unsafe {
let trampoline_fn: usize = std::mem::transmute(trampoline::<F> as *const ());
let closure_address: usize = &filter as *const F as usize;
index.inner.filtered_search_b1x8(
b1x8::to_u8s(query),
count,
trampoline_fn,
closure_address,
)
}
}
fn change_metric(
index: &mut Index,
metric: std::boxed::Box<dyn Fn(*const Self, *const Self) -> Distance + Send + Sync>,
) -> Result<(), cxx::Exception> {
type MetricFn = fn(*const b1x8, *const b1x8) -> Distance;
index.metric_fn = Some(MetricFunction::B1X8Metric(metric));
extern "C" fn trampoline(first: usize, second: usize, closure_address: usize) -> Distance {
let first_ptr = first as *const b1x8;
let second_ptr = second as *const b1x8;
let closure: MetricFn = unsafe { std::mem::transmute(closure_address) };
closure(first_ptr, second_ptr)
}
unsafe {
let trampoline_fn: usize = std::mem::transmute(trampoline as *const ());
let closure_address = match index.metric_fn {
Some(MetricFunction::B1X8Metric(ref metric)) => metric as *const _ as usize,
_ => panic!("Expected B1X8Metric"),
};
index.inner.change_metric(trampoline_fn, closure_address)
}
Ok(())
}
}
impl Index {
pub fn new(options: &ffi::IndexOptions) -> Result<Self, cxx::Exception> {
match ffi::new_native_index(options) {
Ok(inner) => Result::Ok(Self {
inner,
metric_fn: None,
}),
Err(err) => Err(err),
}
}
pub fn expansion_add(self: &Index) -> usize {
self.inner.expansion_add()
}
pub fn expansion_search(self: &Index) -> usize {
self.inner.expansion_search()
}
pub fn change_expansion_add(self: &Index, n: usize) {
self.inner.change_expansion_add(n)
}
pub fn change_expansion_search(self: &Index, n: usize) {
self.inner.change_expansion_search(n)
}
pub fn change_metric_kind(self: &Index, metric: ffi::MetricKind) {
self.inner.change_metric_kind(metric)
}
pub fn change_metric<T: VectorType>(
self: &mut Index,
metric: std::boxed::Box<dyn Fn(*const T, *const T) -> Distance + Send + Sync>,
) {
T::change_metric(self, metric).unwrap();
}
pub fn hardware_acceleration(&self) -> String {
use core::ffi::CStr;
unsafe {
let c_str = CStr::from_ptr(self.inner.hardware_acceleration());
c_str.to_string_lossy().into_owned()
}
}
pub fn search<T: VectorType>(
self: &Index,
query: &[T],
count: usize,
) -> Result<ffi::Matches, cxx::Exception> {
T::search(self, query, count)
}
pub fn filtered_search<T: VectorType, F>(
self: &Index,
query: &[T],
count: usize,
filter: F,
) -> Result<ffi::Matches, cxx::Exception>
where
F: Fn(Key) -> bool,
{
T::filtered_search(self, query, count, filter)
}
pub fn add<T: VectorType>(self: &Index, key: Key, vector: &[T]) -> Result<(), cxx::Exception> {
T::add(self, key, vector)
}
pub fn get<T: VectorType>(
self: &Index,
key: Key,
vector: &mut [T],
) -> Result<usize, cxx::Exception> {
T::get(self, key, vector)
}
pub fn export<T: VectorType + Default + Clone>(
self: &Index,
key: Key,
vector: &mut Vec<T>,
) -> Result<usize, cxx::Exception> {
let dim = self.dimensions();
let max_matches = self.count(key);
vector.resize(dim * max_matches, T::default());
let matches = T::get(self, key, &mut vector[..]);
if matches.is_err() {
return matches;
}
vector.resize(dim * matches.as_ref().unwrap(), T::default());
return matches;
}
pub fn reserve(self: &Index, capacity: usize) -> Result<(), cxx::Exception> {
self.inner.reserve(capacity)
}
pub fn dimensions(self: &Index) -> usize {
self.inner.dimensions()
}
pub fn connectivity(self: &Index) -> usize {
self.inner.connectivity()
}
pub fn size(self: &Index) -> usize {
self.inner.size()
}
pub fn capacity(self: &Index) -> usize {
self.inner.capacity()
}
pub fn serialized_length(self: &Index) -> usize {
self.inner.serialized_length()
}
pub fn remove(self: &Index, key: Key) -> Result<usize, cxx::Exception> {
self.inner.remove(key)
}
pub fn rename(self: &Index, from: Key, to: Key) -> Result<usize, cxx::Exception> {
self.inner.rename(from, to)
}
pub fn contains(self: &Index, key: Key) -> bool {
self.inner.contains(key)
}
pub fn count(self: &Index, key: Key) -> usize {
self.inner.count(key)
}
pub fn save(self: &Index, path: &str) -> Result<(), cxx::Exception> {
self.inner.save(path)
}
pub fn load(self: &Index, path: &str) -> Result<(), cxx::Exception> {
self.inner.load(path)
}
pub fn view(self: &Index, path: &str) -> Result<(), cxx::Exception> {
self.inner.view(path)
}
pub fn reset(self: &Index) -> Result<(), cxx::Exception> {
self.inner.reset()
}
pub fn memory_usage(self: &Index) -> usize {
self.inner.memory_usage()
}
pub fn save_to_buffer(self: &Index, buffer: &mut [u8]) -> Result<(), cxx::Exception> {
self.inner.save_to_buffer(buffer)
}
pub fn load_from_buffer(self: &Index, buffer: &[u8]) -> Result<(), cxx::Exception> {
self.inner.load_from_buffer(buffer)
}
pub unsafe fn view_from_buffer(self: &Index, buffer: &[u8]) -> Result<(), cxx::Exception> {
self.inner.view_from_buffer(buffer)
}
}
pub fn new_index(options: &ffi::IndexOptions) -> Result<Index, cxx::Exception> {
Index::new(options)
}
#[cfg(test)]
mod tests {
use crate::ffi::IndexOptions;
use crate::ffi::MetricKind;
use crate::ffi::ScalarKind;
use crate::b1x8;
use crate::f16;
use crate::new_index;
use crate::Distance;
use crate::Index;
use crate::Key;
use std::env;
#[test]
fn print_specs() {
print!("--------------------------------------------------\n");
println!("OS: {}", env::consts::OS);
println!(
"Rust version: {}",
env::var("RUST_VERSION").unwrap_or_else(|_| "unknown".into())
);
let f64_index = Index::new(&IndexOptions {
dimensions: 256,
metric: MetricKind::Cos,
quantization: ScalarKind::F64,
..Default::default()
})
.unwrap();
let f32_index = Index::new(&IndexOptions {
dimensions: 256,
metric: MetricKind::Cos,
quantization: ScalarKind::F32,
..Default::default()
})
.unwrap();
let f16_index = Index::new(&IndexOptions {
dimensions: 256,
metric: MetricKind::Cos,
quantization: ScalarKind::F16,
..Default::default()
})
.unwrap();
let i8_index = Index::new(&IndexOptions {
dimensions: 256,
metric: MetricKind::Cos,
quantization: ScalarKind::I8,
..Default::default()
})
.unwrap();
let b1_index = Index::new(&IndexOptions {
dimensions: 256,
metric: MetricKind::Hamming,
quantization: ScalarKind::B1,
..Default::default()
})
.unwrap();
println!(
"f64 hardware acceleration: {}",
f64_index.hardware_acceleration()
);
println!(
"f32 hardware acceleration: {}",
f32_index.hardware_acceleration()
);
println!(
"f16 hardware acceleration: {}",
f16_index.hardware_acceleration()
);
println!(
"i8 hardware acceleration: {}",
i8_index.hardware_acceleration()
);
println!(
"b1 hardware acceleration: {}",
b1_index.hardware_acceleration()
);
print!("--------------------------------------------------\n");
}
#[test]
fn test_add_get_vector() {
let mut options = IndexOptions::default();
options.dimensions = 5;
options.quantization = ScalarKind::F32;
let index = Index::new(&options).unwrap();
assert!(index.reserve(10).is_ok());
let first: [f32; 5] = [0.2, 0.1, 0.2, 0.1, 0.3];
let second: [f32; 5] = [0.3, 0.2, 0.4, 0.0, 0.1];
let too_long: [f32; 6] = [0.3, 0.2, 0.4, 0.0, 0.1, 0.1];
let too_short: [f32; 4] = [0.3, 0.2, 0.4, 0.0];
assert!(index.add(1, &first).is_ok());
assert!(index.add(2, &second).is_ok());
assert!(index.add(3, &too_long).is_err());
assert!(index.add(4, &too_short).is_err());
assert_eq!(index.size(), 2);
let mut found_vec: Vec<f32> = Vec::new();
assert_eq!(index.export(1, &mut found_vec).unwrap(), 1);
assert_eq!(found_vec.len(), 5);
assert_eq!(found_vec, first.to_vec());
let mut found_slice = [0.0 as f32; 5];
assert_eq!(index.get(1, &mut found_slice).unwrap(), 1);
assert_eq!(found_slice, first);
let mut found = [0.0 as f32; 6]; let result = index.get(1, &mut found);
assert!(result.is_err());
}
#[test]
fn test_search_vector() {
let mut options = IndexOptions::default();
options.dimensions = 5;
options.quantization = ScalarKind::F32;
let index = Index::new(&options).unwrap();
assert!(index.reserve(10).is_ok());
let first: [f32; 5] = [0.2, 0.1, 0.2, 0.1, 0.3];
let second: [f32; 5] = [0.3, 0.2, 0.4, 0.0, 0.1];
let too_long: [f32; 6] = [0.3, 0.2, 0.4, 0.0, 0.1, 0.1];
let too_short: [f32; 4] = [0.3, 0.2, 0.4, 0.0];
assert!(index.add(1, &first).is_ok());
assert!(index.add(2, &second).is_ok());
assert_eq!(index.size(), 2);
assert!(index.search(&too_long, 1).is_err());
assert!(index.search(&too_short, 1).is_err());
}
#[test]
fn test_add_remove_vector() {
let mut options = IndexOptions::default();
options.dimensions = 4;
options.metric = MetricKind::IP;
options.quantization = ScalarKind::F64;
options.connectivity = 10;
options.expansion_add = 128;
options.expansion_search = 3;
let index = Index::new(&options).unwrap();
assert!(index.reserve(10).is_ok());
assert!(index.capacity() >= 10);
let first: [f32; 4] = [0.2, 0.1, 0.2, 0.1];
let second: [f32; 4] = [0.3, 0.2, 0.4, 0.0];
let id1 = 483367403120493160;
let id2 = 483367403120558696;
let id3 = 483367403120624232;
let id4 = 483367403120624233;
assert!(index.add(id1, &first).is_ok());
let mut found_slice = [0.0 as f32; 4];
assert_eq!(index.get(id1, &mut found_slice).unwrap(), 1);
assert!(index.remove(id1).is_ok());
assert!(index.add(id2, &second).is_ok());
let mut found_slice = [0.0 as f32; 4];
assert_eq!(index.get(id2, &mut found_slice).unwrap(), 1);
assert!(index.remove(id2).is_ok());
assert!(index.add(id3, &second).is_ok());
let mut found_slice = [0.0 as f32; 4];
assert_eq!(index.get(id3, &mut found_slice).unwrap(), 1);
assert!(index.remove(id3).is_ok());
assert!(index.add(id4, &second).is_ok());
let mut found_slice = [0.0 as f32; 4];
assert_eq!(index.get(id4, &mut found_slice).unwrap(), 1);
assert!(index.remove(id4).is_ok());
assert_eq!(index.size(), 0);
}
#[test]
fn integration() {
let mut options = IndexOptions::default();
options.dimensions = 5;
let index = Index::new(&options).unwrap();
assert!(index.expansion_add() > 0);
assert!(index.expansion_search() > 0);
assert!(index.reserve(10).is_ok());
assert!(index.capacity() >= 10);
assert!(index.connectivity() != 0);
assert_eq!(index.dimensions(), 5);
assert_eq!(index.size(), 0);
let first: [f32; 5] = [0.2, 0.1, 0.2, 0.1, 0.3];
let second: [f32; 5] = [0.3, 0.2, 0.4, 0.0, 0.1];
print!("--------------------------------------------------\n");
println!(
"before add, memory_usage: {} \
cap: {} \
",
index.memory_usage(),
index.capacity(),
);
index.change_expansion_add(10);
assert_eq!(index.expansion_add(), 10);
assert!(index.add(42, &first).is_ok());
index.change_expansion_add(12);
assert_eq!(index.expansion_add(), 12);
assert!(index.add(43, &second).is_ok());
assert_eq!(index.size(), 2);
println!(
"after add, memory_usage: {} \
cap: {} \
",
index.memory_usage(),
index.capacity(),
);
index.change_expansion_search(10);
assert_eq!(index.expansion_search(), 10);
let results = index.search(&first, 10).unwrap();
println!("{:?}", results);
assert_eq!(results.keys.len(), 2);
index.change_expansion_search(12);
assert_eq!(index.expansion_search(), 12);
let results = index.search(&first, 10).unwrap();
println!("{:?}", results);
assert_eq!(results.keys.len(), 2);
print!("--------------------------------------------------\n");
assert!(index.save("index.rust.usearch").is_ok());
assert!(index.load("index.rust.usearch").is_ok());
assert!(index.view("index.rust.usearch").is_ok());
assert!(new_index(&options).is_ok());
options.metric = MetricKind::L2sq;
assert!(new_index(&options).is_ok());
options.metric = MetricKind::Cos;
assert!(new_index(&options).is_ok());
options.metric = MetricKind::Haversine;
options.quantization = ScalarKind::F32;
options.dimensions = 2;
assert!(new_index(&options).is_ok());
let mut serialization_buffer = Vec::new();
serialization_buffer.resize(index.serialized_length(), 0);
assert!(index.save_to_buffer(&mut serialization_buffer).is_ok());
let deserialized_index = new_index(&options).unwrap();
assert!(deserialized_index
.load_from_buffer(&serialization_buffer)
.is_ok());
assert_eq!(index.size(), deserialized_index.size());
assert_ne!(index.memory_usage(), 0);
assert!(index.reset().is_ok());
assert_eq!(index.size(), 0);
assert_eq!(index.memory_usage(), 0);
options.metric = MetricKind::Haversine;
let mut opts = options.clone();
assert_eq!(opts.metric, options.metric);
assert_eq!(opts.quantization, options.quantization);
assert_eq!(opts, options);
opts.metric = MetricKind::Cos;
assert_ne!(opts.metric, options.metric);
assert!(new_index(&opts).is_ok());
}
#[test]
fn test_search_with_stateless_filter() {
let mut options = IndexOptions::default();
options.dimensions = 5;
let index = Index::new(&options).unwrap();
index.reserve(10).unwrap();
let first: [f32; 5] = [0.2, 0.1, 0.2, 0.1, 0.3];
let second: [f32; 5] = [0.3, 0.2, 0.4, 0.0, 0.1];
index.add(1, &first).unwrap();
index.add(2, &second).unwrap();
let is_odd = |key: Key| key % 2 == 1;
let query = vec![0.2, 0.1, 0.2, 0.1, 0.3]; let results = index.filtered_search(&query, 10, is_odd).unwrap();
assert!(
results.keys.iter().all(|&key| key % 2 == 1),
"All keys must be odd"
);
}
#[test]
fn test_search_with_stateful_filter() {
use std::collections::HashSet;
let mut options = IndexOptions::default();
options.dimensions = 5;
let index = Index::new(&options).unwrap();
index.reserve(10).unwrap();
let first: [f32; 5] = [0.2, 0.1, 0.2, 0.1, 0.3];
index.add(1, &first).unwrap();
index.add(2, &first).unwrap();
let allowed_keys = vec![1, 2, 3].into_iter().collect::<HashSet<Key>>();
let filter_keys = allowed_keys.clone();
let stateful_filter = move |key: Key| filter_keys.contains(&key);
let query = vec![0.2, 0.1, 0.2, 0.1, 0.3]; let results = index.filtered_search(&query, 10, stateful_filter).unwrap();
assert!(
results.keys.iter().all(|&key| allowed_keys.contains(&key)),
"All keys must be in the allowed set"
);
}
#[test]
fn test_zero_distances() {
let options = IndexOptions {
dimensions: 8,
metric: MetricKind::L2sq,
quantization: ScalarKind::F16,
..Default::default()
};
let index = new_index(&options).unwrap();
index.reserve(10).unwrap();
index
.add(0, &[0.4, 0.1, 0.1, 0.0, 0.0, 0.0, 0.0, 0.0])
.unwrap();
index
.add(1, &[0.5, 0.1, 0.1, 0.0, 0.0, 0.0, 0.0, 0.0])
.unwrap();
index
.add(2, &[0.6, 0.1, 0.1, 0.0, 0.0, 0.0, 0.0, 0.0])
.unwrap();
let matches = index
.search(&[0.05, 0.1, 0.1, 0.0, 0.0, 0.0, 0.0, 0.0], 2)
.unwrap();
for distance in matches.distances.iter() {
assert_ne!(*distance, 0.0);
}
}
#[test]
fn test_change_distance_function() {
let mut options = IndexOptions::default();
options.dimensions = 2; let mut index = Index::new(&options).unwrap();
index.reserve(10).unwrap();
let vector: [f32; 2] = [1.0, 0.0];
index.add(1, &vector).unwrap();
let first_factor: f32 = 2.0;
let second_factor: f32 = 0.7;
let stateful_distance = Box::new(move |a: *const f32, b: *const f32| unsafe {
let a_slice = std::slice::from_raw_parts(a, 2);
let b_slice = std::slice::from_raw_parts(b, 2);
(a_slice[0] - b_slice[0]).abs() * first_factor
+ (a_slice[1] - b_slice[1]).abs() * second_factor
});
index.change_metric(stateful_distance);
}
#[test]
fn test_binary_vectors_and_hamming_distance() {
let index = Index::new(&IndexOptions {
dimensions: 8,
metric: MetricKind::Hamming,
quantization: ScalarKind::B1,
..Default::default()
})
.unwrap();
let vector42: Vec<b1x8> = vec![b1x8(0b00001111)];
let vector43: Vec<b1x8> = vec![b1x8(0b11110000)];
let query: Vec<b1x8> = vec![b1x8(0b01111000)];
index.reserve(10).unwrap();
index.add(42, &vector42).unwrap();
index.add(43, &vector43).unwrap();
let results = index.search(&query, 5).unwrap();
assert_eq!(results.keys.len(), 2);
assert_eq!(results.keys[0], 43);
assert_eq!(results.distances[0], 2.0);
assert_eq!(results.keys[1], 42);
assert_eq!(results.distances[1], 6.0);
}
}