pub fn version() -> &'static str {
env!("CARGO_PKG_VERSION")
}
pub fn hardware_acceleration_compiled() -> String {
use core::ffi::CStr;
unsafe {
CStr::from_ptr(ffi::hardware_acceleration_compiled())
.to_string_lossy()
.into_owned()
}
}
pub fn hardware_acceleration_available() -> String {
use core::ffi::CStr;
unsafe {
CStr::from_ptr(ffi::hardware_acceleration_available())
.to_string_lossy()
.into_owned()
}
}
pub type Key = u64;
pub type Distance = f32;
pub type StatefulMetric = unsafe extern "C" fn(
*const std::ffi::c_void,
*const std::ffi::c_void,
*mut std::ffi::c_void,
) -> Distance;
pub type StatefulPredicate = unsafe extern "C" fn(Key, *mut std::ffi::c_void) -> bool;
#[derive(Debug)]
pub enum BitAddressableError {
IndexOutOfRange,
}
impl std::fmt::Display for BitAddressableError {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match *self {
BitAddressableError::IndexOutOfRange => write!(f, "Index out of range"),
}
}
}
impl std::error::Error for BitAddressableError {}
pub trait BitAddressable {
fn set_bit(&mut self, index: usize, value: bool) -> Result<(), BitAddressableError>;
fn get_bit(&self, index: usize) -> Result<bool, BitAddressableError>;
}
#[repr(transparent)]
#[allow(non_camel_case_types)]
#[derive(Clone, Copy, Eq, PartialEq)]
pub struct b1x8(pub u8);
impl b1x8 {
pub fn from_u8s(slice: &[u8]) -> &[Self] {
unsafe { std::slice::from_raw_parts(slice.as_ptr() as *const Self, slice.len()) }
}
pub fn from_mut_u8s(slice: &mut [u8]) -> &mut [Self] {
unsafe { std::slice::from_raw_parts_mut(slice.as_mut_ptr() as *mut Self, slice.len()) }
}
pub fn to_u8s(slice: &[Self]) -> &[u8] {
unsafe { std::slice::from_raw_parts(slice.as_ptr() as *const u8, slice.len()) }
}
pub fn to_mut_u8s(slice: &mut [Self]) -> &mut [u8] {
unsafe { std::slice::from_raw_parts_mut(slice.as_mut_ptr() as *mut u8, slice.len()) }
}
}
#[repr(transparent)]
#[allow(non_camel_case_types)]
#[derive(Clone, Copy)]
pub struct f16(i16);
impl f16 {
pub fn from_i16s(slice: &[i16]) -> &[Self] {
unsafe { std::slice::from_raw_parts(slice.as_ptr() as *const Self, slice.len()) }
}
pub fn from_mut_i16s(slice: &mut [i16]) -> &mut [Self] {
unsafe { std::slice::from_raw_parts_mut(slice.as_mut_ptr() as *mut Self, slice.len()) }
}
pub fn to_i16s(slice: &[Self]) -> &[i16] {
unsafe { std::slice::from_raw_parts(slice.as_ptr() as *const i16, slice.len()) }
}
pub fn to_mut_i16s(slice: &mut [Self]) -> &mut [i16] {
unsafe { std::slice::from_raw_parts_mut(slice.as_mut_ptr() as *mut i16, slice.len()) }
}
}
impl BitAddressable for b1x8 {
fn set_bit(&mut self, index: usize, value: bool) -> Result<(), BitAddressableError> {
if index >= 8 {
Err(BitAddressableError::IndexOutOfRange)
} else {
if value {
self.0 |= 1 << index;
} else {
self.0 &= !(1 << index);
}
Ok(())
}
}
fn get_bit(&self, index: usize) -> Result<bool, BitAddressableError> {
if index >= 8 {
Err(BitAddressableError::IndexOutOfRange)
} else {
Ok(((self.0 >> index) & 1) == 1)
}
}
}
impl BitAddressable for [b1x8] {
fn set_bit(&mut self, index: usize, value: bool) -> Result<(), BitAddressableError> {
let byte_index = index / 8;
let bit_index = index % 8;
if byte_index >= self.len() {
Err(BitAddressableError::IndexOutOfRange)
} else {
self[byte_index].set_bit(bit_index, value)
}
}
fn get_bit(&self, index: usize) -> Result<bool, BitAddressableError> {
let byte_index = index / 8;
let bit_index = index % 8;
if byte_index >= self.len() {
Err(BitAddressableError::IndexOutOfRange)
} else {
self[byte_index].get_bit(bit_index)
}
}
}
impl PartialEq for f16 {
fn eq(&self, other: &Self) -> bool {
let nan_self = (self.0 & 0x7C00) == 0x7C00 && (self.0 & 0x03FF) != 0;
let nan_other = (other.0 & 0x7C00) == 0x7C00 && (other.0 & 0x03FF) != 0;
if nan_self || nan_other {
return false;
}
self.0 == other.0
}
}
impl std::fmt::Debug for b1x8 {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{:08b}", self.0)
}
}
impl std::fmt::Debug for f16 {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let bits = self.0;
let sign = (bits >> 15) & 1;
let exponent = (bits >> 10) & 0x1F;
let mantissa = bits & 0x3FF;
write!(f, "{}|{:05b}|{:010b}", sign, exponent, mantissa)
}
}
#[cxx::bridge]
pub mod ffi {
#[derive(Debug)]
#[repr(i32)]
enum MetricKind {
Unknown,
IP,
L2sq,
Cos,
Pearson,
Haversine,
Divergence,
Hamming,
Tanimoto,
Sorensen,
}
#[derive(Debug)]
#[repr(i32)]
enum ScalarKind {
Unknown,
F64,
F32,
BF16,
F16,
E5M2,
E4M3,
E3M2,
E2M3,
I8,
U8,
B1,
}
#[derive(Debug)]
struct Matches {
keys: Vec<u64>,
distances: Vec<f32>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
struct MemoryStats {
graph_allocated: usize,
graph_wasted: usize,
graph_reserved: usize,
vectors_allocated: usize,
vectors_wasted: usize,
vectors_reserved: usize,
}
#[derive(Debug, PartialEq)]
struct IndexOptions {
dimensions: usize,
metric: MetricKind,
quantization: ScalarKind,
connectivity: usize,
expansion_add: usize,
expansion_search: usize,
multi: bool,
}
unsafe extern "C++" {
include!("lib.hpp");
pub fn hardware_acceleration_compiled() -> *const c_char;
pub fn hardware_acceleration_available() -> *const c_char;
type NativeIndex;
pub fn expansion_add(self: &NativeIndex) -> usize;
pub fn expansion_search(self: &NativeIndex) -> usize;
pub fn change_expansion_add(self: &NativeIndex, n: usize);
pub fn change_expansion_search(self: &NativeIndex, n: usize);
pub fn change_metric_kind(self: &NativeIndex, metric: MetricKind);
pub fn change_metric(self: &NativeIndex, metric: usize, metric_state: usize);
pub fn new_native_index(options: &IndexOptions) -> Result<UniquePtr<NativeIndex>>;
pub fn reserve(self: &NativeIndex, capacity: usize) -> Result<()>;
pub fn reserve_capacity_and_threads(
self: &NativeIndex,
capacity: usize,
threads: usize,
) -> Result<()>;
pub fn dimensions(self: &NativeIndex) -> usize;
pub fn connectivity(self: &NativeIndex) -> usize;
pub fn size(self: &NativeIndex) -> usize;
pub fn capacity(self: &NativeIndex) -> usize;
pub fn serialized_length(self: &NativeIndex) -> usize;
pub fn add_f64(self: &NativeIndex, key: u64, vector: &[f64]) -> Result<()>;
pub fn add_f32(self: &NativeIndex, key: u64, vector: &[f32]) -> Result<()>;
pub fn add_f16(self: &NativeIndex, key: u64, vector: &[i16]) -> Result<()>;
pub fn add_i8(self: &NativeIndex, key: u64, vector: &[i8]) -> Result<()>;
pub fn add_u8(self: &NativeIndex, key: u64, vector: &[u8]) -> Result<()>;
pub fn add_b1x8(self: &NativeIndex, key: u64, vector: &[u8]) -> Result<()>;
pub fn search_f64(self: &NativeIndex, query: &[f64], count: usize) -> Result<Matches>;
pub fn search_f32(self: &NativeIndex, query: &[f32], count: usize) -> Result<Matches>;
pub fn search_f16(self: &NativeIndex, query: &[i16], count: usize) -> Result<Matches>;
pub fn search_i8(self: &NativeIndex, query: &[i8], count: usize) -> Result<Matches>;
pub fn search_u8(self: &NativeIndex, query: &[u8], count: usize) -> Result<Matches>;
pub fn search_b1x8(self: &NativeIndex, query: &[u8], count: usize) -> Result<Matches>;
pub fn exact_search_f64(self: &NativeIndex, query: &[f64], count: usize)
-> Result<Matches>;
pub fn exact_search_f32(self: &NativeIndex, query: &[f32], count: usize)
-> Result<Matches>;
pub fn exact_search_f16(self: &NativeIndex, query: &[i16], count: usize)
-> Result<Matches>;
pub fn exact_search_i8(self: &NativeIndex, query: &[i8], count: usize) -> Result<Matches>;
pub fn exact_search_u8(self: &NativeIndex, query: &[u8], count: usize) -> Result<Matches>;
pub fn exact_search_b1x8(self: &NativeIndex, query: &[u8], count: usize)
-> Result<Matches>;
pub fn filtered_search_f64(
self: &NativeIndex,
query: &[f64],
count: usize,
filter: usize,
filter_state: usize,
) -> Result<Matches>;
pub fn filtered_search_f32(
self: &NativeIndex,
query: &[f32],
count: usize,
filter: usize,
filter_state: usize,
) -> Result<Matches>;
pub fn filtered_search_f16(
self: &NativeIndex,
query: &[i16],
count: usize,
filter: usize,
filter_state: usize,
) -> Result<Matches>;
pub fn filtered_search_i8(
self: &NativeIndex,
query: &[i8],
count: usize,
filter: usize,
filter_state: usize,
) -> Result<Matches>;
pub fn filtered_search_u8(
self: &NativeIndex,
query: &[u8],
count: usize,
filter: usize,
filter_state: usize,
) -> Result<Matches>;
pub fn filtered_search_b1x8(
self: &NativeIndex,
query: &[u8],
count: usize,
filter: usize,
filter_state: usize,
) -> Result<Matches>;
pub fn get_f64(self: &NativeIndex, key: u64, buffer: &mut [f64]) -> Result<usize>;
pub fn get_f32(self: &NativeIndex, key: u64, buffer: &mut [f32]) -> Result<usize>;
pub fn get_f16(self: &NativeIndex, key: u64, buffer: &mut [i16]) -> Result<usize>;
pub fn get_i8(self: &NativeIndex, key: u64, buffer: &mut [i8]) -> Result<usize>;
pub fn get_u8(self: &NativeIndex, key: u64, buffer: &mut [u8]) -> Result<usize>;
pub fn get_b1x8(self: &NativeIndex, key: u64, buffer: &mut [u8]) -> Result<usize>;
pub fn remove(self: &NativeIndex, key: u64) -> Result<usize>;
pub fn rename(self: &NativeIndex, from: u64, to: u64) -> Result<usize>;
pub fn contains(self: &NativeIndex, key: u64) -> bool;
pub fn count(self: &NativeIndex, key: u64) -> usize;
pub fn save(self: &NativeIndex, path: &str) -> Result<()>;
pub fn load(self: &NativeIndex, path: &str) -> Result<()>;
pub fn view(self: &NativeIndex, path: &str) -> Result<()>;
pub fn reset(self: &NativeIndex) -> Result<()>;
pub fn memory_usage(self: &NativeIndex) -> usize;
pub fn memory_stats(self: &NativeIndex) -> MemoryStats;
pub fn hardware_acceleration(self: &NativeIndex) -> *const c_char;
pub fn save_to_buffer(self: &NativeIndex, buffer: &mut [u8]) -> Result<()>;
pub fn load_from_buffer(self: &NativeIndex, buffer: &[u8]) -> Result<()>;
pub fn view_from_buffer(self: &NativeIndex, buffer: &[u8]) -> Result<()>;
}
}
pub use ffi::{IndexOptions, MemoryStats, MetricKind, ScalarKind};
pub enum MetricFunction {
B1X8Metric(*mut std::boxed::Box<dyn Fn(*const b1x8, *const b1x8) -> Distance + Send + Sync>),
I8Metric(*mut std::boxed::Box<dyn Fn(*const i8, *const i8) -> Distance + Send + Sync>),
U8Metric(*mut std::boxed::Box<dyn Fn(*const u8, *const u8) -> Distance + Send + Sync>),
F16Metric(*mut std::boxed::Box<dyn Fn(*const f16, *const f16) -> Distance + Send + Sync>),
F32Metric(*mut std::boxed::Box<dyn Fn(*const f32, *const f32) -> Distance + Send + Sync>),
F64Metric(*mut std::boxed::Box<dyn Fn(*const f64, *const f64) -> Distance + Send + Sync>),
}
pub struct Index {
inner: cxx::UniquePtr<ffi::NativeIndex>,
metric_fn: Option<MetricFunction>,
}
unsafe impl Send for Index {}
unsafe impl Sync for Index {}
impl Drop for Index {
fn drop(&mut self) {
if let Some(metric) = &self.metric_fn {
match metric {
MetricFunction::B1X8Metric(pointer) => unsafe {
drop(Box::from_raw(*pointer));
},
MetricFunction::I8Metric(pointer) => unsafe {
drop(Box::from_raw(*pointer));
},
MetricFunction::U8Metric(pointer) => unsafe {
drop(Box::from_raw(*pointer));
},
MetricFunction::F16Metric(pointer) => unsafe {
drop(Box::from_raw(*pointer));
},
MetricFunction::F32Metric(pointer) => unsafe {
drop(Box::from_raw(*pointer));
},
MetricFunction::F64Metric(pointer) => unsafe {
drop(Box::from_raw(*pointer));
},
}
}
}
}
impl Default for ffi::IndexOptions {
fn default() -> Self {
Self {
dimensions: 256,
metric: MetricKind::Cos,
quantization: ScalarKind::BF16,
connectivity: 0,
expansion_add: 0,
expansion_search: 0,
multi: false,
}
}
}
impl Clone for ffi::IndexOptions {
fn clone(&self) -> Self {
ffi::IndexOptions {
dimensions: (self.dimensions),
metric: (self.metric),
quantization: (self.quantization),
connectivity: (self.connectivity),
expansion_add: (self.expansion_add),
expansion_search: (self.expansion_search),
multi: (self.multi),
}
}
}
pub trait VectorType {
fn add(index: &Index, key: Key, vector: &[Self]) -> Result<(), cxx::Exception>
where
Self: Sized;
fn get(index: &Index, key: Key, buffer: &mut [Self]) -> Result<usize, cxx::Exception>
where
Self: Sized;
fn search(index: &Index, query: &[Self], count: usize) -> Result<ffi::Matches, cxx::Exception>
where
Self: Sized;
fn exact_search(
index: &Index,
query: &[Self],
count: usize,
) -> Result<ffi::Matches, cxx::Exception>
where
Self: Sized;
fn filtered_search<F>(
index: &Index,
query: &[Self],
count: usize,
filter: F,
) -> Result<ffi::Matches, cxx::Exception>
where
Self: Sized,
F: Fn(Key) -> bool;
fn change_metric(
index: &mut Index,
metric: std::boxed::Box<dyn Fn(*const Self, *const Self) -> Distance + Send + Sync>,
) -> Result<(), cxx::Exception>
where
Self: Sized;
}
impl VectorType for f32 {
fn search(index: &Index, query: &[Self], count: usize) -> Result<ffi::Matches, cxx::Exception> {
index.inner.search_f32(query, count)
}
fn exact_search(
index: &Index,
query: &[Self],
count: usize,
) -> Result<ffi::Matches, cxx::Exception> {
index.inner.exact_search_f32(query, count)
}
fn get(index: &Index, key: Key, vector: &mut [Self]) -> Result<usize, cxx::Exception> {
index.inner.get_f32(key, vector)
}
fn add(index: &Index, key: Key, vector: &[Self]) -> Result<(), cxx::Exception> {
index.inner.add_f32(key, vector)
}
fn filtered_search<F>(
index: &Index,
query: &[Self],
count: usize,
filter: F,
) -> Result<ffi::Matches, cxx::Exception>
where
Self: Sized,
F: Fn(Key) -> bool,
{
extern "C" fn trampoline<F: Fn(u64) -> bool>(key: u64, closure_address: usize) -> bool {
let closure = closure_address as *const F;
unsafe { (*closure)(key) }
}
let trampoline_fn: usize = trampoline::<F> as *const () as usize;
let closure_address: usize = &filter as *const F as usize;
index
.inner
.filtered_search_f32(query, count, trampoline_fn, closure_address)
}
fn change_metric(
index: &mut Index,
metric: std::boxed::Box<dyn Fn(*const Self, *const Self) -> Distance + Send + Sync>,
) -> Result<(), cxx::Exception> {
type MetricFn = Box<dyn Fn(*const f32, *const f32) -> Distance>;
index.metric_fn = Some(MetricFunction::F32Metric(Box::into_raw(Box::new(metric))));
extern "C" fn trampoline(first: usize, second: usize, closure_address: usize) -> Distance {
let first_ptr = first as *const f32;
let second_ptr = second as *const f32;
let closure: *mut MetricFn = closure_address as *mut MetricFn;
unsafe { (*closure)(first_ptr, second_ptr) }
}
let trampoline_fn: usize = trampoline as *const () as usize;
let closure_address = match index.metric_fn {
Some(MetricFunction::F32Metric(metric)) => metric as *mut () as usize,
_ => panic!("Expected F32Metric"),
};
index.inner.change_metric(trampoline_fn, closure_address);
Ok(())
}
}
impl VectorType for i8 {
fn search(index: &Index, query: &[Self], count: usize) -> Result<ffi::Matches, cxx::Exception> {
index.inner.search_i8(query, count)
}
fn exact_search(
index: &Index,
query: &[Self],
count: usize,
) -> Result<ffi::Matches, cxx::Exception> {
index.inner.exact_search_i8(query, count)
}
fn get(index: &Index, key: Key, vector: &mut [Self]) -> Result<usize, cxx::Exception> {
index.inner.get_i8(key, vector)
}
fn add(index: &Index, key: Key, vector: &[Self]) -> Result<(), cxx::Exception> {
index.inner.add_i8(key, vector)
}
fn filtered_search<F>(
index: &Index,
query: &[Self],
count: usize,
filter: F,
) -> Result<ffi::Matches, cxx::Exception>
where
Self: Sized,
F: Fn(Key) -> bool,
{
extern "C" fn trampoline<F: Fn(u64) -> bool>(key: u64, closure_address: usize) -> bool {
let closure = closure_address as *const F;
unsafe { (*closure)(key) }
}
let trampoline_fn: usize = trampoline::<F> as *const () as usize;
let closure_address: usize = &filter as *const F as usize;
index
.inner
.filtered_search_i8(query, count, trampoline_fn, closure_address)
}
fn change_metric(
index: &mut Index,
metric: std::boxed::Box<dyn Fn(*const Self, *const Self) -> Distance + Send + Sync>,
) -> Result<(), cxx::Exception> {
type MetricFn = Box<dyn Fn(*const i8, *const i8) -> Distance>;
index.metric_fn = Some(MetricFunction::I8Metric(Box::into_raw(Box::new(metric))));
extern "C" fn trampoline(first: usize, second: usize, closure_address: usize) -> Distance {
let first_ptr = first as *const i8;
let second_ptr = second as *const i8;
let closure: *mut MetricFn = closure_address as *mut MetricFn;
unsafe { (*closure)(first_ptr, second_ptr) }
}
let trampoline_fn: usize = trampoline as *const () as usize;
let closure_address = match index.metric_fn {
Some(MetricFunction::I8Metric(metric)) => metric as *mut () as usize,
_ => panic!("Expected I8Metric"),
};
index.inner.change_metric(trampoline_fn, closure_address);
Ok(())
}
}
impl VectorType for u8 {
fn search(index: &Index, query: &[Self], count: usize) -> Result<ffi::Matches, cxx::Exception> {
index.inner.search_u8(query, count)
}
fn exact_search(
index: &Index,
query: &[Self],
count: usize,
) -> Result<ffi::Matches, cxx::Exception> {
index.inner.exact_search_u8(query, count)
}
fn get(index: &Index, key: Key, vector: &mut [Self]) -> Result<usize, cxx::Exception> {
index.inner.get_u8(key, vector)
}
fn add(index: &Index, key: Key, vector: &[Self]) -> Result<(), cxx::Exception> {
index.inner.add_u8(key, vector)
}
fn filtered_search<F>(
index: &Index,
query: &[Self],
count: usize,
filter: F,
) -> Result<ffi::Matches, cxx::Exception>
where
Self: Sized,
F: Fn(Key) -> bool,
{
extern "C" fn trampoline<F: Fn(u64) -> bool>(key: u64, closure_address: usize) -> bool {
let closure = closure_address as *const F;
unsafe { (*closure)(key) }
}
let trampoline_fn: usize = trampoline::<F> as *const () as usize;
let closure_address: usize = &filter as *const F as usize;
index
.inner
.filtered_search_u8(query, count, trampoline_fn, closure_address)
}
fn change_metric(
index: &mut Index,
metric: std::boxed::Box<dyn Fn(*const Self, *const Self) -> Distance + Send + Sync>,
) -> Result<(), cxx::Exception> {
type MetricFn = Box<dyn Fn(*const u8, *const u8) -> Distance>;
index.metric_fn = Some(MetricFunction::U8Metric(Box::into_raw(Box::new(metric))));
extern "C" fn trampoline(first: usize, second: usize, closure_address: usize) -> Distance {
let first_ptr = first as *const u8;
let second_ptr = second as *const u8;
let closure: *mut MetricFn = closure_address as *mut MetricFn;
unsafe { (*closure)(first_ptr, second_ptr) }
}
let trampoline_fn: usize = trampoline as *const () as usize;
let closure_address = match index.metric_fn {
Some(MetricFunction::U8Metric(metric)) => metric as *mut () as usize,
_ => panic!("Expected U8Metric"),
};
index.inner.change_metric(trampoline_fn, closure_address);
Ok(())
}
}
impl VectorType for f64 {
fn search(index: &Index, query: &[Self], count: usize) -> Result<ffi::Matches, cxx::Exception> {
index.inner.search_f64(query, count)
}
fn exact_search(
index: &Index,
query: &[Self],
count: usize,
) -> Result<ffi::Matches, cxx::Exception> {
index.inner.exact_search_f64(query, count)
}
fn get(index: &Index, key: Key, vector: &mut [Self]) -> Result<usize, cxx::Exception> {
index.inner.get_f64(key, vector)
}
fn add(index: &Index, key: Key, vector: &[Self]) -> Result<(), cxx::Exception> {
index.inner.add_f64(key, vector)
}
fn filtered_search<F>(
index: &Index,
query: &[Self],
count: usize,
filter: F,
) -> Result<ffi::Matches, cxx::Exception>
where
Self: Sized,
F: Fn(Key) -> bool,
{
extern "C" fn trampoline<F: Fn(u64) -> bool>(key: u64, closure_address: usize) -> bool {
let closure = closure_address as *const F;
unsafe { (*closure)(key) }
}
let trampoline_fn: usize = trampoline::<F> as *const () as usize;
let closure_address: usize = &filter as *const F as usize;
index
.inner
.filtered_search_f64(query, count, trampoline_fn, closure_address)
}
fn change_metric(
index: &mut Index,
metric: std::boxed::Box<dyn Fn(*const Self, *const Self) -> Distance + Send + Sync>,
) -> Result<(), cxx::Exception> {
type MetricFn = Box<dyn Fn(*const f64, *const f64) -> Distance>;
index.metric_fn = Some(MetricFunction::F64Metric(Box::into_raw(Box::new(metric))));
extern "C" fn trampoline(first: usize, second: usize, closure_address: usize) -> Distance {
let first_ptr = first as *const f64;
let second_ptr = second as *const f64;
let closure: *mut MetricFn = closure_address as *mut MetricFn;
unsafe { (*closure)(first_ptr, second_ptr) }
}
let trampoline_fn: usize = trampoline as *const () as usize;
let closure_address = match index.metric_fn {
Some(MetricFunction::F64Metric(metric)) => metric as *mut () as usize,
_ => panic!("Expected F64Metric"),
};
index.inner.change_metric(trampoline_fn, closure_address);
Ok(())
}
}
impl VectorType for f16 {
fn search(index: &Index, query: &[Self], count: usize) -> Result<ffi::Matches, cxx::Exception> {
index.inner.search_f16(f16::to_i16s(query), count)
}
fn exact_search(
index: &Index,
query: &[Self],
count: usize,
) -> Result<ffi::Matches, cxx::Exception> {
index.inner.exact_search_f16(f16::to_i16s(query), count)
}
fn get(index: &Index, key: Key, vector: &mut [Self]) -> Result<usize, cxx::Exception> {
index.inner.get_f16(key, f16::to_mut_i16s(vector))
}
fn add(index: &Index, key: Key, vector: &[Self]) -> Result<(), cxx::Exception> {
index.inner.add_f16(key, f16::to_i16s(vector))
}
fn filtered_search<F>(
index: &Index,
query: &[Self],
count: usize,
filter: F,
) -> Result<ffi::Matches, cxx::Exception>
where
Self: Sized,
F: Fn(Key) -> bool,
{
extern "C" fn trampoline<F: Fn(u64) -> bool>(key: u64, closure_address: usize) -> bool {
let closure = closure_address as *const F;
unsafe { (*closure)(key) }
}
let trampoline_fn: usize = trampoline::<F> as *const () as usize;
let closure_address: usize = &filter as *const F as usize;
index
.inner
.filtered_search_f16(f16::to_i16s(query), count, trampoline_fn, closure_address)
}
fn change_metric(
index: &mut Index,
metric: std::boxed::Box<dyn Fn(*const Self, *const Self) -> Distance + Send + Sync>,
) -> Result<(), cxx::Exception> {
type MetricFn = Box<dyn Fn(*const f16, *const f16) -> Distance>;
index.metric_fn = Some(MetricFunction::F16Metric(Box::into_raw(Box::new(metric))));
extern "C" fn trampoline(first: usize, second: usize, closure_address: usize) -> Distance {
let first_ptr = first as *const f16;
let second_ptr = second as *const f16;
let closure: *mut MetricFn = closure_address as *mut MetricFn;
unsafe { (*closure)(first_ptr, second_ptr) }
}
let trampoline_fn: usize = trampoline as *const () as usize;
let closure_address = match index.metric_fn {
Some(MetricFunction::F16Metric(metric)) => metric as *mut () as usize,
_ => panic!("Expected F16Metric"),
};
index.inner.change_metric(trampoline_fn, closure_address);
Ok(())
}
}
impl VectorType for b1x8 {
fn search(index: &Index, query: &[Self], count: usize) -> Result<ffi::Matches, cxx::Exception> {
index.inner.search_b1x8(b1x8::to_u8s(query), count)
}
fn exact_search(
index: &Index,
query: &[Self],
count: usize,
) -> Result<ffi::Matches, cxx::Exception> {
index.inner.exact_search_b1x8(b1x8::to_u8s(query), count)
}
fn get(index: &Index, key: Key, vector: &mut [Self]) -> Result<usize, cxx::Exception> {
index.inner.get_b1x8(key, b1x8::to_mut_u8s(vector))
}
fn add(index: &Index, key: Key, vector: &[Self]) -> Result<(), cxx::Exception> {
index.inner.add_b1x8(key, b1x8::to_u8s(vector))
}
fn filtered_search<F>(
index: &Index,
query: &[Self],
count: usize,
filter: F,
) -> Result<ffi::Matches, cxx::Exception>
where
Self: Sized,
F: Fn(Key) -> bool,
{
extern "C" fn trampoline<F: Fn(u64) -> bool>(key: u64, closure_address: usize) -> bool {
let closure = closure_address as *const F;
unsafe { (*closure)(key) }
}
let trampoline_fn: usize = trampoline::<F> as *const () as usize;
let closure_address: usize = &filter as *const F as usize;
index
.inner
.filtered_search_b1x8(b1x8::to_u8s(query), count, trampoline_fn, closure_address)
}
fn change_metric(
index: &mut Index,
metric: std::boxed::Box<dyn Fn(*const Self, *const Self) -> Distance + Send + Sync>,
) -> Result<(), cxx::Exception> {
type MetricFn = Box<dyn Fn(*const b1x8, *const b1x8) -> Distance>;
index.metric_fn = Some(MetricFunction::B1X8Metric(Box::into_raw(Box::new(metric))));
extern "C" fn trampoline(first: usize, second: usize, closure_address: usize) -> Distance {
let first_ptr = first as *const b1x8;
let second_ptr = second as *const b1x8;
let closure: *mut MetricFn = closure_address as *mut MetricFn;
unsafe { (*closure)(first_ptr, second_ptr) }
}
let trampoline_fn: usize = trampoline as *const () as usize;
let closure_address = match index.metric_fn {
Some(MetricFunction::B1X8Metric(metric)) => metric as *mut () as usize,
_ => panic!("Expected F1X8Metric"),
};
index.inner.change_metric(trampoline_fn, closure_address);
Ok(())
}
}
impl Index {
pub fn new(options: &ffi::IndexOptions) -> Result<Self, cxx::Exception> {
match ffi::new_native_index(options) {
Ok(inner) => Result::Ok(Self {
inner,
metric_fn: None,
}),
Err(err) => Err(err),
}
}
pub fn expansion_add(self: &Index) -> usize {
self.inner.expansion_add()
}
pub fn expansion_search(self: &Index) -> usize {
self.inner.expansion_search()
}
pub fn change_expansion_add(self: &Index, n: usize) {
self.inner.change_expansion_add(n)
}
pub fn change_expansion_search(self: &Index, n: usize) {
self.inner.change_expansion_search(n)
}
pub fn change_metric_kind(self: &Index, metric: ffi::MetricKind) {
self.inner.change_metric_kind(metric)
}
pub fn change_metric<T: VectorType>(
self: &mut Index,
metric: std::boxed::Box<dyn Fn(*const T, *const T) -> Distance + Send + Sync>,
) {
T::change_metric(self, metric).unwrap();
}
pub fn hardware_acceleration(&self) -> String {
use core::ffi::CStr;
unsafe {
let c_str = CStr::from_ptr(self.inner.hardware_acceleration());
c_str.to_string_lossy().into_owned()
}
}
pub fn search<T: VectorType>(
self: &Index,
query: &[T],
count: usize,
) -> Result<ffi::Matches, cxx::Exception> {
T::search(self, query, count)
}
pub fn exact_search<T: VectorType>(
self: &Index,
query: &[T],
count: usize,
) -> Result<ffi::Matches, cxx::Exception> {
T::exact_search(self, query, count)
}
pub fn filtered_search<T: VectorType, F>(
self: &Index,
query: &[T],
count: usize,
filter: F,
) -> Result<ffi::Matches, cxx::Exception>
where
F: Fn(Key) -> bool,
{
T::filtered_search(self, query, count, filter)
}
pub fn add<T: VectorType>(self: &Index, key: Key, vector: &[T]) -> Result<(), cxx::Exception> {
T::add(self, key, vector)
}
pub fn get<T: VectorType>(
self: &Index,
key: Key,
vector: &mut [T],
) -> Result<usize, cxx::Exception> {
T::get(self, key, vector)
}
pub fn export<T: VectorType + Default + Clone>(
self: &Index,
key: Key,
vector: &mut Vec<T>,
) -> Result<usize, cxx::Exception> {
let dim = self.dimensions();
let max_matches = self.count(key);
vector.resize(dim * max_matches, T::default());
let matches = T::get(self, key, &mut vector[..])?;
vector.resize(dim * matches, T::default());
Ok(matches)
}
pub fn reserve(self: &Index, capacity: usize) -> Result<(), cxx::Exception> {
self.inner.reserve(capacity)
}
pub fn reserve_capacity_and_threads(
self: &Index,
capacity: usize,
threads: usize,
) -> Result<(), cxx::Exception> {
self.inner.reserve_capacity_and_threads(capacity, threads)
}
pub fn dimensions(self: &Index) -> usize {
self.inner.dimensions()
}
pub fn connectivity(self: &Index) -> usize {
self.inner.connectivity()
}
pub fn size(self: &Index) -> usize {
self.inner.size()
}
pub fn capacity(self: &Index) -> usize {
self.inner.capacity()
}
pub fn serialized_length(self: &Index) -> usize {
self.inner.serialized_length()
}
pub fn remove(self: &Index, key: Key) -> Result<usize, cxx::Exception> {
self.inner.remove(key)
}
pub fn rename(self: &Index, from: Key, to: Key) -> Result<usize, cxx::Exception> {
self.inner.rename(from, to)
}
pub fn contains(self: &Index, key: Key) -> bool {
self.inner.contains(key)
}
pub fn count(self: &Index, key: Key) -> usize {
self.inner.count(key)
}
pub fn save(self: &Index, path: &str) -> Result<(), cxx::Exception> {
self.inner.save(path)
}
pub fn load(self: &Index, path: &str) -> Result<(), cxx::Exception> {
self.inner.load(path)
}
pub fn view(self: &Index, path: &str) -> Result<(), cxx::Exception> {
self.inner.view(path)
}
pub fn reset(self: &Index) -> Result<(), cxx::Exception> {
self.inner.reset()
}
pub fn memory_usage(self: &Index) -> usize {
self.inner.memory_usage()
}
pub fn memory_stats(self: &Index) -> ffi::MemoryStats {
self.inner.memory_stats()
}
pub fn save_to_buffer(self: &Index, buffer: &mut [u8]) -> Result<(), cxx::Exception> {
self.inner.save_to_buffer(buffer)
}
pub fn load_from_buffer(self: &Index, buffer: &[u8]) -> Result<(), cxx::Exception> {
self.inner.load_from_buffer(buffer)
}
pub unsafe fn view_from_buffer(self: &Index, buffer: &[u8]) -> Result<(), cxx::Exception> {
self.inner.view_from_buffer(buffer)
}
}
pub fn new_index(options: &ffi::IndexOptions) -> Result<Index, cxx::Exception> {
Index::new(options)
}
#[cfg(test)]
mod tests {
use crate::ffi::IndexOptions;
use crate::ffi::MetricKind;
use crate::ffi::ScalarKind;
use crate::b1x8;
use crate::new_index;
use crate::Index;
use crate::Key;
use std::env;
#[test]
fn print_specs() {
println!("--------------------------------------------------");
println!("OS: {}", env::consts::OS);
println!(
"Rust version: {}",
env::var("RUST_VERSION").unwrap_or_else(|_| "unknown".into())
);
let f64_index = Index::new(&IndexOptions {
dimensions: 256,
metric: MetricKind::Cos,
quantization: ScalarKind::F64,
..Default::default()
})
.unwrap();
let f32_index = Index::new(&IndexOptions {
dimensions: 256,
metric: MetricKind::Cos,
quantization: ScalarKind::F32,
..Default::default()
})
.unwrap();
let bf16_index = Index::new(&IndexOptions {
dimensions: 256,
metric: MetricKind::Cos,
quantization: ScalarKind::BF16,
..Default::default()
})
.unwrap();
let f16_index = Index::new(&IndexOptions {
dimensions: 256,
metric: MetricKind::Cos,
quantization: ScalarKind::F16,
..Default::default()
})
.unwrap();
let e5m2_index = Index::new(&IndexOptions {
dimensions: 256,
metric: MetricKind::Cos,
quantization: ScalarKind::E5M2,
..Default::default()
})
.unwrap();
let e4m3_index = Index::new(&IndexOptions {
dimensions: 256,
metric: MetricKind::Cos,
quantization: ScalarKind::E4M3,
..Default::default()
})
.unwrap();
let i8_index = Index::new(&IndexOptions {
dimensions: 256,
metric: MetricKind::Cos,
quantization: ScalarKind::I8,
..Default::default()
})
.unwrap();
let u8_index = Index::new(&IndexOptions {
dimensions: 256,
metric: MetricKind::Cos,
quantization: ScalarKind::U8,
..Default::default()
})
.unwrap();
let b1_index = Index::new(&IndexOptions {
dimensions: 256,
metric: MetricKind::Hamming,
quantization: ScalarKind::B1,
..Default::default()
})
.unwrap();
println!(
"f64 hardware acceleration: {}",
f64_index.hardware_acceleration()
);
println!(
"f32 hardware acceleration: {}",
f32_index.hardware_acceleration()
);
println!(
"bf16 hardware acceleration: {}",
bf16_index.hardware_acceleration()
);
println!(
"f16 hardware acceleration: {}",
f16_index.hardware_acceleration()
);
println!(
"e5m2 hardware acceleration: {}",
e5m2_index.hardware_acceleration()
);
println!(
"e4m3 hardware acceleration: {}",
e4m3_index.hardware_acceleration()
);
println!(
"i8 hardware acceleration: {}",
i8_index.hardware_acceleration()
);
println!(
"u8 hardware acceleration: {}",
u8_index.hardware_acceleration()
);
println!(
"b1 hardware acceleration: {}",
b1_index.hardware_acceleration()
);
println!("--------------------------------------------------");
}
#[test]
fn new_index_does_not_preallocate_members() {
let options = IndexOptions {
dimensions: 8,
quantization: ScalarKind::F32,
..Default::default()
};
let index = Index::new(&options).unwrap();
assert_eq!(index.capacity(), 0);
}
#[test]
fn index_survives_box_and_arc_moves_after_construction() {
let options = IndexOptions {
dimensions: 4,
quantization: ScalarKind::F32,
..Default::default()
};
let vector = [0.25f32, 0.5, 0.75, 1.0];
let boxed = Box::new(Index::new(&options).unwrap());
boxed.reserve(8).unwrap();
boxed.add(7, &vector).unwrap();
let boxed_matches = boxed.search(&vector, 1).unwrap();
assert_eq!(boxed_matches.keys.first().copied(), Some(7));
let arc = std::sync::Arc::new(Index::new(&options).unwrap());
let moved_arc = std::sync::Arc::clone(&arc);
moved_arc.reserve_capacity_and_threads(8, 2).unwrap();
moved_arc.add(9, &vector).unwrap();
let arc_matches = arc.search(&vector, 1).unwrap();
assert_eq!(arc_matches.keys.first().copied(), Some(9));
}
#[test]
fn add_get_vector() {
let options = IndexOptions {
dimensions: 5,
quantization: ScalarKind::F32,
..Default::default()
};
let index = Index::new(&options).unwrap();
assert!(index.reserve(10).is_ok());
let first: [f32; 5] = [0.2, 0.1, 0.2, 0.1, 0.3];
let second: [f32; 5] = [0.3, 0.2, 0.4, 0.0, 0.1];
let too_long: [f32; 6] = [0.3, 0.2, 0.4, 0.0, 0.1, 0.1];
let too_short: [f32; 4] = [0.3, 0.2, 0.4, 0.0];
assert!(index.add(1, &first).is_ok());
assert!(index.add(2, &second).is_ok());
assert!(index.add(3, &too_long).is_err());
assert!(index.add(4, &too_short).is_err());
assert_eq!(index.size(), 2);
let mut found_vec: Vec<f32> = Vec::new();
assert_eq!(index.export(1, &mut found_vec).unwrap(), 1);
assert_eq!(found_vec.len(), 5);
assert_eq!(found_vec, first.to_vec());
let mut found_slice = [0.0f32; 5];
assert_eq!(index.get(1, &mut found_slice).unwrap(), 1);
assert_eq!(found_slice, first);
let mut found = [0.0f32; 6]; let result = index.get(1, &mut found);
assert!(result.is_err());
}
#[test]
fn quantized_add_search() {
let metrics = [MetricKind::Cos, MetricKind::L2sq, MetricKind::IP];
let quantizations = [
ScalarKind::F32,
ScalarKind::F64,
ScalarKind::F16,
ScalarKind::BF16,
ScalarKind::I8,
ScalarKind::E5M2,
ScalarKind::E4M3,
ScalarKind::E3M2,
ScalarKind::E2M3,
];
let dimensions: usize = 64;
let first: Vec<f32> = (0..dimensions).map(|i| (i as f32) * 0.1).collect();
let second: Vec<f32> = (0..dimensions)
.map(|i| ((dimensions - i) as f32) * 0.1)
.collect();
for metric in metrics {
for quantization in quantizations {
let index = Index::new(&IndexOptions {
dimensions,
metric,
quantization,
..Default::default()
})
.unwrap();
assert!(index.reserve(10).is_ok());
assert!(index.add(1, &first).is_ok());
assert!(index.add(2, &second).is_ok());
assert_eq!(index.size(), 2, "{metric:?}/{quantization:?}: wrong size");
let results = index.search(&first, 2).unwrap();
assert_eq!(
results.keys[0], 1,
"self-match failed for {metric:?}/{quantization:?}"
);
}
}
}
#[test]
fn search_vector() {
let options = IndexOptions {
dimensions: 5,
quantization: ScalarKind::F32,
..Default::default()
};
let index = Index::new(&options).unwrap();
assert!(index.reserve(10).is_ok());
let first: [f32; 5] = [0.2, 0.1, 0.2, 0.1, 0.3];
let second: [f32; 5] = [0.3, 0.2, 0.4, 0.0, 0.1];
let too_long: [f32; 6] = [0.3, 0.2, 0.4, 0.0, 0.1, 0.1];
let too_short: [f32; 4] = [0.3, 0.2, 0.4, 0.0];
let empty_results = index.search(&first, 10).unwrap();
assert_eq!(empty_results.keys.len(), 0);
assert!(index.add(1, &first).is_ok());
assert!(index.add(2, &second).is_ok());
assert_eq!(index.size(), 2);
assert!(index.search(&too_long, 1).is_err());
assert!(index.search(&too_short, 1).is_err());
}
#[test]
fn add_remove_vector() {
let options = IndexOptions {
dimensions: 4,
metric: MetricKind::IP,
quantization: ScalarKind::F64,
connectivity: 10,
expansion_add: 128,
expansion_search: 3,
..Default::default()
};
let index = Index::new(&options).unwrap();
assert!(index.reserve(10).is_ok());
assert!(index.capacity() >= 10);
let first: [f32; 4] = [0.2, 0.1, 0.2, 0.1];
let second: [f32; 4] = [0.3, 0.2, 0.4, 0.0];
let id1 = 483367403120493160;
let id2 = 483367403120558696;
let id3 = 483367403120624232;
let id4 = 483367403120624233;
assert!(!index.contains(id1));
assert_eq!(index.count(id1), 0);
assert!(index.add(id1, &first).is_ok());
assert!(index.contains(id1));
assert_eq!(index.count(id1), 1);
assert_eq!(index.rename(id1, id2).unwrap(), 1);
assert!(!index.contains(id1));
assert!(index.contains(id2));
let mut found_slice = [0.0f32; 4];
assert_eq!(index.get(id2, &mut found_slice).unwrap(), 1);
assert_eq!(found_slice, first);
assert!(index.remove(id2).is_ok());
assert!(!index.contains(id2));
assert_eq!(index.count(id2), 0);
assert!(index.add(id3, &second).is_ok());
let mut found_slice = [0.0f32; 4];
assert_eq!(index.get(id3, &mut found_slice).unwrap(), 1);
assert!(index.remove(id3).is_ok());
assert!(index.add(id4, &second).is_ok());
let mut found_slice = [0.0f32; 4];
assert_eq!(index.get(id4, &mut found_slice).unwrap(), 1);
assert!(index.remove(id4).is_ok());
assert_eq!(index.size(), 0);
}
#[test]
fn integration() {
let mut options = IndexOptions {
dimensions: 5,
..Default::default()
};
let index = Index::new(&options).unwrap();
assert!(index.expansion_add() > 0);
assert!(index.expansion_search() > 0);
assert!(index.reserve(10).is_ok());
assert!(index.capacity() >= 10);
assert!(index.connectivity() != 0);
assert_eq!(index.dimensions(), 5);
assert_eq!(index.size(), 0);
let first: [f32; 5] = [0.2, 0.1, 0.2, 0.1, 0.3];
let second: [f32; 5] = [0.3, 0.2, 0.4, 0.0, 0.1];
println!("--------------------------------------------------");
println!(
"before add, memory_usage: {} \
cap: {} \
",
index.memory_usage(),
index.capacity(),
);
index.change_expansion_add(10);
assert_eq!(index.expansion_add(), 10);
assert!(index.add(42, &first).is_ok());
index.change_expansion_add(12);
assert_eq!(index.expansion_add(), 12);
assert!(index.add(43, &second).is_ok());
assert_eq!(index.size(), 2);
println!(
"after add, memory_usage: {} \
cap: {} \
",
index.memory_usage(),
index.capacity(),
);
index.change_expansion_search(10);
assert_eq!(index.expansion_search(), 10);
let results = index.search(&first, 10).unwrap();
println!("{:?}", results);
assert_eq!(results.keys.len(), 2);
index.change_expansion_search(12);
assert_eq!(index.expansion_search(), 12);
let results = index.search(&first, 10).unwrap();
println!("{:?}", results);
assert_eq!(results.keys.len(), 2);
println!("--------------------------------------------------");
let stats = index.memory_stats();
assert!(
stats.vectors_allocated > 0,
"vectors should have allocated memory"
);
assert!(index.save("index.rust.usearch").is_ok());
assert!(index.load("index.rust.usearch").is_ok());
let results = index.search(&first, 10).unwrap();
assert!(results.keys.contains(&42), "key 42 survives save/load");
assert!(index.view("index.rust.usearch").is_ok());
assert!(new_index(&options).is_ok());
options.metric = MetricKind::L2sq;
assert!(new_index(&options).is_ok());
options.metric = MetricKind::Cos;
assert!(new_index(&options).is_ok());
options.metric = MetricKind::Haversine;
options.quantization = ScalarKind::F32;
options.dimensions = 2;
assert!(new_index(&options).is_ok());
let mut serialization_buffer = vec![0; index.serialized_length()];
assert!(index.save_to_buffer(&mut serialization_buffer).is_ok());
let deserialized_index = new_index(&options).unwrap();
assert!(deserialized_index
.load_from_buffer(&serialization_buffer)
.is_ok());
assert_eq!(index.size(), deserialized_index.size());
let results = deserialized_index.search(&first, 10).unwrap();
assert!(
results.keys.contains(&42),
"key 42 survives buffer round-trip"
);
let viewed_index = Index::new(&IndexOptions {
dimensions: 5,
..Default::default()
})
.unwrap();
assert!(unsafe { viewed_index.view_from_buffer(&serialization_buffer) }.is_ok());
assert_eq!(viewed_index.size(), index.size());
let results = viewed_index.search(&first, 10).unwrap();
assert!(
results.keys.contains(&42),
"key 42 visible via view_from_buffer"
);
assert_ne!(index.memory_usage(), 0);
assert!(index.reset().is_ok());
assert_eq!(index.size(), 0);
assert_eq!(index.memory_usage(), 0);
assert!(index.reserve(10).is_ok());
assert!(index.add(100, &first).is_ok());
assert!(index.add(101, &second).is_ok());
assert_eq!(index.size(), 2);
let results = index.search(&first, 10).unwrap();
assert_eq!(results.keys.len(), 2);
options.metric = MetricKind::Haversine;
let mut opts = options.clone();
assert_eq!(opts.metric, options.metric);
assert_eq!(opts.quantization, options.quantization);
assert_eq!(opts, options);
opts.metric = MetricKind::Cos;
assert_ne!(opts.metric, options.metric);
assert!(new_index(&opts).is_ok());
}
#[test]
fn search_with_stateless_filter() {
let options = IndexOptions {
dimensions: 5,
..Default::default()
};
let index = Index::new(&options).unwrap();
index.reserve(10).unwrap();
let first: [f32; 5] = [0.2, 0.1, 0.2, 0.1, 0.3];
let second: [f32; 5] = [0.3, 0.2, 0.4, 0.0, 0.1];
index.add(1, &first).unwrap();
index.add(2, &second).unwrap();
let is_odd = |key: Key| key % 2 == 1;
let query = vec![0.2, 0.1, 0.2, 0.1, 0.3]; let results = index.filtered_search(&query, 10, is_odd).unwrap();
assert!(
results.keys.iter().all(|&key| key % 2 == 1),
"All keys must be odd"
);
}
#[test]
fn search_with_stateful_filter() {
use std::collections::HashSet;
let options = IndexOptions {
dimensions: 5,
..Default::default()
};
let index = Index::new(&options).unwrap();
index.reserve(10).unwrap();
let first: [f32; 5] = [0.2, 0.1, 0.2, 0.1, 0.3];
index.add(1, &first).unwrap();
index.add(2, &first).unwrap();
let allowed_keys = vec![1, 2, 3].into_iter().collect::<HashSet<Key>>();
let filter_keys = allowed_keys.clone();
let stateful_filter = move |key: Key| filter_keys.contains(&key);
let query = vec![0.2, 0.1, 0.2, 0.1, 0.3]; let results = index.filtered_search(&query, 10, stateful_filter).unwrap();
assert!(
results.keys.iter().all(|&key| allowed_keys.contains(&key)),
"All keys must be in the allowed set"
);
}
#[test]
fn zero_distances() {
let options = IndexOptions {
dimensions: 8,
metric: MetricKind::L2sq,
quantization: ScalarKind::F16,
..Default::default()
};
let index = new_index(&options).unwrap();
index.reserve(10).unwrap();
index
.add(0, &[0.4, 0.1, 0.1, 0.0, 0.0, 0.0, 0.0, 0.0])
.unwrap();
index
.add(1, &[0.5, 0.1, 0.1, 0.0, 0.0, 0.0, 0.0, 0.0])
.unwrap();
index
.add(2, &[0.6, 0.1, 0.1, 0.0, 0.0, 0.0, 0.0, 0.0])
.unwrap();
let matches = index
.search(&[0.05, 0.1, 0.1, 0.0, 0.0, 0.0, 0.0, 0.0], 2)
.unwrap();
for distance in matches.distances.iter() {
assert_ne!(*distance, 0.0);
}
}
#[test]
fn exact_search() {
use std::collections::HashSet;
let options = IndexOptions {
dimensions: 4,
metric: MetricKind::L2sq,
quantization: ScalarKind::F32,
..Default::default()
};
let index = new_index(&options).unwrap();
index.reserve(100).unwrap();
for i in 0..100 {
let vec = vec![
i as f32 * 0.1,
(i as f32 * 0.05).sin(),
(i as f32 * 0.05).cos(),
0.0,
];
index.add(i, &vec).unwrap();
}
let query = vec![4.5, 0.0, 1.0, 0.0];
let approx_matches = index.search(&query, 10).unwrap();
let exact_matches = index.exact_search(&query, 10).unwrap();
let approx_keys: HashSet<Key> = approx_matches.keys.iter().cloned().collect();
let exact_keys: HashSet<Key> = exact_matches.keys.iter().cloned().collect();
assert_eq!(approx_matches.keys.len(), 10);
assert_eq!(exact_matches.keys.len(), 10);
assert!(exact_matches.distances[0] <= approx_matches.distances[0]);
println!(
"Approximate search first match: key={}, distance={}",
approx_matches.keys[0], approx_matches.distances[0]
);
println!(
"Exact search first match: key={}, distance={}",
exact_matches.keys[0], exact_matches.distances[0]
);
let intersection: HashSet<_> = approx_keys.intersection(&exact_keys).collect();
println!(
"Number of common results between approximate and exact search: {}",
intersection.len()
);
}
#[test]
fn change_distance_function() {
let options = IndexOptions {
dimensions: 2, ..Default::default()
};
let mut index = Index::new(&options).unwrap();
index.reserve(10).unwrap();
let vector: [f32; 2] = [1.0, 0.0];
index.add(1, &vector).unwrap();
let first_factor: f32 = 2.0;
let second_factor: f32 = 0.7;
let stateful_distance = Box::new(move |a: *const f32, b: *const f32| unsafe {
let a_slice = std::slice::from_raw_parts(a, 2);
let b_slice = std::slice::from_raw_parts(b, 2);
(a_slice[0] - b_slice[0]).abs() * first_factor
+ (a_slice[1] - b_slice[1]).abs() * second_factor
});
index.change_metric(stateful_distance);
let another_vector: [f32; 2] = [0.0, 1.0];
index.add(2, &another_vector).unwrap();
}
#[test]
fn binary_vectors_and_hamming_distance() {
let index = Index::new(&IndexOptions {
dimensions: 8,
metric: MetricKind::Hamming,
quantization: ScalarKind::B1,
..Default::default()
})
.unwrap();
let vector42: Vec<b1x8> = vec![b1x8(0b00001111)];
let vector43: Vec<b1x8> = vec![b1x8(0b11110000)];
let query: Vec<b1x8> = vec![b1x8(0b01111000)];
index.reserve(10).unwrap();
index.add(42, &vector42).unwrap();
index.add(43, &vector43).unwrap();
let results = index.search(&query, 5).unwrap();
assert_eq!(results.keys.len(), 2);
assert_eq!(results.keys[0], 43);
assert_eq!(results.distances[0], 2.0);
assert_eq!(results.keys[1], 42);
assert_eq!(results.distances[1], 6.0);
}
#[test]
fn multi_index() {
let options = IndexOptions {
dimensions: 4,
metric: MetricKind::L2sq,
quantization: ScalarKind::F32,
multi: true,
..Default::default()
};
let index = Index::new(&options).unwrap();
index.reserve(10).unwrap();
let vec_a: [f32; 4] = [1.0, 0.0, 0.0, 0.0];
let vec_b: [f32; 4] = [0.0, 1.0, 0.0, 0.0];
let key: Key = 42;
index.add(key, &vec_a).unwrap();
index.add(key, &vec_b).unwrap();
assert_eq!(index.size(), 2);
assert_eq!(index.count(key), 2);
assert!(index.contains(key));
let mut buf = [0.0f32; 8]; let found = index.get(key, &mut buf).unwrap();
assert_eq!(found, 2);
let mut exported: Vec<f32> = Vec::new();
assert_eq!(index.export(key, &mut exported).unwrap(), 2);
assert_eq!(exported.len(), 8);
let results = index.search(&vec_a, 5).unwrap();
assert!(results.keys.contains(&key));
}
#[test]
fn concurrency() {
use fork_union as fu;
use rand::{RngExt, SeedableRng};
use rand_chacha::ChaCha8Rng;
use rand_distr::Uniform;
use std::sync::Arc;
const DIMENSIONS: usize = 128;
const VECTOR_COUNT: usize = 1000;
const THREAD_COUNT: usize = 4;
let options = IndexOptions {
dimensions: DIMENSIONS,
metric: MetricKind::Cos,
quantization: ScalarKind::F32,
..Default::default()
};
let index = Arc::new(Index::new(&options).unwrap());
index
.reserve_capacity_and_threads(VECTOR_COUNT, THREAD_COUNT)
.unwrap();
let seed = 42; let mut rng = ChaCha8Rng::seed_from_u64(seed);
let uniform = Uniform::new(-1.0f32, 1.0f32).unwrap();
let mut reference_vectors: Vec<[f32; DIMENSIONS]> = Vec::with_capacity(VECTOR_COUNT);
for _ in 0..VECTOR_COUNT {
let mut vector = [0.0f32; DIMENSIONS];
for item in vector.iter_mut().take(DIMENSIONS) {
*item = rng.sample(uniform);
}
reference_vectors.push(vector);
}
let mut pool = fu::spawn(THREAD_COUNT);
pool.for_n(VECTOR_COUNT, |prong| {
let index_clone = Arc::clone(&index);
let i = prong.task_index;
let vector = reference_vectors[i];
index_clone.add(i as u64, &vector).unwrap();
});
assert_eq!(index.size(), VECTOR_COUNT);
let mut pool = fu::spawn(THREAD_COUNT);
let validation_results = Arc::new(std::sync::Mutex::new(Vec::new()));
pool.for_n(VECTOR_COUNT, |prong| {
let index_clone = Arc::clone(&index);
let results_clone = Arc::clone(&validation_results);
let i = prong.task_index;
let expected_vector = &reference_vectors[i];
let mut retrieved_vector = [0.0f32; DIMENSIONS];
let count = index_clone.get(i as u64, &mut retrieved_vector).unwrap();
assert_eq!(count, 1);
let matches = retrieved_vector
.iter()
.zip(expected_vector.iter())
.all(|(a, b)| (a - b).abs() < 1e-6);
let mut results = results_clone.lock().unwrap();
results.push(matches);
});
let validation_results = validation_results.lock().unwrap();
assert_eq!(validation_results.len(), VECTOR_COUNT);
assert!(
validation_results.iter().all(|&x| x),
"All retrieved vectors should match the original ones"
);
let mut pool = fu::spawn(THREAD_COUNT);
let search_results = Arc::new(std::sync::Mutex::new(Vec::new()));
pool.for_n(100, |prong| {
let index_clone = Arc::clone(&index);
let results_clone = Arc::clone(&search_results);
let query_idx = prong.task_index % VECTOR_COUNT;
let query_vector = &reference_vectors[query_idx];
let matches = index_clone.exact_search(query_vector, 10).unwrap();
let exact_match_found = !matches.keys.is_empty()
&& matches.keys[0] == query_idx as u64
&& matches.distances[0] < 1e-6;
let mut results = results_clone.lock().unwrap();
results.push(exact_match_found);
});
let search_results = search_results.lock().unwrap();
assert_eq!(search_results.len(), 100);
assert!(
search_results.iter().all(|&x| x),
"All searches should find exact matches"
);
}
}