1use crate::error::{Error, Result};
14use crate::faiss_try;
15use crate::metric::MetricType;
16use crate::selector::IdSelector;
17use std::ffi::CString;
18use std::fmt::{self, Display, Formatter, Write};
19use std::os::raw::c_uint;
20use std::{mem, ptr};
21
22use faiss_sys::*;
23
24pub mod autotune;
25pub mod flat;
26pub mod id_map;
27pub mod io;
28pub mod io_flags;
29pub mod ivf_flat;
30pub mod lsh;
31pub mod pretransform;
32pub mod refine_flat;
33pub mod scalar_quantizer;
34
35#[cfg(feature = "gpu")]
36pub mod gpu;
37
38#[repr(transparent)]
43#[derive(Debug, Copy, Clone)]
44pub struct Idx(idx_t);
45
46impl Display for Idx {
47 fn fmt(&self, f: &mut Formatter) -> fmt::Result {
48 match self.get() {
49 None => f.write_char('x'),
50 Some(i) => i.fmt(f),
51 }
52 }
53}
54
55impl From<idx_t> for Idx {
56 fn from(x: idx_t) -> Self {
57 Idx(x)
58 }
59}
60
61impl Idx {
62 #[inline]
68 pub fn new(idx: u64) -> Self {
69 assert!(
70 idx < 0x8000_0000_0000_0000,
71 "too large index value provided to Idx::new"
72 );
73 let idx = idx as idx_t;
74 Idx(idx)
75 }
76
77 #[inline]
79 pub fn none() -> Self {
80 Idx(-1)
81 }
82
83 #[inline]
85 pub fn is_none(self) -> bool {
86 self.0 == -1
87 }
88
89 #[inline]
91 pub fn is_some(self) -> bool {
92 self.0 != -1
93 }
94
95 pub fn get(self) -> Option<u64> {
97 match self.0 {
98 -1 => None,
99 x => Some(x as u64),
100 }
101 }
102
103 pub fn to_native(self) -> idx_t {
105 self.0
106 }
107}
108
109impl PartialEq<Idx> for Idx {
112 fn eq(&self, idx: &Idx) -> bool {
113 self.0 != -1 && idx.0 != -1 && self.0 == idx.0
114 }
115}
116
117impl PartialOrd<Idx> for Idx {
120 fn partial_cmp(&self, idx: &Idx) -> Option<std::cmp::Ordering> {
121 match (self.get(), idx.get()) {
122 (None, _) => None,
123 (_, None) => None,
124 (Some(a), Some(b)) => Some(a.cmp(&b)),
125 }
126 }
127}
128
129pub trait Index {
141 fn is_trained(&self) -> bool;
143
144 fn ntotal(&self) -> u64;
146
147 fn d(&self) -> u32;
149
150 fn metric_type(&self) -> MetricType;
152
153 fn add(&mut self, x: &[f32]) -> Result<()>;
157
158 fn add_with_ids(&mut self, x: &[f32], xids: &[Idx]) -> Result<()>;
163
164 fn train(&mut self, x: &[f32]) -> Result<()>;
166
167 fn assign(&mut self, q: &[f32], k: usize) -> Result<AssignSearchResult>;
169
170 fn search(&mut self, q: &[f32], k: usize) -> Result<SearchResult>;
172
173 fn range_search(&mut self, q: &[f32], radius: f32) -> Result<RangeSearchResult>;
176
177 fn reset(&mut self) -> Result<()>;
179
180 fn remove_ids(&mut self, sel: &IdSelector) -> Result<usize>;
182
183 fn verbose(&self) -> bool;
185
186 fn set_verbose(&mut self, value: bool);
188}
189
190impl<I> Index for Box<I>
191where
192 I: Index,
193{
194 fn is_trained(&self) -> bool {
195 (**self).is_trained()
196 }
197
198 fn ntotal(&self) -> u64 {
199 (**self).ntotal()
200 }
201
202 fn d(&self) -> u32 {
203 (**self).d()
204 }
205
206 fn metric_type(&self) -> MetricType {
207 (**self).metric_type()
208 }
209
210 fn add(&mut self, x: &[f32]) -> Result<()> {
211 (**self).add(x)
212 }
213
214 fn add_with_ids(&mut self, x: &[f32], xids: &[Idx]) -> Result<()> {
215 (**self).add_with_ids(x, xids)
216 }
217
218 fn train(&mut self, x: &[f32]) -> Result<()> {
219 (**self).train(x)
220 }
221
222 fn assign(&mut self, q: &[f32], k: usize) -> Result<AssignSearchResult> {
223 (**self).assign(q, k)
224 }
225
226 fn search(&mut self, q: &[f32], k: usize) -> Result<SearchResult> {
227 (**self).search(q, k)
228 }
229
230 fn range_search(&mut self, q: &[f32], radius: f32) -> Result<RangeSearchResult> {
231 (**self).range_search(q, radius)
232 }
233
234 fn reset(&mut self) -> Result<()> {
235 (**self).reset()
236 }
237
238 fn remove_ids(&mut self, sel: &IdSelector) -> Result<usize> {
239 (**self).remove_ids(sel)
240 }
241
242 fn verbose(&self) -> bool {
243 (**self).verbose()
244 }
245
246 fn set_verbose(&mut self, value: bool) {
247 (**self).set_verbose(value)
248 }
249}
250
251pub trait NativeIndex: Index {
253 fn inner_ptr(&self) -> *mut FaissIndex;
255}
256
257impl<NI: NativeIndex> NativeIndex for Box<NI> {
258 fn inner_ptr(&self) -> *mut FaissIndex {
259 (**self).inner_ptr()
260 }
261}
262
263pub trait ConcurrentIndex: Index {
272 fn assign(&self, q: &[f32], k: usize) -> Result<AssignSearchResult>;
274
275 fn search(&self, q: &[f32], k: usize) -> Result<SearchResult>;
277
278 fn range_search(&self, q: &[f32], radius: f32) -> Result<RangeSearchResult>;
281}
282
283impl<CI: ConcurrentIndex> ConcurrentIndex for Box<CI> {
284 fn assign(&self, q: &[f32], k: usize) -> Result<AssignSearchResult> {
285 (**self).assign(q, k)
286 }
287
288 fn search(&self, q: &[f32], k: usize) -> Result<SearchResult> {
289 (**self).search(q, k)
290 }
291
292 fn range_search(&self, q: &[f32], radius: f32) -> Result<RangeSearchResult> {
293 (**self).range_search(q, radius)
294 }
295}
296
297pub trait CpuIndex: Index {}
299
300impl<CI: CpuIndex> CpuIndex for Box<CI> {}
301
302pub trait FromInnerPtr: NativeIndex {
305 unsafe fn from_inner_ptr(inner_ptr: *mut FaissIndex) -> Self;
316}
317
318pub trait TryFromInnerPtr: NativeIndex {
321 unsafe fn try_from_inner_ptr(inner_ptr: *mut FaissIndex) -> Result<Self>
334 where
335 Self: Sized;
336}
337
338pub trait TryClone {
340 fn try_clone(&self) -> Result<Self>
347 where
348 Self: Sized;
349}
350
351pub fn try_clone_from_inner_ptr<T>(val: &T) -> Result<T>
352where
353 T: FromInnerPtr,
354{
355 unsafe {
356 let mut new_index_ptr = ::std::ptr::null_mut();
357 faiss_try(faiss_clone_index(val.inner_ptr(), &mut new_index_ptr))?;
358 Ok(crate::index::FromInnerPtr::from_inner_ptr(new_index_ptr))
359 }
360}
361
362#[derive(Debug, Clone, PartialEq)]
364pub struct AssignSearchResult {
365 pub labels: Vec<Idx>,
366}
367
368#[derive(Debug, Clone, PartialEq)]
370pub struct SearchResult {
371 pub distances: Vec<f32>,
372 pub labels: Vec<Idx>,
373}
374
375#[derive(Debug, Clone, PartialEq)]
377pub struct RangeSearchResult {
378 inner: *mut FaissRangeSearchResult,
379}
380
381impl RangeSearchResult {
382 pub fn nq(&self) -> usize {
383 unsafe { faiss_RangeSearchResult_nq(self.inner) }
384 }
385
386 pub fn lims(&self) -> &[usize] {
387 unsafe {
388 let mut lims_ptr = ptr::null_mut();
389 faiss_RangeSearchResult_lims(self.inner, &mut lims_ptr);
390 ::std::slice::from_raw_parts(lims_ptr, self.nq() + 1)
391 }
392 }
393
394 pub fn distance_and_labels(&self) -> (&[f32], &[Idx]) {
397 let lims = self.lims();
398 let full_len = lims.last().cloned().unwrap_or(0);
399 unsafe {
400 let mut distances_ptr = ptr::null_mut();
401 let mut labels_ptr = ptr::null_mut();
402 faiss_RangeSearchResult_labels(self.inner, &mut labels_ptr, &mut distances_ptr);
403 let distances = ::std::slice::from_raw_parts(distances_ptr, full_len);
404 let labels = ::std::slice::from_raw_parts(labels_ptr as *const Idx, full_len);
405 (distances, labels)
406 }
407 }
408
409 pub fn distance_and_labels_mut(&self) -> (&mut [f32], &mut [Idx]) {
412 unsafe {
413 let buf_size = faiss_RangeSearchResult_buffer_size(self.inner);
414 let mut distances_ptr = ptr::null_mut();
415 let mut labels_ptr = ptr::null_mut();
416 faiss_RangeSearchResult_labels(self.inner, &mut labels_ptr, &mut distances_ptr);
417 let distances = ::std::slice::from_raw_parts_mut(distances_ptr, buf_size);
418 let labels = ::std::slice::from_raw_parts_mut(labels_ptr as *mut Idx, buf_size);
419 (distances, labels)
420 }
421 }
422
423 pub fn distances(&self) -> &[f32] {
426 self.distance_and_labels().0
427 }
428
429 pub fn distances_mut(&mut self) -> &mut [f32] {
432 self.distance_and_labels_mut().0
433 }
434
435 pub fn labels(&self) -> &[Idx] {
438 self.distance_and_labels().1
439 }
440
441 pub fn labels_mut(&mut self) -> &mut [Idx] {
444 self.distance_and_labels_mut().1
445 }
446}
447
448impl Drop for RangeSearchResult {
449 fn drop(&mut self) {
450 unsafe {
451 faiss_RangeSearchResult_free(self.inner);
452 }
453 }
454}
455
456#[derive(Debug)]
459pub struct IndexImpl {
460 inner: *mut FaissIndex,
461}
462
463unsafe impl Send for IndexImpl {}
464unsafe impl Sync for IndexImpl {}
465
466impl CpuIndex for IndexImpl {}
467
468impl Drop for IndexImpl {
469 fn drop(&mut self) {
470 unsafe {
471 faiss_Index_free(self.inner);
472 }
473 }
474}
475
476impl IndexImpl {
477 pub fn inner_ptr(&self) -> *mut FaissIndex {
478 self.inner
479 }
480}
481
482impl NativeIndex for IndexImpl {
483 fn inner_ptr(&self) -> *mut FaissIndex {
484 self.inner
485 }
486}
487
488impl FromInnerPtr for IndexImpl {
489 unsafe fn from_inner_ptr(inner_ptr: *mut FaissIndex) -> Self {
490 IndexImpl { inner: inner_ptr }
491 }
492}
493
494impl TryFromInnerPtr for IndexImpl {
495 unsafe fn try_from_inner_ptr(inner_ptr: *mut FaissIndex) -> Result<Self>
496 where
497 Self: Sized,
498 {
499 if inner_ptr.is_null() {
500 Err(Error::BadCast)
501 } else {
502 Ok(IndexImpl { inner: inner_ptr })
503 }
504 }
505}
506
507pub trait UpcastIndex: NativeIndex {
524 fn upcast(self) -> IndexImpl;
526}
527
528impl<NI: NativeIndex> UpcastIndex for NI {
529 fn upcast(self) -> IndexImpl {
530 let inner_ptr = self.inner_ptr();
531 mem::forget(self);
532
533 unsafe { IndexImpl::from_inner_ptr(inner_ptr) }
534 }
535}
536
537impl_native_index!(IndexImpl);
538
539impl TryClone for IndexImpl {
540 fn try_clone(&self) -> Result<Self>
541 where
542 Self: Sized,
543 {
544 try_clone_from_inner_ptr(self)
545 }
546}
547
548pub fn index_factory<D>(d: u32, description: D, metric: MetricType) -> Result<IndexImpl>
557where
558 D: AsRef<str>,
559{
560 unsafe {
561 let metric = metric as c_uint;
562 let description =
563 CString::new(description.as_ref()).map_err(|_| Error::IndexDescription)?;
564 let mut index_ptr = ::std::ptr::null_mut();
565 faiss_try(faiss_index_factory(
566 &mut index_ptr,
567 (d & 0x7FFF_FFFF) as i32,
568 description.as_ptr(),
569 metric,
570 ))?;
571 Ok(IndexImpl { inner: index_ptr })
572 }
573}
574
575#[cfg(test)]
576mod tests {
577 use super::{index_factory, Idx, Index, TryClone};
578 use crate::metric::MetricType;
579
580 #[test]
581 fn index_factory_flat() {
582 let index = index_factory(64, "Flat", MetricType::L2).unwrap();
583 assert_eq!(index.is_trained(), true); assert_eq!(index.ntotal(), 0);
585 }
586
587 #[test]
588 fn index_factory_flat_boxed() {
589 let index = index_factory(64, "Flat", MetricType::L2).unwrap();
590 let boxed = Box::new(index);
591 assert_eq!(boxed.is_trained(), true); assert_eq!(boxed.ntotal(), 0);
593 }
594
595 #[test]
596 fn index_factory_ivf_flat() {
597 let index = index_factory(64, "IVF8,Flat", MetricType::L2).unwrap();
598 assert_eq!(index.is_trained(), false);
599 assert_eq!(index.ntotal(), 0);
600 }
601
602 #[test]
603 fn index_factory_sq() {
604 let index = index_factory(64, "SQ8", MetricType::L2).unwrap();
605 assert_eq!(index.is_trained(), false);
606 assert_eq!(index.ntotal(), 0);
607 }
608
609 #[test]
610 fn index_factory_pq() {
611 let index = index_factory(64, "PQ8", MetricType::L2).unwrap();
612 assert_eq!(index.is_trained(), false);
613 assert_eq!(index.ntotal(), 0);
614 }
615
616 #[test]
617 fn index_factory_ivf_sq() {
618 let index = index_factory(64, "IVF8,SQ4", MetricType::L2).unwrap();
619 assert_eq!(index.is_trained(), false);
620 assert_eq!(index.ntotal(), 0);
621
622 let index = index_factory(64, "IVF8,SQ8", MetricType::L2).unwrap();
623 assert_eq!(index.is_trained(), false);
624 assert_eq!(index.ntotal(), 0);
625 }
626
627 #[test]
628 fn index_factory_hnsw() {
629 let index = index_factory(64, "HNSW8", MetricType::L2).unwrap();
630 assert_eq!(index.is_trained(), true); assert_eq!(index.ntotal(), 0);
632 }
633
634 #[test]
635 fn bad_index_factory_description() {
636 let r = index_factory(64, "fdnoyq", MetricType::L2);
637 assert!(r.is_err());
638 let r = index_factory(64, "Flat\0Flat", MetricType::L2);
639 assert!(r.is_err());
640 }
641
642 #[test]
643 fn index_clone() {
644 let mut index = index_factory(4, "Flat", MetricType::L2).unwrap();
645 let some_data = &[
646 7.5_f32, -7.5, 7.5, -7.5, 7.5, 7.5, 7.5, 7.5, -1., 1., 1., 1., 1., 1., 1., -1., 0., 0.,
647 0., 1., 1., 0., 0., -1.,
648 ];
649
650 index.add(some_data).unwrap();
651 assert_eq!(index.ntotal(), 6);
652
653 let mut index2 = index.try_clone().unwrap();
654 assert_eq!(index2.ntotal(), 6);
655
656 let some_more_data = &[
657 100., 100., 100., 100., -100., 100., 100., 100., 120., 100., 100., 105., -100., 100.,
658 100., 105.,
659 ];
660
661 index2.add(some_more_data).unwrap();
662 assert_eq!(index.ntotal(), 6);
663 assert_eq!(index2.ntotal(), 10);
664 }
665
666 #[test]
667 fn flat_index_search() {
668 let mut index = index_factory(8, "Flat", MetricType::L2).unwrap();
669 let some_data = &[
670 7.5_f32, -7.5, 7.5, -7.5, 7.5, 7.5, 7.5, 7.5, -1., 1., 1., 1., 1., 1., 1., -1., 0., 0.,
671 0., 1., 1., 0., 0., -1., 100., 100., 100., 100., -100., 100., 100., 100., 120., 100.,
672 100., 105., -100., 100., 100., 105.,
673 ];
674 index.add(some_data).unwrap();
675 assert_eq!(index.ntotal(), 5);
676
677 let my_query = [0.; 8];
678 let result = index.search(&my_query, 5).unwrap();
679 assert_eq!(result.labels, vec![Idx(2), Idx(1), Idx(0), Idx(3), Idx(4)]);
680 assert!(result.distances.iter().all(|x| *x > 0.));
681
682 let my_query = [100.; 8];
683 let result = index.search(&my_query, 5).unwrap();
684 assert_eq!(result.labels, vec![Idx(3), Idx(4), Idx(0), Idx(1), Idx(2)]);
685 assert!(result.distances.iter().all(|x| *x > 0.));
686
687 let my_query = vec![
688 0., 0., 0., 0., 0., 0., 0., 0., 100., 100., 100., 100., 100., 100., 100., 100.,
689 ];
690 let result = index.search(&my_query, 5).unwrap();
691 assert_eq!(
692 result.labels,
693 vec![
694 Idx(2),
695 Idx(1),
696 Idx(0),
697 Idx(3),
698 Idx(4),
699 Idx(3),
700 Idx(4),
701 Idx(0),
702 Idx(1),
703 Idx(2)
704 ]
705 );
706 assert!(result.distances.iter().all(|x| *x > 0.));
707 }
708
709 #[test]
710 fn flat_index_assign() {
711 let mut index = index_factory(8, "Flat", MetricType::L2).unwrap();
712 assert_eq!(index.d(), 8);
713 assert_eq!(index.ntotal(), 0);
714 let some_data = &[
715 7.5_f32, -7.5, 7.5, -7.5, 7.5, 7.5, 7.5, 7.5, -1., 1., 1., 1., 1., 1., 1., -1., 0., 0.,
716 0., 1., 1., 0., 0., -1., 100., 100., 100., 100., -100., 100., 100., 100., 120., 100.,
717 100., 105., -100., 100., 100., 105.,
718 ];
719 index.add(some_data).unwrap();
720 assert_eq!(index.ntotal(), 5);
721
722 let my_query = [0.; 8];
723 let result = index.assign(&my_query, 5).unwrap();
724 assert_eq!(result.labels, vec![Idx(2), Idx(1), Idx(0), Idx(3), Idx(4)]);
725
726 let my_query = [0.; 32];
727 let result = index.assign(&my_query, 5).unwrap();
728 assert_eq!(
729 result.labels,
730 vec![2, 1, 0, 3, 4, 2, 1, 0, 3, 4, 2, 1, 0, 3, 4, 2, 1, 0, 3, 4]
731 .into_iter()
732 .map(Idx)
733 .collect::<Vec<_>>()
734 );
735
736 let my_query = [100.; 8];
737 let result = index.assign(&my_query, 5).unwrap();
738 assert_eq!(
739 result.labels,
740 vec![3, 4, 0, 1, 2].into_iter().map(Idx).collect::<Vec<_>>()
741 );
742
743 let my_query = vec![
744 0., 0., 0., 0., 0., 0., 0., 0., 100., 100., 100., 100., 100., 100., 100., 100.,
745 ];
746 let result = index.assign(&my_query, 5).unwrap();
747 assert_eq!(
748 result.labels,
749 vec![2, 1, 0, 3, 4, 3, 4, 0, 1, 2]
750 .into_iter()
751 .map(Idx)
752 .collect::<Vec<_>>()
753 );
754
755 index.reset().unwrap();
756 assert_eq!(index.ntotal(), 0);
757 }
758
759 #[test]
760 fn flat_index_range_search() {
761 let mut index = index_factory(8, "Flat", MetricType::L2).unwrap();
762 let some_data = &[
763 7.5_f32, -7.5, 7.5, -7.5, 7.5, 7.5, 7.5, 7.5, -1., 1., 1., 1., 1., 1., 1., -1., 0., 0.,
764 0., 1., 1., 0., 0., -1., 100., 100., 100., 100., -100., 100., 100., 100., 120., 100.,
765 100., 105., -100., 100., 100., 105.,
766 ];
767 index.add(some_data).unwrap();
768 assert_eq!(index.ntotal(), 5);
769
770 let my_query = [0.; 8];
771 let result = index.range_search(&my_query, 8.125).unwrap();
772 let (distances, labels) = result.distance_and_labels();
773 assert!(labels == &[Idx(1), Idx(2)] || labels == &[Idx(2), Idx(1)]);
774 assert!(distances.iter().all(|x| *x > 0.));
775 }
776}