faiss_sys/
lib.rs

1#![allow(non_upper_case_globals)]
2#![allow(non_camel_case_types)]
3#![allow(non_snake_case)]
4
5#[cfg(feature = "gpu")]
6mod bindings_gpu;
7#[cfg(feature = "gpu")]
8pub use bindings_gpu::*;
9
10#[cfg(not(feature = "gpu"))]
11mod bindings;
12#[cfg(not(feature = "gpu"))]
13pub use bindings::*;
14
15#[cfg(test)]
16mod tests {
17    use super::*;
18    use std::ffi::CString;
19    use std::mem;
20    use std::os::raw::c_char;
21    use std::ptr;
22
23    #[test]
24    fn getting_last_error() {
25        unsafe {
26            let mut index_ptr: *mut FaissIndexFlatL2 = ptr::null_mut();
27            let desc = CString::new("noooo").unwrap();
28            let c = faiss_index_factory(&mut index_ptr, 4, desc.as_ptr(), 0);
29            assert_ne!(c, 0);
30            let last_error: *const c_char = faiss_get_last_error();
31            assert!(!last_error.is_null());
32        }
33    }
34
35    #[test]
36    fn flat_index() {
37        const D: usize = 8;
38        unsafe {
39            let mut index_ptr: *mut FaissIndexFlatL2 = ptr::null_mut();
40            let c = faiss_IndexFlatL2_new_with(&mut index_ptr as *mut _, D as idx_t);
41            assert_eq!(c, 0);
42            assert!(!index_ptr.is_null());
43            let some_data = [
44                7.5_f32, -7.5, 7.5, -7.5, 7.5, 7.5, 7.5, 7.5, -1., 1., 1., 1., 1., 1., 1., -1., 0.,
45                0., 0., 1., 1., 0., 0., -1., 100., 100., 100., 100., -100., 100., 100., 100., 200.,
46                100., 100., 500., -100., 100., 100., 500.,
47            ];
48            let some_data_ptr = some_data.as_ptr();
49            assert_eq!(faiss_Index_is_trained(index_ptr) != 0, true);
50            let c = faiss_Index_add(index_ptr, (some_data.len() / D) as idx_t, some_data_ptr);
51            assert_eq!(c, 0);
52            assert_eq!(faiss_Index_ntotal(index_ptr), 5);
53
54            let some_query = [0.0_f32; D];
55            // output vectors (with canary values at the end)
56            let mut distances = [0_f32, 0., 0., 0., -1.];
57            let mut labels = [0 as idx_t, 0, 0, 0, -1];
58            // search for vectors closest to the origin
59            let c = faiss_Index_search(
60                index_ptr,
61                1,
62                some_query.as_ptr(),
63                4,
64                distances.as_mut_ptr(),
65                labels.as_mut_ptr(),
66            );
67            assert_eq!(c, 0);
68            assert_eq!(labels, [2, 1, 0, 3, -1]);
69            assert!(distances[0] > 0.);
70            assert!(distances[1] > 0.);
71            assert!(distances[2] > 0.);
72            assert!(distances[3] > 0.);
73            assert_eq!(distances[4], -1.);
74            faiss_Index_free(index_ptr);
75        }
76    }
77
78    #[test]
79    fn clustering() {
80        const D: usize = 8;
81        unsafe {
82            let mut params = mem::MaybeUninit::<FaissClusteringParameters>::uninit();
83            faiss_ClusteringParameters_init(params.as_mut_ptr());
84            let mut params = params.assume_init();
85            assert_eq!(params.verbose, 0);
86            assert_eq!(params.spherical, 0);
87            assert_eq!(params.frozen_centroids, 0);
88            assert_eq!(params.update_index, 0);
89            assert!(params.niter > 0);
90            params.niter = 5;
91            assert!(params.min_points_per_centroid > 0);
92            assert!(params.max_points_per_centroid > 0);
93            params.min_points_per_centroid = 1;
94            params.max_points_per_centroid = 10;
95
96            let some_data = [
97                7.5_f32, -7.5, 7.5, -7.5, 7.5, 7.5, 7.5, 7.5, -1., 1., 1., 1., 1., 1., 1., -1., 0.,
98                0., 0., 1., 1., 0., 0., -1., 100., 100., 100., 100., -100., 100., 100., 100., -7.,
99                1., 4., 1., 2., 1., 3., -1., 120., 100., 100., 120., -100., 100., 100., 120., 0.,
100                0., -12., 1., 1., 0., 6., -1., 0., 0., -0.25, 1., 16., 24., 0., -1., 100., 10.,
101                100., 100., 10., 100., 50., 10., 20., 22., 4.5, -2., -100., 0., 0., 100.,
102            ];
103
104            let mut clustering_ptr: *mut FaissClustering = ptr::null_mut();
105            let c = faiss_Clustering_new_with_params(&mut clustering_ptr, D as i32, 2, &params);
106            assert_eq!(c, 0);
107            assert_ne!(clustering_ptr, ptr::null_mut());
108            let mut index_ptr: *mut FaissIndexFlatL2 = ptr::null_mut();
109            let desc = CString::new("Flat").unwrap();
110            let c = faiss_index_factory(
111                &mut index_ptr,
112                D as i32,
113                desc.as_ptr(),
114                FaissMetricType_METRIC_L2,
115            );
116            assert_eq!(c, 0);
117            assert_ne!(index_ptr, ptr::null_mut());
118
119            let c = faiss_Clustering_train(clustering_ptr, 10, some_data.as_ptr(), index_ptr);
120            assert_eq!(c, 0);
121
122            faiss_Clustering_free(clustering_ptr);
123        }
124    }
125}