1use std::ptr;
2
3use faiss_next_sys::{self, FaissIndex, FaissIndexBinary, FaissRangeSearchResult};
4
5use crate::error::{check_return_code, Error, Result};
6use crate::id_selector::IDSelector;
7use crate::idx::Idx;
8use crate::metric::MetricType;
9use crate::result::{BinarySearchResult, RangeSearchResult, SearchResult};
10
11pub trait Index {
12 fn inner_ptr(&self) -> *mut FaissIndex;
13
14 fn is_trained(&self) -> bool {
15 unsafe { faiss_next_sys::faiss_Index_is_trained(self.inner_ptr()) != 0 }
16 }
17
18 fn ntotal(&self) -> u64 {
19 unsafe { faiss_next_sys::faiss_Index_ntotal(self.inner_ptr()) as u64 }
20 }
21
22 fn d(&self) -> u32 {
23 unsafe { faiss_next_sys::faiss_Index_d(self.inner_ptr()) as u32 }
24 }
25
26 fn metric_type(&self) -> MetricType {
27 let mt = unsafe { faiss_next_sys::faiss_Index_metric_type(self.inner_ptr()) };
28 MetricType::from_native(mt)
29 }
30
31 fn train(&mut self, x: &[f32]) -> Result<()> {
32 let n = x.len() / self.d() as usize;
33 check_return_code(unsafe {
34 faiss_next_sys::faiss_Index_train(self.inner_ptr(), n as i64, x.as_ptr())
35 })
36 }
37
38 fn add(&mut self, x: &[f32]) -> Result<()> {
39 let n = x.len() / self.d() as usize;
40 check_return_code(unsafe {
41 faiss_next_sys::faiss_Index_add(self.inner_ptr(), n as i64, x.as_ptr())
42 })
43 }
44
45 fn add_with_ids(&mut self, x: &[f32], ids: &[Idx]) -> Result<()> {
46 let n = x.len() / self.d() as usize;
47 if ids.len() < n {
48 return Err(Error::InvalidDimension {
49 expected: n,
50 actual: ids.len(),
51 });
52 }
53 let ids_raw: Vec<i64> = ids.iter().map(|&id| id.as_repr()).collect();
54 check_return_code(unsafe {
55 faiss_next_sys::faiss_Index_add_with_ids(
56 self.inner_ptr(),
57 n as i64,
58 x.as_ptr(),
59 ids_raw.as_ptr(),
60 )
61 })
62 }
63
64 fn search(&mut self, q: &[f32], k: usize) -> Result<SearchResult> {
65 let d = self.d() as usize;
66 let nq = q.len() / d;
67 let mut distances = vec![0.0f32; nq * k];
68 let mut labels = vec![Idx::NONE; nq * k];
69
70 check_return_code(unsafe {
71 faiss_next_sys::faiss_Index_search(
72 self.inner_ptr(),
73 nq as i64,
74 q.as_ptr(),
75 k as i64,
76 distances.as_mut_ptr(),
77 labels.as_mut_ptr() as *mut i64,
78 )
79 })?;
80
81 Ok(SearchResult::new(distances, labels))
82 }
83
84 fn search_with_params<P: crate::search_params::SearchParams>(
85 &mut self,
86 q: &[f32],
87 k: usize,
88 params: &P,
89 ) -> Result<SearchResult> {
90 let d = self.d() as usize;
91 let nq = q.len() / d;
92 let mut distances = vec![0.0f32; nq * k];
93 let mut labels = vec![Idx::NONE; nq * k];
94
95 check_return_code(unsafe {
96 faiss_next_sys::faiss_Index_search_with_params(
97 self.inner_ptr(),
98 nq as i64,
99 q.as_ptr(),
100 k as i64,
101 params.as_ptr(),
102 distances.as_mut_ptr(),
103 labels.as_mut_ptr() as *mut i64,
104 )
105 })?;
106
107 Ok(SearchResult::new(distances, labels))
108 }
109
110 fn range_search(&mut self, q: &[f32], radius: f32) -> Result<RangeSearchResult> {
111 let d = self.d() as usize;
112 let nq = q.len() / d;
113
114 unsafe {
115 let mut result: *mut FaissRangeSearchResult = ptr::null_mut();
116 check_return_code(faiss_next_sys::faiss_RangeSearchResult_new(
117 &mut result,
118 nq as i64,
119 ))?;
120
121 check_return_code(faiss_next_sys::faiss_Index_range_search(
122 self.inner_ptr(),
123 nq as i64,
124 q.as_ptr(),
125 radius,
126 result,
127 ))?;
128
129 let mut lims = ptr::null_mut();
130 let mut distances = ptr::null_mut();
131 let mut labels = ptr::null_mut();
132
133 faiss_next_sys::faiss_RangeSearchResult_lims(result, &mut lims);
134 faiss_next_sys::faiss_RangeSearchResult_labels(result, &mut labels, &mut distances);
135
136 let nq_actual = nq + 1;
137 let lims_slice = std::slice::from_raw_parts(lims, nq_actual).to_vec();
138 let total = *lims_slice.last().unwrap_or(&0);
139
140 let labels_slice = std::slice::from_raw_parts(labels as *const i64, total)
141 .iter()
142 .map(|&l| Idx(l))
143 .collect();
144 let distances_slice = std::slice::from_raw_parts(distances, total).to_vec();
145
146 faiss_next_sys::faiss_RangeSearchResult_free(result);
147
148 Ok(RangeSearchResult::new(
149 labels_slice,
150 distances_slice,
151 lims_slice,
152 ))
153 }
154 }
155
156 fn assign(&mut self, q: &[f32], k: usize) -> Result<Vec<Idx>> {
157 let d = self.d() as usize;
158 let nq = q.len() / d;
159 let mut labels = vec![Idx::NONE; nq * k];
160
161 check_return_code(unsafe {
162 faiss_next_sys::faiss_Index_assign(
163 self.inner_ptr(),
164 nq as i64,
165 q.as_ptr(),
166 labels.as_mut_ptr() as *mut i64,
167 k as i64,
168 )
169 })?;
170
171 Ok(labels)
172 }
173
174 fn reset(&mut self) -> Result<()> {
175 check_return_code(unsafe { faiss_next_sys::faiss_Index_reset(self.inner_ptr()) })
176 }
177
178 fn remove_ids<S: IDSelector>(&mut self, sel: &S) -> Result<usize> {
179 let mut n_removed: usize = 0;
180 check_return_code(unsafe {
181 faiss_next_sys::faiss_Index_remove_ids(self.inner_ptr(), sel.as_ptr(), &mut n_removed)
182 })?;
183 Ok(n_removed)
184 }
185
186 fn reconstruct(&self, id: Idx) -> Result<Vec<f32>> {
187 let d = self.d() as usize;
188 let mut recons = vec![0.0f32; d];
189 check_return_code(unsafe {
190 faiss_next_sys::faiss_Index_reconstruct(
191 self.inner_ptr(),
192 id.as_repr(),
193 recons.as_mut_ptr(),
194 )
195 })?;
196 Ok(recons)
197 }
198
199 fn reconstruct_n(&self, i0: Idx, ni: usize) -> Result<Vec<f32>> {
200 let d = self.d() as usize;
201 let mut recons = vec![0.0f32; ni * d];
202 check_return_code(unsafe {
203 faiss_next_sys::faiss_Index_reconstruct_n(
204 self.inner_ptr(),
205 i0.as_repr(),
206 ni as i64,
207 recons.as_mut_ptr(),
208 )
209 })?;
210 Ok(recons)
211 }
212
213 fn sa_code_size(&self) -> Result<usize> {
214 let mut size: usize = 0;
215 check_return_code(unsafe {
216 faiss_next_sys::faiss_Index_sa_code_size(self.inner_ptr(), &mut size)
217 })?;
218 Ok(size)
219 }
220
221 fn sa_encode(&self, x: &[f32]) -> Result<Vec<u8>> {
222 let d = self.d() as usize;
223 let n = x.len() / d;
224 let code_size = self.sa_code_size()?;
225 let mut bytes = vec![0u8; n * code_size];
226 check_return_code(unsafe {
227 faiss_next_sys::faiss_Index_sa_encode(
228 self.inner_ptr(),
229 n as i64,
230 x.as_ptr(),
231 bytes.as_mut_ptr(),
232 )
233 })?;
234 Ok(bytes)
235 }
236
237 fn sa_decode(&self, bytes: &[u8]) -> Result<Vec<f32>> {
238 let d = self.d() as usize;
239 let code_size = self.sa_code_size()?;
240 let n = bytes.len() / code_size;
241 let mut x = vec![0.0f32; n * d];
242 check_return_code(unsafe {
243 faiss_next_sys::faiss_Index_sa_decode(
244 self.inner_ptr(),
245 n as i64,
246 bytes.as_ptr(),
247 x.as_mut_ptr(),
248 )
249 })?;
250 Ok(x)
251 }
252
253 fn verbose(&self) -> bool {
254 unsafe { faiss_next_sys::faiss_Index_verbose(self.inner_ptr()) != 0 }
255 }
256
257 fn set_verbose(&mut self, verbose: bool) {
258 unsafe { faiss_next_sys::faiss_Index_set_verbose(self.inner_ptr(), verbose as i32) }
259 }
260
261 fn compute_residual(&self, x: &[f32], key: Idx) -> Result<Vec<f32>> {
262 let d = self.d() as usize;
263 let mut residual = vec![0.0f32; d];
264 check_return_code(unsafe {
265 faiss_next_sys::faiss_Index_compute_residual(
266 self.inner_ptr(),
267 x.as_ptr(),
268 residual.as_mut_ptr(),
269 key.as_repr(),
270 )
271 })?;
272 Ok(residual)
273 }
274
275 fn compute_residual_n(&self, x: &[f32], keys: &[Idx]) -> Result<Vec<f32>> {
276 let d = self.d() as usize;
277 let n = x.len() / d;
278 let mut residuals = vec![0.0f32; x.len()];
279 let keys_raw: Vec<i64> = keys.iter().map(|&id| id.as_repr()).collect();
280 check_return_code(unsafe {
281 faiss_next_sys::faiss_Index_compute_residual_n(
282 self.inner_ptr(),
283 n as i64,
284 x.as_ptr(),
285 residuals.as_mut_ptr(),
286 keys_raw.as_ptr(),
287 )
288 })?;
289 Ok(residuals)
290 }
291}
292
293pub trait IvfIndex: Index {
294 fn nlist(&self) -> usize;
295 fn nprobe(&self) -> usize;
296 fn set_nprobe(&mut self, nprobe: usize);
297
298 fn get_list_size(&self, list_no: usize) -> usize {
299 unsafe { faiss_next_sys::faiss_IndexIVF_get_list_size(self.inner_ptr(), list_no) }
300 }
301
302 fn make_direct_map(&mut self, new_type: i32) -> Result<()> {
303 check_return_code(unsafe {
304 faiss_next_sys::faiss_IndexIVF_make_direct_map(self.inner_ptr(), new_type)
305 })
306 }
307
308 fn merge_from(&mut self, other: &mut impl Index, add_ids: bool) -> Result<()> {
309 check_return_code(unsafe {
310 faiss_next_sys::faiss_IndexIVF_merge_from(
311 self.inner_ptr(),
312 other.inner_ptr(),
313 add_ids as i64,
314 )
315 })
316 }
317
318 fn search_preassigned(
319 &mut self,
320 q: &[f32],
321 k: usize,
322 assign: &[i64],
323 centroid_dis: &[f32],
324 store_pairs: bool,
325 ) -> Result<SearchResult> {
326 let d = self.d() as usize;
327 let nq = q.len() / d;
328 let mut distances = vec![0.0f32; nq * k];
329 let mut labels = vec![Idx::NONE; nq * k];
330
331 check_return_code(unsafe {
332 faiss_next_sys::faiss_IndexIVF_search_preassigned(
333 self.inner_ptr(),
334 nq as i64,
335 q.as_ptr(),
336 k as i64,
337 assign.as_ptr(),
338 centroid_dis.as_ptr(),
339 distances.as_mut_ptr(),
340 labels.as_mut_ptr() as *mut i64,
341 store_pairs as i32,
342 )
343 })?;
344
345 Ok(SearchResult::new(distances, labels))
346 }
347}
348
349pub trait BinaryIndex {
350 fn inner_ptr(&self) -> *mut FaissIndexBinary;
351
352 fn is_trained(&self) -> bool {
353 unsafe { faiss_next_sys::faiss_IndexBinary_is_trained(self.inner_ptr()) != 0 }
354 }
355
356 fn ntotal(&self) -> u64 {
357 unsafe { faiss_next_sys::faiss_IndexBinary_ntotal(self.inner_ptr()) as u64 }
358 }
359
360 fn d(&self) -> u32 {
361 unsafe { faiss_next_sys::faiss_IndexBinary_d(self.inner_ptr()) as u32 }
362 }
363
364 fn metric_type(&self) -> MetricType {
365 let mt = unsafe { faiss_next_sys::faiss_IndexBinary_metric_type(self.inner_ptr()) };
366 MetricType::from_native(mt)
367 }
368
369 fn train(&mut self, x: &[u8]) -> Result<()> {
370 let d_bytes = self.d() as usize / 8;
371 let n = x.len() / d_bytes;
372 check_return_code(unsafe {
373 faiss_next_sys::faiss_IndexBinary_train(self.inner_ptr(), n as i64, x.as_ptr())
374 })
375 }
376
377 fn add(&mut self, x: &[u8]) -> Result<()> {
378 let d_bytes = self.d() as usize / 8;
379 let n = x.len() / d_bytes;
380 check_return_code(unsafe {
381 faiss_next_sys::faiss_IndexBinary_add(self.inner_ptr(), n as i64, x.as_ptr())
382 })
383 }
384
385 fn add_with_ids(&mut self, x: &[u8], ids: &[Idx]) -> Result<()> {
386 let d_bytes = self.d() as usize / 8;
387 let n = x.len() / d_bytes;
388 if ids.len() < n {
389 return Err(Error::InvalidDimension {
390 expected: n,
391 actual: ids.len(),
392 });
393 }
394 let ids_raw: Vec<i64> = ids.iter().map(|&id| id.as_repr()).collect();
395 check_return_code(unsafe {
396 faiss_next_sys::faiss_IndexBinary_add_with_ids(
397 self.inner_ptr(),
398 n as i64,
399 x.as_ptr(),
400 ids_raw.as_ptr(),
401 )
402 })
403 }
404
405 fn search(&mut self, q: &[u8], k: usize) -> Result<BinarySearchResult> {
406 let d_bytes = self.d() as usize / 8;
407 let nq = q.len() / d_bytes;
408 let mut distances = vec![0i32; nq * k];
409 let mut labels = vec![Idx::NONE; nq * k];
410
411 check_return_code(unsafe {
412 faiss_next_sys::faiss_IndexBinary_search(
413 self.inner_ptr(),
414 nq as i64,
415 q.as_ptr(),
416 k as i64,
417 distances.as_mut_ptr(),
418 labels.as_mut_ptr() as *mut i64,
419 )
420 })?;
421
422 Ok(BinarySearchResult::new(distances, labels))
423 }
424
425 fn range_search(&mut self, q: &[u8], radius: i32) -> Result<RangeSearchResult> {
426 let d_bytes = self.d() as usize / 8;
427 let nq = q.len() / d_bytes;
428
429 unsafe {
430 let mut result: *mut FaissRangeSearchResult = ptr::null_mut();
431 check_return_code(faiss_next_sys::faiss_RangeSearchResult_new(
432 &mut result,
433 nq as i64,
434 ))?;
435
436 check_return_code(faiss_next_sys::faiss_IndexBinary_range_search(
437 self.inner_ptr(),
438 nq as i64,
439 q.as_ptr(),
440 radius,
441 result,
442 ))?;
443
444 let mut lims = ptr::null_mut();
445 let mut distances = ptr::null_mut();
446 let mut labels = ptr::null_mut();
447
448 faiss_next_sys::faiss_RangeSearchResult_lims(result, &mut lims);
449 faiss_next_sys::faiss_RangeSearchResult_labels(result, &mut labels, &mut distances);
450
451 let nq_actual = nq + 1;
452 let lims_slice = std::slice::from_raw_parts(lims, nq_actual).to_vec();
453 let total = *lims_slice.last().unwrap_or(&0);
454
455 let labels_slice = std::slice::from_raw_parts(labels as *const i64, total)
456 .iter()
457 .map(|&l| Idx(l))
458 .collect();
459 let distances_slice = std::slice::from_raw_parts(distances, total).to_vec();
460
461 faiss_next_sys::faiss_RangeSearchResult_free(result);
462
463 Ok(RangeSearchResult::new(
464 labels_slice,
465 distances_slice,
466 lims_slice,
467 ))
468 }
469 }
470
471 fn assign(&mut self, q: &[u8], k: usize) -> Result<Vec<Idx>> {
472 let d_bytes = self.d() as usize / 8;
473 let nq = q.len() / d_bytes;
474 let mut labels = vec![Idx::NONE; nq * k];
475
476 check_return_code(unsafe {
477 faiss_next_sys::faiss_IndexBinary_assign(
478 self.inner_ptr(),
479 nq as i64,
480 q.as_ptr(),
481 labels.as_mut_ptr() as *mut i64,
482 k as i64,
483 )
484 })?;
485
486 Ok(labels)
487 }
488
489 fn reset(&mut self) -> Result<()> {
490 check_return_code(unsafe { faiss_next_sys::faiss_IndexBinary_reset(self.inner_ptr()) })
491 }
492
493 fn remove_ids<S: IDSelector>(&mut self, sel: &S) -> Result<usize> {
494 let mut n_removed: usize = 0;
495 check_return_code(unsafe {
496 faiss_next_sys::faiss_IndexBinary_remove_ids(
497 self.inner_ptr(),
498 sel.as_ptr(),
499 &mut n_removed,
500 )
501 })?;
502 Ok(n_removed)
503 }
504
505 fn reconstruct(&self, id: Idx) -> Result<Vec<u8>> {
506 let d_bytes = self.d() as usize / 8;
507 let mut recons = vec![0u8; d_bytes];
508 check_return_code(unsafe {
509 faiss_next_sys::faiss_IndexBinary_reconstruct(
510 self.inner_ptr(),
511 id.as_repr(),
512 recons.as_mut_ptr(),
513 )
514 })?;
515 Ok(recons)
516 }
517
518 fn reconstruct_n(&self, i0: Idx, ni: usize) -> Result<Vec<u8>> {
519 let d_bytes = self.d() as usize / 8;
520 let mut recons = vec![0u8; ni * d_bytes];
521 check_return_code(unsafe {
522 faiss_next_sys::faiss_IndexBinary_reconstruct_n(
523 self.inner_ptr(),
524 i0.as_repr(),
525 ni as i64,
526 recons.as_mut_ptr(),
527 )
528 })?;
529 Ok(recons)
530 }
531
532 fn verbose(&self) -> bool {
533 unsafe { faiss_next_sys::faiss_IndexBinary_verbose(self.inner_ptr()) != 0 }
534 }
535
536 fn set_verbose(&mut self, verbose: bool) {
537 unsafe { faiss_next_sys::faiss_IndexBinary_set_verbose(self.inner_ptr(), verbose as i32) }
538 }
539}