faiss/index/
lsh.rs

1//! Interface and implementation to Locality-Sensitive Hashing (LSH) index type.
2
3use super::{
4    try_clone_from_inner_ptr, AssignSearchResult, CpuIndex, FromInnerPtr, Idx, Index, IndexImpl,
5    NativeIndex, RangeSearchResult, SearchResult, TryClone, TryFromInnerPtr,
6};
7use crate::error::{Error, Result};
8use crate::faiss_try;
9use crate::selector::IdSelector;
10use faiss_sys::*;
11use std::mem;
12use std::ptr;
13
14#[derive(Debug)]
15pub struct LshIndex {
16    inner: *mut FaissIndexLSH,
17}
18
19unsafe impl Send for LshIndex {}
20unsafe impl Sync for LshIndex {}
21
22impl CpuIndex for LshIndex {}
23
24impl Drop for LshIndex {
25    fn drop(&mut self) {
26        unsafe {
27            faiss_IndexLSH_free(self.inner);
28        }
29    }
30}
31
32impl NativeIndex for LshIndex {
33    type Inner = FaissIndex;
34    fn inner_ptr(&self) -> *mut FaissIndex {
35        self.inner
36    }
37}
38
39impl FromInnerPtr for LshIndex {
40    unsafe fn from_inner_ptr(inner_ptr: *mut FaissIndex) -> Self {
41        LshIndex { inner: inner_ptr }
42    }
43}
44
45impl TryFromInnerPtr for LshIndex {
46    unsafe fn try_from_inner_ptr(inner_ptr: *mut FaissIndex) -> Result<Self>
47    where
48        Self: Sized,
49    {
50        // safety: `inner_ptr` is documented to be a valid pointer to an index,
51        // so the dynamic cast should be safe.
52        #[allow(unused_unsafe)]
53        unsafe {
54            let new_inner = faiss_IndexLSH_cast(inner_ptr);
55            if new_inner.is_null() {
56                Err(Error::BadCast)
57            } else {
58                Ok(LshIndex { inner: new_inner })
59            }
60        }
61    }
62}
63
64impl LshIndex {
65    /// Create a new LSH index.
66    pub fn new(d: u32, nbits: u32) -> Result<Self> {
67        unsafe {
68            let mut inner = ptr::null_mut();
69            faiss_try(faiss_IndexLSH_new(
70                &mut inner,
71                d as idx_t,
72                nbits as ::std::os::raw::c_int,
73            ))?;
74            Ok(LshIndex { inner })
75        }
76    }
77
78    /// Create a new LSH index.
79    pub fn new_with_options(
80        d: u32,
81        nbits: u32,
82        rotate_data: bool,
83        train_thresholds: bool,
84    ) -> Result<Self> {
85        unsafe {
86            let mut inner = ptr::null_mut();
87            faiss_try(faiss_IndexLSH_new_with_options(
88                &mut inner,
89                d as idx_t,
90                nbits as ::std::os::raw::c_int,
91                rotate_data as ::std::os::raw::c_int,
92                train_thresholds as ::std::os::raw::c_int,
93            ))?;
94            Ok(LshIndex { inner })
95        }
96    }
97
98    pub fn nbits(&self) -> u32 {
99        unsafe { faiss_IndexLSH_nbits(self.inner) as u32 }
100    }
101
102    pub fn rotate_data(&self) -> bool {
103        unsafe { faiss_IndexLSH_rotate_data(self.inner) != 0 }
104    }
105
106    pub fn train_thresholds(&self) -> bool {
107        unsafe { faiss_IndexLSH_rotate_data(self.inner) != 0 }
108    }
109
110    pub fn code_size(&self) -> usize {
111        unsafe { faiss_IndexLSH_code_size(self.inner) as usize }
112    }
113}
114
115impl_native_index!(LshIndex);
116
117impl TryClone for LshIndex {
118    fn try_clone(&self) -> Result<Self>
119    where
120        Self: Sized,
121    {
122        try_clone_from_inner_ptr(self)
123    }
124}
125
126impl IndexImpl {
127    /// Attempt a dynamic cast of an index to the LSH index type.
128    #[deprecated(
129        since = "0.8.0",
130        note = "Non-idiomatic name, prefer `into_lsh` instead"
131    )]
132    pub fn as_lsh(self) -> Result<LshIndex> {
133        self.into_lsh()
134    }
135
136    /// Attempt a dynamic cast of an index to the LSH index type.
137    pub fn into_lsh(self) -> Result<LshIndex> {
138        unsafe {
139            let new_inner = faiss_IndexLSH_cast(self.inner_ptr());
140            if new_inner.is_null() {
141                Err(Error::BadCast)
142            } else {
143                mem::forget(self);
144                Ok(LshIndex { inner: new_inner })
145            }
146        }
147    }
148}
149
150impl_concurrent_index!(LshIndex);
151
152#[cfg(test)]
153mod tests {
154    use super::LshIndex;
155    use crate::error::Result;
156    use crate::index::{index_factory, ConcurrentIndex, FromInnerPtr, Idx, Index, NativeIndex};
157    use crate::metric::MetricType;
158
159    const D: u32 = 8;
160
161    #[test]
162    fn index_from_cast() {
163        let index = index_factory(8, "Flat", MetricType::L2).unwrap();
164        let r: Result<LshIndex> = index.into_lsh();
165        assert!(r.is_err());
166    }
167
168    #[test]
169    fn index_search() {
170        let mut index = LshIndex::new(D, 16).unwrap();
171        assert_eq!(index.d(), D);
172        assert_eq!(index.ntotal(), 0);
173        let some_data = &[
174            7.5_f32, -7.5, 7.5, -7.5, 7.5, 7.5, 7.5, 7.5, -1., 1., 1., 1., 1., 1., 1., -1., 4.,
175            -4., -8., 1., 1., 2., 4., -1., 8., 8., 10., -10., -10., 10., -10., 10., 16., 16., 32.,
176            25., 20., 20., 40., 15.,
177        ];
178        index.train(some_data).unwrap();
179        index.add(some_data).unwrap();
180        assert_eq!(index.ntotal(), 5);
181
182        let my_query = [0.; D as usize];
183        let result = index.search(&my_query, 3).unwrap();
184        assert_eq!(result.labels.len(), 3);
185        assert!(result.labels.into_iter().all(Idx::is_some));
186        assert_eq!(result.distances.len(), 3);
187        assert!(result.distances.iter().all(|x| *x > 0.));
188
189        let my_query = [100.; D as usize];
190        // flat index can be used behind an immutable ref
191        let result = (&index).search(&my_query, 3).unwrap();
192        assert_eq!(result.labels.len(), 3);
193        assert!(result.labels.into_iter().all(Idx::is_some));
194        assert_eq!(result.distances.len(), 3);
195        assert!(result.distances.iter().all(|x| *x > 0.));
196
197        index.reset().unwrap();
198        assert_eq!(index.ntotal(), 0);
199    }
200
201    #[test]
202    fn index_assign() {
203        let mut index = LshIndex::new(D, 16).unwrap();
204        assert_eq!(index.d(), D);
205        assert_eq!(index.ntotal(), 0);
206        let some_data = &[
207            7.5_f32, -7.5, 7.5, -7.5, 7.5, 7.5, 7.5, 7.5, -1., 1., 1., 1., 1., 1., 1., -1., 4.,
208            -4., -8., 1., 1., 2., 4., -1., 8., 8., 10., -10., -10., 10., -10., 10., 16., 16., 32.,
209            25., 20., 20., 40., 15.,
210        ];
211        index.train(some_data).unwrap();
212        index.add(some_data).unwrap();
213        assert_eq!(index.ntotal(), 5);
214
215        let my_query = [0.; D as usize];
216        let result = index.assign(&my_query, 3).unwrap();
217        assert_eq!(result.labels.len(), 3);
218        assert!(result.labels.into_iter().all(Idx::is_some));
219
220        let my_query = [100.; D as usize];
221        // flat index can be used behind an immutable ref
222        let result = (&index).assign(&my_query, 3).unwrap();
223        assert_eq!(result.labels.len(), 3);
224        assert!(result.labels.into_iter().all(Idx::is_some));
225
226        index.reset().unwrap();
227        assert_eq!(index.ntotal(), 0);
228    }
229
230    #[test]
231    fn index_transition() {
232        let index = {
233            let mut index = LshIndex::new(D, 16).unwrap();
234            assert_eq!(index.d(), D);
235            assert_eq!(index.ntotal(), 0);
236            let some_data = &[
237                7.5_f32, -7.5, 7.5, -7.5, 7.5, 7.5, 7.5, 7.5, -1., 1., 1., 1., 1., 1., 1., -1., 4.,
238                -4., -8., 1., 1., 2., 4., -1., 8., 8., 10., -10., -10., 10., -10., 10., 16., 16.,
239                32., 25., 20., 20., 40., 15.,
240            ];
241            index.train(some_data).unwrap();
242            assert!(index.is_trained());
243            index.add(some_data).unwrap();
244            assert_eq!(index.ntotal(), 5);
245
246            unsafe {
247                let inner = index.inner_ptr();
248                // forget index, rebuild it into another object
249                ::std::mem::forget(index);
250                LshIndex::from_inner_ptr(inner)
251            }
252        };
253        assert!(index.is_trained());
254        assert_eq!(index.ntotal(), 5);
255    }
256}