1#![allow(non_upper_case_globals)]
2#![allow(non_camel_case_types)]
3#![allow(non_snake_case)]
4
5#[cfg(feature = "gpu")]
6mod bindings_gpu;
7#[cfg(feature = "gpu")]
8pub use bindings_gpu::*;
9
10#[cfg(not(feature = "gpu"))]
11mod bindings;
12#[cfg(not(feature = "gpu"))]
13pub use bindings::*;
14
15#[cfg(test)]
16mod tests {
17 use super::*;
18 use std::ffi::CString;
19 use std::mem;
20 use std::os::raw::c_char;
21 use std::ptr;
22
23 #[test]
24 fn getting_last_error() {
25 unsafe {
26 let mut index_ptr: *mut FaissIndexFlatL2 = ptr::null_mut();
27 let desc = CString::new("noooo").unwrap();
28 let c = faiss_index_factory(&mut index_ptr, 4, desc.as_ptr(), 0);
29 assert_ne!(c, 0);
30 let last_error: *const c_char = faiss_get_last_error();
31 assert!(!last_error.is_null());
32 }
33 }
34
35 #[test]
36 fn flat_index() {
37 const D: usize = 8;
38 unsafe {
39 let mut index_ptr: *mut FaissIndexFlatL2 = ptr::null_mut();
40 let c = faiss_IndexFlatL2_new_with(&mut index_ptr as *mut _, D as idx_t);
41 assert_eq!(c, 0);
42 assert!(!index_ptr.is_null());
43 let some_data = [
44 7.5_f32, -7.5, 7.5, -7.5, 7.5, 7.5, 7.5, 7.5, -1., 1., 1., 1., 1., 1., 1., -1., 0.,
45 0., 0., 1., 1., 0., 0., -1., 100., 100., 100., 100., -100., 100., 100., 100., 200.,
46 100., 100., 500., -100., 100., 100., 500.,
47 ];
48 let some_data_ptr = some_data.as_ptr();
49 assert_eq!(faiss_Index_is_trained(index_ptr) != 0, true);
50 let c = faiss_Index_add(index_ptr, (some_data.len() / D) as idx_t, some_data_ptr);
51 assert_eq!(c, 0);
52 assert_eq!(faiss_Index_ntotal(index_ptr), 5);
53
54 let some_query = [0.0_f32; D];
55 let mut distances = [0_f32, 0., 0., 0., -1.];
57 let mut labels = [0 as idx_t, 0, 0, 0, -1];
58 let c = faiss_Index_search(
60 index_ptr,
61 1,
62 some_query.as_ptr(),
63 4,
64 distances.as_mut_ptr(),
65 labels.as_mut_ptr(),
66 );
67 assert_eq!(c, 0);
68 assert_eq!(labels, [2, 1, 0, 3, -1]);
69 assert!(distances[0] > 0.);
70 assert!(distances[1] > 0.);
71 assert!(distances[2] > 0.);
72 assert!(distances[3] > 0.);
73 assert_eq!(distances[4], -1.);
74 faiss_Index_free(index_ptr);
75 }
76 }
77
78 #[test]
79 fn clustering() {
80 const D: usize = 8;
81 unsafe {
82 let mut params = mem::MaybeUninit::<FaissClusteringParameters>::uninit();
83 faiss_ClusteringParameters_init(params.as_mut_ptr());
84 let mut params = params.assume_init();
85 assert_eq!(params.verbose, 0);
86 assert_eq!(params.spherical, 0);
87 assert_eq!(params.frozen_centroids, 0);
88 assert_eq!(params.update_index, 0);
89 assert!(params.niter > 0);
90 params.niter = 5;
91 assert!(params.min_points_per_centroid > 0);
92 assert!(params.max_points_per_centroid > 0);
93 params.min_points_per_centroid = 1;
94 params.max_points_per_centroid = 10;
95
96 let some_data = [
97 7.5_f32, -7.5, 7.5, -7.5, 7.5, 7.5, 7.5, 7.5, -1., 1., 1., 1., 1., 1., 1., -1., 0.,
98 0., 0., 1., 1., 0., 0., -1., 100., 100., 100., 100., -100., 100., 100., 100., -7.,
99 1., 4., 1., 2., 1., 3., -1., 120., 100., 100., 120., -100., 100., 100., 120., 0.,
100 0., -12., 1., 1., 0., 6., -1., 0., 0., -0.25, 1., 16., 24., 0., -1., 100., 10.,
101 100., 100., 10., 100., 50., 10., 20., 22., 4.5, -2., -100., 0., 0., 100.,
102 ];
103
104 let mut clustering_ptr: *mut FaissClustering = ptr::null_mut();
105 let c = faiss_Clustering_new_with_params(&mut clustering_ptr, D as i32, 2, ¶ms);
106 assert_eq!(c, 0);
107 assert_ne!(clustering_ptr, ptr::null_mut());
108 let mut index_ptr: *mut FaissIndexFlatL2 = ptr::null_mut();
109 let desc = CString::new("Flat").unwrap();
110 let c = faiss_index_factory(
111 &mut index_ptr,
112 D as i32,
113 desc.as_ptr(),
114 FaissMetricType_METRIC_L2,
115 );
116 assert_eq!(c, 0);
117 assert_ne!(index_ptr, ptr::null_mut());
118
119 let c = faiss_Clustering_train(clustering_ptr, 10, some_data.as_ptr(), index_ptr);
120 assert_eq!(c, 0);
121
122 faiss_Clustering_free(clustering_ptr);
123 }
124 }
125}