Skip to main content

faiss_next/index/
traits.rs

1use std::ptr;
2
3use faiss_next_sys::{self, FaissIndex, FaissIndexBinary, FaissRangeSearchResult};
4
5use crate::error::{check_return_code, Error, Result};
6use crate::id_selector::IDSelector;
7use crate::idx::Idx;
8use crate::metric::MetricType;
9use crate::result::{BinarySearchResult, RangeSearchResult, SearchResult};
10
11pub trait Index {
12    fn inner_ptr(&self) -> *mut FaissIndex;
13
14    fn is_trained(&self) -> bool {
15        unsafe { faiss_next_sys::faiss_Index_is_trained(self.inner_ptr()) != 0 }
16    }
17
18    fn ntotal(&self) -> u64 {
19        unsafe { faiss_next_sys::faiss_Index_ntotal(self.inner_ptr()) as u64 }
20    }
21
22    fn d(&self) -> u32 {
23        unsafe { faiss_next_sys::faiss_Index_d(self.inner_ptr()) as u32 }
24    }
25
26    fn metric_type(&self) -> MetricType {
27        let mt = unsafe { faiss_next_sys::faiss_Index_metric_type(self.inner_ptr()) };
28        MetricType::from_native(mt)
29    }
30
31    fn train(&mut self, x: &[f32]) -> Result<()> {
32        let n = x.len() / self.d() as usize;
33        check_return_code(unsafe {
34            faiss_next_sys::faiss_Index_train(self.inner_ptr(), n as i64, x.as_ptr())
35        })
36    }
37
38    fn add(&mut self, x: &[f32]) -> Result<()> {
39        let n = x.len() / self.d() as usize;
40        check_return_code(unsafe {
41            faiss_next_sys::faiss_Index_add(self.inner_ptr(), n as i64, x.as_ptr())
42        })
43    }
44
45    fn add_with_ids(&mut self, x: &[f32], ids: &[Idx]) -> Result<()> {
46        let n = x.len() / self.d() as usize;
47        if ids.len() < n {
48            return Err(Error::InvalidDimension {
49                expected: n,
50                actual: ids.len(),
51            });
52        }
53        let ids_raw: Vec<i64> = ids.iter().map(|&id| id.as_repr()).collect();
54        check_return_code(unsafe {
55            faiss_next_sys::faiss_Index_add_with_ids(
56                self.inner_ptr(),
57                n as i64,
58                x.as_ptr(),
59                ids_raw.as_ptr(),
60            )
61        })
62    }
63
64    fn search(&mut self, q: &[f32], k: usize) -> Result<SearchResult> {
65        let d = self.d() as usize;
66        let nq = q.len() / d;
67        let mut distances = vec![0.0f32; nq * k];
68        let mut labels = vec![Idx::NONE; nq * k];
69
70        check_return_code(unsafe {
71            faiss_next_sys::faiss_Index_search(
72                self.inner_ptr(),
73                nq as i64,
74                q.as_ptr(),
75                k as i64,
76                distances.as_mut_ptr(),
77                labels.as_mut_ptr() as *mut i64,
78            )
79        })?;
80
81        Ok(SearchResult::new(distances, labels))
82    }
83
84    fn search_with_params<P: crate::search_params::SearchParams>(
85        &mut self,
86        q: &[f32],
87        k: usize,
88        params: &P,
89    ) -> Result<SearchResult> {
90        let d = self.d() as usize;
91        let nq = q.len() / d;
92        let mut distances = vec![0.0f32; nq * k];
93        let mut labels = vec![Idx::NONE; nq * k];
94
95        check_return_code(unsafe {
96            faiss_next_sys::faiss_Index_search_with_params(
97                self.inner_ptr(),
98                nq as i64,
99                q.as_ptr(),
100                k as i64,
101                params.as_ptr(),
102                distances.as_mut_ptr(),
103                labels.as_mut_ptr() as *mut i64,
104            )
105        })?;
106
107        Ok(SearchResult::new(distances, labels))
108    }
109
110    fn range_search(&mut self, q: &[f32], radius: f32) -> Result<RangeSearchResult> {
111        let d = self.d() as usize;
112        let nq = q.len() / d;
113
114        unsafe {
115            let mut result: *mut FaissRangeSearchResult = ptr::null_mut();
116            check_return_code(faiss_next_sys::faiss_RangeSearchResult_new(
117                &mut result,
118                nq as i64,
119            ))?;
120
121            check_return_code(faiss_next_sys::faiss_Index_range_search(
122                self.inner_ptr(),
123                nq as i64,
124                q.as_ptr(),
125                radius,
126                result,
127            ))?;
128
129            let mut lims = ptr::null_mut();
130            let mut distances = ptr::null_mut();
131            let mut labels = ptr::null_mut();
132
133            faiss_next_sys::faiss_RangeSearchResult_lims(result, &mut lims);
134            faiss_next_sys::faiss_RangeSearchResult_labels(result, &mut labels, &mut distances);
135
136            let nq_actual = nq + 1;
137            let lims_slice = std::slice::from_raw_parts(lims, nq_actual).to_vec();
138            let total = *lims_slice.last().unwrap_or(&0);
139
140            let labels_slice = std::slice::from_raw_parts(labels as *const i64, total)
141                .iter()
142                .map(|&l| Idx(l))
143                .collect();
144            let distances_slice = std::slice::from_raw_parts(distances, total).to_vec();
145
146            faiss_next_sys::faiss_RangeSearchResult_free(result);
147
148            Ok(RangeSearchResult::new(
149                labels_slice,
150                distances_slice,
151                lims_slice,
152            ))
153        }
154    }
155
156    fn assign(&mut self, q: &[f32], k: usize) -> Result<Vec<Idx>> {
157        let d = self.d() as usize;
158        let nq = q.len() / d;
159        let mut labels = vec![Idx::NONE; nq * k];
160
161        check_return_code(unsafe {
162            faiss_next_sys::faiss_Index_assign(
163                self.inner_ptr(),
164                nq as i64,
165                q.as_ptr(),
166                labels.as_mut_ptr() as *mut i64,
167                k as i64,
168            )
169        })?;
170
171        Ok(labels)
172    }
173
174    fn reset(&mut self) -> Result<()> {
175        check_return_code(unsafe { faiss_next_sys::faiss_Index_reset(self.inner_ptr()) })
176    }
177
178    fn remove_ids<S: IDSelector>(&mut self, sel: &S) -> Result<usize> {
179        let mut n_removed: usize = 0;
180        check_return_code(unsafe {
181            faiss_next_sys::faiss_Index_remove_ids(self.inner_ptr(), sel.as_ptr(), &mut n_removed)
182        })?;
183        Ok(n_removed)
184    }
185
186    fn reconstruct(&self, id: Idx) -> Result<Vec<f32>> {
187        let d = self.d() as usize;
188        let mut recons = vec![0.0f32; d];
189        check_return_code(unsafe {
190            faiss_next_sys::faiss_Index_reconstruct(
191                self.inner_ptr(),
192                id.as_repr(),
193                recons.as_mut_ptr(),
194            )
195        })?;
196        Ok(recons)
197    }
198
199    fn reconstruct_n(&self, i0: Idx, ni: usize) -> Result<Vec<f32>> {
200        let d = self.d() as usize;
201        let mut recons = vec![0.0f32; ni * d];
202        check_return_code(unsafe {
203            faiss_next_sys::faiss_Index_reconstruct_n(
204                self.inner_ptr(),
205                i0.as_repr(),
206                ni as i64,
207                recons.as_mut_ptr(),
208            )
209        })?;
210        Ok(recons)
211    }
212
213    fn sa_code_size(&self) -> Result<usize> {
214        let mut size: usize = 0;
215        check_return_code(unsafe {
216            faiss_next_sys::faiss_Index_sa_code_size(self.inner_ptr(), &mut size)
217        })?;
218        Ok(size)
219    }
220
221    fn sa_encode(&self, x: &[f32]) -> Result<Vec<u8>> {
222        let d = self.d() as usize;
223        let n = x.len() / d;
224        let code_size = self.sa_code_size()?;
225        let mut bytes = vec![0u8; n * code_size];
226        check_return_code(unsafe {
227            faiss_next_sys::faiss_Index_sa_encode(
228                self.inner_ptr(),
229                n as i64,
230                x.as_ptr(),
231                bytes.as_mut_ptr(),
232            )
233        })?;
234        Ok(bytes)
235    }
236
237    fn sa_decode(&self, bytes: &[u8]) -> Result<Vec<f32>> {
238        let d = self.d() as usize;
239        let code_size = self.sa_code_size()?;
240        let n = bytes.len() / code_size;
241        let mut x = vec![0.0f32; n * d];
242        check_return_code(unsafe {
243            faiss_next_sys::faiss_Index_sa_decode(
244                self.inner_ptr(),
245                n as i64,
246                bytes.as_ptr(),
247                x.as_mut_ptr(),
248            )
249        })?;
250        Ok(x)
251    }
252
253    fn verbose(&self) -> bool {
254        unsafe { faiss_next_sys::faiss_Index_verbose(self.inner_ptr()) != 0 }
255    }
256
257    fn set_verbose(&mut self, verbose: bool) {
258        unsafe { faiss_next_sys::faiss_Index_set_verbose(self.inner_ptr(), verbose as i32) }
259    }
260
261    fn compute_residual(&self, x: &[f32], key: Idx) -> Result<Vec<f32>> {
262        let d = self.d() as usize;
263        let mut residual = vec![0.0f32; d];
264        check_return_code(unsafe {
265            faiss_next_sys::faiss_Index_compute_residual(
266                self.inner_ptr(),
267                x.as_ptr(),
268                residual.as_mut_ptr(),
269                key.as_repr(),
270            )
271        })?;
272        Ok(residual)
273    }
274
275    fn compute_residual_n(&self, x: &[f32], keys: &[Idx]) -> Result<Vec<f32>> {
276        let d = self.d() as usize;
277        let n = x.len() / d;
278        let mut residuals = vec![0.0f32; x.len()];
279        let keys_raw: Vec<i64> = keys.iter().map(|&id| id.as_repr()).collect();
280        check_return_code(unsafe {
281            faiss_next_sys::faiss_Index_compute_residual_n(
282                self.inner_ptr(),
283                n as i64,
284                x.as_ptr(),
285                residuals.as_mut_ptr(),
286                keys_raw.as_ptr(),
287            )
288        })?;
289        Ok(residuals)
290    }
291}
292
293pub trait IvfIndex: Index {
294    fn nlist(&self) -> usize;
295    fn nprobe(&self) -> usize;
296    fn set_nprobe(&mut self, nprobe: usize);
297
298    fn get_list_size(&self, list_no: usize) -> usize {
299        unsafe { faiss_next_sys::faiss_IndexIVF_get_list_size(self.inner_ptr(), list_no) }
300    }
301
302    fn make_direct_map(&mut self, new_type: i32) -> Result<()> {
303        check_return_code(unsafe {
304            faiss_next_sys::faiss_IndexIVF_make_direct_map(self.inner_ptr(), new_type)
305        })
306    }
307
308    fn merge_from(&mut self, other: &mut impl Index, add_ids: bool) -> Result<()> {
309        check_return_code(unsafe {
310            faiss_next_sys::faiss_IndexIVF_merge_from(
311                self.inner_ptr(),
312                other.inner_ptr(),
313                add_ids as i64,
314            )
315        })
316    }
317
318    fn search_preassigned(
319        &mut self,
320        q: &[f32],
321        k: usize,
322        assign: &[i64],
323        centroid_dis: &[f32],
324        store_pairs: bool,
325    ) -> Result<SearchResult> {
326        let d = self.d() as usize;
327        let nq = q.len() / d;
328        let mut distances = vec![0.0f32; nq * k];
329        let mut labels = vec![Idx::NONE; nq * k];
330
331        check_return_code(unsafe {
332            faiss_next_sys::faiss_IndexIVF_search_preassigned(
333                self.inner_ptr(),
334                nq as i64,
335                q.as_ptr(),
336                k as i64,
337                assign.as_ptr(),
338                centroid_dis.as_ptr(),
339                distances.as_mut_ptr(),
340                labels.as_mut_ptr() as *mut i64,
341                store_pairs as i32,
342            )
343        })?;
344
345        Ok(SearchResult::new(distances, labels))
346    }
347}
348
349pub trait BinaryIndex {
350    fn inner_ptr(&self) -> *mut FaissIndexBinary;
351
352    fn is_trained(&self) -> bool {
353        unsafe { faiss_next_sys::faiss_IndexBinary_is_trained(self.inner_ptr()) != 0 }
354    }
355
356    fn ntotal(&self) -> u64 {
357        unsafe { faiss_next_sys::faiss_IndexBinary_ntotal(self.inner_ptr()) as u64 }
358    }
359
360    fn d(&self) -> u32 {
361        unsafe { faiss_next_sys::faiss_IndexBinary_d(self.inner_ptr()) as u32 }
362    }
363
364    fn metric_type(&self) -> MetricType {
365        let mt = unsafe { faiss_next_sys::faiss_IndexBinary_metric_type(self.inner_ptr()) };
366        MetricType::from_native(mt)
367    }
368
369    fn train(&mut self, x: &[u8]) -> Result<()> {
370        let d_bytes = self.d() as usize / 8;
371        let n = x.len() / d_bytes;
372        check_return_code(unsafe {
373            faiss_next_sys::faiss_IndexBinary_train(self.inner_ptr(), n as i64, x.as_ptr())
374        })
375    }
376
377    fn add(&mut self, x: &[u8]) -> Result<()> {
378        let d_bytes = self.d() as usize / 8;
379        let n = x.len() / d_bytes;
380        check_return_code(unsafe {
381            faiss_next_sys::faiss_IndexBinary_add(self.inner_ptr(), n as i64, x.as_ptr())
382        })
383    }
384
385    fn add_with_ids(&mut self, x: &[u8], ids: &[Idx]) -> Result<()> {
386        let d_bytes = self.d() as usize / 8;
387        let n = x.len() / d_bytes;
388        if ids.len() < n {
389            return Err(Error::InvalidDimension {
390                expected: n,
391                actual: ids.len(),
392            });
393        }
394        let ids_raw: Vec<i64> = ids.iter().map(|&id| id.as_repr()).collect();
395        check_return_code(unsafe {
396            faiss_next_sys::faiss_IndexBinary_add_with_ids(
397                self.inner_ptr(),
398                n as i64,
399                x.as_ptr(),
400                ids_raw.as_ptr(),
401            )
402        })
403    }
404
405    fn search(&mut self, q: &[u8], k: usize) -> Result<BinarySearchResult> {
406        let d_bytes = self.d() as usize / 8;
407        let nq = q.len() / d_bytes;
408        let mut distances = vec![0i32; nq * k];
409        let mut labels = vec![Idx::NONE; nq * k];
410
411        check_return_code(unsafe {
412            faiss_next_sys::faiss_IndexBinary_search(
413                self.inner_ptr(),
414                nq as i64,
415                q.as_ptr(),
416                k as i64,
417                distances.as_mut_ptr(),
418                labels.as_mut_ptr() as *mut i64,
419            )
420        })?;
421
422        Ok(BinarySearchResult::new(distances, labels))
423    }
424
425    fn range_search(&mut self, q: &[u8], radius: i32) -> Result<RangeSearchResult> {
426        let d_bytes = self.d() as usize / 8;
427        let nq = q.len() / d_bytes;
428
429        unsafe {
430            let mut result: *mut FaissRangeSearchResult = ptr::null_mut();
431            check_return_code(faiss_next_sys::faiss_RangeSearchResult_new(
432                &mut result,
433                nq as i64,
434            ))?;
435
436            check_return_code(faiss_next_sys::faiss_IndexBinary_range_search(
437                self.inner_ptr(),
438                nq as i64,
439                q.as_ptr(),
440                radius,
441                result,
442            ))?;
443
444            let mut lims = ptr::null_mut();
445            let mut distances = ptr::null_mut();
446            let mut labels = ptr::null_mut();
447
448            faiss_next_sys::faiss_RangeSearchResult_lims(result, &mut lims);
449            faiss_next_sys::faiss_RangeSearchResult_labels(result, &mut labels, &mut distances);
450
451            let nq_actual = nq + 1;
452            let lims_slice = std::slice::from_raw_parts(lims, nq_actual).to_vec();
453            let total = *lims_slice.last().unwrap_or(&0);
454
455            let labels_slice = std::slice::from_raw_parts(labels as *const i64, total)
456                .iter()
457                .map(|&l| Idx(l))
458                .collect();
459            let distances_slice = std::slice::from_raw_parts(distances, total).to_vec();
460
461            faiss_next_sys::faiss_RangeSearchResult_free(result);
462
463            Ok(RangeSearchResult::new(
464                labels_slice,
465                distances_slice,
466                lims_slice,
467            ))
468        }
469    }
470
471    fn assign(&mut self, q: &[u8], k: usize) -> Result<Vec<Idx>> {
472        let d_bytes = self.d() as usize / 8;
473        let nq = q.len() / d_bytes;
474        let mut labels = vec![Idx::NONE; nq * k];
475
476        check_return_code(unsafe {
477            faiss_next_sys::faiss_IndexBinary_assign(
478                self.inner_ptr(),
479                nq as i64,
480                q.as_ptr(),
481                labels.as_mut_ptr() as *mut i64,
482                k as i64,
483            )
484        })?;
485
486        Ok(labels)
487    }
488
489    fn reset(&mut self) -> Result<()> {
490        check_return_code(unsafe { faiss_next_sys::faiss_IndexBinary_reset(self.inner_ptr()) })
491    }
492
493    fn remove_ids<S: IDSelector>(&mut self, sel: &S) -> Result<usize> {
494        let mut n_removed: usize = 0;
495        check_return_code(unsafe {
496            faiss_next_sys::faiss_IndexBinary_remove_ids(
497                self.inner_ptr(),
498                sel.as_ptr(),
499                &mut n_removed,
500            )
501        })?;
502        Ok(n_removed)
503    }
504
505    fn reconstruct(&self, id: Idx) -> Result<Vec<u8>> {
506        let d_bytes = self.d() as usize / 8;
507        let mut recons = vec![0u8; d_bytes];
508        check_return_code(unsafe {
509            faiss_next_sys::faiss_IndexBinary_reconstruct(
510                self.inner_ptr(),
511                id.as_repr(),
512                recons.as_mut_ptr(),
513            )
514        })?;
515        Ok(recons)
516    }
517
518    fn reconstruct_n(&self, i0: Idx, ni: usize) -> Result<Vec<u8>> {
519        let d_bytes = self.d() as usize / 8;
520        let mut recons = vec![0u8; ni * d_bytes];
521        check_return_code(unsafe {
522            faiss_next_sys::faiss_IndexBinary_reconstruct_n(
523                self.inner_ptr(),
524                i0.as_repr(),
525                ni as i64,
526                recons.as_mut_ptr(),
527            )
528        })?;
529        Ok(recons)
530    }
531
532    fn verbose(&self) -> bool {
533        unsafe { faiss_next_sys::faiss_IndexBinary_verbose(self.inner_ptr()) != 0 }
534    }
535
536    fn set_verbose(&mut self, verbose: bool) {
537        unsafe { faiss_next_sys::faiss_IndexBinary_set_verbose(self.inner_ptr(), verbose as i32) }
538    }
539}