Skip to main content

faiss_next/
transform.rs

1use std::ptr;
2
3use faiss_next_sys::{
4    self, FaissCenteringTransform, FaissITQMatrix, FaissITQTransform, FaissLinearTransform,
5    FaissNormalizationTransform, FaissOPQMatrix, FaissPCAMatrix, FaissRandomRotationMatrix,
6    FaissRemapDimensionsTransform, FaissVectorTransform,
7};
8
9use crate::error::{check_return_code, Result};
10
11pub trait VectorTransform {
12    fn inner_ptr(&self) -> *mut FaissVectorTransform;
13
14    fn is_trained(&self) -> bool {
15        unsafe { faiss_next_sys::faiss_VectorTransform_is_trained(self.inner_ptr()) != 0 }
16    }
17
18    fn d_in(&self) -> u32 {
19        unsafe { faiss_next_sys::faiss_VectorTransform_d_in(self.inner_ptr()) as u32 }
20    }
21
22    fn d_out(&self) -> u32 {
23        unsafe { faiss_next_sys::faiss_VectorTransform_d_out(self.inner_ptr()) as u32 }
24    }
25
26    fn train(&mut self, n: usize, x: &[f32]) -> Result<()> {
27        check_return_code(unsafe {
28            faiss_next_sys::faiss_VectorTransform_train(self.inner_ptr(), n as i64, x.as_ptr())
29        })
30    }
31
32    fn apply(&self, n: usize, x: &[f32]) -> Result<Vec<f32>> {
33        let d_out = self.d_out() as usize;
34        let mut xt = vec![0.0f32; n * d_out];
35        unsafe {
36            faiss_next_sys::faiss_VectorTransform_apply_noalloc(
37                self.inner_ptr(),
38                n as i64,
39                x.as_ptr(),
40                xt.as_mut_ptr(),
41            )
42        }
43        Ok(xt)
44    }
45
46    fn apply_noalloc(&self, n: usize, x: &[f32], xt: &mut [f32]) {
47        unsafe {
48            faiss_next_sys::faiss_VectorTransform_apply_noalloc(
49                self.inner_ptr(),
50                n as i64,
51                x.as_ptr(),
52                xt.as_mut_ptr(),
53            )
54        }
55    }
56
57    fn reverse_transform(&self, n: usize, xt: &[f32], x: &mut [f32]) {
58        unsafe {
59            faiss_next_sys::faiss_VectorTransform_reverse_transform(
60                self.inner_ptr(),
61                n as i64,
62                xt.as_ptr(),
63                x.as_mut_ptr(),
64            )
65        }
66    }
67}
68
69pub trait LinearTransform: VectorTransform {
70    fn inner_linear_ptr(&self) -> *mut FaissLinearTransform;
71
72    fn transform_transpose(&self, n: usize, y: &[f32], x: &mut [f32]) {
73        unsafe {
74            faiss_next_sys::faiss_LinearTransform_transform_transpose(
75                self.inner_linear_ptr(),
76                n as i64,
77                y.as_ptr(),
78                x.as_mut_ptr(),
79            )
80        }
81    }
82
83    fn set_is_orthonormal(&mut self) {
84        unsafe { faiss_next_sys::faiss_LinearTransform_set_is_orthonormal(self.inner_linear_ptr()) }
85    }
86
87    fn have_bias(&self) -> bool {
88        unsafe { faiss_next_sys::faiss_LinearTransform_have_bias(self.inner_linear_ptr()) != 0 }
89    }
90
91    fn is_orthonormal(&self) -> bool {
92        unsafe {
93            faiss_next_sys::faiss_LinearTransform_is_orthonormal(self.inner_linear_ptr()) != 0
94        }
95    }
96}
97
98pub struct PcaMatrix {
99    ptr: *mut FaissPCAMatrix,
100}
101
102impl PcaMatrix {
103    pub fn new(d_in: u32, d_out: u32, eigen_power: f32, random_rotation: bool) -> Result<Self> {
104        unsafe {
105            let mut ptr: *mut FaissPCAMatrix = ptr::null_mut();
106            check_return_code(faiss_next_sys::faiss_PCAMatrix_new_with(
107                &mut ptr,
108                d_in as i32,
109                d_out as i32,
110                eigen_power,
111                random_rotation as i32,
112            ))?;
113            Ok(Self { ptr })
114        }
115    }
116
117    pub fn eigen_power(&self) -> f32 {
118        unsafe { faiss_next_sys::faiss_PCAMatrix_eigen_power(self.ptr) }
119    }
120
121    pub fn random_rotation(&self) -> bool {
122        unsafe { faiss_next_sys::faiss_PCAMatrix_random_rotation(self.ptr) != 0 }
123    }
124}
125
126impl VectorTransform for PcaMatrix {
127    fn inner_ptr(&self) -> *mut FaissVectorTransform {
128        self.ptr as *mut FaissVectorTransform
129    }
130}
131
132impl LinearTransform for PcaMatrix {
133    fn inner_linear_ptr(&self) -> *mut FaissLinearTransform {
134        self.ptr as *mut FaissLinearTransform
135    }
136}
137
138impl Drop for PcaMatrix {
139    fn drop(&mut self) {
140        if !self.ptr.is_null() {
141            unsafe {
142                faiss_next_sys::faiss_PCAMatrix_free(self.ptr);
143            }
144        }
145    }
146}
147
148pub struct RandomRotationMatrix {
149    ptr: *mut FaissRandomRotationMatrix,
150}
151
152impl RandomRotationMatrix {
153    pub fn new(d_in: u32, d_out: u32) -> Result<Self> {
154        unsafe {
155            let mut ptr: *mut FaissRandomRotationMatrix = ptr::null_mut();
156            check_return_code(faiss_next_sys::faiss_RandomRotationMatrix_new_with(
157                &mut ptr,
158                d_in as i32,
159                d_out as i32,
160            ))?;
161            Ok(Self { ptr })
162        }
163    }
164}
165
166impl VectorTransform for RandomRotationMatrix {
167    fn inner_ptr(&self) -> *mut FaissVectorTransform {
168        self.ptr as *mut FaissVectorTransform
169    }
170}
171
172impl LinearTransform for RandomRotationMatrix {
173    fn inner_linear_ptr(&self) -> *mut FaissLinearTransform {
174        self.ptr as *mut FaissLinearTransform
175    }
176}
177
178impl Drop for RandomRotationMatrix {
179    fn drop(&mut self) {
180        if !self.ptr.is_null() {
181            unsafe {
182                faiss_next_sys::faiss_RandomRotationMatrix_free(self.ptr);
183            }
184        }
185    }
186}
187
188pub struct ItqMatrix {
189    ptr: *mut FaissITQMatrix,
190}
191
192impl ItqMatrix {
193    pub fn new(d: u32) -> Result<Self> {
194        unsafe {
195            let mut ptr: *mut FaissITQMatrix = ptr::null_mut();
196            check_return_code(faiss_next_sys::faiss_ITQMatrix_new_with(&mut ptr, d as i32))?;
197            Ok(Self { ptr })
198        }
199    }
200}
201
202impl VectorTransform for ItqMatrix {
203    fn inner_ptr(&self) -> *mut FaissVectorTransform {
204        self.ptr as *mut FaissVectorTransform
205    }
206}
207
208impl LinearTransform for ItqMatrix {
209    fn inner_linear_ptr(&self) -> *mut FaissLinearTransform {
210        self.ptr as *mut FaissLinearTransform
211    }
212}
213
214impl Drop for ItqMatrix {
215    fn drop(&mut self) {
216        if !self.ptr.is_null() {
217            unsafe {
218                faiss_next_sys::faiss_ITQMatrix_free(self.ptr);
219            }
220        }
221    }
222}
223
224pub struct ItqTransform {
225    ptr: *mut FaissITQTransform,
226}
227
228impl ItqTransform {
229    pub fn new(d_in: u32, d_out: u32, do_pca: bool) -> Result<Self> {
230        unsafe {
231            let mut ptr: *mut FaissITQTransform = ptr::null_mut();
232            check_return_code(faiss_next_sys::faiss_ITQTransform_new_with(
233                &mut ptr,
234                d_in as i32,
235                d_out as i32,
236                do_pca as i32,
237            ))?;
238            Ok(Self { ptr })
239        }
240    }
241
242    pub fn do_pca(&self) -> bool {
243        unsafe { faiss_next_sys::faiss_ITQTransform_do_pca(self.ptr) != 0 }
244    }
245}
246
247impl VectorTransform for ItqTransform {
248    fn inner_ptr(&self) -> *mut FaissVectorTransform {
249        self.ptr as *mut FaissVectorTransform
250    }
251}
252
253impl Drop for ItqTransform {
254    fn drop(&mut self) {
255        if !self.ptr.is_null() {
256            unsafe {
257                faiss_next_sys::faiss_ITQTransform_free(self.ptr);
258            }
259        }
260    }
261}
262
263pub struct OpqMatrix {
264    ptr: *mut FaissOPQMatrix,
265}
266
267impl OpqMatrix {
268    pub fn new(d_in: u32, d_out: u32, m: u32) -> Result<Self> {
269        unsafe {
270            let mut ptr: *mut FaissOPQMatrix = ptr::null_mut();
271            check_return_code(faiss_next_sys::faiss_OPQMatrix_new_with(
272                &mut ptr,
273                d_in as i32,
274                d_out as i32,
275                m as i32,
276            ))?;
277            Ok(Self { ptr })
278        }
279    }
280
281    pub fn verbose(&self) -> bool {
282        unsafe { faiss_next_sys::faiss_OPQMatrix_verbose(self.ptr) != 0 }
283    }
284
285    pub fn set_verbose(&mut self, verbose: bool) {
286        unsafe { faiss_next_sys::faiss_OPQMatrix_set_verbose(self.ptr, verbose as i32) }
287    }
288
289    pub fn niter(&self) -> i32 {
290        unsafe { faiss_next_sys::faiss_OPQMatrix_niter(self.ptr) }
291    }
292
293    pub fn set_niter(&mut self, niter: i32) {
294        unsafe { faiss_next_sys::faiss_OPQMatrix_set_niter(self.ptr, niter) }
295    }
296
297    pub fn niter_pq(&self) -> i32 {
298        unsafe { faiss_next_sys::faiss_OPQMatrix_niter_pq(self.ptr) }
299    }
300
301    pub fn set_niter_pq(&mut self, niter: i32) {
302        unsafe { faiss_next_sys::faiss_OPQMatrix_set_niter_pq(self.ptr, niter) }
303    }
304}
305
306impl VectorTransform for OpqMatrix {
307    fn inner_ptr(&self) -> *mut FaissVectorTransform {
308        self.ptr as *mut FaissVectorTransform
309    }
310}
311
312impl LinearTransform for OpqMatrix {
313    fn inner_linear_ptr(&self) -> *mut FaissLinearTransform {
314        self.ptr as *mut FaissLinearTransform
315    }
316}
317
318impl Drop for OpqMatrix {
319    fn drop(&mut self) {
320        if !self.ptr.is_null() {
321            unsafe {
322                faiss_next_sys::faiss_OPQMatrix_free(self.ptr);
323            }
324        }
325    }
326}
327
328pub struct NormalizationTransform {
329    ptr: *mut FaissNormalizationTransform,
330}
331
332impl NormalizationTransform {
333    pub fn new(d: u32, norm: f32) -> Result<Self> {
334        unsafe {
335            let mut ptr: *mut FaissNormalizationTransform = ptr::null_mut();
336            check_return_code(faiss_next_sys::faiss_NormalizationTransform_new_with(
337                &mut ptr, d as i32, norm,
338            ))?;
339            Ok(Self { ptr })
340        }
341    }
342
343    pub fn norm(&self) -> f32 {
344        unsafe { faiss_next_sys::faiss_NormalizationTransform_norm(self.ptr) }
345    }
346}
347
348impl VectorTransform for NormalizationTransform {
349    fn inner_ptr(&self) -> *mut FaissVectorTransform {
350        self.ptr as *mut FaissVectorTransform
351    }
352}
353
354impl Drop for NormalizationTransform {
355    fn drop(&mut self) {
356        if !self.ptr.is_null() {
357            unsafe {
358                faiss_next_sys::faiss_NormalizationTransform_free(self.ptr);
359            }
360        }
361    }
362}
363
364pub struct CenteringTransform {
365    ptr: *mut FaissCenteringTransform,
366}
367
368impl CenteringTransform {
369    pub fn new(d: u32) -> Result<Self> {
370        unsafe {
371            let mut ptr: *mut FaissCenteringTransform = ptr::null_mut();
372            check_return_code(faiss_next_sys::faiss_CenteringTransform_new_with(
373                &mut ptr, d as i32,
374            ))?;
375            Ok(Self { ptr })
376        }
377    }
378}
379
380impl VectorTransform for CenteringTransform {
381    fn inner_ptr(&self) -> *mut FaissVectorTransform {
382        self.ptr as *mut FaissVectorTransform
383    }
384}
385
386impl Drop for CenteringTransform {
387    fn drop(&mut self) {
388        if !self.ptr.is_null() {
389            unsafe {
390                faiss_next_sys::faiss_CenteringTransform_free(self.ptr);
391            }
392        }
393    }
394}
395
396pub struct RemapDimensionsTransform {
397    ptr: *mut FaissRemapDimensionsTransform,
398}
399
400impl RemapDimensionsTransform {
401    pub fn new(d_in: u32, d_out: u32, uniform: bool) -> Result<Self> {
402        unsafe {
403            let mut ptr: *mut FaissRemapDimensionsTransform = ptr::null_mut();
404            check_return_code(faiss_next_sys::faiss_RemapDimensionsTransform_new_with(
405                &mut ptr,
406                d_in as i32,
407                d_out as i32,
408                uniform as i32,
409            ))?;
410            Ok(Self { ptr })
411        }
412    }
413}
414
415impl VectorTransform for RemapDimensionsTransform {
416    fn inner_ptr(&self) -> *mut FaissVectorTransform {
417        self.ptr as *mut FaissVectorTransform
418    }
419}
420
421impl Drop for RemapDimensionsTransform {
422    fn drop(&mut self) {
423        if !self.ptr.is_null() {
424            unsafe {
425                faiss_next_sys::faiss_RemapDimensionsTransform_free(self.ptr);
426            }
427        }
428    }
429}