1pub fn version() -> &'static str {
22 env!("CARGO_PKG_VERSION")
23}
24
25pub type Key = u64;
28
29pub type Distance = f32;
32
33pub type StatefulMetric = unsafe extern "C" fn(
35 *const std::ffi::c_void,
36 *const std::ffi::c_void,
37 *mut std::ffi::c_void,
38) -> Distance;
39
40pub type StatefulPredicate = unsafe extern "C" fn(Key, *mut std::ffi::c_void) -> bool;
42
43#[derive(Debug)]
45pub enum BitAddressableError {
46 IndexOutOfRange,
48}
49
50impl std::fmt::Display for BitAddressableError {
51 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
52 match *self {
53 BitAddressableError::IndexOutOfRange => write!(f, "Index out of range"),
54 }
55 }
56}
57
58impl std::error::Error for BitAddressableError {}
59
60pub trait BitAddressable {
63 fn set_bit(&mut self, index: usize, value: bool) -> Result<(), BitAddressableError>;
71
72 fn get_bit(&self, index: usize) -> Result<bool, BitAddressableError>;
79}
80
81#[repr(transparent)]
86#[allow(non_camel_case_types)]
87#[derive(Clone, Copy, Eq, PartialEq)]
88pub struct b1x8(pub u8);
89
90impl b1x8 {
91 pub fn from_u8s(slice: &[u8]) -> &[Self] {
93 unsafe { std::slice::from_raw_parts(slice.as_ptr() as *const Self, slice.len()) }
94 }
95
96 pub fn from_mut_u8s(slice: &mut [u8]) -> &mut [Self] {
99 unsafe { std::slice::from_raw_parts_mut(slice.as_mut_ptr() as *mut Self, slice.len()) }
100 }
101
102 pub fn to_u8s(slice: &[Self]) -> &[u8] {
105 unsafe { std::slice::from_raw_parts(slice.as_ptr() as *const u8, slice.len()) }
106 }
107
108 pub fn to_mut_u8s(slice: &mut [Self]) -> &mut [u8] {
111 unsafe { std::slice::from_raw_parts_mut(slice.as_mut_ptr() as *mut u8, slice.len()) }
112 }
113}
114
115#[repr(transparent)]
120#[allow(non_camel_case_types)]
121#[derive(Clone, Copy)]
122pub struct f16(i16);
123
124impl f16 {
125 pub fn from_i16s(slice: &[i16]) -> &[Self] {
128 unsafe { std::slice::from_raw_parts(slice.as_ptr() as *const Self, slice.len()) }
129 }
130
131 pub fn from_mut_i16s(slice: &mut [i16]) -> &mut [Self] {
134 unsafe { std::slice::from_raw_parts_mut(slice.as_mut_ptr() as *mut Self, slice.len()) }
135 }
136
137 pub fn to_i16s(slice: &[Self]) -> &[i16] {
140 unsafe { std::slice::from_raw_parts(slice.as_ptr() as *const i16, slice.len()) }
141 }
142
143 pub fn to_mut_i16s(slice: &mut [Self]) -> &mut [i16] {
147 unsafe { std::slice::from_raw_parts_mut(slice.as_mut_ptr() as *mut i16, slice.len()) }
148 }
149}
150
151impl BitAddressable for b1x8 {
152 fn set_bit(&mut self, index: usize, value: bool) -> Result<(), BitAddressableError> {
164 if index >= 8 {
165 Err(BitAddressableError::IndexOutOfRange)
166 } else {
167 if value {
168 self.0 |= 1 << index;
169 } else {
170 self.0 &= !(1 << index);
171 }
172 Ok(())
173 }
174 }
175
176 fn get_bit(&self, index: usize) -> Result<bool, BitAddressableError> {
188 if index >= 8 {
189 Err(BitAddressableError::IndexOutOfRange)
190 } else {
191 Ok(((self.0 >> index) & 1) == 1)
192 }
193 }
194}
195
196impl BitAddressable for [b1x8] {
197 fn set_bit(&mut self, index: usize, value: bool) -> Result<(), BitAddressableError> {
199 let byte_index = index / 8;
200 let bit_index = index % 8;
201 if byte_index >= self.len() {
202 Err(BitAddressableError::IndexOutOfRange)
203 } else {
204 self[byte_index].set_bit(bit_index, value)
205 }
206 }
207
208 fn get_bit(&self, index: usize) -> Result<bool, BitAddressableError> {
210 let byte_index = index / 8;
211 let bit_index = index % 8;
212 if byte_index >= self.len() {
213 Err(BitAddressableError::IndexOutOfRange)
214 } else {
215 self[byte_index].get_bit(bit_index)
216 }
217 }
218}
219
220impl PartialEq for f16 {
221 fn eq(&self, other: &Self) -> bool {
222 let nan_self = (self.0 & 0x7C00) == 0x7C00 && (self.0 & 0x03FF) != 0;
224 let nan_other = (other.0 & 0x7C00) == 0x7C00 && (other.0 & 0x03FF) != 0;
225 if nan_self || nan_other {
226 return false;
227 }
228
229 self.0 == other.0
230 }
231}
232
233impl std::fmt::Debug for b1x8 {
234 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
235 write!(f, "{:08b}", self.0)
236 }
237}
238
239impl std::fmt::Debug for f16 {
240 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
241 let bits = self.0;
242 let sign = (bits >> 15) & 1;
243 let exponent = (bits >> 10) & 0x1F;
244 let mantissa = bits & 0x3FF;
245 write!(f, "{}|{:05b}|{:010b}", sign, exponent, mantissa)
246 }
247}
248
249#[cxx::bridge]
250pub mod ffi {
251
252 #[derive(Debug)]
254 #[repr(i32)]
255 enum MetricKind {
256 Unknown,
257 IP,
259 L2sq,
261 Cos,
263 Pearson,
265 Haversine,
267 Divergence,
269 Hamming,
271 Tanimoto,
273 Sorensen,
275 }
276
277 #[derive(Debug)]
279 #[repr(i32)]
280 enum ScalarKind {
281 Unknown,
282 F64,
284 F32,
286 F16,
288 BF16,
290 I8,
292 B1,
294 }
295
296 #[derive(Debug)]
299 struct Matches {
300 keys: Vec<u64>,
301 distances: Vec<f32>,
302 }
303
304 #[derive(Debug, PartialEq)]
308 struct IndexOptions {
309 dimensions: usize,
310 metric: MetricKind,
311 quantization: ScalarKind,
312 connectivity: usize,
313 expansion_add: usize,
314 expansion_search: usize,
315 multi: bool,
316 }
317
318 unsafe extern "C++" {
320 include!("lib.hpp");
321
322 type NativeIndex;
324
325 pub fn expansion_add(self: &NativeIndex) -> usize;
326 pub fn expansion_search(self: &NativeIndex) -> usize;
327 pub fn change_expansion_add(self: &NativeIndex, n: usize);
328 pub fn change_expansion_search(self: &NativeIndex, n: usize);
329 pub fn change_metric_kind(self: &NativeIndex, metric: MetricKind);
330
331 pub fn change_metric(self: &NativeIndex, metric: usize, metric_state: usize);
337
338 pub fn new_native_index(options: &IndexOptions) -> Result<UniquePtr<NativeIndex>>;
339 pub fn reserve(self: &NativeIndex, capacity: usize) -> Result<()>;
340 pub fn reserve_capacity_and_threads(
341 self: &NativeIndex,
342 capacity: usize,
343 threads: usize,
344 ) -> Result<()>;
345
346 pub fn dimensions(self: &NativeIndex) -> usize;
347 pub fn connectivity(self: &NativeIndex) -> usize;
348 pub fn size(self: &NativeIndex) -> usize;
349 pub fn capacity(self: &NativeIndex) -> usize;
350 pub fn serialized_length(self: &NativeIndex) -> usize;
351
352 pub fn add_b1x8(self: &NativeIndex, key: u64, vector: &[u8]) -> Result<()>;
353 pub fn add_i8(self: &NativeIndex, key: u64, vector: &[i8]) -> Result<()>;
354 pub fn add_f16(self: &NativeIndex, key: u64, vector: &[i16]) -> Result<()>;
355 pub fn add_f32(self: &NativeIndex, key: u64, vector: &[f32]) -> Result<()>;
356 pub fn add_f64(self: &NativeIndex, key: u64, vector: &[f64]) -> Result<()>;
357
358 pub fn search_b1x8(self: &NativeIndex, query: &[u8], count: usize) -> Result<Matches>;
359 pub fn search_i8(self: &NativeIndex, query: &[i8], count: usize) -> Result<Matches>;
360 pub fn search_f16(self: &NativeIndex, query: &[i16], count: usize) -> Result<Matches>;
361 pub fn search_f32(self: &NativeIndex, query: &[f32], count: usize) -> Result<Matches>;
362 pub fn search_f64(self: &NativeIndex, query: &[f64], count: usize) -> Result<Matches>;
363
364 pub fn exact_search_b1x8(self: &NativeIndex, query: &[u8], count: usize)
365 -> Result<Matches>;
366 pub fn exact_search_i8(self: &NativeIndex, query: &[i8], count: usize) -> Result<Matches>;
367 pub fn exact_search_f16(self: &NativeIndex, query: &[i16], count: usize)
368 -> Result<Matches>;
369 pub fn exact_search_f32(self: &NativeIndex, query: &[f32], count: usize)
370 -> Result<Matches>;
371 pub fn exact_search_f64(self: &NativeIndex, query: &[f64], count: usize)
372 -> Result<Matches>;
373
374 pub fn filtered_search_b1x8(
375 self: &NativeIndex,
376 query: &[u8],
377 count: usize,
378 filter: usize,
379 filter_state: usize,
380 ) -> Result<Matches>;
381 pub fn filtered_search_i8(
382 self: &NativeIndex,
383 query: &[i8],
384 count: usize,
385 filter: usize,
386 filter_state: usize,
387 ) -> Result<Matches>;
388 pub fn filtered_search_f16(
389 self: &NativeIndex,
390 query: &[i16],
391 count: usize,
392 filter: usize,
393 filter_state: usize,
394 ) -> Result<Matches>;
395 pub fn filtered_search_f32(
396 self: &NativeIndex,
397 query: &[f32],
398 count: usize,
399 filter: usize,
400 filter_state: usize,
401 ) -> Result<Matches>;
402 pub fn filtered_search_f64(
403 self: &NativeIndex,
404 query: &[f64],
405 count: usize,
406 filter: usize,
407 filter_state: usize,
408 ) -> Result<Matches>;
409
410 pub fn get_b1x8(self: &NativeIndex, key: u64, buffer: &mut [u8]) -> Result<usize>;
411 pub fn get_i8(self: &NativeIndex, key: u64, buffer: &mut [i8]) -> Result<usize>;
412 pub fn get_f16(self: &NativeIndex, key: u64, buffer: &mut [i16]) -> Result<usize>;
413 pub fn get_f32(self: &NativeIndex, key: u64, buffer: &mut [f32]) -> Result<usize>;
414 pub fn get_f64(self: &NativeIndex, key: u64, buffer: &mut [f64]) -> Result<usize>;
415
416 pub fn remove(self: &NativeIndex, key: u64) -> Result<usize>;
417 pub fn rename(self: &NativeIndex, from: u64, to: u64) -> Result<usize>;
418 pub fn contains(self: &NativeIndex, key: u64) -> bool;
419 pub fn count(self: &NativeIndex, key: u64) -> usize;
420
421 pub fn save(self: &NativeIndex, path: &str) -> Result<()>;
422 pub fn load(self: &NativeIndex, path: &str) -> Result<()>;
423 pub fn view(self: &NativeIndex, path: &str) -> Result<()>;
424 pub fn reset(self: &NativeIndex) -> Result<()>;
425 pub fn memory_usage(self: &NativeIndex) -> usize;
426 pub fn hardware_acceleration(self: &NativeIndex) -> *const c_char;
427
428 pub fn save_to_buffer(self: &NativeIndex, buffer: &mut [u8]) -> Result<()>;
429 pub fn load_from_buffer(self: &NativeIndex, buffer: &[u8]) -> Result<()>;
430 pub fn view_from_buffer(self: &NativeIndex, buffer: &[u8]) -> Result<()>;
431 }
432}
433
434pub use ffi::{IndexOptions, MetricKind, ScalarKind};
436
437pub enum MetricFunction {
485 B1X8Metric(*mut std::boxed::Box<dyn Fn(*const b1x8, *const b1x8) -> Distance + Send + Sync>),
486 I8Metric(*mut std::boxed::Box<dyn Fn(*const i8, *const i8) -> Distance + Send + Sync>),
487 F16Metric(*mut std::boxed::Box<dyn Fn(*const f16, *const f16) -> Distance + Send + Sync>),
488 F32Metric(*mut std::boxed::Box<dyn Fn(*const f32, *const f32) -> Distance + Send + Sync>),
489 F64Metric(*mut std::boxed::Box<dyn Fn(*const f64, *const f64) -> Distance + Send + Sync>),
490}
491
492pub struct Index {
529 inner: cxx::UniquePtr<ffi::NativeIndex>,
530 metric_fn: Option<MetricFunction>,
531}
532
533unsafe impl Send for Index {}
534unsafe impl Sync for Index {}
535
536impl Drop for Index {
537 fn drop(&mut self) {
538 if let Some(metric) = &self.metric_fn {
539 match metric {
540 MetricFunction::B1X8Metric(pointer) => unsafe {
541 drop(Box::from_raw(*pointer));
542 },
543 MetricFunction::I8Metric(pointer) => unsafe {
544 drop(Box::from_raw(*pointer));
545 },
546 MetricFunction::F16Metric(pointer) => unsafe {
547 drop(Box::from_raw(*pointer));
548 },
549 MetricFunction::F32Metric(pointer) => unsafe {
550 drop(Box::from_raw(*pointer));
551 },
552 MetricFunction::F64Metric(pointer) => unsafe {
553 drop(Box::from_raw(*pointer));
554 },
555 }
556 }
557 }
558}
559
560impl Default for ffi::IndexOptions {
561 fn default() -> Self {
562 Self {
563 dimensions: 256,
564 metric: MetricKind::Cos,
565 quantization: ScalarKind::BF16,
566 connectivity: 0,
567 expansion_add: 0,
568 expansion_search: 0,
569 multi: false,
570 }
571 }
572}
573
574impl Clone for ffi::IndexOptions {
575 fn clone(&self) -> Self {
576 ffi::IndexOptions {
577 dimensions: (self.dimensions),
578 metric: (self.metric),
579 quantization: (self.quantization),
580 connectivity: (self.connectivity),
581 expansion_add: (self.expansion_add),
582 expansion_search: (self.expansion_search),
583 multi: (self.multi),
584 }
585 }
586}
587
588pub trait VectorType {
592 fn add(index: &Index, key: Key, vector: &[Self]) -> Result<(), cxx::Exception>
603 where
604 Self: Sized;
605
606 fn get(index: &Index, key: Key, buffer: &mut [Self]) -> Result<usize, cxx::Exception>
618 where
619 Self: Sized;
620
621 fn search(index: &Index, query: &[Self], count: usize) -> Result<ffi::Matches, cxx::Exception>
633 where
634 Self: Sized;
635
636 fn exact_search(
649 index: &Index,
650 query: &[Self],
651 count: usize,
652 ) -> Result<ffi::Matches, cxx::Exception>
653 where
654 Self: Sized;
655
656 fn filtered_search<F>(
670 index: &Index,
671 query: &[Self],
672 count: usize,
673 filter: F,
674 ) -> Result<ffi::Matches, cxx::Exception>
675 where
676 Self: Sized,
677 F: Fn(Key) -> bool;
678
679 fn change_metric(
690 index: &mut Index,
691 metric: std::boxed::Box<dyn Fn(*const Self, *const Self) -> Distance + Send + Sync>,
692 ) -> Result<(), cxx::Exception>
693 where
694 Self: Sized;
695}
696
697impl VectorType for f32 {
698 fn search(index: &Index, query: &[Self], count: usize) -> Result<ffi::Matches, cxx::Exception> {
699 index.inner.search_f32(query, count)
700 }
701
702 fn exact_search(
703 index: &Index,
704 query: &[Self],
705 count: usize,
706 ) -> Result<ffi::Matches, cxx::Exception> {
707 index.inner.exact_search_f32(query, count)
708 }
709
710 fn get(index: &Index, key: Key, vector: &mut [Self]) -> Result<usize, cxx::Exception> {
711 index.inner.get_f32(key, vector)
712 }
713
714 fn add(index: &Index, key: Key, vector: &[Self]) -> Result<(), cxx::Exception> {
715 index.inner.add_f32(key, vector)
716 }
717
718 fn filtered_search<F>(
719 index: &Index,
720 query: &[Self],
721 count: usize,
722 filter: F,
723 ) -> Result<ffi::Matches, cxx::Exception>
724 where
725 Self: Sized,
726 F: Fn(Key) -> bool,
727 {
728 extern "C" fn trampoline<F: Fn(u64) -> bool>(key: u64, closure_address: usize) -> bool {
730 let closure = closure_address as *const F;
731 unsafe { (*closure)(key) }
732 }
733
734 let trampoline_fn: usize = trampoline::<F> as *const () as usize;
736 let closure_address: usize = &filter as *const F as usize;
737 index
738 .inner
739 .filtered_search_f32(query, count, trampoline_fn, closure_address)
740 }
741
742 fn change_metric(
743 index: &mut Index,
744 metric: std::boxed::Box<dyn Fn(*const Self, *const Self) -> Distance + Send + Sync>,
745 ) -> Result<(), cxx::Exception> {
746 type MetricFn = Box<dyn Fn(*const f32, *const f32) -> Distance>;
748 index.metric_fn = Some(MetricFunction::F32Metric(Box::into_raw(Box::new(metric))));
749
750 extern "C" fn trampoline(first: usize, second: usize, closure_address: usize) -> Distance {
755 let first_ptr = first as *const f32;
756 let second_ptr = second as *const f32;
757 let closure: *mut MetricFn = closure_address as *mut MetricFn;
758 unsafe { (*closure)(first_ptr, second_ptr) }
759 }
760
761 let trampoline_fn: usize = trampoline as *const () as usize;
762 let closure_address = match index.metric_fn {
763 Some(MetricFunction::F32Metric(metric)) => metric as *mut () as usize,
764 _ => panic!("Expected F32Metric"),
765 };
766 index.inner.change_metric(trampoline_fn, closure_address);
767
768 Ok(())
769 }
770}
771
772impl VectorType for i8 {
773 fn search(index: &Index, query: &[Self], count: usize) -> Result<ffi::Matches, cxx::Exception> {
774 index.inner.search_i8(query, count)
775 }
776
777 fn exact_search(
778 index: &Index,
779 query: &[Self],
780 count: usize,
781 ) -> Result<ffi::Matches, cxx::Exception> {
782 index.inner.exact_search_i8(query, count)
783 }
784
785 fn get(index: &Index, key: Key, vector: &mut [Self]) -> Result<usize, cxx::Exception> {
786 index.inner.get_i8(key, vector)
787 }
788
789 fn add(index: &Index, key: Key, vector: &[Self]) -> Result<(), cxx::Exception> {
790 index.inner.add_i8(key, vector)
791 }
792
793 fn filtered_search<F>(
794 index: &Index,
795 query: &[Self],
796 count: usize,
797 filter: F,
798 ) -> Result<ffi::Matches, cxx::Exception>
799 where
800 Self: Sized,
801 F: Fn(Key) -> bool,
802 {
803 extern "C" fn trampoline<F: Fn(u64) -> bool>(key: u64, closure_address: usize) -> bool {
805 let closure = closure_address as *const F;
806 unsafe { (*closure)(key) }
807 }
808
809 let trampoline_fn: usize = trampoline::<F> as *const () as usize;
811 let closure_address: usize = &filter as *const F as usize;
812 index
813 .inner
814 .filtered_search_i8(query, count, trampoline_fn, closure_address)
815 }
816 fn change_metric(
817 index: &mut Index,
818 metric: std::boxed::Box<dyn Fn(*const Self, *const Self) -> Distance + Send + Sync>,
819 ) -> Result<(), cxx::Exception> {
820 type MetricFn = Box<dyn Fn(*const i8, *const i8) -> Distance>;
822 index.metric_fn = Some(MetricFunction::I8Metric(Box::into_raw(Box::new(metric))));
823
824 extern "C" fn trampoline(first: usize, second: usize, closure_address: usize) -> Distance {
829 let first_ptr = first as *const i8;
830 let second_ptr = second as *const i8;
831 let closure: *mut MetricFn = closure_address as *mut MetricFn;
832 unsafe { (*closure)(first_ptr, second_ptr) }
833 }
834
835 let trampoline_fn: usize = trampoline as *const () as usize;
836 let closure_address = match index.metric_fn {
837 Some(MetricFunction::I8Metric(metric)) => metric as *mut () as usize,
838 _ => panic!("Expected I8Metric"),
839 };
840 index.inner.change_metric(trampoline_fn, closure_address);
841
842 Ok(())
843 }
844}
845
846impl VectorType for f64 {
847 fn search(index: &Index, query: &[Self], count: usize) -> Result<ffi::Matches, cxx::Exception> {
848 index.inner.search_f64(query, count)
849 }
850
851 fn exact_search(
852 index: &Index,
853 query: &[Self],
854 count: usize,
855 ) -> Result<ffi::Matches, cxx::Exception> {
856 index.inner.exact_search_f64(query, count)
857 }
858
859 fn get(index: &Index, key: Key, vector: &mut [Self]) -> Result<usize, cxx::Exception> {
860 index.inner.get_f64(key, vector)
861 }
862
863 fn add(index: &Index, key: Key, vector: &[Self]) -> Result<(), cxx::Exception> {
864 index.inner.add_f64(key, vector)
865 }
866
867 fn filtered_search<F>(
868 index: &Index,
869 query: &[Self],
870 count: usize,
871 filter: F,
872 ) -> Result<ffi::Matches, cxx::Exception>
873 where
874 Self: Sized,
875 F: Fn(Key) -> bool,
876 {
877 extern "C" fn trampoline<F: Fn(u64) -> bool>(key: u64, closure_address: usize) -> bool {
879 let closure = closure_address as *const F;
880 unsafe { (*closure)(key) }
881 }
882
883 let trampoline_fn: usize = trampoline::<F> as *const () as usize;
885 let closure_address: usize = &filter as *const F as usize;
886 index
887 .inner
888 .filtered_search_f64(query, count, trampoline_fn, closure_address)
889 }
890 fn change_metric(
891 index: &mut Index,
892 metric: std::boxed::Box<dyn Fn(*const Self, *const Self) -> Distance + Send + Sync>,
893 ) -> Result<(), cxx::Exception> {
894 type MetricFn = Box<dyn Fn(*const f64, *const f64) -> Distance>;
896 index.metric_fn = Some(MetricFunction::F64Metric(Box::into_raw(Box::new(metric))));
897
898 extern "C" fn trampoline(first: usize, second: usize, closure_address: usize) -> Distance {
903 let first_ptr = first as *const f64;
904 let second_ptr = second as *const f64;
905 let closure: *mut MetricFn = closure_address as *mut MetricFn;
906 unsafe { (*closure)(first_ptr, second_ptr) }
907 }
908
909 let trampoline_fn: usize = trampoline as *const () as usize;
910 let closure_address = match index.metric_fn {
911 Some(MetricFunction::F64Metric(metric)) => metric as *mut () as usize,
912 _ => panic!("Expected F64Metric"),
913 };
914 index.inner.change_metric(trampoline_fn, closure_address);
915
916 Ok(())
917 }
918}
919
920impl VectorType for f16 {
921 fn search(index: &Index, query: &[Self], count: usize) -> Result<ffi::Matches, cxx::Exception> {
922 index.inner.search_f16(f16::to_i16s(query), count)
923 }
924
925 fn exact_search(
926 index: &Index,
927 query: &[Self],
928 count: usize,
929 ) -> Result<ffi::Matches, cxx::Exception> {
930 index.inner.exact_search_f16(f16::to_i16s(query), count)
931 }
932
933 fn get(index: &Index, key: Key, vector: &mut [Self]) -> Result<usize, cxx::Exception> {
934 index.inner.get_f16(key, f16::to_mut_i16s(vector))
935 }
936
937 fn add(index: &Index, key: Key, vector: &[Self]) -> Result<(), cxx::Exception> {
938 index.inner.add_f16(key, f16::to_i16s(vector))
939 }
940
941 fn filtered_search<F>(
942 index: &Index,
943 query: &[Self],
944 count: usize,
945 filter: F,
946 ) -> Result<ffi::Matches, cxx::Exception>
947 where
948 Self: Sized,
949 F: Fn(Key) -> bool,
950 {
951 extern "C" fn trampoline<F: Fn(u64) -> bool>(key: u64, closure_address: usize) -> bool {
953 let closure = closure_address as *const F;
954 unsafe { (*closure)(key) }
955 }
956
957 let trampoline_fn: usize = trampoline::<F> as *const () as usize;
959 let closure_address: usize = &filter as *const F as usize;
960 index.inner.filtered_search_f16(
961 f16::to_i16s(query),
962 count,
963 trampoline_fn,
964 closure_address,
965 )
966 }
967
968 fn change_metric(
969 index: &mut Index,
970 metric: std::boxed::Box<dyn Fn(*const Self, *const Self) -> Distance + Send + Sync>,
971 ) -> Result<(), cxx::Exception> {
972 type MetricFn = Box<dyn Fn(*const f16, *const f16) -> Distance>;
974 index.metric_fn = Some(MetricFunction::F16Metric(Box::into_raw(Box::new(metric))));
975
976 extern "C" fn trampoline(first: usize, second: usize, closure_address: usize) -> Distance {
981 let first_ptr = first as *const f16;
982 let second_ptr = second as *const f16;
983 let closure: *mut MetricFn = closure_address as *mut MetricFn;
984 unsafe { (*closure)(first_ptr, second_ptr) }
985 }
986
987 let trampoline_fn: usize = trampoline as *const () as usize;
988 let closure_address = match index.metric_fn {
989 Some(MetricFunction::F16Metric(metric)) => metric as *mut () as usize,
990 _ => panic!("Expected F16Metric"),
991 };
992 index.inner.change_metric(trampoline_fn, closure_address);
993
994 Ok(())
995 }
996}
997
998impl VectorType for b1x8 {
999 fn search(index: &Index, query: &[Self], count: usize) -> Result<ffi::Matches, cxx::Exception> {
1000 index.inner.search_b1x8(b1x8::to_u8s(query), count)
1001 }
1002
1003 fn exact_search(
1004 index: &Index,
1005 query: &[Self],
1006 count: usize,
1007 ) -> Result<ffi::Matches, cxx::Exception> {
1008 index.inner.exact_search_b1x8(b1x8::to_u8s(query), count)
1009 }
1010
1011 fn get(index: &Index, key: Key, vector: &mut [Self]) -> Result<usize, cxx::Exception> {
1012 index.inner.get_b1x8(key, b1x8::to_mut_u8s(vector))
1013 }
1014
1015 fn add(index: &Index, key: Key, vector: &[Self]) -> Result<(), cxx::Exception> {
1016 index.inner.add_b1x8(key, b1x8::to_u8s(vector))
1017 }
1018
1019 fn filtered_search<F>(
1020 index: &Index,
1021 query: &[Self],
1022 count: usize,
1023 filter: F,
1024 ) -> Result<ffi::Matches, cxx::Exception>
1025 where
1026 Self: Sized,
1027 F: Fn(Key) -> bool,
1028 {
1029 extern "C" fn trampoline<F: Fn(u64) -> bool>(key: u64, closure_address: usize) -> bool {
1031 let closure = closure_address as *const F;
1032 unsafe { (*closure)(key) }
1033 }
1034
1035 let trampoline_fn: usize = trampoline::<F> as *const () as usize;
1037 let closure_address: usize = &filter as *const F as usize;
1038 index.inner.filtered_search_b1x8(
1039 b1x8::to_u8s(query),
1040 count,
1041 trampoline_fn,
1042 closure_address,
1043 )
1044 }
1045
1046 fn change_metric(
1047 index: &mut Index,
1048 metric: std::boxed::Box<dyn Fn(*const Self, *const Self) -> Distance + Send + Sync>,
1049 ) -> Result<(), cxx::Exception> {
1050 type MetricFn = Box<dyn Fn(*const b1x8, *const b1x8) -> Distance>;
1052 index.metric_fn = Some(MetricFunction::B1X8Metric(Box::into_raw(Box::new(metric))));
1053
1054 extern "C" fn trampoline(first: usize, second: usize, closure_address: usize) -> Distance {
1059 let first_ptr = first as *const b1x8;
1060 let second_ptr = second as *const b1x8;
1061 let closure: *mut MetricFn = closure_address as *mut MetricFn;
1062 unsafe { (*closure)(first_ptr, second_ptr) }
1063 }
1064
1065 let trampoline_fn: usize = trampoline as *const () as usize;
1066 let closure_address = match index.metric_fn {
1067 Some(MetricFunction::B1X8Metric(metric)) => metric as *mut () as usize,
1068 _ => panic!("Expected F1X8Metric"),
1069 };
1070 index.inner.change_metric(trampoline_fn, closure_address);
1071
1072 Ok(())
1073 }
1074}
1075
1076impl Index {
1077 pub fn new(options: &ffi::IndexOptions) -> Result<Self, cxx::Exception> {
1078 match ffi::new_native_index(options) {
1079 Ok(inner) => Result::Ok(Self {
1080 inner,
1081 metric_fn: None,
1082 }),
1083 Err(err) => Err(err),
1084 }
1085 }
1086
1087 pub fn expansion_add(self: &Index) -> usize {
1089 self.inner.expansion_add()
1090 }
1091
1092 pub fn expansion_search(self: &Index) -> usize {
1094 self.inner.expansion_search()
1095 }
1096
1097 pub fn change_expansion_add(self: &Index, n: usize) {
1099 self.inner.change_expansion_add(n)
1100 }
1101
1102 pub fn change_expansion_search(self: &Index, n: usize) {
1104 self.inner.change_expansion_search(n)
1105 }
1106
1107 pub fn change_metric_kind(self: &Index, metric: ffi::MetricKind) {
1109 self.inner.change_metric_kind(metric)
1110 }
1111
1112 pub fn change_metric<T: VectorType>(
1114 self: &mut Index,
1115 metric: std::boxed::Box<dyn Fn(*const T, *const T) -> Distance + Send + Sync>,
1116 ) {
1117 T::change_metric(self, metric).unwrap();
1118 }
1119
1120 pub fn hardware_acceleration(&self) -> String {
1122 use core::ffi::CStr;
1123 unsafe {
1124 let c_str = CStr::from_ptr(self.inner.hardware_acceleration());
1125 c_str.to_string_lossy().into_owned()
1126 }
1127 }
1128
1129 pub fn search<T: VectorType>(
1140 self: &Index,
1141 query: &[T],
1142 count: usize,
1143 ) -> Result<ffi::Matches, cxx::Exception> {
1144 T::search(self, query, count)
1145 }
1146
1147 pub fn exact_search<T: VectorType>(
1160 self: &Index,
1161 query: &[T],
1162 count: usize,
1163 ) -> Result<ffi::Matches, cxx::Exception> {
1164 T::exact_search(self, query, count)
1165 }
1166
1167 pub fn filtered_search<T: VectorType, F>(
1180 self: &Index,
1181 query: &[T],
1182 count: usize,
1183 filter: F,
1184 ) -> Result<ffi::Matches, cxx::Exception>
1185 where
1186 F: Fn(Key) -> bool,
1187 {
1188 T::filtered_search(self, query, count, filter)
1189 }
1190
1191 pub fn add<T: VectorType>(self: &Index, key: Key, vector: &[T]) -> Result<(), cxx::Exception> {
1198 T::add(self, key, vector)
1199 }
1200
1201 pub fn get<T: VectorType>(
1213 self: &Index,
1214 key: Key,
1215 vector: &mut [T],
1216 ) -> Result<usize, cxx::Exception> {
1217 T::get(self, key, vector)
1218 }
1219
1220 pub fn export<T: VectorType + Default + Clone>(
1228 self: &Index,
1229 key: Key,
1230 vector: &mut Vec<T>,
1231 ) -> Result<usize, cxx::Exception> {
1232 let dim = self.dimensions();
1233 let max_matches = self.count(key);
1234 vector.resize(dim * max_matches, T::default());
1235 let matches = T::get(self, key, &mut vector[..])?;
1236 vector.resize(dim * matches, T::default());
1237 Ok(matches)
1238 }
1239
1240 pub fn reserve(self: &Index, capacity: usize) -> Result<(), cxx::Exception> {
1246 self.inner.reserve(capacity)
1247 }
1248
1249 pub fn reserve_capacity_and_threads(
1256 self: &Index,
1257 capacity: usize,
1258 threads: usize,
1259 ) -> Result<(), cxx::Exception> {
1260 self.inner.reserve_capacity_and_threads(capacity, threads)
1261 }
1262
1263 pub fn dimensions(self: &Index) -> usize {
1265 self.inner.dimensions()
1266 }
1267
1268 pub fn connectivity(self: &Index) -> usize {
1270 self.inner.connectivity()
1271 }
1272
1273 pub fn size(self: &Index) -> usize {
1275 self.inner.size()
1276 }
1277
1278 pub fn capacity(self: &Index) -> usize {
1280 self.inner.capacity()
1281 }
1282
1283 pub fn serialized_length(self: &Index) -> usize {
1285 self.inner.serialized_length()
1286 }
1287
1288 pub fn remove(self: &Index, key: Key) -> Result<usize, cxx::Exception> {
1298 self.inner.remove(key)
1299 }
1300
1301 pub fn rename(self: &Index, from: Key, to: Key) -> Result<usize, cxx::Exception> {
1312 self.inner.rename(from, to)
1313 }
1314
1315 pub fn contains(self: &Index, key: Key) -> bool {
1325 self.inner.contains(key)
1326 }
1327
1328 pub fn count(self: &Index, key: Key) -> usize {
1338 self.inner.count(key)
1339 }
1340
1341 pub fn save(self: &Index, path: &str) -> Result<(), cxx::Exception> {
1347 self.inner.save(path)
1348 }
1349
1350 pub fn load(self: &Index, path: &str) -> Result<(), cxx::Exception> {
1356 self.inner.load(path)
1357 }
1358
1359 pub fn view(self: &Index, path: &str) -> Result<(), cxx::Exception> {
1365 self.inner.view(path)
1366 }
1367
1368 pub fn reset(self: &Index) -> Result<(), cxx::Exception> {
1370 self.inner.reset()
1371 }
1372
1373 pub fn memory_usage(self: &Index) -> usize {
1376 self.inner.memory_usage()
1377 }
1378
1379 pub fn save_to_buffer(self: &Index, buffer: &mut [u8]) -> Result<(), cxx::Exception> {
1385 self.inner.save_to_buffer(buffer)
1386 }
1387
1388 pub fn load_from_buffer(self: &Index, buffer: &[u8]) -> Result<(), cxx::Exception> {
1394 self.inner.load_from_buffer(buffer)
1395 }
1396
1397 pub unsafe fn view_from_buffer(self: &Index, buffer: &[u8]) -> Result<(), cxx::Exception> {
1426 self.inner.view_from_buffer(buffer)
1427 }
1428}
1429
1430pub fn new_index(options: &ffi::IndexOptions) -> Result<Index, cxx::Exception> {
1431 Index::new(options)
1432}
1433
1434#[cfg(test)]
1435mod tests {
1436 use crate::ffi::IndexOptions;
1437 use crate::ffi::MetricKind;
1438 use crate::ffi::ScalarKind;
1439
1440 use crate::b1x8;
1441 use crate::new_index;
1442 use crate::Index;
1443 use crate::Key;
1444
1445 use std::env;
1446
1447 #[test]
1448 fn print_specs() {
1449 println!("--------------------------------------------------");
1450 println!("OS: {}", env::consts::OS);
1451 println!(
1452 "Rust version: {}",
1453 env::var("RUST_VERSION").unwrap_or_else(|_| "unknown".into())
1454 );
1455
1456 let f64_index = Index::new(&IndexOptions {
1458 dimensions: 256,
1459 metric: MetricKind::Cos,
1460 quantization: ScalarKind::F64,
1461 ..Default::default()
1462 })
1463 .unwrap();
1464
1465 let f32_index = Index::new(&IndexOptions {
1466 dimensions: 256,
1467 metric: MetricKind::Cos,
1468 quantization: ScalarKind::F32,
1469 ..Default::default()
1470 })
1471 .unwrap();
1472
1473 let f16_index = Index::new(&IndexOptions {
1474 dimensions: 256,
1475 metric: MetricKind::Cos,
1476 quantization: ScalarKind::F16,
1477 ..Default::default()
1478 })
1479 .unwrap();
1480
1481 let i8_index = Index::new(&IndexOptions {
1482 dimensions: 256,
1483 metric: MetricKind::Cos,
1484 quantization: ScalarKind::I8,
1485 ..Default::default()
1486 })
1487 .unwrap();
1488
1489 let b1_index = Index::new(&IndexOptions {
1490 dimensions: 256,
1491 metric: MetricKind::Hamming,
1492 quantization: ScalarKind::B1,
1493 ..Default::default()
1494 })
1495 .unwrap();
1496
1497 println!(
1498 "f64 hardware acceleration: {}",
1499 f64_index.hardware_acceleration()
1500 );
1501 println!(
1502 "f32 hardware acceleration: {}",
1503 f32_index.hardware_acceleration()
1504 );
1505 println!(
1506 "f16 hardware acceleration: {}",
1507 f16_index.hardware_acceleration()
1508 );
1509 println!(
1510 "i8 hardware acceleration: {}",
1511 i8_index.hardware_acceleration()
1512 );
1513 println!(
1514 "b1 hardware acceleration: {}",
1515 b1_index.hardware_acceleration()
1516 );
1517 println!("--------------------------------------------------");
1518 }
1519
1520 #[test]
1521 fn test_add_get_vector() {
1522 let options = IndexOptions {
1523 dimensions: 5,
1524 quantization: ScalarKind::F32,
1525 ..Default::default()
1526 };
1527 let index = Index::new(&options).unwrap();
1528 assert!(index.reserve(10).is_ok());
1529
1530 let first: [f32; 5] = [0.2, 0.1, 0.2, 0.1, 0.3];
1531 let second: [f32; 5] = [0.3, 0.2, 0.4, 0.0, 0.1];
1532 let too_long: [f32; 6] = [0.3, 0.2, 0.4, 0.0, 0.1, 0.1];
1533 let too_short: [f32; 4] = [0.3, 0.2, 0.4, 0.0];
1534 assert!(index.add(1, &first).is_ok());
1535 assert!(index.add(2, &second).is_ok());
1536 assert!(index.add(3, &too_long).is_err());
1537 assert!(index.add(4, &too_short).is_err());
1538 assert_eq!(index.size(), 2);
1539
1540 let mut found_vec: Vec<f32> = Vec::new();
1542 assert_eq!(index.export(1, &mut found_vec).unwrap(), 1);
1543 assert_eq!(found_vec.len(), 5);
1544 assert_eq!(found_vec, first.to_vec());
1545
1546 let mut found_slice = [0.0f32; 5];
1548 assert_eq!(index.get(1, &mut found_slice).unwrap(), 1);
1549 assert_eq!(found_slice, first);
1550
1551 let mut found = [0.0f32; 6]; let result = index.get(1, &mut found);
1554 assert!(result.is_err());
1555 }
1556 #[test]
1557 fn test_search_vector() {
1558 let options = IndexOptions {
1559 dimensions: 5,
1560 quantization: ScalarKind::F32,
1561 ..Default::default()
1562 };
1563 let index = Index::new(&options).unwrap();
1564 assert!(index.reserve(10).is_ok());
1565
1566 let first: [f32; 5] = [0.2, 0.1, 0.2, 0.1, 0.3];
1567 let second: [f32; 5] = [0.3, 0.2, 0.4, 0.0, 0.1];
1568 let too_long: [f32; 6] = [0.3, 0.2, 0.4, 0.0, 0.1, 0.1];
1569 let too_short: [f32; 4] = [0.3, 0.2, 0.4, 0.0];
1570 assert!(index.add(1, &first).is_ok());
1571 assert!(index.add(2, &second).is_ok());
1572 assert_eq!(index.size(), 2);
1573 assert!(index.search(&too_long, 1).is_err());
1577 assert!(index.search(&too_short, 1).is_err());
1578 }
1579
1580 #[test]
1581 fn test_add_remove_vector() {
1582 let options = IndexOptions {
1583 dimensions: 4,
1584 metric: MetricKind::IP,
1585 quantization: ScalarKind::F64,
1586 connectivity: 10,
1587 expansion_add: 128,
1588 expansion_search: 3,
1589 ..Default::default()
1590 };
1591 let index = Index::new(&options).unwrap();
1592 assert!(index.reserve(10).is_ok());
1593 assert!(index.capacity() >= 10);
1594
1595 let first: [f32; 4] = [0.2, 0.1, 0.2, 0.1];
1596 let second: [f32; 4] = [0.3, 0.2, 0.4, 0.0];
1597
1598 let id1 = 483367403120493160;
1600 let id2 = 483367403120558696;
1601 let id3 = 483367403120624232;
1602 let id4 = 483367403120624233;
1603
1604 assert!(index.add(id1, &first).is_ok());
1605 let mut found_slice = [0.0f32; 4];
1606 assert_eq!(index.get(id1, &mut found_slice).unwrap(), 1);
1607 assert!(index.remove(id1).is_ok());
1608
1609 assert!(index.add(id2, &second).is_ok());
1610 let mut found_slice = [0.0f32; 4];
1611 assert_eq!(index.get(id2, &mut found_slice).unwrap(), 1);
1612 assert!(index.remove(id2).is_ok());
1613
1614 assert!(index.add(id3, &second).is_ok());
1615 let mut found_slice = [0.0f32; 4];
1616 assert_eq!(index.get(id3, &mut found_slice).unwrap(), 1);
1617 assert!(index.remove(id3).is_ok());
1618
1619 assert!(index.add(id4, &second).is_ok());
1620 let mut found_slice = [0.0f32; 4];
1621 assert_eq!(index.get(id4, &mut found_slice).unwrap(), 1);
1622 assert!(index.remove(id4).is_ok());
1623
1624 assert_eq!(index.size(), 0);
1625 }
1626
1627 #[test]
1628 fn integration() {
1629 let mut options = IndexOptions {
1630 dimensions: 5,
1631 ..Default::default()
1632 };
1633
1634 let index = Index::new(&options).unwrap();
1635
1636 assert!(index.expansion_add() > 0);
1637 assert!(index.expansion_search() > 0);
1638
1639 assert!(index.reserve(10).is_ok());
1640 assert!(index.capacity() >= 10);
1641 assert!(index.connectivity() != 0);
1642 assert_eq!(index.dimensions(), 5);
1643 assert_eq!(index.size(), 0);
1644
1645 let first: [f32; 5] = [0.2, 0.1, 0.2, 0.1, 0.3];
1646 let second: [f32; 5] = [0.3, 0.2, 0.4, 0.0, 0.1];
1647
1648 println!("--------------------------------------------------");
1649 println!(
1650 "before add, memory_usage: {} \
1651 cap: {} \
1652 ",
1653 index.memory_usage(),
1654 index.capacity(),
1655 );
1656 index.change_expansion_add(10);
1657 assert_eq!(index.expansion_add(), 10);
1658 assert!(index.add(42, &first).is_ok());
1659 index.change_expansion_add(12);
1660 assert_eq!(index.expansion_add(), 12);
1661 assert!(index.add(43, &second).is_ok());
1662 assert_eq!(index.size(), 2);
1663 println!(
1664 "after add, memory_usage: {} \
1665 cap: {} \
1666 ",
1667 index.memory_usage(),
1668 index.capacity(),
1669 );
1670
1671 index.change_expansion_search(10);
1672 assert_eq!(index.expansion_search(), 10);
1673 let results = index.search(&first, 10).unwrap();
1675 println!("{:?}", results);
1676 assert_eq!(results.keys.len(), 2);
1677
1678 index.change_expansion_search(12);
1679 assert_eq!(index.expansion_search(), 12);
1680 let results = index.search(&first, 10).unwrap();
1681 println!("{:?}", results);
1682 assert_eq!(results.keys.len(), 2);
1683 println!("--------------------------------------------------");
1684
1685 assert!(index.save("index.rust.usearch").is_ok());
1687 assert!(index.load("index.rust.usearch").is_ok());
1688 assert!(index.view("index.rust.usearch").is_ok());
1689
1690 assert!(new_index(&options).is_ok());
1692 options.metric = MetricKind::L2sq;
1693 assert!(new_index(&options).is_ok());
1694 options.metric = MetricKind::Cos;
1695 assert!(new_index(&options).is_ok());
1696 options.metric = MetricKind::Haversine;
1697 options.quantization = ScalarKind::F32;
1698 options.dimensions = 2;
1699 assert!(new_index(&options).is_ok());
1700
1701 let mut serialization_buffer = vec![0; index.serialized_length()];
1702 assert!(index.save_to_buffer(&mut serialization_buffer).is_ok());
1703
1704 let deserialized_index = new_index(&options).unwrap();
1705 assert!(deserialized_index
1706 .load_from_buffer(&serialization_buffer)
1707 .is_ok());
1708 assert_eq!(index.size(), deserialized_index.size());
1709
1710 assert_ne!(index.memory_usage(), 0);
1712 assert!(index.reset().is_ok());
1713 assert_eq!(index.size(), 0);
1714 assert_eq!(index.memory_usage(), 0);
1715
1716 options.metric = MetricKind::Haversine;
1718 let mut opts = options.clone();
1719 assert_eq!(opts.metric, options.metric);
1720 assert_eq!(opts.quantization, options.quantization);
1721 assert_eq!(opts, options);
1722 opts.metric = MetricKind::Cos;
1723 assert_ne!(opts.metric, options.metric);
1724 assert!(new_index(&opts).is_ok());
1725 }
1726
1727 #[test]
1728 fn test_search_with_stateless_filter() {
1729 let options = IndexOptions {
1730 dimensions: 5,
1731 ..Default::default()
1732 };
1733 let index = Index::new(&options).unwrap();
1734 index.reserve(10).unwrap();
1735
1736 let first: [f32; 5] = [0.2, 0.1, 0.2, 0.1, 0.3];
1738 let second: [f32; 5] = [0.3, 0.2, 0.4, 0.0, 0.1];
1739 index.add(1, &first).unwrap();
1740 index.add(2, &second).unwrap();
1741
1742 let is_odd = |key: Key| key % 2 == 1;
1744 let query = vec![0.2, 0.1, 0.2, 0.1, 0.3]; let results = index.filtered_search(&query, 10, is_odd).unwrap();
1746 assert!(
1747 results.keys.iter().all(|&key| key % 2 == 1),
1748 "All keys must be odd"
1749 );
1750 }
1751
1752 #[test]
1753 fn test_search_with_stateful_filter() {
1754 use std::collections::HashSet;
1755
1756 let options = IndexOptions {
1757 dimensions: 5,
1758 ..Default::default()
1759 };
1760 let index = Index::new(&options).unwrap();
1761 index.reserve(10).unwrap();
1762
1763 let first: [f32; 5] = [0.2, 0.1, 0.2, 0.1, 0.3];
1765 index.add(1, &first).unwrap();
1766 index.add(2, &first).unwrap();
1767
1768 let allowed_keys = vec![1, 2, 3].into_iter().collect::<HashSet<Key>>();
1769 let filter_keys = allowed_keys.clone();
1771 let stateful_filter = move |key: Key| filter_keys.contains(&key);
1772
1773 let query = vec![0.2, 0.1, 0.2, 0.1, 0.3]; let results = index.filtered_search(&query, 10, stateful_filter).unwrap();
1775
1776 assert!(
1778 results.keys.iter().all(|&key| allowed_keys.contains(&key)),
1779 "All keys must be in the allowed set"
1780 );
1781 }
1782
1783 #[test]
1784 fn test_zero_distances() {
1785 let options = IndexOptions {
1786 dimensions: 8,
1787 metric: MetricKind::L2sq,
1788 quantization: ScalarKind::F16,
1789 ..Default::default()
1790 };
1791
1792 let index = new_index(&options).unwrap();
1793 index.reserve(10).unwrap();
1794 index
1795 .add(0, &[0.4, 0.1, 0.1, 0.0, 0.0, 0.0, 0.0, 0.0])
1796 .unwrap();
1797 index
1798 .add(1, &[0.5, 0.1, 0.1, 0.0, 0.0, 0.0, 0.0, 0.0])
1799 .unwrap();
1800 index
1801 .add(2, &[0.6, 0.1, 0.1, 0.0, 0.0, 0.0, 0.0, 0.0])
1802 .unwrap();
1803
1804 let matches = index
1806 .search(&[0.05, 0.1, 0.1, 0.0, 0.0, 0.0, 0.0, 0.0], 2)
1807 .unwrap();
1808 for distance in matches.distances.iter() {
1809 assert_ne!(*distance, 0.0);
1810 }
1811 }
1812
1813 #[test]
1814 fn test_exact_search() {
1815 use std::collections::HashSet;
1816
1817 let options = IndexOptions {
1819 dimensions: 4,
1820 metric: MetricKind::L2sq,
1821 quantization: ScalarKind::F32,
1822 ..Default::default()
1823 };
1824 let index = new_index(&options).unwrap();
1825 index.reserve(100).unwrap();
1826 for i in 0..100 {
1828 let vec = vec![
1829 i as f32 * 0.1,
1830 (i as f32 * 0.05).sin(),
1831 (i as f32 * 0.05).cos(),
1832 0.0,
1833 ];
1834 index.add(i, &vec).unwrap();
1835 }
1836 let query = vec![4.5, 0.0, 1.0, 0.0];
1838 let approx_matches = index.search(&query, 10).unwrap();
1840 let exact_matches = index.exact_search(&query, 10).unwrap();
1841 let approx_keys: HashSet<Key> = approx_matches.keys.iter().cloned().collect();
1843 let exact_keys: HashSet<Key> = exact_matches.keys.iter().cloned().collect();
1844 assert_eq!(approx_matches.keys.len(), 10);
1846 assert_eq!(exact_matches.keys.len(), 10);
1847
1848 assert!(exact_matches.distances[0] <= approx_matches.distances[0]);
1851 println!(
1853 "Approximate search first match: key={}, distance={}",
1854 approx_matches.keys[0], approx_matches.distances[0]
1855 );
1856 println!(
1857 "Exact search first match: key={}, distance={}",
1858 exact_matches.keys[0], exact_matches.distances[0]
1859 );
1860 let intersection: HashSet<_> = approx_keys.intersection(&exact_keys).collect();
1862 println!(
1863 "Number of common results between approximate and exact search: {}",
1864 intersection.len()
1865 );
1866 }
1867
1868 #[test]
1869 fn test_change_distance_function() {
1870 let options = IndexOptions {
1871 dimensions: 2, ..Default::default()
1873 };
1874 let mut index = Index::new(&options).unwrap();
1875 index.reserve(10).unwrap();
1876
1877 let vector: [f32; 2] = [1.0, 0.0];
1879 index.add(1, &vector).unwrap();
1880
1881 let first_factor: f32 = 2.0;
1883 let second_factor: f32 = 0.7;
1884 let stateful_distance = Box::new(move |a: *const f32, b: *const f32| unsafe {
1885 let a_slice = std::slice::from_raw_parts(a, 2);
1886 let b_slice = std::slice::from_raw_parts(b, 2);
1887 (a_slice[0] - b_slice[0]).abs() * first_factor
1888 + (a_slice[1] - b_slice[1]).abs() * second_factor
1889 });
1890 index.change_metric(stateful_distance);
1891
1892 let another_vector: [f32; 2] = [0.0, 1.0];
1893 index.add(2, &another_vector).unwrap();
1894 }
1895
1896 #[test]
1897 fn test_binary_vectors_and_hamming_distance() {
1898 let index = Index::new(&IndexOptions {
1899 dimensions: 8,
1900 metric: MetricKind::Hamming,
1901 quantization: ScalarKind::B1,
1902 ..Default::default()
1903 })
1904 .unwrap();
1905
1906 let vector42: Vec<b1x8> = vec![b1x8(0b00001111)];
1908 let vector43: Vec<b1x8> = vec![b1x8(0b11110000)];
1909 let query: Vec<b1x8> = vec![b1x8(0b01111000)];
1910
1911 index.reserve(10).unwrap();
1913 index.add(42, &vector42).unwrap();
1914 index.add(43, &vector43).unwrap();
1915
1916 let results = index.search(&query, 5).unwrap();
1917
1918 assert_eq!(results.keys.len(), 2);
1920 assert_eq!(results.keys[0], 43);
1921 assert_eq!(results.distances[0], 2.0);
1922 assert_eq!(results.keys[1], 42);
1923 assert_eq!(results.distances[1], 6.0);
1924 }
1925
1926 #[test]
1927 fn test_concurrency() {
1928 use fork_union as fu;
1929 use rand::{Rng, SeedableRng};
1930 use rand_chacha::ChaCha8Rng;
1931 use rand_distr::Uniform;
1932 use std::sync::Arc;
1933
1934 const DIMENSIONS: usize = 128;
1935 const VECTOR_COUNT: usize = 1000;
1936 const THREAD_COUNT: usize = 4;
1937
1938 let options = IndexOptions {
1939 dimensions: DIMENSIONS,
1940 metric: MetricKind::Cos,
1941 quantization: ScalarKind::F32,
1942 ..Default::default()
1943 };
1944
1945 let index = Arc::new(Index::new(&options).unwrap());
1946 index
1947 .reserve_capacity_and_threads(VECTOR_COUNT, THREAD_COUNT)
1948 .unwrap();
1949
1950 let seed = 42; let mut rng = ChaCha8Rng::seed_from_u64(seed);
1953 let uniform = Uniform::new(-1.0f32, 1.0f32).unwrap();
1954
1955 let mut reference_vectors: Vec<[f32; DIMENSIONS]> = Vec::with_capacity(VECTOR_COUNT);
1957 for _ in 0..VECTOR_COUNT {
1958 let mut vector = [0.0f32; DIMENSIONS];
1959 for item in vector.iter_mut().take(DIMENSIONS) {
1961 *item = rng.sample(uniform);
1962 }
1963 reference_vectors.push(vector);
1964 }
1965
1966 let mut pool = fu::spawn(THREAD_COUNT);
1967
1968 pool.for_n(VECTOR_COUNT, |prong| {
1970 let index_clone = Arc::clone(&index);
1971 let i = prong.task_index;
1972 let vector = reference_vectors[i];
1973 index_clone.add(i as u64, &vector).unwrap();
1974 });
1975
1976 assert_eq!(index.size(), VECTOR_COUNT);
1977
1978 let mut pool = fu::spawn(THREAD_COUNT);
1980 let validation_results = Arc::new(std::sync::Mutex::new(Vec::new()));
1981
1982 pool.for_n(VECTOR_COUNT, |prong| {
1983 let index_clone = Arc::clone(&index);
1984 let results_clone = Arc::clone(&validation_results);
1985 let i = prong.task_index;
1986 let expected_vector = &reference_vectors[i];
1987
1988 let mut retrieved_vector = [0.0f32; DIMENSIONS];
1989 let count = index_clone.get(i as u64, &mut retrieved_vector).unwrap();
1990 assert_eq!(count, 1);
1991
1992 let matches = retrieved_vector
1994 .iter()
1995 .zip(expected_vector.iter())
1996 .all(|(a, b)| (a - b).abs() < 1e-6);
1997
1998 let mut results = results_clone.lock().unwrap();
1999 results.push(matches);
2000 });
2001
2002 let validation_results = validation_results.lock().unwrap();
2003 assert_eq!(validation_results.len(), VECTOR_COUNT);
2004 assert!(
2005 validation_results.iter().all(|&x| x),
2006 "All retrieved vectors should match the original ones"
2007 );
2008
2009 let mut pool = fu::spawn(THREAD_COUNT);
2011 let search_results = Arc::new(std::sync::Mutex::new(Vec::new()));
2012
2013 pool.for_n(100, |prong| {
2014 let index_clone = Arc::clone(&index);
2016 let results_clone = Arc::clone(&search_results);
2017 let query_idx = prong.task_index % VECTOR_COUNT;
2018 let query_vector = &reference_vectors[query_idx];
2019
2020 let matches = index_clone.exact_search(query_vector, 10).unwrap();
2021
2022 let exact_match_found = !matches.keys.is_empty()
2024 && matches.keys[0] == query_idx as u64
2025 && matches.distances[0] < 1e-6;
2026
2027 let mut results = results_clone.lock().unwrap();
2028 results.push(exact_match_found);
2029 });
2030
2031 let search_results = search_results.lock().unwrap();
2032 assert_eq!(search_results.len(), 100);
2033 assert!(
2034 search_results.iter().all(|&x| x),
2035 "All searches should find exact matches"
2036 );
2037 }
2038}